braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,504 @@
1
+ import math
2
+ from warnings import warn
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from linear_attention_transformer import LinearAttentionTransformer
7
+
8
+ from braindecode.models.base import EEGModuleMixin
9
+
10
+
11
+ class BIOT(EEGModuleMixin, nn.Module):
12
+ """BIOT from Yang et al. (2023) [Yang2023]_
13
+
14
+ .. figure:: https://braindecode.org/dev/_static/model/biot.jpg
15
+ :align: center
16
+ :alt: BioT
17
+
18
+ BIOT: Cross-data Biosignal Learning in the Wild.
19
+
20
+ BIOT is a large language model for biosignal classification. It is
21
+ a wrapper around the `BIOTEncoder` and `ClassificationHead` modules.
22
+
23
+ It is designed for N-dimensional biosignal data such as EEG, ECG, etc.
24
+ The method was proposed by Yang et al. [Yang2023]_ and the code is
25
+ available at [Code2023]_
26
+
27
+ The model is trained with a contrastive loss on large EEG datasets
28
+ TUH Abnormal EEG Corpus with 400K samples and Sleep Heart Health Study
29
+ 5M. Here, we only provide the model architecture, not the pre-trained
30
+ weights or contrastive loss training.
31
+
32
+ The architecture is based on the `LinearAttentionTransformer` and
33
+ `PatchFrequencyEmbedding` modules.
34
+ The `BIOTEncoder` is a transformer that takes the input data and outputs
35
+ a fixed-size representation of the input data. More details are
36
+ present in the `BIOTEncoder` class.
37
+
38
+ The `ClassificationHead` is an ELU activation layer, followed by a simple
39
+ linear layer that takes the output of the `BIOTEncoder` and outputs
40
+ the classification probabilities.
41
+
42
+ .. versionadded:: 0.9
43
+
44
+ Parameters
45
+ ----------
46
+ emb_size : int, optional
47
+ The size of the embedding layer, by default 256
48
+ att_num_heads : int, optional
49
+ The number of attention heads, by default 8
50
+ n_layers : int, optional
51
+ The number of transformer layers, by default 4
52
+ activation: nn.Module, default=nn.ELU
53
+ Activation function class to apply. Should be a PyTorch activation
54
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
55
+ return_feature: bool, optional
56
+ Changing the output for the neural network. Default is single tensor
57
+ when return_feature is True, return embedding space too.
58
+ Default is False.
59
+ hop_length: int, optional
60
+ The hop length for the torch.stft transformation in the
61
+ encoder. The default is 100.
62
+ sfreq: int, optional
63
+ The sfreq parameter for the encoder. The default is 200
64
+
65
+ References
66
+ ----------
67
+ .. [Yang2023] Yang, C., Westover, M.B. and Sun, J., 2023, November. BIOT:
68
+ Biosignal Transformer for Cross-data Learning in the Wild. In Thirty-seventh
69
+ Conference on Neural Information Processing Systems, NeurIPS.
70
+ .. [Code2023] Yang, C., Westover, M.B. and Sun, J., 2023. BIOT
71
+ Biosignal Transformer for Cross-data Learning in the Wild.
72
+ GitHub https://github.com/ycq091044/BIOT (accessed 2024-02-13)
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ emb_size=256,
78
+ att_num_heads=8,
79
+ n_layers=4,
80
+ sfreq=200,
81
+ hop_length=100,
82
+ return_feature=False,
83
+ n_outputs=None,
84
+ n_chans=None,
85
+ chs_info=None,
86
+ n_times=None,
87
+ input_window_seconds=None,
88
+ activation: nn.Module = nn.ELU,
89
+ drop_prob: float = 0.5,
90
+ # Parameters for the encoder
91
+ max_seq_len: int = 1024,
92
+ attn_dropout=0.2,
93
+ attn_layer_dropout=0.2,
94
+ ):
95
+ super().__init__(
96
+ n_outputs=n_outputs,
97
+ n_chans=n_chans,
98
+ chs_info=chs_info,
99
+ n_times=n_times,
100
+ input_window_seconds=input_window_seconds,
101
+ sfreq=sfreq,
102
+ )
103
+ del n_outputs, n_chans, chs_info, n_times, sfreq
104
+ self.emb_size = emb_size
105
+ self.hop_length = hop_length
106
+ self.att_num_heads = att_num_heads
107
+ self.n_layers = n_layers
108
+ self.return_feature = return_feature
109
+ if (self.sfreq != 200) & (self.sfreq is not None):
110
+ warn(
111
+ "This model has only been trained on a dataset with 200 Hz. "
112
+ + "no guarantee to generalize well with the default parameters",
113
+ UserWarning,
114
+ )
115
+ if self.n_chans > emb_size:
116
+ warn(
117
+ "The number of channels is larger than the embedding size. "
118
+ + "This may cause overfitting. Consider using a larger "
119
+ + "embedding size or a smaller number of channels.",
120
+ UserWarning,
121
+ )
122
+ if self.hop_length > self.sfreq:
123
+ warn(
124
+ "The hop length is larger than the sampling frequency. "
125
+ + "This may cause aliasing. Consider using a smaller "
126
+ "hop length.",
127
+ UserWarning,
128
+ )
129
+ hop_length = self.sfreq // 2
130
+
131
+ if self.input_window_seconds < 1.0:
132
+ warning_msg = (
133
+ "The input window is less than 1 second, which may not be "
134
+ "sufficient for the model to learn meaningful representations."
135
+ "Changing the `n_fft` to `n_times`."
136
+ )
137
+ warn(warning_msg, UserWarning)
138
+ self.n_fft = self.n_times
139
+ else:
140
+ self.n_fft = int(self.sfreq)
141
+
142
+ self.encoder = _BIOTEncoder(
143
+ emb_size=emb_size,
144
+ att_num_heads=att_num_heads,
145
+ n_layers=n_layers,
146
+ n_chans=self.n_chans,
147
+ n_fft=self.n_fft,
148
+ hop_length=hop_length,
149
+ drop_prob=drop_prob,
150
+ max_seq_len=max_seq_len,
151
+ attn_dropout=attn_dropout,
152
+ attn_layer_dropout=attn_layer_dropout,
153
+ )
154
+
155
+ self.final_layer = _ClassificationHead(
156
+ emb_size=emb_size,
157
+ n_outputs=self.n_outputs,
158
+ activation=activation,
159
+ )
160
+
161
+ def forward(self, x):
162
+ """
163
+ Pass the input through the BIOT encoder, and then through the
164
+ classification head.
165
+
166
+ Parameters
167
+ ----------
168
+ x: Tensor
169
+ (batch_size, n_channels, n_times)
170
+
171
+ Returns
172
+ -------
173
+ out: Tensor
174
+ (batch_size, n_outputs)
175
+ (out, emb): tuple Tensor
176
+ (batch_size, n_outputs), (batch_size, emb_size)
177
+ """
178
+ emb = self.encoder(x)
179
+ x = self.final_layer(emb)
180
+
181
+ if self.return_feature:
182
+ return x, emb
183
+ else:
184
+ return x
185
+
186
+
187
+ class _PatchFrequencyEmbedding(nn.Module):
188
+ """
189
+ Patch Frequency Embedding.
190
+
191
+ A simple linear layer is used to learn some representation over the
192
+ frequency domain with permutation in the frequency axis.
193
+
194
+ Parameters
195
+ ----------
196
+ emb_size: int
197
+ The size of the embedding layer
198
+ n_freq: int
199
+ The number of frequency points after the Fourier transform
200
+
201
+ Returns
202
+ -------
203
+ out: Tensor
204
+ (batch, time, emb_size)
205
+ """
206
+
207
+ def __init__(self, emb_size: int = 256, n_freq: int = 101):
208
+ super().__init__()
209
+ self.projection = nn.Linear(n_freq, emb_size)
210
+
211
+ def forward(self, x):
212
+ """
213
+ Parameters
214
+ ----------
215
+ x: Tensor
216
+ (batch, time, n_freq)
217
+
218
+ Returns
219
+ -------
220
+ out: Tensor
221
+ (batch, time, emb_size)
222
+
223
+ """
224
+ x = x.permute(0, 2, 1)
225
+ x = self.projection(x)
226
+ return x
227
+
228
+
229
+ class _ClassificationHead(nn.Sequential):
230
+ """
231
+ Classification head for the BIOT model.
232
+
233
+ Simple linear layer with ELU activation function.
234
+
235
+ Parameters
236
+ ----------
237
+ emb_size: int
238
+ The size of the embedding layer
239
+ n_outputs: int
240
+ The number of classes
241
+ activation: nn.Module, default=nn.ELU
242
+ Activation function class to apply. Should be a PyTorch activation
243
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
244
+
245
+ Returns
246
+ -------
247
+ out: Tensor
248
+ (batch, n_outputs)
249
+ """
250
+
251
+ def __init__(self, emb_size: int, n_outputs: int, activation: nn.Module = nn.ELU):
252
+ super().__init__()
253
+ self.activation_layer = activation()
254
+ self.classification_head = nn.Linear(emb_size, n_outputs)
255
+
256
+ def forward(self, x):
257
+ x = self.activation_layer(x)
258
+ out = self.classification_head(x)
259
+ return out
260
+
261
+
262
+ class _PositionalEncoding(nn.Module):
263
+ """
264
+ Positional Encoding.
265
+
266
+ We first create a `pe` zero matrix of shape (max_len, d_model) where max_len is the
267
+ the maximum length of the sequence, and emb_size is the size of the embedding.
268
+
269
+ Then we create a `position` tensor of shape (max_len, 1) with indices from 0 to max_len
270
+ and a `div_term` tensor of shape (d_model // 2) with the exponential of the
271
+ multiplication of the indices from 0 to d_model by -log(10000.0) / d_model.
272
+
273
+ For more details about the positional encoding, see the `Attention is All You Need` paper.
274
+
275
+ Parameters
276
+ ----------
277
+ emb_size: int
278
+ The size of the embedding layer
279
+ dropout: float
280
+ The dropout rate
281
+ max_len: int
282
+ The maximum length of the sequence
283
+ drop_prob : float, default=0.5
284
+ The dropout rate for regularization. Values should be between 0 and 1.
285
+
286
+ Returns
287
+ -------
288
+ out: Tensor
289
+ (batch, max_len, d_model)
290
+ """
291
+
292
+ def __init__(self, emb_size: int, drop_prob: float = 0.1, max_len: int = 1000):
293
+ super(_PositionalEncoding, self).__init__()
294
+
295
+ # Compute the positional encodings once in log space.
296
+ pe = torch.zeros(max_len, emb_size)
297
+ position = torch.arange(0, max_len).unsqueeze(1).float()
298
+ div_term = torch.exp(
299
+ torch.arange(0, emb_size, 2).float() * -(math.log(10000.0) / emb_size)
300
+ )
301
+ pe[:, 0::2] = torch.sin(position * div_term)
302
+ pe[:, 1::2] = torch.cos(position * div_term)
303
+ pe = pe.unsqueeze(0)
304
+ self.register_buffer("pe", pe)
305
+ self.dropout = nn.Dropout(p=drop_prob)
306
+
307
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
308
+ """
309
+ Parameters
310
+ ----------
311
+ x: FloatTensor
312
+ `embeddings,` shape (batch, max_len, d_model)
313
+
314
+ Returns
315
+ -------
316
+ Tensor
317
+ `encoder input`, shape (batch, max_len, d_model)
318
+ """
319
+ x = x + self.pe[:, : x.size(1)]
320
+ return self.dropout(x)
321
+
322
+
323
+ class _BIOTEncoder(nn.Module):
324
+ """
325
+ BIOT Encoder.
326
+
327
+ The BIOT encoder is a transformer that takes the time series input data and
328
+ return a fixed-size embedding representation of the input data.
329
+ The architecture is based on the `LinearAttentionTransformer` and
330
+ `PatchFrequencyEmbedding` modules.
331
+
332
+ The input data is transformed into a spectrogram and then embedded using a
333
+ "patch" embedding.
334
+ The channel token is added to the patch embedding, and then
335
+ positional encoding is applied (simple index positional).
336
+ The resulting embeddings are concatenated
337
+ and passed through a transformer layer. The mean across different channels
338
+ embeddings is returned.
339
+
340
+ Parameters
341
+ ----------
342
+ n_chans: int
343
+ The number of channels
344
+ emb_size: int
345
+ The size of the embedding layer
346
+ att_num_heads: int
347
+ The number of attention heads
348
+ n_layers: int
349
+ The number of transformer layers
350
+ n_fft: int
351
+ The number of Fourier transform points
352
+ hop_length: int (default 100)
353
+ The distance between neighboring sliding window frames
354
+ """
355
+
356
+ def __init__(
357
+ self,
358
+ emb_size=256, # The size of the embedding layer
359
+ att_num_heads=8, # The number of attention heads
360
+ n_chans=16, # The number of channels
361
+ n_layers=4, # The number of transformer layers
362
+ n_fft=200, # Related with the frequency resolution
363
+ hop_length=100,
364
+ drop_prob: float = 0.1,
365
+ max_seq_len: int = 1024, # The maximum sequence length
366
+ attn_dropout=0.2, # dropout post-attention
367
+ attn_layer_dropout=0.2, # dropout right after self-attention layer
368
+ ):
369
+ super().__init__()
370
+
371
+ self.n_fft = int(n_fft)
372
+ self.hop_length = hop_length
373
+
374
+ self.patch_embedding = _PatchFrequencyEmbedding(
375
+ emb_size=emb_size, n_freq=int(self.n_fft // 2 + 1)
376
+ )
377
+ self.transformer = LinearAttentionTransformer(
378
+ dim=emb_size,
379
+ heads=att_num_heads,
380
+ depth=n_layers,
381
+ max_seq_len=max_seq_len,
382
+ attn_layer_dropout=attn_layer_dropout,
383
+ attn_dropout=attn_dropout,
384
+ )
385
+ self.positional_encoding = _PositionalEncoding(emb_size, drop_prob=drop_prob)
386
+
387
+ # channel token, N_channels >= your actual channels
388
+ self.channel_tokens = nn.Embedding(
389
+ num_embeddings=n_chans, embedding_dim=emb_size
390
+ )
391
+ self.index = nn.Parameter(torch.LongTensor(range(n_chans)), requires_grad=False)
392
+
393
+ def stft(self, sample):
394
+ """
395
+ Short-time Fourier transform.
396
+ For more details, see `torch.stft`.
397
+
398
+ The size of the Fourier transform is obtained by `n_fft`, and the distance
399
+ between neighboring sliding window frames `hop_length` defined in
400
+ the __init__ functions.
401
+
402
+ Parameters
403
+ ----------
404
+ sample: Tensor
405
+ channel representation with size (batch_size, n_times)
406
+ Returns
407
+ -------
408
+ spectral: Tensor
409
+ Absolute value of the Fourier transform with size
410
+ (batch_size, n_fft // 2 + 1, n_times // hop_length + 1)
411
+ """
412
+ spectral = torch.stft(
413
+ input=sample.squeeze(1),
414
+ n_fft=int(self.n_fft),
415
+ hop_length=self.hop_length,
416
+ center=False,
417
+ onesided=True,
418
+ return_complex=True,
419
+ )
420
+ return torch.abs(spectral)
421
+
422
+ def forward(self, x, n_channel_offset=0, perturb=False):
423
+ """
424
+ Forward pass of the BIOT encoder.
425
+
426
+ For each channel, the input is transformed into a spectrogram
427
+ and then embedded using a patch embedding. The channel token
428
+ is added to the patch embedding, and then positional encoding
429
+ is applied. The resulting embeddings are concatenated and
430
+ passed through a transformer layer. The mean of the resulting
431
+ embeddings is returned.
432
+
433
+ For each channel in channels, the channels are transformed into a
434
+ spectrogram with STFT; The spectrogram representation is permuted
435
+ and passed through a linear layer to learn some representation over
436
+ the frequency domain.
437
+
438
+ For each embedding in the sequence, the channel token is added to
439
+ the patch embedding, and then positional encoding is applied.
440
+
441
+ The resulting embeddings are concatenated and passed through a
442
+ transformer layer. The mean of the resulting embeddings is returned.
443
+
444
+ Parameters
445
+ ----------
446
+ x: Tensor
447
+ (batch_size, n_channels, n_times)
448
+ n_channel_offset: int (default 0)
449
+ The offset term to be added to the channel tokens
450
+ perturb: bool (default False)
451
+ Randomly select a number of time steps and reduce the
452
+ channel embedding to those time steps.
453
+
454
+ Returns
455
+ -------
456
+ emb: Tensor
457
+ (batch_size, emb_size)
458
+ """
459
+ emb_seq = []
460
+ for i in range(x.shape[1]):
461
+ # Getting the spectrogram
462
+ channel_spec_emb = self.stft(x[:, i : i + 1, :])
463
+ # Linear layer to learn some representation over the frequency domain
464
+ # with permutation
465
+ channel_spec_emb = self.patch_embedding(channel_spec_emb)
466
+ batch_size, ts, _ = channel_spec_emb.shape
467
+ # (batch_size, ts, emb)
468
+ # Step by step the following lines do the following operations:
469
+ # - self.channel_tokens(self.index[i + n_channel_offset]):
470
+ # Fetches the embedding for a channel specified by i + n_channel_offset,
471
+ # where i is the current index and n_channel_offset adjusts
472
+ # which channel's embedding is retrieved.
473
+ # - .unsqueeze(0).unsqueeze(0):
474
+ # Adds two singleton dimensions to the embedding tensor.
475
+ # [emb_size] to [1, 1, emb_size] .
476
+ # - Repeat(batch_size, ts, 1):
477
+ # Replicates the embedding tensor to match the batch size and
478
+ # time steps (ts),
479
+ channel_token_emb = (
480
+ self.channel_tokens(self.index[i + n_channel_offset])
481
+ .unsqueeze(0)
482
+ .unsqueeze(0)
483
+ .repeat(batch_size, ts, 1)
484
+ )
485
+ # (batch_size, ts, emb)
486
+ # The positional embedding is explaining with more
487
+ # detail in the _PositionalEncoding class.
488
+ channel_emb = self.positional_encoding(channel_spec_emb + channel_token_emb)
489
+ # In case of perturb, the time steps are randomly selected
490
+ # and the channel embedding is reduced to a random number
491
+ # of time steps.
492
+ if perturb:
493
+ ts = channel_emb.shape[1]
494
+ ts_new = torch.randint(low=ts // 2, high=ts, size=(1,)).item()
495
+ selected_ts = torch.randperm(ts)[:ts_new]
496
+ channel_emb = channel_emb[:, selected_ts]
497
+ emb_seq.append(channel_emb)
498
+
499
+ # Concat and transformer
500
+ # (batch_size, 16 * ts, emb)
501
+ emb = torch.cat(emb_seq, dim=1)
502
+ # (batch_size, emb)
503
+ emb = self.transformer(emb).mean(dim=1)
504
+ return emb