braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,85 +1,111 @@
1
1
  # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
2
  #
3
3
  # License: BSD (3-clause)
4
+ from __future__ import annotations
5
+
6
+ from typing import Dict, Optional
4
7
 
5
- import torch
6
8
  from einops.layers.torch import Rearrange
9
+ from mne.utils import warn
7
10
  from torch import nn
8
- from torch.nn.functional import elu
9
-
10
- from .base import EEGModuleMixin, deprecated_args
11
- from .functions import squeeze_final_output
12
- from .modules import Ensure4d, Expression
13
-
14
11
 
15
- class Conv2dWithConstraint(nn.Conv2d):
16
- def __init__(self, *args, max_norm=1, **kwargs):
17
- self.max_norm = max_norm
18
- super(Conv2dWithConstraint, self).__init__(*args, **kwargs)
19
-
20
- def forward(self, x):
21
- self.weight.data = torch.renorm(
22
- self.weight.data, p=2, dim=0, maxnorm=self.max_norm
23
- )
24
- return super(Conv2dWithConstraint, self).forward(x)
12
+ from braindecode.functional import glorot_weight_zero_bias
13
+ from braindecode.models.base import EEGModuleMixin
14
+ from braindecode.modules import (
15
+ Conv2dWithConstraint,
16
+ Ensure4d,
17
+ Expression,
18
+ LinearWithConstraint,
19
+ SqueezeFinalOutput,
20
+ )
25
21
 
26
22
 
27
23
  class EEGNetv4(EEGModuleMixin, nn.Sequential):
28
- """EEGNet v4 model from Lawhern et al 2018.
24
+ """EEGNet v4 model from Lawhern et al. (2018) [EEGNet4]_.
25
+
26
+ .. figure:: https://content.cld.iop.org/journals/1741-2552/15/5/056013/revision2/jneaace8cf01_hr.jpg
27
+ :align: center
28
+ :alt: EEGNet4 Architecture
29
29
 
30
30
  See details in [EEGNet4]_.
31
31
 
32
32
  Parameters
33
33
  ----------
34
- final_conv_length : int | "auto"
35
- If int, final length of convolutional filters.
36
- in_chans :
37
- Alias for n_chans.
38
- n_classes:
39
- Alias for n_outputs.
40
- input_window_samples :
41
- Alias for n_times.
42
-
43
- Notes
44
- -----
45
- This implementation is not guaranteed to be correct, has not been checked
46
- by original authors, only reimplemented from the paper description.
34
+ final_conv_length : int or "auto", default="auto"
35
+ Length of the final convolution layer. If "auto", it is set based on n_times.
36
+ pool_mode : {"mean", "max"}, default="mean"
37
+ Pooling method to use in pooling layers.
38
+ F1 : int, default=8
39
+ Number of temporal filters in the first convolutional layer.
40
+ D : int, default=2
41
+ Depth multiplier for the depthwise convolution.
42
+ F2 : int or None, default=None
43
+ Number of pointwise filters in the separable convolution. Usually set to ``F1 * D``.
44
+ depthwise_kernel_length : int, default=16
45
+ Length of the depthwise convolution kernel in the separable convolution.
46
+ pool1_kernel_size : int, default=4
47
+ Kernel size of the first pooling layer.
48
+ pool2_kernel_size : int, default=8
49
+ Kernel size of the second pooling layer.
50
+ kernel_length : int, default=64
51
+ Length of the temporal convolution kernel.
52
+ conv_spatial_max_norm : float, default=1
53
+ Maximum norm constraint for the spatial (depthwise) convolution.
54
+ activation : nn.Module, default=nn.ELU
55
+ Non-linear activation function to be used in the layers.
56
+ batch_norm_momentum : float, default=0.01
57
+ Momentum for instance normalization in batch norm layers.
58
+ batch_norm_affine : bool, default=True
59
+ If True, batch norm has learnable affine parameters.
60
+ batch_norm_eps : float, default=1e-3
61
+ Epsilon for numeric stability in batch norm layers.
62
+ drop_prob : float, default=0.25
63
+ Dropout probability.
64
+ final_layer_with_constraint : bool, default=False
65
+ If ``False``, uses a convolution-based classification layer. If ``True``,
66
+ apply a flattened linear layer with constraint on the weights norm as the final classification step.
67
+ norm_rate : float, default=0.25
68
+ Max-norm constraint value for the linear layer (used if ``final_layer_conv=False``).
47
69
 
48
70
  References
49
71
  ----------
50
- .. [EEGNet4] Lawhern, V. J., Solon, A. J., Waytowich, N. R., Gordon,
51
- S. M., Hung, C. P., & Lance, B. J. (2018).
52
- EEGNet: A Compact Convolutional Network for EEG-based
53
- Brain-Computer Interfaces.
54
- arXiv preprint arXiv:1611.08024.
72
+ .. [EEGNet4] Lawhern, V. J., Solon, A. J., Waytowich, N. R., Gordon, S. M.,
73
+ Hung, C. P., & Lance, B. J. (2018). EEGNet: a compact convolutional
74
+ neural network for EEG-based brain–computer interfaces. Journal of
75
+ neural engineering, 15(5), 056013.
55
76
  """
56
77
 
57
78
  def __init__(
58
- self,
59
- n_chans=None,
60
- n_outputs=None,
61
- n_times=None,
62
- final_conv_length="auto",
63
- pool_mode="mean",
64
- F1=8,
65
- D=2,
66
- F2=16, # usually set to F1*D (?)
67
- kernel_length=64,
68
- third_kernel_size=(8, 4),
69
- drop_prob=0.25,
70
- chs_info=None,
71
- input_window_seconds=None,
72
- sfreq=None,
73
- in_chans=None,
74
- n_classes=None,
75
- input_window_samples=None,
79
+ self,
80
+ # signal's parameters
81
+ n_chans: Optional[int] = None,
82
+ n_outputs: Optional[int] = None,
83
+ n_times: Optional[int] = None,
84
+ # model's parameters
85
+ final_conv_length: str | int = "auto",
86
+ pool_mode: str = "mean",
87
+ F1: int = 8,
88
+ D: int = 2,
89
+ F2: Optional[int | None] = None,
90
+ kernel_length: int = 64,
91
+ *,
92
+ depthwise_kernel_length: int = 16,
93
+ pool1_kernel_size: int = 4,
94
+ pool2_kernel_size: int = 8,
95
+ conv_spatial_max_norm: int = 1,
96
+ activation: nn.Module = nn.ELU,
97
+ batch_norm_momentum: float = 0.01,
98
+ batch_norm_affine: bool = True,
99
+ batch_norm_eps: float = 1e-3,
100
+ drop_prob: float = 0.25,
101
+ final_layer_with_constraint: bool = False,
102
+ norm_rate: float = 0.25,
103
+ # Other ways to construct the signal related parameters
104
+ chs_info: Optional[list[Dict]] = None,
105
+ input_window_seconds=None,
106
+ sfreq=None,
107
+ **kwargs,
76
108
  ):
77
- n_chans, n_outputs, n_times = deprecated_args(
78
- self,
79
- ("in_chans", "n_chans", in_chans, n_chans),
80
- ("n_classes", "n_outputs", n_classes, n_outputs),
81
- ("input_window_samples", "n_times", input_window_samples, n_times),
82
- )
83
109
  super().__init__(
84
110
  n_outputs=n_outputs,
85
111
  n_chans=n_chans,
@@ -89,68 +115,106 @@ class EEGNetv4(EEGModuleMixin, nn.Sequential):
89
115
  sfreq=sfreq,
90
116
  )
91
117
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
92
- del in_chans, n_classes, input_window_samples
93
118
  if final_conv_length == "auto":
94
119
  assert self.n_times is not None
120
+
121
+ if not final_layer_with_constraint:
122
+ warn(
123
+ "Parameter 'final_layer_with_constraint=False' is deprecated and will be "
124
+ "removed in a future release. Please use `final_layer_linear=True`.",
125
+ DeprecationWarning,
126
+ )
127
+
128
+ if "third_kernel_size" in kwargs:
129
+ warn(
130
+ "The parameter `third_kernel_size` is deprecated "
131
+ "and will be removed in a future version.",
132
+ )
133
+ unexpected_kwargs = set(kwargs) - {"third_kernel_size"}
134
+ if unexpected_kwargs:
135
+ raise TypeError(f"Unexpected keyword arguments: {unexpected_kwargs}")
136
+
95
137
  self.final_conv_length = final_conv_length
96
138
  self.pool_mode = pool_mode
97
139
  self.F1 = F1
98
140
  self.D = D
141
+
142
+ if F2 is None:
143
+ F2 = self.F1 * self.D
99
144
  self.F2 = F2
145
+
100
146
  self.kernel_length = kernel_length
101
- self.third_kernel_size = third_kernel_size
147
+ self.depthwise_kernel_length = depthwise_kernel_length
148
+ self.pool1_kernel_size = pool1_kernel_size
149
+ self.pool2_kernel_size = pool2_kernel_size
102
150
  self.drop_prob = drop_prob
151
+ self.activation = activation
152
+ self.batch_norm_momentum = batch_norm_momentum
153
+ self.batch_norm_affine = batch_norm_affine
154
+ self.batch_norm_eps = batch_norm_eps
155
+ self.conv_spatial_max_norm = conv_spatial_max_norm
156
+ self.norm_rate = norm_rate
157
+
103
158
  # For the load_state_dict
104
159
  # When padronize all layers,
105
160
  # add the old's parameters here
106
161
  self.mapping = {
107
162
  "conv_classifier.weight": "final_layer.conv_classifier.weight",
108
- "conv_classifier.bias": "final_layer.conv_classifier.bias"
163
+ "conv_classifier.bias": "final_layer.conv_classifier.bias",
109
164
  }
110
165
 
111
166
  pool_class = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)[self.pool_mode]
112
167
  self.add_module("ensuredims", Ensure4d())
113
168
 
114
- self.add_module("dimshuffle",
115
- Rearrange("batch ch t 1 -> batch 1 ch t"))
169
+ self.add_module("dimshuffle", Rearrange("batch ch t 1 -> batch 1 ch t"))
116
170
  self.add_module(
117
171
  "conv_temporal",
118
172
  nn.Conv2d(
119
173
  1,
120
174
  self.F1,
121
175
  (1, self.kernel_length),
122
- stride=1,
123
176
  bias=False,
124
177
  padding=(0, self.kernel_length // 2),
125
178
  ),
126
179
  )
127
180
  self.add_module(
128
181
  "bnorm_temporal",
129
- nn.BatchNorm2d(self.F1, momentum=0.01, affine=True, eps=1e-3),
182
+ nn.BatchNorm2d(
183
+ self.F1,
184
+ momentum=self.batch_norm_momentum,
185
+ affine=self.batch_norm_affine,
186
+ eps=self.batch_norm_eps,
187
+ ),
130
188
  )
131
189
  self.add_module(
132
190
  "conv_spatial",
133
191
  Conv2dWithConstraint(
134
- self.F1,
135
- self.F1 * self.D,
136
- (self.n_chans, 1),
137
- max_norm=1,
138
- stride=1,
192
+ in_channels=self.F1,
193
+ out_channels=self.F1 * self.D,
194
+ kernel_size=(self.n_chans, 1),
195
+ max_norm=self.conv_spatial_max_norm,
139
196
  bias=False,
140
197
  groups=self.F1,
141
- padding=(0, 0),
142
198
  ),
143
199
  )
144
200
 
145
201
  self.add_module(
146
202
  "bnorm_1",
147
203
  nn.BatchNorm2d(
148
- self.F1 * self.D, momentum=0.01, affine=True, eps=1e-3
204
+ self.F1 * self.D,
205
+ momentum=self.batch_norm_momentum,
206
+ affine=self.batch_norm_affine,
207
+ eps=self.batch_norm_eps,
149
208
  ),
150
209
  )
151
- self.add_module("elu_1", Expression(elu))
210
+ self.add_module("elu_1", activation())
152
211
 
153
- self.add_module("pool_1", pool_class(kernel_size=(1, 4), stride=(1, 4)))
212
+ self.add_module(
213
+ "pool_1",
214
+ pool_class(
215
+ kernel_size=(1, self.pool1_kernel_size),
216
+ ),
217
+ )
154
218
  self.add_module("drop_1", nn.Dropout(p=self.drop_prob))
155
219
 
156
220
  # https://discuss.pytorch.org/t/how-to-modify-a-conv2d-to-depthwise-separable-convolution/15843/7
@@ -159,11 +223,10 @@ class EEGNetv4(EEGModuleMixin, nn.Sequential):
159
223
  nn.Conv2d(
160
224
  self.F1 * self.D,
161
225
  self.F1 * self.D,
162
- (1, 16),
163
- stride=1,
226
+ (1, self.depthwise_kernel_length),
164
227
  bias=False,
165
228
  groups=self.F1 * self.D,
166
- padding=(0, 16 // 2),
229
+ padding=(0, self.depthwise_kernel_length // 2),
167
230
  ),
168
231
  )
169
232
  self.add_module(
@@ -171,19 +234,27 @@ class EEGNetv4(EEGModuleMixin, nn.Sequential):
171
234
  nn.Conv2d(
172
235
  self.F1 * self.D,
173
236
  self.F2,
174
- (1, 1),
175
- stride=1,
237
+ kernel_size=(1, 1),
176
238
  bias=False,
177
- padding=(0, 0),
178
239
  ),
179
240
  )
180
241
 
181
242
  self.add_module(
182
243
  "bnorm_2",
183
- nn.BatchNorm2d(self.F2, momentum=0.01, affine=True, eps=1e-3),
244
+ nn.BatchNorm2d(
245
+ self.F2,
246
+ momentum=self.batch_norm_momentum,
247
+ affine=self.batch_norm_affine,
248
+ eps=self.batch_norm_eps,
249
+ ),
250
+ )
251
+ self.add_module("elu_2", self.activation())
252
+ self.add_module(
253
+ "pool_2",
254
+ pool_class(
255
+ kernel_size=(1, self.pool2_kernel_size),
256
+ ),
184
257
  )
185
- self.add_module("elu_2", Expression(elu))
186
- self.add_module("pool_2", pool_class(kernel_size=(1, 8), stride=(1, 8)))
187
258
  self.add_module("drop_2", nn.Dropout(p=self.drop_prob))
188
259
 
189
260
  output_shape = self.get_output_shape()
@@ -195,27 +266,42 @@ class EEGNetv4(EEGModuleMixin, nn.Sequential):
195
266
 
196
267
  # Incorporating classification module and subsequent ones in one final layer
197
268
  module = nn.Sequential()
198
-
199
- module.add_module("conv_classifier",
200
- nn.Conv2d(self.F2, self.n_outputs,
201
- (n_out_virtual_chans, self.final_conv_length), bias=True, ))
202
-
203
- if self.add_log_softmax:
204
- module.add_module("logsoftmax", nn.LogSoftmax(dim=1))
205
-
206
- # Transpose back to the logic of braindecode,
207
- # so time in third dimension (axis=2)
208
- module.add_module("permute_back", Rearrange("batch x y z -> batch x z y"), )
209
-
210
- module.add_module("squeeze", Expression(squeeze_final_output))
211
-
269
+ if not final_layer_with_constraint:
270
+ module.add_module(
271
+ "conv_classifier",
272
+ nn.Conv2d(
273
+ self.F2,
274
+ self.n_outputs,
275
+ (n_out_virtual_chans, self.final_conv_length),
276
+ bias=True,
277
+ ),
278
+ )
279
+
280
+ # Transpose back to the logic of braindecode,
281
+ # so time in third dimension (axis=2)
282
+ module.add_module(
283
+ "permute_back",
284
+ Rearrange("batch x y z -> batch x z y"),
285
+ )
286
+
287
+ module.add_module("squeeze", SqueezeFinalOutput())
288
+ else:
289
+ module.add_module("flatten", nn.Flatten())
290
+ module.add_module(
291
+ "linearconstraint",
292
+ LinearWithConstraint(
293
+ in_features=self.F2 * self.final_conv_length,
294
+ out_features=self.n_outputs,
295
+ max_norm=norm_rate,
296
+ ),
297
+ )
212
298
  self.add_module("final_layer", module)
213
299
 
214
- _glorot_weight_zero_bias(self)
300
+ glorot_weight_zero_bias(self)
215
301
 
216
302
 
217
303
  class EEGNetv1(EEGModuleMixin, nn.Sequential):
218
- """EEGNet model from Lawhern et al. 2016.
304
+ """EEGNet model from Lawhern et al. 2016 from [EEGNet]_.
219
305
 
220
306
  See details in [EEGNet]_.
221
307
 
@@ -227,6 +313,9 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
227
313
  Alias for n_outputs.
228
314
  input_window_samples :
229
315
  Alias for n_times.
316
+ activation: nn.Module, default=nn.ELU
317
+ Activation function class to apply. Should be a PyTorch activation
318
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
230
319
 
231
320
  Notes
232
321
  -----
@@ -243,29 +332,20 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
243
332
  """
244
333
 
245
334
  def __init__(
246
- self,
247
- n_chans=None,
248
- n_outputs=None,
249
- n_times=None,
250
- final_conv_length="auto",
251
- pool_mode="max",
252
- second_kernel_size=(2, 32),
253
- third_kernel_size=(8, 4),
254
- drop_prob=0.25,
255
- chs_info=None,
256
- input_window_seconds=None,
257
- sfreq=None,
258
- in_chans=None,
259
- n_classes=None,
260
- input_window_samples=None,
261
- add_log_softmax=True,
335
+ self,
336
+ n_chans=None,
337
+ n_outputs=None,
338
+ n_times=None,
339
+ final_conv_length="auto",
340
+ pool_mode="max",
341
+ second_kernel_size=(2, 32),
342
+ third_kernel_size=(8, 4),
343
+ drop_prob=0.25,
344
+ activation: nn.Module = nn.ELU,
345
+ chs_info=None,
346
+ input_window_seconds=None,
347
+ sfreq=None,
262
348
  ):
263
- n_chans, n_outputs, n_times = deprecated_args(
264
- self,
265
- ("in_chans", "n_chans", in_chans, n_chans),
266
- ("n_classes", "n_outputs", n_classes, n_outputs),
267
- ("input_window_samples", "n_times", input_window_samples, n_times),
268
- )
269
349
  super().__init__(
270
350
  n_outputs=n_outputs,
271
351
  n_chans=n_chans,
@@ -273,10 +353,14 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
273
353
  n_times=n_times,
274
354
  input_window_seconds=input_window_seconds,
275
355
  sfreq=sfreq,
276
- add_log_softmax=add_log_softmax,
277
356
  )
278
357
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
279
- del in_chans, n_classes, input_window_samples
358
+ warn(
359
+ "The class EEGNetv1 is deprecated and will be removed in the "
360
+ "release 1.0 of braindecode. Please use "
361
+ "braindecode.models.EEGNetv4 instead in the future.",
362
+ DeprecationWarning,
363
+ )
280
364
  if final_conv_length == "auto":
281
365
  assert self.n_times is not None
282
366
  self.final_conv_length = final_conv_length
@@ -289,7 +373,7 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
289
373
  # add the old's parameters here
290
374
  self.mapping = {
291
375
  "conv_classifier.weight": "final_layer.conv_classifier.weight",
292
- "conv_classifier.bias": "final_layer.conv_classifier.bias"
376
+ "conv_classifier.bias": "final_layer.conv_classifier.bias",
293
377
  }
294
378
 
295
379
  pool_class = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)[self.pool_mode]
@@ -303,11 +387,9 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
303
387
  "bnorm_1",
304
388
  nn.BatchNorm2d(n_filters_1, momentum=0.01, affine=True, eps=1e-3),
305
389
  )
306
- self.add_module("elu_1", Expression(elu))
390
+ self.add_module("elu_1", activation())
307
391
  # transpose to examples x 1 x (virtual, not EEG) channels x time
308
- self.add_module(
309
- "permute_1", Expression(lambda x: x.permute(0, 3, 1, 2))
310
- )
392
+ self.add_module("permute_1", Rearrange("batch x y z -> batch z x y"))
311
393
 
312
394
  self.add_module("drop_1", nn.Dropout(p=self.drop_prob))
313
395
 
@@ -332,7 +414,7 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
332
414
  "bnorm_2",
333
415
  nn.BatchNorm2d(n_filters_2, momentum=0.01, affine=True, eps=1e-3),
334
416
  )
335
- self.add_module("elu_2", Expression(elu))
417
+ self.add_module("elu_2", activation())
336
418
  self.add_module("pool_2", pool_class(kernel_size=(2, 4), stride=(2, 4)))
337
419
  self.add_module("drop_2", nn.Dropout(p=self.drop_prob))
338
420
 
@@ -352,7 +434,7 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
352
434
  "bnorm_3",
353
435
  nn.BatchNorm2d(n_filters_3, momentum=0.01, affine=True, eps=1e-3),
354
436
  )
355
- self.add_module("elu_3", Expression(elu))
437
+ self.add_module("elu_3", activation())
356
438
  self.add_module("pool_3", pool_class(kernel_size=(2, 4), stride=(2, 4)))
357
439
  self.add_module("drop_3", nn.Dropout(p=self.drop_prob))
358
440
 
@@ -366,40 +448,26 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
366
448
  # Incorporating classification module and subsequent ones in one final layer
367
449
  module = nn.Sequential()
368
450
 
369
- module.add_module("conv_classifier",
370
- nn.Conv2d(n_filters_3, self.n_outputs,
371
- (n_out_virtual_chans, self.final_conv_length), bias=True, ))
451
+ module.add_module(
452
+ "conv_classifier",
453
+ nn.Conv2d(
454
+ n_filters_3,
455
+ self.n_outputs,
456
+ (n_out_virtual_chans, self.final_conv_length),
457
+ bias=True,
458
+ ),
459
+ )
372
460
 
373
- if self.add_log_softmax:
374
- module.add_module("softmax", nn.LogSoftmax(dim=1))
375
461
  # Transpose back to the logic of braindecode,
376
462
 
377
463
  # so time in third dimension (axis=2)
378
- module.add_module("permute_2", Rearrange("batch x y z -> batch x z y"), )
464
+ module.add_module(
465
+ "permute_2",
466
+ Rearrange("batch x y z -> batch x z y"),
467
+ )
379
468
 
380
- module.add_module("squeeze", Expression(squeeze_final_output))
469
+ module.add_module("squeeze", SqueezeFinalOutput())
381
470
 
382
471
  self.add_module("final_layer", module)
383
472
 
384
- _glorot_weight_zero_bias(self)
385
-
386
-
387
- def _glorot_weight_zero_bias(model):
388
- """Initialize parameters of all modules by initializing weights with
389
- glorot
390
- uniform/xavier initialization, and setting biases to zero. Weights from
391
- batch norm layers are set to 1.
392
-
393
- Parameters
394
- ----------
395
- model: Module
396
- """
397
- for module in model.modules():
398
- if hasattr(module, "weight"):
399
- if "BatchNorm" not in module.__class__.__name__:
400
- nn.init.xavier_uniform_(module.weight, gain=1)
401
- else:
402
- nn.init.constant_(module.weight, 1)
403
- if hasattr(module, "bias"):
404
- if module.bias is not None:
405
- nn.init.constant_(module.bias, 0)
473
+ glorot_weight_zero_bias(self)