braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1316 @@
1
+ """
2
+ Labram module.
3
+ Authors: Wei-Bang Jiang
4
+ Bruno Aristimunha <b.aristimunha@gmail.com>
5
+ Matthew Chen <matt.chen4260@gmail.com>
6
+ License: BSD 3 clause
7
+ """
8
+
9
+ from collections import OrderedDict
10
+ from warnings import warn
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from einops import rearrange
15
+ from einops.layers.torch import Rearrange
16
+ from torch.nn.init import trunc_normal_
17
+
18
+ from braindecode.functional import rescale_parameter
19
+ from braindecode.models.base import EEGModuleMixin
20
+ from braindecode.modules import MLP, DropPath
21
+
22
+
23
+ class Labram(EEGModuleMixin, nn.Module):
24
+ r"""Labram from Jiang, W B et al (2024) [Jiang2024]_.
25
+
26
+ :bdg-success:`Convolution` :bdg-danger:`Foundation Model`
27
+
28
+ .. figure:: https://arxiv.org/html/2405.18765v1/x1.png
29
+ :align: center
30
+ :alt: Labram Architecture.
31
+
32
+ Large Brain Model for Learning Generic Representations with Tremendous
33
+ EEG Data in BCI from [Jiang2024]_.
34
+
35
+ This is an **adaptation** of the code [Code2024]_ from the Labram model.
36
+
37
+ The model is transformer architecture with **strong** inspiration from
38
+ BEiTv2 [BeiTv2]_.
39
+
40
+ The models can be used in two modes:
41
+
42
+ - Neural Tokenizer: Design to get an embedding layers (e.g. classification).
43
+ - Neural Decoder: To extract the ampliture and phase outputs with a VQSNP.
44
+
45
+ The braindecode's modification is to allow the model to be used in
46
+ with an input shape of (batch, n_chans, n_times), if neural tokenizer
47
+ equals True. The original implementation uses (batch, n_chans, n_patches,
48
+ patch_size) as input with static segmentation of the input data.
49
+
50
+ The models have the following sequence of steps::
51
+
52
+ if neural tokenizer:
53
+ - SegmentPatch: Segment the input data in patches;
54
+ - TemporalConv: Apply a temporal convolution to the segmented data;
55
+ - Residual adding cls, temporal and position embeddings (optional);
56
+ - WindowsAttentionBlock: Apply a windows attention block to the data;
57
+ - LayerNorm: Apply layer normalization to the data;
58
+ - Linear: An head linear layer to transformer the data into classes.
59
+
60
+ else:
61
+ - PatchEmbed: Apply a patch embedding to the input data;
62
+ - Residual adding cls, temporal and position embeddings (optional);
63
+ - WindowsAttentionBlock: Apply a windows attention block to the data;
64
+ - LayerNorm: Apply layer normalization to the data;
65
+ - Linear: An head linear layer to transformer the data into classes.
66
+
67
+ .. important::
68
+ **Pre-trained Weights Available**
69
+
70
+ This model has pre-trained weights available on the Hugging Face Hub.
71
+ You can load them using:
72
+
73
+ .. code-block:: python
74
+
75
+ from braindecode.models import Labram
76
+
77
+ # Load pre-trained model from Hugging Face Hub
78
+ model = Labram.from_pretrained("braindecode/labram-pretrained")
79
+
80
+ To push your own trained model to the Hub:
81
+
82
+ .. code-block:: python
83
+
84
+ # After training your model
85
+ model.push_to_hub(
86
+ repo_id="username/my-labram-model", commit_message="Upload trained Labram model"
87
+ )
88
+
89
+ Requires installing ``braindecode[hug]`` for Hub integration.
90
+
91
+ .. versionadded:: 0.9
92
+
93
+
94
+ Examples
95
+ --------
96
+ Load pre-trained weights::
97
+
98
+ >>> import torch
99
+ >>> from braindecode.models import Labram
100
+ >>> model = Labram(n_times=1600, n_chans=64, n_outputs=4)
101
+ >>> url = "https://huggingface.co/braindecode/Labram-Braindecode/blob/main/braindecode_labram_base.pt"
102
+ >>> state = torch.hub.load_state_dict_from_url(url, progress=True)
103
+ >>> model.load_state_dict(state)
104
+
105
+
106
+ Parameters
107
+ ----------
108
+ patch_size : int
109
+ The size of the patch to be used in the patch embedding.
110
+ embed_dim : int
111
+ The dimension of the embedding.
112
+ conv_in_channels : int
113
+ The number of convolutional input channels.
114
+ conv_out_channels : int
115
+ The number of convolutional output channels.
116
+ num_layers : int (default=12)
117
+ The number of attention layers of the model.
118
+ num_heads : int (default=10)
119
+ The number of attention heads.
120
+ mlp_ratio : float (default=4.0)
121
+ The expansion ratio of the mlp layer
122
+ qkv_bias : bool (default=False)
123
+ If True, add a learnable bias to the query, key, and value tensors.
124
+ qk_norm : Pytorch Normalize layer (default=nn.LayerNorm)
125
+ If not None, apply LayerNorm to the query and key tensors.
126
+ Default is nn.LayerNorm for better weight transfer from original LaBraM.
127
+ Set to None to disable Q,K normalization.
128
+ qk_scale : float (default=None)
129
+ If not None, use this value as the scale factor. If None,
130
+ use head_dim**-0.5, where head_dim = dim // num_heads.
131
+ drop_prob : float (default=0.0)
132
+ Dropout rate for the attention weights.
133
+ attn_drop_prob : float (default=0.0)
134
+ Dropout rate for the attention weights.
135
+ drop_path_prob : float (default=0.0)
136
+ Dropout rate for the attention weights used on DropPath.
137
+ norm_layer : Pytorch Normalize layer (default=nn.LayerNorm)
138
+ The normalization layer to be used.
139
+ init_values : float (default=0.1)
140
+ If not None, use this value to initialize the gamma_1 and gamma_2
141
+ parameters for residual scaling. Default is 0.1 for better weight
142
+ transfer from original LaBraM. Set to None to disable.
143
+ use_abs_pos_emb : bool (default=True)
144
+ If True, use absolute position embedding.
145
+ use_mean_pooling : bool (default=True)
146
+ If True, use mean pooling.
147
+ init_scale : float (default=0.001)
148
+ The initial scale to be used in the parameters of the model.
149
+ neural_tokenizer : bool (default=True)
150
+ The model can be used in two modes: Neural Tokenizer or Neural Decoder.
151
+ attn_head_dim : bool (default=None)
152
+ The head dimension to be used in the attention layer, to be used only
153
+ during pre-training.
154
+ activation: nn.Module, default=nn.GELU
155
+ Activation function class to apply. Should be a PyTorch activation
156
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.GELU``.
157
+
158
+ References
159
+ ----------
160
+ .. [Jiang2024] Wei-Bang Jiang, Li-Ming Zhao, Bao-Liang Lu. 2024, May.
161
+ Large Brain Model for Learning Generic Representations with Tremendous
162
+ EEG Data in BCI. The Twelfth International Conference on Learning
163
+ Representations, ICLR.
164
+ .. [Code2024] Wei-Bang Jiang, Li-Ming Zhao, Bao-Liang Lu. 2024. Labram
165
+ Large Brain Model for Learning Generic Representations with Tremendous
166
+ EEG Data in BCI. GitHub https://github.com/935963004/LaBraM
167
+ (accessed 2024-03-02)
168
+ .. [BeiTv2] Zhiliang Peng, Li Dong, Hangbo Bao, Qixiang Ye, Furu Wei. 2024.
169
+ BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers.
170
+ arXiv:2208.06366 [cs.CV]
171
+ """
172
+
173
+ def __init__(
174
+ self,
175
+ n_times=None,
176
+ n_outputs=None,
177
+ chs_info=None,
178
+ n_chans=None,
179
+ sfreq=None,
180
+ input_window_seconds=None,
181
+ patch_size=200,
182
+ embed_dim=200,
183
+ conv_in_channels=1,
184
+ conv_out_channels=8,
185
+ num_layers=12,
186
+ num_heads=10,
187
+ mlp_ratio=4.0,
188
+ qkv_bias=False,
189
+ qk_norm: type[nn.Module] = nn.LayerNorm,
190
+ qk_scale=None,
191
+ drop_prob=0.0,
192
+ attn_drop_prob=0.0,
193
+ drop_path_prob=0.0,
194
+ norm_layer: type[nn.Module] = nn.LayerNorm,
195
+ init_values=0.1,
196
+ use_abs_pos_emb=True,
197
+ use_mean_pooling=True,
198
+ init_scale=0.001,
199
+ neural_tokenizer=True,
200
+ attn_head_dim=None,
201
+ activation: type[nn.Module] = nn.GELU,
202
+ ):
203
+ super().__init__(
204
+ n_outputs=n_outputs,
205
+ n_chans=n_chans,
206
+ chs_info=chs_info,
207
+ n_times=n_times,
208
+ input_window_seconds=input_window_seconds,
209
+ sfreq=sfreq,
210
+ )
211
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
212
+
213
+ self.patch_size = patch_size
214
+ self.num_features = self.embed_dim = embed_dim
215
+ self.neural_tokenizer = neural_tokenizer
216
+ self.init_scale = init_scale
217
+
218
+ if patch_size > self.n_times:
219
+ warn(
220
+ f"patch_size ({patch_size}) > n_times ({self.n_times}); "
221
+ f"setting patch_size = {self.n_times}.",
222
+ UserWarning,
223
+ )
224
+ self.patch_size = self.n_times
225
+ self.num_features = None
226
+ self.embed_dim = None
227
+ else:
228
+ self.patch_size = patch_size
229
+ self.n_path = self.n_times // self.patch_size
230
+
231
+ if neural_tokenizer and conv_in_channels != 1:
232
+ warn(
233
+ "The model is in Neural Tokenizer mode, but the variable "
234
+ + "`conv_in_channels` is different from the default values."
235
+ + "`conv_in_channels` is only needed for the Neural Decoder mode."
236
+ + "conv_in_channels is not used in the Neural Tokenizer mode.",
237
+ UserWarning,
238
+ )
239
+ conv_in_channels = 1
240
+ # If you can use the model in Neural Tokenizer mode,
241
+ # temporal conv layer will be use over the patched dataset
242
+ if neural_tokenizer:
243
+ self.patch_embed = nn.Sequential(
244
+ OrderedDict(
245
+ [
246
+ (
247
+ "segment_patch",
248
+ _SegmentPatch(
249
+ n_times=self.n_times,
250
+ patch_size=self.patch_size,
251
+ n_chans=self.n_chans,
252
+ emb_dim=self.patch_size,
253
+ ),
254
+ ),
255
+ (
256
+ "temporal_conv",
257
+ _TemporalConv(
258
+ out_channels=conv_out_channels, activation=activation
259
+ ),
260
+ ),
261
+ ]
262
+ )
263
+ )
264
+ else:
265
+ # If not, the model will be used as Neural Decoder mode
266
+ # So the input here will be after the VQVAE encoder
267
+ # To be used to extract the ampliture and phase outputs.
268
+ # Adding inside a Sequential to use the same convention as the
269
+ # Neural Tokenizer mode.
270
+ self.patch_embed = nn.Sequential()
271
+ self.patch_embed.add_module(
272
+ "segment_patch",
273
+ _PatchEmbed(
274
+ n_times=self.n_times,
275
+ patch_size=patch_size,
276
+ in_channels=conv_in_channels,
277
+ emb_dim=self.embed_dim,
278
+ ),
279
+ )
280
+
281
+ with torch.no_grad():
282
+ dummy = torch.zeros(1, self.n_chans, self.n_times)
283
+ out = self.patch_embed(dummy)
284
+ # out.shape for tokenizer: (1, n_chans, emb_dim)
285
+ # for decoder: (1, n_patch, patch_size, emb_dim), but we want last dim
286
+ self.embed_dim = out.shape[-1]
287
+ self.num_features = self.embed_dim
288
+
289
+ # Defining the parameters
290
+ # Creating a parameter list with cls token]
291
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
292
+ # Positional embedding and time embedding are complementary
293
+ # one is for the spatial information and the other is for the temporal
294
+ # information.
295
+ # The time embedding is used to encode something in the number of
296
+ # patches, and the position embedding is used to encode the channels'
297
+ # information.
298
+ if use_abs_pos_emb:
299
+ self.position_embedding = nn.Parameter(
300
+ torch.zeros(1, self.n_chans + 1, self.embed_dim),
301
+ requires_grad=True,
302
+ )
303
+ else:
304
+ self.position_embedding = None
305
+
306
+ self.temporal_embedding = nn.Parameter(
307
+ torch.zeros(1, self.patch_embed[0].n_patchs + 1, self.embed_dim),
308
+ requires_grad=True,
309
+ )
310
+ self.pos_drop = nn.Dropout(p=drop_prob)
311
+
312
+ dpr = [
313
+ x.item() for x in torch.linspace(0, drop_path_prob, num_layers)
314
+ ] # stochastic depth decay rule
315
+ self.blocks = nn.ModuleList(
316
+ [
317
+ _WindowsAttentionBlock(
318
+ dim=self.embed_dim,
319
+ num_heads=num_heads,
320
+ mlp_ratio=mlp_ratio,
321
+ qkv_bias=qkv_bias,
322
+ qk_norm=qk_norm,
323
+ qk_scale=qk_scale,
324
+ drop=drop_prob,
325
+ attn_drop=attn_drop_prob,
326
+ drop_path=dpr[i],
327
+ norm_layer=norm_layer,
328
+ init_values=init_values,
329
+ window_size=(
330
+ self.patch_embed[0].patch_shape
331
+ if not neural_tokenizer
332
+ else None
333
+ ),
334
+ attn_head_dim=attn_head_dim,
335
+ activation=activation,
336
+ )
337
+ for i in range(num_layers)
338
+ ]
339
+ )
340
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(self.embed_dim)
341
+ self.fc_norm = norm_layer(self.embed_dim) if use_mean_pooling else None
342
+
343
+ if self.n_outputs > 0:
344
+ self.final_layer = nn.Linear(self.embed_dim, self.n_outputs)
345
+ else:
346
+ self.final_layer = nn.Identity()
347
+
348
+ self.apply(self._init_weights)
349
+ self.fix_init_weight_and_init_embedding()
350
+
351
+ def fix_init_weight_and_init_embedding(self):
352
+ """
353
+ Fix the initial weight and the initial embedding.
354
+ Initializing with truncated normal distribution.
355
+ """
356
+ trunc_normal_(self.cls_token, std=0.02)
357
+ trunc_normal_(self.temporal_embedding, std=0.02)
358
+
359
+ if self.position_embedding is not None:
360
+ trunc_normal_(self.position_embedding, std=0.02)
361
+
362
+ if isinstance(self.final_layer, nn.Linear):
363
+ trunc_normal_(self.final_layer.weight, std=0.02)
364
+
365
+ for layer_id, layer in enumerate(self.blocks):
366
+ rescale_parameter(layer.attn.proj.weight.data, layer_id + 1)
367
+ rescale_parameter(layer.mlp[-2].weight.data, layer_id + 1)
368
+
369
+ if isinstance(self.final_layer, nn.Linear):
370
+ self.final_layer.weight.data.mul_(self.init_scale)
371
+ self.final_layer.bias.data.mul_(self.init_scale)
372
+
373
+ @staticmethod
374
+ def _init_weights(layer):
375
+ """
376
+ Initialize the weights of the model for each layer layer.
377
+
378
+ If the layer is a linear layer, the weight will be initialized
379
+ with a truncated normal distribution with std=0.02.
380
+
381
+ If m.bias is not None, the bias will be initialized with a constant
382
+ value of 0.
383
+
384
+ If the layer is a layer normalization layer, the bias will be
385
+ initialized with a constant value of 0, and the weight will be
386
+ initialized with a constant value of 1.
387
+
388
+ Parameters
389
+ ----------
390
+ m : torch.nn.Module
391
+ The layer of the pytorch model
392
+ """
393
+
394
+ if isinstance(layer, nn.Linear):
395
+ trunc_normal_(layer.weight, std=0.02)
396
+ if layer.bias is not None:
397
+ nn.init.constant_(layer.bias, 0)
398
+ elif isinstance(layer, nn.LayerNorm):
399
+ nn.init.constant_(layer.bias, 0)
400
+ nn.init.constant_(layer.weight, 1.0)
401
+
402
+ def get_num_layers(self):
403
+ """
404
+ Convenience method to get the number of layers in the model.
405
+ """
406
+ return len(self.blocks)
407
+
408
+ def forward_features(
409
+ self,
410
+ x,
411
+ input_chans=None,
412
+ return_patch_tokens=False,
413
+ return_all_tokens=False,
414
+ ):
415
+ """
416
+ Forward the features of the model.
417
+
418
+ Parameters
419
+ ----------
420
+ x : torch.Tensor
421
+ The input data with shape (batch, n_chans, n_times).
422
+ input_chans : int
423
+ The number of input channels.
424
+ return_patch_tokens : bool
425
+ Whether to return the patch tokens.
426
+ return_all_tokens : bool
427
+ Whether to return all the tokens.
428
+
429
+ Returns
430
+ -------
431
+ x : torch.Tensor
432
+ The output of the model.
433
+ """
434
+ batch_size = x.shape[0]
435
+
436
+ if self.neural_tokenizer:
437
+ # For neural tokenizer: input is (batch, n_chans, n_times)
438
+ # patch_embed returns (batch, n_chans, emb_dim)
439
+ x = self.patch_embed(x)
440
+ # x shape: (batch, n_chans, emb_dim)
441
+ n_patch = self.n_chans
442
+ temporal = self.embed_dim
443
+ else:
444
+ # For neural decoder: input is (batch, n_chans, n_times)
445
+ # patch_embed returns (batch, n_patchs, emb_dim)
446
+ x = self.patch_embed(x)
447
+ # x shape: (batch, n_patchs, emb_dim)
448
+ batch_size, n_patch, temporal = x.shape
449
+
450
+ # add the [CLS] token to the embedded patch tokens
451
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
452
+
453
+ # Concatenate cls token with patch/channel embeddings
454
+ x = torch.cat((cls_tokens, x), dim=1)
455
+
456
+ # Positional Embedding
457
+ if self.position_embedding is not None:
458
+ if self.neural_tokenizer:
459
+ # In tokenizer mode, use channel-based position embedding
460
+ if input_chans is not None:
461
+ pos_embed_used = self.position_embedding[:, input_chans]
462
+ else:
463
+ pos_embed_used = self.position_embedding
464
+
465
+ pos_embed = self._adj_position_embedding(
466
+ pos_embed_used=pos_embed_used, batch_size=batch_size
467
+ )
468
+ else:
469
+ # In decoder mode, we have different number of patches
470
+ # Adapt position embedding for n_patch patches
471
+ # Use the first n_patch+1 positions from position_embedding
472
+ n_pos = min(self.position_embedding.shape[1], n_patch + 1)
473
+ pos_embed_used = self.position_embedding[:, :n_pos, :]
474
+ pos_embed = pos_embed_used.expand(batch_size, -1, -1)
475
+
476
+ x += pos_embed
477
+
478
+ # The time embedding is added across the channels after the [CLS] token
479
+ if self.neural_tokenizer:
480
+ num_ch = self.n_chans
481
+ time_embed = self._adj_temporal_embedding(
482
+ num_ch=num_ch, batch_size=batch_size, dim_embed=temporal
483
+ )
484
+ x[:, 1:, :] += time_embed
485
+ else:
486
+ # In decoder mode, we have n_patch patches and don't need to expand
487
+ # Just broadcast the temporal embedding
488
+ if temporal is None:
489
+ temporal = self.embed_dim
490
+
491
+ # Get temporal embeddings for n_patch patches
492
+ n_time_tokens = min(n_patch, self.temporal_embedding.shape[1] - 1)
493
+ time_embed = self.temporal_embedding[
494
+ :, 1 : n_time_tokens + 1, :
495
+ ] # (1, n_patch, emb_dim)
496
+ time_embed = time_embed.expand(
497
+ batch_size, -1, -1
498
+ ) # (batch, n_patch, emb_dim)
499
+ x[:, 1:, :] += time_embed
500
+
501
+ x = self.pos_drop(x)
502
+
503
+ for blk in self.blocks:
504
+ x = blk(x)
505
+
506
+ x = self.norm(x)
507
+ if self.fc_norm is not None:
508
+ if return_all_tokens:
509
+ return self.fc_norm(x)
510
+ tokens = x[:, 1:, :]
511
+ if return_patch_tokens:
512
+ return self.fc_norm(tokens)
513
+ return self.fc_norm(tokens.mean(1))
514
+ else:
515
+ if return_all_tokens:
516
+ return x
517
+ elif return_patch_tokens:
518
+ return x[:, 1:]
519
+ return x[:, 0]
520
+
521
+ def forward(
522
+ self,
523
+ x,
524
+ input_chans=None,
525
+ return_patch_tokens=False,
526
+ return_all_tokens=False,
527
+ ):
528
+ """
529
+ Forward the input EEG data through the model.
530
+
531
+ Parameters
532
+ ----------
533
+ x: torch.Tensor
534
+ The input data with shape (batch, n_chans, n_times)
535
+ or (batch, n_chans, n_patches, patch size).
536
+ input_chans: int
537
+ An input channel to select some dimensions
538
+ return_patch_tokens: bool
539
+ Return the patch tokens
540
+ return_all_tokens: bool
541
+ Return all the tokens
542
+
543
+ Returns
544
+ -------
545
+ torch.Tensor
546
+ The output of the model with dimensions (batch, n_outputs)
547
+ """
548
+ x = self.forward_features(
549
+ x,
550
+ input_chans=input_chans,
551
+ return_patch_tokens=return_patch_tokens,
552
+ return_all_tokens=return_all_tokens,
553
+ )
554
+ x = self.final_layer(x)
555
+ return x
556
+
557
+ def get_classifier(self):
558
+ """
559
+ Get the classifier of the model.
560
+
561
+ Returns
562
+ -------
563
+ torch.nn.Module
564
+ The classifier of the head model.
565
+ """
566
+ return self.final_layer
567
+
568
+ def reset_classifier(self, n_outputs):
569
+ """
570
+ Reset the classifier with the new number of classes.
571
+
572
+ Parameters
573
+ ----------
574
+ n_outputs : int
575
+ The new number of classes.
576
+ """
577
+ self.n_outputs = n_outputs
578
+ self.final_layer = (
579
+ nn.Linear(self.emb_dim, self.n_outputs)
580
+ if self.n_outputs > 0
581
+ else nn.Identity()
582
+ )
583
+
584
+ def _adj_temporal_embedding(self, num_ch, batch_size, dim_embed=None):
585
+ """
586
+ Adjust the dimensions of the time embedding to match the
587
+ number of channels or patches.
588
+
589
+ Parameters
590
+ ----------
591
+ num_ch : int
592
+ The number of channels or number of patches.
593
+ batch_size : int
594
+ Batch size of the input data.
595
+ dim_embed : int
596
+ The embedding dimension (temporal feature dimension).
597
+
598
+ Returns
599
+ -------
600
+ temporal_embedding : torch.Tensor
601
+ The adjusted time embedding to be added across the channels
602
+ after the [CLS] token. (x[:, 1:, :] += time_embed)
603
+ """
604
+ if dim_embed is None:
605
+ cut_dimension = self.patch_size
606
+ else:
607
+ cut_dimension = min(dim_embed, self.temporal_embedding.shape[1] - 1)
608
+
609
+ # Get the temporal embedding: (1, temporal_embedding_dim, emb_size)
610
+ # Slice to cut_dimension: (1, cut_dimension, emb_size)
611
+ temporal_embedding = self.temporal_embedding[:, 1 : cut_dimension + 1, :]
612
+
613
+ # Add a new dimension to the time embedding
614
+ # e.g. (1, 5, 200) -> (1, 1, 5, 200)
615
+ temporal_embedding = temporal_embedding.unsqueeze(1)
616
+
617
+ # Expand the time embedding to match the number of channels or patches
618
+ # (1, 1, cut_dimension, 200) -> (batch_size, num_ch, cut_dimension, 200)
619
+ temporal_embedding = temporal_embedding.expand(batch_size, num_ch, -1, -1)
620
+
621
+ # Flatten the intermediate dimensions
622
+ # (batch_size, num_ch, cut_dimension, 200) -> (batch_size, num_ch * cut_dimension, 200)
623
+ temporal_embedding = temporal_embedding.flatten(1, 2)
624
+
625
+ return temporal_embedding
626
+
627
+ def _adj_position_embedding(self, pos_embed_used, batch_size):
628
+ """
629
+ Adjust the dimensions of position embedding to match the
630
+ number of patches.
631
+
632
+ Parameters
633
+ ----------
634
+ pos_embed_used : torch.Tensor
635
+ The position embedding to be adjusted.
636
+ batch_size : int
637
+ The number of batches.
638
+
639
+ Returns
640
+ -------
641
+ pos_embed : torch.Tensor
642
+ The adjusted position embedding
643
+ """
644
+ # [CLS] token has no position embedding
645
+ pos_embed = pos_embed_used[:, 1:, :]
646
+ # Adding a new dimension to the position embedding
647
+ pos_embed = pos_embed.unsqueeze(2)
648
+ # Need to expand the position embedding to match the number of
649
+ # n_patches
650
+ pos_embed = pos_embed.expand(batch_size, -1, self.patch_embed[0].n_patchs, -1)
651
+ # Flatten the intermediate dimensions,
652
+ # such as the number of patches and the "channels" dim
653
+ pos_embed = pos_embed.flatten(1, 2)
654
+ # Get the base position embedding
655
+ # This is the position embedding for the [CLS] token
656
+ base_pos = pos_embed[:, 0:1, :].expand(batch_size, -1, -1)
657
+ # Concatenate the base position embedding with the
658
+ # position embedding
659
+ pos_embed = torch.cat((base_pos, pos_embed), dim=1)
660
+ return pos_embed
661
+
662
+
663
+ class _SegmentPatch(nn.Module):
664
+ r"""Segment and Patch for EEG data.
665
+
666
+ Adapted Patch Embedding inspired in the Visual Transform approach
667
+ to extract the learned segmentor, we expect get the input shape as:
668
+ (Batch, Number of Channels, number of times points).
669
+
670
+ We apply a 2D convolution with kernel size of (1, patch_size)
671
+ and a stride of (1, patch_size).
672
+
673
+ The results output shape will be:
674
+ (Batch, Number of Channels, Number of patches, patch size).
675
+
676
+ This way, we learned a convolution to segment the input shape.
677
+
678
+ The number of patches is calculated as the number of samples divided
679
+ by the patch size.
680
+
681
+ Parameters:
682
+ -----------
683
+ n_times: int (default=2000)
684
+ Number of temporal components of the input tensor.
685
+ in_chans: int (default=1)
686
+ number of electrods from the EEG signal
687
+ emb_dim: int (default=200)
688
+ Number of n_output to be used in the convolution, here,
689
+ we used the same as patch_size.
690
+ patch_size: int (default=200)
691
+ Size of the patch, default is 1-seconds with 200Hz.
692
+ Returns:
693
+ --------
694
+ x_patched: torch.Tensor
695
+ Output tensor of shape (batch, n_chans, num_patches, emb_dim).
696
+ """
697
+
698
+ def __init__(
699
+ self, n_times=2000, patch_size=200, n_chans=1, emb_dim=200, learned_patcher=True
700
+ ):
701
+ super().__init__()
702
+
703
+ self.n_times = n_times
704
+ self.patch_size = patch_size
705
+ self.n_patchs = n_times // patch_size
706
+ self.emb_dim = emb_dim
707
+ self.n_chans = n_chans
708
+ self.learned_patcher = learned_patcher
709
+
710
+ self.patcher = nn.Conv1d(
711
+ in_channels=1,
712
+ out_channels=self.emb_dim,
713
+ kernel_size=self.patch_size,
714
+ stride=self.patch_size,
715
+ )
716
+
717
+ self.adding_extra_dim = Rearrange(
718
+ pattern="batch nchans temporal -> (batch nchans) 1 temporal"
719
+ )
720
+
721
+ def forward(self, x):
722
+ """
723
+ Using an 1D convolution to generate segments of EEG signal.
724
+
725
+ Parameters:
726
+ -----------
727
+ X: Tensor
728
+ [batch, n_chans, n_times]
729
+
730
+ Returns:
731
+ --------
732
+ X_patch: Tensor
733
+ [batch, n_chans, n_times//patch_size, patch_size]
734
+ """
735
+ batch_size, _, _ = x.shape
736
+ # Input shape: [batch, n_chs, n_times]
737
+
738
+ # First, rearrange input to treat the channel dimension 'n_chs' as
739
+ # separate 'dimension' in batch for Conv1d
740
+ # This requires reshaping x to have a height of 1 for each EEG sample.
741
+ if self.learned_patcher:
742
+ x = self.adding_extra_dim(x)
743
+
744
+ # Apply the convolution along the temporal dimension
745
+ # Conv2d output shape: [(batch*n_chs), emb_dim, n_patches]
746
+ x = self.patcher(x)
747
+
748
+ # Now, rearrange output to get back to a batch-first format,
749
+ # combining embedded patches with channel information
750
+ # Assuming you want [batch, n_chs, n_patches, emb_dim]
751
+ # as output, which keeps channel information
752
+ # This treats each patch embedding as a feature alongside channels
753
+ x = rearrange(
754
+ x,
755
+ pattern="(batch nchans) embed npatchs -> batch nchans npatchs embed",
756
+ batch=batch_size,
757
+ nchans=self.n_chans,
758
+ )
759
+ else:
760
+ x = x.view(
761
+ batch_size,
762
+ self.n_chans,
763
+ self.n_times // self.patch_size,
764
+ self.patch_size,
765
+ )
766
+ return x
767
+
768
+
769
+ class _PatchEmbed(nn.Module):
770
+ r"""EEG to Patch Embedding for Neural Decoder mode.
771
+
772
+ This code is used when we want to apply the patch embedding
773
+ after the codebook layer (Neural Decoder mode).
774
+
775
+ The input is expected to be in the format (Batch, n_channels, n_times),
776
+ but the original LaBraM expects pre-patched data (Batch, n_channels, n_patches, patch_size).
777
+ This class reshapes the input to the pre-patched format, then applies a 2D
778
+ convolution to project this pre-patched data to the embedding dimension,
779
+ and finally flattens across channels to produce a unified embedding.
780
+
781
+ Parameters:
782
+ -----------
783
+ n_times: int (default=2000)
784
+ Number of temporal components of the input tensor (used for dimension calculation).
785
+ patch_size: int (default=200)
786
+ Size of the patch, default is 1-seconds with 200Hz.
787
+ in_channels: int (default=1)
788
+ Number of input channels (from VQVAE codebook).
789
+ emb_dim: int (default=200)
790
+ Number of output embedding dimension.
791
+ """
792
+
793
+ def __init__(
794
+ self, n_times=2000, patch_size=200, in_channels=1, emb_dim=200, n_codebooks=62
795
+ ):
796
+ super().__init__()
797
+ self.n_times = n_times
798
+ self.patch_size = patch_size
799
+ self.patch_shape = (1, self.n_times // self.patch_size)
800
+ self.n_patchs = self.n_times // self.patch_size
801
+ self.emb_dim = emb_dim
802
+ self.in_channels = in_channels
803
+
804
+ # 2D Conv to project the pre-patched data
805
+ # Input: (Batch, in_channels, n_patches, patch_size)
806
+ # After proj: (Batch, emb_dim, n_patches, 1)
807
+ self.proj = nn.Conv2d(
808
+ in_channels=in_channels,
809
+ out_channels=emb_dim,
810
+ kernel_size=(1, self.patch_size),
811
+ stride=(1, self.patch_size),
812
+ )
813
+
814
+ def forward(self, x):
815
+ """
816
+ Apply the temporal projection to the input tensor after grouping channels.
817
+
818
+ Parameters
819
+ ----------
820
+ x : torch.Tensor
821
+ Input tensor of shape (Batch, n_channels, n_times) or
822
+ (Batch, n_channels, n_patches, patch_size).
823
+
824
+ Returns
825
+ -------
826
+ torch.Tensor
827
+ Output tensor of shape (Batch, n_patchs, emb_dim).
828
+ """
829
+ if x.ndim == 4:
830
+ batch_size, n_channels, n_patchs, patch_len = x.shape
831
+ if patch_len != self.patch_size:
832
+ raise ValueError(
833
+ "When providing a 4D tensor, the last dimension "
834
+ f"({patch_len}) must match patch_size ({self.patch_size})."
835
+ )
836
+ n_times = n_patchs * patch_len
837
+ x = x.reshape(batch_size, n_channels, n_times)
838
+ elif x.ndim == 3:
839
+ batch_size, n_channels, n_times = x.shape
840
+ else:
841
+ raise ValueError(
842
+ "Input must be either 3D (batch, channels, times) or "
843
+ "4D (batch, channels, n_patches, patch_size)."
844
+ )
845
+
846
+ if n_times % self.patch_size != 0:
847
+ raise ValueError(
848
+ f"n_times ({n_times}) must be divisible by patch_size ({self.patch_size})."
849
+ )
850
+ if n_channels % self.in_channels != 0:
851
+ raise ValueError(
852
+ "The input channel dimension "
853
+ f"({n_channels}) must be divisible by in_channels ({self.in_channels})."
854
+ )
855
+
856
+ group_size = n_channels // self.in_channels
857
+
858
+ # Reshape so Conv2d sees `in_channels` feature maps and uses the grouped
859
+ # EEG channels as the spatial height dimension.
860
+ # Shape after view: (Batch, in_channels, group_size, n_times)
861
+ x = x.view(batch_size, self.in_channels, group_size, n_times)
862
+
863
+ # Apply the temporal projection per group.
864
+ # Output shape: (Batch, emb_dim, group_size, n_patchs)
865
+ x = self.proj(x)
866
+
867
+ # THIS IS braindecode's MODIFICATION:
868
+ # Average over the grouped channel dimension and permute to (Batch, n_patchs, emb_dim)
869
+ x = x.mean(dim=2)
870
+ x = x.transpose(1, 2).contiguous()
871
+
872
+ return x
873
+
874
+
875
+ class _Attention(nn.Module):
876
+ r"""
877
+ Attention with the options of Window-based multi-head self attention (W-MSA).
878
+
879
+ This code is strong inspired by:
880
+ https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py#L77
881
+
882
+ Basically, the attention module is a linear layer that takes the input
883
+ tensor and returns the output tensor. The input tensor is first passed
884
+ through a linear layer to get the query, key, and value tensors. Then,
885
+ the query tensor is multiplied by the scale factor and the result is
886
+ multiplied by the transpose of the key tensor.
887
+
888
+ The flag window_size is used to determine if the attention is
889
+ window-based or not.
890
+
891
+ Parameters:
892
+ -----------
893
+ dim: int
894
+ Number of input features.
895
+ num_heads: int (default=8)
896
+ Number of attention heads.
897
+ qkv_bias: bool (default=False)
898
+ If True, add a learnable bias to the query, key, and value tensors.
899
+ qk_norm: nn.LayerNorm (default=None)
900
+ If not None, apply LayerNorm to the query and key tensors.
901
+ qk_scale: float (default=None)
902
+ If not None, use this value as the scale factor. If None,
903
+ use head_dim**-0.5, where head_dim = dim // num_heads.
904
+ attn_drop: float (default=0.0)
905
+ Dropout rate for the attention weights.
906
+ proj_drop: float (default=0.0)
907
+ Dropout rate for the output tensor.
908
+ window_size: bool (default=None)
909
+ If not None, use window-based multi-head self attention based on Swin Transformer.
910
+ attn_head_dim: int (default=None)
911
+ If not None, use this value as the head_dim. If None, use dim // num_heads.
912
+ """
913
+
914
+ def __init__(
915
+ self,
916
+ dim,
917
+ num_heads=8,
918
+ qkv_bias=False,
919
+ qk_norm=None,
920
+ qk_scale=None,
921
+ attn_drop=0.0,
922
+ proj_drop=0.0,
923
+ window_size=None,
924
+ attn_head_dim=None,
925
+ ):
926
+ super().__init__()
927
+ self.num_heads = num_heads
928
+ head_dim = dim // num_heads
929
+ if attn_head_dim is not None:
930
+ head_dim = attn_head_dim
931
+ all_head_dim = head_dim * self.num_heads
932
+ self.scale = qk_scale or head_dim**-0.5
933
+
934
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
935
+ if qkv_bias:
936
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
937
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
938
+ else:
939
+ self.q_bias = None
940
+ self.v_bias = None
941
+
942
+ if qk_norm is not None:
943
+ self.q_norm = qk_norm(head_dim)
944
+ self.k_norm = qk_norm(head_dim)
945
+ else:
946
+ self.q_norm = None
947
+ self.k_norm = None
948
+
949
+ if window_size:
950
+ self.window_size = window_size
951
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
952
+ 2 * window_size[1] - 1
953
+ ) + 3
954
+ self.relative_position_bias_table = nn.Parameter(
955
+ torch.zeros(self.num_relative_distance, num_heads)
956
+ ) # 2*Wh-1 * 2*Ww-1, nH
957
+ # cls to token & token 2 cls & cls to cls
958
+
959
+ # get pair-wise relative position index for each token inside the window
960
+ coords_h = torch.arange(window_size[0])
961
+ coords_w = torch.arange(window_size[1])
962
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
963
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
964
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
965
+ # 2, Wh*Ww, Wh*Ww
966
+ relative_coords = relative_coords.permute(
967
+ 1, 2, 0
968
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
969
+ relative_coords[:, :, 0] += window_size[0] - 1
970
+ # shift to start from 0
971
+ relative_coords[:, :, 1] += window_size[1] - 1
972
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
973
+ relative_position_index = torch.zeros(
974
+ size=(window_size[0] * window_size[1] + 1,) * 2,
975
+ dtype=relative_coords.dtype,
976
+ )
977
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
978
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
979
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
980
+ relative_position_index[0, 0] = self.num_relative_distance - 1
981
+
982
+ self.register_buffer("relative_position_index", relative_position_index)
983
+ else:
984
+ self.window_size = None
985
+ self.relative_position_bias_table = None
986
+ self.relative_position_index = None
987
+
988
+ self.attn_drop = nn.Dropout(attn_drop)
989
+ self.proj = nn.Linear(all_head_dim, dim)
990
+ self.proj_drop = nn.Dropout(proj_drop)
991
+
992
+ def forward(
993
+ self,
994
+ x: torch.Tensor,
995
+ return_attention=False,
996
+ return_qkv=False,
997
+ ):
998
+ """
999
+ Apply the attention mechanism to the input tensor.
1000
+
1001
+ Parameters:
1002
+ -----------
1003
+ x: torch.Tensor
1004
+ Input tensor of shape (Batch, N, C).
1005
+ return_attention: bool (default=False)
1006
+ If True, return the attention weights.
1007
+ return_qkv: bool (default=False)
1008
+ If True, return the query, key, and value tensors together with
1009
+ the output tensor.
1010
+ Returns:
1011
+ --------
1012
+ x: torch.Tensor
1013
+ Output tensor of shape (Batch, N, C).
1014
+ qkv: torch.Tensor (optional)
1015
+ Query, key, and value tensors of shape
1016
+ (Batch, N, 3, num_heads, C // num_heads).
1017
+ """
1018
+ B, N, _ = x.shape
1019
+ qkv_bias = None
1020
+ if self.q_bias is not None:
1021
+ qkv_bias = torch.cat(
1022
+ (
1023
+ self.q_bias,
1024
+ torch.zeros_like(self.v_bias, requires_grad=False),
1025
+ self.v_bias,
1026
+ )
1027
+ )
1028
+ qkv = nn.functional.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
1029
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
1030
+ q, k, v = (
1031
+ qkv[0],
1032
+ qkv[1],
1033
+ qkv[2],
1034
+ ) # make torchscript happy (cannot use tensor as tuple) (B, H, N, C)
1035
+ if self.q_norm is not None:
1036
+ q = self.q_norm(q).type_as(v)
1037
+ if self.k_norm is not None:
1038
+ k = self.k_norm(k).type_as(v)
1039
+
1040
+ q = q * self.scale
1041
+ attn = q @ k.transpose(-2, -1)
1042
+
1043
+ if self.relative_position_bias_table is not None:
1044
+ relative_position_bias = self.relative_position_bias_table[
1045
+ self.relative_position_index.view(-1)
1046
+ ].view(
1047
+ self.window_size[0] * self.window_size[1] + 1,
1048
+ self.window_size[0] * self.window_size[1] + 1,
1049
+ -1,
1050
+ ) # Wh*Ww,Wh*Ww,nH
1051
+ relative_position_bias = relative_position_bias.permute(
1052
+ 2, 0, 1
1053
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
1054
+ attn = attn + relative_position_bias.unsqueeze(0)
1055
+
1056
+ attn = attn.softmax(dim=-1)
1057
+ attn = self.attn_drop(attn)
1058
+
1059
+ if return_attention:
1060
+ return attn
1061
+
1062
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
1063
+
1064
+ x = self.proj(x)
1065
+ x = self.proj_drop(x)
1066
+
1067
+ if return_qkv:
1068
+ return x, qkv
1069
+
1070
+ return x
1071
+
1072
+
1073
+ class _WindowsAttentionBlock(nn.Module):
1074
+ r"""Blocks of Windows Attention with Layer norm and MLP.
1075
+
1076
+ Notes: This code is strong inspired by:
1077
+ BeiTv2 from Microsoft.
1078
+
1079
+ Parameters:
1080
+ -----------
1081
+ dim: int
1082
+ Number of input features.
1083
+ num_heads: int (default=8)
1084
+ Number of attention heads.
1085
+ mlp_ratio: float (default=4.0)
1086
+ Ratio to increase the hidden features from input features in the MLP layer
1087
+ qkv_bias: bool (default=False)
1088
+ If True, add a learnable bias to the query, key, and value tensors.
1089
+ qk_norm: nn.LayerNorm (default=None)
1090
+ If not None, apply LayerNorm to the query and key tensors.
1091
+ qk_scale: float (default=None)
1092
+ If not None, use this value as the scale factor. If None,
1093
+ use head_dim**-0.5, where head_dim = dim // num_heads.
1094
+ drop: float (default=0.0)
1095
+ Dropout rate for the output tensor.
1096
+ attn_drop: float (default=0.0)
1097
+ Dropout rate for the attention weights.
1098
+ drop_path: float (default=0.0)
1099
+ Dropout rate for the output tensor.
1100
+ init_values: float (default=None)
1101
+ If not None, use this value to initialize the gamma_1 and gamma_2
1102
+ parameters.
1103
+ activation: nn.GELU (default)
1104
+ Activation function.
1105
+ norm_layer: nn.LayerNorm (default)
1106
+ Normalization layer.
1107
+ window_size: bool (default=None)
1108
+ If not None, use window-based multi-head self attention based on
1109
+ Swin Transformer.
1110
+ attn_head_dim: int (default=None)
1111
+ If not None, use this value as the head_dim. If None,
1112
+ the classes use dim // num_heads
1113
+
1114
+ Returns:
1115
+ --------
1116
+ x: torch.Tensor
1117
+ Output tensor of shape (Batch, N, C). [I think]
1118
+
1119
+ """
1120
+
1121
+ def __init__(
1122
+ self,
1123
+ dim: int,
1124
+ num_heads: int,
1125
+ mlp_ratio=4.0,
1126
+ qkv_bias=False,
1127
+ qk_norm=None,
1128
+ qk_scale=None,
1129
+ drop=0.0,
1130
+ attn_drop=0.0,
1131
+ drop_path=0.0,
1132
+ init_values=None,
1133
+ activation: type[nn.Module] = nn.GELU,
1134
+ norm_layer=nn.LayerNorm,
1135
+ window_size=None,
1136
+ attn_head_dim=None,
1137
+ ):
1138
+ super().__init__()
1139
+ self.norm1 = norm_layer(dim)
1140
+ self.attn = _Attention(
1141
+ dim,
1142
+ num_heads=num_heads,
1143
+ qkv_bias=qkv_bias,
1144
+ qk_norm=qk_norm,
1145
+ qk_scale=qk_scale,
1146
+ attn_drop=attn_drop,
1147
+ proj_drop=drop,
1148
+ window_size=window_size,
1149
+ attn_head_dim=attn_head_dim,
1150
+ )
1151
+
1152
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
1153
+ self.norm2 = norm_layer(dim)
1154
+ mlp_hidden_dim = int(dim * mlp_ratio)
1155
+ self.mlp = MLP(
1156
+ in_features=dim,
1157
+ hidden_features=[mlp_hidden_dim],
1158
+ activation=activation,
1159
+ drop=drop,
1160
+ )
1161
+
1162
+ if init_values is not None and init_values > 0:
1163
+ self.gamma_1 = nn.Parameter(
1164
+ init_values * torch.ones((dim)), requires_grad=True
1165
+ )
1166
+ self.gamma_2 = nn.Parameter(
1167
+ init_values * torch.ones((dim)), requires_grad=True
1168
+ )
1169
+ else:
1170
+ self.gamma_1, self.gamma_2 = None, None
1171
+
1172
+ def forward(self, x, return_attention=False, return_qkv=False):
1173
+ """
1174
+ Apply the attention mechanism to the input tensor.
1175
+ Parameters
1176
+ ----------
1177
+ x: torch.Tensor
1178
+ Input tensor of shape (Batch, chs, npatchs, patch).
1179
+ return_attention: bool (default=False)
1180
+ If True, return the attention weights.
1181
+ return_qkv: bool (default=False)
1182
+ If True, return the query, key, and value tensors together with
1183
+ the output tensor.
1184
+
1185
+ Returns
1186
+ -------
1187
+ torch.Tensor
1188
+ Output tensor of shape (Batch, chs, npatchs, patch).
1189
+ """
1190
+
1191
+ if return_attention:
1192
+ return self.attn(self.norm1(x), return_attention=True)
1193
+ if return_qkv:
1194
+ y, qkv = self.attn(self.norm1(x), return_qkv=return_qkv)
1195
+ x = x + self.drop_path(self.gamma_1 * y)
1196
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
1197
+ return x, qkv
1198
+
1199
+ if self.gamma_1 is None:
1200
+ x = x + self.drop_path(self.attn(self.norm1(x)))
1201
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
1202
+ else:
1203
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
1204
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
1205
+ return x
1206
+
1207
+
1208
+ class _TemporalConv(nn.Module):
1209
+ r"""
1210
+ Temporal Convolutional Module inspired by Visual Transformer.
1211
+
1212
+ In this module we apply the follow steps three times repeatedly
1213
+ to the input tensor, reducing the temporal dimension only in the first.
1214
+ - Apply a 2D convolution.
1215
+ - Apply a GELU activation function.
1216
+ - Apply a GroupNorm with 4 groups.
1217
+
1218
+ Parameters:
1219
+ -----------
1220
+ in_chans: int (default=1)
1221
+ Number of input channels.
1222
+ out_chans: int (default=8)
1223
+ Number of output channels.
1224
+ num_groups: int (default=4)
1225
+ Number of groups for GroupNorm.
1226
+ kernel_size_1: tuple (default=(1, 15))
1227
+ Kernel size for the first convolution.
1228
+ kernel_size_2: tuple (default=(1, 3))
1229
+ Kernel size for the second and third convolutions.
1230
+ stride_1: tuple (default=(1, 8))
1231
+ Stride for the first convolution.
1232
+ padding_1: tuple (default=(0, 7))
1233
+ Padding for the first convolution.
1234
+ padding_2: tuple (default=(0, 1))
1235
+ Padding for the second and third convolutions.
1236
+ activation: nn.Module, default=nn.GELU
1237
+ Activation function class to apply. Should be a PyTorch activation
1238
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.GELU``.
1239
+
1240
+ Returns:
1241
+ --------
1242
+ x: torch.Tensor
1243
+ Output tensor of shape (Batch, NA, Temporal Channel).
1244
+ """
1245
+
1246
+ def __init__(
1247
+ self,
1248
+ in_channels=1,
1249
+ out_channels=8,
1250
+ num_groups=4,
1251
+ kernel_size_1=(1, 15),
1252
+ stride_1=(1, 8),
1253
+ padding_1=(0, 7),
1254
+ kernel_size_2=(1, 3),
1255
+ padding_2=(0, 1),
1256
+ activation: type[nn.Module] = nn.GELU,
1257
+ ):
1258
+ super().__init__()
1259
+
1260
+ # Here, we use the Rearrange layer from einops to flatten the input
1261
+ # tensor to a 2D tensor, so we can apply 2D convolutions.
1262
+ self.channel_patch_flatten = Rearrange(
1263
+ "Batch chs npat spatch -> Batch () (chs npat) spatch"
1264
+ )
1265
+
1266
+ self.conv1 = nn.Conv2d(
1267
+ in_channels=in_channels,
1268
+ out_channels=out_channels,
1269
+ kernel_size=kernel_size_1,
1270
+ stride=stride_1,
1271
+ padding=padding_1,
1272
+ )
1273
+ self.act_layer_1 = activation()
1274
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)
1275
+
1276
+ self.conv2 = nn.Conv2d(
1277
+ in_channels=out_channels,
1278
+ out_channels=out_channels,
1279
+ kernel_size=kernel_size_2,
1280
+ padding=padding_2,
1281
+ )
1282
+ self.act_layer_2 = activation()
1283
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)
1284
+
1285
+ self.conv3 = nn.Conv2d(
1286
+ in_channels=out_channels,
1287
+ out_channels=out_channels,
1288
+ kernel_size=kernel_size_2,
1289
+ padding=padding_2,
1290
+ )
1291
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)
1292
+ self.act_layer_3 = activation()
1293
+
1294
+ self.transpose_temporal_channel = Rearrange("Batch C NA T -> Batch NA (T C)")
1295
+
1296
+ def forward(self, x):
1297
+ """
1298
+ Apply 3 steps of 2D convolution, GELU activation function,
1299
+ and GroupNorm.
1300
+
1301
+ Parameters:
1302
+ -----------
1303
+ x: torch.Tensor
1304
+ Input tensor of shape (Batch, Channels, n_patchs, size_patch).
1305
+
1306
+ Returns:
1307
+ --------
1308
+ x: torch.Tensor
1309
+ Output tensor of shape (Batch, NA, Temporal Channel).
1310
+ """
1311
+ x = self.channel_patch_flatten(x)
1312
+ x = self.act_layer_1(self.norm1(self.conv1(x)))
1313
+ x = self.act_layer_2(self.norm2(self.conv2(x)))
1314
+ x = self.act_layer_3(self.norm3(self.conv3(x)))
1315
+ x = self.transpose_temporal_channel(x)
1316
+ return x