braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,377 @@
|
|
|
1
|
+
# Authors: Tao Yang <sheeptao@outlook.com>
|
|
2
|
+
# Bruno Aristimunha <b.aristimunha@gmail.com> (braindecode adaptation)
|
|
3
|
+
#
|
|
4
|
+
from typing import Type, Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
from einops.layers.torch import Rearrange
|
|
9
|
+
|
|
10
|
+
from braindecode.models.base import EEGModuleMixin
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MSVTNet(EEGModuleMixin, nn.Module):
|
|
14
|
+
r"""MSVTNet model from Liu K et al (2024) from [msvt2024]_.
|
|
15
|
+
|
|
16
|
+
:bdg-success:`Convolution` :bdg-secondary:`Recurrent` :bdg-info:`Attention/Transformer`
|
|
17
|
+
|
|
18
|
+
This model implements a multi-scale convolutional transformer network
|
|
19
|
+
for EEG signal classification, as described in [msvt2024]_.
|
|
20
|
+
|
|
21
|
+
.. figure:: https://raw.githubusercontent.com/SheepTAO/MSVTNet/refs/heads/main/MSVTNet_Arch.png
|
|
22
|
+
:align: center
|
|
23
|
+
:alt: MSVTNet Architecture
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
n_filters_list : list[int], optional
|
|
28
|
+
List of filter numbers for each TSConv block, by default (9, 9, 9, 9).
|
|
29
|
+
conv1_kernels_size : list[int], optional
|
|
30
|
+
List of kernel sizes for the first convolution in each TSConv block,
|
|
31
|
+
by default (15, 31, 63, 125).
|
|
32
|
+
conv2_kernel_size : int, optional
|
|
33
|
+
Kernel size for the second convolution in TSConv blocks, by default 15.
|
|
34
|
+
depth_multiplier : int, optional
|
|
35
|
+
Depth multiplier for depthwise convolution, by default 2.
|
|
36
|
+
pool1_size : int, optional
|
|
37
|
+
Pooling size for the first pooling layer in TSConv blocks, by default 8.
|
|
38
|
+
pool2_size : int, optional
|
|
39
|
+
Pooling size for the second pooling layer in TSConv blocks, by default 7.
|
|
40
|
+
drop_prob : float, optional
|
|
41
|
+
Dropout probability for convolutional layers, by default 0.3.
|
|
42
|
+
num_heads : int, optional
|
|
43
|
+
Number of attention heads in the transformer encoder, by default 8.
|
|
44
|
+
ffn_expansion_factor : float, optional
|
|
45
|
+
Ratio to compute feedforward dimension in the transformer, by default 1.
|
|
46
|
+
att_drop_prob : float, optional
|
|
47
|
+
Dropout probability for the transformer, by default 0.5.
|
|
48
|
+
num_layers : int, optional
|
|
49
|
+
Number of transformer encoder layers, by default 2.
|
|
50
|
+
activation : Type[nn.Module], optional
|
|
51
|
+
Activation function class to use, by default nn.ELU.
|
|
52
|
+
return_features : bool, optional
|
|
53
|
+
Whether to return predictions from branch classifiers, by default False.
|
|
54
|
+
|
|
55
|
+
Notes
|
|
56
|
+
-----
|
|
57
|
+
This implementation is not guaranteed to be correct, has not been checked
|
|
58
|
+
by original authors, only reimplemented based on the original code [msvt2024code]_.
|
|
59
|
+
|
|
60
|
+
References
|
|
61
|
+
----------
|
|
62
|
+
.. [msvt2024] Liu, K., et al. (2024). MSVTNet: Multi-Scale Vision
|
|
63
|
+
Transformer Neural Network for EEG-Based Motor Imagery Decoding.
|
|
64
|
+
IEEE Journal of Biomedical an Health Informatics.
|
|
65
|
+
.. [msvt2024code] Liu, K., et al. (2024). MSVTNet: Multi-Scale Vision
|
|
66
|
+
Transformer Neural Network for EEG-Based Motor Imagery Decoding.
|
|
67
|
+
Source Code: https://github.com/SheepTAO/MSVTNet
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
# braindecode parameters
|
|
73
|
+
n_chans=None,
|
|
74
|
+
n_outputs=None,
|
|
75
|
+
n_times=None,
|
|
76
|
+
input_window_seconds=None,
|
|
77
|
+
sfreq=None,
|
|
78
|
+
chs_info=None,
|
|
79
|
+
# Model's parameters
|
|
80
|
+
n_filters_list: tuple[int, ...] = (9, 9, 9, 9),
|
|
81
|
+
conv1_kernels_size: tuple[int, ...] = (15, 31, 63, 125),
|
|
82
|
+
conv2_kernel_size: int = 15,
|
|
83
|
+
depth_multiplier: int = 2,
|
|
84
|
+
pool1_size: int = 8,
|
|
85
|
+
pool2_size: int = 7,
|
|
86
|
+
drop_prob: float = 0.3,
|
|
87
|
+
num_heads: int = 8,
|
|
88
|
+
ffn_expansion_factor: float = 1,
|
|
89
|
+
att_drop_prob: float = 0.5,
|
|
90
|
+
num_layers: int = 2,
|
|
91
|
+
activation: Type[nn.Module] = nn.ELU,
|
|
92
|
+
return_features: bool = False,
|
|
93
|
+
):
|
|
94
|
+
super().__init__(
|
|
95
|
+
n_outputs=n_outputs,
|
|
96
|
+
n_chans=n_chans,
|
|
97
|
+
chs_info=chs_info,
|
|
98
|
+
n_times=n_times,
|
|
99
|
+
input_window_seconds=input_window_seconds,
|
|
100
|
+
sfreq=sfreq,
|
|
101
|
+
)
|
|
102
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
103
|
+
|
|
104
|
+
self.return_features = return_features
|
|
105
|
+
assert len(n_filters_list) == len(conv1_kernels_size), (
|
|
106
|
+
"The length of n_filters_list and conv1_kernel_sizes should be equal."
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
self.ensure_dim = Rearrange("batch chans time -> batch 1 chans time")
|
|
110
|
+
self.mstsconv = nn.ModuleList(
|
|
111
|
+
[
|
|
112
|
+
nn.Sequential(
|
|
113
|
+
_TSConv(
|
|
114
|
+
self.n_chans,
|
|
115
|
+
n_filters_list[b],
|
|
116
|
+
conv1_kernels_size[b],
|
|
117
|
+
conv2_kernel_size,
|
|
118
|
+
depth_multiplier,
|
|
119
|
+
pool1_size,
|
|
120
|
+
pool2_size,
|
|
121
|
+
drop_prob,
|
|
122
|
+
activation,
|
|
123
|
+
),
|
|
124
|
+
Rearrange("batch channels 1 time -> batch time channels"),
|
|
125
|
+
)
|
|
126
|
+
for b in range(len(n_filters_list))
|
|
127
|
+
]
|
|
128
|
+
)
|
|
129
|
+
branch_linear_in = self._forward_flatten(cat=False)
|
|
130
|
+
self.branch_head = nn.ModuleList(
|
|
131
|
+
[
|
|
132
|
+
_DenseLayers(branch_linear_in[b].shape[1], self.n_outputs)
|
|
133
|
+
for b in range(len(n_filters_list))
|
|
134
|
+
]
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
seq_len, d_model = self._forward_mstsconv().shape[1:3] # type: ignore
|
|
138
|
+
self.transformer = _Transformer(
|
|
139
|
+
seq_len,
|
|
140
|
+
d_model,
|
|
141
|
+
num_heads,
|
|
142
|
+
ffn_expansion_factor,
|
|
143
|
+
att_drop_prob,
|
|
144
|
+
num_layers,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
linear_in = self._forward_flatten().shape[1] # type: ignore
|
|
148
|
+
self.flatten_layer = nn.Flatten()
|
|
149
|
+
self.final_layer = nn.Linear(linear_in, self.n_outputs)
|
|
150
|
+
|
|
151
|
+
def _forward_mstsconv(
|
|
152
|
+
self, cat: bool = True
|
|
153
|
+
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
|
154
|
+
x = torch.randn(1, 1, self.n_chans, self.n_times)
|
|
155
|
+
x = [tsconv(x) for tsconv in self.mstsconv]
|
|
156
|
+
if cat:
|
|
157
|
+
x = torch.cat(x, dim=2)
|
|
158
|
+
return x
|
|
159
|
+
|
|
160
|
+
def _forward_flatten(
|
|
161
|
+
self, cat: bool = True
|
|
162
|
+
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
|
163
|
+
x = self._forward_mstsconv(cat)
|
|
164
|
+
if cat:
|
|
165
|
+
x = self.transformer(x)
|
|
166
|
+
x = x.flatten(start_dim=1, end_dim=-1)
|
|
167
|
+
else:
|
|
168
|
+
x = [_.flatten(start_dim=1, end_dim=-1) for _ in x]
|
|
169
|
+
return x
|
|
170
|
+
|
|
171
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
172
|
+
# x with shape: (batch, n_chans, n_times)
|
|
173
|
+
x = self.ensure_dim(x)
|
|
174
|
+
# x with shape: (batch, 1, n_chans, n_times)
|
|
175
|
+
x_list = [tsconv(x) for tsconv in self.mstsconv]
|
|
176
|
+
# x_list contains 4 tensors, each of shape: [batch_size, seq_len, embed_dim]
|
|
177
|
+
branch_preds = [
|
|
178
|
+
branch(x_list[idx]) for idx, branch in enumerate(self.branch_head)
|
|
179
|
+
]
|
|
180
|
+
# branch_preds contains 4 tensors, each of shape: [batch_size, num_classes]
|
|
181
|
+
x = torch.stack(x_list, dim=2)
|
|
182
|
+
x = x.view(x.size(0), x.size(1), -1)
|
|
183
|
+
# x shape after concatenation: [batch_size, seq_len, total_embed_dim]
|
|
184
|
+
x = self.transformer(x)
|
|
185
|
+
# x shape after transformer: [batch_size, embed_dim]
|
|
186
|
+
|
|
187
|
+
x = self.final_layer(x)
|
|
188
|
+
if self.return_features:
|
|
189
|
+
# x shape after final layer: [batch_size, num_classes]
|
|
190
|
+
# branch_preds shape: [batch_size, num_classes]
|
|
191
|
+
return torch.stack(branch_preds)
|
|
192
|
+
return x
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class _TSConv(nn.Sequential):
|
|
196
|
+
r"""
|
|
197
|
+
Time-Distributed Separable Convolution block.
|
|
198
|
+
|
|
199
|
+
The architecture consists of:
|
|
200
|
+
- **Temporal Convolution**
|
|
201
|
+
- **Batch Normalization**
|
|
202
|
+
- **Depthwise Spatial Convolution**
|
|
203
|
+
- **Batch Normalization**
|
|
204
|
+
- **Activation Function**
|
|
205
|
+
- **First Pooling Layer**
|
|
206
|
+
- **Dropout**
|
|
207
|
+
- **Depthwise Temporal Convolution**
|
|
208
|
+
- **Batch Normalization**
|
|
209
|
+
- **Activation Function**
|
|
210
|
+
- **Second Pooling Layer**
|
|
211
|
+
- **Dropout**
|
|
212
|
+
|
|
213
|
+
Parameters
|
|
214
|
+
----------
|
|
215
|
+
n_channels : int
|
|
216
|
+
Number of input channels (EEG channels).
|
|
217
|
+
n_filters : int
|
|
218
|
+
Number of filters for the convolution layers.
|
|
219
|
+
conv1_kernel_size : int
|
|
220
|
+
Kernel size for the first convolution layer.
|
|
221
|
+
conv2_kernel_size : int
|
|
222
|
+
Kernel size for the second convolution layer.
|
|
223
|
+
depth_multiplier : int
|
|
224
|
+
Depth multiplier for depthwise convolution.
|
|
225
|
+
pool1_size : int
|
|
226
|
+
Kernel size for the first pooling layer.
|
|
227
|
+
pool2_size : int
|
|
228
|
+
Kernel size for the second pooling layer.
|
|
229
|
+
drop_prob : float
|
|
230
|
+
Dropout probability.
|
|
231
|
+
activation : Type[nn.Module], optional
|
|
232
|
+
Activation function class to use, by default nn.ELU.
|
|
233
|
+
"""
|
|
234
|
+
|
|
235
|
+
def __init__(
|
|
236
|
+
self,
|
|
237
|
+
n_channels: int,
|
|
238
|
+
n_filters: int,
|
|
239
|
+
conv1_kernel_size: int,
|
|
240
|
+
conv2_kernel_size: int,
|
|
241
|
+
depth_multiplier: int,
|
|
242
|
+
pool1_size: int,
|
|
243
|
+
pool2_size: int,
|
|
244
|
+
drop_prob: float,
|
|
245
|
+
activation: Type[nn.Module] = nn.ELU,
|
|
246
|
+
):
|
|
247
|
+
super().__init__(
|
|
248
|
+
nn.Conv2d(
|
|
249
|
+
in_channels=1,
|
|
250
|
+
out_channels=n_filters,
|
|
251
|
+
kernel_size=(1, conv1_kernel_size),
|
|
252
|
+
padding="same",
|
|
253
|
+
bias=False,
|
|
254
|
+
),
|
|
255
|
+
nn.BatchNorm2d(n_filters),
|
|
256
|
+
nn.Conv2d(
|
|
257
|
+
in_channels=n_filters,
|
|
258
|
+
out_channels=n_filters * depth_multiplier,
|
|
259
|
+
kernel_size=(n_channels, 1),
|
|
260
|
+
groups=n_filters,
|
|
261
|
+
bias=False,
|
|
262
|
+
),
|
|
263
|
+
nn.BatchNorm2d(n_filters * depth_multiplier),
|
|
264
|
+
activation(),
|
|
265
|
+
nn.AvgPool2d(kernel_size=(1, pool1_size)),
|
|
266
|
+
nn.Dropout(drop_prob),
|
|
267
|
+
nn.Conv2d(
|
|
268
|
+
in_channels=n_filters * depth_multiplier,
|
|
269
|
+
out_channels=n_filters * depth_multiplier,
|
|
270
|
+
kernel_size=(1, conv2_kernel_size),
|
|
271
|
+
padding="same",
|
|
272
|
+
groups=n_filters * depth_multiplier,
|
|
273
|
+
bias=False,
|
|
274
|
+
),
|
|
275
|
+
nn.BatchNorm2d(n_filters * depth_multiplier),
|
|
276
|
+
activation(),
|
|
277
|
+
nn.AvgPool2d(kernel_size=(1, pool2_size)),
|
|
278
|
+
nn.Dropout(drop_prob),
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
class _PositionalEncoding(nn.Module):
|
|
283
|
+
r"""
|
|
284
|
+
Positional encoding module that adds learnable positional embeddings.
|
|
285
|
+
|
|
286
|
+
Parameters
|
|
287
|
+
----------
|
|
288
|
+
seq_length : int
|
|
289
|
+
Sequence length.
|
|
290
|
+
d_model : int
|
|
291
|
+
Dimensionality of the model.
|
|
292
|
+
"""
|
|
293
|
+
|
|
294
|
+
def __init__(self, seq_length: int, d_model: int) -> None:
|
|
295
|
+
super().__init__()
|
|
296
|
+
self.seq_length = seq_length
|
|
297
|
+
self.d_model = d_model
|
|
298
|
+
self.pe = nn.Parameter(torch.zeros(1, seq_length, d_model))
|
|
299
|
+
|
|
300
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
301
|
+
x = x + self.pe
|
|
302
|
+
return x
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
class _Transformer(nn.Module):
|
|
306
|
+
r"""
|
|
307
|
+
Transformer encoder module with learnable class token and positional encoding.
|
|
308
|
+
|
|
309
|
+
Parameters
|
|
310
|
+
----------
|
|
311
|
+
seq_length : int
|
|
312
|
+
Sequence length of the input.
|
|
313
|
+
d_model : int
|
|
314
|
+
Dimensionality of the model.
|
|
315
|
+
num_heads : int
|
|
316
|
+
Number of heads in the multihead attention.
|
|
317
|
+
ffn_expansion_factor : float
|
|
318
|
+
Ratio to compute the dimension of the feedforward network.
|
|
319
|
+
drop_prob : float, optional
|
|
320
|
+
Dropout probability, by default 0.5.
|
|
321
|
+
num_layers : int, optional
|
|
322
|
+
Number of transformer encoder layers, by default 4.
|
|
323
|
+
"""
|
|
324
|
+
|
|
325
|
+
def __init__(
|
|
326
|
+
self,
|
|
327
|
+
seq_length: int,
|
|
328
|
+
d_model: int,
|
|
329
|
+
num_heads: int,
|
|
330
|
+
ffn_expansion_factor: float,
|
|
331
|
+
drop_prob: float = 0.5,
|
|
332
|
+
num_layers: int = 4,
|
|
333
|
+
) -> None:
|
|
334
|
+
super().__init__()
|
|
335
|
+
self.cls_embedding = nn.Parameter(torch.zeros(1, 1, d_model))
|
|
336
|
+
self.pos_embedding = _PositionalEncoding(seq_length + 1, d_model)
|
|
337
|
+
|
|
338
|
+
dim_ff = int(d_model * ffn_expansion_factor)
|
|
339
|
+
self.dropout = nn.Dropout(drop_prob)
|
|
340
|
+
self.trans = nn.TransformerEncoder(
|
|
341
|
+
nn.TransformerEncoderLayer(
|
|
342
|
+
d_model,
|
|
343
|
+
num_heads,
|
|
344
|
+
dim_ff,
|
|
345
|
+
drop_prob,
|
|
346
|
+
batch_first=True,
|
|
347
|
+
norm_first=True,
|
|
348
|
+
),
|
|
349
|
+
num_layers,
|
|
350
|
+
norm=nn.LayerNorm(d_model),
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
354
|
+
batch_size = x.shape[0]
|
|
355
|
+
x = torch.cat((self.cls_embedding.expand(batch_size, -1, -1), x), dim=1)
|
|
356
|
+
x = self.pos_embedding(x)
|
|
357
|
+
x = self.dropout(x)
|
|
358
|
+
return self.trans(x)[:, 0]
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
class _DenseLayers(nn.Sequential):
|
|
362
|
+
r"""
|
|
363
|
+
Final classification layers.
|
|
364
|
+
|
|
365
|
+
Parameters
|
|
366
|
+
----------
|
|
367
|
+
linear_in : int
|
|
368
|
+
Input dimension to the linear layer.
|
|
369
|
+
n_classes : int
|
|
370
|
+
Number of output classes.
|
|
371
|
+
"""
|
|
372
|
+
|
|
373
|
+
def __init__(self, linear_in: int, n_classes: int):
|
|
374
|
+
super().__init__(
|
|
375
|
+
nn.Flatten(),
|
|
376
|
+
nn.Linear(linear_in, n_classes),
|
|
377
|
+
)
|