braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,536 @@
1
+ """
2
+ CTNet: a convolutional transformer network for EEG-based motor imagery
3
+ classification from Wei Zhao et al. (2024).
4
+ """
5
+
6
+ # Authors: Wei Zhao <zhaowei701@163.com>
7
+ # Bruno Aristimunha <b.aristimunha@gmail.com> (braindecode adaptation)
8
+ # License: MIT
9
+
10
+ from __future__ import annotations
11
+
12
+ import math
13
+ from typing import Optional
14
+
15
+ import torch
16
+ from einops.layers.torch import Rearrange
17
+ from mne.utils import warn
18
+ from torch import Tensor, nn
19
+
20
+ from braindecode.models.base import EEGModuleMixin
21
+ from braindecode.modules import (
22
+ FeedForwardBlock,
23
+ MultiHeadAttention,
24
+ )
25
+
26
+
27
+ class CTNet(EEGModuleMixin, nn.Module):
28
+ """CTNet from Zhao, W et al (2024) [ctnet]_.
29
+
30
+ A Convolutional Transformer Network for EEG-Based Motor Imagery Classification
31
+
32
+ .. figure:: https://raw.githubusercontent.com/snailpt/CTNet/main/architecture.png
33
+ :align: center
34
+ :alt: CTNet Architecture
35
+
36
+ CTNet is an end-to-end neural network architecture designed for classifying motor imagery (MI) tasks from EEG signals.
37
+ The model combines convolutional neural networks (CNNs) with a Transformer encoder to capture both local and global temporal dependencies in the EEG data.
38
+
39
+ The architecture consists of three main components:
40
+
41
+ 1. **Convolutional Module**:
42
+ - Apply EEGNetV4 to perform some feature extraction, denoted here as
43
+ _PatchEmbeddingEEGNet module.
44
+
45
+ 2. **Transformer Encoder Module**:
46
+ - Utilizes multi-head self-attention mechanisms as EEGConformer but
47
+ with residual blocks.
48
+
49
+ 3. **Classifier Module**:
50
+ - Combines features from both the convolutional module
51
+ and the Transformer encoder.
52
+ - Flattens the combined features and applies dropout for regularization.
53
+ - Uses a fully connected layer to produce the final classification output.
54
+
55
+ Parameters
56
+ ----------
57
+ activation : nn.Module, default=nn.GELU
58
+ Activation function to use in the network.
59
+ heads : int, default=4
60
+ Number of attention heads in the Transformer encoder.
61
+ emb_size : int or None, default=None
62
+ Embedding size (dimensionality) for the Transformer encoder.
63
+ depth : int, default=6
64
+ Number of encoder layers in the Transformer.
65
+ n_filters_time : int, default=20
66
+ Number of temporal filters in the first convolutional layer.
67
+ kernel_size : int, default=64
68
+ Kernel size for the temporal convolutional layer.
69
+ depth_multiplier : int, default=2
70
+ Multiplier for the number of depth-wise convolutional filters.
71
+ pool_size_1 : int, default=8
72
+ Pooling size for the first average pooling layer.
73
+ pool_size_2 : int, default=8
74
+ Pooling size for the second average pooling layer.
75
+ drop_prob_cnn : float, default=0.3
76
+ Dropout probability after convolutional layers.
77
+ drop_prob_posi : float, default=0.1
78
+ Dropout probability for the positional encoding in the Transformer.
79
+ drop_prob_final : float, default=0.5
80
+ Dropout probability before the final classification layer.
81
+
82
+ Notes
83
+ -----
84
+ This implementation is adapted from the original CTNet source code
85
+ [ctnetcode]_ to comply with Braindecode's model standards.
86
+
87
+ References
88
+ ----------
89
+ .. [ctnet] Zhao, W., Jiang, X., Zhang, B., Xiao, S., & Weng, S. (2024).
90
+ CTNet: a convolutional transformer network for EEG-based motor imagery
91
+ classification. Scientific Reports, 14(1), 20237.
92
+ .. [ctnetcode] Zhao, W., Jiang, X., Zhang, B., Xiao, S., & Weng, S. (2024).
93
+ CTNet source code:
94
+ https://github.com/snailpt/CTNet
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ # Base arguments
100
+ n_outputs=None,
101
+ n_chans=None,
102
+ sfreq=None,
103
+ chs_info=None,
104
+ n_times=None,
105
+ input_window_seconds=None,
106
+ # Model specific arguments
107
+ activation_patch: nn.Module = nn.ELU,
108
+ activation_transformer: nn.Module = nn.GELU,
109
+ drop_prob_cnn: float = 0.3,
110
+ drop_prob_posi: float = 0.1,
111
+ drop_prob_final: float = 0.5,
112
+ # other parameters
113
+ heads: int = 4,
114
+ emb_size: Optional[int] = 40,
115
+ depth: int = 6,
116
+ n_filters_time: Optional[int] = None,
117
+ kernel_size: int = 64,
118
+ depth_multiplier: Optional[int] = 2,
119
+ pool_size_1: int = 8,
120
+ pool_size_2: int = 8,
121
+ ):
122
+ super().__init__(
123
+ n_outputs=n_outputs,
124
+ n_chans=n_chans,
125
+ chs_info=chs_info,
126
+ n_times=n_times,
127
+ input_window_seconds=input_window_seconds,
128
+ sfreq=sfreq,
129
+ )
130
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
131
+
132
+ self.activation_patch = activation_patch
133
+ self.activation_transformer = activation_transformer
134
+ self.drop_prob_cnn = drop_prob_cnn
135
+ self.pool_size_1 = pool_size_1
136
+ self.pool_size_2 = pool_size_2
137
+ self.kernel_size = kernel_size
138
+ self.drop_prob_posi = drop_prob_posi
139
+ self.drop_prob_final = drop_prob_final
140
+ self.heads = heads
141
+ self.depth = depth
142
+ # n_times - pool_size_1 / p
143
+ self.sequence_length = math.floor(
144
+ (
145
+ math.floor((self.n_times - self.pool_size_1) / self.pool_size_1 + 1)
146
+ - self.pool_size_2
147
+ )
148
+ / self.pool_size_2
149
+ + 1
150
+ )
151
+
152
+ self.depth_multiplier, self.n_filters_time, self.emb_size = self._resolve_dims(
153
+ depth_multiplier, n_filters_time, emb_size
154
+ )
155
+
156
+ # Layers
157
+ self.ensuredim = Rearrange("batch nchans time -> batch 1 nchans time")
158
+ self.flatten = nn.Flatten()
159
+
160
+ self.cnn = _PatchEmbeddingEEGNet(
161
+ n_filters_time=self.n_filters_time,
162
+ kernel_size=self.kernel_size,
163
+ depth_multiplier=self.depth_multiplier,
164
+ pool_size_1=self.pool_size_1,
165
+ pool_size_2=self.pool_size_2,
166
+ drop_prob=self.drop_prob_cnn,
167
+ n_chans=self.n_chans,
168
+ activation=self.activation_patch,
169
+ )
170
+
171
+ self.position = _PositionalEncoding(
172
+ emb_size=self.emb_size,
173
+ drop_prob=self.drop_prob_posi,
174
+ n_times=self.n_times,
175
+ pool_size=self.pool_size_1,
176
+ )
177
+
178
+ self.trans = _TransformerEncoder(
179
+ self.heads,
180
+ self.depth,
181
+ self.emb_size,
182
+ activation=self.activation_transformer,
183
+ )
184
+
185
+ self.flatten_drop_layer = nn.Sequential(
186
+ nn.Flatten(),
187
+ nn.Dropout(p=self.drop_prob_final),
188
+ )
189
+
190
+ self.final_layer = nn.Linear(
191
+ in_features=int(self.emb_size * self.sequence_length),
192
+ out_features=self.n_outputs,
193
+ )
194
+
195
+ def forward(self, x: Tensor) -> Tensor:
196
+ """
197
+ Forward pass of the CTNet model.
198
+
199
+ Parameters
200
+ ----------
201
+ x : Tensor
202
+ Input tensor of shape (batch_size, n_channels, n_times).
203
+
204
+ Returns
205
+ -------
206
+ Tensor
207
+ Output with shape (batch_size, n_outputs).
208
+ """
209
+ x = self.ensuredim(x)
210
+ cnn = self.cnn(x)
211
+ cnn = cnn * math.sqrt(self.emb_size)
212
+ cnn = self.position(cnn)
213
+ trans = self.trans(cnn)
214
+ features = cnn + trans
215
+ flatten_feature = self.flatten(features)
216
+ out = self.final_layer(flatten_feature)
217
+ return out
218
+
219
+ @staticmethod
220
+ def _resolve_dims(
221
+ depth_multiplier: Optional[int],
222
+ n_filters_time: Optional[int],
223
+ emb_size: Optional[int],
224
+ ) -> tuple[int, int, int]:
225
+ # Basic type/positivity checks for provided values
226
+ for name, val in (
227
+ ("depth_multiplier", depth_multiplier),
228
+ ("n_filters_time", n_filters_time),
229
+ ("emb_size", emb_size),
230
+ ):
231
+ if val is not None:
232
+ if not isinstance(val, int):
233
+ raise TypeError(f"{name} must be int, got {type(val).__name__}")
234
+ if val <= 0:
235
+ raise ValueError(f"{name} must be > 0, got {val}")
236
+
237
+ missing = [
238
+ k
239
+ for k, v in {
240
+ "depth_multiplier": depth_multiplier,
241
+ "n_filters_time": n_filters_time,
242
+ "emb_size": emb_size,
243
+ }.items()
244
+ if v is None
245
+ ]
246
+
247
+ if len(missing) >= 2:
248
+ # Too many unknowns → ambiguous
249
+ raise ValueError(
250
+ "Specify exactly two of {depth_multiplier, n_filters_time, emb_size}; the third will be inferred."
251
+ )
252
+
253
+ if len(missing) == 1:
254
+ # Infer the missing one
255
+ if missing[0] == "emb_size":
256
+ assert depth_multiplier is not None and n_filters_time is not None
257
+ emb_size = depth_multiplier * n_filters_time
258
+ elif missing[0] == "n_filters_time":
259
+ assert emb_size is not None and depth_multiplier is not None
260
+ if emb_size % depth_multiplier != 0:
261
+ raise ValueError(
262
+ f"emb_size={emb_size} must be divisible by depth_multiplier={depth_multiplier}"
263
+ )
264
+ n_filters_time = emb_size // depth_multiplier
265
+ else: # missing depth_multiplier
266
+ assert emb_size is not None and n_filters_time is not None
267
+ if emb_size % n_filters_time != 0:
268
+ raise ValueError(
269
+ f"emb_size={emb_size} must be divisible by n_filters_time={n_filters_time}"
270
+ )
271
+ depth_multiplier = emb_size // n_filters_time
272
+
273
+ else:
274
+ # All provided: enforce consistency
275
+ assert (
276
+ depth_multiplier is not None
277
+ and n_filters_time is not None
278
+ and emb_size is not None
279
+ )
280
+ prod = depth_multiplier * n_filters_time
281
+ if prod != emb_size:
282
+ raise ValueError(
283
+ "`depth_multiplier * n_filters_time` must equal `emb_size`, "
284
+ f"but got {depth_multiplier} * {n_filters_time} = {prod} != {emb_size}. "
285
+ "Fix by setting one of: "
286
+ f"emb_size={prod}, "
287
+ f"n_filters_time={emb_size // depth_multiplier if emb_size % depth_multiplier == 0 else 'not integer'}, "
288
+ f"depth_multiplier={emb_size // n_filters_time if emb_size % n_filters_time == 0 else 'not integer'}."
289
+ )
290
+
291
+ # Ensure plain ints for the return type
292
+ assert (
293
+ depth_multiplier is not None
294
+ and n_filters_time is not None
295
+ and emb_size is not None
296
+ )
297
+ return depth_multiplier, n_filters_time, emb_size
298
+
299
+
300
+ class _PatchEmbeddingEEGNet(nn.Module):
301
+ def __init__(
302
+ self,
303
+ n_filters_time: int = 16,
304
+ kernel_size: int = 64,
305
+ depth_multiplier: int = 2,
306
+ pool_size_1: int = 8,
307
+ pool_size_2: int = 8,
308
+ drop_prob: float = 0.3,
309
+ n_chans: int = 22,
310
+ activation: nn.Module = nn.ELU,
311
+ ):
312
+ super().__init__()
313
+ n_filters_out = depth_multiplier * n_filters_time
314
+ self.eegnet_module = nn.Sequential(
315
+ # Temporal convolution
316
+ nn.Conv2d(
317
+ in_channels=1,
318
+ out_channels=n_filters_time,
319
+ kernel_size=(1, kernel_size),
320
+ stride=(1, 1),
321
+ padding="same",
322
+ bias=False,
323
+ ),
324
+ nn.BatchNorm2d(n_filters_time),
325
+ # Channel depth-wise convolution
326
+ nn.Conv2d(
327
+ in_channels=n_filters_time,
328
+ out_channels=n_filters_out,
329
+ kernel_size=(n_chans, 1),
330
+ stride=(1, 1),
331
+ groups=n_filters_time,
332
+ padding="valid",
333
+ bias=False,
334
+ ),
335
+ nn.BatchNorm2d(n_filters_out),
336
+ activation(),
337
+ # First average pooling
338
+ nn.AvgPool2d(kernel_size=(1, pool_size_1)),
339
+ nn.Dropout(drop_prob),
340
+ # Spatial convolution
341
+ nn.Conv2d(
342
+ in_channels=n_filters_out,
343
+ out_channels=n_filters_out,
344
+ kernel_size=(1, 16),
345
+ padding="same",
346
+ bias=False,
347
+ ),
348
+ nn.BatchNorm2d(n_filters_out),
349
+ activation(),
350
+ # Second average pooling
351
+ nn.AvgPool2d(kernel_size=(1, pool_size_2)),
352
+ nn.Dropout(drop_prob),
353
+ )
354
+
355
+ self.projection = nn.Sequential(
356
+ Rearrange("b e h w -> b (h w) e"),
357
+ )
358
+
359
+ def forward(self, x: Tensor) -> Tensor:
360
+ """
361
+ Forward pass of the Patch Embedding CNN.
362
+
363
+ Parameters
364
+ ----------
365
+ x : Tensor
366
+ Input tensor of shape (batch_size, 1, n_channels, n_times).
367
+
368
+ Returns
369
+ -------
370
+ Tensor
371
+ Embedded patches of shape (batch_size, num_patches, embedding_dim).
372
+ """
373
+ x = self.eegnet_module(x)
374
+ x = self.projection(x)
375
+ return x
376
+
377
+
378
+ class _ResidualAdd(nn.Module):
379
+ def __init__(self, module: nn.Module, emb_size: int, drop_p: float):
380
+ super().__init__()
381
+ self.module = module
382
+ self.drop = nn.Dropout(drop_p)
383
+ self.layernorm = nn.LayerNorm(emb_size)
384
+
385
+ def forward(self, x: Tensor) -> Tensor:
386
+ """
387
+ Forward pass with residual connection.
388
+
389
+ Parameters
390
+ ----------
391
+ x : Tensor
392
+ Input tensor.
393
+ **kwargs : Any
394
+ Additional arguments.
395
+
396
+ Returns
397
+ -------
398
+ Tensor
399
+ Output tensor after applying residual connection.
400
+ """
401
+ res = self.module(x)
402
+ out = self.layernorm(self.drop(res) + x)
403
+ return out
404
+
405
+
406
+ class _TransformerEncoderBlock(nn.Module):
407
+ def __init__(
408
+ self,
409
+ dim_feedforward: int,
410
+ num_heads: int = 4,
411
+ drop_prob: float = 0.5,
412
+ forward_expansion: int = 4,
413
+ forward_drop_p: float = 0.5,
414
+ activation: nn.Module = nn.GELU,
415
+ ):
416
+ super().__init__()
417
+ self.attention = _ResidualAdd(
418
+ nn.Sequential(
419
+ MultiHeadAttention(dim_feedforward, num_heads, drop_prob),
420
+ ),
421
+ dim_feedforward,
422
+ drop_prob,
423
+ )
424
+ self.feed_forward = _ResidualAdd(
425
+ nn.Sequential(
426
+ FeedForwardBlock(
427
+ dim_feedforward,
428
+ expansion=forward_expansion,
429
+ drop_p=forward_drop_p,
430
+ activation=activation,
431
+ ),
432
+ ),
433
+ dim_feedforward,
434
+ drop_prob,
435
+ )
436
+
437
+ def forward(self, x: Tensor) -> Tensor:
438
+ """
439
+ Forward pass of the transformer encoder block.
440
+
441
+ Parameters
442
+ ----------
443
+ x : Tensor
444
+ Input tensor.
445
+ **kwargs : Any
446
+ Additional arguments.
447
+
448
+ Returns
449
+ -------
450
+ Tensor
451
+ Output tensor after transformer encoder block.
452
+ """
453
+ x = self.attention(x)
454
+ x = self.feed_forward(x)
455
+ return x
456
+
457
+
458
+ class _TransformerEncoder(nn.Module):
459
+ def __init__(
460
+ self,
461
+ nheads: int,
462
+ depth: int,
463
+ dim_feedforward: int,
464
+ activation: nn.Module = nn.GELU,
465
+ ):
466
+ super().__init__()
467
+ self.layers = nn.Sequential(
468
+ *[
469
+ _TransformerEncoderBlock(
470
+ dim_feedforward=dim_feedforward,
471
+ num_heads=nheads,
472
+ activation=activation,
473
+ )
474
+ for _ in range(depth)
475
+ ]
476
+ )
477
+
478
+ def forward(self, x: Tensor) -> Tensor:
479
+ """
480
+ Forward pass of the transformer encoder.
481
+
482
+ Parameters
483
+ ----------
484
+ x : Tensor
485
+ Input tensor.
486
+
487
+ Returns
488
+ -------
489
+ Tensor
490
+ Output tensor after transformer encoder.
491
+ """
492
+ return self.layers(x)
493
+
494
+
495
+ class _PositionalEncoding(nn.Module):
496
+ def __init__(
497
+ self,
498
+ n_times: int,
499
+ emb_size: int,
500
+ length: int = 100,
501
+ drop_prob: float = 0.1,
502
+ pool_size: int = 8,
503
+ ):
504
+ super().__init__()
505
+ self.pool_size = pool_size
506
+ self.n_times = n_times
507
+
508
+ if int(n_times / (pool_size * pool_size)) > length:
509
+ warn(
510
+ "the temporal dimensional is too long for this default length. "
511
+ "The length parameter will be automatically adjusted to "
512
+ "avoid inference issues."
513
+ )
514
+ length = int(n_times / (pool_size * pool_size))
515
+
516
+ self.dropout = nn.Dropout(drop_prob)
517
+ self.encoding = nn.Parameter(torch.randn(1, length, emb_size))
518
+
519
+ def forward(self, x: Tensor) -> Tensor:
520
+ """
521
+ Forward pass of the positional encoding.
522
+
523
+ Parameters
524
+ ----------
525
+ x : Tensor
526
+ Input tensor of shape (batch_size, sequence_length, embedding_dim).
527
+
528
+ Returns
529
+ -------
530
+ Tensor
531
+ Tensor with positional encoding added.
532
+ """
533
+ seq_length = x.size(1)
534
+ encoding = self.encoding[:, :seq_length, :]
535
+ x = x + encoding
536
+ return self.dropout(x)