braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,536 @@
1
+ # Authors: Divyesh Narayanan <divyesh.narayanan@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+ import math
6
+ import warnings
7
+ from copy import deepcopy
8
+ from typing import Callable, Optional
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+
14
+ from braindecode.models.base import EEGModuleMixin
15
+ from braindecode.modules import CausalConv1d
16
+
17
+
18
+ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
19
+ """Sleep Staging Architecture from Eldele et al. (2021) [Eldele2021]_.
20
+
21
+ .. figure:: https://raw.githubusercontent.com/emadeldeen24/AttnSleep/refs/heads/main/imgs/AttnSleep.png
22
+ :align: center
23
+ :alt: SleepStagerEldele2021 Architecture
24
+
25
+ Attention based Neural Net for sleep staging as described in [Eldele2021]_.
26
+ The code for the paper and this model is also available at [1]_.
27
+ Takes single channel EEG as input.
28
+ Feature extraction module based on multi-resolution convolutional neural network (MRCNN)
29
+ and adaptive feature recalibration (AFR).
30
+ The second module is the temporal context encoder (TCE) that leverages a multi-head attention
31
+ mechanism to capture the temporal dependencies among the extracted features.
32
+
33
+ Warning - This model was designed for signals of 30 seconds at 100Hz or 125Hz (in which case
34
+ the reference architecture from [1]_ which was validated on SHHS dataset [2]_ will be used)
35
+ to use any other input is likely to make the model perform in unintended ways.
36
+
37
+ Parameters
38
+ ----------
39
+ n_tce : int
40
+ Number of TCE clones.
41
+ d_model : int
42
+ Input dimension for the TCE.
43
+ Also the input dimension of the first FC layer in the feed forward
44
+ and the output of the second FC layer in the same.
45
+ Increase for higher sampling rate/signal length.
46
+ It should be divisible by n_attn_heads
47
+ d_ff : int
48
+ Output dimension of the first FC layer in the feed forward and the
49
+ input dimension of the second FC layer in the same.
50
+ n_attn_heads : int
51
+ Number of attention heads. It should be a factor of d_model
52
+ drop_prob : float
53
+ Dropout rate in the PositionWiseFeedforward layer and the TCE layers.
54
+ after_reduced_cnn_size : int
55
+ Number of output channels produced by the convolution in the AFR module.
56
+ return_feats : bool
57
+ If True, return the features, i.e. the output of the feature extractor
58
+ (before the final linear layer). If False, pass the features through
59
+ the final linear layer.
60
+ n_classes : int
61
+ Alias for `n_outputs`.
62
+ input_size_s : float
63
+ Alias for `input_window_seconds`.
64
+ activation: nn.Module, default=nn.ReLU
65
+ Activation function class to apply. Should be a PyTorch activation
66
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
67
+ activation_mrcnn: nn.Module, default=nn.ReLU
68
+ Activation function class to apply in the Mask R-CNN layer.
69
+ Should be a PyTorch activation module class like ``nn.ReLU`` or
70
+ ``nn.GELU``. Default is ``nn.GELU``.
71
+
72
+ References
73
+ ----------
74
+ .. [Eldele2021] E. Eldele et al., "An Attention-Based Deep Learning Approach for Sleep Stage
75
+ Classification With Single-Channel EEG," in IEEE Transactions on Neural Systems and
76
+ Rehabilitation Engineering, vol. 29, pp. 809-818, 2021, doi: 10.1109/TNSRE.2021.3076234.
77
+
78
+ .. [1] https://github.com/emadeldeen24/AttnSleep
79
+
80
+ .. [2] https://sleepdata.org/datasets/shhs
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ sfreq=None,
86
+ n_tce=2,
87
+ d_model=80,
88
+ d_ff=120,
89
+ n_attn_heads=5,
90
+ drop_prob=0.1,
91
+ activation_mrcnn: nn.Module = nn.GELU,
92
+ activation: nn.Module = nn.ReLU,
93
+ input_window_seconds=None,
94
+ n_outputs=None,
95
+ after_reduced_cnn_size=30,
96
+ return_feats=False,
97
+ chs_info=None,
98
+ n_chans=None,
99
+ n_times=None,
100
+ ):
101
+ super().__init__(
102
+ n_outputs=n_outputs,
103
+ n_chans=n_chans,
104
+ chs_info=chs_info,
105
+ n_times=n_times,
106
+ input_window_seconds=input_window_seconds,
107
+ sfreq=sfreq,
108
+ )
109
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
110
+
111
+ self.mapping = {
112
+ "fc.weight": "final_layer.weight",
113
+ "fc.bias": "final_layer.bias",
114
+ }
115
+
116
+ if not (
117
+ (self.input_window_seconds == 30 and self.sfreq == 100 and d_model == 80)
118
+ or (
119
+ self.input_window_seconds == 30 and self.sfreq == 125 and d_model == 100
120
+ )
121
+ ):
122
+ warnings.warn(
123
+ "This model was designed originally for input windows of 30sec at 100Hz, "
124
+ "with d_model at 80 or at 125Hz, with d_model at 100, to use anything "
125
+ "other than this may cause errors or cause the model to perform in "
126
+ "other ways than intended",
127
+ UserWarning,
128
+ )
129
+
130
+ # the usual kernel size for the mrcnn, for sfreq 100
131
+ kernel_size = 7
132
+
133
+ if self.sfreq == 125:
134
+ kernel_size = 6
135
+
136
+ mrcnn = _MRCNN(
137
+ after_reduced_cnn_size,
138
+ kernel_size,
139
+ activation=activation_mrcnn,
140
+ activation_se=activation,
141
+ )
142
+ attn = _MultiHeadedAttention(n_attn_heads, d_model, after_reduced_cnn_size)
143
+ ff = _PositionwiseFeedForward(d_model, d_ff, drop_prob, activation=activation)
144
+ tce = _TCE(
145
+ _EncoderLayer(
146
+ d_model, deepcopy(attn), deepcopy(ff), after_reduced_cnn_size, drop_prob
147
+ ),
148
+ n_tce,
149
+ )
150
+
151
+ self.feature_extractor = nn.Sequential(mrcnn, tce)
152
+ self.len_last_layer = self._len_last_layer(self.n_times)
153
+ self.return_feats = return_feats
154
+
155
+ # TODO: Add new way to handle return features
156
+ """if return_feats:
157
+ raise ValueError("return_feat == True is not accepted anymore")"""
158
+ if not return_feats:
159
+ self.final_layer = nn.Linear(
160
+ d_model * after_reduced_cnn_size, self.n_outputs
161
+ )
162
+
163
+ def _len_last_layer(self, input_size):
164
+ self.feature_extractor.eval()
165
+ with torch.no_grad():
166
+ out = self.feature_extractor(torch.Tensor(1, 1, input_size))
167
+ self.feature_extractor.train()
168
+ return len(out.flatten())
169
+
170
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
171
+ """
172
+ Forward pass.
173
+
174
+ Parameters
175
+ ----------
176
+ x: torch.Tensor
177
+ Batch of EEG windows of shape (batch_size, n_channels, n_times).
178
+ """
179
+
180
+ encoded_features = self.feature_extractor(x)
181
+ encoded_features = encoded_features.contiguous().view(
182
+ encoded_features.shape[0], -1
183
+ )
184
+
185
+ if self.return_feats:
186
+ return encoded_features
187
+
188
+ return self.final_layer(encoded_features)
189
+
190
+
191
+ class _SELayer(nn.Module):
192
+ def __init__(self, channel, reduction=16, activation=nn.ReLU):
193
+ super(_SELayer, self).__init__()
194
+ self.avg_pool = nn.AdaptiveAvgPool1d(1)
195
+ self.fc = nn.Sequential(
196
+ nn.Linear(channel, channel // reduction, bias=False),
197
+ activation(inplace=True),
198
+ nn.Linear(channel // reduction, channel, bias=False),
199
+ nn.Sigmoid(),
200
+ )
201
+
202
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
203
+ """
204
+ Forward pass of the SE layer.
205
+
206
+ Parameters
207
+ ----------
208
+ x : torch.Tensor
209
+ Input tensor of shape (batch_size, channel, length).
210
+
211
+ Returns
212
+ -------
213
+ torch.Tensor
214
+ Output tensor after applying the SE recalibration.
215
+ """
216
+ b, c, _ = x.size()
217
+ y = self.avg_pool(x).view(b, c)
218
+ y = self.fc(y).view(b, c, 1)
219
+ return x * y.expand_as(x)
220
+
221
+
222
+ class _SEBasicBlock(nn.Module):
223
+ expansion = 1
224
+
225
+ def __init__(
226
+ self,
227
+ inplanes,
228
+ planes,
229
+ stride=1,
230
+ downsample=None,
231
+ activation: nn.Module = nn.ReLU,
232
+ *,
233
+ reduction=16,
234
+ ):
235
+ super(_SEBasicBlock, self).__init__()
236
+ self.conv1 = nn.Conv1d(inplanes, planes, stride)
237
+ self.bn1 = nn.BatchNorm1d(planes)
238
+ self.relu = activation(inplace=True)
239
+ self.conv2 = nn.Conv1d(planes, planes, 1)
240
+ self.bn2 = nn.BatchNorm1d(planes)
241
+ self.se = _SELayer(planes, reduction)
242
+ self.downsample = downsample
243
+ self.stride = stride
244
+ self.features = nn.Sequential(
245
+ self.conv1, self.bn1, self.relu, self.conv2, self.bn2, self.se
246
+ )
247
+
248
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
249
+ """
250
+ Forward pass of the SE layer.
251
+
252
+ Parameters
253
+ ----------
254
+ x : torch.Tensor
255
+ Input tensor of shape (batch_size, n_chans, n_times).
256
+
257
+ Returns
258
+ -------
259
+ torch.Tensor
260
+ Output tensor after applying the SE recalibration.
261
+ """
262
+ residual = x
263
+ out = self.features(x)
264
+
265
+ if self.downsample is not None:
266
+ residual = self.downsample(x)
267
+
268
+ out += residual
269
+ out = self.relu(out)
270
+
271
+ return out
272
+
273
+
274
+ class _MRCNN(nn.Module):
275
+ def __init__(
276
+ self,
277
+ after_reduced_cnn_size,
278
+ kernel_size=7,
279
+ activation: nn.Module = nn.GELU,
280
+ activation_se: nn.Module = nn.ReLU,
281
+ ):
282
+ super(_MRCNN, self).__init__()
283
+ drate = 0.5
284
+ self.GELU = activation()
285
+ self.features1 = nn.Sequential(
286
+ nn.Conv1d(1, 64, kernel_size=50, stride=6, bias=False, padding=24),
287
+ nn.BatchNorm1d(64),
288
+ self.GELU,
289
+ nn.MaxPool1d(kernel_size=8, stride=2, padding=4),
290
+ nn.Dropout(drate),
291
+ nn.Conv1d(64, 128, kernel_size=8, stride=1, bias=False, padding=4),
292
+ nn.BatchNorm1d(128),
293
+ self.GELU,
294
+ nn.Conv1d(128, 128, kernel_size=8, stride=1, bias=False, padding=4),
295
+ nn.BatchNorm1d(128),
296
+ self.GELU,
297
+ nn.MaxPool1d(kernel_size=4, stride=4, padding=2),
298
+ )
299
+
300
+ self.features2 = nn.Sequential(
301
+ nn.Conv1d(1, 64, kernel_size=400, stride=50, bias=False, padding=200),
302
+ nn.BatchNorm1d(64),
303
+ self.GELU,
304
+ nn.MaxPool1d(kernel_size=4, stride=2, padding=2),
305
+ nn.Dropout(drate),
306
+ nn.Conv1d(
307
+ 64, 128, kernel_size=kernel_size, stride=1, bias=False, padding=3
308
+ ),
309
+ nn.BatchNorm1d(128),
310
+ self.GELU,
311
+ nn.Conv1d(
312
+ 128, 128, kernel_size=kernel_size, stride=1, bias=False, padding=3
313
+ ),
314
+ nn.BatchNorm1d(128),
315
+ self.GELU,
316
+ nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
317
+ )
318
+
319
+ self.dropout = nn.Dropout(drate)
320
+ self.inplanes = 128
321
+ self.AFR = self._make_layer(
322
+ _SEBasicBlock, after_reduced_cnn_size, 1, activate=activation_se
323
+ )
324
+
325
+ def _make_layer(
326
+ self, block, planes, blocks, stride=1, activate: nn.Module = nn.ReLU
327
+ ): # makes residual SE block
328
+ downsample = None
329
+ if stride != 1 or self.inplanes != planes * block.expansion:
330
+ downsample = nn.Sequential(
331
+ nn.Conv1d(
332
+ self.inplanes,
333
+ planes * block.expansion,
334
+ kernel_size=1,
335
+ stride=stride,
336
+ bias=False,
337
+ ),
338
+ nn.BatchNorm1d(planes * block.expansion),
339
+ )
340
+
341
+ layers = []
342
+ layers.append(block(self.inplanes, planes, stride, downsample))
343
+ self.inplanes = planes * block.expansion
344
+ for i in range(1, blocks):
345
+ layers.append(block(self.inplanes, planes, activate=activate))
346
+
347
+ return nn.Sequential(*layers)
348
+
349
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
350
+ x1 = self.features1(x)
351
+ x2 = self.features2(x)
352
+ x_concat = torch.cat((x1, x2), dim=2)
353
+ x_concat = self.dropout(x_concat)
354
+ x_concat = self.AFR(x_concat)
355
+ return x_concat
356
+
357
+
358
+ ##########################################################################################
359
+
360
+
361
+ def _attention(
362
+ query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
363
+ ) -> tuple[torch.Tensor, torch.Tensor]:
364
+ """Implementation of Scaled dot product attention"""
365
+ # d_k - dimension of the query and key vectors
366
+ d_k = query.size(-1)
367
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
368
+ p_attn = F.softmax(scores, dim=-1) # attention weights
369
+ output = torch.matmul(p_attn, value) # (B, h, T, d_k)
370
+ return output, p_attn
371
+
372
+
373
+ class _MultiHeadedAttention(nn.Module):
374
+ def __init__(self, h, d_model, after_reduced_cnn_size, dropout=0.1):
375
+ """Take in model size and number of heads."""
376
+ super().__init__()
377
+ assert d_model % h == 0
378
+ self.d_per_head = d_model // h
379
+ self.h = h
380
+
381
+ base_conv = CausalConv1d(
382
+ in_channels=after_reduced_cnn_size,
383
+ out_channels=after_reduced_cnn_size,
384
+ kernel_size=7,
385
+ stride=1,
386
+ )
387
+ self.convs = nn.ModuleList([deepcopy(base_conv) for _ in range(3)])
388
+
389
+ self.linear = nn.Linear(d_model, d_model)
390
+ self.dropout = nn.Dropout(p=dropout)
391
+
392
+ def forward(self, query, key, value: torch.Tensor) -> torch.Tensor:
393
+ """Implements Multi-head attention"""
394
+ nbatches = query.size(0)
395
+
396
+ query = query.view(nbatches, -1, self.h, self.d_per_head).transpose(1, 2)
397
+ key = (
398
+ self.convs[1](key)
399
+ .view(nbatches, -1, self.h, self.d_per_head)
400
+ .transpose(1, 2)
401
+ )
402
+ value = (
403
+ self.convs[2](value)
404
+ .view(nbatches, -1, self.h, self.d_per_head)
405
+ .transpose(1, 2)
406
+ )
407
+
408
+ x_raw, attn_weights = _attention(query, key, value)
409
+ # apply dropout to the *weights*
410
+ attn = self.dropout(attn_weights)
411
+ # recompute the weighted sum with dropped weights
412
+ x = torch.matmul(attn, value)
413
+
414
+ # stash the pre‑dropout weights if you need them
415
+ self.attn = attn_weights
416
+
417
+ # merge heads and project
418
+ x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_per_head)
419
+
420
+ return self.linear(x)
421
+
422
+
423
+ class _ResidualLayerNormAttn(nn.Module):
424
+ """
425
+ A residual connection followed by a layer norm.
426
+ """
427
+
428
+ def __init__(self, size, dropout, fn_attn):
429
+ super().__init__()
430
+ self.norm = nn.LayerNorm(size, eps=1e-6)
431
+ self.dropout = nn.Dropout(dropout)
432
+ self.fn_attn = fn_attn
433
+
434
+ def forward(
435
+ self,
436
+ x: torch.Tensor,
437
+ key: torch.Tensor,
438
+ value: torch.Tensor,
439
+ ) -> torch.Tensor:
440
+ """Apply residual connection to any sublayer with the same size."""
441
+ x_norm = self.norm(x)
442
+
443
+ out = self.fn_attn(x_norm, key, value)
444
+
445
+ return x + self.dropout(out)
446
+
447
+
448
+ class _ResidualLayerNormFF(nn.Module):
449
+ def __init__(self, size, dropout, fn_ff):
450
+ super().__init__()
451
+ self.norm = nn.LayerNorm(size, eps=1e-6)
452
+ self.dropout = nn.Dropout(dropout)
453
+ self.fn_ff = fn_ff
454
+
455
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
456
+ """Apply residual connection to any sublayer with the same size."""
457
+ x_norm = self.norm(x)
458
+
459
+ out = self.fn_ff(x_norm)
460
+
461
+ return x + self.dropout(out)
462
+
463
+
464
+ class _TCE(nn.Module):
465
+ """
466
+ Transformer Encoder
467
+ It is a stack of n layers.
468
+ """
469
+
470
+ def __init__(self, layer, n):
471
+ super().__init__()
472
+
473
+ self.layers = nn.ModuleList([deepcopy(layer) for _ in range(n)])
474
+
475
+ self.norm = nn.LayerNorm(layer.size, eps=1e-6)
476
+
477
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
478
+ for layer in self.layers:
479
+ x = layer(x)
480
+ return self.norm(x)
481
+
482
+
483
+ class _EncoderLayer(nn.Module):
484
+ """
485
+ An encoder layer
486
+ Made up of self-attention and a feed forward layer.
487
+ Each of these sublayers have residual and layer norm, implemented by _ResidualLayerNorm.
488
+ """
489
+
490
+ def __init__(self, size, self_attn, feed_forward, after_reduced_cnn_size, dropout):
491
+ super().__init__()
492
+ self.size = size
493
+ self.self_attn = self_attn
494
+ self.feed_forward = feed_forward
495
+
496
+ self.residual_self_attn = _ResidualLayerNormAttn(
497
+ size=size,
498
+ dropout=dropout,
499
+ fn_attn=self_attn,
500
+ )
501
+ self.residual_ff = _ResidualLayerNormFF(
502
+ size=size,
503
+ dropout=dropout,
504
+ fn_ff=feed_forward,
505
+ )
506
+
507
+ self.conv = CausalConv1d(
508
+ in_channels=after_reduced_cnn_size,
509
+ out_channels=after_reduced_cnn_size,
510
+ kernel_size=7,
511
+ stride=1,
512
+ dilation=1,
513
+ )
514
+
515
+ def forward(self, x_in: torch.Tensor) -> torch.Tensor:
516
+ """Transformer Encoder"""
517
+ query = self.conv(x_in)
518
+ # Encoder self-attention
519
+ x = self.residual_self_attn(query, x_in, x_in)
520
+ x_ff = self.residual_ff(x)
521
+ return x_ff
522
+
523
+
524
+ class _PositionwiseFeedForward(nn.Module):
525
+ """Positionwise feed-forward network."""
526
+
527
+ def __init__(self, d_model, d_ff, dropout=0.1, activation: nn.Module = nn.ReLU):
528
+ super().__init__()
529
+ self.w_1 = nn.Linear(d_model, d_ff)
530
+ self.w_2 = nn.Linear(d_ff, d_model)
531
+ self.dropout = nn.Dropout(dropout)
532
+ self.activate = activation()
533
+
534
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
535
+ """Implements FFN equation."""
536
+ return self.w_2(self.dropout(self.activate(self.w_1(x))))