braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -3,17 +3,25 @@
3
3
  # License: BSD (3-clause)
4
4
 
5
5
  from einops.layers.torch import Rearrange
6
+ from mne.utils import warn
6
7
  from torch import nn
7
8
  from torch.nn import init
8
- from torch.nn.functional import elu
9
9
 
10
- from .base import EEGModuleMixin, deprecated_args
11
- from .functions import identity, squeeze_final_output
12
- from .modules import AvgPool2dWithConv, CombinedConv, Ensure4d, Expression
10
+ from braindecode.models.base import EEGModuleMixin
11
+ from braindecode.modules import (
12
+ AvgPool2dWithConv,
13
+ CombinedConv,
14
+ Ensure4d,
15
+ SqueezeFinalOutput,
16
+ )
13
17
 
14
18
 
15
19
  class Deep4Net(EEGModuleMixin, nn.Sequential):
16
- """Deep ConvNet model from Schirrmeister et al 2017.
20
+ """Deep ConvNet model from Schirrmeister et al (2017) [Schirrmeister2017]_.
21
+
22
+ .. figure:: https://onlinelibrary.wiley.com/cms/asset/fc200ccc-d8c4-45b4-8577-56ce4d15999a/hbm23730-fig-0001-m.jpg
23
+ :align: center
24
+ :alt: CTNet Architecture
17
25
 
18
26
  Model described in [Schirrmeister2017]_.
19
27
 
@@ -44,13 +52,13 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
44
52
  Number of temporal filters in layer 4.
45
53
  filter_length_4: int
46
54
  Length of the temporal filter in layer 4.
47
- first_conv_nonlin: callable
55
+ activation_first_conv_nonlin: nn.Module, default is nn.ELU
48
56
  Non-linear activation function to be used after convolution in layer 1.
49
57
  first_pool_mode: str
50
58
  Pooling mode in layer 1. "max" or "mean".
51
59
  first_pool_nonlin: callable
52
60
  Non-linear activation function to be used after pooling in layer 1.
53
- later_conv_nonlin: callable
61
+ activation_later_conv_nonlin: nn.Module, default is nn.ELU
54
62
  Non-linear activation function to be used after convolution in later layers.
55
63
  later_pool_mode: str
56
64
  Pooling mode in later layers. "max" or "mean".
@@ -67,12 +75,6 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
67
75
  Momentum for BatchNorm2d.
68
76
  stride_before_pool: bool
69
77
  Stride before pooling.
70
- in_chans :
71
- Alias for n_chans.
72
- n_classes:
73
- Alias for n_outputs.
74
- input_window_samples :
75
- Alias for n_times.
76
78
 
77
79
 
78
80
  References
@@ -87,47 +89,38 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
87
89
  """
88
90
 
89
91
  def __init__(
90
- self,
91
- n_chans=None,
92
- n_outputs=None,
93
- n_times=None,
94
- final_conv_length="auto",
95
- n_filters_time=25,
96
- n_filters_spat=25,
97
- filter_time_length=10,
98
- pool_time_length=3,
99
- pool_time_stride=3,
100
- n_filters_2=50,
101
- filter_length_2=10,
102
- n_filters_3=100,
103
- filter_length_3=10,
104
- n_filters_4=200,
105
- filter_length_4=10,
106
- first_conv_nonlin=elu,
107
- first_pool_mode="max",
108
- first_pool_nonlin=identity,
109
- later_conv_nonlin=elu,
110
- later_pool_mode="max",
111
- later_pool_nonlin=identity,
112
- drop_prob=0.5,
113
- split_first_layer=True,
114
- batch_norm=True,
115
- batch_norm_alpha=0.1,
116
- stride_before_pool=False,
117
- chs_info=None,
118
- input_window_seconds=None,
119
- sfreq=None,
120
- in_chans=None,
121
- n_classes=None,
122
- input_window_samples=None,
123
- add_log_softmax=True,
92
+ self,
93
+ n_chans=None,
94
+ n_outputs=None,
95
+ n_times=None,
96
+ final_conv_length="auto",
97
+ n_filters_time=25,
98
+ n_filters_spat=25,
99
+ filter_time_length=10,
100
+ pool_time_length=3,
101
+ pool_time_stride=3,
102
+ n_filters_2=50,
103
+ filter_length_2=10,
104
+ n_filters_3=100,
105
+ filter_length_3=10,
106
+ n_filters_4=200,
107
+ filter_length_4=10,
108
+ activation_first_conv_nonlin: nn.Module = nn.ELU,
109
+ first_pool_mode="max",
110
+ first_pool_nonlin: nn.Module = nn.Identity,
111
+ activation_later_conv_nonlin: nn.Module = nn.ELU,
112
+ later_pool_mode="max",
113
+ later_pool_nonlin: nn.Module = nn.Identity,
114
+ drop_prob=0.5,
115
+ split_first_layer=True,
116
+ batch_norm=True,
117
+ batch_norm_alpha=0.1,
118
+ stride_before_pool=False,
119
+ # Braindecode EEGModuleMixin parameters
120
+ chs_info=None,
121
+ input_window_seconds=None,
122
+ sfreq=None,
124
123
  ):
125
- n_chans, n_outputs, n_times = deprecated_args(
126
- self,
127
- ('in_chans', 'n_chans', in_chans, n_chans),
128
- ('n_classes', 'n_outputs', n_classes, n_outputs),
129
- ('input_window_samples', 'n_times', input_window_samples, n_times),
130
- )
131
124
  super().__init__(
132
125
  n_outputs=n_outputs,
133
126
  n_chans=n_chans,
@@ -135,10 +128,9 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
135
128
  n_times=n_times,
136
129
  input_window_seconds=input_window_seconds,
137
130
  sfreq=sfreq,
138
- add_log_softmax=add_log_softmax,
139
131
  )
140
132
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
141
- del in_chans, n_classes, input_window_samples
133
+
142
134
  if final_conv_length == "auto":
143
135
  assert self.n_times is not None
144
136
  self.final_conv_length = final_conv_length
@@ -153,10 +145,10 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
153
145
  self.filter_length_3 = filter_length_3
154
146
  self.n_filters_4 = n_filters_4
155
147
  self.filter_length_4 = filter_length_4
156
- self.first_nonlin = first_conv_nonlin
148
+ self.first_nonlin = activation_first_conv_nonlin
157
149
  self.first_pool_mode = first_pool_mode
158
150
  self.first_pool_nonlin = first_pool_nonlin
159
- self.later_conv_nonlin = later_conv_nonlin
151
+ self.later_conv_nonlin = activation_later_conv_nonlin
160
152
  self.later_pool_mode = later_pool_mode
161
153
  self.later_pool_nonlin = later_pool_nonlin
162
154
  self.drop_prob = drop_prob
@@ -165,6 +157,27 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
165
157
  self.batch_norm_alpha = batch_norm_alpha
166
158
  self.stride_before_pool = stride_before_pool
167
159
 
160
+ min_n_times = self._get_min_n_times()
161
+ if self.n_times < min_n_times:
162
+ scaling_factor = self.n_times / min_n_times
163
+ warn(
164
+ f"n_times ({self.n_times}) is smaller than the minimum required "
165
+ f"({min_n_times}) for the current model parameters configuration. "
166
+ "Adjusting parameters to ensure compatibility."
167
+ "Reducing the kernel, pooling, and stride sizes accordingly."
168
+ "Scaling factor: {:.2f}".format(scaling_factor),
169
+ UserWarning,
170
+ )
171
+ # Calculate a scaling factor to adjust temporal parameters
172
+ # Apply the scaling factor to all temporal kernel and pooling sizes
173
+ self.filter_time_length = max(
174
+ 1, int(self.filter_time_length * scaling_factor)
175
+ )
176
+ self.pool_time_length = max(1, int(self.pool_time_length * scaling_factor))
177
+ self.pool_time_stride = max(1, int(self.pool_time_stride * scaling_factor))
178
+ self.filter_length_2 = max(1, int(self.filter_length_2 * scaling_factor))
179
+ self.filter_length_3 = max(1, int(self.filter_length_3 * scaling_factor))
180
+ self.filter_length_4 = max(1, int(self.filter_length_4 * scaling_factor))
168
181
  # For the load_state_dict
169
182
  # When padronize all layers,
170
183
  # add the old's parameters here
@@ -174,7 +187,7 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
174
187
  "conv_time.bias": "conv_time_spat.conv_time.bias",
175
188
  "conv_spat.bias": "conv_time_spat.conv_spat.bias",
176
189
  "conv_classifier.weight": "final_layer.conv_classifier.weight",
177
- "conv_classifier.bias": "final_layer.conv_classifier.bias"
190
+ "conv_classifier.bias": "final_layer.conv_classifier.bias",
178
191
  }
179
192
 
180
193
  if self.stride_before_pool:
@@ -223,17 +236,17 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
223
236
  eps=1e-5,
224
237
  ),
225
238
  )
226
- self.add_module("conv_nonlin", Expression(self.first_nonlin))
239
+ self.add_module("conv_nonlin", self.first_nonlin())
227
240
  self.add_module(
228
241
  "pool",
229
242
  first_pool_class(
230
243
  kernel_size=(self.pool_time_length, 1), stride=(pool_stride, 1)
231
244
  ),
232
245
  )
233
- self.add_module("pool_nonlin", Expression(self.first_pool_nonlin))
246
+ self.add_module("pool_nonlin", self.first_pool_nonlin())
234
247
 
235
248
  def add_conv_pool_block(
236
- model, n_filters_before, n_filters, filter_length, block_nr
249
+ model, n_filters_before, n_filters, filter_length, block_nr
237
250
  ):
238
251
  suffix = "_{:d}".format(block_nr)
239
252
  self.add_module("drop" + suffix, nn.Dropout(p=self.drop_prob))
@@ -257,7 +270,7 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
257
270
  eps=1e-5,
258
271
  ),
259
272
  )
260
- self.add_module("nonlin" + suffix, Expression(self.later_conv_nonlin))
273
+ self.add_module("nonlin" + suffix, self.later_conv_nonlin())
261
274
 
262
275
  self.add_module(
263
276
  "pool" + suffix,
@@ -266,7 +279,7 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
266
279
  stride=(pool_stride, 1),
267
280
  ),
268
281
  )
269
- self.add_module("pool_nonlin" + suffix, Expression(self.later_pool_nonlin))
282
+ self.add_module("pool_nonlin" + suffix, self.later_pool_nonlin())
270
283
 
271
284
  add_conv_pool_block(
272
285
  self, n_filters_conv, self.n_filters_2, self.filter_length_2, 2
@@ -278,7 +291,6 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
278
291
  self, self.n_filters_3, self.n_filters_4, self.filter_length_4, 4
279
292
  )
280
293
 
281
- # self.add_module('drop_classifier', nn.Dropout(p=self.drop_prob))
282
294
  self.eval()
283
295
  if self.final_conv_length == "auto":
284
296
  self.final_conv_length = self.get_output_shape()[2]
@@ -286,17 +298,17 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
286
298
  # Incorporating classification module and subsequent ones in one final layer
287
299
  module = nn.Sequential()
288
300
 
289
- module.add_module("conv_classifier",
290
- nn.Conv2d(
291
- self.n_filters_4,
292
- self.n_outputs,
293
- (self.final_conv_length, 1),
294
- bias=True, ))
295
-
296
- if self.add_log_softmax:
297
- module.add_module("logsoftmax", nn.LogSoftmax(dim=1))
301
+ module.add_module(
302
+ "conv_classifier",
303
+ nn.Conv2d(
304
+ self.n_filters_4,
305
+ self.n_outputs,
306
+ (self.final_conv_length, 1),
307
+ bias=True,
308
+ ),
309
+ )
298
310
 
299
- module.add_module("squeeze", Expression(squeeze_final_output))
311
+ module.add_module("squeeze", SqueezeFinalOutput())
300
312
 
301
313
  self.add_module("final_layer", module)
302
314
 
@@ -309,7 +321,7 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
309
321
  if self.split_first_layer:
310
322
  init.xavier_uniform_(self.conv_time_spat.conv_spat.weight, gain=1)
311
323
  if not self.batch_norm:
312
- init.constant_(self.conv_spat.bias, 0)
324
+ init.constant_(self.conv_time_spat.conv_spat.bias, 0)
313
325
  if self.batch_norm:
314
326
  init.constant_(self.bnorm.weight, 1)
315
327
  init.constant_(self.bnorm.bias, 0)
@@ -329,5 +341,32 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
329
341
  init.xavier_uniform_(self.final_layer.conv_classifier.weight, gain=1)
330
342
  init.constant_(self.final_layer.conv_classifier.bias, 0)
331
343
 
332
- # Start in eval mode
333
- self.eval()
344
+ self.train()
345
+
346
+ def _get_min_n_times(self) -> int:
347
+ """
348
+ Calculate the minimum number of time samples required for the model
349
+ to work with the given temporal parameters.
350
+ """
351
+ # Start with the minimum valid output length of the network (1)
352
+ min_len = 1
353
+
354
+ # List of conv kernel sizes and pool parameters for the 4 blocks, in reverse order
355
+ # Each tuple: (filter_length, pool_length, pool_stride)
356
+ block_params = [
357
+ (self.filter_length_4, self.pool_time_length, self.pool_time_stride),
358
+ (self.filter_length_3, self.pool_time_length, self.pool_time_stride),
359
+ (self.filter_length_2, self.pool_time_length, self.pool_time_stride),
360
+ (self.filter_time_length, self.pool_time_length, self.pool_time_stride),
361
+ ]
362
+
363
+ # Work backward from the last layer to the input
364
+ for filter_len, pool_len, pool_stride in block_params:
365
+ # Reverse the pooling operation
366
+ # L_in = stride * (L_out - 1) + kernel_size
367
+ min_len = pool_stride * (min_len - 1) + pool_len
368
+ # Reverse the convolution operation (assuming stride=1)
369
+ # L_in = L_out + kernel_size - 1
370
+ min_len = min_len + filter_len - 1
371
+
372
+ return min_len
@@ -4,11 +4,133 @@
4
4
  import torch
5
5
  import torch.nn as nn
6
6
 
7
- from .base import EEGModuleMixin, deprecated_args
7
+ from braindecode.models.base import EEGModuleMixin
8
8
 
9
9
 
10
- class _SmallCNN(nn.Module): # smaller filter sizes to learn temporal information
11
- def __init__(self):
10
+ class DeepSleepNet(EEGModuleMixin, nn.Module):
11
+ """Sleep staging architecture from Supratak et al. (2017) [Supratak2017]_.
12
+
13
+ .. figure:: https://raw.githubusercontent.com/akaraspt/deepsleepnet/refs/heads/master/img/deepsleepnet.png
14
+ :align: center
15
+ :alt: DeepSleepNet Architecture
16
+
17
+ Convolutional neural network and bidirectional-Long Short-Term
18
+ for single channels sleep staging described in [Supratak2017]_.
19
+
20
+ Parameters
21
+ ----------
22
+ activation_large: nn.Module, default=nn.ELU
23
+ Activation function class to apply. Should be a PyTorch activation
24
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
25
+ activation_small: nn.Module, default=nn.ReLU
26
+ Activation function class to apply. Should be a PyTorch activation
27
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
28
+ return_feats : bool
29
+ If True, return the features, i.e. the output of the feature extractor
30
+ (before the final linear layer). If False, pass the features through
31
+ the final linear layer.
32
+ drop_prob : float, default=0.5
33
+ The dropout rate for regularization. Values should be between 0 and 1.
34
+
35
+
36
+ References
37
+ ----------
38
+ .. [Supratak2017] Supratak, A., Dong, H., Wu, C., & Guo, Y. (2017).
39
+ DeepSleepNet: A model for automatic sleep stage scoring based
40
+ on raw single-channel EEG. IEEE Transactions on Neural Systems
41
+ and Rehabilitation Engineering, 25(11), 1998-2008.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ n_outputs=5,
47
+ return_feats=False,
48
+ n_chans=None,
49
+ chs_info=None,
50
+ n_times=None,
51
+ input_window_seconds=None,
52
+ sfreq=None,
53
+ activation_large: nn.Module = nn.ELU,
54
+ activation_small: nn.Module = nn.ReLU,
55
+ drop_prob: float = 0.5,
56
+ ):
57
+ super().__init__(
58
+ n_outputs=n_outputs,
59
+ n_chans=n_chans,
60
+ chs_info=chs_info,
61
+ n_times=n_times,
62
+ input_window_seconds=input_window_seconds,
63
+ sfreq=sfreq,
64
+ )
65
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
66
+ self.cnn1 = _SmallCNN(activation=activation_small, drop_prob=drop_prob)
67
+ self.cnn2 = _LargeCNN(activation=activation_large, drop_prob=drop_prob)
68
+ self.dropout = nn.Dropout(0.5)
69
+ self.bilstm = _BiLSTM(input_size=3072, hidden_size=512, num_layers=2)
70
+ self.fc = nn.Sequential(
71
+ nn.Linear(3072, 1024, bias=False), nn.BatchNorm1d(num_features=1024)
72
+ )
73
+
74
+ self.features_extractor = nn.Identity()
75
+ self.len_last_layer = 1024
76
+ self.return_feats = return_feats
77
+
78
+ # TODO: Add new way to handle return_features == True
79
+ if not return_feats:
80
+ self.final_layer = nn.Linear(1024, self.n_outputs)
81
+ else:
82
+ self.final_layer = nn.Identity()
83
+
84
+ def forward(self, x):
85
+ """Forward pass.
86
+
87
+ Parameters
88
+ ----------
89
+ x: torch.Tensor
90
+ Batch of EEG windows of shape (batch_size, n_channels, n_times).
91
+ """
92
+
93
+ if x.ndim == 3:
94
+ x = x.unsqueeze(1)
95
+
96
+ x1 = self.cnn1(x)
97
+ x1 = x1.flatten(start_dim=1)
98
+
99
+ x2 = self.cnn2(x)
100
+ x2 = x2.flatten(start_dim=1)
101
+
102
+ x = torch.cat((x1, x2), dim=1)
103
+ x = self.dropout(x)
104
+ temp = x.clone()
105
+ temp = self.fc(temp)
106
+ x = x.unsqueeze(1)
107
+ x = self.bilstm(x)
108
+ x = x.squeeze()
109
+ x = torch.add(x, temp)
110
+ x = self.dropout(x)
111
+
112
+ feats = self.features_extractor(x)
113
+
114
+ if self.return_feats:
115
+ return feats
116
+ else:
117
+ return self.final_layer(feats)
118
+
119
+
120
+ class _SmallCNN(nn.Module):
121
+ """
122
+ Smaller filter sizes to learn temporal information.
123
+
124
+ Parameters
125
+ ----------
126
+ activation: nn.Module, default=nn.ReLU
127
+ Activation function class to apply. Should be a PyTorch activation
128
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
129
+ drop_prob : float, default=0.5
130
+ The dropout rate for regularization. Values should be between 0 and 1.
131
+ """
132
+
133
+ def __init__(self, activation: nn.Module = nn.ReLU, drop_prob: float = 0.5):
12
134
  super().__init__()
13
135
  self.conv1 = nn.Sequential(
14
136
  nn.Conv2d(
@@ -20,10 +142,10 @@ class _SmallCNN(nn.Module): # smaller filter sizes to learn temporal informatio
20
142
  bias=False,
21
143
  ),
22
144
  nn.BatchNorm2d(num_features=64),
23
- nn.ReLU(),
145
+ activation(),
24
146
  )
25
147
  self.pool1 = nn.MaxPool2d(kernel_size=(1, 8), stride=(1, 8), padding=(0, 2))
26
- self.dropout = nn.Dropout(p=0.5)
148
+ self.dropout = nn.Dropout(p=drop_prob)
27
149
  self.conv2 = nn.Sequential(
28
150
  nn.Conv2d(
29
151
  in_channels=64,
@@ -34,7 +156,7 @@ class _SmallCNN(nn.Module): # smaller filter sizes to learn temporal informatio
34
156
  bias=False,
35
157
  ),
36
158
  nn.BatchNorm2d(num_features=128),
37
- nn.ReLU(),
159
+ activation(),
38
160
  )
39
161
  self.conv3 = nn.Sequential(
40
162
  nn.Conv2d(
@@ -46,7 +168,7 @@ class _SmallCNN(nn.Module): # smaller filter sizes to learn temporal informatio
46
168
  bias=False,
47
169
  ),
48
170
  nn.BatchNorm2d(num_features=128),
49
- nn.ReLU(),
171
+ activation(),
50
172
  )
51
173
  self.conv4 = nn.Sequential(
52
174
  nn.Conv2d(
@@ -58,7 +180,7 @@ class _SmallCNN(nn.Module): # smaller filter sizes to learn temporal informatio
58
180
  bias=False,
59
181
  ),
60
182
  nn.BatchNorm2d(num_features=128),
61
- nn.ReLU(),
183
+ activation(),
62
184
  )
63
185
  self.pool2 = nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=(0, 1))
64
186
 
@@ -72,8 +194,19 @@ class _SmallCNN(nn.Module): # smaller filter sizes to learn temporal informatio
72
194
  return x
73
195
 
74
196
 
75
- class _LargeCNN(nn.Module): # larger filter sizes to learn frequency information
76
- def __init__(self):
197
+ class _LargeCNN(nn.Module):
198
+ """
199
+ Larger filter sizes to learn frequency information.
200
+
201
+ Parameters
202
+ ----------
203
+ activation: nn.Module, default=nn.ELU
204
+ Activation function class to apply. Should be a PyTorch activation
205
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
206
+
207
+ """
208
+
209
+ def __init__(self, activation: nn.Module = nn.ELU, drop_prob: float = 0.5):
77
210
  super().__init__()
78
211
 
79
212
  self.conv1 = nn.Sequential(
@@ -86,10 +219,10 @@ class _LargeCNN(nn.Module): # larger filter sizes to learn frequency informatio
86
219
  bias=False,
87
220
  ),
88
221
  nn.BatchNorm2d(num_features=64),
89
- nn.ReLU(),
222
+ activation(),
90
223
  )
91
224
  self.pool1 = nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4))
92
- self.dropout = nn.Dropout(p=0.5)
225
+ self.dropout = nn.Dropout(p=drop_prob)
93
226
  self.conv2 = nn.Sequential(
94
227
  nn.Conv2d(
95
228
  in_channels=64,
@@ -100,7 +233,7 @@ class _LargeCNN(nn.Module): # larger filter sizes to learn frequency informatio
100
233
  bias=False,
101
234
  ),
102
235
  nn.BatchNorm2d(num_features=128),
103
- nn.ReLU(),
236
+ activation(),
104
237
  )
105
238
  self.conv3 = nn.Sequential(
106
239
  nn.Conv2d(
@@ -112,7 +245,7 @@ class _LargeCNN(nn.Module): # larger filter sizes to learn frequency informatio
112
245
  bias=False,
113
246
  ),
114
247
  nn.BatchNorm2d(num_features=128),
115
- nn.ReLU(),
248
+ activation(),
116
249
  )
117
250
  self.conv4 = nn.Sequential(
118
251
  nn.Conv2d(
@@ -124,7 +257,7 @@ class _LargeCNN(nn.Module): # larger filter sizes to learn frequency informatio
124
257
  bias=False,
125
258
  ),
126
259
  nn.BatchNorm2d(num_features=128),
127
- nn.ReLU(),
260
+ activation(),
128
261
  )
129
262
  self.pool2 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=(0, 1))
130
263
 
@@ -154,112 +287,9 @@ class _BiLSTM(nn.Module):
154
287
 
155
288
  def forward(self, x):
156
289
  # set initial hidden and cell states
157
- h0 = torch.zeros(
158
- self.num_layers * 2, x.size(0), self.hidden_size
159
- ).to(x.device)
290
+ h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
160
291
  c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
161
292
 
162
293
  # forward propagate LSTM
163
294
  out, _ = self.lstm(x, (h0, c0))
164
295
  return out
165
-
166
-
167
- class DeepSleepNet(EEGModuleMixin, nn.Module):
168
- """Sleep staging architecture from Supratak et al 2017.
169
-
170
- Convolutional neural network and bidirectional-Long Short-Term
171
- for single channels sleep staging described in [Supratak2017]_.
172
-
173
- Parameters
174
- ----------
175
- return_feats : bool
176
- If True, return the features, i.e. the output of the feature extractor
177
- (before the final linear layer). If False, pass the features through
178
- the final linear layer.
179
- n_classes :
180
- Alias for n_outputs.
181
-
182
- References
183
- ----------
184
- .. [Supratak2017] Supratak, A., Dong, H., Wu, C., & Guo, Y. (2017).
185
- DeepSleepNet: A model for automatic sleep stage scoring based
186
- on raw single-channel EEG. IEEE Transactions on Neural Systems
187
- and Rehabilitation Engineering, 25(11), 1998-2008.
188
- """
189
-
190
- def __init__(
191
- self,
192
- n_outputs=5,
193
- return_feats=False,
194
- n_chans=None,
195
- chs_info=None,
196
- n_times=None,
197
- input_window_seconds=None,
198
- sfreq=None,
199
- n_classes=None,
200
- ):
201
- n_outputs, = deprecated_args(
202
- self,
203
- ('n_classes', 'n_outputs', n_classes, n_outputs),
204
- )
205
- super().__init__(
206
- n_outputs=n_outputs,
207
- n_chans=n_chans,
208
- chs_info=chs_info,
209
- n_times=n_times,
210
- input_window_seconds=input_window_seconds,
211
- sfreq=sfreq,
212
- )
213
- del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
214
- del n_classes
215
- self.cnn1 = _SmallCNN()
216
- self.cnn2 = _LargeCNN()
217
- self.dropout = nn.Dropout(0.5)
218
- self.bilstm = _BiLSTM(input_size=3072, hidden_size=512, num_layers=2)
219
- self.fc = nn.Sequential(nn.Linear(3072, 1024, bias=False),
220
- nn.BatchNorm1d(num_features=1024))
221
-
222
- self.features_extractor = nn.Identity()
223
- self.len_last_layer = 1024
224
- self.return_feats = return_feats
225
-
226
- # TODO: Add new way to handle return_features == True
227
- if not return_feats:
228
- self.final_layer = nn.Linear(1024, self.n_outputs)
229
- else:
230
- self.final_layer = nn.Identity()
231
-
232
- def forward(self, x):
233
- """Forward pass.
234
-
235
- Parameters
236
- ----------
237
- x: torch.Tensor
238
- Batch of EEG windows of shape (batch_size, n_channels, n_times).
239
- """
240
-
241
- if x.ndim == 3:
242
- x = x.unsqueeze(1)
243
-
244
- x1 = self.cnn1(x)
245
- x1 = x1.flatten(start_dim=1)
246
-
247
- x2 = self.cnn2(x)
248
- x2 = x2.flatten(start_dim=1)
249
-
250
- x = torch.cat((x1, x2), dim=1)
251
- x = self.dropout(x)
252
- temp = x.clone()
253
- temp = self.fc(temp)
254
- x = x.unsqueeze(1)
255
- x = self.bilstm(x)
256
- x = x.squeeze()
257
- x = torch.add(x, temp)
258
- x = self.dropout(x)
259
-
260
- feats = self.features_extractor(x)
261
-
262
- if self.return_feats:
263
- return feats
264
- else:
265
- return self.final_layer(feats)