braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,371 @@
1
+ # Authors: Cedric Rommel <cedric.rommel@inria.fr>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+ import torch
6
+ from einops.layers.torch import Rearrange
7
+ from torch import nn
8
+
9
+ from braindecode.models.base import EEGModuleMixin
10
+ from braindecode.modules import Ensure4d
11
+
12
+
13
+ class EEGInceptionMI(EEGModuleMixin, nn.Module):
14
+ """EEG Inception for Motor Imagery, as proposed in Zhang et al. (2021) [1]_
15
+
16
+ .. figure:: https://content.cld.iop.org/journals/1741-2552/18/4/046014/revision3/jneabed81f1_hr.jpg
17
+ :align: center
18
+ :alt: EEGInceptionMI Architecture
19
+
20
+
21
+ The model is strongly based on the original InceptionNet for computer
22
+ vision. The main goal is to extract features in parallel with different
23
+ scales. The network has two blocks made of 3 inception modules with a skip
24
+ connection.
25
+
26
+ The model is fully described in [1]_.
27
+
28
+ Notes
29
+ -----
30
+ This implementation is not guaranteed to be correct, has not been checked
31
+ by original authors, only reimplemented bosed on the paper [1]_.
32
+
33
+ Parameters
34
+ ----------
35
+ input_window_seconds : float, optional
36
+ Size of the input, in seconds. Set to 4.5 s as in [1]_ for dataset
37
+ BCI IV 2a.
38
+ sfreq : float, optional
39
+ EEG sampling frequency in Hz. Defaults to 250 Hz as in [1]_ for dataset
40
+ BCI IV 2a.
41
+ n_convs : int, optional
42
+ Number of convolution per inception wide branching. Defaults to 5 as
43
+ in [1]_ for dataset BCI IV 2a.
44
+ n_filters : int, optional
45
+ Number of convolutional filters for all layers of this type. Set to 48
46
+ as in [1]_ for dataset BCI IV 2a.
47
+ kernel_unit_s : float, optional
48
+ Size in seconds of the basic 1D convolutional kernel used in inception
49
+ modules. Each convolutional layer in such modules have kernels of
50
+ increasing size, odd multiples of this value (e.g. 0.1, 0.3, 0.5, 0.7,
51
+ 0.9 here for ``n_convs=5``). Defaults to 0.1 s.
52
+ activation: nn.Module
53
+ Activation function. Defaults to ReLU activation.
54
+
55
+ References
56
+ ----------
57
+ .. [1] Zhang, C., Kim, Y. K., & Eskandarian, A. (2021).
58
+ EEG-inception: an accurate and robust end-to-end neural network
59
+ for EEG-based motor imagery classification.
60
+ Journal of Neural Engineering, 18(4), 046014.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ n_chans=None,
66
+ n_outputs=None,
67
+ input_window_seconds=None,
68
+ sfreq=250,
69
+ n_convs=5,
70
+ n_filters=48,
71
+ kernel_unit_s=0.1,
72
+ activation: nn.Module = nn.ReLU,
73
+ chs_info=None,
74
+ n_times=None,
75
+ ):
76
+ super().__init__(
77
+ n_outputs=n_outputs,
78
+ n_chans=n_chans,
79
+ chs_info=chs_info,
80
+ n_times=n_times,
81
+ input_window_seconds=input_window_seconds,
82
+ sfreq=sfreq,
83
+ )
84
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
85
+
86
+ self.n_convs = n_convs
87
+ self.n_filters = n_filters
88
+ self.kernel_unit_s = kernel_unit_s
89
+ self.activation = activation
90
+
91
+ self.ensuredims = Ensure4d()
92
+ self.dimshuffle = Rearrange("batch C T 1 -> batch C 1 T")
93
+
94
+ self.mapping = {
95
+ "fc.weight": "final_layer.fc.weight",
96
+ "tc.bias": "final_layer.fc.bias",
97
+ }
98
+
99
+ # ======== Inception branches ========================
100
+
101
+ self.initial_inception_module = _InceptionModuleMI(
102
+ in_channels=self.n_chans,
103
+ n_filters=self.n_filters,
104
+ n_convs=self.n_convs,
105
+ kernel_unit_s=self.kernel_unit_s,
106
+ sfreq=self.sfreq,
107
+ activation=self.activation,
108
+ )
109
+
110
+ intermediate_in_channels = (self.n_convs + 1) * self.n_filters
111
+
112
+ self.intermediate_inception_modules_1 = nn.ModuleList(
113
+ [
114
+ _InceptionModuleMI(
115
+ in_channels=intermediate_in_channels,
116
+ n_filters=self.n_filters,
117
+ n_convs=self.n_convs,
118
+ kernel_unit_s=self.kernel_unit_s,
119
+ sfreq=self.sfreq,
120
+ activation=self.activation,
121
+ )
122
+ for _ in range(2)
123
+ ]
124
+ )
125
+
126
+ self.residual_block_1 = _ResidualModuleMI(
127
+ in_channels=self.n_chans,
128
+ n_filters=intermediate_in_channels,
129
+ activation=self.activation,
130
+ )
131
+
132
+ self.intermediate_inception_modules_2 = nn.ModuleList(
133
+ [
134
+ _InceptionModuleMI(
135
+ in_channels=intermediate_in_channels,
136
+ n_filters=self.n_filters,
137
+ n_convs=self.n_convs,
138
+ kernel_unit_s=self.kernel_unit_s,
139
+ sfreq=self.sfreq,
140
+ activation=self.activation,
141
+ )
142
+ for _ in range(3)
143
+ ]
144
+ )
145
+
146
+ self.residual_block_2 = _ResidualModuleMI(
147
+ in_channels=intermediate_in_channels,
148
+ n_filters=intermediate_in_channels,
149
+ activation=self.activation,
150
+ )
151
+
152
+ # XXX The paper mentions a final average pooling but does not indicate
153
+ # the kernel size... The only info available is figure1 showing a
154
+ # final AveragePooling layer and the table3 indicating the spatial and
155
+ # channel dimensions are unchanged by this layer... This could indicate
156
+ # a stride=1 as for MaxPooling layers. Howevere, when we look at the
157
+ # number of parameters of the linear layer following the average
158
+ # pooling, we see a small number of parameters, potentially indicating
159
+ # that the whole time dimension is averaged on this stage for each
160
+ # channel. We follow this last hypothesis here to comply with the
161
+ # number of parameters reported in the paper.
162
+ self.ave_pooling = nn.AvgPool2d(
163
+ kernel_size=(1, self.n_times),
164
+ )
165
+
166
+ self.flat = nn.Flatten()
167
+
168
+ module = nn.Sequential()
169
+ module.add_module(
170
+ "fc",
171
+ nn.Linear(
172
+ in_features=intermediate_in_channels,
173
+ out_features=self.n_outputs,
174
+ bias=True,
175
+ ),
176
+ )
177
+ module.add_module("out_fun", nn.Identity())
178
+ self.final_layer = module
179
+
180
+ def forward(
181
+ self,
182
+ X: torch.Tensor,
183
+ ) -> torch.Tensor:
184
+ X = self.ensuredims(X)
185
+ X = self.dimshuffle(X)
186
+
187
+ res1 = self.residual_block_1(X)
188
+
189
+ out = self.initial_inception_module(X)
190
+ for layer in self.intermediate_inception_modules_1:
191
+ out = layer(out)
192
+
193
+ out = out + res1
194
+
195
+ res2 = self.residual_block_2(out)
196
+
197
+ for layer in self.intermediate_inception_modules_2:
198
+ out = layer(out)
199
+
200
+ out = res2 + out
201
+
202
+ out = self.ave_pooling(out)
203
+ out = self.flat(out)
204
+ out = self.final_layer(out)
205
+ return out
206
+
207
+
208
+ class _InceptionModuleMI(nn.Module):
209
+ """
210
+ Inception module.
211
+
212
+ This module implements a inception-like architecture that processes input
213
+ feature maps through multiple convolutional paths with different kernel
214
+ sizes, allowing the network to capture features at multiple scales.
215
+ It includes bottleneck layers, convolutional layers, and pooling layers,
216
+ followed by batch normalization and an activation function.
217
+
218
+ Parameters
219
+ ----------
220
+ in_channels : int
221
+ Number of input channels in the input tensor.
222
+ n_filters : int
223
+ Number of filters (output channels) for each convolutional layer.
224
+ n_convs : int
225
+ Number of convolutional layers in the module (excluding bottleneck
226
+ and pooling paths).
227
+ kernel_unit_s : float, optional
228
+ Base size (in seconds) for the convolutional kernels. The actual kernel
229
+ size is computed as ``(2 * n_units + 1) * kernel_unit``, where ``n_units``
230
+ ranges from 0 to ``n_convs - 1``. Default is 0.1 seconds.
231
+ sfreq : float, optional
232
+ Sampling frequency of the input data, used to convert kernel sizes from
233
+ seconds to samples. Default is 250 Hz.
234
+ activation : nn.Module class, optional
235
+ Activation function class to apply after batch normalization. Should be
236
+ a PyTorch activation module class like ``nn.ReLU`` or ``nn.ELU``.
237
+ Default is ``nn.ReLU``.
238
+
239
+ """
240
+
241
+ def __init__(
242
+ self,
243
+ in_channels,
244
+ n_filters,
245
+ n_convs,
246
+ kernel_unit_s=0.1,
247
+ sfreq=250,
248
+ activation: nn.Module = nn.ReLU,
249
+ ):
250
+ super().__init__()
251
+ self.in_channels = in_channels
252
+ self.n_filters = n_filters
253
+ self.n_convs = n_convs
254
+ self.kernel_unit_s = kernel_unit_s
255
+ self.sfreq = sfreq
256
+
257
+ self.bottleneck = nn.Conv2d(
258
+ in_channels=self.in_channels,
259
+ out_channels=self.n_filters,
260
+ kernel_size=1,
261
+ bias=True,
262
+ )
263
+
264
+ kernel_unit = int(self.kernel_unit_s * self.sfreq)
265
+
266
+ # XXX Maxpooling is usually used to reduce spatial resolution, with a
267
+ # stride equal to the kernel size... But it seems the authors use
268
+ # stride=1 in their paper according to the output shapes from Table3,
269
+ # although this is not clearly specified in the paper text.
270
+ self.pooling = nn.MaxPool2d(
271
+ kernel_size=(1, kernel_unit),
272
+ stride=1,
273
+ padding=(0, int(kernel_unit // 2)),
274
+ )
275
+
276
+ self.pooling_conv = nn.Conv2d(
277
+ in_channels=self.in_channels,
278
+ out_channels=self.n_filters,
279
+ kernel_size=1,
280
+ bias=True,
281
+ )
282
+
283
+ self.conv_list = nn.ModuleList(
284
+ [
285
+ nn.Conv2d(
286
+ in_channels=self.n_filters,
287
+ out_channels=self.n_filters,
288
+ kernel_size=(1, (n_units * 2 + 1) * kernel_unit),
289
+ padding="same",
290
+ bias=True,
291
+ )
292
+ for n_units in range(self.n_convs)
293
+ ]
294
+ )
295
+
296
+ self.bn = nn.BatchNorm2d(self.n_filters * (self.n_convs + 1))
297
+
298
+ self.activation = activation()
299
+
300
+ def forward(
301
+ self,
302
+ X: torch.Tensor,
303
+ ) -> torch.Tensor:
304
+ X1 = self.bottleneck(X)
305
+
306
+ X1 = [conv(X1) for conv in self.conv_list]
307
+
308
+ X2 = self.pooling(X)
309
+ X2 = self.pooling_conv(X2)
310
+
311
+ out = torch.cat(X1 + [X2], 1)
312
+
313
+ out = self.bn(out)
314
+ return self.activation(out)
315
+
316
+
317
+ class _ResidualModuleMI(nn.Module):
318
+ """
319
+ Residual module.
320
+
321
+ This module performs a 1x1 convolution followed by batch normalization and an activation function.
322
+ It is designed to process input feature maps and produce transformed output feature maps, often used
323
+ in residual connections within neural network architectures.
324
+
325
+ Parameters
326
+ ----------
327
+ in_channels : int
328
+ Number of input channels in the input tensor.
329
+ n_filters : int
330
+ Number of filters (output channels) for the convolutional layer.
331
+ activation : nn.Module, optional
332
+ Activation function to apply after batch normalization. Should be an instance of a PyTorch
333
+ activation module (e.g., ``nn.ReLU()``, ``nn.ELU()``). Default is ``nn.ReLU()``.
334
+
335
+
336
+ """
337
+
338
+ def __init__(self, in_channels, n_filters, activation: nn.Module = nn.ReLU):
339
+ super().__init__()
340
+ self.in_channels = in_channels
341
+ self.n_filters = n_filters
342
+ self.activation = activation()
343
+
344
+ self.bn = nn.BatchNorm2d(self.n_filters)
345
+ self.conv = nn.Conv2d(
346
+ in_channels=self.in_channels,
347
+ out_channels=self.n_filters,
348
+ kernel_size=1,
349
+ bias=True,
350
+ )
351
+
352
+ def forward(
353
+ self,
354
+ X: torch.Tensor,
355
+ ) -> torch.Tensor:
356
+ """
357
+ Forward pass of the residual module.
358
+
359
+ Parameters
360
+ ----------
361
+ X : torch.Tensor
362
+ Input tensor of shape (batch_size, ch_names, n_times).
363
+
364
+ Returns
365
+ -------
366
+ torch.Tensor
367
+ Output tensor after convolution, batch normalization, and activation function.
368
+ """
369
+ out = self.conv(X)
370
+ out = self.bn(out)
371
+ return self.activation(out)
@@ -0,0 +1,301 @@
1
+ # Authors: Ghaith Bouallegue <ghaithbouallegue@gmail.com>
2
+ #
3
+ # License: BSD-3
4
+ import torch
5
+ from einops.layers.torch import Rearrange
6
+ from torch import nn
7
+
8
+ from braindecode.models.base import EEGModuleMixin
9
+ from braindecode.modules import DepthwiseConv2d, Ensure4d, InceptionBlock
10
+
11
+
12
+ class EEGITNet(EEGModuleMixin, nn.Sequential):
13
+ """EEG-ITNet from Salami, et al (2022) [Salami2022]_
14
+
15
+ .. figure:: https://braindecode.org/dev/_static/model/eegitnet.jpg
16
+ :align: center
17
+ :alt: EEG-ITNet Architecture
18
+
19
+ EEG-ITNet: An Explainable Inception Temporal
20
+ Convolutional Network for motor imagery classification from
21
+ Salami et al. 2022.
22
+
23
+ See [Salami2022]_ for details.
24
+
25
+ Code adapted from https://github.com/abbassalami/eeg-itnet
26
+
27
+ Parameters
28
+ ----------
29
+ drop_prob: float
30
+ Dropout probability.
31
+ activation: nn.Module, default=nn.ELU
32
+ Activation function class to apply. Should be a PyTorch activation
33
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
34
+ kernel_length : int, optional
35
+ Kernel length for inception branches. Determines the temporal receptive field.
36
+ Default is 16.
37
+ pool_kernel : int, optional
38
+ Pooling kernel size for the average pooling layer. Default is 4.
39
+ tcn_in_channel : int, optional
40
+ Number of input channels for Temporal Convolutional (TC) blocks. Default is 14.
41
+ tcn_kernel_size : int, optional
42
+ Kernel size for the TC blocks. Determines the temporal receptive field.
43
+ Default is 4.
44
+ tcn_padding : int, optional
45
+ Padding size for the TC blocks to maintain the input dimensions. Default is 3.
46
+ drop_prob : float, optional
47
+ Dropout probability applied after certain layers to prevent overfitting.
48
+ Default is 0.4.
49
+ tcn_dilatation : int, optional
50
+ Dilation rate for the first TC block. Subsequent blocks will have
51
+ dilation rates multiplied by powers of 2. Default is 1.
52
+
53
+ Notes
54
+ -----
55
+ This implementation is not guaranteed to be correct, has not been checked
56
+ by original authors, only reimplemented from the paper based on author implementation.
57
+
58
+
59
+ References
60
+ ----------
61
+ .. [Salami2022] A. Salami, J. Andreu-Perez and H. Gillmeister, "EEG-ITNet:
62
+ An Explainable Inception Temporal Convolutional Network for motor
63
+ imagery classification," in IEEE Access,
64
+ doi: 10.1109/ACCESS.2022.3161489.
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ # Braindecode parameters
70
+ n_outputs=None,
71
+ n_chans=None,
72
+ n_times=None,
73
+ chs_info=None,
74
+ input_window_seconds=None,
75
+ sfreq=None,
76
+ # Model parameters
77
+ n_filters_time: int = 2,
78
+ kernel_length: int = 16,
79
+ pool_kernel: int = 4,
80
+ tcn_in_channel: int = 14,
81
+ tcn_kernel_size: int = 4,
82
+ tcn_padding: int = 3,
83
+ drop_prob: float = 0.4,
84
+ tcn_dilatation: int = 1,
85
+ activation: nn.Module = nn.ELU,
86
+ ):
87
+ super().__init__(
88
+ n_outputs=n_outputs,
89
+ n_chans=n_chans,
90
+ chs_info=chs_info,
91
+ n_times=n_times,
92
+ input_window_seconds=input_window_seconds,
93
+ sfreq=sfreq,
94
+ )
95
+ self.mapping = {
96
+ "classification.1.weight": "final_layer.clf.weight",
97
+ "classification.1.bias": "final_layer.clf.weight",
98
+ }
99
+
100
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
101
+
102
+ # ======== Handling EEG input ========================
103
+ self.add_module(
104
+ "input_preprocess",
105
+ nn.Sequential(Ensure4d(), Rearrange("ba ch t 1 -> ba 1 ch t")),
106
+ )
107
+ # ======== Inception branches ========================
108
+ block11 = self._get_inception_branch(
109
+ in_channels=self.n_chans,
110
+ out_channels=n_filters_time,
111
+ kernel_length=kernel_length,
112
+ activation=activation,
113
+ )
114
+ block12 = self._get_inception_branch(
115
+ in_channels=self.n_chans,
116
+ out_channels=n_filters_time * 2,
117
+ kernel_length=kernel_length * 2,
118
+ activation=activation,
119
+ )
120
+ block13 = self._get_inception_branch(
121
+ in_channels=self.n_chans,
122
+ out_channels=n_filters_time * 4,
123
+ kernel_length=n_filters_time * 4,
124
+ activation=activation,
125
+ )
126
+ self.add_module("inception_block", InceptionBlock((block11, block12, block13)))
127
+ self.pool1 = self.add_module(
128
+ "pooling",
129
+ nn.Sequential(
130
+ nn.AvgPool2d(kernel_size=(1, pool_kernel)), nn.Dropout(drop_prob)
131
+ ),
132
+ )
133
+ # =========== TC blocks =====================
134
+ self.add_module(
135
+ "TC_block1",
136
+ _TCBlock(
137
+ in_ch=tcn_in_channel,
138
+ kernel_length=tcn_kernel_size,
139
+ dilatation=tcn_dilatation,
140
+ padding=tcn_padding,
141
+ drop_prob=drop_prob,
142
+ activation=activation,
143
+ ),
144
+ )
145
+ # ================================
146
+ self.add_module(
147
+ "TC_block2",
148
+ _TCBlock(
149
+ in_ch=tcn_in_channel,
150
+ kernel_length=tcn_kernel_size,
151
+ dilatation=tcn_dilatation * 2,
152
+ padding=tcn_padding * 2,
153
+ drop_prob=drop_prob,
154
+ activation=activation,
155
+ ),
156
+ )
157
+ # ================================
158
+ self.add_module(
159
+ "TC_block3",
160
+ _TCBlock(
161
+ in_ch=tcn_in_channel,
162
+ kernel_length=tcn_kernel_size,
163
+ dilatation=tcn_dilatation * 4,
164
+ padding=tcn_padding * 4,
165
+ drop_prob=drop_prob,
166
+ activation=activation,
167
+ ),
168
+ )
169
+ # ================================
170
+ self.add_module(
171
+ "TC_block4",
172
+ _TCBlock(
173
+ in_ch=tcn_in_channel,
174
+ kernel_length=tcn_kernel_size,
175
+ dilatation=tcn_dilatation * 8,
176
+ padding=tcn_padding * 8,
177
+ drop_prob=drop_prob,
178
+ activation=activation,
179
+ ),
180
+ )
181
+
182
+ # ============= Dimensionality reduction ===================
183
+ self.add_module(
184
+ "dim_reduction",
185
+ nn.Sequential(
186
+ nn.Conv2d(tcn_in_channel, tcn_in_channel * 2, kernel_size=(1, 1)),
187
+ nn.BatchNorm2d(tcn_in_channel * 2),
188
+ activation(),
189
+ nn.AvgPool2d((1, tcn_kernel_size)),
190
+ nn.Dropout(drop_prob),
191
+ ),
192
+ )
193
+ # ============== Classifier ==================
194
+ # Moved flatten to another layer
195
+ self.add_module("flatten", nn.Flatten())
196
+
197
+ num_features = self.get_output_shape()[-1]
198
+
199
+ self.add_module("final_layer", nn.Linear(num_features, self.n_outputs))
200
+
201
+ @staticmethod
202
+ def _get_inception_branch(
203
+ in_channels,
204
+ out_channels,
205
+ kernel_length,
206
+ depth_multiplier=1,
207
+ activation: nn.Module = nn.ELU,
208
+ ):
209
+ return nn.Sequential(
210
+ nn.Conv2d(
211
+ 1,
212
+ out_channels,
213
+ kernel_size=(1, kernel_length),
214
+ padding="same",
215
+ bias=False,
216
+ ),
217
+ nn.BatchNorm2d(out_channels),
218
+ DepthwiseConv2d(
219
+ out_channels,
220
+ kernel_size=(in_channels, 1),
221
+ depth_multiplier=depth_multiplier,
222
+ bias=False,
223
+ padding="valid",
224
+ ),
225
+ nn.BatchNorm2d(out_channels),
226
+ activation(),
227
+ )
228
+
229
+
230
+ class _TCBlock(nn.Module):
231
+ """
232
+ Temporal Convolutional (TC) block.
233
+
234
+ This module applies two depthwise separable convolutions with dilation and residual
235
+ connections, commonly used in temporal convolutional networks to capture long-range
236
+ dependencies in time-series data.
237
+
238
+ Parameters
239
+ ----------
240
+ in_ch : int
241
+ Number of input channels.
242
+ kernel_length : int
243
+ Length of the convolutional kernels.
244
+ dilatation : int
245
+ Dilatation rate for the convolutions.
246
+ padding : int
247
+ Amount of padding to add to the input.
248
+ drop_prob : float, optional
249
+ Dropout probability. Default is 0.4.
250
+ activation : nn.Module class, optional
251
+ Activation function class to use. Should be a PyTorch activation module class
252
+ like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
253
+ """
254
+
255
+ def __init__(
256
+ self,
257
+ in_ch,
258
+ kernel_length,
259
+ dilatation,
260
+ padding,
261
+ drop_prob=0.4,
262
+ activation: nn.Module = nn.ELU,
263
+ ):
264
+ super().__init__()
265
+ self.pad = padding
266
+ self.tc1 = nn.Sequential(
267
+ DepthwiseConv2d(
268
+ in_ch,
269
+ kernel_size=(1, kernel_length),
270
+ depth_multiplier=1,
271
+ dilation=(1, dilatation),
272
+ bias=False,
273
+ padding="valid",
274
+ ),
275
+ nn.BatchNorm2d(in_ch),
276
+ activation(),
277
+ nn.Dropout(drop_prob),
278
+ )
279
+
280
+ self.tc2 = nn.Sequential(
281
+ DepthwiseConv2d(
282
+ in_ch,
283
+ kernel_size=(1, kernel_length),
284
+ depth_multiplier=1,
285
+ dilation=(1, dilatation),
286
+ bias=False,
287
+ padding="valid",
288
+ ),
289
+ nn.BatchNorm2d(in_ch),
290
+ activation(),
291
+ nn.Dropout(drop_prob),
292
+ )
293
+
294
+ def forward(self, x):
295
+ residual = x
296
+ paddings = (self.pad, 0, 0, 0, 0, 0, 0, 0)
297
+ x = nn.functional.pad(x, paddings)
298
+ x = self.tc1(x)
299
+ x = nn.functional.pad(x, paddings)
300
+ x = self.tc2(x) + residual
301
+ return x