braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,549 @@
1
+ # Authors: Divyesh Narayanan <divyesh.narayanan@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+ import math
6
+ import warnings
7
+ from copy import deepcopy
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from mne.utils import deprecated
12
+ from torch import nn
13
+
14
+ from braindecode.models.base import EEGModuleMixin
15
+ from braindecode.modules import CausalConv1d
16
+
17
+
18
+ class AttnSleep(EEGModuleMixin, nn.Module):
19
+ r"""Sleep Staging Architecture from Eldele et al (2021) [Eldele2021]_.
20
+
21
+ :bdg-success:`Convolution` :bdg-info:`Attention/Transformer`
22
+
23
+ .. figure:: https://raw.githubusercontent.com/emadeldeen24/AttnSleep/refs/heads/main/imgs/AttnSleep.png
24
+ :align: center
25
+ :alt: AttnSleep Architecture
26
+
27
+ Attention based Neural Net for sleep staging as described in [Eldele2021]_.
28
+ The code for the paper and this model is also available at [1]_.
29
+ Takes single channel EEG as input.
30
+ Feature extraction module based on multi-resolution convolutional neural network (MRCNN)
31
+ and adaptive feature recalibration (AFR).
32
+ The second module is the temporal context encoder (TCE) that leverages a multi-head attention
33
+ mechanism to capture the temporal dependencies among the extracted features.
34
+
35
+ Warning - This model was designed for signals of 30 seconds at 100Hz or 125Hz (in which case
36
+ the reference architecture from [1]_ which was validated on SHHS dataset [2]_ will be used)
37
+ to use any other input is likely to make the model perform in unintended ways.
38
+
39
+ Parameters
40
+ ----------
41
+ n_tce : int
42
+ Number of TCE clones.
43
+ d_model : int
44
+ Input dimension for the TCE.
45
+ Also the input dimension of the first FC layer in the feed forward
46
+ and the output of the second FC layer in the same.
47
+ Increase for higher sampling rate/signal length.
48
+ It should be divisible by n_attn_heads
49
+ d_ff : int
50
+ Output dimension of the first FC layer in the feed forward and the
51
+ input dimension of the second FC layer in the same.
52
+ n_attn_heads : int
53
+ Number of attention heads. It should be a factor of d_model
54
+ drop_prob : float
55
+ Dropout rate in the PositionWiseFeedforward layer and the TCE layers.
56
+ after_reduced_cnn_size : int
57
+ Number of output channels produced by the convolution in the AFR module.
58
+ return_feats : bool
59
+ If True, return the features, i.e. the output of the feature extractor
60
+ (before the final linear layer). If False, pass the features through
61
+ the final linear layer.
62
+ n_classes : int
63
+ Alias for `n_outputs`.
64
+ input_size_s : float
65
+ Alias for `input_window_seconds`.
66
+ activation : nn.Module, default=nn.ReLU
67
+ Activation function class to apply. Should be a PyTorch activation
68
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
69
+ activation_mrcnn : nn.Module, default=nn.ReLU
70
+ Activation function class to apply in the Mask R-CNN layer.
71
+ Should be a PyTorch activation module class like ``nn.ReLU`` or
72
+ ``nn.GELU``. Default is ``nn.GELU``.
73
+
74
+ References
75
+ ----------
76
+ .. [Eldele2021] E. Eldele et al., "An Attention-Based Deep Learning Approach for Sleep Stage
77
+ Classification With Single-Channel EEG," in IEEE Transactions on Neural Systems and
78
+ Rehabilitation Engineering, vol. 29, pp. 809-818, 2021, doi: 10.1109/TNSRE.2021.3076234.
79
+
80
+ .. [1] https://github.com/emadeldeen24/AttnSleep
81
+
82
+ .. [2] https://sleepdata.org/datasets/shhs
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ sfreq=None,
88
+ n_tce=2,
89
+ d_model=80,
90
+ d_ff=120,
91
+ n_attn_heads=5,
92
+ drop_prob=0.1,
93
+ activation_mrcnn: type[nn.Module] = nn.GELU,
94
+ activation: type[nn.Module] = nn.ReLU,
95
+ input_window_seconds=None,
96
+ n_outputs=None,
97
+ after_reduced_cnn_size=30,
98
+ return_feats=False,
99
+ chs_info=None,
100
+ n_chans=None,
101
+ n_times=None,
102
+ ):
103
+ super().__init__(
104
+ n_outputs=n_outputs,
105
+ n_chans=n_chans,
106
+ chs_info=chs_info,
107
+ n_times=n_times,
108
+ input_window_seconds=input_window_seconds,
109
+ sfreq=sfreq,
110
+ )
111
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
112
+
113
+ self.mapping = {
114
+ "fc.weight": "final_layer.weight",
115
+ "fc.bias": "final_layer.bias",
116
+ }
117
+
118
+ if not (
119
+ (self.input_window_seconds == 30 and self.sfreq == 100 and d_model == 80)
120
+ or (
121
+ self.input_window_seconds == 30 and self.sfreq == 125 and d_model == 100
122
+ )
123
+ ):
124
+ warnings.warn(
125
+ "This model was designed originally for input windows of 30sec at 100Hz, "
126
+ "with d_model at 80 or at 125Hz, with d_model at 100, to use anything "
127
+ "other than this may cause errors or cause the model to perform in "
128
+ "other ways than intended",
129
+ UserWarning,
130
+ )
131
+
132
+ # the usual kernel size for the mrcnn, for sfreq 100
133
+ kernel_size = 7
134
+
135
+ if self.sfreq == 125:
136
+ kernel_size = 6
137
+
138
+ mrcnn = _MRCNN(
139
+ after_reduced_cnn_size,
140
+ kernel_size,
141
+ activation=activation_mrcnn,
142
+ activation_se=activation,
143
+ )
144
+ attn = _MultiHeadedAttention(n_attn_heads, d_model, after_reduced_cnn_size)
145
+ ff = _PositionwiseFeedForward(d_model, d_ff, drop_prob, activation=activation)
146
+ tce = _TCE(
147
+ _EncoderLayer(
148
+ d_model, deepcopy(attn), deepcopy(ff), after_reduced_cnn_size, drop_prob
149
+ ),
150
+ n_tce,
151
+ )
152
+
153
+ self.feature_extractor = nn.Sequential(mrcnn, tce)
154
+ self.len_last_layer = self._len_last_layer(self.n_times)
155
+ self.return_feats = return_feats
156
+
157
+ # TODO: Add new way to handle return features
158
+ """if return_feats:
159
+ raise ValueError("return_feat == True is not accepted anymore")"""
160
+ if not return_feats:
161
+ self.final_layer = nn.Linear(
162
+ d_model * after_reduced_cnn_size, self.n_outputs
163
+ )
164
+
165
+ def _len_last_layer(self, input_size):
166
+ self.feature_extractor.eval()
167
+ with torch.no_grad():
168
+ out = self.feature_extractor(torch.Tensor(1, 1, input_size))
169
+ self.feature_extractor.train()
170
+ return len(out.flatten())
171
+
172
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
173
+ """
174
+ Forward pass.
175
+
176
+ Parameters
177
+ ----------
178
+ x : torch.Tensor
179
+ Batch of EEG windows of shape (batch_size, n_channels, n_times).
180
+ """
181
+
182
+ encoded_features = self.feature_extractor(x)
183
+ encoded_features = encoded_features.contiguous().view(
184
+ encoded_features.shape[0], -1
185
+ )
186
+
187
+ if self.return_feats:
188
+ return encoded_features
189
+
190
+ return self.final_layer(encoded_features)
191
+
192
+
193
+ class _SELayer(nn.Module):
194
+ def __init__(self, channel, reduction=16, activation=nn.ReLU):
195
+ super(_SELayer, self).__init__()
196
+ self.avg_pool = nn.AdaptiveAvgPool1d(1)
197
+ self.fc = nn.Sequential(
198
+ nn.Linear(channel, channel // reduction, bias=False),
199
+ activation(inplace=True),
200
+ nn.Linear(channel // reduction, channel, bias=False),
201
+ nn.Sigmoid(),
202
+ )
203
+
204
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
205
+ """
206
+ Forward pass of the SE layer.
207
+
208
+ Parameters
209
+ ----------
210
+ x : torch.Tensor
211
+ Input tensor of shape (batch_size, channel, length).
212
+
213
+ Returns
214
+ -------
215
+ torch.Tensor
216
+ Output tensor after applying the SE recalibration.
217
+ """
218
+ b, c, _ = x.size()
219
+ y = self.avg_pool(x).view(b, c)
220
+ y = self.fc(y).view(b, c, 1)
221
+ return x * y.expand_as(x)
222
+
223
+
224
+ class _SEBasicBlock(nn.Module):
225
+ expansion = 1
226
+
227
+ def __init__(
228
+ self,
229
+ inplanes,
230
+ planes,
231
+ stride=1,
232
+ downsample=None,
233
+ activation: type[nn.Module] = nn.ReLU,
234
+ *,
235
+ reduction=16,
236
+ ):
237
+ super(_SEBasicBlock, self).__init__()
238
+ self.conv1 = nn.Conv1d(inplanes, planes, stride)
239
+ self.bn1 = nn.BatchNorm1d(planes)
240
+ self.relu = activation(inplace=True)
241
+ self.conv2 = nn.Conv1d(planes, planes, 1)
242
+ self.bn2 = nn.BatchNorm1d(planes)
243
+ self.se = _SELayer(planes, reduction)
244
+ self.downsample = downsample
245
+ self.stride = stride
246
+ self.features = nn.Sequential(
247
+ self.conv1, self.bn1, self.relu, self.conv2, self.bn2, self.se
248
+ )
249
+
250
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
251
+ """
252
+ Forward pass of the SE layer.
253
+
254
+ Parameters
255
+ ----------
256
+ x : torch.Tensor
257
+ Input tensor of shape (batch_size, n_chans, n_times).
258
+
259
+ Returns
260
+ -------
261
+ torch.Tensor
262
+ Output tensor after applying the SE recalibration.
263
+ """
264
+ residual = x
265
+ out = self.features(x)
266
+
267
+ if self.downsample is not None:
268
+ residual = self.downsample(x)
269
+
270
+ out += residual
271
+ out = self.relu(out)
272
+
273
+ return out
274
+
275
+
276
+ class _MRCNN(nn.Module):
277
+ def __init__(
278
+ self,
279
+ after_reduced_cnn_size,
280
+ kernel_size=7,
281
+ activation: type[nn.Module] = nn.GELU,
282
+ activation_se: type[nn.Module] = nn.ReLU,
283
+ ):
284
+ super(_MRCNN, self).__init__()
285
+ drate = 0.5
286
+ self.GELU = activation()
287
+ self.features1 = nn.Sequential(
288
+ nn.Conv1d(1, 64, kernel_size=50, stride=6, bias=False, padding=24),
289
+ nn.BatchNorm1d(64),
290
+ self.GELU,
291
+ nn.MaxPool1d(kernel_size=8, stride=2, padding=4),
292
+ nn.Dropout(drate),
293
+ nn.Conv1d(64, 128, kernel_size=8, stride=1, bias=False, padding=4),
294
+ nn.BatchNorm1d(128),
295
+ self.GELU,
296
+ nn.Conv1d(128, 128, kernel_size=8, stride=1, bias=False, padding=4),
297
+ nn.BatchNorm1d(128),
298
+ self.GELU,
299
+ nn.MaxPool1d(kernel_size=4, stride=4, padding=2),
300
+ )
301
+
302
+ self.features2 = nn.Sequential(
303
+ nn.Conv1d(1, 64, kernel_size=400, stride=50, bias=False, padding=200),
304
+ nn.BatchNorm1d(64),
305
+ self.GELU,
306
+ nn.MaxPool1d(kernel_size=4, stride=2, padding=2),
307
+ nn.Dropout(drate),
308
+ nn.Conv1d(
309
+ 64, 128, kernel_size=kernel_size, stride=1, bias=False, padding=3
310
+ ),
311
+ nn.BatchNorm1d(128),
312
+ self.GELU,
313
+ nn.Conv1d(
314
+ 128, 128, kernel_size=kernel_size, stride=1, bias=False, padding=3
315
+ ),
316
+ nn.BatchNorm1d(128),
317
+ self.GELU,
318
+ nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
319
+ )
320
+
321
+ self.dropout = nn.Dropout(drate)
322
+ self.inplanes = 128
323
+ self.AFR = self._make_layer(
324
+ _SEBasicBlock, after_reduced_cnn_size, 1, activate=activation_se
325
+ )
326
+
327
+ def _make_layer(
328
+ self, block, planes, blocks, stride=1, activate: type[nn.Module] = nn.ReLU
329
+ ): # makes residual SE block
330
+ downsample = None
331
+ if stride != 1 or self.inplanes != planes * block.expansion:
332
+ downsample = nn.Sequential(
333
+ nn.Conv1d(
334
+ self.inplanes,
335
+ planes * block.expansion,
336
+ kernel_size=1,
337
+ stride=stride,
338
+ bias=False,
339
+ ),
340
+ nn.BatchNorm1d(planes * block.expansion),
341
+ )
342
+
343
+ layers = []
344
+ layers.append(block(self.inplanes, planes, stride, downsample))
345
+ self.inplanes = planes * block.expansion
346
+ for i in range(1, blocks):
347
+ layers.append(block(self.inplanes, planes, activate=activate))
348
+
349
+ return nn.Sequential(*layers)
350
+
351
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
352
+ x1 = self.features1(x)
353
+ x2 = self.features2(x)
354
+ x_concat = torch.cat((x1, x2), dim=2)
355
+ x_concat = self.dropout(x_concat)
356
+ x_concat = self.AFR(x_concat)
357
+ return x_concat
358
+
359
+
360
+ ##########################################################################################
361
+
362
+
363
+ def _attention(
364
+ query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
365
+ ) -> tuple[torch.Tensor, torch.Tensor]:
366
+ """Implementation of Scaled dot product attention."""
367
+ # d_k - dimension of the query and key vectors
368
+ d_k = query.size(-1)
369
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
370
+ p_attn = F.softmax(scores, dim=-1) # attention weights
371
+ output = torch.matmul(p_attn, value) # (B, h, T, d_k)
372
+ return output, p_attn
373
+
374
+
375
+ class _MultiHeadedAttention(nn.Module):
376
+ def __init__(self, h, d_model, after_reduced_cnn_size, dropout=0.1):
377
+ """Take in model size and number of heads."""
378
+ super().__init__()
379
+ assert d_model % h == 0
380
+ self.d_per_head = d_model // h
381
+ self.h = h
382
+
383
+ base_conv = CausalConv1d(
384
+ in_channels=after_reduced_cnn_size,
385
+ out_channels=after_reduced_cnn_size,
386
+ kernel_size=7,
387
+ stride=1,
388
+ )
389
+ self.convs = nn.ModuleList([deepcopy(base_conv) for _ in range(3)])
390
+
391
+ self.linear = nn.Linear(d_model, d_model)
392
+ self.dropout = nn.Dropout(p=dropout)
393
+
394
+ def forward(self, query, key, value: torch.Tensor) -> torch.Tensor:
395
+ """Implements Multi-head attention."""
396
+ nbatches = query.size(0)
397
+
398
+ query = query.view(nbatches, -1, self.h, self.d_per_head).transpose(1, 2)
399
+ key = (
400
+ self.convs[1](key)
401
+ .view(nbatches, -1, self.h, self.d_per_head)
402
+ .transpose(1, 2)
403
+ )
404
+ value = (
405
+ self.convs[2](value)
406
+ .view(nbatches, -1, self.h, self.d_per_head)
407
+ .transpose(1, 2)
408
+ )
409
+
410
+ x_raw, attn_weights = _attention(query, key, value)
411
+ # apply dropout to the *weights*
412
+ attn = self.dropout(attn_weights)
413
+ # recompute the weighted sum with dropped weights
414
+ x = torch.matmul(attn, value)
415
+
416
+ # stash the pre‑dropout weights if you need them
417
+ self.attn = attn_weights
418
+
419
+ # merge heads and project
420
+ x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_per_head)
421
+
422
+ return self.linear(x)
423
+
424
+
425
+ class _ResidualLayerNormAttn(nn.Module):
426
+ r"""A residual connection followed by a layer norm."""
427
+
428
+ def __init__(self, size, dropout, fn_attn):
429
+ super().__init__()
430
+ self.norm = nn.LayerNorm(size, eps=1e-6)
431
+ self.dropout = nn.Dropout(dropout)
432
+ self.fn_attn = fn_attn
433
+
434
+ def forward(
435
+ self,
436
+ x: torch.Tensor,
437
+ key: torch.Tensor,
438
+ value: torch.Tensor,
439
+ ) -> torch.Tensor:
440
+ """Apply residual connection to any sublayer with the same size."""
441
+ x_norm = self.norm(x)
442
+
443
+ out = self.fn_attn(x_norm, key, value)
444
+
445
+ return x + self.dropout(out)
446
+
447
+
448
+ class _ResidualLayerNormFF(nn.Module):
449
+ def __init__(self, size, dropout, fn_ff):
450
+ super().__init__()
451
+ self.norm = nn.LayerNorm(size, eps=1e-6)
452
+ self.dropout = nn.Dropout(dropout)
453
+ self.fn_ff = fn_ff
454
+
455
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
456
+ """Apply residual connection to any sublayer with the same size."""
457
+ x_norm = self.norm(x)
458
+
459
+ out = self.fn_ff(x_norm)
460
+
461
+ return x + self.dropout(out)
462
+
463
+
464
+ class _TCE(nn.Module):
465
+ r"""
466
+ Transformer Encoder.
467
+
468
+ It is a stack of n layers.
469
+ """
470
+
471
+ def __init__(self, layer, n):
472
+ super().__init__()
473
+
474
+ self.layers = nn.ModuleList([deepcopy(layer) for _ in range(n)])
475
+
476
+ self.norm = nn.LayerNorm(layer.size, eps=1e-6)
477
+
478
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
479
+ for layer in self.layers:
480
+ x = layer(x)
481
+ return self.norm(x)
482
+
483
+
484
+ class _EncoderLayer(nn.Module):
485
+ r"""
486
+ An encoder layer.
487
+
488
+ Made up of self-attention and a feed forward layer.
489
+ Each of these sublayers have residual and layer norm, implemented by _ResidualLayerNorm.
490
+ """
491
+
492
+ def __init__(self, size, self_attn, feed_forward, after_reduced_cnn_size, dropout):
493
+ super().__init__()
494
+ self.size = size
495
+ self.self_attn = self_attn
496
+ self.feed_forward = feed_forward
497
+
498
+ self.residual_self_attn = _ResidualLayerNormAttn(
499
+ size=size,
500
+ dropout=dropout,
501
+ fn_attn=self_attn,
502
+ )
503
+ self.residual_ff = _ResidualLayerNormFF(
504
+ size=size,
505
+ dropout=dropout,
506
+ fn_ff=feed_forward,
507
+ )
508
+
509
+ self.conv = CausalConv1d(
510
+ in_channels=after_reduced_cnn_size,
511
+ out_channels=after_reduced_cnn_size,
512
+ kernel_size=7,
513
+ stride=1,
514
+ dilation=1,
515
+ )
516
+
517
+ def forward(self, x_in: torch.Tensor) -> torch.Tensor:
518
+ """Transformer Encoder."""
519
+ query = self.conv(x_in)
520
+ # Encoder self-attention
521
+ x = self.residual_self_attn(query, x_in, x_in)
522
+ x_ff = self.residual_ff(x)
523
+ return x_ff
524
+
525
+
526
+ class _PositionwiseFeedForward(nn.Module):
527
+ r"""Positionwise feed-forward network."""
528
+
529
+ def __init__(
530
+ self, d_model, d_ff, dropout=0.1, activation: type[nn.Module] = nn.ReLU
531
+ ):
532
+ super().__init__()
533
+ self.w_1 = nn.Linear(d_model, d_ff)
534
+ self.w_2 = nn.Linear(d_ff, d_model)
535
+ self.dropout = nn.Dropout(dropout)
536
+ self.activate = activation()
537
+
538
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
539
+ """Implements FFN equation."""
540
+ return self.w_2(self.dropout(self.activate(self.w_1(x))))
541
+
542
+
543
+ @deprecated(
544
+ "`SleepStagerEldele2021` was renamed to `AttnSleep` in v1.12 to follow original author's name; this alias will be removed in v1.14."
545
+ )
546
+ class SleepStagerEldele2021(AttnSleep):
547
+ r"""Deprecated alias for SleepStagerEldele2021."""
548
+
549
+ pass