braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,379 @@
1
+ # Authors: Cedric Rommel <cedric.rommel@inria.fr>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+ import torch
6
+ from einops.layers.torch import Rearrange
7
+ from torch import nn
8
+
9
+ from braindecode.models.base import EEGModuleMixin
10
+ from braindecode.modules import Ensure4d
11
+
12
+
13
+ class EEGInceptionMI(EEGModuleMixin, nn.Module):
14
+ r"""EEG Inception for Motor Imagery, as proposed in Zhang et al. (2021) [1]_
15
+
16
+ :bdg-success:`Convolution`
17
+
18
+ .. figure:: https://content.cld.iop.org/journals/1741-2552/18/4/046014/revision3/jneabed81f1_hr.jpg
19
+ :align: center
20
+ :alt: EEGInceptionMI Architecture
21
+
22
+
23
+ The model is strongly based on the original InceptionNet for computer
24
+ vision. The main goal is to extract features in parallel with different
25
+ scales. The network has two blocks made of 3 inception modules with a skip
26
+ connection.
27
+
28
+ The model is fully described in [1]_.
29
+
30
+ Notes
31
+ -----
32
+ This implementation is not guaranteed to be correct, has not been checked
33
+ by original authors, only reimplemented bosed on the paper [1]_.
34
+
35
+ Parameters
36
+ ----------
37
+ input_window_seconds : float, optional
38
+ Size of the input, in seconds. Set to 4.5 s as in [1]_ for dataset
39
+ BCI IV 2a.
40
+ sfreq : float, optional
41
+ EEG sampling frequency in Hz. Defaults to 250 Hz as in [1]_ for dataset
42
+ BCI IV 2a.
43
+ n_convs : int, optional
44
+ Number of convolution per inception wide branching. Defaults to 5 as
45
+ in [1]_ for dataset BCI IV 2a.
46
+ n_filters : int, optional
47
+ Number of convolutional filters for all layers of this type. Set to 48
48
+ as in [1]_ for dataset BCI IV 2a.
49
+ kernel_unit_s : float, optional
50
+ Size in seconds of the basic 1D convolutional kernel used in inception
51
+ modules. Each convolutional layer in such modules have kernels of
52
+ increasing size, odd multiples of this value (e.g. 0.1, 0.3, 0.5, 0.7,
53
+ 0.9 here for ``n_convs=5``). Defaults to 0.1 s.
54
+ activation: nn.Module
55
+ Activation function. Defaults to ReLU activation.
56
+
57
+ References
58
+ ----------
59
+ .. [1] Zhang, C., Kim, Y. K., & Eskandarian, A. (2021).
60
+ EEG-inception: an accurate and robust end-to-end neural network
61
+ for EEG-based motor imagery classification.
62
+ Journal of Neural Engineering, 18(4), 046014.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ n_chans=None,
68
+ n_outputs=None,
69
+ input_window_seconds=None,
70
+ sfreq=250,
71
+ n_convs: int = 5,
72
+ n_filters: int = 48,
73
+ kernel_unit_s: float = 0.1,
74
+ activation: type[nn.Module] = nn.ReLU,
75
+ chs_info=None,
76
+ n_times=None,
77
+ ):
78
+ super().__init__(
79
+ n_outputs=n_outputs,
80
+ n_chans=n_chans,
81
+ chs_info=chs_info,
82
+ n_times=n_times,
83
+ input_window_seconds=input_window_seconds,
84
+ sfreq=sfreq,
85
+ )
86
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
87
+
88
+ self.n_convs = n_convs
89
+ self.n_filters = n_filters
90
+ self.kernel_unit_s = kernel_unit_s
91
+ self.activation = activation
92
+
93
+ self.ensuredims = Ensure4d()
94
+ self.dimshuffle = Rearrange("batch C T 1 -> batch C 1 T")
95
+
96
+ self.mapping = {
97
+ "fc.weight": "final_layer.fc.weight",
98
+ "tc.bias": "final_layer.fc.bias",
99
+ }
100
+
101
+ # ======== Inception branches ========================
102
+
103
+ self.initial_inception_module = _InceptionModuleMI(
104
+ in_channels=self.n_chans,
105
+ n_filters=self.n_filters,
106
+ n_convs=self.n_convs,
107
+ kernel_unit_s=self.kernel_unit_s,
108
+ sfreq=self.sfreq,
109
+ activation=self.activation,
110
+ )
111
+
112
+ intermediate_in_channels = (self.n_convs + 1) * self.n_filters
113
+
114
+ self.intermediate_inception_modules_1 = nn.ModuleList(
115
+ [
116
+ _InceptionModuleMI(
117
+ in_channels=intermediate_in_channels,
118
+ n_filters=self.n_filters,
119
+ n_convs=self.n_convs,
120
+ kernel_unit_s=self.kernel_unit_s,
121
+ sfreq=self.sfreq,
122
+ activation=self.activation,
123
+ )
124
+ for _ in range(2)
125
+ ]
126
+ )
127
+
128
+ self.residual_block_1 = _ResidualModuleMI(
129
+ in_channels=self.n_chans,
130
+ n_filters=intermediate_in_channels,
131
+ activation=self.activation,
132
+ )
133
+
134
+ self.intermediate_inception_modules_2 = nn.ModuleList(
135
+ [
136
+ _InceptionModuleMI(
137
+ in_channels=intermediate_in_channels,
138
+ n_filters=self.n_filters,
139
+ n_convs=self.n_convs,
140
+ kernel_unit_s=self.kernel_unit_s,
141
+ sfreq=self.sfreq,
142
+ activation=self.activation,
143
+ )
144
+ for _ in range(3)
145
+ ]
146
+ )
147
+
148
+ self.residual_block_2 = _ResidualModuleMI(
149
+ in_channels=intermediate_in_channels,
150
+ n_filters=intermediate_in_channels,
151
+ activation=self.activation,
152
+ )
153
+
154
+ # XXX The paper mentions a final average pooling but does not indicate
155
+ # the kernel size... The only info available is figure1 showing a
156
+ # final AveragePooling layer and the table3 indicating the spatial and
157
+ # channel dimensions are unchanged by this layer... This could indicate
158
+ # a stride=1 as for MaxPooling layers. Howevere, when we look at the
159
+ # number of parameters of the linear layer following the average
160
+ # pooling, we see a small number of parameters, potentially indicating
161
+ # that the whole time dimension is averaged on this stage for each
162
+ # channel. We follow this last hypothesis here to comply with the
163
+ # number of parameters reported in the paper.
164
+ self.ave_pooling = nn.AvgPool2d(
165
+ kernel_size=(1, self.n_times),
166
+ )
167
+
168
+ self.flat = nn.Flatten()
169
+
170
+ module = nn.Sequential()
171
+ module.add_module(
172
+ "fc",
173
+ nn.Linear(
174
+ in_features=intermediate_in_channels,
175
+ out_features=self.n_outputs,
176
+ bias=True,
177
+ ),
178
+ )
179
+ module.add_module("out_fun", nn.Identity())
180
+ self.final_layer = module
181
+
182
+ def forward(
183
+ self,
184
+ X: torch.Tensor,
185
+ ) -> torch.Tensor:
186
+ X = self.ensuredims(X)
187
+ X = self.dimshuffle(X)
188
+
189
+ res1 = self.residual_block_1(X)
190
+
191
+ out = self.initial_inception_module(X)
192
+ for layer in self.intermediate_inception_modules_1:
193
+ out = layer(out)
194
+
195
+ out = out + res1
196
+
197
+ res2 = self.residual_block_2(out)
198
+
199
+ for layer in self.intermediate_inception_modules_2:
200
+ out = layer(out)
201
+
202
+ out = res2 + out
203
+
204
+ out = self.ave_pooling(out)
205
+ out = self.flat(out)
206
+ out = self.final_layer(out)
207
+ return out
208
+
209
+
210
+ class _InceptionModuleMI(nn.Module):
211
+ r"""
212
+ Inception module.
213
+
214
+ This module implements a inception-like architecture that processes input
215
+ feature maps through multiple convolutional paths with different kernel
216
+ sizes, allowing the network to capture features at multiple scales.
217
+ It includes bottleneck layers, convolutional layers, and pooling layers,
218
+ followed by batch normalization and an activation function.
219
+
220
+ Parameters
221
+ ----------
222
+ in_channels : int
223
+ Number of input channels in the input tensor.
224
+ n_filters : int
225
+ Number of filters (output channels) for each convolutional layer.
226
+ n_convs : int
227
+ Number of convolutional layers in the module (excluding bottleneck
228
+ and pooling paths).
229
+ kernel_unit_s : float, optional
230
+ Base size (in seconds) for the convolutional kernels. The actual kernel
231
+ size is computed as ``(2 * n_units + 1) * kernel_unit``, where ``n_units``
232
+ ranges from 0 to ``n_convs - 1``. Default is 0.1 seconds.
233
+ sfreq : float, optional
234
+ Sampling frequency of the input data, used to convert kernel sizes from
235
+ seconds to samples. Default is 250 Hz.
236
+ activation : nn.Module class, optional
237
+ Activation function class to apply after batch normalization. Should be
238
+ a PyTorch activation module class like ``nn.ReLU`` or ``nn.ELU``.
239
+ Default is ``nn.ReLU``.
240
+
241
+ """
242
+
243
+ def __init__(
244
+ self,
245
+ in_channels,
246
+ n_filters,
247
+ n_convs,
248
+ kernel_unit_s=0.1,
249
+ sfreq=250,
250
+ activation: type[nn.Module] = nn.ReLU,
251
+ ):
252
+ super().__init__()
253
+ self.in_channels = in_channels
254
+ self.n_filters = n_filters
255
+ self.n_convs = n_convs
256
+ self.kernel_unit_s = kernel_unit_s
257
+ self.sfreq = sfreq
258
+
259
+ self.bottleneck = nn.Conv2d(
260
+ in_channels=self.in_channels,
261
+ out_channels=self.n_filters,
262
+ kernel_size=1,
263
+ bias=True,
264
+ )
265
+
266
+ kernel_unit = int(self.kernel_unit_s * self.sfreq)
267
+
268
+ # XXX Maxpooling is usually used to reduce spatial resolution, with a
269
+ # stride equal to the kernel size... But it seems the authors use
270
+ # stride=1 in their paper according to the output shapes from Table3,
271
+ # although this is not clearly specified in the paper text.
272
+ self.pooling = nn.MaxPool2d(
273
+ kernel_size=(1, kernel_unit),
274
+ stride=1,
275
+ padding=(0, int(kernel_unit // 2)),
276
+ )
277
+
278
+ self.pooling_conv = nn.Conv2d(
279
+ in_channels=self.in_channels,
280
+ out_channels=self.n_filters,
281
+ kernel_size=1,
282
+ bias=True,
283
+ )
284
+
285
+ self.conv_list = nn.ModuleList(
286
+ [
287
+ nn.Conv2d(
288
+ in_channels=self.n_filters,
289
+ out_channels=self.n_filters,
290
+ kernel_size=(1, (n_units * 2 + 1) * kernel_unit),
291
+ padding="same",
292
+ bias=True,
293
+ )
294
+ for n_units in range(self.n_convs)
295
+ ]
296
+ )
297
+
298
+ self.bn = nn.BatchNorm2d(self.n_filters * (self.n_convs + 1))
299
+
300
+ self.activation = activation()
301
+
302
+ def forward(
303
+ self,
304
+ X: torch.Tensor,
305
+ ) -> torch.Tensor:
306
+ X1 = self.bottleneck(X)
307
+
308
+ X1 = [conv(X1) for conv in self.conv_list]
309
+
310
+ X2 = self.pooling(X)
311
+ X2 = self.pooling_conv(X2)
312
+ # Get the target length from one of the conv branches
313
+ target_len = X1[0].shape[-1]
314
+
315
+ # Crop the pooling output if its length does not match
316
+ if X2.shape[-1] != target_len:
317
+ X2 = X2[..., :target_len]
318
+
319
+ out = torch.cat(X1 + [X2], 1)
320
+
321
+ out = self.bn(out)
322
+ return self.activation(out)
323
+
324
+
325
+ class _ResidualModuleMI(nn.Module):
326
+ r"""
327
+ Residual module.
328
+
329
+ This module performs a 1x1 convolution followed by batch normalization and an activation function.
330
+ It is designed to process input feature maps and produce transformed output feature maps, often used
331
+ in residual connections within neural network architectures.
332
+
333
+ Parameters
334
+ ----------
335
+ in_channels : int
336
+ Number of input channels in the input tensor.
337
+ n_filters : int
338
+ Number of filters (output channels) for the convolutional layer.
339
+ activation : nn.Module, optional
340
+ Activation function to apply after batch normalization. Should be an instance of a PyTorch
341
+ activation module (e.g., ``nn.ReLU()``, ``nn.ELU()``). Default is ``nn.ReLU()``.
342
+
343
+
344
+ """
345
+
346
+ def __init__(self, in_channels, n_filters, activation: type[nn.Module] = nn.ReLU):
347
+ super().__init__()
348
+ self.in_channels = in_channels
349
+ self.n_filters = n_filters
350
+ self.activation = activation()
351
+
352
+ self.bn = nn.BatchNorm2d(self.n_filters)
353
+ self.conv = nn.Conv2d(
354
+ in_channels=self.in_channels,
355
+ out_channels=self.n_filters,
356
+ kernel_size=1,
357
+ bias=True,
358
+ )
359
+
360
+ def forward(
361
+ self,
362
+ X: torch.Tensor,
363
+ ) -> torch.Tensor:
364
+ """
365
+ Forward pass of the residual module.
366
+
367
+ Parameters
368
+ ----------
369
+ X : torch.Tensor
370
+ Input tensor of shape (batch_size, ch_names, n_times).
371
+
372
+ Returns
373
+ -------
374
+ torch.Tensor
375
+ Output tensor after convolution, batch normalization, and activation function.
376
+ """
377
+ out = self.conv(X)
378
+ out = self.bn(out)
379
+ return self.activation(out)
@@ -0,0 +1,302 @@
1
+ # Authors: Ghaith Bouallegue <ghaithbouallegue@gmail.com>
2
+ #
3
+ # License: BSD-3
4
+ from einops.layers.torch import Rearrange
5
+ from torch import nn
6
+
7
+ from braindecode.models.base import EEGModuleMixin
8
+ from braindecode.modules import DepthwiseConv2d, Ensure4d, InceptionBlock
9
+
10
+
11
+ class EEGITNet(EEGModuleMixin, nn.Sequential):
12
+ r"""EEG-ITNet from Salami, et al (2022) [Salami2022]_
13
+
14
+ :bdg-success:`Convolution` :bdg-secondary:`Recurrent`
15
+
16
+ .. figure:: https://braindecode.org/dev/_static/model/eegitnet.jpg
17
+ :align: center
18
+ :alt: EEG-ITNet Architecture
19
+
20
+ EEG-ITNet: An Explainable Inception Temporal
21
+ Convolutional Network for motor imagery classification from
22
+ Salami et al. 2022.
23
+
24
+ See [Salami2022]_ for details.
25
+
26
+ Code adapted from https://github.com/abbassalami/eeg-itnet
27
+
28
+ Parameters
29
+ ----------
30
+ drop_prob: float
31
+ Dropout probability.
32
+ activation: nn.Module, default=nn.ELU
33
+ Activation function class to apply. Should be a PyTorch activation
34
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
35
+ kernel_length : int, optional
36
+ Kernel length for inception branches. Determines the temporal receptive field.
37
+ Default is 16.
38
+ pool_kernel : int, optional
39
+ Pooling kernel size for the average pooling layer. Default is 4.
40
+ tcn_in_channel : int, optional
41
+ Number of input channels for Temporal Convolutional (TC) blocks. Default is 14.
42
+ tcn_kernel_size : int, optional
43
+ Kernel size for the TC blocks. Determines the temporal receptive field.
44
+ Default is 4.
45
+ tcn_padding : int, optional
46
+ Padding size for the TC blocks to maintain the input dimensions. Default is 3.
47
+ drop_prob : float, optional
48
+ Dropout probability applied after certain layers to prevent overfitting.
49
+ Default is 0.4.
50
+ tcn_dilatation : int, optional
51
+ Dilation rate for the first TC block. Subsequent blocks will have
52
+ dilation rates multiplied by powers of 2. Default is 1.
53
+
54
+ Notes
55
+ -----
56
+ This implementation is not guaranteed to be correct, has not been checked
57
+ by original authors, only reimplemented from the paper based on author implementation.
58
+
59
+
60
+ References
61
+ ----------
62
+ .. [Salami2022] A. Salami, J. Andreu-Perez and H. Gillmeister, "EEG-ITNet:
63
+ An Explainable Inception Temporal Convolutional Network for motor
64
+ imagery classification," in IEEE Access,
65
+ doi: 10.1109/ACCESS.2022.3161489.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ # Braindecode parameters
71
+ n_outputs=None,
72
+ n_chans=None,
73
+ n_times=None,
74
+ chs_info=None,
75
+ input_window_seconds=None,
76
+ sfreq=None,
77
+ # Model parameters
78
+ n_filters_time: int = 2,
79
+ kernel_length: int = 16,
80
+ pool_kernel: int = 4,
81
+ tcn_in_channel: int = 14,
82
+ tcn_kernel_size: int = 4,
83
+ tcn_padding: int = 3,
84
+ drop_prob: float = 0.4,
85
+ tcn_dilatation: int = 1,
86
+ activation: type[nn.Module] = nn.ELU,
87
+ ):
88
+ super().__init__(
89
+ n_outputs=n_outputs,
90
+ n_chans=n_chans,
91
+ chs_info=chs_info,
92
+ n_times=n_times,
93
+ input_window_seconds=input_window_seconds,
94
+ sfreq=sfreq,
95
+ )
96
+ self.mapping = {
97
+ "classification.1.weight": "final_layer.clf.weight",
98
+ "classification.1.bias": "final_layer.clf.weight",
99
+ }
100
+
101
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
102
+
103
+ # ======== Handling EEG input ========================
104
+ self.add_module(
105
+ "input_preprocess",
106
+ nn.Sequential(Ensure4d(), Rearrange("ba ch t 1 -> ba 1 ch t")),
107
+ )
108
+ # ======== Inception branches ========================
109
+ block11 = self._get_inception_branch(
110
+ in_channels=self.n_chans,
111
+ out_channels=n_filters_time,
112
+ kernel_length=kernel_length,
113
+ activation=activation,
114
+ )
115
+ block12 = self._get_inception_branch(
116
+ in_channels=self.n_chans,
117
+ out_channels=n_filters_time * 2,
118
+ kernel_length=kernel_length * 2,
119
+ activation=activation,
120
+ )
121
+ block13 = self._get_inception_branch(
122
+ in_channels=self.n_chans,
123
+ out_channels=n_filters_time * 4,
124
+ kernel_length=n_filters_time * 4,
125
+ activation=activation,
126
+ )
127
+ self.add_module("inception_block", InceptionBlock((block11, block12, block13)))
128
+ self.pool1 = self.add_module(
129
+ "pooling",
130
+ nn.Sequential(
131
+ nn.AvgPool2d(kernel_size=(1, pool_kernel)), nn.Dropout(drop_prob)
132
+ ),
133
+ )
134
+ # =========== TC blocks =====================
135
+ self.add_module(
136
+ "TC_block1",
137
+ _TCBlock(
138
+ in_ch=tcn_in_channel,
139
+ kernel_length=tcn_kernel_size,
140
+ dilatation=tcn_dilatation,
141
+ padding=tcn_padding,
142
+ drop_prob=drop_prob,
143
+ activation=activation,
144
+ ),
145
+ )
146
+ # ================================
147
+ self.add_module(
148
+ "TC_block2",
149
+ _TCBlock(
150
+ in_ch=tcn_in_channel,
151
+ kernel_length=tcn_kernel_size,
152
+ dilatation=tcn_dilatation * 2,
153
+ padding=tcn_padding * 2,
154
+ drop_prob=drop_prob,
155
+ activation=activation,
156
+ ),
157
+ )
158
+ # ================================
159
+ self.add_module(
160
+ "TC_block3",
161
+ _TCBlock(
162
+ in_ch=tcn_in_channel,
163
+ kernel_length=tcn_kernel_size,
164
+ dilatation=tcn_dilatation * 4,
165
+ padding=tcn_padding * 4,
166
+ drop_prob=drop_prob,
167
+ activation=activation,
168
+ ),
169
+ )
170
+ # ================================
171
+ self.add_module(
172
+ "TC_block4",
173
+ _TCBlock(
174
+ in_ch=tcn_in_channel,
175
+ kernel_length=tcn_kernel_size,
176
+ dilatation=tcn_dilatation * 8,
177
+ padding=tcn_padding * 8,
178
+ drop_prob=drop_prob,
179
+ activation=activation,
180
+ ),
181
+ )
182
+
183
+ # ============= Dimensionality reduction ===================
184
+ self.add_module(
185
+ "dim_reduction",
186
+ nn.Sequential(
187
+ nn.Conv2d(tcn_in_channel, tcn_in_channel * 2, kernel_size=(1, 1)),
188
+ nn.BatchNorm2d(tcn_in_channel * 2),
189
+ activation(),
190
+ nn.AvgPool2d((1, tcn_kernel_size)),
191
+ nn.Dropout(drop_prob),
192
+ ),
193
+ )
194
+ # ============== Classifier ==================
195
+ # Moved flatten to another layer
196
+ self.add_module("flatten", nn.Flatten())
197
+
198
+ num_features = self.get_output_shape()[-1]
199
+
200
+ self.add_module("final_layer", nn.Linear(num_features, self.n_outputs))
201
+
202
+ @staticmethod
203
+ def _get_inception_branch(
204
+ in_channels,
205
+ out_channels,
206
+ kernel_length,
207
+ depth_multiplier=1,
208
+ activation: type[nn.Module] = nn.ELU,
209
+ ):
210
+ return nn.Sequential(
211
+ nn.Conv2d(
212
+ 1,
213
+ out_channels,
214
+ kernel_size=(1, kernel_length),
215
+ padding="same",
216
+ bias=False,
217
+ ),
218
+ nn.BatchNorm2d(out_channels),
219
+ DepthwiseConv2d(
220
+ out_channels,
221
+ kernel_size=(in_channels, 1),
222
+ depth_multiplier=depth_multiplier,
223
+ bias=False,
224
+ padding="valid",
225
+ ),
226
+ nn.BatchNorm2d(out_channels),
227
+ activation(),
228
+ )
229
+
230
+
231
+ class _TCBlock(nn.Module):
232
+ r"""
233
+ Temporal Convolutional (TC) block.
234
+
235
+ This module applies two depthwise separable convolutions with dilation and residual
236
+ connections, commonly used in temporal convolutional networks to capture long-range
237
+ dependencies in time-series data.
238
+
239
+ Parameters
240
+ ----------
241
+ in_ch : int
242
+ Number of input channels.
243
+ kernel_length : int
244
+ Length of the convolutional kernels.
245
+ dilatation : int
246
+ Dilatation rate for the convolutions.
247
+ padding : int
248
+ Amount of padding to add to the input.
249
+ drop_prob : float, optional
250
+ Dropout probability. Default is 0.4.
251
+ activation : nn.Module class, optional
252
+ Activation function class to use. Should be a PyTorch activation module class
253
+ like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
254
+ """
255
+
256
+ def __init__(
257
+ self,
258
+ in_ch,
259
+ kernel_length,
260
+ dilatation,
261
+ padding,
262
+ drop_prob=0.4,
263
+ activation: type[nn.Module] = nn.ELU,
264
+ ):
265
+ super().__init__()
266
+ self.pad = padding
267
+ self.tc1 = nn.Sequential(
268
+ DepthwiseConv2d(
269
+ in_ch,
270
+ kernel_size=(1, kernel_length),
271
+ depth_multiplier=1,
272
+ dilation=(1, dilatation),
273
+ bias=False,
274
+ padding="valid",
275
+ ),
276
+ nn.BatchNorm2d(in_ch),
277
+ activation(),
278
+ nn.Dropout(drop_prob),
279
+ )
280
+
281
+ self.tc2 = nn.Sequential(
282
+ DepthwiseConv2d(
283
+ in_ch,
284
+ kernel_size=(1, kernel_length),
285
+ depth_multiplier=1,
286
+ dilation=(1, dilatation),
287
+ bias=False,
288
+ padding="valid",
289
+ ),
290
+ nn.BatchNorm2d(in_ch),
291
+ activation(),
292
+ nn.Dropout(drop_prob),
293
+ )
294
+
295
+ def forward(self, x):
296
+ residual = x
297
+ paddings = (self.pad, 0, 0, 0, 0, 0, 0, 0)
298
+ x = nn.functional.pad(x, paddings)
299
+ x = self.tc1(x)
300
+ x = nn.functional.pad(x, paddings)
301
+ x = self.tc2(x) + residual
302
+ return x