braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1166 @@
1
+ """
2
+ Labram module.
3
+ Authors: Wei-Bang Jiang
4
+ Bruno Aristimunha <b.aristimunha@gmail.com>
5
+ License: BSD 3 clause
6
+ """
7
+
8
+ from collections import OrderedDict
9
+ from warnings import warn
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from einops import rearrange
14
+ from einops.layers.torch import Rearrange
15
+ from torch.nn.init import trunc_normal_
16
+
17
+ from braindecode.functional import rescale_parameter
18
+ from braindecode.models.base import EEGModuleMixin
19
+ from braindecode.modules import MLP, DropPath
20
+
21
+
22
+ class Labram(EEGModuleMixin, nn.Module):
23
+ """Labram from Jiang, W B et al (2024) [Jiang2024]_.
24
+
25
+ .. figure:: https://arxiv.org/html/2405.18765v1/x1.png
26
+ :align: center
27
+ :alt: Labram Architecture.
28
+
29
+ Large Brain Model for Learning Generic Representations with Tremendous
30
+ EEG Data in BCI from [Jiang2024]_
31
+
32
+ This is an **adaptation** of the code [Code2024]_ from the Labram model.
33
+
34
+ The model is transformer architecture with **strong** inspiration from
35
+ BEiTv2 [BeiTv2]_.
36
+
37
+ The models can be used in two modes:
38
+ - Neural Tokenizor: Design to get an embedding layers (e.g. classification).
39
+ - Neural Decoder: To extract the ampliture and phase outputs with a VQSNP.
40
+
41
+ The braindecode's modification is to allow the model to be used in
42
+ with an input shape of (batch, n_chans, n_times), if neural tokenizer
43
+ equals True. The original implementation uses (batch, n_chans, n_patches,
44
+ patch_size) as input with static segmentation of the input data.
45
+
46
+ The models have the following sequence of steps:
47
+ if neural tokenizer:
48
+ - SegmentPatch: Segment the input data in patches;
49
+ - TemporalConv: Apply a temporal convolution to the segmented data;
50
+ - Residual adding cls, temporal and position embeddings (optional);
51
+ - WindowsAttentionBlock: Apply a windows attention block to the data;
52
+ - LayerNorm: Apply layer normalization to the data;
53
+ - Linear: An head linear layer to transformer the data into classes.
54
+
55
+ else:
56
+ - PatchEmbed: Apply a patch embedding to the input data;
57
+ - Residual adding cls, temporal and position embeddings (optional);
58
+ - WindowsAttentionBlock: Apply a windows attention block to the data;
59
+ - LayerNorm: Apply layer normalization to the data;
60
+ - Linear: An head linear layer to transformer the data into classes.
61
+
62
+ .. versionadded:: 0.9
63
+
64
+ Parameters
65
+ ----------
66
+ patch_size : int
67
+ The size of the patch to be used in the patch embedding.
68
+ emb_size : int
69
+ The dimension of the embedding.
70
+ in_channels : int
71
+ The number of convolutional input channels.
72
+ out_channels : int
73
+ The number of convolutional output channels.
74
+ n_layers : int (default=12)
75
+ The number of attention layers of the model.
76
+ att_num_heads : int (default=10)
77
+ The number of attention heads.
78
+ mlp_ratio : float (default=4.0)
79
+ The expansion ratio of the mlp layer
80
+ qkv_bias : bool (default=False)
81
+ If True, add a learnable bias to the query, key, and value tensors.
82
+ qk_norm : Pytorch Normalize layer (default=None)
83
+ If not None, apply LayerNorm to the query and key tensors
84
+ qk_scale : float (default=None)
85
+ If not None, use this value as the scale factor. If None,
86
+ use head_dim**-0.5, where head_dim = dim // num_heads.
87
+ drop_prob : float (default=0.0)
88
+ Dropout rate for the attention weights.
89
+ attn_drop_prob : float (default=0.0)
90
+ Dropout rate for the attention weights.
91
+ drop_path_prob : float (default=0.0)
92
+ Dropout rate for the attention weights used on DropPath.
93
+ norm_layer : Pytorch Normalize layer (default=nn.LayerNorm)
94
+ The normalization layer to be used.
95
+ init_values : float (default=None)
96
+ If not None, use this value to initialize the gamma_1 and gamma_2
97
+ parameters.
98
+ use_abs_pos_emb : bool (default=True)
99
+ If True, use absolute position embedding.
100
+ use_mean_pooling : bool (default=True)
101
+ If True, use mean pooling.
102
+ init_scale : float (default=0.001)
103
+ The initial scale to be used in the parameters of the model.
104
+ neural_tokenizer : bool (default=True)
105
+ The model can be used in two modes: Neural Tokenizor or Neural Decoder.
106
+ attn_head_dim : bool (default=None)
107
+ The head dimension to be used in the attention layer, to be used only
108
+ during pre-training.
109
+ activation: nn.Module, default=nn.GELU
110
+ Activation function class to apply. Should be a PyTorch activation
111
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.GELU``.
112
+
113
+ References
114
+ ----------
115
+ .. [Jiang2024] Wei-Bang Jiang, Li-Ming Zhao, Bao-Liang Lu. 2024, May.
116
+ Large Brain Model for Learning Generic Representations with Tremendous
117
+ EEG Data in BCI. The Twelfth International Conference on Learning
118
+ Representations, ICLR.
119
+ .. [Code2024] Wei-Bang Jiang, Li-Ming Zhao, Bao-Liang Lu. 2024. Labram
120
+ Large Brain Model for Learning Generic Representations with Tremendous
121
+ EEG Data in BCI. GitHub https://github.com/935963004/LaBraM
122
+ (accessed 2024-03-02)
123
+ .. [BeiTv2] Zhiliang Peng, Li Dong, Hangbo Bao, Qixiang Ye, Furu Wei. 2024.
124
+ BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers.
125
+ arXiv:2208.06366 [cs.CV]
126
+ """
127
+
128
+ def __init__(
129
+ self,
130
+ n_times=None,
131
+ n_outputs=None,
132
+ chs_info=None,
133
+ n_chans=None,
134
+ sfreq=None,
135
+ input_window_seconds=None,
136
+ patch_size=200,
137
+ emb_size=200,
138
+ in_channels=1,
139
+ out_channels=8,
140
+ n_layers=12,
141
+ att_num_heads=10,
142
+ mlp_ratio=4.0,
143
+ qkv_bias=False,
144
+ qk_norm=None,
145
+ qk_scale=None,
146
+ drop_prob=0.0,
147
+ attn_drop_prob=0.0,
148
+ drop_path_prob=0.0,
149
+ norm_layer=nn.LayerNorm,
150
+ init_values=None,
151
+ use_abs_pos_emb=True,
152
+ use_mean_pooling=True,
153
+ init_scale=0.001,
154
+ neural_tokenizer=True,
155
+ attn_head_dim=None,
156
+ activation: nn.Module = nn.GELU,
157
+ ):
158
+ super().__init__(
159
+ n_outputs=n_outputs,
160
+ n_chans=n_chans,
161
+ chs_info=chs_info,
162
+ n_times=n_times,
163
+ input_window_seconds=input_window_seconds,
164
+ sfreq=sfreq,
165
+ )
166
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
167
+
168
+ self.patch_size = patch_size
169
+ self.n_path = self.n_times // patch_size
170
+ self.num_features = self.emb_size = emb_size
171
+ self.neural_tokenizer = neural_tokenizer
172
+ self.init_scale = init_scale
173
+
174
+ if neural_tokenizer and in_channels != 1:
175
+ warn(
176
+ "The model is in Neural Tokenizer mode, but the variable "
177
+ + "`in_channels` is different from the default values."
178
+ + "`in_channels` is only needed for the Neural Decoder mode."
179
+ + "in_channels is not used in the Neural Tokenizer mode.",
180
+ UserWarning,
181
+ )
182
+ in_channels = 1
183
+ # If you can use the model in Neural Tokenizer mode,
184
+ # temporal conv layer will be use over the patched dataset
185
+ if neural_tokenizer:
186
+ self.patch_embed = nn.Sequential(
187
+ OrderedDict(
188
+ [
189
+ (
190
+ "segment_patch",
191
+ _SegmentPatch(
192
+ n_times=self.n_times,
193
+ patch_size=self.patch_size,
194
+ n_chans=self.n_chans,
195
+ emb_dim=self.patch_size,
196
+ ),
197
+ ),
198
+ (
199
+ "temporal_conv",
200
+ _TemporalConv(
201
+ out_channels=out_channels, activation=activation
202
+ ),
203
+ ),
204
+ ]
205
+ )
206
+ )
207
+ else:
208
+ # If not, the model will be used as Neural Decoder mode
209
+ # So the input here will be after the VQVAE encoder
210
+ # To be used to extract the ampliture and phase outputs.
211
+ # Adding inside a Sequential to use the same convention as the
212
+ # Neural Tokenizer mode.
213
+ self.patch_embed = nn.Sequential()
214
+ self.patch_embed.add_module(
215
+ "segment_patch",
216
+ _PatchEmbed(
217
+ n_times=self.n_times,
218
+ patch_size=patch_size,
219
+ in_channels=in_channels,
220
+ emb_dim=self.emb_size,
221
+ ),
222
+ )
223
+ # Defining the parameters
224
+ # Creating a parameter list with cls token
225
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.emb_size))
226
+ # Positional embedding and time embedding are complementary
227
+ # one is for the spatial information and the other is for the temporal
228
+ # information.
229
+ # The time embedding is used to encode something in the number of
230
+ # patches, and the position embedding is used to encode the channels'
231
+ # information.
232
+ if use_abs_pos_emb:
233
+ self.position_embedding = nn.Parameter(
234
+ torch.zeros(1, self.n_chans + 1, self.emb_size),
235
+ requires_grad=True,
236
+ )
237
+ else:
238
+ self.position_embedding = None
239
+
240
+ self.temporal_embedding = nn.Parameter(
241
+ torch.zeros(1, self.patch_embed[0].n_patchs + 1, self.emb_size),
242
+ requires_grad=True,
243
+ )
244
+ self.pos_drop = nn.Dropout(p=drop_prob)
245
+
246
+ dpr = [
247
+ x.item() for x in torch.linspace(0, drop_path_prob, n_layers)
248
+ ] # stochastic depth decay rule
249
+ self.blocks = nn.ModuleList(
250
+ [
251
+ _WindowsAttentionBlock(
252
+ dim=self.emb_size,
253
+ num_heads=att_num_heads,
254
+ mlp_ratio=mlp_ratio,
255
+ qkv_bias=qkv_bias,
256
+ qk_norm=qk_norm,
257
+ qk_scale=qk_scale,
258
+ drop=drop_prob,
259
+ attn_drop=attn_drop_prob,
260
+ drop_path=dpr[i],
261
+ norm_layer=norm_layer,
262
+ init_values=init_values,
263
+ window_size=(
264
+ self.patch_embed[0].patch_shape
265
+ if not neural_tokenizer
266
+ else None
267
+ ),
268
+ attn_head_dim=attn_head_dim,
269
+ activation=activation,
270
+ )
271
+ for i in range(n_layers)
272
+ ]
273
+ )
274
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(self.emb_size)
275
+ self.fc_norm = norm_layer(self.emb_size) if use_mean_pooling else None
276
+
277
+ if self.n_outputs > 0:
278
+ self.final_layer = nn.Linear(self.emb_size, self.n_outputs)
279
+ else:
280
+ self.final_layer = nn.Identity()
281
+
282
+ self.apply(self._init_weights)
283
+ self.fix_init_weight_and_init_embedding()
284
+
285
+ def fix_init_weight_and_init_embedding(self):
286
+ """
287
+ Fix the initial weight and the initial embedding.
288
+ Initializing with truncated normal distribution.
289
+ """
290
+ trunc_normal_(self.cls_token, std=0.02)
291
+ trunc_normal_(self.temporal_embedding, std=0.02)
292
+
293
+ if self.position_embedding is not None:
294
+ trunc_normal_(self.position_embedding, std=0.02)
295
+
296
+ if isinstance(self.final_layer, nn.Linear):
297
+ trunc_normal_(self.final_layer.weight, std=0.02)
298
+
299
+ for layer_id, layer in enumerate(self.blocks):
300
+ rescale_parameter(layer.attn.proj.weight.data, layer_id + 1)
301
+ rescale_parameter(layer.mlp[-2].weight.data, layer_id + 1)
302
+
303
+ if isinstance(self.final_layer, nn.Linear):
304
+ self.final_layer.weight.data.mul_(self.init_scale)
305
+ self.final_layer.bias.data.mul_(self.init_scale)
306
+
307
+ @staticmethod
308
+ def _init_weights(layer):
309
+ """
310
+ Initialize the weights of the model for each layer layer.
311
+
312
+ If the layer is a linear layer, the weight will be initialized
313
+ with a truncated normal distribution with std=0.02.
314
+
315
+ If m.bias is not None, the bias will be initialized with a constant
316
+ value of 0.
317
+
318
+ If the layer is a layer normalization layer, the bias will be
319
+ initialized with a constant value of 0, and the weight will be
320
+ initialized with a constant value of 1.
321
+
322
+ Parameters
323
+ ----------
324
+ m : torch.nn.Module
325
+ The layer of the pytorch model
326
+ """
327
+
328
+ if isinstance(layer, nn.Linear):
329
+ trunc_normal_(layer.weight, std=0.02)
330
+ if layer.bias is not None:
331
+ nn.init.constant_(layer.bias, 0)
332
+ elif isinstance(layer, nn.LayerNorm):
333
+ nn.init.constant_(layer.bias, 0)
334
+ nn.init.constant_(layer.weight, 1.0)
335
+
336
+ def get_num_layers(self):
337
+ """
338
+ Convenience method to get the number of layers in the model.
339
+ """
340
+ return len(self.blocks)
341
+
342
+ def forward_features(
343
+ self,
344
+ x,
345
+ input_chans=None,
346
+ return_patch_tokens=False,
347
+ return_all_tokens=False,
348
+ ):
349
+ """
350
+ Forward the features of the model.
351
+
352
+ Parameters
353
+ ----------
354
+ x : torch.Tensor
355
+ The input data with shape (batch, n_chans, n_patches, patch size),
356
+ if neural decoder or (batch, n_chans, n_times), if neural tokenizer.
357
+ input_chans : int
358
+ The number of input channels.
359
+ return_patch_tokens : bool
360
+ Whether to return the patch tokens.
361
+ return_all_tokens : bool
362
+ Whether to return all the tokens.
363
+
364
+ Returns
365
+ -------
366
+ x : torch.Tensor
367
+ The output of the model.
368
+ """
369
+
370
+ if self.neural_tokenizer:
371
+ batch_size, nch, n_patch, temporal = self.patch_embed.segment_patch(x).shape
372
+ else:
373
+ batch_size, nch, n_patch = self.patch_embed(x).shape
374
+ x = self.patch_embed(x)
375
+ # add the [CLS] token to the embedded patch tokens
376
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
377
+
378
+ x = torch.cat((cls_tokens, x), dim=1)
379
+
380
+ # Positional Embedding
381
+ if input_chans is not None:
382
+ pos_embed_used = self.position_embedding[:, input_chans]
383
+ else:
384
+ pos_embed_used = self.position_embedding
385
+
386
+ if self.position_embedding is not None:
387
+ pos_embed = self._adj_position_embedding(
388
+ pos_embed_used=pos_embed_used, batch_size=batch_size
389
+ )
390
+ x += pos_embed
391
+
392
+ # The time embedding is added across the channels after the [CLS] token
393
+ if self.neural_tokenizer:
394
+ num_ch = self.n_chans
395
+ else:
396
+ num_ch = n_patch
397
+ time_embed = self._adj_temporal_embedding(
398
+ num_ch=num_ch, batch_size=batch_size, dim_embed=temporal
399
+ )
400
+ x[:, 1:, :] += time_embed
401
+
402
+ x = self.pos_drop(x)
403
+
404
+ for blk in self.blocks:
405
+ x = blk(x)
406
+
407
+ x = self.norm(x)
408
+ if self.fc_norm is not None:
409
+ if return_all_tokens:
410
+ return self.fc_norm(x)
411
+ temporal = x[:, 1:, :]
412
+ if return_patch_tokens:
413
+ return self.fc_norm(temporal)
414
+ return self.fc_norm(temporal.mean(1))
415
+ else:
416
+ if return_all_tokens:
417
+ return x
418
+ elif return_patch_tokens:
419
+ return x[:, 1:]
420
+ return x[:, 0]
421
+
422
+ def forward(
423
+ self,
424
+ x,
425
+ input_chans=None,
426
+ return_patch_tokens=False,
427
+ return_all_tokens=False,
428
+ ):
429
+ """
430
+ Forward the input EEG data through the model.
431
+
432
+ Parameters
433
+ ----------
434
+ x: torch.Tensor
435
+ The input data with shape (batch, n_chans, n_times)
436
+ or (batch, n_chans, n_patches, patch size).
437
+ input_chans: int
438
+ An input channel to select some dimensions
439
+ return_patch_tokens: bool
440
+ Return the patch tokens
441
+ return_all_tokens: bool
442
+ Return all the tokens
443
+
444
+ Returns
445
+ -------
446
+ torch.Tensor
447
+ The output of the model with dimensions (batch, n_outputs)
448
+ """
449
+ x = self.forward_features(
450
+ x,
451
+ input_chans=input_chans,
452
+ return_patch_tokens=return_patch_tokens,
453
+ return_all_tokens=return_all_tokens,
454
+ )
455
+ x = self.final_layer(x)
456
+ return x
457
+
458
+ def get_classifier(self):
459
+ """
460
+ Get the classifier of the model.
461
+
462
+ Returns
463
+ -------
464
+ torch.nn.Module
465
+ The classifier of the head model.
466
+ """
467
+ return self.final_layer
468
+
469
+ def reset_classifier(self, n_outputs):
470
+ """
471
+ Reset the classifier with the new number of classes.
472
+
473
+ Parameters
474
+ ----------
475
+ n_outputs : int
476
+ The new number of classes.
477
+ """
478
+ self.n_outputs = n_outputs
479
+ self.final_layer = (
480
+ nn.Linear(self.emb_dim, self.n_outputs)
481
+ if self.n_outputs > 0
482
+ else nn.Identity()
483
+ )
484
+
485
+ def _adj_temporal_embedding(self, num_ch, batch_size, dim_embed=None):
486
+ """
487
+ Adjust the dimensions of the time embedding to match the
488
+ number of channels.
489
+
490
+ Parameters
491
+ ----------
492
+ num_ch : int
493
+ The number of channels or number of code books vectors.
494
+ batch_size : int
495
+ Batch size of the input data.
496
+
497
+ Returns
498
+ -------
499
+ temporal_embedding : torch.Tensor
500
+ The adjusted time embedding to be added across the channels
501
+ after the [CLS] token. (x[:, 1:, :] += time_embed)
502
+ """
503
+ if dim_embed is None:
504
+ cut_dimension = self.patch_size
505
+ else:
506
+ cut_dimension = dim_embed
507
+ # first step will be match the time_embed to the number of channels
508
+ temporal_embedding = self.temporal_embedding[:, 1:cut_dimension, :]
509
+ # Add a new dimension to the time embedding
510
+ # e.g. (batch, 62, 200) -> (batch, 1, 62, 200)
511
+ temporal_embedding = temporal_embedding.unsqueeze(1)
512
+ # Expand the time embedding to match the number of channels
513
+ # or number of patches from
514
+ temporal_embedding = temporal_embedding.expand(batch_size, num_ch, -1, -1)
515
+ # Flatten the intermediate dimensions
516
+ temporal_embedding = temporal_embedding.flatten(1, 2)
517
+ return temporal_embedding
518
+
519
+ def _adj_position_embedding(self, pos_embed_used, batch_size):
520
+ """
521
+ Adjust the dimensions of position embedding to match the
522
+ number of patches.
523
+
524
+ Parameters
525
+ ----------
526
+ pos_embed_used : torch.Tensor
527
+ The position embedding to be adjusted.
528
+ batch_size : int
529
+ The number of batches.
530
+
531
+ Returns
532
+ -------
533
+ pos_embed : torch.Tensor
534
+ The adjusted position embedding
535
+ """
536
+ # [CLS] token has no position embedding
537
+ pos_embed = pos_embed_used[:, 1:, :]
538
+ # Adding a new dimension to the position embedding
539
+ pos_embed = pos_embed.unsqueeze(2)
540
+ # Need to expand the position embedding to match the number of
541
+ # n_patches
542
+ pos_embed = pos_embed.expand(batch_size, -1, self.patch_embed[0].n_patchs, -1)
543
+ # Flatten the intermediate dimensions,
544
+ # such as the number of patches and the "channels" dim
545
+ pos_embed = pos_embed.flatten(1, 2)
546
+ # Get the base position embedding
547
+ # This is the position embedding for the [CLS] token
548
+ base_pos = pos_embed[:, 0:1, :].expand(batch_size, -1, -1)
549
+ # Concatenate the base position embedding with the
550
+ # position embedding
551
+ pos_embed = torch.cat((base_pos, pos_embed), dim=1)
552
+ return pos_embed
553
+
554
+
555
+ class _SegmentPatch(nn.Module):
556
+ """Segment and Patch for EEG data.
557
+
558
+ Adapted Patch Embedding inspired in the Visual Transform approach
559
+ to extract the learned segmentor, we expect get the input shape as:
560
+ (Batch, Number of Channels, number of times points).
561
+
562
+ We apply a 2D convolution with kernel size of (1, patch_size)
563
+ and a stride of (1, patch_size).
564
+
565
+ The results output shape will be:
566
+ (Batch, Number of Channels, Number of patches, patch size).
567
+
568
+ This way, we learned a convolution to segment the input shape.
569
+
570
+ The number of patches is calculated as the number of samples divided
571
+ by the patch size.
572
+
573
+ Parameters:
574
+ -----------
575
+ n_times: int (default=2000)
576
+ Number of temporal components of the input tensor.
577
+ in_chans: int (default=1)
578
+ number of electrods from the EEG signal
579
+ emb_dim: int (default=200)
580
+ Number of n_output to be used in the convolution, here,
581
+ we used the same as patch_size.
582
+ patch_size: int (default=200)
583
+ Size of the patch, default is 1-seconds with 200Hz.
584
+ Returns:
585
+ --------
586
+ x_patched: torch.Tensor
587
+ Output tensor of shape (batch, n_chans, num_patches, emb_dim).
588
+ """
589
+
590
+ def __init__(
591
+ self, n_times=2000, patch_size=200, n_chans=1, emb_dim=200, learned_patcher=True
592
+ ):
593
+ super().__init__()
594
+
595
+ self.n_times = n_times
596
+ self.patch_size = patch_size
597
+ self.n_patchs = n_times // patch_size
598
+ self.emb_dim = emb_dim
599
+ self.n_chans = n_chans
600
+ self.learned_patcher = learned_patcher
601
+
602
+ self.patcher = nn.Conv1d(
603
+ in_channels=1,
604
+ out_channels=self.emb_dim,
605
+ kernel_size=self.patch_size,
606
+ stride=self.patch_size,
607
+ )
608
+
609
+ self.adding_extra_dim = Rearrange(
610
+ pattern="batch nchans temporal -> (batch nchans) 1 temporal"
611
+ )
612
+
613
+ def forward(self, x):
614
+ """
615
+ Using an 1D convolution to generate segments of EEG signal.
616
+
617
+ Parameters:
618
+ -----------
619
+ X: Tensor
620
+ [batch, n_chans, n_times]
621
+
622
+ Returns:
623
+ --------
624
+ X_patch: Tensor
625
+ [batch, n_chans, n_times//patch_size, patch_size]
626
+ """
627
+ batch_size, _, _ = x.shape
628
+ # Input shape: [batch, n_chs, n_times]
629
+
630
+ # First, rearrange input to treat the channel dimension 'n_chs' as
631
+ # separate 'dimension' in batch for Conv1d
632
+ # This requires reshaping x to have a height of 1 for each EEG sample.
633
+ if self.learned_patcher:
634
+ x = self.adding_extra_dim(x)
635
+
636
+ # Apply the convolution along the temporal dimension
637
+ # Conv2d output shape: [(batch*n_chs), emb_dim, n_patches]
638
+ x = self.patcher(x)
639
+
640
+ # Now, rearrange output to get back to a batch-first format,
641
+ # combining embedded patches with channel information
642
+ # Assuming you want [batch, n_chs, n_patches, emb_dim]
643
+ # as output, which keeps channel information
644
+ # This treats each patch embedding as a feature alongside channels
645
+ x = rearrange(
646
+ x,
647
+ pattern="(batch nchans) embed npatchs -> batch nchans npatchs embed",
648
+ batch=batch_size,
649
+ nchans=self.n_chans,
650
+ )
651
+ else:
652
+ x = x.view(
653
+ batch_size,
654
+ self.n_chans,
655
+ self.n_times // self.patch_size,
656
+ self.patch_size,
657
+ )
658
+ return x
659
+
660
+
661
+ class _PatchEmbed(nn.Module):
662
+ """EEG to Patch Embedding.
663
+
664
+ This code is used when we want to apply the patch embedding
665
+ after the codebook layer.
666
+
667
+ Parameters:
668
+ -----------
669
+ n_times: int (default=2000)
670
+ Number of temporal components of the input tensor.
671
+ patch_size: int (default=200)
672
+ Size of the patch, default is 1-seconds with 200Hz.
673
+ in_channels: int (default=1)
674
+ Number of input channels for to be used in the convolution.
675
+ emb_dim: int (default=200)
676
+ Number of out_channes to be used in the convolution, here,
677
+ we used the same as patch_size.
678
+ n_codebooks: int (default=62)
679
+ Number of patches to be used in the convolution, here,
680
+ we used the same as n_times // patch_size.
681
+ """
682
+
683
+ def __init__(
684
+ self, n_times=2000, patch_size=200, in_channels=1, emb_dim=200, n_codebooks=62
685
+ ):
686
+ super().__init__()
687
+ self.n_times = n_times
688
+ self.patch_size = patch_size
689
+ self.patch_shape = (1, self.n_times // self.patch_size)
690
+ n_patchs = n_codebooks * (self.n_times // self.patch_size)
691
+
692
+ self.n_patchs = n_patchs
693
+
694
+ self.proj = nn.Conv2d(
695
+ in_channels=in_channels,
696
+ out_channels=emb_dim,
697
+ kernel_size=(1, self.patch_size),
698
+ stride=(1, self.patch_size),
699
+ )
700
+
701
+ self.merge_transpose = Rearrange(
702
+ "Batch ch patch spatch -> Batch patch spatch ch",
703
+ )
704
+
705
+ def forward(self, x):
706
+ """
707
+ Apply the convolution to the input tensor.
708
+ then merge the output tensor to the desired shape.
709
+
710
+ Parameters:
711
+ -----------
712
+ x: torch.Tensor
713
+ Input tensor of shape (Batch, Channels, n_patchs, patch_size).
714
+
715
+ Return:
716
+ -------
717
+ x: torch.Tensor
718
+ Output tensor of shape (Batch, n_patchs, patch_size, channels).
719
+ """
720
+ x = self.proj(x)
721
+ x = self.merge_transpose(x)
722
+ return x
723
+
724
+
725
+ class _Attention(nn.Module):
726
+ """
727
+ Attention with the options of Window-based multi-head self attention (W-MSA).
728
+
729
+ This code is strong inspired by:
730
+ https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py#L77
731
+
732
+ Basically, the attention module is a linear layer that takes the input
733
+ tensor and returns the output tensor. The input tensor is first passed
734
+ through a linear layer to get the query, key, and value tensors. Then,
735
+ the query tensor is multiplied by the scale factor and the result is
736
+ multiplied by the transpose of the key tensor.
737
+
738
+ The flag window_size is used to determine if the attention is
739
+ window-based or not.
740
+
741
+ Parameters:
742
+ -----------
743
+ dim: int
744
+ Number of input features.
745
+ num_heads: int (default=8)
746
+ Number of attention heads.
747
+ qkv_bias: bool (default=False)
748
+ If True, add a learnable bias to the query, key, and value tensors.
749
+ qk_norm: nn.LayerNorm (default=None)
750
+ If not None, apply LayerNorm to the query and key tensors.
751
+ qk_scale: float (default=None)
752
+ If not None, use this value as the scale factor. If None,
753
+ use head_dim**-0.5, where head_dim = dim // num_heads.
754
+ attn_drop: float (default=0.0)
755
+ Dropout rate for the attention weights.
756
+ proj_drop: float (default=0.0)
757
+ Dropout rate for the output tensor.
758
+ window_size: bool (default=None)
759
+ If not None, use window-based multi-head self attention based on Swin Transformer.
760
+ attn_head_dim: int (default=None)
761
+ If not None, use this value as the head_dim. If None, use dim // num_heads.
762
+ """
763
+
764
+ def __init__(
765
+ self,
766
+ dim,
767
+ num_heads=8,
768
+ qkv_bias=False,
769
+ qk_norm=None,
770
+ qk_scale=None,
771
+ attn_drop=0.0,
772
+ proj_drop=0.0,
773
+ window_size=None,
774
+ attn_head_dim=None,
775
+ ):
776
+ super().__init__()
777
+ self.num_heads = num_heads
778
+ head_dim = dim // num_heads
779
+ if attn_head_dim is not None:
780
+ head_dim = attn_head_dim
781
+ all_head_dim = head_dim * self.num_heads
782
+ self.scale = qk_scale or head_dim**-0.5
783
+
784
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
785
+ if qkv_bias:
786
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
787
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
788
+ else:
789
+ self.q_bias = None
790
+ self.v_bias = None
791
+
792
+ if qk_norm is not None:
793
+ self.q_norm = qk_norm(head_dim)
794
+ self.k_norm = qk_norm(head_dim)
795
+ else:
796
+ self.q_norm = None
797
+ self.k_norm = None
798
+
799
+ if window_size:
800
+ self.window_size = window_size
801
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
802
+ 2 * window_size[1] - 1
803
+ ) + 3
804
+ self.relative_position_bias_table = nn.Parameter(
805
+ torch.zeros(self.num_relative_distance, num_heads)
806
+ ) # 2*Wh-1 * 2*Ww-1, nH
807
+ # cls to token & token 2 cls & cls to cls
808
+
809
+ # get pair-wise relative position index for each token inside the window
810
+ coords_h = torch.arange(window_size[0])
811
+ coords_w = torch.arange(window_size[1])
812
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
813
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
814
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
815
+ # 2, Wh*Ww, Wh*Ww
816
+ relative_coords = relative_coords.permute(
817
+ 1, 2, 0
818
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
819
+ relative_coords[:, :, 0] += window_size[0] - 1
820
+ # shift to start from 0
821
+ relative_coords[:, :, 1] += window_size[1] - 1
822
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
823
+ relative_position_index = torch.zeros(
824
+ size=(window_size[0] * window_size[1] + 1,) * 2,
825
+ dtype=relative_coords.dtype,
826
+ )
827
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
828
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
829
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
830
+ relative_position_index[0, 0] = self.num_relative_distance - 1
831
+
832
+ self.register_buffer("relative_position_index", relative_position_index)
833
+ else:
834
+ self.window_size = None
835
+ self.relative_position_bias_table = None
836
+ self.relative_position_index = None
837
+
838
+ self.attn_drop = nn.Dropout(attn_drop)
839
+ self.proj = nn.Linear(all_head_dim, dim)
840
+ self.proj_drop = nn.Dropout(proj_drop)
841
+
842
+ def forward(
843
+ self,
844
+ x: torch.Tensor,
845
+ return_attention=False,
846
+ return_qkv=False,
847
+ ):
848
+ """
849
+ Apply the attention mechanism to the input tensor.
850
+
851
+ Parameters:
852
+ -----------
853
+ x: torch.Tensor
854
+ Input tensor of shape (Batch, N, C).
855
+ return_attention: bool (default=False)
856
+ If True, return the attention weights.
857
+ return_qkv: bool (default=False)
858
+ If True, return the query, key, and value tensors together with
859
+ the output tensor.
860
+ Returns:
861
+ --------
862
+ x: torch.Tensor
863
+ Output tensor of shape (Batch, N, C).
864
+ qkv: torch.Tensor (optional)
865
+ Query, key, and value tensors of shape
866
+ (Batch, N, 3, num_heads, C // num_heads).
867
+ """
868
+ B, N, _ = x.shape
869
+ qkv_bias = None
870
+ if self.q_bias is not None:
871
+ qkv_bias = torch.cat(
872
+ (
873
+ self.q_bias,
874
+ torch.zeros_like(self.v_bias, requires_grad=False),
875
+ self.v_bias,
876
+ )
877
+ )
878
+ qkv = nn.functional.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
879
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
880
+ q, k, v = (
881
+ qkv[0],
882
+ qkv[1],
883
+ qkv[2],
884
+ ) # make torchscript happy (cannot use tensor as tuple) (B, H, N, C)
885
+ if self.q_norm is not None:
886
+ q = self.q_norm(q).type_as(v)
887
+ if self.k_norm is not None:
888
+ k = self.k_norm(k).type_as(v)
889
+
890
+ q = q * self.scale
891
+ attn = q @ k.transpose(-2, -1)
892
+
893
+ if self.relative_position_bias_table is not None:
894
+ relative_position_bias = self.relative_position_bias_table[
895
+ self.relative_position_index.view(-1)
896
+ ].view(
897
+ self.window_size[0] * self.window_size[1] + 1,
898
+ self.window_size[0] * self.window_size[1] + 1,
899
+ -1,
900
+ ) # Wh*Ww,Wh*Ww,nH
901
+ relative_position_bias = relative_position_bias.permute(
902
+ 2, 0, 1
903
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
904
+ attn = attn + relative_position_bias.unsqueeze(0)
905
+
906
+ attn = attn.softmax(dim=-1)
907
+ attn = self.attn_drop(attn)
908
+
909
+ if return_attention:
910
+ return attn
911
+
912
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
913
+
914
+ x = self.proj(x)
915
+ x = self.proj_drop(x)
916
+
917
+ if return_qkv:
918
+ return x, qkv
919
+
920
+ return x
921
+
922
+
923
+ class _WindowsAttentionBlock(nn.Module):
924
+ """Blocks of Windows Attention with Layer norm and MLP.
925
+
926
+ Notes: This code is strong inspired by:
927
+ BeiTv2 from Microsoft.
928
+
929
+ Parameters:
930
+ -----------
931
+ dim: int
932
+ Number of input features.
933
+ num_heads: int (default=8)
934
+ Number of attention heads.
935
+ mlp_ratio: float (default=4.0)
936
+ Ratio to increase the hidden features from input features in the MLP layer
937
+ qkv_bias: bool (default=False)
938
+ If True, add a learnable bias to the query, key, and value tensors.
939
+ qk_norm: nn.LayerNorm (default=None)
940
+ If not None, apply LayerNorm to the query and key tensors.
941
+ qk_scale: float (default=None)
942
+ If not None, use this value as the scale factor. If None,
943
+ use head_dim**-0.5, where head_dim = dim // num_heads.
944
+ drop: float (default=0.0)
945
+ Dropout rate for the output tensor.
946
+ attn_drop: float (default=0.0)
947
+ Dropout rate for the attention weights.
948
+ drop_path: float (default=0.0)
949
+ Dropout rate for the output tensor.
950
+ init_values: float (default=None)
951
+ If not None, use this value to initialize the gamma_1 and gamma_2
952
+ parameters.
953
+ activation: nn.GELU (default)
954
+ Activation function.
955
+ norm_layer: nn.LayerNorm (default)
956
+ Normalization layer.
957
+ window_size: bool (default=None)
958
+ If not None, use window-based multi-head self attention based on
959
+ Swin Transformer.
960
+ attn_head_dim: int (default=None)
961
+ If not None, use this value as the head_dim. If None,
962
+ the classes use dim // num_heads
963
+
964
+ Returns:
965
+ --------
966
+ x: torch.Tensor
967
+ Output tensor of shape (Batch, N, C). [I think]
968
+
969
+ """
970
+
971
+ def __init__(
972
+ self,
973
+ dim: int,
974
+ num_heads: int,
975
+ mlp_ratio=4.0,
976
+ qkv_bias=False,
977
+ qk_norm=None,
978
+ qk_scale=None,
979
+ drop=0.0,
980
+ attn_drop=0.0,
981
+ drop_path=0.0,
982
+ init_values=None,
983
+ activation: nn.Module = nn.GELU,
984
+ norm_layer=nn.LayerNorm,
985
+ window_size=None,
986
+ attn_head_dim=None,
987
+ ):
988
+ super().__init__()
989
+ self.norm1 = norm_layer(dim)
990
+ self.attn = _Attention(
991
+ dim,
992
+ num_heads=num_heads,
993
+ qkv_bias=qkv_bias,
994
+ qk_norm=qk_norm,
995
+ qk_scale=qk_scale,
996
+ attn_drop=attn_drop,
997
+ proj_drop=drop,
998
+ window_size=window_size,
999
+ attn_head_dim=attn_head_dim,
1000
+ )
1001
+
1002
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
1003
+ self.norm2 = norm_layer(dim)
1004
+ mlp_hidden_dim = int(dim * mlp_ratio)
1005
+ self.mlp = MLP(
1006
+ in_features=dim,
1007
+ hidden_features=[mlp_hidden_dim],
1008
+ activation=activation,
1009
+ drop=drop,
1010
+ )
1011
+
1012
+ if init_values is not None and init_values > 0:
1013
+ self.gamma_1 = nn.Parameter(
1014
+ init_values * torch.ones((dim)), requires_grad=True
1015
+ )
1016
+ self.gamma_2 = nn.Parameter(
1017
+ init_values * torch.ones((dim)), requires_grad=True
1018
+ )
1019
+ else:
1020
+ self.gamma_1, self.gamma_2 = None, None
1021
+
1022
+ def forward(self, x, return_attention=False, return_qkv=False):
1023
+ """
1024
+ Apply the attention mechanism to the input tensor.
1025
+ Parameters
1026
+ ----------
1027
+ x: torch.Tensor
1028
+ Input tensor of shape (Batch, chs, npatchs, patch).
1029
+ return_attention: bool (default=False)
1030
+ If True, return the attention weights.
1031
+ return_qkv: bool (default=False)
1032
+ If True, return the query, key, and value tensors together with
1033
+ the output tensor.
1034
+
1035
+ Returns
1036
+ -------
1037
+ torch.Tensor
1038
+ Output tensor of shape (Batch, chs, npatchs, patch).
1039
+ """
1040
+
1041
+ if return_attention:
1042
+ return self.attn(self.norm1(x), return_attention=True)
1043
+ if return_qkv:
1044
+ y, qkv = self.attn(self.norm1(x), return_qkv=return_qkv)
1045
+ x = x + self.drop_path(self.gamma_1 * y)
1046
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
1047
+ return x, qkv
1048
+
1049
+ if self.gamma_1 is None:
1050
+ x = x + self.drop_path(self.attn(self.norm1(x)))
1051
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
1052
+ else:
1053
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
1054
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
1055
+ return x
1056
+
1057
+
1058
+ class _TemporalConv(nn.Module):
1059
+ """
1060
+ Temporal Convolutional Module inspired by Visual Transformer.
1061
+
1062
+ In this module we apply the follow steps three times repeatedly
1063
+ to the input tensor, reducing the temporal dimension only in the first.
1064
+ - Apply a 2D convolution.
1065
+ - Apply a GELU activation function.
1066
+ - Apply a GroupNorm with 4 groups.
1067
+
1068
+ Parameters:
1069
+ -----------
1070
+ in_chans: int (default=1)
1071
+ Number of input channels.
1072
+ out_chans: int (default=8)
1073
+ Number of output channels.
1074
+ num_groups: int (default=4)
1075
+ Number of groups for GroupNorm.
1076
+ kernel_size_1: tuple (default=(1, 15))
1077
+ Kernel size for the first convolution.
1078
+ kernel_size_2: tuple (default=(1, 3))
1079
+ Kernel size for the second and third convolutions.
1080
+ stride_1: tuple (default=(1, 8))
1081
+ Stride for the first convolution.
1082
+ padding_1: tuple (default=(0, 7))
1083
+ Padding for the first convolution.
1084
+ padding_2: tuple (default=(0, 1))
1085
+ Padding for the second and third convolutions.
1086
+ activation: nn.Module, default=nn.GELU
1087
+ Activation function class to apply. Should be a PyTorch activation
1088
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.GELU``.
1089
+
1090
+ Returns:
1091
+ --------
1092
+ x: torch.Tensor
1093
+ Output tensor of shape (Batch, NA, Temporal Channel).
1094
+ """
1095
+
1096
+ def __init__(
1097
+ self,
1098
+ in_channels=1,
1099
+ out_channels=8,
1100
+ num_groups=4,
1101
+ kernel_size_1=(1, 15),
1102
+ stride_1=(1, 8),
1103
+ padding_1=(0, 7),
1104
+ kernel_size_2=(1, 3),
1105
+ padding_2=(0, 1),
1106
+ activation: nn.Module = nn.GELU,
1107
+ ):
1108
+ super().__init__()
1109
+
1110
+ # Here, we use the Rearrange layer from einops to flatten the input
1111
+ # tensor to a 2D tensor, so we can apply 2D convolutions.
1112
+ self.channel_patch_flatten = Rearrange(
1113
+ "Batch chs npat spatch -> Batch () (chs npat) spatch"
1114
+ )
1115
+
1116
+ self.conv1 = nn.Conv2d(
1117
+ in_channels=in_channels,
1118
+ out_channels=out_channels,
1119
+ kernel_size=kernel_size_1,
1120
+ stride=stride_1,
1121
+ padding=padding_1,
1122
+ )
1123
+ self.act_layer_1 = activation()
1124
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)
1125
+
1126
+ self.conv2 = nn.Conv2d(
1127
+ in_channels=out_channels,
1128
+ out_channels=out_channels,
1129
+ kernel_size=kernel_size_2,
1130
+ padding=padding_2,
1131
+ )
1132
+ self.act_layer_2 = activation()
1133
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)
1134
+
1135
+ self.conv3 = nn.Conv2d(
1136
+ in_channels=out_channels,
1137
+ out_channels=out_channels,
1138
+ kernel_size=kernel_size_2,
1139
+ padding=padding_2,
1140
+ )
1141
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)
1142
+ self.act_layer_3 = activation()
1143
+
1144
+ self.transpose_temporal_channel = Rearrange("Batch C NA T -> Batch NA (T C)")
1145
+
1146
+ def forward(self, x):
1147
+ """
1148
+ Apply 3 steps of 2D convolution, GELU activation function,
1149
+ and GroupNorm.
1150
+
1151
+ Parameters:
1152
+ -----------
1153
+ x: torch.Tensor
1154
+ Input tensor of shape (Batch, Channels, n_patchs, size_patch).
1155
+
1156
+ Returns:
1157
+ --------
1158
+ x: torch.Tensor
1159
+ Output tensor of shape (Batch, NA, Temporal Channel).
1160
+ """
1161
+ x = self.channel_patch_flatten(x)
1162
+ x = self.act_layer_1(self.norm1(self.conv1(x)))
1163
+ x = self.act_layer_2(self.norm2(self.conv2(x)))
1164
+ x = self.act_layer_3(self.norm3(self.conv3(x)))
1165
+ x = self.transpose_temporal_channel(x)
1166
+ return x