braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1186 @@
1
+ """
2
+ Labram module.
3
+ Authors: Wei-Bang Jiang
4
+ Bruno Aristimunha <b.aristimunha@gmail.com>
5
+ License: BSD 3 clause
6
+ """
7
+
8
+ from collections import OrderedDict
9
+ from warnings import warn
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from einops import rearrange
14
+ from einops.layers.torch import Rearrange
15
+ from torch.nn.init import trunc_normal_
16
+
17
+ from braindecode.functional import rescale_parameter
18
+ from braindecode.models.base import EEGModuleMixin
19
+ from braindecode.modules import MLP, DropPath
20
+
21
+
22
+ class Labram(EEGModuleMixin, nn.Module):
23
+ """Labram from Jiang, W B et al (2024) [Jiang2024]_.
24
+
25
+ .. figure:: https://arxiv.org/html/2405.18765v1/x1.png
26
+ :align: center
27
+ :alt: Labram Architecture.
28
+
29
+ Large Brain Model for Learning Generic Representations with Tremendous
30
+ EEG Data in BCI from [Jiang2024]_
31
+
32
+ This is an **adaptation** of the code [Code2024]_ from the Labram model.
33
+
34
+ The model is transformer architecture with **strong** inspiration from
35
+ BEiTv2 [BeiTv2]_.
36
+
37
+ The models can be used in two modes:
38
+ - Neural Tokenizor: Design to get an embedding layers (e.g. classification).
39
+ - Neural Decoder: To extract the ampliture and phase outputs with a VQSNP.
40
+
41
+ The braindecode's modification is to allow the model to be used in
42
+ with an input shape of (batch, n_chans, n_times), if neural tokenizer
43
+ equals True. The original implementation uses (batch, n_chans, n_patches,
44
+ patch_size) as input with static segmentation of the input data.
45
+
46
+ The models have the following sequence of steps:
47
+ if neural tokenizer:
48
+ - SegmentPatch: Segment the input data in patches;
49
+ - TemporalConv: Apply a temporal convolution to the segmented data;
50
+ - Residual adding cls, temporal and position embeddings (optional);
51
+ - WindowsAttentionBlock: Apply a windows attention block to the data;
52
+ - LayerNorm: Apply layer normalization to the data;
53
+ - Linear: An head linear layer to transformer the data into classes.
54
+
55
+ else:
56
+ - PatchEmbed: Apply a patch embedding to the input data;
57
+ - Residual adding cls, temporal and position embeddings (optional);
58
+ - WindowsAttentionBlock: Apply a windows attention block to the data;
59
+ - LayerNorm: Apply layer normalization to the data;
60
+ - Linear: An head linear layer to transformer the data into classes.
61
+
62
+ .. versionadded:: 0.9
63
+
64
+ Parameters
65
+ ----------
66
+ patch_size : int
67
+ The size of the patch to be used in the patch embedding.
68
+ emb_size : int
69
+ The dimension of the embedding.
70
+ in_channels : int
71
+ The number of convolutional input channels.
72
+ out_channels : int
73
+ The number of convolutional output channels.
74
+ n_layers : int (default=12)
75
+ The number of attention layers of the model.
76
+ att_num_heads : int (default=10)
77
+ The number of attention heads.
78
+ mlp_ratio : float (default=4.0)
79
+ The expansion ratio of the mlp layer
80
+ qkv_bias : bool (default=False)
81
+ If True, add a learnable bias to the query, key, and value tensors.
82
+ qk_norm : Pytorch Normalize layer (default=None)
83
+ If not None, apply LayerNorm to the query and key tensors
84
+ qk_scale : float (default=None)
85
+ If not None, use this value as the scale factor. If None,
86
+ use head_dim**-0.5, where head_dim = dim // num_heads.
87
+ drop_prob : float (default=0.0)
88
+ Dropout rate for the attention weights.
89
+ attn_drop_prob : float (default=0.0)
90
+ Dropout rate for the attention weights.
91
+ drop_path_prob : float (default=0.0)
92
+ Dropout rate for the attention weights used on DropPath.
93
+ norm_layer : Pytorch Normalize layer (default=nn.LayerNorm)
94
+ The normalization layer to be used.
95
+ init_values : float (default=None)
96
+ If not None, use this value to initialize the gamma_1 and gamma_2
97
+ parameters.
98
+ use_abs_pos_emb : bool (default=True)
99
+ If True, use absolute position embedding.
100
+ use_mean_pooling : bool (default=True)
101
+ If True, use mean pooling.
102
+ init_scale : float (default=0.001)
103
+ The initial scale to be used in the parameters of the model.
104
+ neural_tokenizer : bool (default=True)
105
+ The model can be used in two modes: Neural Tokenizor or Neural Decoder.
106
+ attn_head_dim : bool (default=None)
107
+ The head dimension to be used in the attention layer, to be used only
108
+ during pre-training.
109
+ activation: nn.Module, default=nn.GELU
110
+ Activation function class to apply. Should be a PyTorch activation
111
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.GELU``.
112
+
113
+ References
114
+ ----------
115
+ .. [Jiang2024] Wei-Bang Jiang, Li-Ming Zhao, Bao-Liang Lu. 2024, May.
116
+ Large Brain Model for Learning Generic Representations with Tremendous
117
+ EEG Data in BCI. The Twelfth International Conference on Learning
118
+ Representations, ICLR.
119
+ .. [Code2024] Wei-Bang Jiang, Li-Ming Zhao, Bao-Liang Lu. 2024. Labram
120
+ Large Brain Model for Learning Generic Representations with Tremendous
121
+ EEG Data in BCI. GitHub https://github.com/935963004/LaBraM
122
+ (accessed 2024-03-02)
123
+ .. [BeiTv2] Zhiliang Peng, Li Dong, Hangbo Bao, Qixiang Ye, Furu Wei. 2024.
124
+ BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers.
125
+ arXiv:2208.06366 [cs.CV]
126
+ """
127
+
128
+ def __init__(
129
+ self,
130
+ n_times=None,
131
+ n_outputs=None,
132
+ chs_info=None,
133
+ n_chans=None,
134
+ sfreq=None,
135
+ input_window_seconds=None,
136
+ patch_size=200,
137
+ emb_size=200,
138
+ in_channels=1,
139
+ out_channels=8,
140
+ n_layers=12,
141
+ att_num_heads=10,
142
+ mlp_ratio=4.0,
143
+ qkv_bias=False,
144
+ qk_norm=None,
145
+ qk_scale=None,
146
+ drop_prob=0.0,
147
+ attn_drop_prob=0.0,
148
+ drop_path_prob=0.0,
149
+ norm_layer=nn.LayerNorm,
150
+ init_values=None,
151
+ use_abs_pos_emb=True,
152
+ use_mean_pooling=True,
153
+ init_scale=0.001,
154
+ neural_tokenizer=True,
155
+ attn_head_dim=None,
156
+ activation: nn.Module = nn.GELU,
157
+ ):
158
+ super().__init__(
159
+ n_outputs=n_outputs,
160
+ n_chans=n_chans,
161
+ chs_info=chs_info,
162
+ n_times=n_times,
163
+ input_window_seconds=input_window_seconds,
164
+ sfreq=sfreq,
165
+ )
166
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
167
+
168
+ self.patch_size = patch_size
169
+ self.num_features = self.emb_size = emb_size
170
+ self.neural_tokenizer = neural_tokenizer
171
+ self.init_scale = init_scale
172
+
173
+ if patch_size > self.n_times:
174
+ warn(
175
+ f"patch_size ({patch_size}) > n_times ({self.n_times}); "
176
+ f"setting patch_size = {self.n_times}.",
177
+ UserWarning,
178
+ )
179
+ self.patch_size = self.n_times
180
+ self.num_features = None
181
+ self.emb_size = None
182
+ else:
183
+ self.patch_size = patch_size
184
+ self.n_path = self.n_times // self.patch_size
185
+
186
+ if neural_tokenizer and in_channels != 1:
187
+ warn(
188
+ "The model is in Neural Tokenizer mode, but the variable "
189
+ + "`in_channels` is different from the default values."
190
+ + "`in_channels` is only needed for the Neural Decoder mode."
191
+ + "in_channels is not used in the Neural Tokenizer mode.",
192
+ UserWarning,
193
+ )
194
+ in_channels = 1
195
+ # If you can use the model in Neural Tokenizer mode,
196
+ # temporal conv layer will be use over the patched dataset
197
+ if neural_tokenizer:
198
+ self.patch_embed = nn.Sequential(
199
+ OrderedDict(
200
+ [
201
+ (
202
+ "segment_patch",
203
+ _SegmentPatch(
204
+ n_times=self.n_times,
205
+ patch_size=self.patch_size,
206
+ n_chans=self.n_chans,
207
+ emb_dim=self.patch_size,
208
+ ),
209
+ ),
210
+ (
211
+ "temporal_conv",
212
+ _TemporalConv(
213
+ out_channels=out_channels, activation=activation
214
+ ),
215
+ ),
216
+ ]
217
+ )
218
+ )
219
+ else:
220
+ # If not, the model will be used as Neural Decoder mode
221
+ # So the input here will be after the VQVAE encoder
222
+ # To be used to extract the ampliture and phase outputs.
223
+ # Adding inside a Sequential to use the same convention as the
224
+ # Neural Tokenizer mode.
225
+ self.patch_embed = nn.Sequential()
226
+ self.patch_embed.add_module(
227
+ "segment_patch",
228
+ _PatchEmbed(
229
+ n_times=self.n_times,
230
+ patch_size=patch_size,
231
+ in_channels=in_channels,
232
+ emb_dim=self.emb_size,
233
+ ),
234
+ )
235
+
236
+ with torch.no_grad():
237
+ dummy = torch.zeros(1, self.n_chans, self.n_times)
238
+ out = self.patch_embed(dummy)
239
+ # out.shape for tokenizer: (1, n_chans, emb_dim)
240
+ # for decoder: (1, n_patch, patch_size, emb_dim), but we want last dim
241
+ self.emb_size = out.shape[-1]
242
+ self.num_features = self.emb_size
243
+
244
+ # Defining the parameters
245
+ # Creating a parameter list with cls token]
246
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.emb_size))
247
+ # Positional embedding and time embedding are complementary
248
+ # one is for the spatial information and the other is for the temporal
249
+ # information.
250
+ # The time embedding is used to encode something in the number of
251
+ # patches, and the position embedding is used to encode the channels'
252
+ # information.
253
+ if use_abs_pos_emb:
254
+ self.position_embedding = nn.Parameter(
255
+ torch.zeros(1, self.n_chans + 1, self.emb_size),
256
+ requires_grad=True,
257
+ )
258
+ else:
259
+ self.position_embedding = None
260
+
261
+ self.temporal_embedding = nn.Parameter(
262
+ torch.zeros(1, self.patch_embed[0].n_patchs + 1, self.emb_size),
263
+ requires_grad=True,
264
+ )
265
+ self.pos_drop = nn.Dropout(p=drop_prob)
266
+
267
+ dpr = [
268
+ x.item() for x in torch.linspace(0, drop_path_prob, n_layers)
269
+ ] # stochastic depth decay rule
270
+ self.blocks = nn.ModuleList(
271
+ [
272
+ _WindowsAttentionBlock(
273
+ dim=self.emb_size,
274
+ num_heads=att_num_heads,
275
+ mlp_ratio=mlp_ratio,
276
+ qkv_bias=qkv_bias,
277
+ qk_norm=qk_norm,
278
+ qk_scale=qk_scale,
279
+ drop=drop_prob,
280
+ attn_drop=attn_drop_prob,
281
+ drop_path=dpr[i],
282
+ norm_layer=norm_layer,
283
+ init_values=init_values,
284
+ window_size=(
285
+ self.patch_embed[0].patch_shape
286
+ if not neural_tokenizer
287
+ else None
288
+ ),
289
+ attn_head_dim=attn_head_dim,
290
+ activation=activation,
291
+ )
292
+ for i in range(n_layers)
293
+ ]
294
+ )
295
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(self.emb_size)
296
+ self.fc_norm = norm_layer(self.emb_size) if use_mean_pooling else None
297
+
298
+ if self.n_outputs > 0:
299
+ self.final_layer = nn.Linear(self.emb_size, self.n_outputs)
300
+ else:
301
+ self.final_layer = nn.Identity()
302
+
303
+ self.apply(self._init_weights)
304
+ self.fix_init_weight_and_init_embedding()
305
+
306
+ def fix_init_weight_and_init_embedding(self):
307
+ """
308
+ Fix the initial weight and the initial embedding.
309
+ Initializing with truncated normal distribution.
310
+ """
311
+ trunc_normal_(self.cls_token, std=0.02)
312
+ trunc_normal_(self.temporal_embedding, std=0.02)
313
+
314
+ if self.position_embedding is not None:
315
+ trunc_normal_(self.position_embedding, std=0.02)
316
+
317
+ if isinstance(self.final_layer, nn.Linear):
318
+ trunc_normal_(self.final_layer.weight, std=0.02)
319
+
320
+ for layer_id, layer in enumerate(self.blocks):
321
+ rescale_parameter(layer.attn.proj.weight.data, layer_id + 1)
322
+ rescale_parameter(layer.mlp[-2].weight.data, layer_id + 1)
323
+
324
+ if isinstance(self.final_layer, nn.Linear):
325
+ self.final_layer.weight.data.mul_(self.init_scale)
326
+ self.final_layer.bias.data.mul_(self.init_scale)
327
+
328
+ @staticmethod
329
+ def _init_weights(layer):
330
+ """
331
+ Initialize the weights of the model for each layer layer.
332
+
333
+ If the layer is a linear layer, the weight will be initialized
334
+ with a truncated normal distribution with std=0.02.
335
+
336
+ If m.bias is not None, the bias will be initialized with a constant
337
+ value of 0.
338
+
339
+ If the layer is a layer normalization layer, the bias will be
340
+ initialized with a constant value of 0, and the weight will be
341
+ initialized with a constant value of 1.
342
+
343
+ Parameters
344
+ ----------
345
+ m : torch.nn.Module
346
+ The layer of the pytorch model
347
+ """
348
+
349
+ if isinstance(layer, nn.Linear):
350
+ trunc_normal_(layer.weight, std=0.02)
351
+ if layer.bias is not None:
352
+ nn.init.constant_(layer.bias, 0)
353
+ elif isinstance(layer, nn.LayerNorm):
354
+ nn.init.constant_(layer.bias, 0)
355
+ nn.init.constant_(layer.weight, 1.0)
356
+
357
+ def get_num_layers(self):
358
+ """
359
+ Convenience method to get the number of layers in the model.
360
+ """
361
+ return len(self.blocks)
362
+
363
+ def forward_features(
364
+ self,
365
+ x,
366
+ input_chans=None,
367
+ return_patch_tokens=False,
368
+ return_all_tokens=False,
369
+ ):
370
+ """
371
+ Forward the features of the model.
372
+
373
+ Parameters
374
+ ----------
375
+ x : torch.Tensor
376
+ The input data with shape (batch, n_chans, n_patches, patch size),
377
+ if neural decoder or (batch, n_chans, n_times), if neural tokenizer.
378
+ input_chans : int
379
+ The number of input channels.
380
+ return_patch_tokens : bool
381
+ Whether to return the patch tokens.
382
+ return_all_tokens : bool
383
+ Whether to return all the tokens.
384
+
385
+ Returns
386
+ -------
387
+ x : torch.Tensor
388
+ The output of the model.
389
+ """
390
+ if self.neural_tokenizer:
391
+ batch_size, nch, n_patch, temporal = self.patch_embed.segment_patch(x).shape
392
+ else:
393
+ batch_size, nch, n_patch = self.patch_embed(x).shape
394
+ x = self.patch_embed(x)
395
+ # add the [CLS] token to the embedded patch tokens
396
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
397
+
398
+ x = torch.cat((cls_tokens, x), dim=1)
399
+
400
+ # Positional Embedding
401
+ if input_chans is not None:
402
+ pos_embed_used = self.position_embedding[:, input_chans]
403
+ else:
404
+ pos_embed_used = self.position_embedding
405
+
406
+ if self.position_embedding is not None:
407
+ pos_embed = self._adj_position_embedding(
408
+ pos_embed_used=pos_embed_used, batch_size=batch_size
409
+ )
410
+ x += pos_embed
411
+
412
+ # The time embedding is added across the channels after the [CLS] token
413
+ if self.neural_tokenizer:
414
+ num_ch = self.n_chans
415
+ else:
416
+ num_ch = n_patch
417
+ time_embed = self._adj_temporal_embedding(
418
+ num_ch=num_ch, batch_size=batch_size, dim_embed=temporal
419
+ )
420
+ x[:, 1:, :] += time_embed
421
+
422
+ x = self.pos_drop(x)
423
+
424
+ for blk in self.blocks:
425
+ x = blk(x)
426
+
427
+ x = self.norm(x)
428
+ if self.fc_norm is not None:
429
+ if return_all_tokens:
430
+ return self.fc_norm(x)
431
+ temporal = x[:, 1:, :]
432
+ if return_patch_tokens:
433
+ return self.fc_norm(temporal)
434
+ return self.fc_norm(temporal.mean(1))
435
+ else:
436
+ if return_all_tokens:
437
+ return x
438
+ elif return_patch_tokens:
439
+ return x[:, 1:]
440
+ return x[:, 0]
441
+
442
+ def forward(
443
+ self,
444
+ x,
445
+ input_chans=None,
446
+ return_patch_tokens=False,
447
+ return_all_tokens=False,
448
+ ):
449
+ """
450
+ Forward the input EEG data through the model.
451
+
452
+ Parameters
453
+ ----------
454
+ x: torch.Tensor
455
+ The input data with shape (batch, n_chans, n_times)
456
+ or (batch, n_chans, n_patches, patch size).
457
+ input_chans: int
458
+ An input channel to select some dimensions
459
+ return_patch_tokens: bool
460
+ Return the patch tokens
461
+ return_all_tokens: bool
462
+ Return all the tokens
463
+
464
+ Returns
465
+ -------
466
+ torch.Tensor
467
+ The output of the model with dimensions (batch, n_outputs)
468
+ """
469
+ x = self.forward_features(
470
+ x,
471
+ input_chans=input_chans,
472
+ return_patch_tokens=return_patch_tokens,
473
+ return_all_tokens=return_all_tokens,
474
+ )
475
+ x = self.final_layer(x)
476
+ return x
477
+
478
+ def get_classifier(self):
479
+ """
480
+ Get the classifier of the model.
481
+
482
+ Returns
483
+ -------
484
+ torch.nn.Module
485
+ The classifier of the head model.
486
+ """
487
+ return self.final_layer
488
+
489
+ def reset_classifier(self, n_outputs):
490
+ """
491
+ Reset the classifier with the new number of classes.
492
+
493
+ Parameters
494
+ ----------
495
+ n_outputs : int
496
+ The new number of classes.
497
+ """
498
+ self.n_outputs = n_outputs
499
+ self.final_layer = (
500
+ nn.Linear(self.emb_dim, self.n_outputs)
501
+ if self.n_outputs > 0
502
+ else nn.Identity()
503
+ )
504
+
505
+ def _adj_temporal_embedding(self, num_ch, batch_size, dim_embed=None):
506
+ """
507
+ Adjust the dimensions of the time embedding to match the
508
+ number of channels.
509
+
510
+ Parameters
511
+ ----------
512
+ num_ch : int
513
+ The number of channels or number of code books vectors.
514
+ batch_size : int
515
+ Batch size of the input data.
516
+
517
+ Returns
518
+ -------
519
+ temporal_embedding : torch.Tensor
520
+ The adjusted time embedding to be added across the channels
521
+ after the [CLS] token. (x[:, 1:, :] += time_embed)
522
+ """
523
+ if dim_embed is None:
524
+ cut_dimension = self.patch_size
525
+ else:
526
+ cut_dimension = dim_embed
527
+ # first step will be match the time_embed to the number of channels
528
+ temporal_embedding = self.temporal_embedding[:, 1:cut_dimension, :]
529
+ # Add a new dimension to the time embedding
530
+ # e.g. (batch, 62, 200) -> (batch, 1, 62, 200)
531
+ temporal_embedding = temporal_embedding.unsqueeze(1)
532
+ # Expand the time embedding to match the number of channels
533
+ # or number of patches from
534
+ temporal_embedding = temporal_embedding.expand(batch_size, num_ch, -1, -1)
535
+ # Flatten the intermediate dimensions
536
+ temporal_embedding = temporal_embedding.flatten(1, 2)
537
+ return temporal_embedding
538
+
539
+ def _adj_position_embedding(self, pos_embed_used, batch_size):
540
+ """
541
+ Adjust the dimensions of position embedding to match the
542
+ number of patches.
543
+
544
+ Parameters
545
+ ----------
546
+ pos_embed_used : torch.Tensor
547
+ The position embedding to be adjusted.
548
+ batch_size : int
549
+ The number of batches.
550
+
551
+ Returns
552
+ -------
553
+ pos_embed : torch.Tensor
554
+ The adjusted position embedding
555
+ """
556
+ # [CLS] token has no position embedding
557
+ pos_embed = pos_embed_used[:, 1:, :]
558
+ # Adding a new dimension to the position embedding
559
+ pos_embed = pos_embed.unsqueeze(2)
560
+ # Need to expand the position embedding to match the number of
561
+ # n_patches
562
+ pos_embed = pos_embed.expand(batch_size, -1, self.patch_embed[0].n_patchs, -1)
563
+ # Flatten the intermediate dimensions,
564
+ # such as the number of patches and the "channels" dim
565
+ pos_embed = pos_embed.flatten(1, 2)
566
+ # Get the base position embedding
567
+ # This is the position embedding for the [CLS] token
568
+ base_pos = pos_embed[:, 0:1, :].expand(batch_size, -1, -1)
569
+ # Concatenate the base position embedding with the
570
+ # position embedding
571
+ pos_embed = torch.cat((base_pos, pos_embed), dim=1)
572
+ return pos_embed
573
+
574
+
575
+ class _SegmentPatch(nn.Module):
576
+ """Segment and Patch for EEG data.
577
+
578
+ Adapted Patch Embedding inspired in the Visual Transform approach
579
+ to extract the learned segmentor, we expect get the input shape as:
580
+ (Batch, Number of Channels, number of times points).
581
+
582
+ We apply a 2D convolution with kernel size of (1, patch_size)
583
+ and a stride of (1, patch_size).
584
+
585
+ The results output shape will be:
586
+ (Batch, Number of Channels, Number of patches, patch size).
587
+
588
+ This way, we learned a convolution to segment the input shape.
589
+
590
+ The number of patches is calculated as the number of samples divided
591
+ by the patch size.
592
+
593
+ Parameters:
594
+ -----------
595
+ n_times: int (default=2000)
596
+ Number of temporal components of the input tensor.
597
+ in_chans: int (default=1)
598
+ number of electrods from the EEG signal
599
+ emb_dim: int (default=200)
600
+ Number of n_output to be used in the convolution, here,
601
+ we used the same as patch_size.
602
+ patch_size: int (default=200)
603
+ Size of the patch, default is 1-seconds with 200Hz.
604
+ Returns:
605
+ --------
606
+ x_patched: torch.Tensor
607
+ Output tensor of shape (batch, n_chans, num_patches, emb_dim).
608
+ """
609
+
610
+ def __init__(
611
+ self, n_times=2000, patch_size=200, n_chans=1, emb_dim=200, learned_patcher=True
612
+ ):
613
+ super().__init__()
614
+
615
+ self.n_times = n_times
616
+ self.patch_size = patch_size
617
+ self.n_patchs = n_times // patch_size
618
+ self.emb_dim = emb_dim
619
+ self.n_chans = n_chans
620
+ self.learned_patcher = learned_patcher
621
+
622
+ self.patcher = nn.Conv1d(
623
+ in_channels=1,
624
+ out_channels=self.emb_dim,
625
+ kernel_size=self.patch_size,
626
+ stride=self.patch_size,
627
+ )
628
+
629
+ self.adding_extra_dim = Rearrange(
630
+ pattern="batch nchans temporal -> (batch nchans) 1 temporal"
631
+ )
632
+
633
+ def forward(self, x):
634
+ """
635
+ Using an 1D convolution to generate segments of EEG signal.
636
+
637
+ Parameters:
638
+ -----------
639
+ X: Tensor
640
+ [batch, n_chans, n_times]
641
+
642
+ Returns:
643
+ --------
644
+ X_patch: Tensor
645
+ [batch, n_chans, n_times//patch_size, patch_size]
646
+ """
647
+ batch_size, _, _ = x.shape
648
+ # Input shape: [batch, n_chs, n_times]
649
+
650
+ # First, rearrange input to treat the channel dimension 'n_chs' as
651
+ # separate 'dimension' in batch for Conv1d
652
+ # This requires reshaping x to have a height of 1 for each EEG sample.
653
+ if self.learned_patcher:
654
+ x = self.adding_extra_dim(x)
655
+
656
+ # Apply the convolution along the temporal dimension
657
+ # Conv2d output shape: [(batch*n_chs), emb_dim, n_patches]
658
+ x = self.patcher(x)
659
+
660
+ # Now, rearrange output to get back to a batch-first format,
661
+ # combining embedded patches with channel information
662
+ # Assuming you want [batch, n_chs, n_patches, emb_dim]
663
+ # as output, which keeps channel information
664
+ # This treats each patch embedding as a feature alongside channels
665
+ x = rearrange(
666
+ x,
667
+ pattern="(batch nchans) embed npatchs -> batch nchans npatchs embed",
668
+ batch=batch_size,
669
+ nchans=self.n_chans,
670
+ )
671
+ else:
672
+ x = x.view(
673
+ batch_size,
674
+ self.n_chans,
675
+ self.n_times // self.patch_size,
676
+ self.patch_size,
677
+ )
678
+ return x
679
+
680
+
681
+ class _PatchEmbed(nn.Module):
682
+ """EEG to Patch Embedding.
683
+
684
+ This code is used when we want to apply the patch embedding
685
+ after the codebook layer.
686
+
687
+ Parameters:
688
+ -----------
689
+ n_times: int (default=2000)
690
+ Number of temporal components of the input tensor.
691
+ patch_size: int (default=200)
692
+ Size of the patch, default is 1-seconds with 200Hz.
693
+ in_channels: int (default=1)
694
+ Number of input channels for to be used in the convolution.
695
+ emb_dim: int (default=200)
696
+ Number of out_channes to be used in the convolution, here,
697
+ we used the same as patch_size.
698
+ n_codebooks: int (default=62)
699
+ Number of patches to be used in the convolution, here,
700
+ we used the same as n_times // patch_size.
701
+ """
702
+
703
+ def __init__(
704
+ self, n_times=2000, patch_size=200, in_channels=1, emb_dim=200, n_codebooks=62
705
+ ):
706
+ super().__init__()
707
+ self.n_times = n_times
708
+ self.patch_size = patch_size
709
+ self.patch_shape = (1, self.n_times // self.patch_size)
710
+ n_patchs = n_codebooks * (self.n_times // self.patch_size)
711
+
712
+ self.n_patchs = n_patchs
713
+
714
+ self.proj = nn.Conv2d(
715
+ in_channels=in_channels,
716
+ out_channels=emb_dim,
717
+ kernel_size=(1, self.patch_size),
718
+ stride=(1, self.patch_size),
719
+ )
720
+
721
+ self.merge_transpose = Rearrange(
722
+ "Batch ch patch spatch -> Batch patch spatch ch",
723
+ )
724
+
725
+ def forward(self, x):
726
+ """
727
+ Apply the convolution to the input tensor.
728
+ then merge the output tensor to the desired shape.
729
+
730
+ Parameters:
731
+ -----------
732
+ x: torch.Tensor
733
+ Input tensor of shape (Batch, Channels, n_patchs, patch_size).
734
+
735
+ Return:
736
+ -------
737
+ x: torch.Tensor
738
+ Output tensor of shape (Batch, n_patchs, patch_size, channels).
739
+ """
740
+ x = self.proj(x)
741
+ x = self.merge_transpose(x)
742
+ return x
743
+
744
+
745
+ class _Attention(nn.Module):
746
+ """
747
+ Attention with the options of Window-based multi-head self attention (W-MSA).
748
+
749
+ This code is strong inspired by:
750
+ https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py#L77
751
+
752
+ Basically, the attention module is a linear layer that takes the input
753
+ tensor and returns the output tensor. The input tensor is first passed
754
+ through a linear layer to get the query, key, and value tensors. Then,
755
+ the query tensor is multiplied by the scale factor and the result is
756
+ multiplied by the transpose of the key tensor.
757
+
758
+ The flag window_size is used to determine if the attention is
759
+ window-based or not.
760
+
761
+ Parameters:
762
+ -----------
763
+ dim: int
764
+ Number of input features.
765
+ num_heads: int (default=8)
766
+ Number of attention heads.
767
+ qkv_bias: bool (default=False)
768
+ If True, add a learnable bias to the query, key, and value tensors.
769
+ qk_norm: nn.LayerNorm (default=None)
770
+ If not None, apply LayerNorm to the query and key tensors.
771
+ qk_scale: float (default=None)
772
+ If not None, use this value as the scale factor. If None,
773
+ use head_dim**-0.5, where head_dim = dim // num_heads.
774
+ attn_drop: float (default=0.0)
775
+ Dropout rate for the attention weights.
776
+ proj_drop: float (default=0.0)
777
+ Dropout rate for the output tensor.
778
+ window_size: bool (default=None)
779
+ If not None, use window-based multi-head self attention based on Swin Transformer.
780
+ attn_head_dim: int (default=None)
781
+ If not None, use this value as the head_dim. If None, use dim // num_heads.
782
+ """
783
+
784
+ def __init__(
785
+ self,
786
+ dim,
787
+ num_heads=8,
788
+ qkv_bias=False,
789
+ qk_norm=None,
790
+ qk_scale=None,
791
+ attn_drop=0.0,
792
+ proj_drop=0.0,
793
+ window_size=None,
794
+ attn_head_dim=None,
795
+ ):
796
+ super().__init__()
797
+ self.num_heads = num_heads
798
+ head_dim = dim // num_heads
799
+ if attn_head_dim is not None:
800
+ head_dim = attn_head_dim
801
+ all_head_dim = head_dim * self.num_heads
802
+ self.scale = qk_scale or head_dim**-0.5
803
+
804
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
805
+ if qkv_bias:
806
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
807
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
808
+ else:
809
+ self.q_bias = None
810
+ self.v_bias = None
811
+
812
+ if qk_norm is not None:
813
+ self.q_norm = qk_norm(head_dim)
814
+ self.k_norm = qk_norm(head_dim)
815
+ else:
816
+ self.q_norm = None
817
+ self.k_norm = None
818
+
819
+ if window_size:
820
+ self.window_size = window_size
821
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
822
+ 2 * window_size[1] - 1
823
+ ) + 3
824
+ self.relative_position_bias_table = nn.Parameter(
825
+ torch.zeros(self.num_relative_distance, num_heads)
826
+ ) # 2*Wh-1 * 2*Ww-1, nH
827
+ # cls to token & token 2 cls & cls to cls
828
+
829
+ # get pair-wise relative position index for each token inside the window
830
+ coords_h = torch.arange(window_size[0])
831
+ coords_w = torch.arange(window_size[1])
832
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
833
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
834
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
835
+ # 2, Wh*Ww, Wh*Ww
836
+ relative_coords = relative_coords.permute(
837
+ 1, 2, 0
838
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
839
+ relative_coords[:, :, 0] += window_size[0] - 1
840
+ # shift to start from 0
841
+ relative_coords[:, :, 1] += window_size[1] - 1
842
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
843
+ relative_position_index = torch.zeros(
844
+ size=(window_size[0] * window_size[1] + 1,) * 2,
845
+ dtype=relative_coords.dtype,
846
+ )
847
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
848
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
849
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
850
+ relative_position_index[0, 0] = self.num_relative_distance - 1
851
+
852
+ self.register_buffer("relative_position_index", relative_position_index)
853
+ else:
854
+ self.window_size = None
855
+ self.relative_position_bias_table = None
856
+ self.relative_position_index = None
857
+
858
+ self.attn_drop = nn.Dropout(attn_drop)
859
+ self.proj = nn.Linear(all_head_dim, dim)
860
+ self.proj_drop = nn.Dropout(proj_drop)
861
+
862
+ def forward(
863
+ self,
864
+ x: torch.Tensor,
865
+ return_attention=False,
866
+ return_qkv=False,
867
+ ):
868
+ """
869
+ Apply the attention mechanism to the input tensor.
870
+
871
+ Parameters:
872
+ -----------
873
+ x: torch.Tensor
874
+ Input tensor of shape (Batch, N, C).
875
+ return_attention: bool (default=False)
876
+ If True, return the attention weights.
877
+ return_qkv: bool (default=False)
878
+ If True, return the query, key, and value tensors together with
879
+ the output tensor.
880
+ Returns:
881
+ --------
882
+ x: torch.Tensor
883
+ Output tensor of shape (Batch, N, C).
884
+ qkv: torch.Tensor (optional)
885
+ Query, key, and value tensors of shape
886
+ (Batch, N, 3, num_heads, C // num_heads).
887
+ """
888
+ B, N, _ = x.shape
889
+ qkv_bias = None
890
+ if self.q_bias is not None:
891
+ qkv_bias = torch.cat(
892
+ (
893
+ self.q_bias,
894
+ torch.zeros_like(self.v_bias, requires_grad=False),
895
+ self.v_bias,
896
+ )
897
+ )
898
+ qkv = nn.functional.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
899
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
900
+ q, k, v = (
901
+ qkv[0],
902
+ qkv[1],
903
+ qkv[2],
904
+ ) # make torchscript happy (cannot use tensor as tuple) (B, H, N, C)
905
+ if self.q_norm is not None:
906
+ q = self.q_norm(q).type_as(v)
907
+ if self.k_norm is not None:
908
+ k = self.k_norm(k).type_as(v)
909
+
910
+ q = q * self.scale
911
+ attn = q @ k.transpose(-2, -1)
912
+
913
+ if self.relative_position_bias_table is not None:
914
+ relative_position_bias = self.relative_position_bias_table[
915
+ self.relative_position_index.view(-1)
916
+ ].view(
917
+ self.window_size[0] * self.window_size[1] + 1,
918
+ self.window_size[0] * self.window_size[1] + 1,
919
+ -1,
920
+ ) # Wh*Ww,Wh*Ww,nH
921
+ relative_position_bias = relative_position_bias.permute(
922
+ 2, 0, 1
923
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
924
+ attn = attn + relative_position_bias.unsqueeze(0)
925
+
926
+ attn = attn.softmax(dim=-1)
927
+ attn = self.attn_drop(attn)
928
+
929
+ if return_attention:
930
+ return attn
931
+
932
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
933
+
934
+ x = self.proj(x)
935
+ x = self.proj_drop(x)
936
+
937
+ if return_qkv:
938
+ return x, qkv
939
+
940
+ return x
941
+
942
+
943
+ class _WindowsAttentionBlock(nn.Module):
944
+ """Blocks of Windows Attention with Layer norm and MLP.
945
+
946
+ Notes: This code is strong inspired by:
947
+ BeiTv2 from Microsoft.
948
+
949
+ Parameters:
950
+ -----------
951
+ dim: int
952
+ Number of input features.
953
+ num_heads: int (default=8)
954
+ Number of attention heads.
955
+ mlp_ratio: float (default=4.0)
956
+ Ratio to increase the hidden features from input features in the MLP layer
957
+ qkv_bias: bool (default=False)
958
+ If True, add a learnable bias to the query, key, and value tensors.
959
+ qk_norm: nn.LayerNorm (default=None)
960
+ If not None, apply LayerNorm to the query and key tensors.
961
+ qk_scale: float (default=None)
962
+ If not None, use this value as the scale factor. If None,
963
+ use head_dim**-0.5, where head_dim = dim // num_heads.
964
+ drop: float (default=0.0)
965
+ Dropout rate for the output tensor.
966
+ attn_drop: float (default=0.0)
967
+ Dropout rate for the attention weights.
968
+ drop_path: float (default=0.0)
969
+ Dropout rate for the output tensor.
970
+ init_values: float (default=None)
971
+ If not None, use this value to initialize the gamma_1 and gamma_2
972
+ parameters.
973
+ activation: nn.GELU (default)
974
+ Activation function.
975
+ norm_layer: nn.LayerNorm (default)
976
+ Normalization layer.
977
+ window_size: bool (default=None)
978
+ If not None, use window-based multi-head self attention based on
979
+ Swin Transformer.
980
+ attn_head_dim: int (default=None)
981
+ If not None, use this value as the head_dim. If None,
982
+ the classes use dim // num_heads
983
+
984
+ Returns:
985
+ --------
986
+ x: torch.Tensor
987
+ Output tensor of shape (Batch, N, C). [I think]
988
+
989
+ """
990
+
991
+ def __init__(
992
+ self,
993
+ dim: int,
994
+ num_heads: int,
995
+ mlp_ratio=4.0,
996
+ qkv_bias=False,
997
+ qk_norm=None,
998
+ qk_scale=None,
999
+ drop=0.0,
1000
+ attn_drop=0.0,
1001
+ drop_path=0.0,
1002
+ init_values=None,
1003
+ activation: nn.Module = nn.GELU,
1004
+ norm_layer=nn.LayerNorm,
1005
+ window_size=None,
1006
+ attn_head_dim=None,
1007
+ ):
1008
+ super().__init__()
1009
+ self.norm1 = norm_layer(dim)
1010
+ self.attn = _Attention(
1011
+ dim,
1012
+ num_heads=num_heads,
1013
+ qkv_bias=qkv_bias,
1014
+ qk_norm=qk_norm,
1015
+ qk_scale=qk_scale,
1016
+ attn_drop=attn_drop,
1017
+ proj_drop=drop,
1018
+ window_size=window_size,
1019
+ attn_head_dim=attn_head_dim,
1020
+ )
1021
+
1022
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
1023
+ self.norm2 = norm_layer(dim)
1024
+ mlp_hidden_dim = int(dim * mlp_ratio)
1025
+ self.mlp = MLP(
1026
+ in_features=dim,
1027
+ hidden_features=[mlp_hidden_dim],
1028
+ activation=activation,
1029
+ drop=drop,
1030
+ )
1031
+
1032
+ if init_values is not None and init_values > 0:
1033
+ self.gamma_1 = nn.Parameter(
1034
+ init_values * torch.ones((dim)), requires_grad=True
1035
+ )
1036
+ self.gamma_2 = nn.Parameter(
1037
+ init_values * torch.ones((dim)), requires_grad=True
1038
+ )
1039
+ else:
1040
+ self.gamma_1, self.gamma_2 = None, None
1041
+
1042
+ def forward(self, x, return_attention=False, return_qkv=False):
1043
+ """
1044
+ Apply the attention mechanism to the input tensor.
1045
+ Parameters
1046
+ ----------
1047
+ x: torch.Tensor
1048
+ Input tensor of shape (Batch, chs, npatchs, patch).
1049
+ return_attention: bool (default=False)
1050
+ If True, return the attention weights.
1051
+ return_qkv: bool (default=False)
1052
+ If True, return the query, key, and value tensors together with
1053
+ the output tensor.
1054
+
1055
+ Returns
1056
+ -------
1057
+ torch.Tensor
1058
+ Output tensor of shape (Batch, chs, npatchs, patch).
1059
+ """
1060
+
1061
+ if return_attention:
1062
+ return self.attn(self.norm1(x), return_attention=True)
1063
+ if return_qkv:
1064
+ y, qkv = self.attn(self.norm1(x), return_qkv=return_qkv)
1065
+ x = x + self.drop_path(self.gamma_1 * y)
1066
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
1067
+ return x, qkv
1068
+
1069
+ if self.gamma_1 is None:
1070
+ x = x + self.drop_path(self.attn(self.norm1(x)))
1071
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
1072
+ else:
1073
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
1074
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
1075
+ return x
1076
+
1077
+
1078
+ class _TemporalConv(nn.Module):
1079
+ """
1080
+ Temporal Convolutional Module inspired by Visual Transformer.
1081
+
1082
+ In this module we apply the follow steps three times repeatedly
1083
+ to the input tensor, reducing the temporal dimension only in the first.
1084
+ - Apply a 2D convolution.
1085
+ - Apply a GELU activation function.
1086
+ - Apply a GroupNorm with 4 groups.
1087
+
1088
+ Parameters:
1089
+ -----------
1090
+ in_chans: int (default=1)
1091
+ Number of input channels.
1092
+ out_chans: int (default=8)
1093
+ Number of output channels.
1094
+ num_groups: int (default=4)
1095
+ Number of groups for GroupNorm.
1096
+ kernel_size_1: tuple (default=(1, 15))
1097
+ Kernel size for the first convolution.
1098
+ kernel_size_2: tuple (default=(1, 3))
1099
+ Kernel size for the second and third convolutions.
1100
+ stride_1: tuple (default=(1, 8))
1101
+ Stride for the first convolution.
1102
+ padding_1: tuple (default=(0, 7))
1103
+ Padding for the first convolution.
1104
+ padding_2: tuple (default=(0, 1))
1105
+ Padding for the second and third convolutions.
1106
+ activation: nn.Module, default=nn.GELU
1107
+ Activation function class to apply. Should be a PyTorch activation
1108
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.GELU``.
1109
+
1110
+ Returns:
1111
+ --------
1112
+ x: torch.Tensor
1113
+ Output tensor of shape (Batch, NA, Temporal Channel).
1114
+ """
1115
+
1116
+ def __init__(
1117
+ self,
1118
+ in_channels=1,
1119
+ out_channels=8,
1120
+ num_groups=4,
1121
+ kernel_size_1=(1, 15),
1122
+ stride_1=(1, 8),
1123
+ padding_1=(0, 7),
1124
+ kernel_size_2=(1, 3),
1125
+ padding_2=(0, 1),
1126
+ activation: nn.Module = nn.GELU,
1127
+ ):
1128
+ super().__init__()
1129
+
1130
+ # Here, we use the Rearrange layer from einops to flatten the input
1131
+ # tensor to a 2D tensor, so we can apply 2D convolutions.
1132
+ self.channel_patch_flatten = Rearrange(
1133
+ "Batch chs npat spatch -> Batch () (chs npat) spatch"
1134
+ )
1135
+
1136
+ self.conv1 = nn.Conv2d(
1137
+ in_channels=in_channels,
1138
+ out_channels=out_channels,
1139
+ kernel_size=kernel_size_1,
1140
+ stride=stride_1,
1141
+ padding=padding_1,
1142
+ )
1143
+ self.act_layer_1 = activation()
1144
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)
1145
+
1146
+ self.conv2 = nn.Conv2d(
1147
+ in_channels=out_channels,
1148
+ out_channels=out_channels,
1149
+ kernel_size=kernel_size_2,
1150
+ padding=padding_2,
1151
+ )
1152
+ self.act_layer_2 = activation()
1153
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)
1154
+
1155
+ self.conv3 = nn.Conv2d(
1156
+ in_channels=out_channels,
1157
+ out_channels=out_channels,
1158
+ kernel_size=kernel_size_2,
1159
+ padding=padding_2,
1160
+ )
1161
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)
1162
+ self.act_layer_3 = activation()
1163
+
1164
+ self.transpose_temporal_channel = Rearrange("Batch C NA T -> Batch NA (T C)")
1165
+
1166
+ def forward(self, x):
1167
+ """
1168
+ Apply 3 steps of 2D convolution, GELU activation function,
1169
+ and GroupNorm.
1170
+
1171
+ Parameters:
1172
+ -----------
1173
+ x: torch.Tensor
1174
+ Input tensor of shape (Batch, Channels, n_patchs, size_patch).
1175
+
1176
+ Returns:
1177
+ --------
1178
+ x: torch.Tensor
1179
+ Output tensor of shape (Batch, NA, Temporal Channel).
1180
+ """
1181
+ x = self.channel_patch_flatten(x)
1182
+ x = self.act_layer_1(self.norm1(self.conv1(x)))
1183
+ x = self.act_layer_2(self.norm2(self.conv2(x)))
1184
+ x = self.act_layer_3(self.norm3(self.conv3(x)))
1185
+ x = self.transpose_temporal_channel(x)
1186
+ return x