braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,537 @@
1
+ import math
2
+ from warnings import warn
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from linear_attention_transformer import LinearAttentionTransformer
7
+
8
+ from braindecode.models.base import EEGModuleMixin
9
+
10
+
11
+ class BIOT(EEGModuleMixin, nn.Module):
12
+ r"""BIOT from Yang et al (2023) [Yang2023]_
13
+
14
+ :bdg-danger:`Foundation Model`
15
+
16
+ .. figure:: https://braindecode.org/dev/_static/model/biot.jpg
17
+ :align: center
18
+ :alt: BioT
19
+
20
+ BIOT: Cross-data Biosignal Learning in the Wild.
21
+
22
+ BIOT is a foundation model for biosignal classification. It is
23
+ a wrapper around the `BIOTEncoder` and `ClassificationHead` modules.
24
+
25
+ It is designed for N-dimensional biosignal data such as EEG, ECG, etc.
26
+ The method was proposed by Yang et al. [Yang2023]_ and the code is
27
+ available at [Code2023]_
28
+
29
+ The model is trained with a contrastive loss on large EEG datasets
30
+ TUH Abnormal EEG Corpus with 400K samples and Sleep Heart Health Study
31
+ 5M. Here, we only provide the model architecture, not the pre-trained
32
+ weights or contrastive loss training.
33
+
34
+ The architecture is based on the `LinearAttentionTransformer` and
35
+ `PatchFrequencyEmbedding` modules.
36
+ The `BIOTEncoder` is a transformer that takes the input data and outputs
37
+ a fixed-size representation of the input data. More details are
38
+ present in the `BIOTEncoder` class.
39
+
40
+ The `ClassificationHead` is an ELU activation layer, followed by a simple
41
+ linear layer that takes the output of the `BIOTEncoder` and outputs
42
+ the classification probabilities.
43
+
44
+ .. important::
45
+ **Pre-trained Weights Available**
46
+
47
+ This model has pre-trained weights available on the Hugging Face Hub.
48
+ You can load them using:
49
+
50
+ .. code-block:: python
51
+
52
+ from braindecode.models import BIOT
53
+
54
+ # Load the original pre-trained model from Hugging Face Hub
55
+ # For 16-channel models:
56
+ model = BIOT.from_pretrained("braindecode/biot-pretrained-prest-16chs")
57
+
58
+ # For 18-channel models:
59
+ model = BIOT.from_pretrained("braindecode/biot-pretrained-shhs-prest-18chs")
60
+ model = BIOT.from_pretrained("braindecode/biot-pretrained-six-datasets-18chs")
61
+
62
+ To push your own trained model to the Hub:
63
+
64
+ .. code-block:: python
65
+
66
+ # After training your model
67
+ model.push_to_hub(
68
+ repo_id="username/my-biot-model", commit_message="Upload trained BIOT model"
69
+ )
70
+
71
+ Requires installing ``braindecode[hug]`` for Hub integration.
72
+
73
+ .. versionadded:: 0.9
74
+
75
+ Parameters
76
+ ----------
77
+ embed_dim : int, optional
78
+ The size of the embedding layer, by default 256
79
+ num_heads : int, optional
80
+ The number of attention heads, by default 8
81
+ num_layers : int, optional
82
+ The number of transformer layers, by default 4
83
+ activation: nn.Module, default=nn.ELU
84
+ Activation function class to apply. Should be a PyTorch activation
85
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
86
+ return_feature: bool, optional
87
+ Changing the output for the neural network. Default is single tensor
88
+ when return_feature is True, return embedding space too.
89
+ Default is False.
90
+ hop_length: int, optional
91
+ The hop length for the torch.stft transformation in the
92
+ encoder. The default is 100.
93
+ sfreq: int, optional
94
+ The sfreq parameter for the encoder. The default is 200
95
+
96
+ References
97
+ ----------
98
+ .. [Yang2023] Yang, C., Westover, M.B. and Sun, J., 2023, November. BIOT:
99
+ Biosignal Transformer for Cross-data Learning in the Wild. In Thirty-seventh
100
+ Conference on Neural Information Processing Systems, NeurIPS.
101
+ .. [Code2023] Yang, C., Westover, M.B. and Sun, J., 2023. BIOT
102
+ Biosignal Transformer for Cross-data Learning in the Wild.
103
+ GitHub https://github.com/ycq091044/BIOT (accessed 2024-02-13)
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ embed_dim=256,
109
+ num_heads=8,
110
+ num_layers=4,
111
+ sfreq=200,
112
+ hop_length=100,
113
+ return_feature=False,
114
+ n_outputs=None,
115
+ n_chans=None,
116
+ chs_info=None,
117
+ n_times=None,
118
+ input_window_seconds=None,
119
+ activation: type[nn.Module] = nn.ELU,
120
+ drop_prob: float = 0.5,
121
+ # Parameters for the encoder
122
+ max_seq_len: int = 1024,
123
+ att_drop_prob=0.2,
124
+ att_layer_drop_prob=0.2,
125
+ ):
126
+ super().__init__(
127
+ n_outputs=n_outputs,
128
+ n_chans=n_chans,
129
+ chs_info=chs_info,
130
+ n_times=n_times,
131
+ input_window_seconds=input_window_seconds,
132
+ sfreq=sfreq,
133
+ )
134
+ del n_outputs, n_chans, chs_info, n_times, sfreq
135
+ self.embed_dim = embed_dim
136
+ self.hop_length = hop_length
137
+ self.num_heads = num_heads
138
+ self.num_layers = num_layers
139
+ self.return_feature = return_feature
140
+ if (self.sfreq != 200) & (self.sfreq is not None):
141
+ warn(
142
+ "This model has only been trained on a dataset with 200 Hz. "
143
+ + "no guarantee to generalize well with the default parameters",
144
+ UserWarning,
145
+ )
146
+ if self.n_chans > embed_dim:
147
+ warn(
148
+ "The number of channels is larger than the embedding size. "
149
+ + "This may cause overfitting. Consider using a larger "
150
+ + "embedding size or a smaller number of channels.",
151
+ UserWarning,
152
+ )
153
+ if self.hop_length > self.sfreq:
154
+ warn(
155
+ "The hop length is larger than the sampling frequency. "
156
+ + "This may cause aliasing. Consider using a smaller "
157
+ "hop length.",
158
+ UserWarning,
159
+ )
160
+ hop_length = self.sfreq // 2
161
+
162
+ if self.input_window_seconds < 1.0:
163
+ warning_msg = (
164
+ "The input window is less than 1 second, which may not be "
165
+ "sufficient for the model to learn meaningful representations."
166
+ "Changing the `n_fft` to `n_times`."
167
+ )
168
+ warn(warning_msg, UserWarning)
169
+ self.n_fft = self.n_times
170
+ else:
171
+ self.n_fft = int(self.sfreq)
172
+
173
+ self.encoder = _BIOTEncoder(
174
+ emb_size=self.embed_dim,
175
+ num_heads=self.num_heads,
176
+ n_layers=self.num_layers,
177
+ n_chans=self.n_chans,
178
+ n_fft=self.n_fft,
179
+ hop_length=hop_length,
180
+ drop_prob=drop_prob,
181
+ max_seq_len=max_seq_len,
182
+ attn_dropout=att_drop_prob,
183
+ attn_layer_dropout=att_layer_drop_prob,
184
+ )
185
+
186
+ self.final_layer = _ClassificationHead(
187
+ emb_size=self.embed_dim,
188
+ n_outputs=self.n_outputs,
189
+ activation=activation,
190
+ )
191
+
192
+ def forward(self, x):
193
+ """
194
+ Pass the input through the BIOT encoder, and then through the
195
+ classification head.
196
+
197
+ Parameters
198
+ ----------
199
+ x: Tensor
200
+ (batch_size, n_channels, n_times)
201
+
202
+ Returns
203
+ -------
204
+ out: Tensor
205
+ (batch_size, n_outputs)
206
+ (out, emb): tuple Tensor
207
+ (batch_size, n_outputs), (batch_size, emb_size)
208
+ """
209
+ emb = self.encoder(x)
210
+ x = self.final_layer(emb)
211
+
212
+ if self.return_feature:
213
+ return x, emb
214
+ else:
215
+ return x
216
+
217
+
218
+ class _PatchFrequencyEmbedding(nn.Module):
219
+ r"""
220
+ Patch Frequency Embedding.
221
+
222
+ A simple linear layer is used to learn some representation over the
223
+ frequency domain with permutation in the frequency axis.
224
+
225
+ Parameters
226
+ ----------
227
+ emb_size: int
228
+ The size of the embedding layer
229
+ n_freq: int
230
+ The number of frequency points after the Fourier transform
231
+
232
+ Returns
233
+ -------
234
+ out: Tensor
235
+ (batch, time, emb_size)
236
+ """
237
+
238
+ def __init__(self, emb_size: int = 256, n_freq: int = 101):
239
+ super().__init__()
240
+ self.projection = nn.Linear(n_freq, emb_size)
241
+
242
+ def forward(self, x):
243
+ """
244
+ Parameters
245
+ ----------
246
+ x: Tensor
247
+ (batch, time, n_freq)
248
+
249
+ Returns
250
+ -------
251
+ out: Tensor
252
+ (batch, time, emb_size)
253
+
254
+ """
255
+ x = x.permute(0, 2, 1)
256
+ x = self.projection(x)
257
+ return x
258
+
259
+
260
+ class _ClassificationHead(nn.Sequential):
261
+ r"""
262
+ Classification head for the BIOT model.
263
+
264
+ Simple linear layer with ELU activation function.
265
+
266
+ Parameters
267
+ ----------
268
+ emb_size: int
269
+ The size of the embedding layer
270
+ n_outputs: int
271
+ The number of classes
272
+ activation: nn.Module, default=nn.ELU
273
+ Activation function class to apply. Should be a PyTorch activation
274
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
275
+
276
+ Returns
277
+ -------
278
+ out: Tensor
279
+ (batch, n_outputs)
280
+ """
281
+
282
+ def __init__(
283
+ self, emb_size: int, n_outputs: int, activation: type[nn.Module] = nn.ELU
284
+ ):
285
+ super().__init__()
286
+ self.activation_layer = activation()
287
+ self.classification_head = nn.Linear(emb_size, n_outputs)
288
+
289
+ def forward(self, x):
290
+ x = self.activation_layer(x)
291
+ out = self.classification_head(x)
292
+ return out
293
+
294
+
295
+ class _PositionalEncoding(nn.Module):
296
+ r"""
297
+ Positional Encoding.
298
+
299
+ We first create a `pe` zero matrix of shape (max_len, d_model) where max_len is the
300
+ the maximum length of the sequence, and emb_size is the size of the embedding.
301
+
302
+ Then we create a `position` tensor of shape (max_len, 1) with indices from 0 to max_len
303
+ and a `div_term` tensor of shape (d_model // 2) with the exponential of the
304
+ multiplication of the indices from 0 to d_model by -log(10000.0) / d_model.
305
+
306
+ For more details about the positional encoding, see the `Attention is All You Need` paper.
307
+
308
+ Parameters
309
+ ----------
310
+ emb_size: int
311
+ The size of the embedding layer
312
+ dropout: float
313
+ The dropout rate
314
+ max_len: int
315
+ The maximum length of the sequence
316
+ drop_prob : float, default=0.5
317
+ The dropout rate for regularization. Values should be between 0 and 1.
318
+
319
+ Returns
320
+ -------
321
+ out: Tensor
322
+ (batch, max_len, d_model)
323
+ """
324
+
325
+ def __init__(self, emb_size: int, drop_prob: float = 0.1, max_len: int = 1000):
326
+ super(_PositionalEncoding, self).__init__()
327
+
328
+ # Compute the positional encodings once in log space.
329
+ pe = torch.zeros(max_len, emb_size)
330
+ position = torch.arange(0, max_len).unsqueeze(1).float()
331
+ div_term = torch.exp(
332
+ torch.arange(0, emb_size, 2).float() * -(math.log(10000.0) / emb_size)
333
+ )
334
+ pe[:, 0::2] = torch.sin(position * div_term)
335
+ pe[:, 1::2] = torch.cos(position * div_term)
336
+ pe = pe.unsqueeze(0)
337
+ self.register_buffer("pe", pe)
338
+ self.dropout = nn.Dropout(p=drop_prob)
339
+
340
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
341
+ """
342
+ Parameters
343
+ ----------
344
+ x: FloatTensor
345
+ `embeddings,` shape (batch, max_len, d_model)
346
+
347
+ Returns
348
+ -------
349
+ Tensor
350
+ `encoder input`, shape (batch, max_len, d_model)
351
+ """
352
+ x = x + self.pe[:, : x.size(1)]
353
+ return self.dropout(x)
354
+
355
+
356
+ class _BIOTEncoder(nn.Module):
357
+ r"""
358
+ BIOT Encoder.
359
+
360
+ The BIOT encoder is a transformer that takes the time series input data and
361
+ return a fixed-size embedding representation of the input data.
362
+ The architecture is based on the `LinearAttentionTransformer` and
363
+ `PatchFrequencyEmbedding` modules.
364
+
365
+ The input data is transformed into a spectrogram and then embedded using a
366
+ "patch" embedding.
367
+ The channel token is added to the patch embedding, and then
368
+ positional encoding is applied (simple index positional).
369
+ The resulting embeddings are concatenated
370
+ and passed through a transformer layer. The mean across different channels
371
+ embeddings is returned.
372
+
373
+ Parameters
374
+ ----------
375
+ n_chans: int
376
+ The number of channels
377
+ emb_size: int
378
+ The size of the embedding layer
379
+ num_heads: int
380
+ The number of attention heads
381
+ n_layers: int
382
+ The number of transformer layers
383
+ n_fft: int
384
+ The number of Fourier transform points
385
+ hop_length: int (default 100)
386
+ The distance between neighboring sliding window frames
387
+ """
388
+
389
+ def __init__(
390
+ self,
391
+ emb_size=256, # The size of the embedding layer
392
+ num_heads=8, # The number of attention heads
393
+ n_chans=16, # The number of channels
394
+ n_layers=4, # The number of transformer layers
395
+ n_fft=200, # Related with the frequency resolution
396
+ hop_length=100,
397
+ drop_prob: float = 0.1,
398
+ max_seq_len: int = 1024, # The maximum sequence length
399
+ attn_dropout=0.2, # dropout post-attention
400
+ attn_layer_dropout=0.2, # dropout right after self-attention layer
401
+ ):
402
+ super().__init__()
403
+
404
+ self.n_fft = int(n_fft)
405
+ self.hop_length = hop_length
406
+
407
+ self.patch_embedding = _PatchFrequencyEmbedding(
408
+ emb_size=emb_size, n_freq=int(self.n_fft // 2 + 1)
409
+ )
410
+ self.transformer = LinearAttentionTransformer(
411
+ dim=emb_size,
412
+ heads=num_heads,
413
+ depth=n_layers,
414
+ max_seq_len=max_seq_len,
415
+ attn_layer_dropout=attn_layer_dropout,
416
+ attn_dropout=attn_dropout,
417
+ )
418
+ self.positional_encoding = _PositionalEncoding(emb_size, drop_prob=drop_prob)
419
+
420
+ # channel token, N_channels >= your actual channels
421
+ self.channel_tokens = nn.Embedding(
422
+ num_embeddings=n_chans, embedding_dim=emb_size
423
+ )
424
+ self.index = nn.Parameter(torch.LongTensor(range(n_chans)), requires_grad=False)
425
+
426
+ def stft(self, sample):
427
+ """
428
+ Short-time Fourier transform.
429
+ For more details, see `torch.stft`.
430
+
431
+ The size of the Fourier transform is obtained by `n_fft`, and the distance
432
+ between neighboring sliding window frames `hop_length` defined in
433
+ the __init__ functions.
434
+
435
+ Parameters
436
+ ----------
437
+ sample: Tensor
438
+ channel representation with size (batch_size, n_times)
439
+ Returns
440
+ -------
441
+ spectral: Tensor
442
+ Absolute value of the Fourier transform with size
443
+ (batch_size, n_fft // 2 + 1, n_times // hop_length + 1)
444
+ """
445
+ spectral = torch.stft(
446
+ input=sample.squeeze(1),
447
+ n_fft=int(self.n_fft),
448
+ hop_length=self.hop_length,
449
+ center=False,
450
+ onesided=True,
451
+ return_complex=True,
452
+ )
453
+ return torch.abs(spectral)
454
+
455
+ def forward(self, x, n_channel_offset=0, perturb=False):
456
+ """
457
+ Forward pass of the BIOT encoder.
458
+
459
+ For each channel, the input is transformed into a spectrogram
460
+ and then embedded using a patch embedding. The channel token
461
+ is added to the patch embedding, and then positional encoding
462
+ is applied. The resulting embeddings are concatenated and
463
+ passed through a transformer layer. The mean of the resulting
464
+ embeddings is returned.
465
+
466
+ For each channel in channels, the channels are transformed into a
467
+ spectrogram with STFT; The spectrogram representation is permuted
468
+ and passed through a linear layer to learn some representation over
469
+ the frequency domain.
470
+
471
+ For each embedding in the sequence, the channel token is added to
472
+ the patch embedding, and then positional encoding is applied.
473
+
474
+ The resulting embeddings are concatenated and passed through a
475
+ transformer layer. The mean of the resulting embeddings is returned.
476
+
477
+ Parameters
478
+ ----------
479
+ x: Tensor
480
+ (batch_size, n_channels, n_times)
481
+ n_channel_offset: int (default 0)
482
+ The offset term to be added to the channel tokens
483
+ perturb: bool (default False)
484
+ Randomly select a number of time steps and reduce the
485
+ channel embedding to those time steps.
486
+
487
+ Returns
488
+ -------
489
+ emb: Tensor
490
+ (batch_size, emb_size)
491
+ """
492
+ emb_seq = []
493
+ for i in range(x.shape[1]):
494
+ # Getting the spectrogram
495
+ channel_spec_emb = self.stft(x[:, i : i + 1, :])
496
+ # Linear layer to learn some representation over the frequency domain
497
+ # with permutation
498
+ channel_spec_emb = self.patch_embedding(channel_spec_emb)
499
+ batch_size, ts, _ = channel_spec_emb.shape
500
+ # (batch_size, ts, emb)
501
+ # Step by step the following lines do the following operations:
502
+ # - self.channel_tokens(self.index[i + n_channel_offset]):
503
+ # Fetches the embedding for a channel specified by i + n_channel_offset,
504
+ # where i is the current index and n_channel_offset adjusts
505
+ # which channel's embedding is retrieved.
506
+ # - .unsqueeze(0).unsqueeze(0):
507
+ # Adds two singleton dimensions to the embedding tensor.
508
+ # [emb_size] to [1, 1, emb_size] .
509
+ # - Repeat(batch_size, ts, 1):
510
+ # Replicates the embedding tensor to match the batch size and
511
+ # time steps (ts),
512
+ channel_token_emb = (
513
+ self.channel_tokens(self.index[i + n_channel_offset])
514
+ .unsqueeze(0)
515
+ .unsqueeze(0)
516
+ .repeat(batch_size, ts, 1)
517
+ )
518
+ # (batch_size, ts, emb)
519
+ # The positional embedding is explaining with more
520
+ # detail in the _PositionalEncoding class.
521
+ channel_emb = self.positional_encoding(channel_spec_emb + channel_token_emb)
522
+ # In case of perturb, the time steps are randomly selected
523
+ # and the channel embedding is reduced to a random number
524
+ # of time steps.
525
+ if perturb:
526
+ ts = channel_emb.shape[1]
527
+ ts_new = torch.randint(low=ts // 2, high=ts, size=(1,)).item()
528
+ selected_ts = torch.randperm(ts)[:ts_new]
529
+ channel_emb = channel_emb[:, selected_ts]
530
+ emb_seq.append(channel_emb)
531
+
532
+ # Concat and transformer
533
+ # (batch_size, 16 * ts, emb)
534
+ emb = torch.cat(emb_seq, dim=1)
535
+ # (batch_size, emb)
536
+ emb = self.transformer(emb).mean(dim=1)
537
+ return emb