braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,541 @@
1
+ """
2
+ CTNet: a convolutional transformer network for EEG-based motor imagery
3
+ classification from Wei Zhao et al. (2024).
4
+ """
5
+
6
+ # Authors: Wei Zhao <zhaowei701@163.com>
7
+ # Bruno Aristimunha <b.aristimunha@gmail.com> (braindecode adaptation)
8
+ # License: MIT
9
+
10
+ from __future__ import annotations
11
+
12
+ import math
13
+ from typing import Optional
14
+
15
+ import torch
16
+ from einops.layers.torch import Rearrange
17
+ from mne.utils import warn
18
+ from torch import Tensor, nn
19
+
20
+ from braindecode.models.base import EEGModuleMixin
21
+ from braindecode.modules import (
22
+ FeedForwardBlock,
23
+ MultiHeadAttention,
24
+ )
25
+
26
+
27
+ class CTNet(EEGModuleMixin, nn.Module):
28
+ r"""CTNet from Zhao, W et al (2024) [ctnet]_.
29
+
30
+ :bdg-success:`Convolution` :bdg-info:`Attention/Transformer`
31
+
32
+ A Convolutional Transformer Network for EEG-Based Motor Imagery Classification
33
+
34
+ .. figure:: https://raw.githubusercontent.com/snailpt/CTNet/main/architecture.png
35
+ :align: center
36
+ :alt: CTNet Architecture
37
+
38
+ CTNet is an end-to-end neural network architecture designed for classifying motor imagery (MI) tasks from EEG signals.
39
+ The model combines convolutional neural networks (CNNs) with a Transformer encoder to capture both local and global temporal dependencies in the EEG data.
40
+
41
+ The architecture consists of three main components:
42
+
43
+ 1. **Convolutional Module**:
44
+
45
+ - Apply :class:`EEGNet` to perform some feature extraction, denoted here as
46
+ _PatchEmbeddingEEGNet module.
47
+
48
+ 2. **Transformer Encoder Module**:
49
+
50
+ - Utilizes multi-head self-attention mechanisms as EEGConformer but
51
+ with residual blocks.
52
+
53
+ 3. **Classifier Module**:
54
+
55
+ - Combines features from both the convolutional module
56
+ and the Transformer encoder.
57
+ - Flattens the combined features and applies dropout for regularization.
58
+ - Uses a fully connected layer to produce the final classification output.
59
+
60
+ Parameters
61
+ ----------
62
+ activation : nn.Module, default=nn.GELU
63
+ Activation function to use in the network.
64
+ num_heads : int, default=4
65
+ Number of attention heads in the Transformer encoder.
66
+ embed_dim : int or None, default=None
67
+ Embedding size (dimensionality) for the Transformer encoder.
68
+ num_layers : int, default=6
69
+ Number of encoder layers in the Transformer.
70
+ n_filters_time : int, default=20
71
+ Number of temporal filters in the first convolutional layer.
72
+ kernel_size : int, default=64
73
+ Kernel size for the temporal convolutional layer.
74
+ depth_multiplier : int, default=2
75
+ Multiplier for the number of depth-wise convolutional filters.
76
+ pool_size_1 : int, default=8
77
+ Pooling size for the first average pooling layer.
78
+ pool_size_2 : int, default=8
79
+ Pooling size for the second average pooling layer.
80
+ cnn_drop_prob: float, default=0.3
81
+ Dropout probability after convolutional layers.
82
+ att_positional_drop_prob : float, default=0.1
83
+ Dropout probability for the positional encoding in the Transformer.
84
+ final_drop_prob : float, default=0.5
85
+ Dropout probability before the final classification layer.
86
+
87
+ Notes
88
+ -----
89
+ This implementation is adapted from the original CTNet source code
90
+ [ctnetcode]_ to comply with Braindecode's model standards.
91
+
92
+ References
93
+ ----------
94
+ .. [ctnet] Zhao, W., Jiang, X., Zhang, B., Xiao, S., & Weng, S. (2024).
95
+ CTNet: a convolutional transformer network for EEG-based motor imagery
96
+ classification. Scientific Reports, 14(1), 20237.
97
+ .. [ctnetcode] Zhao, W., Jiang, X., Zhang, B., Xiao, S., & Weng, S. (2024).
98
+ CTNet source code:
99
+ https://github.com/snailpt/CTNet
100
+ """
101
+
102
+ def __init__(
103
+ self,
104
+ # Base arguments
105
+ n_outputs=None,
106
+ n_chans=None,
107
+ sfreq=None,
108
+ chs_info=None,
109
+ n_times=None,
110
+ input_window_seconds=None,
111
+ # Model specific arguments
112
+ activation_patch: type[nn.Module] = nn.ELU,
113
+ activation_transformer: type[nn.Module] = nn.GELU,
114
+ cnn_drop_prob: float = 0.3,
115
+ att_positional_drop_prob: float = 0.1,
116
+ final_drop_prob: float = 0.5,
117
+ # other parameters
118
+ num_heads: int = 4,
119
+ embed_dim: Optional[int] = 40,
120
+ num_layers: int = 6,
121
+ n_filters_time: Optional[int] = None,
122
+ kernel_size: int = 64,
123
+ depth_multiplier: Optional[int] = 2,
124
+ pool_size_1: int = 8,
125
+ pool_size_2: int = 8,
126
+ ):
127
+ super().__init__(
128
+ n_outputs=n_outputs,
129
+ n_chans=n_chans,
130
+ chs_info=chs_info,
131
+ n_times=n_times,
132
+ input_window_seconds=input_window_seconds,
133
+ sfreq=sfreq,
134
+ )
135
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
136
+
137
+ self.activation_patch = activation_patch
138
+ self.activation_transformer = activation_transformer
139
+ self.cnn_drop_prob = cnn_drop_prob
140
+ self.pool_size_1 = pool_size_1
141
+ self.pool_size_2 = pool_size_2
142
+ self.kernel_size = kernel_size
143
+ self.att_positional_drop_prob = att_positional_drop_prob
144
+ self.final_drop_prob = final_drop_prob
145
+ self.num_heads = num_heads
146
+ self.num_layers = num_layers
147
+ # n_times - pool_size_1 / p
148
+ self.sequence_length = math.floor(
149
+ (
150
+ math.floor((self.n_times - self.pool_size_1) / self.pool_size_1 + 1)
151
+ - self.pool_size_2
152
+ )
153
+ / self.pool_size_2
154
+ + 1
155
+ )
156
+
157
+ self.depth_multiplier, self.n_filters_time, self.embed_dim = self._resolve_dims(
158
+ depth_multiplier, n_filters_time, embed_dim
159
+ )
160
+
161
+ # Layers
162
+ self.ensuredim = Rearrange("batch nchans time -> batch 1 nchans time")
163
+ self.flatten = nn.Flatten()
164
+
165
+ self.cnn = _PatchEmbeddingEEGNet(
166
+ n_filters_time=self.n_filters_time,
167
+ kernel_size=self.kernel_size,
168
+ depth_multiplier=self.depth_multiplier,
169
+ pool_size_1=self.pool_size_1,
170
+ pool_size_2=self.pool_size_2,
171
+ drop_prob=self.cnn_drop_prob,
172
+ n_chans=self.n_chans,
173
+ activation=self.activation_patch,
174
+ )
175
+
176
+ self.position = _PositionalEncoding(
177
+ emb_size=self.embed_dim,
178
+ drop_prob=self.att_positional_drop_prob,
179
+ n_times=self.n_times,
180
+ pool_size=self.pool_size_1,
181
+ )
182
+
183
+ self.trans = _TransformerEncoder(
184
+ self.num_heads,
185
+ self.num_layers,
186
+ self.embed_dim,
187
+ activation=self.activation_transformer,
188
+ )
189
+
190
+ self.flatten_drop_layer = nn.Sequential(
191
+ nn.Flatten(),
192
+ nn.Dropout(p=self.final_drop_prob),
193
+ )
194
+
195
+ self.final_layer = nn.Linear(
196
+ in_features=int(self.embed_dim * self.sequence_length),
197
+ out_features=self.n_outputs,
198
+ )
199
+
200
+ def forward(self, x: Tensor) -> Tensor:
201
+ """
202
+ Forward pass of the CTNet model.
203
+
204
+ Parameters
205
+ ----------
206
+ x : Tensor
207
+ Input tensor of shape (batch_size, n_channels, n_times).
208
+
209
+ Returns
210
+ -------
211
+ Tensor
212
+ Output with shape (batch_size, n_outputs).
213
+ """
214
+ x = self.ensuredim(x)
215
+ cnn = self.cnn(x)
216
+ cnn = cnn * math.sqrt(self.embed_dim)
217
+ cnn = self.position(cnn)
218
+ trans = self.trans(cnn)
219
+ features = cnn + trans
220
+ flatten_feature = self.flatten(features)
221
+ out = self.final_layer(flatten_feature)
222
+ return out
223
+
224
+ @staticmethod
225
+ def _resolve_dims(
226
+ depth_multiplier: Optional[int],
227
+ n_filters_time: Optional[int],
228
+ emb_size: Optional[int],
229
+ ) -> tuple[int, int, int]:
230
+ # Basic type/positivity checks for provided values
231
+ for name, val in (
232
+ ("depth_multiplier", depth_multiplier),
233
+ ("n_filters_time", n_filters_time),
234
+ ("emb_size", emb_size),
235
+ ):
236
+ if val is not None:
237
+ if not isinstance(val, int):
238
+ raise TypeError(f"{name} must be int, got {type(val).__name__}")
239
+ if val <= 0:
240
+ raise ValueError(f"{name} must be > 0, got {val}")
241
+
242
+ missing = [
243
+ k
244
+ for k, v in {
245
+ "depth_multiplier": depth_multiplier,
246
+ "n_filters_time": n_filters_time,
247
+ "emb_size": emb_size,
248
+ }.items()
249
+ if v is None
250
+ ]
251
+
252
+ if len(missing) >= 2:
253
+ # Too many unknowns → ambiguous
254
+ raise ValueError(
255
+ "Specify exactly two of {depth_multiplier, n_filters_time, emb_size}; the third will be inferred."
256
+ )
257
+
258
+ if len(missing) == 1:
259
+ # Infer the missing one
260
+ if missing[0] == "emb_size":
261
+ assert depth_multiplier is not None and n_filters_time is not None
262
+ emb_size = depth_multiplier * n_filters_time
263
+ elif missing[0] == "n_filters_time":
264
+ assert emb_size is not None and depth_multiplier is not None
265
+ if emb_size % depth_multiplier != 0:
266
+ raise ValueError(
267
+ f"emb_size={emb_size} must be divisible by depth_multiplier={depth_multiplier}"
268
+ )
269
+ n_filters_time = emb_size // depth_multiplier
270
+ else: # missing depth_multiplier
271
+ assert emb_size is not None and n_filters_time is not None
272
+ if emb_size % n_filters_time != 0:
273
+ raise ValueError(
274
+ f"emb_size={emb_size} must be divisible by n_filters_time={n_filters_time}"
275
+ )
276
+ depth_multiplier = emb_size // n_filters_time
277
+
278
+ else:
279
+ # All provided: enforce consistency
280
+ assert (
281
+ depth_multiplier is not None
282
+ and n_filters_time is not None
283
+ and emb_size is not None
284
+ )
285
+ prod = depth_multiplier * n_filters_time
286
+ if prod != emb_size:
287
+ raise ValueError(
288
+ "`depth_multiplier * n_filters_time` must equal `emb_size`, "
289
+ f"but got {depth_multiplier} * {n_filters_time} = {prod} != {emb_size}. "
290
+ "Fix by setting one of: "
291
+ f"emb_size={prod}, "
292
+ f"n_filters_time={emb_size // depth_multiplier if emb_size % depth_multiplier == 0 else 'not integer'}, "
293
+ f"depth_multiplier={emb_size // n_filters_time if emb_size % n_filters_time == 0 else 'not integer'}."
294
+ )
295
+
296
+ # Ensure plain ints for the return type
297
+ assert (
298
+ depth_multiplier is not None
299
+ and n_filters_time is not None
300
+ and emb_size is not None
301
+ )
302
+ return depth_multiplier, n_filters_time, emb_size
303
+
304
+
305
+ class _PatchEmbeddingEEGNet(nn.Module):
306
+ def __init__(
307
+ self,
308
+ n_filters_time: int = 16,
309
+ kernel_size: int = 64,
310
+ depth_multiplier: int = 2,
311
+ pool_size_1: int = 8,
312
+ pool_size_2: int = 8,
313
+ drop_prob: float = 0.3,
314
+ n_chans: int = 22,
315
+ activation: type[nn.Module] = nn.ELU,
316
+ ):
317
+ super().__init__()
318
+ n_filters_out = depth_multiplier * n_filters_time
319
+ self.eegnet_module = nn.Sequential(
320
+ # Temporal convolution
321
+ nn.Conv2d(
322
+ in_channels=1,
323
+ out_channels=n_filters_time,
324
+ kernel_size=(1, kernel_size),
325
+ stride=(1, 1),
326
+ padding="same",
327
+ bias=False,
328
+ ),
329
+ nn.BatchNorm2d(n_filters_time),
330
+ # Channel depth-wise convolution
331
+ nn.Conv2d(
332
+ in_channels=n_filters_time,
333
+ out_channels=n_filters_out,
334
+ kernel_size=(n_chans, 1),
335
+ stride=(1, 1),
336
+ groups=n_filters_time,
337
+ padding="valid",
338
+ bias=False,
339
+ ),
340
+ nn.BatchNorm2d(n_filters_out),
341
+ activation(),
342
+ # First average pooling
343
+ nn.AvgPool2d(kernel_size=(1, pool_size_1)),
344
+ nn.Dropout(drop_prob),
345
+ # Spatial convolution
346
+ nn.Conv2d(
347
+ in_channels=n_filters_out,
348
+ out_channels=n_filters_out,
349
+ kernel_size=(1, 16),
350
+ padding="same",
351
+ bias=False,
352
+ ),
353
+ nn.BatchNorm2d(n_filters_out),
354
+ activation(),
355
+ # Second average pooling
356
+ nn.AvgPool2d(kernel_size=(1, pool_size_2)),
357
+ nn.Dropout(drop_prob),
358
+ )
359
+
360
+ self.projection = nn.Sequential(
361
+ Rearrange("b e h w -> b (h w) e"),
362
+ )
363
+
364
+ def forward(self, x: Tensor) -> Tensor:
365
+ """
366
+ Forward pass of the Patch Embedding CNN.
367
+
368
+ Parameters
369
+ ----------
370
+ x : Tensor
371
+ Input tensor of shape (batch_size, 1, n_channels, n_times).
372
+
373
+ Returns
374
+ -------
375
+ Tensor
376
+ Embedded patches of shape (batch_size, num_patches, embedding_dim).
377
+ """
378
+ x = self.eegnet_module(x)
379
+ x = self.projection(x)
380
+ return x
381
+
382
+
383
+ class _ResidualAdd(nn.Module):
384
+ def __init__(self, module: nn.Module, emb_size: int, drop_p: float):
385
+ super().__init__()
386
+ self.module = module
387
+ self.drop = nn.Dropout(drop_p)
388
+ self.layernorm = nn.LayerNorm(emb_size)
389
+
390
+ def forward(self, x: Tensor) -> Tensor:
391
+ """
392
+ Forward pass with residual connection.
393
+
394
+ Parameters
395
+ ----------
396
+ x : Tensor
397
+ Input tensor.
398
+ **kwargs : Any
399
+ Additional arguments.
400
+
401
+ Returns
402
+ -------
403
+ Tensor
404
+ Output tensor after applying residual connection.
405
+ """
406
+ res = self.module(x)
407
+ out = self.layernorm(self.drop(res) + x)
408
+ return out
409
+
410
+
411
+ class _TransformerEncoderBlock(nn.Module):
412
+ def __init__(
413
+ self,
414
+ dim_feedforward: int,
415
+ num_heads: int = 4,
416
+ drop_prob: float = 0.5,
417
+ forward_expansion: int = 4,
418
+ forward_drop_p: float = 0.5,
419
+ activation: type[nn.Module] = nn.GELU,
420
+ ):
421
+ super().__init__()
422
+ self.attention = _ResidualAdd(
423
+ nn.Sequential(
424
+ MultiHeadAttention(dim_feedforward, num_heads, drop_prob),
425
+ ),
426
+ dim_feedforward,
427
+ drop_prob,
428
+ )
429
+ self.feed_forward = _ResidualAdd(
430
+ nn.Sequential(
431
+ FeedForwardBlock(
432
+ dim_feedforward,
433
+ expansion=forward_expansion,
434
+ drop_p=forward_drop_p,
435
+ activation=activation,
436
+ ),
437
+ ),
438
+ dim_feedforward,
439
+ drop_prob,
440
+ )
441
+
442
+ def forward(self, x: Tensor) -> Tensor:
443
+ """
444
+ Forward pass of the transformer encoder block.
445
+
446
+ Parameters
447
+ ----------
448
+ x : Tensor
449
+ Input tensor.
450
+ **kwargs : Any
451
+ Additional arguments.
452
+
453
+ Returns
454
+ -------
455
+ Tensor
456
+ Output tensor after transformer encoder block.
457
+ """
458
+ x = self.attention(x)
459
+ x = self.feed_forward(x)
460
+ return x
461
+
462
+
463
+ class _TransformerEncoder(nn.Module):
464
+ def __init__(
465
+ self,
466
+ nheads: int,
467
+ depth: int,
468
+ dim_feedforward: int,
469
+ activation: type[nn.Module] = nn.GELU,
470
+ ):
471
+ super().__init__()
472
+ self.layers = nn.Sequential(
473
+ *[
474
+ _TransformerEncoderBlock(
475
+ dim_feedforward=dim_feedforward,
476
+ num_heads=nheads,
477
+ activation=activation,
478
+ )
479
+ for _ in range(depth)
480
+ ]
481
+ )
482
+
483
+ def forward(self, x: Tensor) -> Tensor:
484
+ """
485
+ Forward pass of the transformer encoder.
486
+
487
+ Parameters
488
+ ----------
489
+ x : Tensor
490
+ Input tensor.
491
+
492
+ Returns
493
+ -------
494
+ Tensor
495
+ Output tensor after transformer encoder.
496
+ """
497
+ return self.layers(x)
498
+
499
+
500
+ class _PositionalEncoding(nn.Module):
501
+ def __init__(
502
+ self,
503
+ n_times: int,
504
+ emb_size: int,
505
+ length: int = 100,
506
+ drop_prob: float = 0.1,
507
+ pool_size: int = 8,
508
+ ):
509
+ super().__init__()
510
+ self.pool_size = pool_size
511
+ self.n_times = n_times
512
+
513
+ if int(n_times / (pool_size * pool_size)) > length:
514
+ warn(
515
+ "the temporal dimensional is too long for this default length. "
516
+ "The length parameter will be automatically adjusted to "
517
+ "avoid inference issues."
518
+ )
519
+ length = int(n_times / (pool_size * pool_size))
520
+
521
+ self.dropout = nn.Dropout(drop_prob)
522
+ self.encoding = nn.Parameter(torch.randn(1, length, emb_size))
523
+
524
+ def forward(self, x: Tensor) -> Tensor:
525
+ """
526
+ Forward pass of the positional encoding.
527
+
528
+ Parameters
529
+ ----------
530
+ x : Tensor
531
+ Input tensor of shape (batch_size, sequence_length, embedding_dim).
532
+
533
+ Returns
534
+ -------
535
+ Tensor
536
+ Tensor with positional encoding added.
537
+ """
538
+ seq_length = x.size(1)
539
+ encoding = self.encoding[:, :seq_length, :]
540
+ x = x + encoding
541
+ return self.dropout(x)