braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,85 +1,110 @@
1
1
  # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
2
  #
3
3
  # License: BSD (3-clause)
4
+ from __future__ import annotations
5
+
6
+ from typing import Dict, Optional
4
7
 
5
- import torch
6
8
  from einops.layers.torch import Rearrange
9
+ from mne.utils import warn
7
10
  from torch import nn
8
- from torch.nn.functional import elu
9
-
10
- from .base import EEGModuleMixin, deprecated_args
11
- from .functions import squeeze_final_output
12
- from .modules import Ensure4d, Expression
13
-
14
11
 
15
- class Conv2dWithConstraint(nn.Conv2d):
16
- def __init__(self, *args, max_norm=1, **kwargs):
17
- self.max_norm = max_norm
18
- super(Conv2dWithConstraint, self).__init__(*args, **kwargs)
19
-
20
- def forward(self, x):
21
- self.weight.data = torch.renorm(
22
- self.weight.data, p=2, dim=0, maxnorm=self.max_norm
23
- )
24
- return super(Conv2dWithConstraint, self).forward(x)
12
+ from braindecode.functional import glorot_weight_zero_bias
13
+ from braindecode.models.base import EEGModuleMixin
14
+ from braindecode.modules import (
15
+ Conv2dWithConstraint,
16
+ Ensure4d,
17
+ LinearWithConstraint,
18
+ SqueezeFinalOutput,
19
+ )
25
20
 
26
21
 
27
22
  class EEGNetv4(EEGModuleMixin, nn.Sequential):
28
- """EEGNet v4 model from Lawhern et al 2018.
23
+ """EEGNet v4 model from Lawhern et al. (2018) [EEGNet4]_.
24
+
25
+ .. figure:: https://content.cld.iop.org/journals/1741-2552/15/5/056013/revision2/jneaace8cf01_hr.jpg
26
+ :align: center
27
+ :alt: EEGNet4 Architecture
29
28
 
30
29
  See details in [EEGNet4]_.
31
30
 
32
31
  Parameters
33
32
  ----------
34
- final_conv_length : int | "auto"
35
- If int, final length of convolutional filters.
36
- in_chans :
37
- Alias for n_chans.
38
- n_classes:
39
- Alias for n_outputs.
40
- input_window_samples :
41
- Alias for n_times.
42
-
43
- Notes
44
- -----
45
- This implementation is not guaranteed to be correct, has not been checked
46
- by original authors, only reimplemented from the paper description.
33
+ final_conv_length : int or "auto", default="auto"
34
+ Length of the final convolution layer. If "auto", it is set based on n_times.
35
+ pool_mode : {"mean", "max"}, default="mean"
36
+ Pooling method to use in pooling layers.
37
+ F1 : int, default=8
38
+ Number of temporal filters in the first convolutional layer.
39
+ D : int, default=2
40
+ Depth multiplier for the depthwise convolution.
41
+ F2 : int or None, default=None
42
+ Number of pointwise filters in the separable convolution. Usually set to ``F1 * D``.
43
+ depthwise_kernel_length : int, default=16
44
+ Length of the depthwise convolution kernel in the separable convolution.
45
+ pool1_kernel_size : int, default=4
46
+ Kernel size of the first pooling layer.
47
+ pool2_kernel_size : int, default=8
48
+ Kernel size of the second pooling layer.
49
+ kernel_length : int, default=64
50
+ Length of the temporal convolution kernel.
51
+ conv_spatial_max_norm : float, default=1
52
+ Maximum norm constraint for the spatial (depthwise) convolution.
53
+ activation : nn.Module, default=nn.ELU
54
+ Non-linear activation function to be used in the layers.
55
+ batch_norm_momentum : float, default=0.01
56
+ Momentum for instance normalization in batch norm layers.
57
+ batch_norm_affine : bool, default=True
58
+ If True, batch norm has learnable affine parameters.
59
+ batch_norm_eps : float, default=1e-3
60
+ Epsilon for numeric stability in batch norm layers.
61
+ drop_prob : float, default=0.25
62
+ Dropout probability.
63
+ final_layer_with_constraint : bool, default=False
64
+ If ``False``, uses a convolution-based classification layer. If ``True``,
65
+ apply a flattened linear layer with constraint on the weights norm as the final classification step.
66
+ norm_rate : float, default=0.25
67
+ Max-norm constraint value for the linear layer (used if ``final_layer_conv=False``).
47
68
 
48
69
  References
49
70
  ----------
50
- .. [EEGNet4] Lawhern, V. J., Solon, A. J., Waytowich, N. R., Gordon,
51
- S. M., Hung, C. P., & Lance, B. J. (2018).
52
- EEGNet: A Compact Convolutional Network for EEG-based
53
- Brain-Computer Interfaces.
54
- arXiv preprint arXiv:1611.08024.
71
+ .. [EEGNet4] Lawhern, V. J., Solon, A. J., Waytowich, N. R., Gordon, S. M.,
72
+ Hung, C. P., & Lance, B. J. (2018). EEGNet: a compact convolutional
73
+ neural network for EEG-based brain–computer interfaces. Journal of
74
+ neural engineering, 15(5), 056013.
55
75
  """
56
76
 
57
77
  def __init__(
58
- self,
59
- n_chans=None,
60
- n_outputs=None,
61
- n_times=None,
62
- final_conv_length="auto",
63
- pool_mode="mean",
64
- F1=8,
65
- D=2,
66
- F2=16, # usually set to F1*D (?)
67
- kernel_length=64,
68
- third_kernel_size=(8, 4),
69
- drop_prob=0.25,
70
- chs_info=None,
71
- input_window_seconds=None,
72
- sfreq=None,
73
- in_chans=None,
74
- n_classes=None,
75
- input_window_samples=None,
78
+ self,
79
+ # signal's parameters
80
+ n_chans: Optional[int] = None,
81
+ n_outputs: Optional[int] = None,
82
+ n_times: Optional[int] = None,
83
+ # model's parameters
84
+ final_conv_length: str | int = "auto",
85
+ pool_mode: str = "mean",
86
+ F1: int = 8,
87
+ D: int = 2,
88
+ F2: Optional[int | None] = None,
89
+ kernel_length: int = 64,
90
+ *,
91
+ depthwise_kernel_length: int = 16,
92
+ pool1_kernel_size: int = 4,
93
+ pool2_kernel_size: int = 8,
94
+ conv_spatial_max_norm: int = 1,
95
+ activation: nn.Module = nn.ELU,
96
+ batch_norm_momentum: float = 0.01,
97
+ batch_norm_affine: bool = True,
98
+ batch_norm_eps: float = 1e-3,
99
+ drop_prob: float = 0.25,
100
+ final_layer_with_constraint: bool = False,
101
+ norm_rate: float = 0.25,
102
+ # Other ways to construct the signal related parameters
103
+ chs_info: Optional[list[Dict]] = None,
104
+ input_window_seconds=None,
105
+ sfreq=None,
106
+ **kwargs,
76
107
  ):
77
- n_chans, n_outputs, n_times = deprecated_args(
78
- self,
79
- ("in_chans", "n_chans", in_chans, n_chans),
80
- ("n_classes", "n_outputs", n_classes, n_outputs),
81
- ("input_window_samples", "n_times", input_window_samples, n_times),
82
- )
83
108
  super().__init__(
84
109
  n_outputs=n_outputs,
85
110
  n_chans=n_chans,
@@ -89,68 +114,106 @@ class EEGNetv4(EEGModuleMixin, nn.Sequential):
89
114
  sfreq=sfreq,
90
115
  )
91
116
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
92
- del in_chans, n_classes, input_window_samples
93
117
  if final_conv_length == "auto":
94
118
  assert self.n_times is not None
119
+
120
+ if not final_layer_with_constraint:
121
+ warn(
122
+ "Parameter 'final_layer_with_constraint=False' is deprecated and will be "
123
+ "removed in a future release. Please use `final_layer_linear=True`.",
124
+ DeprecationWarning,
125
+ )
126
+
127
+ if "third_kernel_size" in kwargs:
128
+ warn(
129
+ "The parameter `third_kernel_size` is deprecated "
130
+ "and will be removed in a future version.",
131
+ )
132
+ unexpected_kwargs = set(kwargs) - {"third_kernel_size"}
133
+ if unexpected_kwargs:
134
+ raise TypeError(f"Unexpected keyword arguments: {unexpected_kwargs}")
135
+
95
136
  self.final_conv_length = final_conv_length
96
137
  self.pool_mode = pool_mode
97
138
  self.F1 = F1
98
139
  self.D = D
140
+
141
+ if F2 is None:
142
+ F2 = self.F1 * self.D
99
143
  self.F2 = F2
144
+
100
145
  self.kernel_length = kernel_length
101
- self.third_kernel_size = third_kernel_size
146
+ self.depthwise_kernel_length = depthwise_kernel_length
147
+ self.pool1_kernel_size = pool1_kernel_size
148
+ self.pool2_kernel_size = pool2_kernel_size
102
149
  self.drop_prob = drop_prob
150
+ self.activation = activation
151
+ self.batch_norm_momentum = batch_norm_momentum
152
+ self.batch_norm_affine = batch_norm_affine
153
+ self.batch_norm_eps = batch_norm_eps
154
+ self.conv_spatial_max_norm = conv_spatial_max_norm
155
+ self.norm_rate = norm_rate
156
+
103
157
  # For the load_state_dict
104
158
  # When padronize all layers,
105
159
  # add the old's parameters here
106
160
  self.mapping = {
107
161
  "conv_classifier.weight": "final_layer.conv_classifier.weight",
108
- "conv_classifier.bias": "final_layer.conv_classifier.bias"
162
+ "conv_classifier.bias": "final_layer.conv_classifier.bias",
109
163
  }
110
164
 
111
165
  pool_class = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)[self.pool_mode]
112
166
  self.add_module("ensuredims", Ensure4d())
113
167
 
114
- self.add_module("dimshuffle",
115
- Rearrange("batch ch t 1 -> batch 1 ch t"))
168
+ self.add_module("dimshuffle", Rearrange("batch ch t 1 -> batch 1 ch t"))
116
169
  self.add_module(
117
170
  "conv_temporal",
118
171
  nn.Conv2d(
119
172
  1,
120
173
  self.F1,
121
174
  (1, self.kernel_length),
122
- stride=1,
123
175
  bias=False,
124
176
  padding=(0, self.kernel_length // 2),
125
177
  ),
126
178
  )
127
179
  self.add_module(
128
180
  "bnorm_temporal",
129
- nn.BatchNorm2d(self.F1, momentum=0.01, affine=True, eps=1e-3),
181
+ nn.BatchNorm2d(
182
+ self.F1,
183
+ momentum=self.batch_norm_momentum,
184
+ affine=self.batch_norm_affine,
185
+ eps=self.batch_norm_eps,
186
+ ),
130
187
  )
131
188
  self.add_module(
132
189
  "conv_spatial",
133
190
  Conv2dWithConstraint(
134
- self.F1,
135
- self.F1 * self.D,
136
- (self.n_chans, 1),
137
- max_norm=1,
138
- stride=1,
191
+ in_channels=self.F1,
192
+ out_channels=self.F1 * self.D,
193
+ kernel_size=(self.n_chans, 1),
194
+ max_norm=self.conv_spatial_max_norm,
139
195
  bias=False,
140
196
  groups=self.F1,
141
- padding=(0, 0),
142
197
  ),
143
198
  )
144
199
 
145
200
  self.add_module(
146
201
  "bnorm_1",
147
202
  nn.BatchNorm2d(
148
- self.F1 * self.D, momentum=0.01, affine=True, eps=1e-3
203
+ self.F1 * self.D,
204
+ momentum=self.batch_norm_momentum,
205
+ affine=self.batch_norm_affine,
206
+ eps=self.batch_norm_eps,
149
207
  ),
150
208
  )
151
- self.add_module("elu_1", Expression(elu))
209
+ self.add_module("elu_1", activation())
152
210
 
153
- self.add_module("pool_1", pool_class(kernel_size=(1, 4), stride=(1, 4)))
211
+ self.add_module(
212
+ "pool_1",
213
+ pool_class(
214
+ kernel_size=(1, self.pool1_kernel_size),
215
+ ),
216
+ )
154
217
  self.add_module("drop_1", nn.Dropout(p=self.drop_prob))
155
218
 
156
219
  # https://discuss.pytorch.org/t/how-to-modify-a-conv2d-to-depthwise-separable-convolution/15843/7
@@ -159,11 +222,10 @@ class EEGNetv4(EEGModuleMixin, nn.Sequential):
159
222
  nn.Conv2d(
160
223
  self.F1 * self.D,
161
224
  self.F1 * self.D,
162
- (1, 16),
163
- stride=1,
225
+ (1, self.depthwise_kernel_length),
164
226
  bias=False,
165
227
  groups=self.F1 * self.D,
166
- padding=(0, 16 // 2),
228
+ padding=(0, self.depthwise_kernel_length // 2),
167
229
  ),
168
230
  )
169
231
  self.add_module(
@@ -171,19 +233,27 @@ class EEGNetv4(EEGModuleMixin, nn.Sequential):
171
233
  nn.Conv2d(
172
234
  self.F1 * self.D,
173
235
  self.F2,
174
- (1, 1),
175
- stride=1,
236
+ kernel_size=(1, 1),
176
237
  bias=False,
177
- padding=(0, 0),
178
238
  ),
179
239
  )
180
240
 
181
241
  self.add_module(
182
242
  "bnorm_2",
183
- nn.BatchNorm2d(self.F2, momentum=0.01, affine=True, eps=1e-3),
243
+ nn.BatchNorm2d(
244
+ self.F2,
245
+ momentum=self.batch_norm_momentum,
246
+ affine=self.batch_norm_affine,
247
+ eps=self.batch_norm_eps,
248
+ ),
249
+ )
250
+ self.add_module("elu_2", self.activation())
251
+ self.add_module(
252
+ "pool_2",
253
+ pool_class(
254
+ kernel_size=(1, self.pool2_kernel_size),
255
+ ),
184
256
  )
185
- self.add_module("elu_2", Expression(elu))
186
- self.add_module("pool_2", pool_class(kernel_size=(1, 8), stride=(1, 8)))
187
257
  self.add_module("drop_2", nn.Dropout(p=self.drop_prob))
188
258
 
189
259
  output_shape = self.get_output_shape()
@@ -195,27 +265,42 @@ class EEGNetv4(EEGModuleMixin, nn.Sequential):
195
265
 
196
266
  # Incorporating classification module and subsequent ones in one final layer
197
267
  module = nn.Sequential()
198
-
199
- module.add_module("conv_classifier",
200
- nn.Conv2d(self.F2, self.n_outputs,
201
- (n_out_virtual_chans, self.final_conv_length), bias=True, ))
202
-
203
- if self.add_log_softmax:
204
- module.add_module("logsoftmax", nn.LogSoftmax(dim=1))
205
-
206
- # Transpose back to the logic of braindecode,
207
- # so time in third dimension (axis=2)
208
- module.add_module("permute_back", Rearrange("batch x y z -> batch x z y"), )
209
-
210
- module.add_module("squeeze", Expression(squeeze_final_output))
211
-
268
+ if not final_layer_with_constraint:
269
+ module.add_module(
270
+ "conv_classifier",
271
+ nn.Conv2d(
272
+ self.F2,
273
+ self.n_outputs,
274
+ (n_out_virtual_chans, self.final_conv_length),
275
+ bias=True,
276
+ ),
277
+ )
278
+
279
+ # Transpose back to the logic of braindecode,
280
+ # so time in third dimension (axis=2)
281
+ module.add_module(
282
+ "permute_back",
283
+ Rearrange("batch x y z -> batch x z y"),
284
+ )
285
+
286
+ module.add_module("squeeze", SqueezeFinalOutput())
287
+ else:
288
+ module.add_module("flatten", nn.Flatten())
289
+ module.add_module(
290
+ "linearconstraint",
291
+ LinearWithConstraint(
292
+ in_features=self.F2 * self.final_conv_length,
293
+ out_features=self.n_outputs,
294
+ max_norm=norm_rate,
295
+ ),
296
+ )
212
297
  self.add_module("final_layer", module)
213
298
 
214
- _glorot_weight_zero_bias(self)
299
+ glorot_weight_zero_bias(self)
215
300
 
216
301
 
217
302
  class EEGNetv1(EEGModuleMixin, nn.Sequential):
218
- """EEGNet model from Lawhern et al. 2016.
303
+ """EEGNet model from Lawhern et al. 2016 from [EEGNet]_.
219
304
 
220
305
  See details in [EEGNet]_.
221
306
 
@@ -227,6 +312,9 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
227
312
  Alias for n_outputs.
228
313
  input_window_samples :
229
314
  Alias for n_times.
315
+ activation: nn.Module, default=nn.ELU
316
+ Activation function class to apply. Should be a PyTorch activation
317
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
230
318
 
231
319
  Notes
232
320
  -----
@@ -243,29 +331,20 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
243
331
  """
244
332
 
245
333
  def __init__(
246
- self,
247
- n_chans=None,
248
- n_outputs=None,
249
- n_times=None,
250
- final_conv_length="auto",
251
- pool_mode="max",
252
- second_kernel_size=(2, 32),
253
- third_kernel_size=(8, 4),
254
- drop_prob=0.25,
255
- chs_info=None,
256
- input_window_seconds=None,
257
- sfreq=None,
258
- in_chans=None,
259
- n_classes=None,
260
- input_window_samples=None,
261
- add_log_softmax=True,
334
+ self,
335
+ n_chans=None,
336
+ n_outputs=None,
337
+ n_times=None,
338
+ final_conv_length="auto",
339
+ pool_mode="max",
340
+ second_kernel_size=(2, 32),
341
+ third_kernel_size=(8, 4),
342
+ drop_prob=0.25,
343
+ activation: nn.Module = nn.ELU,
344
+ chs_info=None,
345
+ input_window_seconds=None,
346
+ sfreq=None,
262
347
  ):
263
- n_chans, n_outputs, n_times = deprecated_args(
264
- self,
265
- ("in_chans", "n_chans", in_chans, n_chans),
266
- ("n_classes", "n_outputs", n_classes, n_outputs),
267
- ("input_window_samples", "n_times", input_window_samples, n_times),
268
- )
269
348
  super().__init__(
270
349
  n_outputs=n_outputs,
271
350
  n_chans=n_chans,
@@ -273,10 +352,14 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
273
352
  n_times=n_times,
274
353
  input_window_seconds=input_window_seconds,
275
354
  sfreq=sfreq,
276
- add_log_softmax=add_log_softmax,
277
355
  )
278
356
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
279
- del in_chans, n_classes, input_window_samples
357
+ warn(
358
+ "The class EEGNetv1 is deprecated and will be removed in the "
359
+ "release 1.0 of braindecode. Please use "
360
+ "braindecode.models.EEGNetv4 instead in the future.",
361
+ DeprecationWarning,
362
+ )
280
363
  if final_conv_length == "auto":
281
364
  assert self.n_times is not None
282
365
  self.final_conv_length = final_conv_length
@@ -289,7 +372,7 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
289
372
  # add the old's parameters here
290
373
  self.mapping = {
291
374
  "conv_classifier.weight": "final_layer.conv_classifier.weight",
292
- "conv_classifier.bias": "final_layer.conv_classifier.bias"
375
+ "conv_classifier.bias": "final_layer.conv_classifier.bias",
293
376
  }
294
377
 
295
378
  pool_class = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)[self.pool_mode]
@@ -303,11 +386,9 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
303
386
  "bnorm_1",
304
387
  nn.BatchNorm2d(n_filters_1, momentum=0.01, affine=True, eps=1e-3),
305
388
  )
306
- self.add_module("elu_1", Expression(elu))
389
+ self.add_module("elu_1", activation())
307
390
  # transpose to examples x 1 x (virtual, not EEG) channels x time
308
- self.add_module(
309
- "permute_1", Expression(lambda x: x.permute(0, 3, 1, 2))
310
- )
391
+ self.add_module("permute_1", Rearrange("batch x y z -> batch z x y"))
311
392
 
312
393
  self.add_module("drop_1", nn.Dropout(p=self.drop_prob))
313
394
 
@@ -332,7 +413,7 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
332
413
  "bnorm_2",
333
414
  nn.BatchNorm2d(n_filters_2, momentum=0.01, affine=True, eps=1e-3),
334
415
  )
335
- self.add_module("elu_2", Expression(elu))
416
+ self.add_module("elu_2", activation())
336
417
  self.add_module("pool_2", pool_class(kernel_size=(2, 4), stride=(2, 4)))
337
418
  self.add_module("drop_2", nn.Dropout(p=self.drop_prob))
338
419
 
@@ -352,7 +433,7 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
352
433
  "bnorm_3",
353
434
  nn.BatchNorm2d(n_filters_3, momentum=0.01, affine=True, eps=1e-3),
354
435
  )
355
- self.add_module("elu_3", Expression(elu))
436
+ self.add_module("elu_3", activation())
356
437
  self.add_module("pool_3", pool_class(kernel_size=(2, 4), stride=(2, 4)))
357
438
  self.add_module("drop_3", nn.Dropout(p=self.drop_prob))
358
439
 
@@ -366,40 +447,26 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
366
447
  # Incorporating classification module and subsequent ones in one final layer
367
448
  module = nn.Sequential()
368
449
 
369
- module.add_module("conv_classifier",
370
- nn.Conv2d(n_filters_3, self.n_outputs,
371
- (n_out_virtual_chans, self.final_conv_length), bias=True, ))
450
+ module.add_module(
451
+ "conv_classifier",
452
+ nn.Conv2d(
453
+ n_filters_3,
454
+ self.n_outputs,
455
+ (n_out_virtual_chans, self.final_conv_length),
456
+ bias=True,
457
+ ),
458
+ )
372
459
 
373
- if self.add_log_softmax:
374
- module.add_module("softmax", nn.LogSoftmax(dim=1))
375
460
  # Transpose back to the logic of braindecode,
376
461
 
377
462
  # so time in third dimension (axis=2)
378
- module.add_module("permute_2", Rearrange("batch x y z -> batch x z y"), )
463
+ module.add_module(
464
+ "permute_2",
465
+ Rearrange("batch x y z -> batch x z y"),
466
+ )
379
467
 
380
- module.add_module("squeeze", Expression(squeeze_final_output))
468
+ module.add_module("squeeze", SqueezeFinalOutput())
381
469
 
382
470
  self.add_module("final_layer", module)
383
471
 
384
- _glorot_weight_zero_bias(self)
385
-
386
-
387
- def _glorot_weight_zero_bias(model):
388
- """Initialize parameters of all modules by initializing weights with
389
- glorot
390
- uniform/xavier initialization, and setting biases to zero. Weights from
391
- batch norm layers are set to 1.
392
-
393
- Parameters
394
- ----------
395
- model: Module
396
- """
397
- for module in model.modules():
398
- if hasattr(module, "weight"):
399
- if "BatchNorm" not in module.__class__.__name__:
400
- nn.init.xavier_uniform_(module.weight, gain=1)
401
- else:
402
- nn.init.constant_(module.weight, 1)
403
- if hasattr(module, "bias"):
404
- if module.bias is not None:
405
- nn.init.constant_(module.bias, 0)
472
+ glorot_weight_zero_bias(self)