braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,537 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from warnings import warn
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from linear_attention_transformer import LinearAttentionTransformer
|
|
7
|
+
|
|
8
|
+
from braindecode.models.base import EEGModuleMixin
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BIOT(EEGModuleMixin, nn.Module):
|
|
12
|
+
r"""BIOT from Yang et al (2023) [Yang2023]_
|
|
13
|
+
|
|
14
|
+
:bdg-danger:`Foundation Model`
|
|
15
|
+
|
|
16
|
+
.. figure:: https://braindecode.org/dev/_static/model/biot.jpg
|
|
17
|
+
:align: center
|
|
18
|
+
:alt: BioT
|
|
19
|
+
|
|
20
|
+
BIOT: Cross-data Biosignal Learning in the Wild.
|
|
21
|
+
|
|
22
|
+
BIOT is a foundation model for biosignal classification. It is
|
|
23
|
+
a wrapper around the `BIOTEncoder` and `ClassificationHead` modules.
|
|
24
|
+
|
|
25
|
+
It is designed for N-dimensional biosignal data such as EEG, ECG, etc.
|
|
26
|
+
The method was proposed by Yang et al. [Yang2023]_ and the code is
|
|
27
|
+
available at [Code2023]_
|
|
28
|
+
|
|
29
|
+
The model is trained with a contrastive loss on large EEG datasets
|
|
30
|
+
TUH Abnormal EEG Corpus with 400K samples and Sleep Heart Health Study
|
|
31
|
+
5M. Here, we only provide the model architecture, not the pre-trained
|
|
32
|
+
weights or contrastive loss training.
|
|
33
|
+
|
|
34
|
+
The architecture is based on the `LinearAttentionTransformer` and
|
|
35
|
+
`PatchFrequencyEmbedding` modules.
|
|
36
|
+
The `BIOTEncoder` is a transformer that takes the input data and outputs
|
|
37
|
+
a fixed-size representation of the input data. More details are
|
|
38
|
+
present in the `BIOTEncoder` class.
|
|
39
|
+
|
|
40
|
+
The `ClassificationHead` is an ELU activation layer, followed by a simple
|
|
41
|
+
linear layer that takes the output of the `BIOTEncoder` and outputs
|
|
42
|
+
the classification probabilities.
|
|
43
|
+
|
|
44
|
+
.. important::
|
|
45
|
+
**Pre-trained Weights Available**
|
|
46
|
+
|
|
47
|
+
This model has pre-trained weights available on the Hugging Face Hub.
|
|
48
|
+
You can load them using:
|
|
49
|
+
|
|
50
|
+
.. code-block:: python
|
|
51
|
+
|
|
52
|
+
from braindecode.models import BIOT
|
|
53
|
+
|
|
54
|
+
# Load the original pre-trained model from Hugging Face Hub
|
|
55
|
+
# For 16-channel models:
|
|
56
|
+
model = BIOT.from_pretrained("braindecode/biot-pretrained-prest-16chs")
|
|
57
|
+
|
|
58
|
+
# For 18-channel models:
|
|
59
|
+
model = BIOT.from_pretrained("braindecode/biot-pretrained-shhs-prest-18chs")
|
|
60
|
+
model = BIOT.from_pretrained("braindecode/biot-pretrained-six-datasets-18chs")
|
|
61
|
+
|
|
62
|
+
To push your own trained model to the Hub:
|
|
63
|
+
|
|
64
|
+
.. code-block:: python
|
|
65
|
+
|
|
66
|
+
# After training your model
|
|
67
|
+
model.push_to_hub(
|
|
68
|
+
repo_id="username/my-biot-model", commit_message="Upload trained BIOT model"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
Requires installing ``braindecode[hug]`` for Hub integration.
|
|
72
|
+
|
|
73
|
+
.. versionadded:: 0.9
|
|
74
|
+
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
embed_dim : int, optional
|
|
78
|
+
The size of the embedding layer, by default 256
|
|
79
|
+
num_heads : int, optional
|
|
80
|
+
The number of attention heads, by default 8
|
|
81
|
+
num_layers : int, optional
|
|
82
|
+
The number of transformer layers, by default 4
|
|
83
|
+
activation: nn.Module, default=nn.ELU
|
|
84
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
85
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
86
|
+
return_feature: bool, optional
|
|
87
|
+
Changing the output for the neural network. Default is single tensor
|
|
88
|
+
when return_feature is True, return embedding space too.
|
|
89
|
+
Default is False.
|
|
90
|
+
hop_length: int, optional
|
|
91
|
+
The hop length for the torch.stft transformation in the
|
|
92
|
+
encoder. The default is 100.
|
|
93
|
+
sfreq: int, optional
|
|
94
|
+
The sfreq parameter for the encoder. The default is 200
|
|
95
|
+
|
|
96
|
+
References
|
|
97
|
+
----------
|
|
98
|
+
.. [Yang2023] Yang, C., Westover, M.B. and Sun, J., 2023, November. BIOT:
|
|
99
|
+
Biosignal Transformer for Cross-data Learning in the Wild. In Thirty-seventh
|
|
100
|
+
Conference on Neural Information Processing Systems, NeurIPS.
|
|
101
|
+
.. [Code2023] Yang, C., Westover, M.B. and Sun, J., 2023. BIOT
|
|
102
|
+
Biosignal Transformer for Cross-data Learning in the Wild.
|
|
103
|
+
GitHub https://github.com/ycq091044/BIOT (accessed 2024-02-13)
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
def __init__(
|
|
107
|
+
self,
|
|
108
|
+
embed_dim=256,
|
|
109
|
+
num_heads=8,
|
|
110
|
+
num_layers=4,
|
|
111
|
+
sfreq=200,
|
|
112
|
+
hop_length=100,
|
|
113
|
+
return_feature=False,
|
|
114
|
+
n_outputs=None,
|
|
115
|
+
n_chans=None,
|
|
116
|
+
chs_info=None,
|
|
117
|
+
n_times=None,
|
|
118
|
+
input_window_seconds=None,
|
|
119
|
+
activation: type[nn.Module] = nn.ELU,
|
|
120
|
+
drop_prob: float = 0.5,
|
|
121
|
+
# Parameters for the encoder
|
|
122
|
+
max_seq_len: int = 1024,
|
|
123
|
+
att_drop_prob=0.2,
|
|
124
|
+
att_layer_drop_prob=0.2,
|
|
125
|
+
):
|
|
126
|
+
super().__init__(
|
|
127
|
+
n_outputs=n_outputs,
|
|
128
|
+
n_chans=n_chans,
|
|
129
|
+
chs_info=chs_info,
|
|
130
|
+
n_times=n_times,
|
|
131
|
+
input_window_seconds=input_window_seconds,
|
|
132
|
+
sfreq=sfreq,
|
|
133
|
+
)
|
|
134
|
+
del n_outputs, n_chans, chs_info, n_times, sfreq
|
|
135
|
+
self.embed_dim = embed_dim
|
|
136
|
+
self.hop_length = hop_length
|
|
137
|
+
self.num_heads = num_heads
|
|
138
|
+
self.num_layers = num_layers
|
|
139
|
+
self.return_feature = return_feature
|
|
140
|
+
if (self.sfreq != 200) & (self.sfreq is not None):
|
|
141
|
+
warn(
|
|
142
|
+
"This model has only been trained on a dataset with 200 Hz. "
|
|
143
|
+
+ "no guarantee to generalize well with the default parameters",
|
|
144
|
+
UserWarning,
|
|
145
|
+
)
|
|
146
|
+
if self.n_chans > embed_dim:
|
|
147
|
+
warn(
|
|
148
|
+
"The number of channels is larger than the embedding size. "
|
|
149
|
+
+ "This may cause overfitting. Consider using a larger "
|
|
150
|
+
+ "embedding size or a smaller number of channels.",
|
|
151
|
+
UserWarning,
|
|
152
|
+
)
|
|
153
|
+
if self.hop_length > self.sfreq:
|
|
154
|
+
warn(
|
|
155
|
+
"The hop length is larger than the sampling frequency. "
|
|
156
|
+
+ "This may cause aliasing. Consider using a smaller "
|
|
157
|
+
"hop length.",
|
|
158
|
+
UserWarning,
|
|
159
|
+
)
|
|
160
|
+
hop_length = self.sfreq // 2
|
|
161
|
+
|
|
162
|
+
if self.input_window_seconds < 1.0:
|
|
163
|
+
warning_msg = (
|
|
164
|
+
"The input window is less than 1 second, which may not be "
|
|
165
|
+
"sufficient for the model to learn meaningful representations."
|
|
166
|
+
"Changing the `n_fft` to `n_times`."
|
|
167
|
+
)
|
|
168
|
+
warn(warning_msg, UserWarning)
|
|
169
|
+
self.n_fft = self.n_times
|
|
170
|
+
else:
|
|
171
|
+
self.n_fft = int(self.sfreq)
|
|
172
|
+
|
|
173
|
+
self.encoder = _BIOTEncoder(
|
|
174
|
+
emb_size=self.embed_dim,
|
|
175
|
+
num_heads=self.num_heads,
|
|
176
|
+
n_layers=self.num_layers,
|
|
177
|
+
n_chans=self.n_chans,
|
|
178
|
+
n_fft=self.n_fft,
|
|
179
|
+
hop_length=hop_length,
|
|
180
|
+
drop_prob=drop_prob,
|
|
181
|
+
max_seq_len=max_seq_len,
|
|
182
|
+
attn_dropout=att_drop_prob,
|
|
183
|
+
attn_layer_dropout=att_layer_drop_prob,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
self.final_layer = _ClassificationHead(
|
|
187
|
+
emb_size=self.embed_dim,
|
|
188
|
+
n_outputs=self.n_outputs,
|
|
189
|
+
activation=activation,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
def forward(self, x):
|
|
193
|
+
"""
|
|
194
|
+
Pass the input through the BIOT encoder, and then through the
|
|
195
|
+
classification head.
|
|
196
|
+
|
|
197
|
+
Parameters
|
|
198
|
+
----------
|
|
199
|
+
x: Tensor
|
|
200
|
+
(batch_size, n_channels, n_times)
|
|
201
|
+
|
|
202
|
+
Returns
|
|
203
|
+
-------
|
|
204
|
+
out: Tensor
|
|
205
|
+
(batch_size, n_outputs)
|
|
206
|
+
(out, emb): tuple Tensor
|
|
207
|
+
(batch_size, n_outputs), (batch_size, emb_size)
|
|
208
|
+
"""
|
|
209
|
+
emb = self.encoder(x)
|
|
210
|
+
x = self.final_layer(emb)
|
|
211
|
+
|
|
212
|
+
if self.return_feature:
|
|
213
|
+
return x, emb
|
|
214
|
+
else:
|
|
215
|
+
return x
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class _PatchFrequencyEmbedding(nn.Module):
|
|
219
|
+
r"""
|
|
220
|
+
Patch Frequency Embedding.
|
|
221
|
+
|
|
222
|
+
A simple linear layer is used to learn some representation over the
|
|
223
|
+
frequency domain with permutation in the frequency axis.
|
|
224
|
+
|
|
225
|
+
Parameters
|
|
226
|
+
----------
|
|
227
|
+
emb_size: int
|
|
228
|
+
The size of the embedding layer
|
|
229
|
+
n_freq: int
|
|
230
|
+
The number of frequency points after the Fourier transform
|
|
231
|
+
|
|
232
|
+
Returns
|
|
233
|
+
-------
|
|
234
|
+
out: Tensor
|
|
235
|
+
(batch, time, emb_size)
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
def __init__(self, emb_size: int = 256, n_freq: int = 101):
|
|
239
|
+
super().__init__()
|
|
240
|
+
self.projection = nn.Linear(n_freq, emb_size)
|
|
241
|
+
|
|
242
|
+
def forward(self, x):
|
|
243
|
+
"""
|
|
244
|
+
Parameters
|
|
245
|
+
----------
|
|
246
|
+
x: Tensor
|
|
247
|
+
(batch, time, n_freq)
|
|
248
|
+
|
|
249
|
+
Returns
|
|
250
|
+
-------
|
|
251
|
+
out: Tensor
|
|
252
|
+
(batch, time, emb_size)
|
|
253
|
+
|
|
254
|
+
"""
|
|
255
|
+
x = x.permute(0, 2, 1)
|
|
256
|
+
x = self.projection(x)
|
|
257
|
+
return x
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
class _ClassificationHead(nn.Sequential):
|
|
261
|
+
r"""
|
|
262
|
+
Classification head for the BIOT model.
|
|
263
|
+
|
|
264
|
+
Simple linear layer with ELU activation function.
|
|
265
|
+
|
|
266
|
+
Parameters
|
|
267
|
+
----------
|
|
268
|
+
emb_size: int
|
|
269
|
+
The size of the embedding layer
|
|
270
|
+
n_outputs: int
|
|
271
|
+
The number of classes
|
|
272
|
+
activation: nn.Module, default=nn.ELU
|
|
273
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
274
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
275
|
+
|
|
276
|
+
Returns
|
|
277
|
+
-------
|
|
278
|
+
out: Tensor
|
|
279
|
+
(batch, n_outputs)
|
|
280
|
+
"""
|
|
281
|
+
|
|
282
|
+
def __init__(
|
|
283
|
+
self, emb_size: int, n_outputs: int, activation: type[nn.Module] = nn.ELU
|
|
284
|
+
):
|
|
285
|
+
super().__init__()
|
|
286
|
+
self.activation_layer = activation()
|
|
287
|
+
self.classification_head = nn.Linear(emb_size, n_outputs)
|
|
288
|
+
|
|
289
|
+
def forward(self, x):
|
|
290
|
+
x = self.activation_layer(x)
|
|
291
|
+
out = self.classification_head(x)
|
|
292
|
+
return out
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
class _PositionalEncoding(nn.Module):
|
|
296
|
+
r"""
|
|
297
|
+
Positional Encoding.
|
|
298
|
+
|
|
299
|
+
We first create a `pe` zero matrix of shape (max_len, d_model) where max_len is the
|
|
300
|
+
the maximum length of the sequence, and emb_size is the size of the embedding.
|
|
301
|
+
|
|
302
|
+
Then we create a `position` tensor of shape (max_len, 1) with indices from 0 to max_len
|
|
303
|
+
and a `div_term` tensor of shape (d_model // 2) with the exponential of the
|
|
304
|
+
multiplication of the indices from 0 to d_model by -log(10000.0) / d_model.
|
|
305
|
+
|
|
306
|
+
For more details about the positional encoding, see the `Attention is All You Need` paper.
|
|
307
|
+
|
|
308
|
+
Parameters
|
|
309
|
+
----------
|
|
310
|
+
emb_size: int
|
|
311
|
+
The size of the embedding layer
|
|
312
|
+
dropout: float
|
|
313
|
+
The dropout rate
|
|
314
|
+
max_len: int
|
|
315
|
+
The maximum length of the sequence
|
|
316
|
+
drop_prob : float, default=0.5
|
|
317
|
+
The dropout rate for regularization. Values should be between 0 and 1.
|
|
318
|
+
|
|
319
|
+
Returns
|
|
320
|
+
-------
|
|
321
|
+
out: Tensor
|
|
322
|
+
(batch, max_len, d_model)
|
|
323
|
+
"""
|
|
324
|
+
|
|
325
|
+
def __init__(self, emb_size: int, drop_prob: float = 0.1, max_len: int = 1000):
|
|
326
|
+
super(_PositionalEncoding, self).__init__()
|
|
327
|
+
|
|
328
|
+
# Compute the positional encodings once in log space.
|
|
329
|
+
pe = torch.zeros(max_len, emb_size)
|
|
330
|
+
position = torch.arange(0, max_len).unsqueeze(1).float()
|
|
331
|
+
div_term = torch.exp(
|
|
332
|
+
torch.arange(0, emb_size, 2).float() * -(math.log(10000.0) / emb_size)
|
|
333
|
+
)
|
|
334
|
+
pe[:, 0::2] = torch.sin(position * div_term)
|
|
335
|
+
pe[:, 1::2] = torch.cos(position * div_term)
|
|
336
|
+
pe = pe.unsqueeze(0)
|
|
337
|
+
self.register_buffer("pe", pe)
|
|
338
|
+
self.dropout = nn.Dropout(p=drop_prob)
|
|
339
|
+
|
|
340
|
+
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
|
341
|
+
"""
|
|
342
|
+
Parameters
|
|
343
|
+
----------
|
|
344
|
+
x: FloatTensor
|
|
345
|
+
`embeddings,` shape (batch, max_len, d_model)
|
|
346
|
+
|
|
347
|
+
Returns
|
|
348
|
+
-------
|
|
349
|
+
Tensor
|
|
350
|
+
`encoder input`, shape (batch, max_len, d_model)
|
|
351
|
+
"""
|
|
352
|
+
x = x + self.pe[:, : x.size(1)]
|
|
353
|
+
return self.dropout(x)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
class _BIOTEncoder(nn.Module):
|
|
357
|
+
r"""
|
|
358
|
+
BIOT Encoder.
|
|
359
|
+
|
|
360
|
+
The BIOT encoder is a transformer that takes the time series input data and
|
|
361
|
+
return a fixed-size embedding representation of the input data.
|
|
362
|
+
The architecture is based on the `LinearAttentionTransformer` and
|
|
363
|
+
`PatchFrequencyEmbedding` modules.
|
|
364
|
+
|
|
365
|
+
The input data is transformed into a spectrogram and then embedded using a
|
|
366
|
+
"patch" embedding.
|
|
367
|
+
The channel token is added to the patch embedding, and then
|
|
368
|
+
positional encoding is applied (simple index positional).
|
|
369
|
+
The resulting embeddings are concatenated
|
|
370
|
+
and passed through a transformer layer. The mean across different channels
|
|
371
|
+
embeddings is returned.
|
|
372
|
+
|
|
373
|
+
Parameters
|
|
374
|
+
----------
|
|
375
|
+
n_chans: int
|
|
376
|
+
The number of channels
|
|
377
|
+
emb_size: int
|
|
378
|
+
The size of the embedding layer
|
|
379
|
+
num_heads: int
|
|
380
|
+
The number of attention heads
|
|
381
|
+
n_layers: int
|
|
382
|
+
The number of transformer layers
|
|
383
|
+
n_fft: int
|
|
384
|
+
The number of Fourier transform points
|
|
385
|
+
hop_length: int (default 100)
|
|
386
|
+
The distance between neighboring sliding window frames
|
|
387
|
+
"""
|
|
388
|
+
|
|
389
|
+
def __init__(
|
|
390
|
+
self,
|
|
391
|
+
emb_size=256, # The size of the embedding layer
|
|
392
|
+
num_heads=8, # The number of attention heads
|
|
393
|
+
n_chans=16, # The number of channels
|
|
394
|
+
n_layers=4, # The number of transformer layers
|
|
395
|
+
n_fft=200, # Related with the frequency resolution
|
|
396
|
+
hop_length=100,
|
|
397
|
+
drop_prob: float = 0.1,
|
|
398
|
+
max_seq_len: int = 1024, # The maximum sequence length
|
|
399
|
+
attn_dropout=0.2, # dropout post-attention
|
|
400
|
+
attn_layer_dropout=0.2, # dropout right after self-attention layer
|
|
401
|
+
):
|
|
402
|
+
super().__init__()
|
|
403
|
+
|
|
404
|
+
self.n_fft = int(n_fft)
|
|
405
|
+
self.hop_length = hop_length
|
|
406
|
+
|
|
407
|
+
self.patch_embedding = _PatchFrequencyEmbedding(
|
|
408
|
+
emb_size=emb_size, n_freq=int(self.n_fft // 2 + 1)
|
|
409
|
+
)
|
|
410
|
+
self.transformer = LinearAttentionTransformer(
|
|
411
|
+
dim=emb_size,
|
|
412
|
+
heads=num_heads,
|
|
413
|
+
depth=n_layers,
|
|
414
|
+
max_seq_len=max_seq_len,
|
|
415
|
+
attn_layer_dropout=attn_layer_dropout,
|
|
416
|
+
attn_dropout=attn_dropout,
|
|
417
|
+
)
|
|
418
|
+
self.positional_encoding = _PositionalEncoding(emb_size, drop_prob=drop_prob)
|
|
419
|
+
|
|
420
|
+
# channel token, N_channels >= your actual channels
|
|
421
|
+
self.channel_tokens = nn.Embedding(
|
|
422
|
+
num_embeddings=n_chans, embedding_dim=emb_size
|
|
423
|
+
)
|
|
424
|
+
self.index = nn.Parameter(torch.LongTensor(range(n_chans)), requires_grad=False)
|
|
425
|
+
|
|
426
|
+
def stft(self, sample):
|
|
427
|
+
"""
|
|
428
|
+
Short-time Fourier transform.
|
|
429
|
+
For more details, see `torch.stft`.
|
|
430
|
+
|
|
431
|
+
The size of the Fourier transform is obtained by `n_fft`, and the distance
|
|
432
|
+
between neighboring sliding window frames `hop_length` defined in
|
|
433
|
+
the __init__ functions.
|
|
434
|
+
|
|
435
|
+
Parameters
|
|
436
|
+
----------
|
|
437
|
+
sample: Tensor
|
|
438
|
+
channel representation with size (batch_size, n_times)
|
|
439
|
+
Returns
|
|
440
|
+
-------
|
|
441
|
+
spectral: Tensor
|
|
442
|
+
Absolute value of the Fourier transform with size
|
|
443
|
+
(batch_size, n_fft // 2 + 1, n_times // hop_length + 1)
|
|
444
|
+
"""
|
|
445
|
+
spectral = torch.stft(
|
|
446
|
+
input=sample.squeeze(1),
|
|
447
|
+
n_fft=int(self.n_fft),
|
|
448
|
+
hop_length=self.hop_length,
|
|
449
|
+
center=False,
|
|
450
|
+
onesided=True,
|
|
451
|
+
return_complex=True,
|
|
452
|
+
)
|
|
453
|
+
return torch.abs(spectral)
|
|
454
|
+
|
|
455
|
+
def forward(self, x, n_channel_offset=0, perturb=False):
|
|
456
|
+
"""
|
|
457
|
+
Forward pass of the BIOT encoder.
|
|
458
|
+
|
|
459
|
+
For each channel, the input is transformed into a spectrogram
|
|
460
|
+
and then embedded using a patch embedding. The channel token
|
|
461
|
+
is added to the patch embedding, and then positional encoding
|
|
462
|
+
is applied. The resulting embeddings are concatenated and
|
|
463
|
+
passed through a transformer layer. The mean of the resulting
|
|
464
|
+
embeddings is returned.
|
|
465
|
+
|
|
466
|
+
For each channel in channels, the channels are transformed into a
|
|
467
|
+
spectrogram with STFT; The spectrogram representation is permuted
|
|
468
|
+
and passed through a linear layer to learn some representation over
|
|
469
|
+
the frequency domain.
|
|
470
|
+
|
|
471
|
+
For each embedding in the sequence, the channel token is added to
|
|
472
|
+
the patch embedding, and then positional encoding is applied.
|
|
473
|
+
|
|
474
|
+
The resulting embeddings are concatenated and passed through a
|
|
475
|
+
transformer layer. The mean of the resulting embeddings is returned.
|
|
476
|
+
|
|
477
|
+
Parameters
|
|
478
|
+
----------
|
|
479
|
+
x: Tensor
|
|
480
|
+
(batch_size, n_channels, n_times)
|
|
481
|
+
n_channel_offset: int (default 0)
|
|
482
|
+
The offset term to be added to the channel tokens
|
|
483
|
+
perturb: bool (default False)
|
|
484
|
+
Randomly select a number of time steps and reduce the
|
|
485
|
+
channel embedding to those time steps.
|
|
486
|
+
|
|
487
|
+
Returns
|
|
488
|
+
-------
|
|
489
|
+
emb: Tensor
|
|
490
|
+
(batch_size, emb_size)
|
|
491
|
+
"""
|
|
492
|
+
emb_seq = []
|
|
493
|
+
for i in range(x.shape[1]):
|
|
494
|
+
# Getting the spectrogram
|
|
495
|
+
channel_spec_emb = self.stft(x[:, i : i + 1, :])
|
|
496
|
+
# Linear layer to learn some representation over the frequency domain
|
|
497
|
+
# with permutation
|
|
498
|
+
channel_spec_emb = self.patch_embedding(channel_spec_emb)
|
|
499
|
+
batch_size, ts, _ = channel_spec_emb.shape
|
|
500
|
+
# (batch_size, ts, emb)
|
|
501
|
+
# Step by step the following lines do the following operations:
|
|
502
|
+
# - self.channel_tokens(self.index[i + n_channel_offset]):
|
|
503
|
+
# Fetches the embedding for a channel specified by i + n_channel_offset,
|
|
504
|
+
# where i is the current index and n_channel_offset adjusts
|
|
505
|
+
# which channel's embedding is retrieved.
|
|
506
|
+
# - .unsqueeze(0).unsqueeze(0):
|
|
507
|
+
# Adds two singleton dimensions to the embedding tensor.
|
|
508
|
+
# [emb_size] to [1, 1, emb_size] .
|
|
509
|
+
# - Repeat(batch_size, ts, 1):
|
|
510
|
+
# Replicates the embedding tensor to match the batch size and
|
|
511
|
+
# time steps (ts),
|
|
512
|
+
channel_token_emb = (
|
|
513
|
+
self.channel_tokens(self.index[i + n_channel_offset])
|
|
514
|
+
.unsqueeze(0)
|
|
515
|
+
.unsqueeze(0)
|
|
516
|
+
.repeat(batch_size, ts, 1)
|
|
517
|
+
)
|
|
518
|
+
# (batch_size, ts, emb)
|
|
519
|
+
# The positional embedding is explaining with more
|
|
520
|
+
# detail in the _PositionalEncoding class.
|
|
521
|
+
channel_emb = self.positional_encoding(channel_spec_emb + channel_token_emb)
|
|
522
|
+
# In case of perturb, the time steps are randomly selected
|
|
523
|
+
# and the channel embedding is reduced to a random number
|
|
524
|
+
# of time steps.
|
|
525
|
+
if perturb:
|
|
526
|
+
ts = channel_emb.shape[1]
|
|
527
|
+
ts_new = torch.randint(low=ts // 2, high=ts, size=(1,)).item()
|
|
528
|
+
selected_ts = torch.randperm(ts)[:ts_new]
|
|
529
|
+
channel_emb = channel_emb[:, selected_ts]
|
|
530
|
+
emb_seq.append(channel_emb)
|
|
531
|
+
|
|
532
|
+
# Concat and transformer
|
|
533
|
+
# (batch_size, 16 * ts, emb)
|
|
534
|
+
emb = torch.cat(emb_seq, dim=1)
|
|
535
|
+
# (batch_size, emb)
|
|
536
|
+
emb = self.transformer(emb).mean(dim=1)
|
|
537
|
+
return emb
|