braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,493 @@
1
+ import copy
2
+
3
+ import torch
4
+ from einops.layers.torch import Rearrange
5
+ from torch import nn
6
+
7
+ from braindecode.models.base import EEGModuleMixin
8
+
9
+
10
+ class BENDR(EEGModuleMixin, nn.Module):
11
+ r"""BENDR (BErt-inspired Neural Data Representations) from Kostas et al (2021) [bendr]_.
12
+
13
+ :bdg-success:`Convolution` :bdg-danger:`Foundation Model`
14
+
15
+ .. figure:: https://www.frontiersin.org/files/Articles/653659/fnhum-15-653659-HTML/image_m/fnhum-15-653659-g001.jpg
16
+ :align: center
17
+ :alt: BENDR Architecture
18
+ :width: 1000px
19
+
20
+ The **BENDR** architecture adapts techniques used for language modeling (LM) toward the
21
+ development of encephalography modeling (EM) [bendr]_. It utilizes a self-supervised
22
+ training objective to learn compressed representations of raw EEG signals [bendr]_. The
23
+ model is capable of modeling completely novel raw EEG sequences recorded with differing
24
+ hardware and subjects, aiming for transferable performance across a variety of downstream
25
+ BCI and EEG classification tasks [bendr]_.
26
+
27
+ .. rubric:: Architectural Overview
28
+
29
+ BENDR is adapted from wav2vec 2.0 [wav2vec2]_ and is composed of two main stages: a
30
+ feature extractor (Convolutional stage) that produces BErt-inspired Neural Data
31
+ Representations (BENDR), followed by a transformer encoder (Contextualizer) [bendr]_.
32
+
33
+ .. rubric:: Macro Components
34
+
35
+ - `BENDR.encoder` **(Convolutional Stage/Feature Extractor)**
36
+ - *Operations.* A stack of six short-receptive field 1D convolutions [bendr]_. Each
37
+ block consists of 1D convolution, GroupNorm, and GELU activation.
38
+ - *Role.* Takes raw data :math:`X_{raw}` and dramatically downsamples it to a new
39
+ sequence of vectors (BENDR) [bendr]_. Each resulting vector has a length of 512.
40
+ - `BENDR.contextualizer` **(Transformer Encoder)**
41
+ - *Operations.* A transformer encoder that uses layered, multi-head self-attention
42
+ [bendr]_. It employs T-Fixup weight initialization [tfixup]_ and uses 8 layers
43
+ and 8 heads.
44
+ - *Role.* Maps the sequence of BENDR vectors to a contextualized sequence. The output
45
+ of a fixed start token is typically used as the aggregate representation for
46
+ downstream classification [bendr]_.
47
+ - `Contextualizer.position_encoder` **(Positional Encoding)**
48
+ - *Operations.* An additive (grouped) convolution layer with a receptive field of 25
49
+ and 16 groups [bendr]_.
50
+ - *Role.* Encodes position information before the input enters the transformer.
51
+
52
+ .. rubric:: How the information is encoded temporally, spatially, and spectrally
53
+
54
+ * **Temporal.**
55
+ The convolutional encoder uses a stack of blocks where the stride matches the receptive
56
+ field (e.g., 3 for the first block, 2 for subsequent blocks) [bendr]_. This process
57
+ downsamples the raw data by a factor of 96, resulting in an effective sampling frequency
58
+ of approximately 2.67 Hz.
59
+ * **Spatial.**
60
+ To maintain simplicity and reduce complexity, the convolutional stage uses **1D
61
+ convolutions** and elects not to mix EEG channels across the first stage [bendr]_. The
62
+ input includes 20 channels (19 EEG channels and one relative amplitude channel).
63
+ * **Spectral.**
64
+ The convolution operations implicitly extract features from the raw EEG signal [bendr]_.
65
+ The representations (BENDR) are derived from the raw waveform using convolutional
66
+ operations followed by sequence modeling [wav2vec2]_.
67
+
68
+ .. rubric:: Additional Mechanisms
69
+
70
+ - **Self-Supervision (Pre-training).** Uses a masked sequence learning approach (adapted
71
+ from wav2vec 2.0 [wav2vec2]_) where contiguous spans of BENDR sequences are masked, and
72
+ the model attempts to reconstruct the original underlying encoded vector based on the
73
+ transformer output and a set of negative distractors [bendr]_.
74
+ - **Regularization.** LayerDrop [layerdrop]_ and Dropout (at probabilities 0.01 and 0.15,
75
+ respectively) are used during pre-training [bendr]_. The implementation also uses T-Fixup
76
+ scaling for parameter initialization [tfixup]_.
77
+ - **Input Conditioning.** A fixed token (a vector filled with the value **-5**) is
78
+ prepended to the BENDR sequence before input to the transformer, serving as the aggregate
79
+ representation token [bendr]_.
80
+
81
+ .. important::
82
+ **Pre-trained Weights Available**
83
+
84
+ This model has pre-trained weights available on the Hugging Face Hub.
85
+ You can load them using:
86
+
87
+ .. code-block:: python
88
+
89
+ from braindecode.models import BENDR
90
+
91
+ # Load pre-trained model from Hugging Face Hub
92
+ # you can specify `n_outputs` for your downstream task
93
+ model = BENDR.from_pretrained("braindecode/braindecode-bendr", n_outputs=2)
94
+
95
+ To push your own trained model to the Hub:
96
+
97
+ .. code-block:: python
98
+
99
+ # After training your model
100
+ model.push_to_hub(
101
+ repo_id="username/my-bendr-model", commit_message="Upload trained BENDR model"
102
+ )
103
+
104
+ Requires installing ``braindecode[hug]`` for Hub integration.
105
+
106
+ Notes
107
+ -----
108
+ * The full BENDR architecture contains a large number of parameters; configuration (1)
109
+ involved training over **one billion parameters** [bendr]_.
110
+ * Randomly initialized full BENDR architecture was generally ineffective at solving
111
+ downstream tasks without prior self-supervised training [bendr]_.
112
+ * The pre-training task (contrastive predictive coding via masking) is generalizable,
113
+ exhibiting strong uniformity of performance across novel subjects, hardware, and
114
+ tasks [bendr]_.
115
+
116
+ .. warning::
117
+
118
+ **Important:** To utilize the full potential of BENDR, the model requires
119
+ **self-supervised pre-training** on large, unlabeled EEG datasets (like TUEG) followed
120
+ by subsequent fine-tuning on the specific downstream classification task [bendr]_.
121
+
122
+ References
123
+ ----------
124
+ .. [bendr] Kostas, D., Aroca-Ouellette, S., & Rudzicz, F. (2021).
125
+ BENDR: Using transformers and a contrastive self-supervised learning task to learn from
126
+ massive amounts of EEG data.
127
+ Frontiers in Human Neuroscience, 15, 653659.
128
+ https://doi.org/10.3389/fnhum.2021.653659
129
+ .. [wav2vec2] Baevski, A., Zhou, Y., Mohamed, A., & Auli, M. (2020).
130
+ wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations.
131
+ In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, & H. Lin (Eds),
132
+ Advances in Neural Information Processing Systems (Vol. 33, pp. 12449-12460).
133
+ https://dl.acm.org/doi/10.5555/3495724.3496768
134
+ .. [tfixup] Huang, T. K., Liang, S., Jha, A., & Salakhutdinov, R. (2020).
135
+ Improving Transformer Optimization Through Better Initialization.
136
+ In International Conference on Machine Learning (pp. 4475-4483). PMLR.
137
+ https://dl.acm.org/doi/10.5555/3524938.3525354
138
+ .. [layerdrop] Fan, A., Grave, E., & Joulin, A. (2020).
139
+ Reducing Transformer Depth on Demand with Structured Dropout.
140
+ International Conference on Learning Representations.
141
+ Retrieved from https://openreview.net/forum?id=SylO2yStDr
142
+
143
+ Parameters
144
+ ----------
145
+ encoder_h : int, default=512
146
+ Hidden size (number of output channels) of the convolutional encoder. This determines
147
+ the dimensionality of the BENDR feature vectors produced by the encoder.
148
+ contextualizer_hidden : int, default=3076
149
+ Hidden size of the feedforward layer within each transformer block. The paper uses
150
+ approximately 2x the transformer dimension (3076 ~ 2 x 1536).
151
+ projection_head : bool, default=False
152
+ If True, adds a projection layer at the end of the encoder to project back to the
153
+ input feature size. This is used during self-supervised pre-training but typically
154
+ disabled during fine-tuning.
155
+ drop_prob : float, default=0.1
156
+ Dropout probability applied throughout the model. The paper recommends 0.15 for
157
+ pre-training and 0.0 for fine-tuning. Default is 0.1 as a compromise.
158
+ layer_drop : float, default=0.0
159
+ Probability of dropping entire transformer layers during training (LayerDrop
160
+ regularization [layerdrop]_). The paper uses 0.01 for pre-training and 0.0 for
161
+ fine-tuning.
162
+ activation : :class:`torch.nn.Module`, default=:class:`torch.nn.GELU`
163
+ Activation function used in the encoder convolutional blocks. The paper uses GELU
164
+ activation throughout.
165
+ transformer_layers : int, default=8
166
+ Number of transformer encoder layers in the contextualizer. The paper uses 8 layers.
167
+ transformer_heads : int, default=8
168
+ Number of attention heads in each transformer layer. The paper uses 8 heads with
169
+ head dimension of 192 (1536 / 8).
170
+ position_encoder_length : int, default=25
171
+ Kernel size for the convolutional positional encoding layer. The paper uses a
172
+ receptive field of 25 with 16 groups.
173
+ enc_width : tuple of int, default=(3, 2, 2, 2, 2, 2)
174
+ Kernel sizes for each of the 6 convolutional blocks in the encoder. Each value
175
+ corresponds to one block.
176
+ enc_downsample : tuple of int, default=(3, 2, 2, 2, 2, 2)
177
+ Stride values for each of the 6 convolutional blocks in the encoder. The total
178
+ downsampling factor is the product of all strides (3 x 2 x 2 x 2 x 2 x 2 = 96).
179
+ start_token : int or float, default=-5
180
+ Value used to fill the start token embedding that is prepended to the BENDR sequence
181
+ before input to the transformer. This token's output is used as the aggregate
182
+ representation for classification.
183
+ final_layer : bool, default=True
184
+ If True, includes a final linear classification layer that maps from encoder_h to
185
+ n_outputs. If False, the model outputs the contextualized features directly.
186
+ """
187
+
188
+ def __init__(
189
+ self,
190
+ # Signal related parameters
191
+ n_chans=None,
192
+ n_outputs=None,
193
+ n_times=None,
194
+ chs_info=None,
195
+ input_window_seconds=None,
196
+ sfreq=None,
197
+ # Model parameters
198
+ encoder_h=512, # Hidden size of the encoder convolutional layers
199
+ contextualizer_hidden=3076, # Feedforward hidden size in transformer
200
+ projection_head=False, # Whether encoder should project back to input feature size (unused in original fine-tuning)
201
+ drop_prob=0.1, # General dropout probability (paper: 0.15 for pretraining, 0.0 for fine-tuning)
202
+ layer_drop=0.0, # Probability of dropping transformer layers during training (paper: 0.01 for pretraining)
203
+ activation: type[nn.Module] = nn.GELU, # Activation function
204
+ # Transformer specific parameters
205
+ transformer_layers=8,
206
+ transformer_heads=8,
207
+ position_encoder_length=25, # Kernel size for positional encoding conv
208
+ enc_width=(3, 2, 2, 2, 2, 2),
209
+ enc_downsample=(3, 2, 2, 2, 2, 2),
210
+ start_token=-5, # Value for start token embedding
211
+ final_layer=True, # Whether to include the final linear layer
212
+ ):
213
+ super().__init__(
214
+ n_outputs=n_outputs,
215
+ n_chans=n_chans,
216
+ chs_info=chs_info,
217
+ n_times=n_times,
218
+ input_window_seconds=input_window_seconds,
219
+ sfreq=sfreq,
220
+ )
221
+
222
+ # Keep these parameters if needed later, otherwise they are captured by the mixin
223
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
224
+
225
+ self.encoder_h = encoder_h
226
+ self.contextualizer_hidden = contextualizer_hidden
227
+ self.include_final_layer = final_layer
228
+
229
+ # Encoder: Use parameters from __init__
230
+ self.encoder = _ConvEncoderBENDR(
231
+ in_features=self.n_chans,
232
+ encoder_h=encoder_h,
233
+ dropout=drop_prob,
234
+ projection_head=projection_head,
235
+ enc_width=enc_width,
236
+ enc_downsample=enc_downsample,
237
+ activation=activation,
238
+ )
239
+
240
+ self.contextualizer = _BENDRContextualizer(
241
+ in_features=self.encoder.encoder_h, # Use the output feature size of the encoder
242
+ hidden_feedforward=contextualizer_hidden,
243
+ heads=transformer_heads,
244
+ layers=transformer_layers,
245
+ dropout=drop_prob, # Use general dropout probability
246
+ activation=activation,
247
+ position_encoder=position_encoder_length, # Pass position encoder kernel size
248
+ layer_drop=layer_drop,
249
+ start_token=start_token, # Keep fixed start token value
250
+ )
251
+ in_features = self.encoder.encoder_h # Set in_features for final layer
252
+
253
+ self.final_layer = None
254
+ if self.include_final_layer:
255
+ # Input to Linear will be [batch_size, encoder_h] after taking last timestep
256
+ linear = nn.Linear(in_features=in_features, out_features=self.n_outputs)
257
+ self.final_layer = nn.utils.parametrizations.weight_norm(
258
+ linear, name="weight", dim=1
259
+ )
260
+
261
+ def forward(self, x):
262
+ # Input x: [batch_size, n_chans, n_times]
263
+ encoded = self.encoder(x)
264
+ # encoded: [batch_size, encoder_h, n_encoded_times]
265
+
266
+ context = self.contextualizer(encoded)
267
+ # context: [batch_size, encoder_h, n_encoded_times + 1] (due to start token)
268
+
269
+ # Extract features - use the output corresponding to the start token (index 0)
270
+ # The output has shape [batch_size, features, seq_len+1]
271
+ # The start token was prepended at position 0 in the contextualizer before the
272
+ # transformer layers. After permutation back to [batch, features, seq_len+1],
273
+ # the start token output is at index 0.
274
+ # According to the paper (Kostas et al. 2021): "The transformer output of this
275
+ # initial position was not modified during pre-training, and only used for
276
+ # downstream tasks." This follows BERT's [CLS] token convention where the
277
+ # first token aggregates sequence information via self-attention.
278
+ feature = context[:, :, 0]
279
+ # feature: [batch_size, encoder_h]
280
+
281
+ if self.final_layer is not None:
282
+ feature = self.final_layer(feature)
283
+ # feature: [batch_size, n_outputs]
284
+
285
+ return feature
286
+
287
+
288
+ class _ConvEncoderBENDR(nn.Module):
289
+ def __init__(
290
+ self,
291
+ in_features,
292
+ encoder_h=512,
293
+ enc_width=(3, 2, 2, 2, 2, 2),
294
+ dropout=0.0,
295
+ projection_head=False,
296
+ enc_downsample=(3, 2, 2, 2, 2, 2),
297
+ activation=nn.GELU,
298
+ ):
299
+ super().__init__()
300
+ self.encoder_h = encoder_h
301
+ self.in_features = in_features
302
+
303
+ if not isinstance(enc_width, (list, tuple)):
304
+ enc_width = [enc_width]
305
+ if not isinstance(enc_downsample, (list, tuple)):
306
+ enc_downsample = [enc_downsample]
307
+ if len(enc_downsample) != len(enc_width):
308
+ raise ValueError(
309
+ "Encoder width and downsampling factors must have the same length."
310
+ )
311
+
312
+ # Centerable convolutions make life simpler
313
+ enc_width = [
314
+ e if e % 2 != 0 else e + 1 for e in enc_width
315
+ ] # Ensure odd kernel size
316
+ self._downsampling = enc_downsample
317
+ self._width = enc_width
318
+
319
+ current_in_features = in_features
320
+ self.encoder = nn.Sequential()
321
+ for i, (width, downsample) in enumerate(zip(enc_width, enc_downsample)):
322
+ self.encoder.add_module(
323
+ "Encoder_{}".format(i),
324
+ nn.Sequential(
325
+ nn.Conv1d(
326
+ current_in_features,
327
+ encoder_h,
328
+ width,
329
+ stride=downsample,
330
+ padding=width
331
+ // 2, # Correct padding for 'same' output length before stride
332
+ ),
333
+ nn.Dropout2d(dropout), # 2D dropout (matches paper specification)
334
+ nn.GroupNorm(
335
+ encoder_h // 2, encoder_h
336
+ ), # Consider making num_groups configurable or ensure encoder_h is divisible by 2
337
+ activation(),
338
+ ),
339
+ )
340
+ current_in_features = encoder_h
341
+ if projection_head:
342
+ self.encoder.add_module(
343
+ "projection_head",
344
+ nn.Conv1d(encoder_h, encoder_h, 1),
345
+ )
346
+
347
+ def forward(self, x):
348
+ return self.encoder(x)
349
+
350
+
351
+ class _BENDRContextualizer(nn.Module):
352
+ r"""Transformer-based contextualizer for BENDR."""
353
+
354
+ def __init__(
355
+ self,
356
+ in_features,
357
+ hidden_feedforward=3076,
358
+ heads=8,
359
+ layers=8,
360
+ dropout=0.1, # Default dropout
361
+ activation=nn.GELU, # Activation for transformer FF layer
362
+ position_encoder=25, # Kernel size for conv positional encoding
363
+ layer_drop=0.0, # Probability of dropping a whole layer
364
+ start_token=-5, # Value for start token embedding
365
+ ):
366
+ super().__init__()
367
+ self.dropout = dropout
368
+ self.layer_drop = layer_drop
369
+ self.start_token = start_token # Store start token value
370
+
371
+ # Paper specification: transformer_dim = 3 * encoder_h (1536 for encoder_h=512)
372
+ self.transformer_dim = 3 * in_features
373
+ self.in_features = in_features
374
+ # --- Positional Encoding --- (Applied before projection)
375
+ self.position_encoder = None
376
+ if position_encoder > 0:
377
+ conv = nn.Conv1d(
378
+ in_features,
379
+ in_features,
380
+ kernel_size=position_encoder,
381
+ padding=position_encoder // 2,
382
+ groups=16, # Number of groups for depthwise separation
383
+ )
384
+ # T-Fixup positional encoding initialization (paper specification)
385
+ nn.init.normal_(conv.weight, mean=0, std=2 / self.transformer_dim)
386
+ nn.init.constant_(conv.bias, 0)
387
+
388
+ conv = nn.utils.parametrizations.weight_norm(conv, name="weight", dim=2)
389
+ self.relative_position = nn.Sequential(conv, activation())
390
+
391
+ # --- Input Conditioning --- (Includes projection up to transformer_dim)
392
+ # Rearrange, Norm, Dropout, Project, Rearrange
393
+ self.input_conditioning = nn.Sequential(
394
+ Rearrange(
395
+ "batch channel time -> batch time channel"
396
+ ), # Batch, Time, Channels
397
+ nn.LayerNorm(in_features),
398
+ nn.Dropout(dropout),
399
+ nn.Linear(in_features, self.transformer_dim), # Project up using Linear
400
+ # nn.Conv1d(in_features, self.transformer_dim, 1), # Alternative: Project up using Conv1d
401
+ Rearrange(
402
+ "batch time channel -> time batch channel"
403
+ ), # Time, Batch, Channels (Transformer expected format)
404
+ )
405
+
406
+ # --- Transformer Encoder Layers ---
407
+ # Paper uses T-Fixup: remove internal LayerNorm layers
408
+
409
+ encoder_layer = nn.TransformerEncoderLayer(
410
+ d_model=self.transformer_dim, # Use projected dimension
411
+ nhead=heads,
412
+ dim_feedforward=hidden_feedforward,
413
+ dropout=dropout, # Dropout within transformer layer
414
+ activation=activation(),
415
+ batch_first=False, # Expects (T, B, C)
416
+ norm_first=False, # Standard post-norm architecture
417
+ )
418
+
419
+ # T-Fixup: Replace internal LayerNorms with Identity
420
+ encoder_layer.norm1 = nn.Identity()
421
+ encoder_layer.norm2 = nn.Identity()
422
+
423
+ self.transformer_layers = nn.ModuleList(
424
+ [copy.deepcopy(encoder_layer) for _ in range(layers)]
425
+ )
426
+
427
+ # Add final LayerNorm after all transformer layers (paper specification)
428
+ self.norm = nn.LayerNorm(self.transformer_dim)
429
+
430
+ # --- Output Layer --- (Project back down to in_features)
431
+ # Paper uses Conv1d with kernel=1 (equivalent to Linear but expects (B,C,T) input)
432
+ self.output_layer = nn.Conv1d(self.transformer_dim, in_features, 1)
433
+
434
+ # Initialize parameters with T-Fixup (paper specification)
435
+ self.apply(self._init_bert_params)
436
+
437
+ def _init_bert_params(self, module):
438
+ """Initialize linear layers and apply T-Fixup scaling."""
439
+ if isinstance(module, nn.Linear):
440
+ # Standard init
441
+ nn.init.xavier_uniform_(module.weight.data)
442
+ if module.bias is not None:
443
+ module.bias.data.zero_()
444
+ # Tfixup Scaling
445
+ module.weight.data = (
446
+ 0.67 * len(self.transformer_layers) ** (-0.25) * module.weight.data
447
+ )
448
+ elif isinstance(module, nn.LayerNorm):
449
+ module.bias.data.zero_()
450
+ module.weight.data.fill_(1.0)
451
+
452
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
453
+ # Input x: [batch_size, in_features, seq_len]
454
+
455
+ # Apply relative positional encoding
456
+ if hasattr(self, "relative_position"):
457
+ pos_enc = self.relative_position(x)
458
+ x = x + pos_enc
459
+
460
+ # Apply input conditioning (includes projection up and rearrange)
461
+ x = self.input_conditioning(x)
462
+ # x: [seq_len, batch_size, transformer_dim]
463
+
464
+ # Prepend start token
465
+ if self.start_token is not None:
466
+ token_emb = torch.full(
467
+ (1, x.shape[1], x.shape[2]),
468
+ float(self.start_token),
469
+ device=x.device,
470
+ requires_grad=False,
471
+ )
472
+ x = torch.cat([token_emb, x], dim=0)
473
+ # x: [seq_len + 1, batch_size, transformer_dim]
474
+
475
+ # Apply transformer layers with layer drop
476
+ for layer in self.transformer_layers:
477
+ if not self.training or torch.rand(1) > self.layer_drop:
478
+ x = layer(x)
479
+ # x: [seq_len + 1, batch_size, transformer_dim]
480
+
481
+ # Apply final LayerNorm (T-Fixup: norm only at the end)
482
+ x = self.norm(x)
483
+ # x: [seq_len + 1, batch_size, transformer_dim]
484
+
485
+ # Permute to (B, C, T) format for Conv1d output layer
486
+ x = Rearrange("time batch channel -> batch channel time")(x)
487
+ # x: [batch_size, transformer_dim, seq_len + 1]
488
+
489
+ # Apply output projection (Conv1d expects B, C, T)
490
+ x = self.output_layer(x)
491
+ # x: [batch_size, in_features, seq_len + 1]
492
+
493
+ return x