braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,475 @@
1
+ # Authors: Yonghao Song <eeyhsong@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+ import warnings
5
+
6
+ import torch
7
+ from einops.layers.torch import Rearrange
8
+ from torch import Tensor, nn
9
+
10
+ from braindecode.models.base import EEGModuleMixin
11
+ from braindecode.modules import FeedForwardBlock, MultiHeadAttention
12
+
13
+
14
+ class EEGConformer(EEGModuleMixin, nn.Module):
15
+ r"""EEG Conformer from Song et al (2022) [song2022]_.
16
+
17
+ :bdg-success:`Convolution` :bdg-info:`Attention/Transformer`
18
+
19
+ .. figure:: https://raw.githubusercontent.com/eeyhsong/EEG-Conformer/refs/heads/main/visualization/Fig1.png
20
+ :align: center
21
+ :alt: EEGConformer Architecture
22
+ :width: 600px
23
+
24
+
25
+ .. rubric:: Architectural Overview
26
+
27
+ EEG-Conformer is a *convolution-first* model augmented with a *lightweight transformer
28
+ encoder*. The end-to-end flow is:
29
+
30
+ - (i) :class:`_PatchEmbedding` converts the continuous EEG into a compact sequence of tokens via a
31
+ :class:`ShallowFBCSPNet` temporal–spatial conv stem and temporal pooling;
32
+ - (ii) :class:`_TransformerEncoder` applies small multi-head self-attention to integrate
33
+ longer-range temporal context across tokens;
34
+ - (iii) :class:`_ClassificationHead` aggregates the sequence and performs a linear readout.
35
+ This preserves the strong inductive biases of shallow CNN filter banks while adding
36
+ just enough attention to capture dependencies beyond the pooling horizon [song2022]_.
37
+
38
+ .. rubric:: Macro Components
39
+
40
+ - :class:`_PatchEmbedding` **(Shallow conv stem → tokens)**
41
+
42
+ - *Operations.*
43
+ - A temporal convolution (`:class:torch.nn.Conv2d`) ``(1 x L_t)`` forms a data-driven "filter bank";
44
+ - A spatial convolution (`:class:torch.nn.Conv2d`) (n_chans x 1)`` projects across electrodes,
45
+ collapsing the channel axis into a virtual channel.
46
+ - **Normalization function** :class:`torch.nn.BatchNorm`
47
+ - **Activation function** :class:`torch.nn.ELU`
48
+ - **Average Pooling** :class:`torch.nn.AvgPool` along time (kernel ``(1, P)`` with stride ``(1, S)``)
49
+ - final ``1x1`` :class:`torch.nn.Linear` projection.
50
+
51
+ The result is rearranged to a token sequence ``(B, S_tokens, D)``, where ``D = n_filters_time``.
52
+
53
+ *Interpretability/robustness.* Temporal kernels can be inspected as FIR filters;
54
+ the spatial conv yields channel projections analogous to :class:`ShallowFBCSPNet`'s learned
55
+ spatial filters. Temporal pooling stabilizes statistics and reduces sequence length.
56
+
57
+ - :class:`_TransformerEncoder` **(context over temporal tokens)**
58
+
59
+ - *Operations.*
60
+ - A stack of ``num_layers`` encoder blocks. :class:`_TransformerEncoderBlock`
61
+ - Each block applies LayerNorm :class:`torch.nn.LayerNorm`
62
+ - Multi-Head Self-Attention (``num_heads``) with dropout + residual :class:`MultiHeadAttention` (:class:`torch.nn.Dropout`)
63
+ - LayerNorm :class:`torch.nn.LayerNorm`
64
+ - 2-layer feed-forward (≈4x expansion, :class:`torch.nn.GELU`) with dropout + residual.
65
+
66
+ Shapes remain ``(B, S_tokens, D)`` throughout.
67
+
68
+ *Role.* Small attention focuses on interactions among *temporal patches* (not channels),
69
+ extending effective receptive fields at modest cost.
70
+
71
+ - :class:`ClassificationHead` **(aggregation + readout)**
72
+
73
+ - *Operations*.
74
+ - Flatten, :class:`torch.nn.Flatten` the sequence ``(B, S_tokens·D)`` -
75
+ - MLP (:class:`torch.nn.Linear` → activation (default: :class:`torch.nn.ELU`) → :class:`torch.nn.Dropout` → :class:`torch.nn.Linear`)
76
+ - final Linear to classes.
77
+
78
+ With ``return_features=True``, features before the last Linear can be exported for
79
+ linear probing or downstream tasks.
80
+
81
+ .. rubric:: Convolutional Details
82
+
83
+ - **Temporal (where time-domain patterns are learned).**
84
+ The initial ``(1 x L_t)`` conv per channel acts as a *learned filter bank* for oscillatory
85
+ bands and transients. Subsequent **AvgPool** along time performs local integration,
86
+ converting activations into “patches” (tokens). Pool length/stride control the
87
+ token rate and set the lower bound on temporal context within each token.
88
+
89
+ - **Spatial (how electrodes are processed).**
90
+ A single conv with kernel ``(n_chans x 1)`` spans the full montage to learn spatial
91
+ projections for each temporal feature map, collapsing the channel axis into a
92
+ virtual channel before tokenization. This mirrors the shallow spatial step in
93
+ :class:`ShallowFBCSPNet` (temporal filters → spatial projection → temporal condensation).
94
+
95
+ - **Spectral (how frequency content is captured).**
96
+ No explicit Fourier/wavelet stage is used. Spectral selectivity emerges implicitly
97
+ from the learned temporal kernels; pooling further smooths high-frequency noise.
98
+ The effective spectral resolution is thus governed by ``L_t`` and the pooling
99
+ configuration.
100
+
101
+ .. rubric:: Attention / Sequential Modules
102
+
103
+ - **Type.** Standard multi-head self-attention (MHA) with ``num_heads`` heads over the token sequence.
104
+ - **Shapes.** Input/Output: ``(B, S_tokens, D)``; attention operates along the ``S_tokens`` axis.
105
+ - **Role.** Re-weights and integrates evidence across pooled windows, capturing dependencies
106
+ longer than any single token while leaving channel relationships to the convolutional stem.
107
+ The design is intentionally *small*—attention refines rather than replaces convolutional feature extraction.
108
+
109
+ .. rubric:: Additional Mechanisms
110
+
111
+ - **Parallel with ShallowFBCSPNet.** Both begin with a learned temporal filter bank,
112
+ spatial projection across electrodes, and early temporal condensation.
113
+ :class:`ShallowFBCSPNet` then computes band-power (via squaring/log-variance), whereas
114
+ EEG-Conformer applies BN/ELU and **continues with attention** over tokens to
115
+ refine temporal context before classification.
116
+
117
+ - **Tokenization knob.** ``pool_time_length`` and especially ``pool_time_stride`` set
118
+ the number of tokens ``S_tokens``. Smaller strides → more tokens and higher attention
119
+ capacity (but higher compute); larger strides → fewer tokens and stronger inductive bias.
120
+
121
+ - **Embedding dimension = filters.** ``n_filters_time`` serves double duty as both the
122
+ number of temporal filters in the stem and the transformer's embedding size ``D``,
123
+ simplifying dimensional alignment.
124
+
125
+ .. rubric:: Usage and Configuration
126
+
127
+ - **Instantiation.** Choose ``n_filters_time`` (embedding size ``D``) and
128
+ ``filter_time_length`` to match the rhythms of interest. Tune
129
+ ``pool_time_length/stride`` to trade temporal resolution for sequence length.
130
+ Keep ``num_layers`` modest (e.g., 4–6) and set ``num_heads`` to divide ``D``.
131
+ ``final_fc_length="auto"`` infers the flattened size from PatchEmbedding.
132
+
133
+ Notes
134
+ -----
135
+ The authors recommend using data augmentation before using Conformer,
136
+ e.g. segmentation and recombination,
137
+ Please refer to the original paper and code for more details [ConformerCode]_.
138
+
139
+ The model was initially tuned on 4 seconds of 250 Hz data.
140
+ Please adjust the scale of the temporal convolutional layer,
141
+ and the pooling layer for better performance.
142
+
143
+ .. versionadded:: 0.8
144
+
145
+ We aggregate the parameters based on the parts of the models, or
146
+ when the parameters were used first, e.g. ``n_filters_time``.
147
+
148
+ .. versionadded:: 1.1
149
+
150
+
151
+ Parameters
152
+ ----------
153
+ n_filters_time: int
154
+ Number of temporal filters, defines also embedding size.
155
+ filter_time_length: int
156
+ Length of the temporal filter.
157
+ pool_time_length: int
158
+ Length of temporal pooling filter.
159
+ pool_time_stride: int
160
+ Length of stride between temporal pooling filters.
161
+ drop_prob: float
162
+ Dropout rate of the convolutional layer.
163
+ num_layers: int
164
+ Number of self-attention layers.
165
+ num_heads: int
166
+ Number of attention heads.
167
+ att_drop_prob: float
168
+ Dropout rate of the self-attention layer.
169
+ final_fc_length: int | str
170
+ The dimension of the fully connected layer.
171
+ return_features: bool
172
+ If True, the forward method returns the features before the
173
+ last classification layer. Defaults to False.
174
+ activation: nn.Module
175
+ Activation function as parameter. Default is nn.ELU
176
+ activation_transfor: nn.Module
177
+ Activation function as parameter, applied at the FeedForwardBlock module
178
+ inside the transformer. Default is nn.GeLU
179
+
180
+ References
181
+ ----------
182
+ .. [song2022] Song, Y., Zheng, Q., Liu, B. and Gao, X., 2022. EEG
183
+ conformer: Convolutional transformer for EEG decoding and visualization.
184
+ IEEE Transactions on Neural Systems and Rehabilitation Engineering,
185
+ 31, pp.710-719. https://ieeexplore.ieee.org/document/9991178
186
+ .. [ConformerCode] Song, Y., Zheng, Q., Liu, B. and Gao, X., 2022. EEG
187
+ conformer: Convolutional transformer for EEG decoding and visualization.
188
+ https://github.com/eeyhsong/EEG-Conformer.
189
+ """
190
+
191
+ def __init__(
192
+ self,
193
+ n_outputs=None,
194
+ n_chans=None,
195
+ n_filters_time=40,
196
+ filter_time_length=25,
197
+ pool_time_length=75,
198
+ pool_time_stride=15,
199
+ drop_prob=0.5,
200
+ num_layers=6,
201
+ num_heads=10,
202
+ att_drop_prob=0.5,
203
+ final_fc_length="auto",
204
+ return_features=False,
205
+ activation: type[nn.Module] = nn.ELU,
206
+ activation_transfor: type[nn.Module] = nn.GELU,
207
+ n_times=None,
208
+ chs_info=None,
209
+ input_window_seconds=None,
210
+ sfreq=None,
211
+ ):
212
+ super().__init__(
213
+ n_outputs=n_outputs,
214
+ n_chans=n_chans,
215
+ chs_info=chs_info,
216
+ n_times=n_times,
217
+ input_window_seconds=input_window_seconds,
218
+ sfreq=sfreq,
219
+ )
220
+ self.mapping = {
221
+ "classification_head.fc.6.weight": "final_layer.final_layer.0.weight",
222
+ "classification_head.fc.6.bias": "final_layer.final_layer.0.bias",
223
+ }
224
+
225
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
226
+ if not (self.n_chans <= 64):
227
+ warnings.warn(
228
+ "This model has only been tested on no more "
229
+ + "than 64 channels. no guarantee to work with "
230
+ + "more channels.",
231
+ UserWarning,
232
+ )
233
+
234
+ self.return_features = return_features
235
+
236
+ self.patch_embedding = _PatchEmbedding(
237
+ n_filters_time=n_filters_time,
238
+ filter_time_length=filter_time_length,
239
+ n_channels=self.n_chans,
240
+ pool_time_length=pool_time_length,
241
+ stride_avg_pool=pool_time_stride,
242
+ drop_prob=drop_prob,
243
+ activation=activation,
244
+ )
245
+
246
+ if final_fc_length == "auto":
247
+ assert self.n_times is not None
248
+ self.final_fc_length = self.get_fc_size()
249
+ else:
250
+ self.final_fc_length = final_fc_length
251
+
252
+ self.transformer = _TransformerEncoder(
253
+ num_layers=num_layers,
254
+ emb_size=n_filters_time,
255
+ num_heads=num_heads,
256
+ att_drop=att_drop_prob,
257
+ activation=activation_transfor,
258
+ )
259
+
260
+ self.fc = _FullyConnected(
261
+ final_fc_length=self.final_fc_length, activation=activation
262
+ )
263
+
264
+ self.final_layer = nn.Linear(self.fc.hidden_channels, self.n_outputs)
265
+
266
+ def forward(self, x: Tensor) -> Tensor:
267
+ x = torch.unsqueeze(x, dim=1) # add one extra dimension
268
+ x = self.patch_embedding(x)
269
+ feature = self.transformer(x)
270
+
271
+ if self.return_features:
272
+ return feature
273
+
274
+ x = self.fc(feature)
275
+ x = self.final_layer(x)
276
+ return x
277
+
278
+ def get_fc_size(self):
279
+ out = self.patch_embedding(torch.ones((1, 1, self.n_chans, self.n_times)))
280
+ size_embedding_1 = out.cpu().data.numpy().shape[1]
281
+ size_embedding_2 = out.cpu().data.numpy().shape[2]
282
+
283
+ return size_embedding_1 * size_embedding_2
284
+
285
+
286
+ class _PatchEmbedding(nn.Module):
287
+ r"""Patch Embedding.
288
+
289
+ The authors used a convolution module to capture local features,
290
+ instead of position embedding.
291
+
292
+ Parameters
293
+ ----------
294
+ n_filters_time: int
295
+ Number of temporal filters, defines also embedding size.
296
+ filter_time_length: int
297
+ Length of the temporal filter.
298
+ n_channels: int
299
+ Number of channels to be used as number of spatial filters.
300
+ pool_time_length: int
301
+ Length of temporal poling filter.
302
+ stride_avg_pool: int
303
+ Length of stride between temporal pooling filters.
304
+ drop_prob: float
305
+ Dropout rate of the convolutional layer.
306
+
307
+ Returns
308
+ -------
309
+ x: torch.Tensor
310
+ The output tensor of the patch embedding layer.
311
+ """
312
+
313
+ def __init__(
314
+ self,
315
+ n_filters_time,
316
+ filter_time_length,
317
+ n_channels,
318
+ pool_time_length,
319
+ stride_avg_pool,
320
+ drop_prob,
321
+ activation: type[nn.Module] = nn.ELU,
322
+ ):
323
+ super().__init__()
324
+
325
+ self.shallownet = nn.Sequential(
326
+ nn.Conv2d(1, n_filters_time, (1, filter_time_length), (1, 1)),
327
+ nn.Conv2d(n_filters_time, n_filters_time, (n_channels, 1), (1, 1)),
328
+ nn.BatchNorm2d(num_features=n_filters_time),
329
+ activation(),
330
+ nn.AvgPool2d(
331
+ kernel_size=(1, pool_time_length), stride=(1, stride_avg_pool)
332
+ ),
333
+ # pooling acts as slicing to obtain 'patch' along the
334
+ # time dimension as in ViT
335
+ nn.Dropout(p=drop_prob),
336
+ )
337
+
338
+ self.projection = nn.Sequential(
339
+ nn.Conv2d(
340
+ n_filters_time, n_filters_time, (1, 1), stride=(1, 1)
341
+ ), # transpose, conv could enhance fiting ability slightly
342
+ Rearrange("b d_model 1 seq -> b seq d_model"),
343
+ )
344
+
345
+ def forward(self, x: Tensor) -> Tensor:
346
+ x = self.shallownet(x)
347
+ x = self.projection(x)
348
+ return x
349
+
350
+
351
+ class _ResidualAdd(nn.Module):
352
+ def __init__(self, fn):
353
+ super().__init__()
354
+ self.fn = fn
355
+
356
+ def forward(self, x):
357
+ res = x
358
+ x = self.fn(x)
359
+ x += res
360
+ return x
361
+
362
+
363
+ class _TransformerEncoderBlock(nn.Sequential):
364
+ def __init__(
365
+ self,
366
+ emb_size,
367
+ num_heads,
368
+ att_drop,
369
+ forward_expansion=4,
370
+ activation: type[nn.Module] = nn.GELU,
371
+ ):
372
+ super().__init__(
373
+ _ResidualAdd(
374
+ nn.Sequential(
375
+ nn.LayerNorm(emb_size),
376
+ MultiHeadAttention(emb_size, num_heads, att_drop),
377
+ nn.Dropout(att_drop),
378
+ )
379
+ ),
380
+ _ResidualAdd(
381
+ nn.Sequential(
382
+ nn.LayerNorm(emb_size),
383
+ FeedForwardBlock(
384
+ emb_size,
385
+ expansion=forward_expansion,
386
+ drop_p=att_drop,
387
+ activation=activation,
388
+ ),
389
+ nn.Dropout(att_drop),
390
+ )
391
+ ),
392
+ )
393
+
394
+
395
+ class _TransformerEncoder(nn.Sequential):
396
+ r"""Transformer encoder module for the transformer encoder.
397
+
398
+ Similar to the layers used in ViT.
399
+
400
+ Parameters
401
+ ----------
402
+ num_layers : int
403
+ Number of transformer encoder blocks.
404
+ emb_size : int
405
+ Embedding size of the transformer encoder.
406
+ num_heads : int
407
+ Number of attention heads.
408
+ att_drop : float
409
+ Dropout probability for the attention layers.
410
+
411
+ """
412
+
413
+ def __init__(
414
+ self,
415
+ num_layers,
416
+ emb_size,
417
+ num_heads,
418
+ att_drop,
419
+ activation: type[nn.Module] = nn.GELU,
420
+ ):
421
+ super().__init__(
422
+ *[
423
+ _TransformerEncoderBlock(
424
+ emb_size, num_heads, att_drop, activation=activation
425
+ )
426
+ for _ in range(num_layers)
427
+ ]
428
+ )
429
+
430
+
431
+ class _FullyConnected(nn.Module):
432
+ def __init__(
433
+ self,
434
+ final_fc_length,
435
+ drop_prob_1=0.5,
436
+ drop_prob_2=0.3,
437
+ out_channels=256,
438
+ hidden_channels=32,
439
+ activation: type[nn.Module] = nn.ELU,
440
+ ):
441
+ """Fully-connected layer for the transformer encoder.
442
+
443
+ Parameters
444
+ ----------
445
+ final_fc_length : int
446
+ Length of the final fully connected layer.
447
+ n_classes : int
448
+ Number of classes for classification.
449
+ drop_prob_1 : float
450
+ Dropout probability for the first dropout layer.
451
+ drop_prob_2 : float
452
+ Dropout probability for the second dropout layer.
453
+ out_channels : int
454
+ Number of output channels for the first linear layer.
455
+ hidden_channels : int
456
+ Number of output channels for the second linear layer.
457
+ return_features : bool
458
+ Whether to return input features.
459
+ """
460
+
461
+ super().__init__()
462
+ self.hidden_channels = hidden_channels
463
+ self.fc = nn.Sequential(
464
+ nn.Linear(final_fc_length, out_channels),
465
+ activation(),
466
+ nn.Dropout(drop_prob_1),
467
+ nn.Linear(out_channels, hidden_channels),
468
+ activation(),
469
+ nn.Dropout(drop_prob_2),
470
+ )
471
+
472
+ def forward(self, x):
473
+ x = x.contiguous().view(x.size(0), -1)
474
+ out = self.fc(x)
475
+ return out