braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1316 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Labram module.
|
|
3
|
+
Authors: Wei-Bang Jiang
|
|
4
|
+
Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
5
|
+
Matthew Chen <matt.chen4260@gmail.com>
|
|
6
|
+
License: BSD 3 clause
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from collections import OrderedDict
|
|
10
|
+
from warnings import warn
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
from einops import rearrange
|
|
15
|
+
from einops.layers.torch import Rearrange
|
|
16
|
+
from torch.nn.init import trunc_normal_
|
|
17
|
+
|
|
18
|
+
from braindecode.functional import rescale_parameter
|
|
19
|
+
from braindecode.models.base import EEGModuleMixin
|
|
20
|
+
from braindecode.modules import MLP, DropPath
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Labram(EEGModuleMixin, nn.Module):
|
|
24
|
+
r"""Labram from Jiang, W B et al (2024) [Jiang2024]_.
|
|
25
|
+
|
|
26
|
+
:bdg-success:`Convolution` :bdg-danger:`Foundation Model`
|
|
27
|
+
|
|
28
|
+
.. figure:: https://arxiv.org/html/2405.18765v1/x1.png
|
|
29
|
+
:align: center
|
|
30
|
+
:alt: Labram Architecture.
|
|
31
|
+
|
|
32
|
+
Large Brain Model for Learning Generic Representations with Tremendous
|
|
33
|
+
EEG Data in BCI from [Jiang2024]_.
|
|
34
|
+
|
|
35
|
+
This is an **adaptation** of the code [Code2024]_ from the Labram model.
|
|
36
|
+
|
|
37
|
+
The model is transformer architecture with **strong** inspiration from
|
|
38
|
+
BEiTv2 [BeiTv2]_.
|
|
39
|
+
|
|
40
|
+
The models can be used in two modes:
|
|
41
|
+
|
|
42
|
+
- Neural Tokenizer: Design to get an embedding layers (e.g. classification).
|
|
43
|
+
- Neural Decoder: To extract the ampliture and phase outputs with a VQSNP.
|
|
44
|
+
|
|
45
|
+
The braindecode's modification is to allow the model to be used in
|
|
46
|
+
with an input shape of (batch, n_chans, n_times), if neural tokenizer
|
|
47
|
+
equals True. The original implementation uses (batch, n_chans, n_patches,
|
|
48
|
+
patch_size) as input with static segmentation of the input data.
|
|
49
|
+
|
|
50
|
+
The models have the following sequence of steps::
|
|
51
|
+
|
|
52
|
+
if neural tokenizer:
|
|
53
|
+
- SegmentPatch: Segment the input data in patches;
|
|
54
|
+
- TemporalConv: Apply a temporal convolution to the segmented data;
|
|
55
|
+
- Residual adding cls, temporal and position embeddings (optional);
|
|
56
|
+
- WindowsAttentionBlock: Apply a windows attention block to the data;
|
|
57
|
+
- LayerNorm: Apply layer normalization to the data;
|
|
58
|
+
- Linear: An head linear layer to transformer the data into classes.
|
|
59
|
+
|
|
60
|
+
else:
|
|
61
|
+
- PatchEmbed: Apply a patch embedding to the input data;
|
|
62
|
+
- Residual adding cls, temporal and position embeddings (optional);
|
|
63
|
+
- WindowsAttentionBlock: Apply a windows attention block to the data;
|
|
64
|
+
- LayerNorm: Apply layer normalization to the data;
|
|
65
|
+
- Linear: An head linear layer to transformer the data into classes.
|
|
66
|
+
|
|
67
|
+
.. important::
|
|
68
|
+
**Pre-trained Weights Available**
|
|
69
|
+
|
|
70
|
+
This model has pre-trained weights available on the Hugging Face Hub.
|
|
71
|
+
You can load them using:
|
|
72
|
+
|
|
73
|
+
.. code-block:: python
|
|
74
|
+
|
|
75
|
+
from braindecode.models import Labram
|
|
76
|
+
|
|
77
|
+
# Load pre-trained model from Hugging Face Hub
|
|
78
|
+
model = Labram.from_pretrained("braindecode/labram-pretrained")
|
|
79
|
+
|
|
80
|
+
To push your own trained model to the Hub:
|
|
81
|
+
|
|
82
|
+
.. code-block:: python
|
|
83
|
+
|
|
84
|
+
# After training your model
|
|
85
|
+
model.push_to_hub(
|
|
86
|
+
repo_id="username/my-labram-model", commit_message="Upload trained Labram model"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
Requires installing ``braindecode[hug]`` for Hub integration.
|
|
90
|
+
|
|
91
|
+
.. versionadded:: 0.9
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
Examples
|
|
95
|
+
--------
|
|
96
|
+
Load pre-trained weights::
|
|
97
|
+
|
|
98
|
+
>>> import torch
|
|
99
|
+
>>> from braindecode.models import Labram
|
|
100
|
+
>>> model = Labram(n_times=1600, n_chans=64, n_outputs=4)
|
|
101
|
+
>>> url = "https://huggingface.co/braindecode/Labram-Braindecode/blob/main/braindecode_labram_base.pt"
|
|
102
|
+
>>> state = torch.hub.load_state_dict_from_url(url, progress=True)
|
|
103
|
+
>>> model.load_state_dict(state)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
Parameters
|
|
107
|
+
----------
|
|
108
|
+
patch_size : int
|
|
109
|
+
The size of the patch to be used in the patch embedding.
|
|
110
|
+
embed_dim : int
|
|
111
|
+
The dimension of the embedding.
|
|
112
|
+
conv_in_channels : int
|
|
113
|
+
The number of convolutional input channels.
|
|
114
|
+
conv_out_channels : int
|
|
115
|
+
The number of convolutional output channels.
|
|
116
|
+
num_layers : int (default=12)
|
|
117
|
+
The number of attention layers of the model.
|
|
118
|
+
num_heads : int (default=10)
|
|
119
|
+
The number of attention heads.
|
|
120
|
+
mlp_ratio : float (default=4.0)
|
|
121
|
+
The expansion ratio of the mlp layer
|
|
122
|
+
qkv_bias : bool (default=False)
|
|
123
|
+
If True, add a learnable bias to the query, key, and value tensors.
|
|
124
|
+
qk_norm : Pytorch Normalize layer (default=nn.LayerNorm)
|
|
125
|
+
If not None, apply LayerNorm to the query and key tensors.
|
|
126
|
+
Default is nn.LayerNorm for better weight transfer from original LaBraM.
|
|
127
|
+
Set to None to disable Q,K normalization.
|
|
128
|
+
qk_scale : float (default=None)
|
|
129
|
+
If not None, use this value as the scale factor. If None,
|
|
130
|
+
use head_dim**-0.5, where head_dim = dim // num_heads.
|
|
131
|
+
drop_prob : float (default=0.0)
|
|
132
|
+
Dropout rate for the attention weights.
|
|
133
|
+
attn_drop_prob : float (default=0.0)
|
|
134
|
+
Dropout rate for the attention weights.
|
|
135
|
+
drop_path_prob : float (default=0.0)
|
|
136
|
+
Dropout rate for the attention weights used on DropPath.
|
|
137
|
+
norm_layer : Pytorch Normalize layer (default=nn.LayerNorm)
|
|
138
|
+
The normalization layer to be used.
|
|
139
|
+
init_values : float (default=0.1)
|
|
140
|
+
If not None, use this value to initialize the gamma_1 and gamma_2
|
|
141
|
+
parameters for residual scaling. Default is 0.1 for better weight
|
|
142
|
+
transfer from original LaBraM. Set to None to disable.
|
|
143
|
+
use_abs_pos_emb : bool (default=True)
|
|
144
|
+
If True, use absolute position embedding.
|
|
145
|
+
use_mean_pooling : bool (default=True)
|
|
146
|
+
If True, use mean pooling.
|
|
147
|
+
init_scale : float (default=0.001)
|
|
148
|
+
The initial scale to be used in the parameters of the model.
|
|
149
|
+
neural_tokenizer : bool (default=True)
|
|
150
|
+
The model can be used in two modes: Neural Tokenizer or Neural Decoder.
|
|
151
|
+
attn_head_dim : bool (default=None)
|
|
152
|
+
The head dimension to be used in the attention layer, to be used only
|
|
153
|
+
during pre-training.
|
|
154
|
+
activation: nn.Module, default=nn.GELU
|
|
155
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
156
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.GELU``.
|
|
157
|
+
|
|
158
|
+
References
|
|
159
|
+
----------
|
|
160
|
+
.. [Jiang2024] Wei-Bang Jiang, Li-Ming Zhao, Bao-Liang Lu. 2024, May.
|
|
161
|
+
Large Brain Model for Learning Generic Representations with Tremendous
|
|
162
|
+
EEG Data in BCI. The Twelfth International Conference on Learning
|
|
163
|
+
Representations, ICLR.
|
|
164
|
+
.. [Code2024] Wei-Bang Jiang, Li-Ming Zhao, Bao-Liang Lu. 2024. Labram
|
|
165
|
+
Large Brain Model for Learning Generic Representations with Tremendous
|
|
166
|
+
EEG Data in BCI. GitHub https://github.com/935963004/LaBraM
|
|
167
|
+
(accessed 2024-03-02)
|
|
168
|
+
.. [BeiTv2] Zhiliang Peng, Li Dong, Hangbo Bao, Qixiang Ye, Furu Wei. 2024.
|
|
169
|
+
BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers.
|
|
170
|
+
arXiv:2208.06366 [cs.CV]
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
def __init__(
|
|
174
|
+
self,
|
|
175
|
+
n_times=None,
|
|
176
|
+
n_outputs=None,
|
|
177
|
+
chs_info=None,
|
|
178
|
+
n_chans=None,
|
|
179
|
+
sfreq=None,
|
|
180
|
+
input_window_seconds=None,
|
|
181
|
+
patch_size=200,
|
|
182
|
+
embed_dim=200,
|
|
183
|
+
conv_in_channels=1,
|
|
184
|
+
conv_out_channels=8,
|
|
185
|
+
num_layers=12,
|
|
186
|
+
num_heads=10,
|
|
187
|
+
mlp_ratio=4.0,
|
|
188
|
+
qkv_bias=False,
|
|
189
|
+
qk_norm: type[nn.Module] = nn.LayerNorm,
|
|
190
|
+
qk_scale=None,
|
|
191
|
+
drop_prob=0.0,
|
|
192
|
+
attn_drop_prob=0.0,
|
|
193
|
+
drop_path_prob=0.0,
|
|
194
|
+
norm_layer: type[nn.Module] = nn.LayerNorm,
|
|
195
|
+
init_values=0.1,
|
|
196
|
+
use_abs_pos_emb=True,
|
|
197
|
+
use_mean_pooling=True,
|
|
198
|
+
init_scale=0.001,
|
|
199
|
+
neural_tokenizer=True,
|
|
200
|
+
attn_head_dim=None,
|
|
201
|
+
activation: type[nn.Module] = nn.GELU,
|
|
202
|
+
):
|
|
203
|
+
super().__init__(
|
|
204
|
+
n_outputs=n_outputs,
|
|
205
|
+
n_chans=n_chans,
|
|
206
|
+
chs_info=chs_info,
|
|
207
|
+
n_times=n_times,
|
|
208
|
+
input_window_seconds=input_window_seconds,
|
|
209
|
+
sfreq=sfreq,
|
|
210
|
+
)
|
|
211
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
212
|
+
|
|
213
|
+
self.patch_size = patch_size
|
|
214
|
+
self.num_features = self.embed_dim = embed_dim
|
|
215
|
+
self.neural_tokenizer = neural_tokenizer
|
|
216
|
+
self.init_scale = init_scale
|
|
217
|
+
|
|
218
|
+
if patch_size > self.n_times:
|
|
219
|
+
warn(
|
|
220
|
+
f"patch_size ({patch_size}) > n_times ({self.n_times}); "
|
|
221
|
+
f"setting patch_size = {self.n_times}.",
|
|
222
|
+
UserWarning,
|
|
223
|
+
)
|
|
224
|
+
self.patch_size = self.n_times
|
|
225
|
+
self.num_features = None
|
|
226
|
+
self.embed_dim = None
|
|
227
|
+
else:
|
|
228
|
+
self.patch_size = patch_size
|
|
229
|
+
self.n_path = self.n_times // self.patch_size
|
|
230
|
+
|
|
231
|
+
if neural_tokenizer and conv_in_channels != 1:
|
|
232
|
+
warn(
|
|
233
|
+
"The model is in Neural Tokenizer mode, but the variable "
|
|
234
|
+
+ "`conv_in_channels` is different from the default values."
|
|
235
|
+
+ "`conv_in_channels` is only needed for the Neural Decoder mode."
|
|
236
|
+
+ "conv_in_channels is not used in the Neural Tokenizer mode.",
|
|
237
|
+
UserWarning,
|
|
238
|
+
)
|
|
239
|
+
conv_in_channels = 1
|
|
240
|
+
# If you can use the model in Neural Tokenizer mode,
|
|
241
|
+
# temporal conv layer will be use over the patched dataset
|
|
242
|
+
if neural_tokenizer:
|
|
243
|
+
self.patch_embed = nn.Sequential(
|
|
244
|
+
OrderedDict(
|
|
245
|
+
[
|
|
246
|
+
(
|
|
247
|
+
"segment_patch",
|
|
248
|
+
_SegmentPatch(
|
|
249
|
+
n_times=self.n_times,
|
|
250
|
+
patch_size=self.patch_size,
|
|
251
|
+
n_chans=self.n_chans,
|
|
252
|
+
emb_dim=self.patch_size,
|
|
253
|
+
),
|
|
254
|
+
),
|
|
255
|
+
(
|
|
256
|
+
"temporal_conv",
|
|
257
|
+
_TemporalConv(
|
|
258
|
+
out_channels=conv_out_channels, activation=activation
|
|
259
|
+
),
|
|
260
|
+
),
|
|
261
|
+
]
|
|
262
|
+
)
|
|
263
|
+
)
|
|
264
|
+
else:
|
|
265
|
+
# If not, the model will be used as Neural Decoder mode
|
|
266
|
+
# So the input here will be after the VQVAE encoder
|
|
267
|
+
# To be used to extract the ampliture and phase outputs.
|
|
268
|
+
# Adding inside a Sequential to use the same convention as the
|
|
269
|
+
# Neural Tokenizer mode.
|
|
270
|
+
self.patch_embed = nn.Sequential()
|
|
271
|
+
self.patch_embed.add_module(
|
|
272
|
+
"segment_patch",
|
|
273
|
+
_PatchEmbed(
|
|
274
|
+
n_times=self.n_times,
|
|
275
|
+
patch_size=patch_size,
|
|
276
|
+
in_channels=conv_in_channels,
|
|
277
|
+
emb_dim=self.embed_dim,
|
|
278
|
+
),
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
with torch.no_grad():
|
|
282
|
+
dummy = torch.zeros(1, self.n_chans, self.n_times)
|
|
283
|
+
out = self.patch_embed(dummy)
|
|
284
|
+
# out.shape for tokenizer: (1, n_chans, emb_dim)
|
|
285
|
+
# for decoder: (1, n_patch, patch_size, emb_dim), but we want last dim
|
|
286
|
+
self.embed_dim = out.shape[-1]
|
|
287
|
+
self.num_features = self.embed_dim
|
|
288
|
+
|
|
289
|
+
# Defining the parameters
|
|
290
|
+
# Creating a parameter list with cls token]
|
|
291
|
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
|
292
|
+
# Positional embedding and time embedding are complementary
|
|
293
|
+
# one is for the spatial information and the other is for the temporal
|
|
294
|
+
# information.
|
|
295
|
+
# The time embedding is used to encode something in the number of
|
|
296
|
+
# patches, and the position embedding is used to encode the channels'
|
|
297
|
+
# information.
|
|
298
|
+
if use_abs_pos_emb:
|
|
299
|
+
self.position_embedding = nn.Parameter(
|
|
300
|
+
torch.zeros(1, self.n_chans + 1, self.embed_dim),
|
|
301
|
+
requires_grad=True,
|
|
302
|
+
)
|
|
303
|
+
else:
|
|
304
|
+
self.position_embedding = None
|
|
305
|
+
|
|
306
|
+
self.temporal_embedding = nn.Parameter(
|
|
307
|
+
torch.zeros(1, self.patch_embed[0].n_patchs + 1, self.embed_dim),
|
|
308
|
+
requires_grad=True,
|
|
309
|
+
)
|
|
310
|
+
self.pos_drop = nn.Dropout(p=drop_prob)
|
|
311
|
+
|
|
312
|
+
dpr = [
|
|
313
|
+
x.item() for x in torch.linspace(0, drop_path_prob, num_layers)
|
|
314
|
+
] # stochastic depth decay rule
|
|
315
|
+
self.blocks = nn.ModuleList(
|
|
316
|
+
[
|
|
317
|
+
_WindowsAttentionBlock(
|
|
318
|
+
dim=self.embed_dim,
|
|
319
|
+
num_heads=num_heads,
|
|
320
|
+
mlp_ratio=mlp_ratio,
|
|
321
|
+
qkv_bias=qkv_bias,
|
|
322
|
+
qk_norm=qk_norm,
|
|
323
|
+
qk_scale=qk_scale,
|
|
324
|
+
drop=drop_prob,
|
|
325
|
+
attn_drop=attn_drop_prob,
|
|
326
|
+
drop_path=dpr[i],
|
|
327
|
+
norm_layer=norm_layer,
|
|
328
|
+
init_values=init_values,
|
|
329
|
+
window_size=(
|
|
330
|
+
self.patch_embed[0].patch_shape
|
|
331
|
+
if not neural_tokenizer
|
|
332
|
+
else None
|
|
333
|
+
),
|
|
334
|
+
attn_head_dim=attn_head_dim,
|
|
335
|
+
activation=activation,
|
|
336
|
+
)
|
|
337
|
+
for i in range(num_layers)
|
|
338
|
+
]
|
|
339
|
+
)
|
|
340
|
+
self.norm = nn.Identity() if use_mean_pooling else norm_layer(self.embed_dim)
|
|
341
|
+
self.fc_norm = norm_layer(self.embed_dim) if use_mean_pooling else None
|
|
342
|
+
|
|
343
|
+
if self.n_outputs > 0:
|
|
344
|
+
self.final_layer = nn.Linear(self.embed_dim, self.n_outputs)
|
|
345
|
+
else:
|
|
346
|
+
self.final_layer = nn.Identity()
|
|
347
|
+
|
|
348
|
+
self.apply(self._init_weights)
|
|
349
|
+
self.fix_init_weight_and_init_embedding()
|
|
350
|
+
|
|
351
|
+
def fix_init_weight_and_init_embedding(self):
|
|
352
|
+
"""
|
|
353
|
+
Fix the initial weight and the initial embedding.
|
|
354
|
+
Initializing with truncated normal distribution.
|
|
355
|
+
"""
|
|
356
|
+
trunc_normal_(self.cls_token, std=0.02)
|
|
357
|
+
trunc_normal_(self.temporal_embedding, std=0.02)
|
|
358
|
+
|
|
359
|
+
if self.position_embedding is not None:
|
|
360
|
+
trunc_normal_(self.position_embedding, std=0.02)
|
|
361
|
+
|
|
362
|
+
if isinstance(self.final_layer, nn.Linear):
|
|
363
|
+
trunc_normal_(self.final_layer.weight, std=0.02)
|
|
364
|
+
|
|
365
|
+
for layer_id, layer in enumerate(self.blocks):
|
|
366
|
+
rescale_parameter(layer.attn.proj.weight.data, layer_id + 1)
|
|
367
|
+
rescale_parameter(layer.mlp[-2].weight.data, layer_id + 1)
|
|
368
|
+
|
|
369
|
+
if isinstance(self.final_layer, nn.Linear):
|
|
370
|
+
self.final_layer.weight.data.mul_(self.init_scale)
|
|
371
|
+
self.final_layer.bias.data.mul_(self.init_scale)
|
|
372
|
+
|
|
373
|
+
@staticmethod
|
|
374
|
+
def _init_weights(layer):
|
|
375
|
+
"""
|
|
376
|
+
Initialize the weights of the model for each layer layer.
|
|
377
|
+
|
|
378
|
+
If the layer is a linear layer, the weight will be initialized
|
|
379
|
+
with a truncated normal distribution with std=0.02.
|
|
380
|
+
|
|
381
|
+
If m.bias is not None, the bias will be initialized with a constant
|
|
382
|
+
value of 0.
|
|
383
|
+
|
|
384
|
+
If the layer is a layer normalization layer, the bias will be
|
|
385
|
+
initialized with a constant value of 0, and the weight will be
|
|
386
|
+
initialized with a constant value of 1.
|
|
387
|
+
|
|
388
|
+
Parameters
|
|
389
|
+
----------
|
|
390
|
+
m : torch.nn.Module
|
|
391
|
+
The layer of the pytorch model
|
|
392
|
+
"""
|
|
393
|
+
|
|
394
|
+
if isinstance(layer, nn.Linear):
|
|
395
|
+
trunc_normal_(layer.weight, std=0.02)
|
|
396
|
+
if layer.bias is not None:
|
|
397
|
+
nn.init.constant_(layer.bias, 0)
|
|
398
|
+
elif isinstance(layer, nn.LayerNorm):
|
|
399
|
+
nn.init.constant_(layer.bias, 0)
|
|
400
|
+
nn.init.constant_(layer.weight, 1.0)
|
|
401
|
+
|
|
402
|
+
def get_num_layers(self):
|
|
403
|
+
"""
|
|
404
|
+
Convenience method to get the number of layers in the model.
|
|
405
|
+
"""
|
|
406
|
+
return len(self.blocks)
|
|
407
|
+
|
|
408
|
+
def forward_features(
|
|
409
|
+
self,
|
|
410
|
+
x,
|
|
411
|
+
input_chans=None,
|
|
412
|
+
return_patch_tokens=False,
|
|
413
|
+
return_all_tokens=False,
|
|
414
|
+
):
|
|
415
|
+
"""
|
|
416
|
+
Forward the features of the model.
|
|
417
|
+
|
|
418
|
+
Parameters
|
|
419
|
+
----------
|
|
420
|
+
x : torch.Tensor
|
|
421
|
+
The input data with shape (batch, n_chans, n_times).
|
|
422
|
+
input_chans : int
|
|
423
|
+
The number of input channels.
|
|
424
|
+
return_patch_tokens : bool
|
|
425
|
+
Whether to return the patch tokens.
|
|
426
|
+
return_all_tokens : bool
|
|
427
|
+
Whether to return all the tokens.
|
|
428
|
+
|
|
429
|
+
Returns
|
|
430
|
+
-------
|
|
431
|
+
x : torch.Tensor
|
|
432
|
+
The output of the model.
|
|
433
|
+
"""
|
|
434
|
+
batch_size = x.shape[0]
|
|
435
|
+
|
|
436
|
+
if self.neural_tokenizer:
|
|
437
|
+
# For neural tokenizer: input is (batch, n_chans, n_times)
|
|
438
|
+
# patch_embed returns (batch, n_chans, emb_dim)
|
|
439
|
+
x = self.patch_embed(x)
|
|
440
|
+
# x shape: (batch, n_chans, emb_dim)
|
|
441
|
+
n_patch = self.n_chans
|
|
442
|
+
temporal = self.embed_dim
|
|
443
|
+
else:
|
|
444
|
+
# For neural decoder: input is (batch, n_chans, n_times)
|
|
445
|
+
# patch_embed returns (batch, n_patchs, emb_dim)
|
|
446
|
+
x = self.patch_embed(x)
|
|
447
|
+
# x shape: (batch, n_patchs, emb_dim)
|
|
448
|
+
batch_size, n_patch, temporal = x.shape
|
|
449
|
+
|
|
450
|
+
# add the [CLS] token to the embedded patch tokens
|
|
451
|
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
|
452
|
+
|
|
453
|
+
# Concatenate cls token with patch/channel embeddings
|
|
454
|
+
x = torch.cat((cls_tokens, x), dim=1)
|
|
455
|
+
|
|
456
|
+
# Positional Embedding
|
|
457
|
+
if self.position_embedding is not None:
|
|
458
|
+
if self.neural_tokenizer:
|
|
459
|
+
# In tokenizer mode, use channel-based position embedding
|
|
460
|
+
if input_chans is not None:
|
|
461
|
+
pos_embed_used = self.position_embedding[:, input_chans]
|
|
462
|
+
else:
|
|
463
|
+
pos_embed_used = self.position_embedding
|
|
464
|
+
|
|
465
|
+
pos_embed = self._adj_position_embedding(
|
|
466
|
+
pos_embed_used=pos_embed_used, batch_size=batch_size
|
|
467
|
+
)
|
|
468
|
+
else:
|
|
469
|
+
# In decoder mode, we have different number of patches
|
|
470
|
+
# Adapt position embedding for n_patch patches
|
|
471
|
+
# Use the first n_patch+1 positions from position_embedding
|
|
472
|
+
n_pos = min(self.position_embedding.shape[1], n_patch + 1)
|
|
473
|
+
pos_embed_used = self.position_embedding[:, :n_pos, :]
|
|
474
|
+
pos_embed = pos_embed_used.expand(batch_size, -1, -1)
|
|
475
|
+
|
|
476
|
+
x += pos_embed
|
|
477
|
+
|
|
478
|
+
# The time embedding is added across the channels after the [CLS] token
|
|
479
|
+
if self.neural_tokenizer:
|
|
480
|
+
num_ch = self.n_chans
|
|
481
|
+
time_embed = self._adj_temporal_embedding(
|
|
482
|
+
num_ch=num_ch, batch_size=batch_size, dim_embed=temporal
|
|
483
|
+
)
|
|
484
|
+
x[:, 1:, :] += time_embed
|
|
485
|
+
else:
|
|
486
|
+
# In decoder mode, we have n_patch patches and don't need to expand
|
|
487
|
+
# Just broadcast the temporal embedding
|
|
488
|
+
if temporal is None:
|
|
489
|
+
temporal = self.embed_dim
|
|
490
|
+
|
|
491
|
+
# Get temporal embeddings for n_patch patches
|
|
492
|
+
n_time_tokens = min(n_patch, self.temporal_embedding.shape[1] - 1)
|
|
493
|
+
time_embed = self.temporal_embedding[
|
|
494
|
+
:, 1 : n_time_tokens + 1, :
|
|
495
|
+
] # (1, n_patch, emb_dim)
|
|
496
|
+
time_embed = time_embed.expand(
|
|
497
|
+
batch_size, -1, -1
|
|
498
|
+
) # (batch, n_patch, emb_dim)
|
|
499
|
+
x[:, 1:, :] += time_embed
|
|
500
|
+
|
|
501
|
+
x = self.pos_drop(x)
|
|
502
|
+
|
|
503
|
+
for blk in self.blocks:
|
|
504
|
+
x = blk(x)
|
|
505
|
+
|
|
506
|
+
x = self.norm(x)
|
|
507
|
+
if self.fc_norm is not None:
|
|
508
|
+
if return_all_tokens:
|
|
509
|
+
return self.fc_norm(x)
|
|
510
|
+
tokens = x[:, 1:, :]
|
|
511
|
+
if return_patch_tokens:
|
|
512
|
+
return self.fc_norm(tokens)
|
|
513
|
+
return self.fc_norm(tokens.mean(1))
|
|
514
|
+
else:
|
|
515
|
+
if return_all_tokens:
|
|
516
|
+
return x
|
|
517
|
+
elif return_patch_tokens:
|
|
518
|
+
return x[:, 1:]
|
|
519
|
+
return x[:, 0]
|
|
520
|
+
|
|
521
|
+
def forward(
|
|
522
|
+
self,
|
|
523
|
+
x,
|
|
524
|
+
input_chans=None,
|
|
525
|
+
return_patch_tokens=False,
|
|
526
|
+
return_all_tokens=False,
|
|
527
|
+
):
|
|
528
|
+
"""
|
|
529
|
+
Forward the input EEG data through the model.
|
|
530
|
+
|
|
531
|
+
Parameters
|
|
532
|
+
----------
|
|
533
|
+
x: torch.Tensor
|
|
534
|
+
The input data with shape (batch, n_chans, n_times)
|
|
535
|
+
or (batch, n_chans, n_patches, patch size).
|
|
536
|
+
input_chans: int
|
|
537
|
+
An input channel to select some dimensions
|
|
538
|
+
return_patch_tokens: bool
|
|
539
|
+
Return the patch tokens
|
|
540
|
+
return_all_tokens: bool
|
|
541
|
+
Return all the tokens
|
|
542
|
+
|
|
543
|
+
Returns
|
|
544
|
+
-------
|
|
545
|
+
torch.Tensor
|
|
546
|
+
The output of the model with dimensions (batch, n_outputs)
|
|
547
|
+
"""
|
|
548
|
+
x = self.forward_features(
|
|
549
|
+
x,
|
|
550
|
+
input_chans=input_chans,
|
|
551
|
+
return_patch_tokens=return_patch_tokens,
|
|
552
|
+
return_all_tokens=return_all_tokens,
|
|
553
|
+
)
|
|
554
|
+
x = self.final_layer(x)
|
|
555
|
+
return x
|
|
556
|
+
|
|
557
|
+
def get_classifier(self):
|
|
558
|
+
"""
|
|
559
|
+
Get the classifier of the model.
|
|
560
|
+
|
|
561
|
+
Returns
|
|
562
|
+
-------
|
|
563
|
+
torch.nn.Module
|
|
564
|
+
The classifier of the head model.
|
|
565
|
+
"""
|
|
566
|
+
return self.final_layer
|
|
567
|
+
|
|
568
|
+
def reset_classifier(self, n_outputs):
|
|
569
|
+
"""
|
|
570
|
+
Reset the classifier with the new number of classes.
|
|
571
|
+
|
|
572
|
+
Parameters
|
|
573
|
+
----------
|
|
574
|
+
n_outputs : int
|
|
575
|
+
The new number of classes.
|
|
576
|
+
"""
|
|
577
|
+
self.n_outputs = n_outputs
|
|
578
|
+
self.final_layer = (
|
|
579
|
+
nn.Linear(self.emb_dim, self.n_outputs)
|
|
580
|
+
if self.n_outputs > 0
|
|
581
|
+
else nn.Identity()
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
def _adj_temporal_embedding(self, num_ch, batch_size, dim_embed=None):
|
|
585
|
+
"""
|
|
586
|
+
Adjust the dimensions of the time embedding to match the
|
|
587
|
+
number of channels or patches.
|
|
588
|
+
|
|
589
|
+
Parameters
|
|
590
|
+
----------
|
|
591
|
+
num_ch : int
|
|
592
|
+
The number of channels or number of patches.
|
|
593
|
+
batch_size : int
|
|
594
|
+
Batch size of the input data.
|
|
595
|
+
dim_embed : int
|
|
596
|
+
The embedding dimension (temporal feature dimension).
|
|
597
|
+
|
|
598
|
+
Returns
|
|
599
|
+
-------
|
|
600
|
+
temporal_embedding : torch.Tensor
|
|
601
|
+
The adjusted time embedding to be added across the channels
|
|
602
|
+
after the [CLS] token. (x[:, 1:, :] += time_embed)
|
|
603
|
+
"""
|
|
604
|
+
if dim_embed is None:
|
|
605
|
+
cut_dimension = self.patch_size
|
|
606
|
+
else:
|
|
607
|
+
cut_dimension = min(dim_embed, self.temporal_embedding.shape[1] - 1)
|
|
608
|
+
|
|
609
|
+
# Get the temporal embedding: (1, temporal_embedding_dim, emb_size)
|
|
610
|
+
# Slice to cut_dimension: (1, cut_dimension, emb_size)
|
|
611
|
+
temporal_embedding = self.temporal_embedding[:, 1 : cut_dimension + 1, :]
|
|
612
|
+
|
|
613
|
+
# Add a new dimension to the time embedding
|
|
614
|
+
# e.g. (1, 5, 200) -> (1, 1, 5, 200)
|
|
615
|
+
temporal_embedding = temporal_embedding.unsqueeze(1)
|
|
616
|
+
|
|
617
|
+
# Expand the time embedding to match the number of channels or patches
|
|
618
|
+
# (1, 1, cut_dimension, 200) -> (batch_size, num_ch, cut_dimension, 200)
|
|
619
|
+
temporal_embedding = temporal_embedding.expand(batch_size, num_ch, -1, -1)
|
|
620
|
+
|
|
621
|
+
# Flatten the intermediate dimensions
|
|
622
|
+
# (batch_size, num_ch, cut_dimension, 200) -> (batch_size, num_ch * cut_dimension, 200)
|
|
623
|
+
temporal_embedding = temporal_embedding.flatten(1, 2)
|
|
624
|
+
|
|
625
|
+
return temporal_embedding
|
|
626
|
+
|
|
627
|
+
def _adj_position_embedding(self, pos_embed_used, batch_size):
|
|
628
|
+
"""
|
|
629
|
+
Adjust the dimensions of position embedding to match the
|
|
630
|
+
number of patches.
|
|
631
|
+
|
|
632
|
+
Parameters
|
|
633
|
+
----------
|
|
634
|
+
pos_embed_used : torch.Tensor
|
|
635
|
+
The position embedding to be adjusted.
|
|
636
|
+
batch_size : int
|
|
637
|
+
The number of batches.
|
|
638
|
+
|
|
639
|
+
Returns
|
|
640
|
+
-------
|
|
641
|
+
pos_embed : torch.Tensor
|
|
642
|
+
The adjusted position embedding
|
|
643
|
+
"""
|
|
644
|
+
# [CLS] token has no position embedding
|
|
645
|
+
pos_embed = pos_embed_used[:, 1:, :]
|
|
646
|
+
# Adding a new dimension to the position embedding
|
|
647
|
+
pos_embed = pos_embed.unsqueeze(2)
|
|
648
|
+
# Need to expand the position embedding to match the number of
|
|
649
|
+
# n_patches
|
|
650
|
+
pos_embed = pos_embed.expand(batch_size, -1, self.patch_embed[0].n_patchs, -1)
|
|
651
|
+
# Flatten the intermediate dimensions,
|
|
652
|
+
# such as the number of patches and the "channels" dim
|
|
653
|
+
pos_embed = pos_embed.flatten(1, 2)
|
|
654
|
+
# Get the base position embedding
|
|
655
|
+
# This is the position embedding for the [CLS] token
|
|
656
|
+
base_pos = pos_embed[:, 0:1, :].expand(batch_size, -1, -1)
|
|
657
|
+
# Concatenate the base position embedding with the
|
|
658
|
+
# position embedding
|
|
659
|
+
pos_embed = torch.cat((base_pos, pos_embed), dim=1)
|
|
660
|
+
return pos_embed
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
class _SegmentPatch(nn.Module):
|
|
664
|
+
r"""Segment and Patch for EEG data.
|
|
665
|
+
|
|
666
|
+
Adapted Patch Embedding inspired in the Visual Transform approach
|
|
667
|
+
to extract the learned segmentor, we expect get the input shape as:
|
|
668
|
+
(Batch, Number of Channels, number of times points).
|
|
669
|
+
|
|
670
|
+
We apply a 2D convolution with kernel size of (1, patch_size)
|
|
671
|
+
and a stride of (1, patch_size).
|
|
672
|
+
|
|
673
|
+
The results output shape will be:
|
|
674
|
+
(Batch, Number of Channels, Number of patches, patch size).
|
|
675
|
+
|
|
676
|
+
This way, we learned a convolution to segment the input shape.
|
|
677
|
+
|
|
678
|
+
The number of patches is calculated as the number of samples divided
|
|
679
|
+
by the patch size.
|
|
680
|
+
|
|
681
|
+
Parameters:
|
|
682
|
+
-----------
|
|
683
|
+
n_times: int (default=2000)
|
|
684
|
+
Number of temporal components of the input tensor.
|
|
685
|
+
in_chans: int (default=1)
|
|
686
|
+
number of electrods from the EEG signal
|
|
687
|
+
emb_dim: int (default=200)
|
|
688
|
+
Number of n_output to be used in the convolution, here,
|
|
689
|
+
we used the same as patch_size.
|
|
690
|
+
patch_size: int (default=200)
|
|
691
|
+
Size of the patch, default is 1-seconds with 200Hz.
|
|
692
|
+
Returns:
|
|
693
|
+
--------
|
|
694
|
+
x_patched: torch.Tensor
|
|
695
|
+
Output tensor of shape (batch, n_chans, num_patches, emb_dim).
|
|
696
|
+
"""
|
|
697
|
+
|
|
698
|
+
def __init__(
|
|
699
|
+
self, n_times=2000, patch_size=200, n_chans=1, emb_dim=200, learned_patcher=True
|
|
700
|
+
):
|
|
701
|
+
super().__init__()
|
|
702
|
+
|
|
703
|
+
self.n_times = n_times
|
|
704
|
+
self.patch_size = patch_size
|
|
705
|
+
self.n_patchs = n_times // patch_size
|
|
706
|
+
self.emb_dim = emb_dim
|
|
707
|
+
self.n_chans = n_chans
|
|
708
|
+
self.learned_patcher = learned_patcher
|
|
709
|
+
|
|
710
|
+
self.patcher = nn.Conv1d(
|
|
711
|
+
in_channels=1,
|
|
712
|
+
out_channels=self.emb_dim,
|
|
713
|
+
kernel_size=self.patch_size,
|
|
714
|
+
stride=self.patch_size,
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
self.adding_extra_dim = Rearrange(
|
|
718
|
+
pattern="batch nchans temporal -> (batch nchans) 1 temporal"
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
def forward(self, x):
|
|
722
|
+
"""
|
|
723
|
+
Using an 1D convolution to generate segments of EEG signal.
|
|
724
|
+
|
|
725
|
+
Parameters:
|
|
726
|
+
-----------
|
|
727
|
+
X: Tensor
|
|
728
|
+
[batch, n_chans, n_times]
|
|
729
|
+
|
|
730
|
+
Returns:
|
|
731
|
+
--------
|
|
732
|
+
X_patch: Tensor
|
|
733
|
+
[batch, n_chans, n_times//patch_size, patch_size]
|
|
734
|
+
"""
|
|
735
|
+
batch_size, _, _ = x.shape
|
|
736
|
+
# Input shape: [batch, n_chs, n_times]
|
|
737
|
+
|
|
738
|
+
# First, rearrange input to treat the channel dimension 'n_chs' as
|
|
739
|
+
# separate 'dimension' in batch for Conv1d
|
|
740
|
+
# This requires reshaping x to have a height of 1 for each EEG sample.
|
|
741
|
+
if self.learned_patcher:
|
|
742
|
+
x = self.adding_extra_dim(x)
|
|
743
|
+
|
|
744
|
+
# Apply the convolution along the temporal dimension
|
|
745
|
+
# Conv2d output shape: [(batch*n_chs), emb_dim, n_patches]
|
|
746
|
+
x = self.patcher(x)
|
|
747
|
+
|
|
748
|
+
# Now, rearrange output to get back to a batch-first format,
|
|
749
|
+
# combining embedded patches with channel information
|
|
750
|
+
# Assuming you want [batch, n_chs, n_patches, emb_dim]
|
|
751
|
+
# as output, which keeps channel information
|
|
752
|
+
# This treats each patch embedding as a feature alongside channels
|
|
753
|
+
x = rearrange(
|
|
754
|
+
x,
|
|
755
|
+
pattern="(batch nchans) embed npatchs -> batch nchans npatchs embed",
|
|
756
|
+
batch=batch_size,
|
|
757
|
+
nchans=self.n_chans,
|
|
758
|
+
)
|
|
759
|
+
else:
|
|
760
|
+
x = x.view(
|
|
761
|
+
batch_size,
|
|
762
|
+
self.n_chans,
|
|
763
|
+
self.n_times // self.patch_size,
|
|
764
|
+
self.patch_size,
|
|
765
|
+
)
|
|
766
|
+
return x
|
|
767
|
+
|
|
768
|
+
|
|
769
|
+
class _PatchEmbed(nn.Module):
|
|
770
|
+
r"""EEG to Patch Embedding for Neural Decoder mode.
|
|
771
|
+
|
|
772
|
+
This code is used when we want to apply the patch embedding
|
|
773
|
+
after the codebook layer (Neural Decoder mode).
|
|
774
|
+
|
|
775
|
+
The input is expected to be in the format (Batch, n_channels, n_times),
|
|
776
|
+
but the original LaBraM expects pre-patched data (Batch, n_channels, n_patches, patch_size).
|
|
777
|
+
This class reshapes the input to the pre-patched format, then applies a 2D
|
|
778
|
+
convolution to project this pre-patched data to the embedding dimension,
|
|
779
|
+
and finally flattens across channels to produce a unified embedding.
|
|
780
|
+
|
|
781
|
+
Parameters:
|
|
782
|
+
-----------
|
|
783
|
+
n_times: int (default=2000)
|
|
784
|
+
Number of temporal components of the input tensor (used for dimension calculation).
|
|
785
|
+
patch_size: int (default=200)
|
|
786
|
+
Size of the patch, default is 1-seconds with 200Hz.
|
|
787
|
+
in_channels: int (default=1)
|
|
788
|
+
Number of input channels (from VQVAE codebook).
|
|
789
|
+
emb_dim: int (default=200)
|
|
790
|
+
Number of output embedding dimension.
|
|
791
|
+
"""
|
|
792
|
+
|
|
793
|
+
def __init__(
|
|
794
|
+
self, n_times=2000, patch_size=200, in_channels=1, emb_dim=200, n_codebooks=62
|
|
795
|
+
):
|
|
796
|
+
super().__init__()
|
|
797
|
+
self.n_times = n_times
|
|
798
|
+
self.patch_size = patch_size
|
|
799
|
+
self.patch_shape = (1, self.n_times // self.patch_size)
|
|
800
|
+
self.n_patchs = self.n_times // self.patch_size
|
|
801
|
+
self.emb_dim = emb_dim
|
|
802
|
+
self.in_channels = in_channels
|
|
803
|
+
|
|
804
|
+
# 2D Conv to project the pre-patched data
|
|
805
|
+
# Input: (Batch, in_channels, n_patches, patch_size)
|
|
806
|
+
# After proj: (Batch, emb_dim, n_patches, 1)
|
|
807
|
+
self.proj = nn.Conv2d(
|
|
808
|
+
in_channels=in_channels,
|
|
809
|
+
out_channels=emb_dim,
|
|
810
|
+
kernel_size=(1, self.patch_size),
|
|
811
|
+
stride=(1, self.patch_size),
|
|
812
|
+
)
|
|
813
|
+
|
|
814
|
+
def forward(self, x):
|
|
815
|
+
"""
|
|
816
|
+
Apply the temporal projection to the input tensor after grouping channels.
|
|
817
|
+
|
|
818
|
+
Parameters
|
|
819
|
+
----------
|
|
820
|
+
x : torch.Tensor
|
|
821
|
+
Input tensor of shape (Batch, n_channels, n_times) or
|
|
822
|
+
(Batch, n_channels, n_patches, patch_size).
|
|
823
|
+
|
|
824
|
+
Returns
|
|
825
|
+
-------
|
|
826
|
+
torch.Tensor
|
|
827
|
+
Output tensor of shape (Batch, n_patchs, emb_dim).
|
|
828
|
+
"""
|
|
829
|
+
if x.ndim == 4:
|
|
830
|
+
batch_size, n_channels, n_patchs, patch_len = x.shape
|
|
831
|
+
if patch_len != self.patch_size:
|
|
832
|
+
raise ValueError(
|
|
833
|
+
"When providing a 4D tensor, the last dimension "
|
|
834
|
+
f"({patch_len}) must match patch_size ({self.patch_size})."
|
|
835
|
+
)
|
|
836
|
+
n_times = n_patchs * patch_len
|
|
837
|
+
x = x.reshape(batch_size, n_channels, n_times)
|
|
838
|
+
elif x.ndim == 3:
|
|
839
|
+
batch_size, n_channels, n_times = x.shape
|
|
840
|
+
else:
|
|
841
|
+
raise ValueError(
|
|
842
|
+
"Input must be either 3D (batch, channels, times) or "
|
|
843
|
+
"4D (batch, channels, n_patches, patch_size)."
|
|
844
|
+
)
|
|
845
|
+
|
|
846
|
+
if n_times % self.patch_size != 0:
|
|
847
|
+
raise ValueError(
|
|
848
|
+
f"n_times ({n_times}) must be divisible by patch_size ({self.patch_size})."
|
|
849
|
+
)
|
|
850
|
+
if n_channels % self.in_channels != 0:
|
|
851
|
+
raise ValueError(
|
|
852
|
+
"The input channel dimension "
|
|
853
|
+
f"({n_channels}) must be divisible by in_channels ({self.in_channels})."
|
|
854
|
+
)
|
|
855
|
+
|
|
856
|
+
group_size = n_channels // self.in_channels
|
|
857
|
+
|
|
858
|
+
# Reshape so Conv2d sees `in_channels` feature maps and uses the grouped
|
|
859
|
+
# EEG channels as the spatial height dimension.
|
|
860
|
+
# Shape after view: (Batch, in_channels, group_size, n_times)
|
|
861
|
+
x = x.view(batch_size, self.in_channels, group_size, n_times)
|
|
862
|
+
|
|
863
|
+
# Apply the temporal projection per group.
|
|
864
|
+
# Output shape: (Batch, emb_dim, group_size, n_patchs)
|
|
865
|
+
x = self.proj(x)
|
|
866
|
+
|
|
867
|
+
# THIS IS braindecode's MODIFICATION:
|
|
868
|
+
# Average over the grouped channel dimension and permute to (Batch, n_patchs, emb_dim)
|
|
869
|
+
x = x.mean(dim=2)
|
|
870
|
+
x = x.transpose(1, 2).contiguous()
|
|
871
|
+
|
|
872
|
+
return x
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
class _Attention(nn.Module):
|
|
876
|
+
r"""
|
|
877
|
+
Attention with the options of Window-based multi-head self attention (W-MSA).
|
|
878
|
+
|
|
879
|
+
This code is strong inspired by:
|
|
880
|
+
https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py#L77
|
|
881
|
+
|
|
882
|
+
Basically, the attention module is a linear layer that takes the input
|
|
883
|
+
tensor and returns the output tensor. The input tensor is first passed
|
|
884
|
+
through a linear layer to get the query, key, and value tensors. Then,
|
|
885
|
+
the query tensor is multiplied by the scale factor and the result is
|
|
886
|
+
multiplied by the transpose of the key tensor.
|
|
887
|
+
|
|
888
|
+
The flag window_size is used to determine if the attention is
|
|
889
|
+
window-based or not.
|
|
890
|
+
|
|
891
|
+
Parameters:
|
|
892
|
+
-----------
|
|
893
|
+
dim: int
|
|
894
|
+
Number of input features.
|
|
895
|
+
num_heads: int (default=8)
|
|
896
|
+
Number of attention heads.
|
|
897
|
+
qkv_bias: bool (default=False)
|
|
898
|
+
If True, add a learnable bias to the query, key, and value tensors.
|
|
899
|
+
qk_norm: nn.LayerNorm (default=None)
|
|
900
|
+
If not None, apply LayerNorm to the query and key tensors.
|
|
901
|
+
qk_scale: float (default=None)
|
|
902
|
+
If not None, use this value as the scale factor. If None,
|
|
903
|
+
use head_dim**-0.5, where head_dim = dim // num_heads.
|
|
904
|
+
attn_drop: float (default=0.0)
|
|
905
|
+
Dropout rate for the attention weights.
|
|
906
|
+
proj_drop: float (default=0.0)
|
|
907
|
+
Dropout rate for the output tensor.
|
|
908
|
+
window_size: bool (default=None)
|
|
909
|
+
If not None, use window-based multi-head self attention based on Swin Transformer.
|
|
910
|
+
attn_head_dim: int (default=None)
|
|
911
|
+
If not None, use this value as the head_dim. If None, use dim // num_heads.
|
|
912
|
+
"""
|
|
913
|
+
|
|
914
|
+
def __init__(
|
|
915
|
+
self,
|
|
916
|
+
dim,
|
|
917
|
+
num_heads=8,
|
|
918
|
+
qkv_bias=False,
|
|
919
|
+
qk_norm=None,
|
|
920
|
+
qk_scale=None,
|
|
921
|
+
attn_drop=0.0,
|
|
922
|
+
proj_drop=0.0,
|
|
923
|
+
window_size=None,
|
|
924
|
+
attn_head_dim=None,
|
|
925
|
+
):
|
|
926
|
+
super().__init__()
|
|
927
|
+
self.num_heads = num_heads
|
|
928
|
+
head_dim = dim // num_heads
|
|
929
|
+
if attn_head_dim is not None:
|
|
930
|
+
head_dim = attn_head_dim
|
|
931
|
+
all_head_dim = head_dim * self.num_heads
|
|
932
|
+
self.scale = qk_scale or head_dim**-0.5
|
|
933
|
+
|
|
934
|
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
|
935
|
+
if qkv_bias:
|
|
936
|
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
|
937
|
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
|
938
|
+
else:
|
|
939
|
+
self.q_bias = None
|
|
940
|
+
self.v_bias = None
|
|
941
|
+
|
|
942
|
+
if qk_norm is not None:
|
|
943
|
+
self.q_norm = qk_norm(head_dim)
|
|
944
|
+
self.k_norm = qk_norm(head_dim)
|
|
945
|
+
else:
|
|
946
|
+
self.q_norm = None
|
|
947
|
+
self.k_norm = None
|
|
948
|
+
|
|
949
|
+
if window_size:
|
|
950
|
+
self.window_size = window_size
|
|
951
|
+
self.num_relative_distance = (2 * window_size[0] - 1) * (
|
|
952
|
+
2 * window_size[1] - 1
|
|
953
|
+
) + 3
|
|
954
|
+
self.relative_position_bias_table = nn.Parameter(
|
|
955
|
+
torch.zeros(self.num_relative_distance, num_heads)
|
|
956
|
+
) # 2*Wh-1 * 2*Ww-1, nH
|
|
957
|
+
# cls to token & token 2 cls & cls to cls
|
|
958
|
+
|
|
959
|
+
# get pair-wise relative position index for each token inside the window
|
|
960
|
+
coords_h = torch.arange(window_size[0])
|
|
961
|
+
coords_w = torch.arange(window_size[1])
|
|
962
|
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
|
963
|
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
|
964
|
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
|
965
|
+
# 2, Wh*Ww, Wh*Ww
|
|
966
|
+
relative_coords = relative_coords.permute(
|
|
967
|
+
1, 2, 0
|
|
968
|
+
).contiguous() # Wh*Ww, Wh*Ww, 2
|
|
969
|
+
relative_coords[:, :, 0] += window_size[0] - 1
|
|
970
|
+
# shift to start from 0
|
|
971
|
+
relative_coords[:, :, 1] += window_size[1] - 1
|
|
972
|
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
|
973
|
+
relative_position_index = torch.zeros(
|
|
974
|
+
size=(window_size[0] * window_size[1] + 1,) * 2,
|
|
975
|
+
dtype=relative_coords.dtype,
|
|
976
|
+
)
|
|
977
|
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
|
978
|
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
|
979
|
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
|
980
|
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
|
981
|
+
|
|
982
|
+
self.register_buffer("relative_position_index", relative_position_index)
|
|
983
|
+
else:
|
|
984
|
+
self.window_size = None
|
|
985
|
+
self.relative_position_bias_table = None
|
|
986
|
+
self.relative_position_index = None
|
|
987
|
+
|
|
988
|
+
self.attn_drop = nn.Dropout(attn_drop)
|
|
989
|
+
self.proj = nn.Linear(all_head_dim, dim)
|
|
990
|
+
self.proj_drop = nn.Dropout(proj_drop)
|
|
991
|
+
|
|
992
|
+
def forward(
|
|
993
|
+
self,
|
|
994
|
+
x: torch.Tensor,
|
|
995
|
+
return_attention=False,
|
|
996
|
+
return_qkv=False,
|
|
997
|
+
):
|
|
998
|
+
"""
|
|
999
|
+
Apply the attention mechanism to the input tensor.
|
|
1000
|
+
|
|
1001
|
+
Parameters:
|
|
1002
|
+
-----------
|
|
1003
|
+
x: torch.Tensor
|
|
1004
|
+
Input tensor of shape (Batch, N, C).
|
|
1005
|
+
return_attention: bool (default=False)
|
|
1006
|
+
If True, return the attention weights.
|
|
1007
|
+
return_qkv: bool (default=False)
|
|
1008
|
+
If True, return the query, key, and value tensors together with
|
|
1009
|
+
the output tensor.
|
|
1010
|
+
Returns:
|
|
1011
|
+
--------
|
|
1012
|
+
x: torch.Tensor
|
|
1013
|
+
Output tensor of shape (Batch, N, C).
|
|
1014
|
+
qkv: torch.Tensor (optional)
|
|
1015
|
+
Query, key, and value tensors of shape
|
|
1016
|
+
(Batch, N, 3, num_heads, C // num_heads).
|
|
1017
|
+
"""
|
|
1018
|
+
B, N, _ = x.shape
|
|
1019
|
+
qkv_bias = None
|
|
1020
|
+
if self.q_bias is not None:
|
|
1021
|
+
qkv_bias = torch.cat(
|
|
1022
|
+
(
|
|
1023
|
+
self.q_bias,
|
|
1024
|
+
torch.zeros_like(self.v_bias, requires_grad=False),
|
|
1025
|
+
self.v_bias,
|
|
1026
|
+
)
|
|
1027
|
+
)
|
|
1028
|
+
qkv = nn.functional.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
|
1029
|
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
|
1030
|
+
q, k, v = (
|
|
1031
|
+
qkv[0],
|
|
1032
|
+
qkv[1],
|
|
1033
|
+
qkv[2],
|
|
1034
|
+
) # make torchscript happy (cannot use tensor as tuple) (B, H, N, C)
|
|
1035
|
+
if self.q_norm is not None:
|
|
1036
|
+
q = self.q_norm(q).type_as(v)
|
|
1037
|
+
if self.k_norm is not None:
|
|
1038
|
+
k = self.k_norm(k).type_as(v)
|
|
1039
|
+
|
|
1040
|
+
q = q * self.scale
|
|
1041
|
+
attn = q @ k.transpose(-2, -1)
|
|
1042
|
+
|
|
1043
|
+
if self.relative_position_bias_table is not None:
|
|
1044
|
+
relative_position_bias = self.relative_position_bias_table[
|
|
1045
|
+
self.relative_position_index.view(-1)
|
|
1046
|
+
].view(
|
|
1047
|
+
self.window_size[0] * self.window_size[1] + 1,
|
|
1048
|
+
self.window_size[0] * self.window_size[1] + 1,
|
|
1049
|
+
-1,
|
|
1050
|
+
) # Wh*Ww,Wh*Ww,nH
|
|
1051
|
+
relative_position_bias = relative_position_bias.permute(
|
|
1052
|
+
2, 0, 1
|
|
1053
|
+
).contiguous() # nH, Wh*Ww, Wh*Ww
|
|
1054
|
+
attn = attn + relative_position_bias.unsqueeze(0)
|
|
1055
|
+
|
|
1056
|
+
attn = attn.softmax(dim=-1)
|
|
1057
|
+
attn = self.attn_drop(attn)
|
|
1058
|
+
|
|
1059
|
+
if return_attention:
|
|
1060
|
+
return attn
|
|
1061
|
+
|
|
1062
|
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
|
1063
|
+
|
|
1064
|
+
x = self.proj(x)
|
|
1065
|
+
x = self.proj_drop(x)
|
|
1066
|
+
|
|
1067
|
+
if return_qkv:
|
|
1068
|
+
return x, qkv
|
|
1069
|
+
|
|
1070
|
+
return x
|
|
1071
|
+
|
|
1072
|
+
|
|
1073
|
+
class _WindowsAttentionBlock(nn.Module):
|
|
1074
|
+
r"""Blocks of Windows Attention with Layer norm and MLP.
|
|
1075
|
+
|
|
1076
|
+
Notes: This code is strong inspired by:
|
|
1077
|
+
BeiTv2 from Microsoft.
|
|
1078
|
+
|
|
1079
|
+
Parameters:
|
|
1080
|
+
-----------
|
|
1081
|
+
dim: int
|
|
1082
|
+
Number of input features.
|
|
1083
|
+
num_heads: int (default=8)
|
|
1084
|
+
Number of attention heads.
|
|
1085
|
+
mlp_ratio: float (default=4.0)
|
|
1086
|
+
Ratio to increase the hidden features from input features in the MLP layer
|
|
1087
|
+
qkv_bias: bool (default=False)
|
|
1088
|
+
If True, add a learnable bias to the query, key, and value tensors.
|
|
1089
|
+
qk_norm: nn.LayerNorm (default=None)
|
|
1090
|
+
If not None, apply LayerNorm to the query and key tensors.
|
|
1091
|
+
qk_scale: float (default=None)
|
|
1092
|
+
If not None, use this value as the scale factor. If None,
|
|
1093
|
+
use head_dim**-0.5, where head_dim = dim // num_heads.
|
|
1094
|
+
drop: float (default=0.0)
|
|
1095
|
+
Dropout rate for the output tensor.
|
|
1096
|
+
attn_drop: float (default=0.0)
|
|
1097
|
+
Dropout rate for the attention weights.
|
|
1098
|
+
drop_path: float (default=0.0)
|
|
1099
|
+
Dropout rate for the output tensor.
|
|
1100
|
+
init_values: float (default=None)
|
|
1101
|
+
If not None, use this value to initialize the gamma_1 and gamma_2
|
|
1102
|
+
parameters.
|
|
1103
|
+
activation: nn.GELU (default)
|
|
1104
|
+
Activation function.
|
|
1105
|
+
norm_layer: nn.LayerNorm (default)
|
|
1106
|
+
Normalization layer.
|
|
1107
|
+
window_size: bool (default=None)
|
|
1108
|
+
If not None, use window-based multi-head self attention based on
|
|
1109
|
+
Swin Transformer.
|
|
1110
|
+
attn_head_dim: int (default=None)
|
|
1111
|
+
If not None, use this value as the head_dim. If None,
|
|
1112
|
+
the classes use dim // num_heads
|
|
1113
|
+
|
|
1114
|
+
Returns:
|
|
1115
|
+
--------
|
|
1116
|
+
x: torch.Tensor
|
|
1117
|
+
Output tensor of shape (Batch, N, C). [I think]
|
|
1118
|
+
|
|
1119
|
+
"""
|
|
1120
|
+
|
|
1121
|
+
def __init__(
|
|
1122
|
+
self,
|
|
1123
|
+
dim: int,
|
|
1124
|
+
num_heads: int,
|
|
1125
|
+
mlp_ratio=4.0,
|
|
1126
|
+
qkv_bias=False,
|
|
1127
|
+
qk_norm=None,
|
|
1128
|
+
qk_scale=None,
|
|
1129
|
+
drop=0.0,
|
|
1130
|
+
attn_drop=0.0,
|
|
1131
|
+
drop_path=0.0,
|
|
1132
|
+
init_values=None,
|
|
1133
|
+
activation: type[nn.Module] = nn.GELU,
|
|
1134
|
+
norm_layer=nn.LayerNorm,
|
|
1135
|
+
window_size=None,
|
|
1136
|
+
attn_head_dim=None,
|
|
1137
|
+
):
|
|
1138
|
+
super().__init__()
|
|
1139
|
+
self.norm1 = norm_layer(dim)
|
|
1140
|
+
self.attn = _Attention(
|
|
1141
|
+
dim,
|
|
1142
|
+
num_heads=num_heads,
|
|
1143
|
+
qkv_bias=qkv_bias,
|
|
1144
|
+
qk_norm=qk_norm,
|
|
1145
|
+
qk_scale=qk_scale,
|
|
1146
|
+
attn_drop=attn_drop,
|
|
1147
|
+
proj_drop=drop,
|
|
1148
|
+
window_size=window_size,
|
|
1149
|
+
attn_head_dim=attn_head_dim,
|
|
1150
|
+
)
|
|
1151
|
+
|
|
1152
|
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
1153
|
+
self.norm2 = norm_layer(dim)
|
|
1154
|
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
1155
|
+
self.mlp = MLP(
|
|
1156
|
+
in_features=dim,
|
|
1157
|
+
hidden_features=[mlp_hidden_dim],
|
|
1158
|
+
activation=activation,
|
|
1159
|
+
drop=drop,
|
|
1160
|
+
)
|
|
1161
|
+
|
|
1162
|
+
if init_values is not None and init_values > 0:
|
|
1163
|
+
self.gamma_1 = nn.Parameter(
|
|
1164
|
+
init_values * torch.ones((dim)), requires_grad=True
|
|
1165
|
+
)
|
|
1166
|
+
self.gamma_2 = nn.Parameter(
|
|
1167
|
+
init_values * torch.ones((dim)), requires_grad=True
|
|
1168
|
+
)
|
|
1169
|
+
else:
|
|
1170
|
+
self.gamma_1, self.gamma_2 = None, None
|
|
1171
|
+
|
|
1172
|
+
def forward(self, x, return_attention=False, return_qkv=False):
|
|
1173
|
+
"""
|
|
1174
|
+
Apply the attention mechanism to the input tensor.
|
|
1175
|
+
Parameters
|
|
1176
|
+
----------
|
|
1177
|
+
x: torch.Tensor
|
|
1178
|
+
Input tensor of shape (Batch, chs, npatchs, patch).
|
|
1179
|
+
return_attention: bool (default=False)
|
|
1180
|
+
If True, return the attention weights.
|
|
1181
|
+
return_qkv: bool (default=False)
|
|
1182
|
+
If True, return the query, key, and value tensors together with
|
|
1183
|
+
the output tensor.
|
|
1184
|
+
|
|
1185
|
+
Returns
|
|
1186
|
+
-------
|
|
1187
|
+
torch.Tensor
|
|
1188
|
+
Output tensor of shape (Batch, chs, npatchs, patch).
|
|
1189
|
+
"""
|
|
1190
|
+
|
|
1191
|
+
if return_attention:
|
|
1192
|
+
return self.attn(self.norm1(x), return_attention=True)
|
|
1193
|
+
if return_qkv:
|
|
1194
|
+
y, qkv = self.attn(self.norm1(x), return_qkv=return_qkv)
|
|
1195
|
+
x = x + self.drop_path(self.gamma_1 * y)
|
|
1196
|
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
|
1197
|
+
return x, qkv
|
|
1198
|
+
|
|
1199
|
+
if self.gamma_1 is None:
|
|
1200
|
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
|
1201
|
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
1202
|
+
else:
|
|
1203
|
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
|
|
1204
|
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
|
1205
|
+
return x
|
|
1206
|
+
|
|
1207
|
+
|
|
1208
|
+
class _TemporalConv(nn.Module):
|
|
1209
|
+
r"""
|
|
1210
|
+
Temporal Convolutional Module inspired by Visual Transformer.
|
|
1211
|
+
|
|
1212
|
+
In this module we apply the follow steps three times repeatedly
|
|
1213
|
+
to the input tensor, reducing the temporal dimension only in the first.
|
|
1214
|
+
- Apply a 2D convolution.
|
|
1215
|
+
- Apply a GELU activation function.
|
|
1216
|
+
- Apply a GroupNorm with 4 groups.
|
|
1217
|
+
|
|
1218
|
+
Parameters:
|
|
1219
|
+
-----------
|
|
1220
|
+
in_chans: int (default=1)
|
|
1221
|
+
Number of input channels.
|
|
1222
|
+
out_chans: int (default=8)
|
|
1223
|
+
Number of output channels.
|
|
1224
|
+
num_groups: int (default=4)
|
|
1225
|
+
Number of groups for GroupNorm.
|
|
1226
|
+
kernel_size_1: tuple (default=(1, 15))
|
|
1227
|
+
Kernel size for the first convolution.
|
|
1228
|
+
kernel_size_2: tuple (default=(1, 3))
|
|
1229
|
+
Kernel size for the second and third convolutions.
|
|
1230
|
+
stride_1: tuple (default=(1, 8))
|
|
1231
|
+
Stride for the first convolution.
|
|
1232
|
+
padding_1: tuple (default=(0, 7))
|
|
1233
|
+
Padding for the first convolution.
|
|
1234
|
+
padding_2: tuple (default=(0, 1))
|
|
1235
|
+
Padding for the second and third convolutions.
|
|
1236
|
+
activation: nn.Module, default=nn.GELU
|
|
1237
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
1238
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.GELU``.
|
|
1239
|
+
|
|
1240
|
+
Returns:
|
|
1241
|
+
--------
|
|
1242
|
+
x: torch.Tensor
|
|
1243
|
+
Output tensor of shape (Batch, NA, Temporal Channel).
|
|
1244
|
+
"""
|
|
1245
|
+
|
|
1246
|
+
def __init__(
|
|
1247
|
+
self,
|
|
1248
|
+
in_channels=1,
|
|
1249
|
+
out_channels=8,
|
|
1250
|
+
num_groups=4,
|
|
1251
|
+
kernel_size_1=(1, 15),
|
|
1252
|
+
stride_1=(1, 8),
|
|
1253
|
+
padding_1=(0, 7),
|
|
1254
|
+
kernel_size_2=(1, 3),
|
|
1255
|
+
padding_2=(0, 1),
|
|
1256
|
+
activation: type[nn.Module] = nn.GELU,
|
|
1257
|
+
):
|
|
1258
|
+
super().__init__()
|
|
1259
|
+
|
|
1260
|
+
# Here, we use the Rearrange layer from einops to flatten the input
|
|
1261
|
+
# tensor to a 2D tensor, so we can apply 2D convolutions.
|
|
1262
|
+
self.channel_patch_flatten = Rearrange(
|
|
1263
|
+
"Batch chs npat spatch -> Batch () (chs npat) spatch"
|
|
1264
|
+
)
|
|
1265
|
+
|
|
1266
|
+
self.conv1 = nn.Conv2d(
|
|
1267
|
+
in_channels=in_channels,
|
|
1268
|
+
out_channels=out_channels,
|
|
1269
|
+
kernel_size=kernel_size_1,
|
|
1270
|
+
stride=stride_1,
|
|
1271
|
+
padding=padding_1,
|
|
1272
|
+
)
|
|
1273
|
+
self.act_layer_1 = activation()
|
|
1274
|
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)
|
|
1275
|
+
|
|
1276
|
+
self.conv2 = nn.Conv2d(
|
|
1277
|
+
in_channels=out_channels,
|
|
1278
|
+
out_channels=out_channels,
|
|
1279
|
+
kernel_size=kernel_size_2,
|
|
1280
|
+
padding=padding_2,
|
|
1281
|
+
)
|
|
1282
|
+
self.act_layer_2 = activation()
|
|
1283
|
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)
|
|
1284
|
+
|
|
1285
|
+
self.conv3 = nn.Conv2d(
|
|
1286
|
+
in_channels=out_channels,
|
|
1287
|
+
out_channels=out_channels,
|
|
1288
|
+
kernel_size=kernel_size_2,
|
|
1289
|
+
padding=padding_2,
|
|
1290
|
+
)
|
|
1291
|
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)
|
|
1292
|
+
self.act_layer_3 = activation()
|
|
1293
|
+
|
|
1294
|
+
self.transpose_temporal_channel = Rearrange("Batch C NA T -> Batch NA (T C)")
|
|
1295
|
+
|
|
1296
|
+
def forward(self, x):
|
|
1297
|
+
"""
|
|
1298
|
+
Apply 3 steps of 2D convolution, GELU activation function,
|
|
1299
|
+
and GroupNorm.
|
|
1300
|
+
|
|
1301
|
+
Parameters:
|
|
1302
|
+
-----------
|
|
1303
|
+
x: torch.Tensor
|
|
1304
|
+
Input tensor of shape (Batch, Channels, n_patchs, size_patch).
|
|
1305
|
+
|
|
1306
|
+
Returns:
|
|
1307
|
+
--------
|
|
1308
|
+
x: torch.Tensor
|
|
1309
|
+
Output tensor of shape (Batch, NA, Temporal Channel).
|
|
1310
|
+
"""
|
|
1311
|
+
x = self.channel_patch_flatten(x)
|
|
1312
|
+
x = self.act_layer_1(self.norm1(self.conv1(x)))
|
|
1313
|
+
x = self.act_layer_2(self.norm2(self.conv2(x)))
|
|
1314
|
+
x = self.act_layer_3(self.norm3(self.conv3(x)))
|
|
1315
|
+
x = self.transpose_temporal_channel(x)
|
|
1316
|
+
return x
|