braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,199 @@
1
+ """
2
+ EEG-SimpleConv is a 1D Convolutional Neural Network from Yassine El Ouahidi et al. (2023).
3
+
4
+ Originally designed for Motor Imagery decoding, from EEG signals.
5
+ The model offers competitive performances, with a low latency and is mainly composed of
6
+ 1D convolutional layers.
7
+
8
+ """
9
+
10
+ # Authors: Yassine El Ouahidi <eloua.yas@gmail.com>
11
+ #
12
+ # License: BSD-3
13
+
14
+ import torch
15
+ from torch import nn
16
+ from torchaudio.transforms import Resample
17
+
18
+ from braindecode.models.base import EEGModuleMixin
19
+
20
+
21
+ class EEGSimpleConv(EEGModuleMixin, torch.nn.Module):
22
+ """EEGSimpleConv from Ouahidi, YE et al. (2023) [Yassine2023]_.
23
+
24
+ .. figure:: https://raw.githubusercontent.com/elouayas/EEGSimpleConv/refs/heads/main/architecture.png
25
+ :align: center
26
+ :alt: EEGSimpleConv Architecture
27
+
28
+ EEGSimpleConv is a 1D Convolutional Neural Network originally designed
29
+ for decoding motor imagery from EEG signals. The model aims to have a
30
+ very simple and straightforward architecture that allows a low latency,
31
+ while still achieving very competitive performance.
32
+
33
+ EEG-SimpleConv starts with a 1D convolutional layer, where each EEG channel
34
+ enters a separate 1D convolutional channel. This is followed by a series of
35
+ blocks of two 1D convolutional layers. Between the two convolutional layers
36
+ of each block is a max pooling layer, which downsamples the data by a factor
37
+ of 2. Each convolution is followed by a batch normalisation layer and a ReLU
38
+ activation function. Finally, a global average pooling (in the time domain)
39
+ is performed to obtain a single value per feature map, which is then fed
40
+ into a linear layer to obtain the final classification prediction output.
41
+
42
+ The paper and original code with more details about the methodological
43
+ choices are available at the [Yassine2023]_ and [Yassine2023Code]_.
44
+
45
+ The input shape should be three-dimensional matrix representing the EEG
46
+ signals.
47
+
48
+ ``(batch_size, n_channels, n_timesteps)``.
49
+
50
+ Notes
51
+ -----
52
+ The authors recommend using the default parameters for MI decoding.
53
+ Please refer to the original paper and code for more details.
54
+
55
+ Recommended range for the choice of the hyperparameters, regarding the
56
+ evaluation paradigm.
57
+
58
+ | Parameter | Within-Subject | Cross-Subject |
59
+ | feature_maps | [64-144] | [64-144] |
60
+ | n_convs | 1 | [2-4] |
61
+ | resampling_freq | [70-100] | [50-80] |
62
+ | kernel_size | [12-17] | [5-8] |
63
+
64
+
65
+ An intensive ablation study is included in the paper to understand the
66
+ of each parameter on the model performance.
67
+
68
+ .. versionadded:: 0.9
69
+
70
+ Parameters
71
+ ----------
72
+ feature_maps: int
73
+ Number of Feature Maps at the first Convolution, width of the model.
74
+ n_convs: int
75
+ Number of blocks of convolutions (2 convolutions per block), depth of the model.
76
+ resampling: int
77
+ Resampling Frequency.
78
+ kernel_size: int
79
+ Size of the convolutions kernels.
80
+ activation: nn.Module, default=nn.ELU
81
+ Activation function class to apply. Should be a PyTorch activation
82
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
83
+
84
+ References
85
+ ----------
86
+ .. [Yassine2023] Yassine El Ouahidi, V. Gripon, B. Pasdeloup, G. Bouallegue
87
+ N. Farrugia, G. Lioi, 2023. A Strong and Simple Deep Learning Baseline for
88
+ BCI Motor Imagery Decoding. Arxiv preprint. arxiv.org/abs/2309.07159
89
+ .. [Yassine2023Code] Yassine El Ouahidi, V. Gripon, B. Pasdeloup, G. Bouallegue
90
+ N. Farrugia, G. Lioi, 2023. A Strong and Simple Deep Learning Baseline for
91
+ BCI Motor Imagery Decoding. GitHub repository.
92
+ https://github.com/elouayas/EEGSimpleConv.
93
+
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ # Base arguments
99
+ n_outputs=None,
100
+ n_chans=None,
101
+ sfreq=None,
102
+ # Model specific arguments
103
+ feature_maps=128,
104
+ n_convs=2,
105
+ resampling_freq=80,
106
+ kernel_size=8,
107
+ return_feature=False,
108
+ activation: nn.Module = nn.ReLU,
109
+ # Other ways to initialize the model
110
+ chs_info=None,
111
+ n_times=None,
112
+ input_window_seconds=None,
113
+ ):
114
+ super().__init__(
115
+ n_outputs=n_outputs,
116
+ n_chans=n_chans,
117
+ chs_info=chs_info,
118
+ n_times=n_times,
119
+ input_window_seconds=input_window_seconds,
120
+ sfreq=sfreq,
121
+ )
122
+ del n_outputs, n_chans, chs_info, n_times, sfreq, input_window_seconds
123
+
124
+ self.return_feature = return_feature
125
+ self.resample = (
126
+ Resample(orig_freq=int(self.sfreq), new_freq=int(resampling_freq))
127
+ if self.sfreq != resampling_freq
128
+ else torch.nn.Identity()
129
+ )
130
+
131
+ self.conv = torch.nn.Conv1d(
132
+ self.n_chans,
133
+ feature_maps,
134
+ kernel_size=kernel_size,
135
+ padding=kernel_size // 2,
136
+ bias=False,
137
+ )
138
+ self.bn = torch.nn.BatchNorm1d(feature_maps)
139
+ self.blocks = []
140
+ new_feature_maps = feature_maps
141
+ old_feature_maps = feature_maps
142
+ for i in range(n_convs):
143
+ if i > 0:
144
+ # 1.414 = sqrt(2) allow constant flops.
145
+ new_feature_maps = int(1.414 * new_feature_maps)
146
+ self.blocks.append(
147
+ torch.nn.Sequential(
148
+ (
149
+ torch.nn.Conv1d(
150
+ old_feature_maps,
151
+ new_feature_maps,
152
+ kernel_size=kernel_size,
153
+ padding=kernel_size // 2,
154
+ bias=False,
155
+ )
156
+ ),
157
+ (torch.nn.BatchNorm1d(new_feature_maps)),
158
+ (torch.nn.MaxPool1d(2) if i > 0 - 1 else torch.nn.MaxPool1d(1)),
159
+ (activation()),
160
+ (
161
+ torch.nn.Conv1d(
162
+ new_feature_maps,
163
+ new_feature_maps,
164
+ kernel_size=kernel_size,
165
+ padding=kernel_size // 2,
166
+ bias=False,
167
+ )
168
+ ),
169
+ (torch.nn.BatchNorm1d(new_feature_maps)),
170
+ (activation()),
171
+ )
172
+ )
173
+ old_feature_maps = new_feature_maps
174
+ self.blocks = torch.nn.ModuleList(self.blocks)
175
+ self.final_layer = torch.nn.Linear(old_feature_maps, self.n_outputs)
176
+
177
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
178
+ """
179
+ Forward pass of the model.
180
+
181
+ Parameters
182
+ ----------
183
+ x: PyTorch Tensor
184
+ Input tensor of shape (batch_size, n_channels, n_times)
185
+
186
+ Returns
187
+ -------
188
+ PyTorch Tensor (optional)
189
+ Output tensor of shape (batch_size, n_outputs)
190
+ """
191
+ x_rs = self.resample(x.contiguous())
192
+ feat = torch.relu(self.bn(self.conv(x_rs)))
193
+ for seq in self.blocks:
194
+ feat = seq(feat)
195
+ feat = feat.mean(dim=2)
196
+ if self.return_feature:
197
+ return feat
198
+ else:
199
+ return self.final_layer(feat)
@@ -0,0 +1,335 @@
1
+ # Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+ from __future__ import annotations
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from einops.layers.torch import Rearrange
10
+
11
+ from braindecode.models.base import EEGModuleMixin
12
+ from braindecode.modules import Chomp1d, MaxNormLinear
13
+
14
+
15
+ class EEGTCNet(EEGModuleMixin, nn.Module):
16
+ """EEGTCNet model from Ingolfsson et al. (2020) [ingolfsson2020]_.
17
+
18
+ .. figure:: https://braindecode.org/dev/_static/model/eegtcnet.jpg
19
+ :align: center
20
+ :alt: EEGTCNet Architecture
21
+
22
+ Combining EEGNet and TCN blocks.
23
+
24
+ Parameters
25
+ ----------
26
+ activation : nn.Module, optional
27
+ Activation function to use. Default is `nn.ELU()`.
28
+ depth_multiplier : int, optional
29
+ Depth multiplier for the depthwise convolution. Default is 2.
30
+ filter_1 : int, optional
31
+ Number of temporal filters in the first convolutional layer. Default is 8.
32
+ kern_length : int, optional
33
+ Length of the temporal kernel in the first convolutional layer. Default is 64.
34
+ dropout : float, optional
35
+ Dropout rate. Default is 0.5.
36
+ depth : int, optional
37
+ Number of residual blocks in the TCN. Default is 2.
38
+ kernel_size : int, optional
39
+ Size of the temporal convolutional kernel in the TCN. Default is 4.
40
+ filters : int, optional
41
+ Number of filters in the TCN convolutional layers. Default is 12.
42
+ max_norm_const : float
43
+ Maximum L2-norm constraint imposed on weights of the last
44
+ fully-connected layer. Defaults to 0.25.
45
+
46
+ References
47
+ ----------
48
+ .. [ingolfsson2020] Ingolfsson, T. M., Hersche, M., Wang, X., Kobayashi, N.,
49
+ Cavigelli, L., & Benini, L. (2020). EEG-TCNet: An accurate temporal
50
+ convolutional network for embedded motor-imagery brain–machine interfaces.
51
+ https://doi.org/10.48550/arXiv.2006.00622
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ # Signal related parameters
57
+ n_chans=None,
58
+ n_outputs=None,
59
+ n_times=None,
60
+ chs_info=None,
61
+ input_window_seconds=None,
62
+ sfreq=None,
63
+ # Model parameters
64
+ activation: nn.Module = nn.ELU,
65
+ depth_multiplier: int = 2,
66
+ filter_1: int = 8,
67
+ kern_length: int = 64,
68
+ drop_prob: float = 0.5,
69
+ depth: int = 2,
70
+ kernel_size: int = 4,
71
+ filters: int = 12,
72
+ max_norm_const: float = 0.25,
73
+ ):
74
+ super().__init__(
75
+ n_outputs=n_outputs,
76
+ n_chans=n_chans,
77
+ chs_info=chs_info,
78
+ n_times=n_times,
79
+ input_window_seconds=input_window_seconds,
80
+ sfreq=sfreq,
81
+ )
82
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
83
+
84
+ self.activation = activation
85
+ self.drop_prob = drop_prob
86
+ self.depth_multiplier = depth_multiplier
87
+ self.filter_1 = filter_1
88
+ self.kern_length = kern_length
89
+ self.depth = depth
90
+ self.kernel_size = kernel_size
91
+ self.filters = filters
92
+ self.max_norm_const = max_norm_const
93
+ self.filter_2 = self.filter_1 * self.depth_multiplier
94
+
95
+ self.arrange_dim_input = Rearrange(
96
+ "batch nchans ntimes -> batch 1 ntimes nchans"
97
+ )
98
+ # EEGNet_TC Block
99
+ self.eegnet_tc = _EEGNetTC(
100
+ n_chans=self.n_chans,
101
+ filter_1=self.filter_1,
102
+ kern_length=self.kern_length,
103
+ depth_multiplier=self.depth_multiplier,
104
+ drop_prob=self.drop_prob,
105
+ activation=self.activation,
106
+ )
107
+ self.arrange_dim_eegnet = Rearrange(
108
+ "batch filter2 rtimes 1 -> batch rtimes filter2"
109
+ )
110
+
111
+ # TCN Block
112
+ self.tcn_block = _TCNBlock(
113
+ input_dimension=self.filter_2,
114
+ depth=self.depth,
115
+ kernel_size=self.kernel_size,
116
+ filters=self.filters,
117
+ drop_prob=self.drop_prob,
118
+ activation=self.activation,
119
+ )
120
+
121
+ # Classification Block
122
+ self.final_layer = MaxNormLinear(
123
+ in_features=self.filters,
124
+ out_features=self.n_outputs,
125
+ max_norm_val=self.max_norm_const,
126
+ )
127
+
128
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
129
+ """
130
+ Forward pass of the EEGTCNet model.
131
+
132
+ Parameters
133
+ ----------
134
+ x : torch.Tensor
135
+ Input tensor of shape (batch_size, n_chans, n_times).
136
+
137
+ Returns
138
+ -------
139
+ torch.Tensor
140
+ Output tensor of shape (batch_size, n_outputs).
141
+ """
142
+ # x shape: (batch_size, n_chans, n_times)
143
+ x = self.arrange_dim_input(x) # (batch_size, 1, n_times, n_chans)
144
+ x = self.eegnet_tc(x) # (batch_size, filter, reduced_time, 1)
145
+
146
+ x = self.arrange_dim_eegnet(x) # (batch_size, reduced_time, F2)
147
+ x = self.tcn_block(x) # (batch_size, time_steps, filters)
148
+
149
+ # Select the last time step
150
+ x = x[:, -1, :] # (batch_size, filters)
151
+
152
+ x = self.final_layer(x) # (batch_size, n_outputs)
153
+
154
+ return x
155
+
156
+
157
+ class _EEGNetTC(nn.Module):
158
+ """EEGNet Temporal Convolutional Network (TCN) block.
159
+
160
+ The main difference from our EEGNetV4 (braindecode) implementation is the
161
+ kernel and dimensional order. Because of this, we decided to keep this
162
+ implementation in a future issue; we will re-evaluate if it is necessary
163
+ to maintain this separate implementation.
164
+
165
+ Parameters
166
+ ----------
167
+ n_chans : int
168
+ Number of EEG channels.
169
+ filter_1 : int
170
+ Number of temporal filters in the first convolutional layer.
171
+ kern_length : int
172
+ Length of the temporal kernel in the first convolutional layer.
173
+ depth_multiplier : int
174
+ Depth multiplier for the depthwise convolution.
175
+ drop_prob : float
176
+ Dropout rate.
177
+ activation : nn.Module
178
+ Activation function.
179
+ """
180
+
181
+ def __init__(
182
+ self,
183
+ n_chans: int,
184
+ filter_1: int = 8,
185
+ kern_length: int = 64,
186
+ depth_multiplier: int = 2,
187
+ drop_prob: float = 0.5,
188
+ activation: nn.Module = nn.ELU,
189
+ ):
190
+ super().__init__()
191
+ self.activation = activation()
192
+ self.drop_prob = drop_prob
193
+ self.n_chans = n_chans
194
+ self.filter_1 = filter_1
195
+ self.filter_2 = self.filter_1 * depth_multiplier
196
+
197
+ # First Conv2D Layer
198
+ self.conv1 = nn.Conv2d(
199
+ in_channels=1,
200
+ out_channels=self.filter_1,
201
+ kernel_size=(kern_length, 1),
202
+ padding=(kern_length // 2, 0),
203
+ bias=False,
204
+ )
205
+ self.bn1 = nn.BatchNorm2d(self.filter_1)
206
+
207
+ # Depthwise Convolution
208
+ self.depthwise_conv = nn.Conv2d(
209
+ in_channels=self.filter_1,
210
+ out_channels=self.filter_2,
211
+ kernel_size=(1, n_chans),
212
+ groups=self.filter_1,
213
+ bias=False,
214
+ )
215
+ self.bn2 = nn.BatchNorm2d(self.filter_2)
216
+ self.pool1 = nn.AvgPool2d(kernel_size=(8, 1))
217
+ self.drop1 = nn.Dropout(p=drop_prob)
218
+
219
+ # Separable Convolution (Depthwise + Pointwise)
220
+ self.separable_conv_depthwise = nn.Conv2d(
221
+ in_channels=self.filter_2,
222
+ out_channels=self.filter_2,
223
+ kernel_size=(self.filter_2, 1),
224
+ groups=self.filter_2,
225
+ padding=(self.filter_2 // 2, 0),
226
+ bias=False,
227
+ )
228
+ self.separable_conv_pointwise = nn.Conv2d(
229
+ in_channels=self.filter_2,
230
+ out_channels=self.filter_2,
231
+ kernel_size=(1, 1),
232
+ bias=False,
233
+ )
234
+ self.bn3 = nn.BatchNorm2d(self.filter_2)
235
+ self.pool2 = nn.AvgPool2d(kernel_size=(self.filter_1, 1))
236
+ self.drop2 = nn.Dropout(p=drop_prob)
237
+
238
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
239
+ # x shape: (batch_size, 1, n_times, n_chans)
240
+ x = self.conv1(x)
241
+ x = self.bn1(x)
242
+ x = self.activation(x)
243
+
244
+ x = self.depthwise_conv(x)
245
+ x = self.bn2(x)
246
+ x = self.activation(x)
247
+ x = self.pool1(x)
248
+ x = self.drop1(x)
249
+
250
+ x = self.separable_conv_depthwise(x)
251
+ x = self.separable_conv_pointwise(x)
252
+ x = self.bn3(x)
253
+ x = self.activation(x)
254
+ x = self.pool2(x)
255
+ x = self.drop2(x)
256
+
257
+ return x # Shape: (batch_size, F2, reduced_time, 1)
258
+
259
+
260
+ class _TCNBlock(nn.Module):
261
+ """
262
+ Many differences from our Temporal Block (braindecode) implementation.
263
+ Because of this, we decided to keep this implementation in a future issue;
264
+ we will re-evaluate if it is necessary to maintain this separate
265
+ implementation.
266
+
267
+
268
+ """
269
+
270
+ def __init__(
271
+ self,
272
+ input_dimension: int,
273
+ depth: int,
274
+ kernel_size: int,
275
+ filters: int,
276
+ drop_prob: float,
277
+ activation: nn.Module = nn.ELU,
278
+ ):
279
+ super().__init__()
280
+ self.activation = activation()
281
+ self.drop_prob = drop_prob
282
+ self.depth = depth
283
+ self.filters = filters
284
+ self.kernel_size = kernel_size
285
+
286
+ self.layers = nn.ModuleList()
287
+ self.downsample = (
288
+ nn.Conv1d(input_dimension, filters, kernel_size=1, bias=False)
289
+ if input_dimension != filters
290
+ else None
291
+ )
292
+
293
+ for i in range(depth):
294
+ dilation = 2**i
295
+ padding = (kernel_size - 1) * dilation
296
+ conv_block = nn.Sequential(
297
+ nn.Conv1d(
298
+ in_channels=input_dimension if i == 0 else filters,
299
+ out_channels=filters,
300
+ kernel_size=kernel_size,
301
+ dilation=dilation,
302
+ padding=padding,
303
+ bias=False,
304
+ ),
305
+ Chomp1d(padding),
306
+ self.activation,
307
+ nn.Dropout(self.drop_prob),
308
+ nn.Conv1d(
309
+ in_channels=filters,
310
+ out_channels=filters,
311
+ kernel_size=kernel_size,
312
+ dilation=dilation,
313
+ padding=padding,
314
+ bias=False,
315
+ ),
316
+ Chomp1d(padding),
317
+ self.activation,
318
+ nn.Dropout(self.drop_prob),
319
+ )
320
+ self.layers.append(conv_block)
321
+
322
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
323
+ # x shape: (batch_size, time_steps, input_dimension)
324
+ x = x.permute(0, 2, 1) # (batch_size, input_dimension, time_steps)
325
+
326
+ res = x if self.downsample is None else self.downsample(x)
327
+ for layer in self.layers:
328
+ out = layer(x)
329
+ out = out + res
330
+ out = self.activation(out)
331
+ res = out # Update residual
332
+ x = out # Update input for next layer
333
+
334
+ out = out.permute(0, 2, 1) # (batch_size, time_steps, filters)
335
+ return out