braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,450 @@
1
+ """
2
+ CTNet: a convolutional transformer network for EEG-based motor imagery
3
+ classification from Wei Zhao et al. (2024).
4
+ """
5
+
6
+ # Authors: Wei Zhao <zhaowei701@163.com>
7
+ # Bruno Aristimunha <b.aristimunha@gmail.com> (braindecode adaptation)
8
+ # License: MIT
9
+
10
+ from __future__ import annotations
11
+
12
+ import math
13
+
14
+ import torch
15
+ from einops.layers.torch import Rearrange
16
+ from mne.utils import warn
17
+ from torch import Tensor, nn
18
+
19
+ from braindecode.models.base import EEGModuleMixin
20
+ from braindecode.modules import (
21
+ FeedForwardBlock,
22
+ MultiHeadAttention,
23
+ )
24
+
25
+
26
+ class CTNet(EEGModuleMixin, nn.Module):
27
+ """CTNet from Zhao, W et al (2024) [ctnet]_.
28
+
29
+ A Convolutional Transformer Network for EEG-Based Motor Imagery Classification
30
+
31
+ .. figure:: https://raw.githubusercontent.com/snailpt/CTNet/main/architecture.png
32
+ :align: center
33
+ :alt: CTNet Architecture
34
+
35
+ CTNet is an end-to-end neural network architecture designed for classifying motor imagery (MI) tasks from EEG signals.
36
+ The model combines convolutional neural networks (CNNs) with a Transformer encoder to capture both local and global temporal dependencies in the EEG data.
37
+
38
+ The architecture consists of three main components:
39
+
40
+ 1. **Convolutional Module**:
41
+ - Apply EEGNetV4 to perform some feature extraction, denoted here as
42
+ _PatchEmbeddingEEGNet module.
43
+
44
+ 2. **Transformer Encoder Module**:
45
+ - Utilizes multi-head self-attention mechanisms as EEGConformer but
46
+ with residual blocks.
47
+
48
+ 3. **Classifier Module**:
49
+ - Combines features from both the convolutional module
50
+ and the Transformer encoder.
51
+ - Flattens the combined features and applies dropout for regularization.
52
+ - Uses a fully connected layer to produce the final classification output.
53
+
54
+ Parameters
55
+ ----------
56
+ activation : nn.Module, default=nn.GELU
57
+ Activation function to use in the network.
58
+ heads : int, default=4
59
+ Number of attention heads in the Transformer encoder.
60
+ emb_size : int, default=40
61
+ Embedding size (dimensionality) for the Transformer encoder.
62
+ depth : int, default=6
63
+ Number of encoder layers in the Transformer.
64
+ n_filters_time : int, default=20
65
+ Number of temporal filters in the first convolutional layer.
66
+ kernel_size : int, default=64
67
+ Kernel size for the temporal convolutional layer.
68
+ depth_multiplier : int, default=2
69
+ Multiplier for the number of depth-wise convolutional filters.
70
+ pool_size_1 : int, default=8
71
+ Pooling size for the first average pooling layer.
72
+ pool_size_2 : int, default=8
73
+ Pooling size for the second average pooling layer.
74
+ drop_prob_cnn : float, default=0.3
75
+ Dropout probability after convolutional layers.
76
+ drop_prob_posi : float, default=0.1
77
+ Dropout probability for the positional encoding in the Transformer.
78
+ drop_prob_final : float, default=0.5
79
+ Dropout probability before the final classification layer.
80
+
81
+ Notes
82
+ -----
83
+ This implementation is adapted from the original CTNet source code
84
+ [ctnetcode]_ to comply with Braindecode's model standards.
85
+
86
+ References
87
+ ----------
88
+ .. [ctnet] Zhao, W., Jiang, X., Zhang, B., Xiao, S., & Weng, S. (2024).
89
+ CTNet: a convolutional transformer network for EEG-based motor imagery
90
+ classification. Scientific Reports, 14(1), 20237.
91
+ .. [ctnetcode] Zhao, W., Jiang, X., Zhang, B., Xiao, S., & Weng, S. (2024).
92
+ CTNet source code:
93
+ https://github.com/snailpt/CTNet
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ # Base arguments
99
+ n_outputs=None,
100
+ n_chans=None,
101
+ sfreq=None,
102
+ chs_info=None,
103
+ n_times=None,
104
+ input_window_seconds=None,
105
+ # Model specific arguments
106
+ activation_patch: nn.Module = nn.ELU,
107
+ activation_transformer: nn.Module = nn.GELU,
108
+ drop_prob_cnn: float = 0.3,
109
+ drop_prob_posi: float = 0.1,
110
+ drop_prob_final: float = 0.5,
111
+ # other parameters
112
+ heads: int = 4,
113
+ emb_size: int = 40,
114
+ depth: int = 6,
115
+ n_filters_time: int = 20,
116
+ kernel_size: int = 64,
117
+ depth_multiplier: int = 2,
118
+ pool_size_1: int = 8,
119
+ pool_size_2: int = 8,
120
+ ):
121
+ super().__init__(
122
+ n_outputs=n_outputs,
123
+ n_chans=n_chans,
124
+ chs_info=chs_info,
125
+ n_times=n_times,
126
+ input_window_seconds=input_window_seconds,
127
+ sfreq=sfreq,
128
+ )
129
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
130
+
131
+ self.emb_size = emb_size
132
+ self.activation_patch = activation_patch
133
+ self.activation_transformer = activation_transformer
134
+
135
+ self.n_filters_time = n_filters_time
136
+ self.drop_prob_cnn = drop_prob_cnn
137
+ self.pool_size_1 = pool_size_1
138
+ self.pool_size_2 = pool_size_2
139
+ self.depth_multiplier = depth_multiplier
140
+ self.kernel_size = kernel_size
141
+ self.drop_prob_posi = drop_prob_posi
142
+ self.drop_prob_final = drop_prob_final
143
+
144
+ # n_times - pool_size_1 / p
145
+ sequence_length = math.floor(
146
+ (
147
+ math.floor((self.n_times - self.pool_size_1) / self.pool_size_1 + 1)
148
+ - self.pool_size_2
149
+ )
150
+ / self.pool_size_2
151
+ + 1
152
+ )
153
+
154
+ # Layers
155
+ self.ensuredim = Rearrange("batch nchans time -> batch 1 nchans time")
156
+ self.flatten = nn.Flatten()
157
+
158
+ self.cnn = _PatchEmbeddingEEGNet(
159
+ n_filters_time=self.n_filters_time,
160
+ kernel_size=self.kernel_size,
161
+ depth_multiplier=self.depth_multiplier,
162
+ pool_size_1=self.pool_size_1,
163
+ pool_size_2=self.pool_size_2,
164
+ drop_prob=self.drop_prob_cnn,
165
+ n_chans=self.n_chans,
166
+ activation=self.activation_patch,
167
+ )
168
+
169
+ self.position = _PositionalEncoding(
170
+ emb_size=emb_size,
171
+ drop_prob=self.drop_prob_posi,
172
+ n_times=self.n_times,
173
+ pool_size=self.pool_size_1,
174
+ )
175
+
176
+ self.trans = _TransformerEncoder(
177
+ heads, depth, emb_size, activation=self.activation_transformer
178
+ )
179
+
180
+ self.flatten_drop_layer = nn.Sequential(
181
+ nn.Flatten(),
182
+ nn.Dropout(p=self.drop_prob_final),
183
+ )
184
+
185
+ self.final_layer = nn.Linear(
186
+ in_features=emb_size * sequence_length, out_features=self.n_outputs
187
+ )
188
+
189
+ def forward(self, x: Tensor) -> Tensor:
190
+ """
191
+ Forward pass of the CTNet model.
192
+
193
+ Parameters
194
+ ----------
195
+ x : Tensor
196
+ Input tensor of shape (batch_size, n_channels, n_times).
197
+
198
+ Returns
199
+ -------
200
+ Tensor
201
+ Output with shape (batch_size, n_outputs).
202
+ """
203
+ x = self.ensuredim(x)
204
+ cnn = self.cnn(x)
205
+ cnn = cnn * math.sqrt(self.emb_size)
206
+ cnn = self.position(cnn)
207
+ trans = self.trans(cnn)
208
+ features = cnn + trans
209
+ flatten_feature = self.flatten(features)
210
+ out = self.final_layer(flatten_feature)
211
+ return out
212
+
213
+
214
+ class _PatchEmbeddingEEGNet(nn.Module):
215
+ def __init__(
216
+ self,
217
+ n_filters_time: int = 16,
218
+ kernel_size: int = 64,
219
+ depth_multiplier: int = 2,
220
+ pool_size_1: int = 8,
221
+ pool_size_2: int = 8,
222
+ drop_prob: float = 0.3,
223
+ n_chans: int = 22,
224
+ activation: nn.Module = nn.ELU,
225
+ ):
226
+ super().__init__()
227
+ n_filters_out = depth_multiplier * n_filters_time
228
+ self.eegnet_module = nn.Sequential(
229
+ # Temporal convolution
230
+ nn.Conv2d(
231
+ in_channels=1,
232
+ out_channels=n_filters_time,
233
+ kernel_size=(1, kernel_size),
234
+ stride=(1, 1),
235
+ padding="same",
236
+ bias=False,
237
+ ),
238
+ nn.BatchNorm2d(n_filters_time),
239
+ # Channel depth-wise convolution
240
+ nn.Conv2d(
241
+ in_channels=n_filters_time,
242
+ out_channels=n_filters_out,
243
+ kernel_size=(n_chans, 1),
244
+ stride=(1, 1),
245
+ groups=n_filters_time,
246
+ padding="valid",
247
+ bias=False,
248
+ ),
249
+ nn.BatchNorm2d(n_filters_out),
250
+ activation(),
251
+ # First average pooling
252
+ nn.AvgPool2d(kernel_size=(1, pool_size_1)),
253
+ nn.Dropout(drop_prob),
254
+ # Spatial convolution
255
+ nn.Conv2d(
256
+ in_channels=n_filters_out,
257
+ out_channels=n_filters_out,
258
+ kernel_size=(1, 16),
259
+ padding="same",
260
+ bias=False,
261
+ ),
262
+ nn.BatchNorm2d(n_filters_out),
263
+ activation(),
264
+ # Second average pooling
265
+ nn.AvgPool2d(kernel_size=(1, pool_size_2)),
266
+ nn.Dropout(drop_prob),
267
+ )
268
+
269
+ self.projection = nn.Sequential(
270
+ Rearrange("b e h w -> b (h w) e"),
271
+ )
272
+
273
+ def forward(self, x: Tensor) -> Tensor:
274
+ """
275
+ Forward pass of the Patch Embedding CNN.
276
+
277
+ Parameters
278
+ ----------
279
+ x : Tensor
280
+ Input tensor of shape (batch_size, 1, n_channels, n_times).
281
+
282
+ Returns
283
+ -------
284
+ Tensor
285
+ Embedded patches of shape (batch_size, num_patches, embedding_dim).
286
+ """
287
+ x = self.eegnet_module(x)
288
+ x = self.projection(x)
289
+ return x
290
+
291
+
292
+ class _ResidualAdd(nn.Module):
293
+ def __init__(self, module: nn.Module, emb_size: int, drop_p: float):
294
+ super().__init__()
295
+ self.module = module
296
+ self.drop = nn.Dropout(drop_p)
297
+ self.layernorm = nn.LayerNorm(emb_size)
298
+
299
+ def forward(self, x: Tensor) -> Tensor:
300
+ """
301
+ Forward pass with residual connection.
302
+
303
+ Parameters
304
+ ----------
305
+ x : Tensor
306
+ Input tensor.
307
+ **kwargs : Any
308
+ Additional arguments.
309
+
310
+ Returns
311
+ -------
312
+ Tensor
313
+ Output tensor after applying residual connection.
314
+ """
315
+ res = self.module(x)
316
+ out = self.layernorm(self.drop(res) + x)
317
+ return out
318
+
319
+
320
+ class _TransformerEncoderBlock(nn.Module):
321
+ def __init__(
322
+ self,
323
+ dim_feedforward: int,
324
+ num_heads: int = 4,
325
+ drop_prob: float = 0.5,
326
+ forward_expansion: int = 4,
327
+ forward_drop_p: float = 0.5,
328
+ activation: nn.Module = nn.GELU,
329
+ ):
330
+ super().__init__()
331
+ self.attention = _ResidualAdd(
332
+ nn.Sequential(
333
+ MultiHeadAttention(dim_feedforward, num_heads, drop_prob),
334
+ ),
335
+ dim_feedforward,
336
+ drop_prob,
337
+ )
338
+ self.feed_forward = _ResidualAdd(
339
+ nn.Sequential(
340
+ FeedForwardBlock(
341
+ dim_feedforward,
342
+ expansion=forward_expansion,
343
+ drop_p=forward_drop_p,
344
+ activation=activation,
345
+ ),
346
+ ),
347
+ dim_feedforward,
348
+ drop_prob,
349
+ )
350
+
351
+ def forward(self, x: Tensor) -> Tensor:
352
+ """
353
+ Forward pass of the transformer encoder block.
354
+
355
+ Parameters
356
+ ----------
357
+ x : Tensor
358
+ Input tensor.
359
+ **kwargs : Any
360
+ Additional arguments.
361
+
362
+ Returns
363
+ -------
364
+ Tensor
365
+ Output tensor after transformer encoder block.
366
+ """
367
+ x = self.attention(x)
368
+ x = self.feed_forward(x)
369
+ return x
370
+
371
+
372
+ class _TransformerEncoder(nn.Module):
373
+ def __init__(
374
+ self,
375
+ nheads: int,
376
+ depth: int,
377
+ dim_feedforward: int,
378
+ activation: nn.Module = nn.GELU,
379
+ ):
380
+ super().__init__()
381
+ self.layers = nn.Sequential(
382
+ *[
383
+ _TransformerEncoderBlock(
384
+ dim_feedforward=dim_feedforward,
385
+ num_heads=nheads,
386
+ activation=activation,
387
+ )
388
+ for _ in range(depth)
389
+ ]
390
+ )
391
+
392
+ def forward(self, x: Tensor) -> Tensor:
393
+ """
394
+ Forward pass of the transformer encoder.
395
+
396
+ Parameters
397
+ ----------
398
+ x : Tensor
399
+ Input tensor.
400
+
401
+ Returns
402
+ -------
403
+ Tensor
404
+ Output tensor after transformer encoder.
405
+ """
406
+ return self.layers(x)
407
+
408
+
409
+ class _PositionalEncoding(nn.Module):
410
+ def __init__(
411
+ self,
412
+ n_times: int,
413
+ emb_size: int,
414
+ length: int = 100,
415
+ drop_prob: float = 0.1,
416
+ pool_size: int = 8,
417
+ ):
418
+ super().__init__()
419
+ self.pool_size = pool_size
420
+ self.n_times = n_times
421
+
422
+ if int(n_times / (pool_size * pool_size)) > length:
423
+ warn(
424
+ "the temporal dimensional is too long for this default length. "
425
+ "The length parameter will be automatically adjusted to "
426
+ "avoid inference issues."
427
+ )
428
+ length = int(n_times / (pool_size * pool_size))
429
+
430
+ self.dropout = nn.Dropout(drop_prob)
431
+ self.encoding = nn.Parameter(torch.randn(1, length, emb_size))
432
+
433
+ def forward(self, x: Tensor) -> Tensor:
434
+ """
435
+ Forward pass of the positional encoding.
436
+
437
+ Parameters
438
+ ----------
439
+ x : Tensor
440
+ Input tensor of shape (batch_size, sequence_length, embedding_dim).
441
+
442
+ Returns
443
+ -------
444
+ Tensor
445
+ Tensor with positional encoding added.
446
+ """
447
+ seq_length = x.size(1)
448
+ encoding = self.encoding[:, :seq_length, :]
449
+ x = x + encoding
450
+ return self.dropout(x)
@@ -5,15 +5,22 @@
5
5
  from einops.layers.torch import Rearrange
6
6
  from torch import nn
7
7
  from torch.nn import init
8
- from torch.nn.functional import elu
9
8
 
10
- from .base import EEGModuleMixin, deprecated_args
11
- from .functions import identity, squeeze_final_output
12
- from .modules import AvgPool2dWithConv, CombinedConv, Ensure4d, Expression
9
+ from braindecode.models.base import EEGModuleMixin
10
+ from braindecode.modules import (
11
+ AvgPool2dWithConv,
12
+ CombinedConv,
13
+ Ensure4d,
14
+ SqueezeFinalOutput,
15
+ )
13
16
 
14
17
 
15
18
  class Deep4Net(EEGModuleMixin, nn.Sequential):
16
- """Deep ConvNet model from Schirrmeister et al 2017.
19
+ """Deep ConvNet model from Schirrmeister et al (2017) [Schirrmeister2017]_.
20
+
21
+ .. figure:: https://onlinelibrary.wiley.com/cms/asset/fc200ccc-d8c4-45b4-8577-56ce4d15999a/hbm23730-fig-0001-m.jpg
22
+ :align: center
23
+ :alt: CTNet Architecture
17
24
 
18
25
  Model described in [Schirrmeister2017]_.
19
26
 
@@ -44,13 +51,13 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
44
51
  Number of temporal filters in layer 4.
45
52
  filter_length_4: int
46
53
  Length of the temporal filter in layer 4.
47
- first_conv_nonlin: callable
54
+ activation_first_conv_nonlin: nn.Module, default is nn.ELU
48
55
  Non-linear activation function to be used after convolution in layer 1.
49
56
  first_pool_mode: str
50
57
  Pooling mode in layer 1. "max" or "mean".
51
58
  first_pool_nonlin: callable
52
59
  Non-linear activation function to be used after pooling in layer 1.
53
- later_conv_nonlin: callable
60
+ activation_later_conv_nonlin: nn.Module, default is nn.ELU
54
61
  Non-linear activation function to be used after convolution in later layers.
55
62
  later_pool_mode: str
56
63
  Pooling mode in later layers. "max" or "mean".
@@ -67,12 +74,6 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
67
74
  Momentum for BatchNorm2d.
68
75
  stride_before_pool: bool
69
76
  Stride before pooling.
70
- in_chans :
71
- Alias for n_chans.
72
- n_classes:
73
- Alias for n_outputs.
74
- input_window_samples :
75
- Alias for n_times.
76
77
 
77
78
 
78
79
  References
@@ -87,47 +88,37 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
87
88
  """
88
89
 
89
90
  def __init__(
90
- self,
91
- n_chans=None,
92
- n_outputs=None,
93
- n_times=None,
94
- final_conv_length="auto",
95
- n_filters_time=25,
96
- n_filters_spat=25,
97
- filter_time_length=10,
98
- pool_time_length=3,
99
- pool_time_stride=3,
100
- n_filters_2=50,
101
- filter_length_2=10,
102
- n_filters_3=100,
103
- filter_length_3=10,
104
- n_filters_4=200,
105
- filter_length_4=10,
106
- first_conv_nonlin=elu,
107
- first_pool_mode="max",
108
- first_pool_nonlin=identity,
109
- later_conv_nonlin=elu,
110
- later_pool_mode="max",
111
- later_pool_nonlin=identity,
112
- drop_prob=0.5,
113
- split_first_layer=True,
114
- batch_norm=True,
115
- batch_norm_alpha=0.1,
116
- stride_before_pool=False,
117
- chs_info=None,
118
- input_window_seconds=None,
119
- sfreq=None,
120
- in_chans=None,
121
- n_classes=None,
122
- input_window_samples=None,
123
- add_log_softmax=True,
91
+ self,
92
+ n_chans=None,
93
+ n_outputs=None,
94
+ n_times=None,
95
+ final_conv_length="auto",
96
+ n_filters_time=25,
97
+ n_filters_spat=25,
98
+ filter_time_length=10,
99
+ pool_time_length=3,
100
+ pool_time_stride=3,
101
+ n_filters_2=50,
102
+ filter_length_2=10,
103
+ n_filters_3=100,
104
+ filter_length_3=10,
105
+ n_filters_4=200,
106
+ filter_length_4=10,
107
+ activation_first_conv_nonlin: nn.Module = nn.ELU,
108
+ first_pool_mode="max",
109
+ first_pool_nonlin: nn.Module = nn.Identity,
110
+ activation_later_conv_nonlin: nn.Module = nn.ELU,
111
+ later_pool_mode="max",
112
+ later_pool_nonlin: nn.Module = nn.Identity,
113
+ drop_prob=0.5,
114
+ split_first_layer=True,
115
+ batch_norm=True,
116
+ batch_norm_alpha=0.1,
117
+ stride_before_pool=False,
118
+ chs_info=None,
119
+ input_window_seconds=None,
120
+ sfreq=None,
124
121
  ):
125
- n_chans, n_outputs, n_times = deprecated_args(
126
- self,
127
- ('in_chans', 'n_chans', in_chans, n_chans),
128
- ('n_classes', 'n_outputs', n_classes, n_outputs),
129
- ('input_window_samples', 'n_times', input_window_samples, n_times),
130
- )
131
122
  super().__init__(
132
123
  n_outputs=n_outputs,
133
124
  n_chans=n_chans,
@@ -135,10 +126,9 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
135
126
  n_times=n_times,
136
127
  input_window_seconds=input_window_seconds,
137
128
  sfreq=sfreq,
138
- add_log_softmax=add_log_softmax,
139
129
  )
140
130
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
141
- del in_chans, n_classes, input_window_samples
131
+
142
132
  if final_conv_length == "auto":
143
133
  assert self.n_times is not None
144
134
  self.final_conv_length = final_conv_length
@@ -153,10 +143,10 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
153
143
  self.filter_length_3 = filter_length_3
154
144
  self.n_filters_4 = n_filters_4
155
145
  self.filter_length_4 = filter_length_4
156
- self.first_nonlin = first_conv_nonlin
146
+ self.first_nonlin = activation_first_conv_nonlin
157
147
  self.first_pool_mode = first_pool_mode
158
148
  self.first_pool_nonlin = first_pool_nonlin
159
- self.later_conv_nonlin = later_conv_nonlin
149
+ self.later_conv_nonlin = activation_later_conv_nonlin
160
150
  self.later_pool_mode = later_pool_mode
161
151
  self.later_pool_nonlin = later_pool_nonlin
162
152
  self.drop_prob = drop_prob
@@ -174,7 +164,7 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
174
164
  "conv_time.bias": "conv_time_spat.conv_time.bias",
175
165
  "conv_spat.bias": "conv_time_spat.conv_spat.bias",
176
166
  "conv_classifier.weight": "final_layer.conv_classifier.weight",
177
- "conv_classifier.bias": "final_layer.conv_classifier.bias"
167
+ "conv_classifier.bias": "final_layer.conv_classifier.bias",
178
168
  }
179
169
 
180
170
  if self.stride_before_pool:
@@ -223,17 +213,17 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
223
213
  eps=1e-5,
224
214
  ),
225
215
  )
226
- self.add_module("conv_nonlin", Expression(self.first_nonlin))
216
+ self.add_module("conv_nonlin", self.first_nonlin())
227
217
  self.add_module(
228
218
  "pool",
229
219
  first_pool_class(
230
220
  kernel_size=(self.pool_time_length, 1), stride=(pool_stride, 1)
231
221
  ),
232
222
  )
233
- self.add_module("pool_nonlin", Expression(self.first_pool_nonlin))
223
+ self.add_module("pool_nonlin", self.first_pool_nonlin())
234
224
 
235
225
  def add_conv_pool_block(
236
- model, n_filters_before, n_filters, filter_length, block_nr
226
+ model, n_filters_before, n_filters, filter_length, block_nr
237
227
  ):
238
228
  suffix = "_{:d}".format(block_nr)
239
229
  self.add_module("drop" + suffix, nn.Dropout(p=self.drop_prob))
@@ -257,7 +247,7 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
257
247
  eps=1e-5,
258
248
  ),
259
249
  )
260
- self.add_module("nonlin" + suffix, Expression(self.later_conv_nonlin))
250
+ self.add_module("nonlin" + suffix, self.later_conv_nonlin())
261
251
 
262
252
  self.add_module(
263
253
  "pool" + suffix,
@@ -266,7 +256,7 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
266
256
  stride=(pool_stride, 1),
267
257
  ),
268
258
  )
269
- self.add_module("pool_nonlin" + suffix, Expression(self.later_pool_nonlin))
259
+ self.add_module("pool_nonlin" + suffix, self.later_pool_nonlin())
270
260
 
271
261
  add_conv_pool_block(
272
262
  self, n_filters_conv, self.n_filters_2, self.filter_length_2, 2
@@ -286,17 +276,17 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
286
276
  # Incorporating classification module and subsequent ones in one final layer
287
277
  module = nn.Sequential()
288
278
 
289
- module.add_module("conv_classifier",
290
- nn.Conv2d(
291
- self.n_filters_4,
292
- self.n_outputs,
293
- (self.final_conv_length, 1),
294
- bias=True, ))
295
-
296
- if self.add_log_softmax:
297
- module.add_module("logsoftmax", nn.LogSoftmax(dim=1))
279
+ module.add_module(
280
+ "conv_classifier",
281
+ nn.Conv2d(
282
+ self.n_filters_4,
283
+ self.n_outputs,
284
+ (self.final_conv_length, 1),
285
+ bias=True,
286
+ ),
287
+ )
298
288
 
299
- module.add_module("squeeze", Expression(squeeze_final_output))
289
+ module.add_module("squeeze", SqueezeFinalOutput())
300
290
 
301
291
  self.add_module("final_layer", module)
302
292
 
@@ -329,5 +319,4 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
329
319
  init.xavier_uniform_(self.final_layer.conv_classifier.weight, gain=1)
330
320
  init.constant_(self.final_layer.conv_classifier.bias, 0)
331
321
 
332
- # Start in eval mode
333
- self.eval()
322
+ self.train()