braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,475 @@
|
|
|
1
|
+
# Authors: Yonghao Song <eeyhsong@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from einops.layers.torch import Rearrange
|
|
8
|
+
from torch import Tensor, nn
|
|
9
|
+
|
|
10
|
+
from braindecode.models.base import EEGModuleMixin
|
|
11
|
+
from braindecode.modules import FeedForwardBlock, MultiHeadAttention
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class EEGConformer(EEGModuleMixin, nn.Module):
|
|
15
|
+
r"""EEG Conformer from Song et al (2022) [song2022]_.
|
|
16
|
+
|
|
17
|
+
:bdg-success:`Convolution` :bdg-info:`Attention/Transformer`
|
|
18
|
+
|
|
19
|
+
.. figure:: https://raw.githubusercontent.com/eeyhsong/EEG-Conformer/refs/heads/main/visualization/Fig1.png
|
|
20
|
+
:align: center
|
|
21
|
+
:alt: EEGConformer Architecture
|
|
22
|
+
:width: 600px
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
.. rubric:: Architectural Overview
|
|
26
|
+
|
|
27
|
+
EEG-Conformer is a *convolution-first* model augmented with a *lightweight transformer
|
|
28
|
+
encoder*. The end-to-end flow is:
|
|
29
|
+
|
|
30
|
+
- (i) :class:`_PatchEmbedding` converts the continuous EEG into a compact sequence of tokens via a
|
|
31
|
+
:class:`ShallowFBCSPNet` temporal–spatial conv stem and temporal pooling;
|
|
32
|
+
- (ii) :class:`_TransformerEncoder` applies small multi-head self-attention to integrate
|
|
33
|
+
longer-range temporal context across tokens;
|
|
34
|
+
- (iii) :class:`_ClassificationHead` aggregates the sequence and performs a linear readout.
|
|
35
|
+
This preserves the strong inductive biases of shallow CNN filter banks while adding
|
|
36
|
+
just enough attention to capture dependencies beyond the pooling horizon [song2022]_.
|
|
37
|
+
|
|
38
|
+
.. rubric:: Macro Components
|
|
39
|
+
|
|
40
|
+
- :class:`_PatchEmbedding` **(Shallow conv stem → tokens)**
|
|
41
|
+
|
|
42
|
+
- *Operations.*
|
|
43
|
+
- A temporal convolution (`:class:torch.nn.Conv2d`) ``(1 x L_t)`` forms a data-driven "filter bank";
|
|
44
|
+
- A spatial convolution (`:class:torch.nn.Conv2d`) (n_chans x 1)`` projects across electrodes,
|
|
45
|
+
collapsing the channel axis into a virtual channel.
|
|
46
|
+
- **Normalization function** :class:`torch.nn.BatchNorm`
|
|
47
|
+
- **Activation function** :class:`torch.nn.ELU`
|
|
48
|
+
- **Average Pooling** :class:`torch.nn.AvgPool` along time (kernel ``(1, P)`` with stride ``(1, S)``)
|
|
49
|
+
- final ``1x1`` :class:`torch.nn.Linear` projection.
|
|
50
|
+
|
|
51
|
+
The result is rearranged to a token sequence ``(B, S_tokens, D)``, where ``D = n_filters_time``.
|
|
52
|
+
|
|
53
|
+
*Interpretability/robustness.* Temporal kernels can be inspected as FIR filters;
|
|
54
|
+
the spatial conv yields channel projections analogous to :class:`ShallowFBCSPNet`'s learned
|
|
55
|
+
spatial filters. Temporal pooling stabilizes statistics and reduces sequence length.
|
|
56
|
+
|
|
57
|
+
- :class:`_TransformerEncoder` **(context over temporal tokens)**
|
|
58
|
+
|
|
59
|
+
- *Operations.*
|
|
60
|
+
- A stack of ``num_layers`` encoder blocks. :class:`_TransformerEncoderBlock`
|
|
61
|
+
- Each block applies LayerNorm :class:`torch.nn.LayerNorm`
|
|
62
|
+
- Multi-Head Self-Attention (``num_heads``) with dropout + residual :class:`MultiHeadAttention` (:class:`torch.nn.Dropout`)
|
|
63
|
+
- LayerNorm :class:`torch.nn.LayerNorm`
|
|
64
|
+
- 2-layer feed-forward (≈4x expansion, :class:`torch.nn.GELU`) with dropout + residual.
|
|
65
|
+
|
|
66
|
+
Shapes remain ``(B, S_tokens, D)`` throughout.
|
|
67
|
+
|
|
68
|
+
*Role.* Small attention focuses on interactions among *temporal patches* (not channels),
|
|
69
|
+
extending effective receptive fields at modest cost.
|
|
70
|
+
|
|
71
|
+
- :class:`ClassificationHead` **(aggregation + readout)**
|
|
72
|
+
|
|
73
|
+
- *Operations*.
|
|
74
|
+
- Flatten, :class:`torch.nn.Flatten` the sequence ``(B, S_tokens·D)`` -
|
|
75
|
+
- MLP (:class:`torch.nn.Linear` → activation (default: :class:`torch.nn.ELU`) → :class:`torch.nn.Dropout` → :class:`torch.nn.Linear`)
|
|
76
|
+
- final Linear to classes.
|
|
77
|
+
|
|
78
|
+
With ``return_features=True``, features before the last Linear can be exported for
|
|
79
|
+
linear probing or downstream tasks.
|
|
80
|
+
|
|
81
|
+
.. rubric:: Convolutional Details
|
|
82
|
+
|
|
83
|
+
- **Temporal (where time-domain patterns are learned).**
|
|
84
|
+
The initial ``(1 x L_t)`` conv per channel acts as a *learned filter bank* for oscillatory
|
|
85
|
+
bands and transients. Subsequent **AvgPool** along time performs local integration,
|
|
86
|
+
converting activations into “patches” (tokens). Pool length/stride control the
|
|
87
|
+
token rate and set the lower bound on temporal context within each token.
|
|
88
|
+
|
|
89
|
+
- **Spatial (how electrodes are processed).**
|
|
90
|
+
A single conv with kernel ``(n_chans x 1)`` spans the full montage to learn spatial
|
|
91
|
+
projections for each temporal feature map, collapsing the channel axis into a
|
|
92
|
+
virtual channel before tokenization. This mirrors the shallow spatial step in
|
|
93
|
+
:class:`ShallowFBCSPNet` (temporal filters → spatial projection → temporal condensation).
|
|
94
|
+
|
|
95
|
+
- **Spectral (how frequency content is captured).**
|
|
96
|
+
No explicit Fourier/wavelet stage is used. Spectral selectivity emerges implicitly
|
|
97
|
+
from the learned temporal kernels; pooling further smooths high-frequency noise.
|
|
98
|
+
The effective spectral resolution is thus governed by ``L_t`` and the pooling
|
|
99
|
+
configuration.
|
|
100
|
+
|
|
101
|
+
.. rubric:: Attention / Sequential Modules
|
|
102
|
+
|
|
103
|
+
- **Type.** Standard multi-head self-attention (MHA) with ``num_heads`` heads over the token sequence.
|
|
104
|
+
- **Shapes.** Input/Output: ``(B, S_tokens, D)``; attention operates along the ``S_tokens`` axis.
|
|
105
|
+
- **Role.** Re-weights and integrates evidence across pooled windows, capturing dependencies
|
|
106
|
+
longer than any single token while leaving channel relationships to the convolutional stem.
|
|
107
|
+
The design is intentionally *small*—attention refines rather than replaces convolutional feature extraction.
|
|
108
|
+
|
|
109
|
+
.. rubric:: Additional Mechanisms
|
|
110
|
+
|
|
111
|
+
- **Parallel with ShallowFBCSPNet.** Both begin with a learned temporal filter bank,
|
|
112
|
+
spatial projection across electrodes, and early temporal condensation.
|
|
113
|
+
:class:`ShallowFBCSPNet` then computes band-power (via squaring/log-variance), whereas
|
|
114
|
+
EEG-Conformer applies BN/ELU and **continues with attention** over tokens to
|
|
115
|
+
refine temporal context before classification.
|
|
116
|
+
|
|
117
|
+
- **Tokenization knob.** ``pool_time_length`` and especially ``pool_time_stride`` set
|
|
118
|
+
the number of tokens ``S_tokens``. Smaller strides → more tokens and higher attention
|
|
119
|
+
capacity (but higher compute); larger strides → fewer tokens and stronger inductive bias.
|
|
120
|
+
|
|
121
|
+
- **Embedding dimension = filters.** ``n_filters_time`` serves double duty as both the
|
|
122
|
+
number of temporal filters in the stem and the transformer's embedding size ``D``,
|
|
123
|
+
simplifying dimensional alignment.
|
|
124
|
+
|
|
125
|
+
.. rubric:: Usage and Configuration
|
|
126
|
+
|
|
127
|
+
- **Instantiation.** Choose ``n_filters_time`` (embedding size ``D``) and
|
|
128
|
+
``filter_time_length`` to match the rhythms of interest. Tune
|
|
129
|
+
``pool_time_length/stride`` to trade temporal resolution for sequence length.
|
|
130
|
+
Keep ``num_layers`` modest (e.g., 4–6) and set ``num_heads`` to divide ``D``.
|
|
131
|
+
``final_fc_length="auto"`` infers the flattened size from PatchEmbedding.
|
|
132
|
+
|
|
133
|
+
Notes
|
|
134
|
+
-----
|
|
135
|
+
The authors recommend using data augmentation before using Conformer,
|
|
136
|
+
e.g. segmentation and recombination,
|
|
137
|
+
Please refer to the original paper and code for more details [ConformerCode]_.
|
|
138
|
+
|
|
139
|
+
The model was initially tuned on 4 seconds of 250 Hz data.
|
|
140
|
+
Please adjust the scale of the temporal convolutional layer,
|
|
141
|
+
and the pooling layer for better performance.
|
|
142
|
+
|
|
143
|
+
.. versionadded:: 0.8
|
|
144
|
+
|
|
145
|
+
We aggregate the parameters based on the parts of the models, or
|
|
146
|
+
when the parameters were used first, e.g. ``n_filters_time``.
|
|
147
|
+
|
|
148
|
+
.. versionadded:: 1.1
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
Parameters
|
|
152
|
+
----------
|
|
153
|
+
n_filters_time: int
|
|
154
|
+
Number of temporal filters, defines also embedding size.
|
|
155
|
+
filter_time_length: int
|
|
156
|
+
Length of the temporal filter.
|
|
157
|
+
pool_time_length: int
|
|
158
|
+
Length of temporal pooling filter.
|
|
159
|
+
pool_time_stride: int
|
|
160
|
+
Length of stride between temporal pooling filters.
|
|
161
|
+
drop_prob: float
|
|
162
|
+
Dropout rate of the convolutional layer.
|
|
163
|
+
num_layers: int
|
|
164
|
+
Number of self-attention layers.
|
|
165
|
+
num_heads: int
|
|
166
|
+
Number of attention heads.
|
|
167
|
+
att_drop_prob: float
|
|
168
|
+
Dropout rate of the self-attention layer.
|
|
169
|
+
final_fc_length: int | str
|
|
170
|
+
The dimension of the fully connected layer.
|
|
171
|
+
return_features: bool
|
|
172
|
+
If True, the forward method returns the features before the
|
|
173
|
+
last classification layer. Defaults to False.
|
|
174
|
+
activation: nn.Module
|
|
175
|
+
Activation function as parameter. Default is nn.ELU
|
|
176
|
+
activation_transfor: nn.Module
|
|
177
|
+
Activation function as parameter, applied at the FeedForwardBlock module
|
|
178
|
+
inside the transformer. Default is nn.GeLU
|
|
179
|
+
|
|
180
|
+
References
|
|
181
|
+
----------
|
|
182
|
+
.. [song2022] Song, Y., Zheng, Q., Liu, B. and Gao, X., 2022. EEG
|
|
183
|
+
conformer: Convolutional transformer for EEG decoding and visualization.
|
|
184
|
+
IEEE Transactions on Neural Systems and Rehabilitation Engineering,
|
|
185
|
+
31, pp.710-719. https://ieeexplore.ieee.org/document/9991178
|
|
186
|
+
.. [ConformerCode] Song, Y., Zheng, Q., Liu, B. and Gao, X., 2022. EEG
|
|
187
|
+
conformer: Convolutional transformer for EEG decoding and visualization.
|
|
188
|
+
https://github.com/eeyhsong/EEG-Conformer.
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
def __init__(
|
|
192
|
+
self,
|
|
193
|
+
n_outputs=None,
|
|
194
|
+
n_chans=None,
|
|
195
|
+
n_filters_time=40,
|
|
196
|
+
filter_time_length=25,
|
|
197
|
+
pool_time_length=75,
|
|
198
|
+
pool_time_stride=15,
|
|
199
|
+
drop_prob=0.5,
|
|
200
|
+
num_layers=6,
|
|
201
|
+
num_heads=10,
|
|
202
|
+
att_drop_prob=0.5,
|
|
203
|
+
final_fc_length="auto",
|
|
204
|
+
return_features=False,
|
|
205
|
+
activation: type[nn.Module] = nn.ELU,
|
|
206
|
+
activation_transfor: type[nn.Module] = nn.GELU,
|
|
207
|
+
n_times=None,
|
|
208
|
+
chs_info=None,
|
|
209
|
+
input_window_seconds=None,
|
|
210
|
+
sfreq=None,
|
|
211
|
+
):
|
|
212
|
+
super().__init__(
|
|
213
|
+
n_outputs=n_outputs,
|
|
214
|
+
n_chans=n_chans,
|
|
215
|
+
chs_info=chs_info,
|
|
216
|
+
n_times=n_times,
|
|
217
|
+
input_window_seconds=input_window_seconds,
|
|
218
|
+
sfreq=sfreq,
|
|
219
|
+
)
|
|
220
|
+
self.mapping = {
|
|
221
|
+
"classification_head.fc.6.weight": "final_layer.final_layer.0.weight",
|
|
222
|
+
"classification_head.fc.6.bias": "final_layer.final_layer.0.bias",
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
226
|
+
if not (self.n_chans <= 64):
|
|
227
|
+
warnings.warn(
|
|
228
|
+
"This model has only been tested on no more "
|
|
229
|
+
+ "than 64 channels. no guarantee to work with "
|
|
230
|
+
+ "more channels.",
|
|
231
|
+
UserWarning,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
self.return_features = return_features
|
|
235
|
+
|
|
236
|
+
self.patch_embedding = _PatchEmbedding(
|
|
237
|
+
n_filters_time=n_filters_time,
|
|
238
|
+
filter_time_length=filter_time_length,
|
|
239
|
+
n_channels=self.n_chans,
|
|
240
|
+
pool_time_length=pool_time_length,
|
|
241
|
+
stride_avg_pool=pool_time_stride,
|
|
242
|
+
drop_prob=drop_prob,
|
|
243
|
+
activation=activation,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
if final_fc_length == "auto":
|
|
247
|
+
assert self.n_times is not None
|
|
248
|
+
self.final_fc_length = self.get_fc_size()
|
|
249
|
+
else:
|
|
250
|
+
self.final_fc_length = final_fc_length
|
|
251
|
+
|
|
252
|
+
self.transformer = _TransformerEncoder(
|
|
253
|
+
num_layers=num_layers,
|
|
254
|
+
emb_size=n_filters_time,
|
|
255
|
+
num_heads=num_heads,
|
|
256
|
+
att_drop=att_drop_prob,
|
|
257
|
+
activation=activation_transfor,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
self.fc = _FullyConnected(
|
|
261
|
+
final_fc_length=self.final_fc_length, activation=activation
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
self.final_layer = nn.Linear(self.fc.hidden_channels, self.n_outputs)
|
|
265
|
+
|
|
266
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
267
|
+
x = torch.unsqueeze(x, dim=1) # add one extra dimension
|
|
268
|
+
x = self.patch_embedding(x)
|
|
269
|
+
feature = self.transformer(x)
|
|
270
|
+
|
|
271
|
+
if self.return_features:
|
|
272
|
+
return feature
|
|
273
|
+
|
|
274
|
+
x = self.fc(feature)
|
|
275
|
+
x = self.final_layer(x)
|
|
276
|
+
return x
|
|
277
|
+
|
|
278
|
+
def get_fc_size(self):
|
|
279
|
+
out = self.patch_embedding(torch.ones((1, 1, self.n_chans, self.n_times)))
|
|
280
|
+
size_embedding_1 = out.cpu().data.numpy().shape[1]
|
|
281
|
+
size_embedding_2 = out.cpu().data.numpy().shape[2]
|
|
282
|
+
|
|
283
|
+
return size_embedding_1 * size_embedding_2
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class _PatchEmbedding(nn.Module):
|
|
287
|
+
r"""Patch Embedding.
|
|
288
|
+
|
|
289
|
+
The authors used a convolution module to capture local features,
|
|
290
|
+
instead of position embedding.
|
|
291
|
+
|
|
292
|
+
Parameters
|
|
293
|
+
----------
|
|
294
|
+
n_filters_time: int
|
|
295
|
+
Number of temporal filters, defines also embedding size.
|
|
296
|
+
filter_time_length: int
|
|
297
|
+
Length of the temporal filter.
|
|
298
|
+
n_channels: int
|
|
299
|
+
Number of channels to be used as number of spatial filters.
|
|
300
|
+
pool_time_length: int
|
|
301
|
+
Length of temporal poling filter.
|
|
302
|
+
stride_avg_pool: int
|
|
303
|
+
Length of stride between temporal pooling filters.
|
|
304
|
+
drop_prob: float
|
|
305
|
+
Dropout rate of the convolutional layer.
|
|
306
|
+
|
|
307
|
+
Returns
|
|
308
|
+
-------
|
|
309
|
+
x: torch.Tensor
|
|
310
|
+
The output tensor of the patch embedding layer.
|
|
311
|
+
"""
|
|
312
|
+
|
|
313
|
+
def __init__(
|
|
314
|
+
self,
|
|
315
|
+
n_filters_time,
|
|
316
|
+
filter_time_length,
|
|
317
|
+
n_channels,
|
|
318
|
+
pool_time_length,
|
|
319
|
+
stride_avg_pool,
|
|
320
|
+
drop_prob,
|
|
321
|
+
activation: type[nn.Module] = nn.ELU,
|
|
322
|
+
):
|
|
323
|
+
super().__init__()
|
|
324
|
+
|
|
325
|
+
self.shallownet = nn.Sequential(
|
|
326
|
+
nn.Conv2d(1, n_filters_time, (1, filter_time_length), (1, 1)),
|
|
327
|
+
nn.Conv2d(n_filters_time, n_filters_time, (n_channels, 1), (1, 1)),
|
|
328
|
+
nn.BatchNorm2d(num_features=n_filters_time),
|
|
329
|
+
activation(),
|
|
330
|
+
nn.AvgPool2d(
|
|
331
|
+
kernel_size=(1, pool_time_length), stride=(1, stride_avg_pool)
|
|
332
|
+
),
|
|
333
|
+
# pooling acts as slicing to obtain 'patch' along the
|
|
334
|
+
# time dimension as in ViT
|
|
335
|
+
nn.Dropout(p=drop_prob),
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
self.projection = nn.Sequential(
|
|
339
|
+
nn.Conv2d(
|
|
340
|
+
n_filters_time, n_filters_time, (1, 1), stride=(1, 1)
|
|
341
|
+
), # transpose, conv could enhance fiting ability slightly
|
|
342
|
+
Rearrange("b d_model 1 seq -> b seq d_model"),
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
346
|
+
x = self.shallownet(x)
|
|
347
|
+
x = self.projection(x)
|
|
348
|
+
return x
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
class _ResidualAdd(nn.Module):
|
|
352
|
+
def __init__(self, fn):
|
|
353
|
+
super().__init__()
|
|
354
|
+
self.fn = fn
|
|
355
|
+
|
|
356
|
+
def forward(self, x):
|
|
357
|
+
res = x
|
|
358
|
+
x = self.fn(x)
|
|
359
|
+
x += res
|
|
360
|
+
return x
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
class _TransformerEncoderBlock(nn.Sequential):
|
|
364
|
+
def __init__(
|
|
365
|
+
self,
|
|
366
|
+
emb_size,
|
|
367
|
+
num_heads,
|
|
368
|
+
att_drop,
|
|
369
|
+
forward_expansion=4,
|
|
370
|
+
activation: type[nn.Module] = nn.GELU,
|
|
371
|
+
):
|
|
372
|
+
super().__init__(
|
|
373
|
+
_ResidualAdd(
|
|
374
|
+
nn.Sequential(
|
|
375
|
+
nn.LayerNorm(emb_size),
|
|
376
|
+
MultiHeadAttention(emb_size, num_heads, att_drop),
|
|
377
|
+
nn.Dropout(att_drop),
|
|
378
|
+
)
|
|
379
|
+
),
|
|
380
|
+
_ResidualAdd(
|
|
381
|
+
nn.Sequential(
|
|
382
|
+
nn.LayerNorm(emb_size),
|
|
383
|
+
FeedForwardBlock(
|
|
384
|
+
emb_size,
|
|
385
|
+
expansion=forward_expansion,
|
|
386
|
+
drop_p=att_drop,
|
|
387
|
+
activation=activation,
|
|
388
|
+
),
|
|
389
|
+
nn.Dropout(att_drop),
|
|
390
|
+
)
|
|
391
|
+
),
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
class _TransformerEncoder(nn.Sequential):
|
|
396
|
+
r"""Transformer encoder module for the transformer encoder.
|
|
397
|
+
|
|
398
|
+
Similar to the layers used in ViT.
|
|
399
|
+
|
|
400
|
+
Parameters
|
|
401
|
+
----------
|
|
402
|
+
num_layers : int
|
|
403
|
+
Number of transformer encoder blocks.
|
|
404
|
+
emb_size : int
|
|
405
|
+
Embedding size of the transformer encoder.
|
|
406
|
+
num_heads : int
|
|
407
|
+
Number of attention heads.
|
|
408
|
+
att_drop : float
|
|
409
|
+
Dropout probability for the attention layers.
|
|
410
|
+
|
|
411
|
+
"""
|
|
412
|
+
|
|
413
|
+
def __init__(
|
|
414
|
+
self,
|
|
415
|
+
num_layers,
|
|
416
|
+
emb_size,
|
|
417
|
+
num_heads,
|
|
418
|
+
att_drop,
|
|
419
|
+
activation: type[nn.Module] = nn.GELU,
|
|
420
|
+
):
|
|
421
|
+
super().__init__(
|
|
422
|
+
*[
|
|
423
|
+
_TransformerEncoderBlock(
|
|
424
|
+
emb_size, num_heads, att_drop, activation=activation
|
|
425
|
+
)
|
|
426
|
+
for _ in range(num_layers)
|
|
427
|
+
]
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
class _FullyConnected(nn.Module):
|
|
432
|
+
def __init__(
|
|
433
|
+
self,
|
|
434
|
+
final_fc_length,
|
|
435
|
+
drop_prob_1=0.5,
|
|
436
|
+
drop_prob_2=0.3,
|
|
437
|
+
out_channels=256,
|
|
438
|
+
hidden_channels=32,
|
|
439
|
+
activation: type[nn.Module] = nn.ELU,
|
|
440
|
+
):
|
|
441
|
+
"""Fully-connected layer for the transformer encoder.
|
|
442
|
+
|
|
443
|
+
Parameters
|
|
444
|
+
----------
|
|
445
|
+
final_fc_length : int
|
|
446
|
+
Length of the final fully connected layer.
|
|
447
|
+
n_classes : int
|
|
448
|
+
Number of classes for classification.
|
|
449
|
+
drop_prob_1 : float
|
|
450
|
+
Dropout probability for the first dropout layer.
|
|
451
|
+
drop_prob_2 : float
|
|
452
|
+
Dropout probability for the second dropout layer.
|
|
453
|
+
out_channels : int
|
|
454
|
+
Number of output channels for the first linear layer.
|
|
455
|
+
hidden_channels : int
|
|
456
|
+
Number of output channels for the second linear layer.
|
|
457
|
+
return_features : bool
|
|
458
|
+
Whether to return input features.
|
|
459
|
+
"""
|
|
460
|
+
|
|
461
|
+
super().__init__()
|
|
462
|
+
self.hidden_channels = hidden_channels
|
|
463
|
+
self.fc = nn.Sequential(
|
|
464
|
+
nn.Linear(final_fc_length, out_channels),
|
|
465
|
+
activation(),
|
|
466
|
+
nn.Dropout(drop_prob_1),
|
|
467
|
+
nn.Linear(out_channels, hidden_channels),
|
|
468
|
+
activation(),
|
|
469
|
+
nn.Dropout(drop_prob_2),
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
def forward(self, x):
|
|
473
|
+
x = x.contiguous().view(x.size(0), -1)
|
|
474
|
+
out = self.fc(x)
|
|
475
|
+
return out
|