braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,337 @@
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops.layers.torch import Rearrange
8
+
9
+ from braindecode.models.base import EEGModuleMixin
10
+
11
+
12
+ class SincShallowNet(EEGModuleMixin, nn.Module):
13
+ """Sinc-ShallowNet from Borra, D et al (2020) [borra2020]_.
14
+
15
+ .. figure:: https://ars.els-cdn.com/content/image/1-s2.0-S0893608020302021-gr2_lrg.jpg
16
+ :align: center
17
+ :alt: SincShallowNet Architecture
18
+
19
+ The Sinc-ShallowNet architecture has these fundamental blocks:
20
+
21
+ 1. **Block 1: Spectral and Spatial Feature Extraction**
22
+ - *Temporal Sinc-Convolutional Layer*:
23
+ Uses parametrized sinc functions to learn band-pass filters,
24
+ significantly reducing the number of trainable parameters by only
25
+ learning the lower and upper cutoff frequencies for each filter.
26
+ - *Spatial Depthwise Convolutional Layer*:
27
+ Applies depthwise convolutions to learn spatial filters for
28
+ each temporal feature map independently, further reducing
29
+ parameters and enhancing interpretability.
30
+ - *Batch Normalization*
31
+
32
+ 2. **Block 2: Temporal Aggregation**
33
+ - *Activation Function*: ELU
34
+ - *Average Pooling Layer*: Aggregation by averaging spatial dim
35
+ - *Dropout Layer*
36
+ - *Flatten Layer*
37
+
38
+ 3. **Block 3: Classification**
39
+ - *Fully Connected Layer*: Maps the feature vector to n_outputs.
40
+
41
+ **Implementation Notes:**
42
+
43
+ - The sinc-convolutional layer initializes cutoff frequencies uniformly
44
+ within the desired frequency range and updates them during training while
45
+ ensuring the lower cutoff is less than the upper cutoff.
46
+
47
+ Parameters
48
+ ----------
49
+ num_time_filters : int
50
+ Number of temporal filters in the SincFilter layer.
51
+ time_filter_len : int
52
+ Size of the temporal filters.
53
+ depth_multiplier : int
54
+ Depth multiplier for spatial filtering.
55
+ activation : nn.Module, optional
56
+ Activation function to use. Default is nn.ELU().
57
+ drop_prob : float, optional
58
+ Dropout probability. Default is 0.5.
59
+ first_freq : float, optional
60
+ The starting frequency for the first Sinc filter. Default is 5.0.
61
+ min_freq : float, optional
62
+ Minimum frequency allowed for the low frequencies of the filters. Default is 1.0.
63
+ freq_stride : float, optional
64
+ Frequency stride for the Sinc filters. Controls the spacing between the filter frequencies.
65
+ Default is 1.0.
66
+ padding : str, optional
67
+ Padding mode for convolution, either 'same' or 'valid'. Default is 'same'.
68
+ bandwidth : float, optional
69
+ Initial bandwidth for each Sinc filter. Default is 4.0.
70
+ pool_size : int, optional
71
+ Size of the pooling window for the average pooling layer. Default is 55.
72
+ pool_stride : int, optional
73
+ Stride of the pooling operation. Default is 12.
74
+
75
+ Notes
76
+ -----
77
+ This implementation is based on the implementation from [sincshallowcode]_.
78
+
79
+ References
80
+ ----------
81
+ .. [borra2020] Borra, D., Fantozzi, S., & Magosso, E. (2020). Interpretable
82
+ and lightweight convolutional neural network for EEG decoding: Application
83
+ to movement execution and imagination. Neural Networks, 129, 55-74.
84
+ .. [sincshallowcode] Sinc-ShallowNet re-implementation source code:
85
+ https://github.com/marcellosicbaldi/SincNet-Tensorflow
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ num_time_filters: int = 32,
91
+ time_filter_len: int = 33,
92
+ depth_multiplier: int = 2,
93
+ activation: Optional[nn.Module] = nn.ELU,
94
+ drop_prob: float = 0.5,
95
+ first_freq: float = 5.0,
96
+ min_freq: float = 1.0,
97
+ freq_stride: float = 1.0,
98
+ padding: str = "same",
99
+ bandwidth: float = 4.0,
100
+ pool_size: int = 55,
101
+ pool_stride: int = 12,
102
+ # braindecode parameters
103
+ n_chans=None,
104
+ n_outputs=None,
105
+ n_times=None,
106
+ input_window_seconds=None,
107
+ sfreq=None,
108
+ chs_info=None,
109
+ ):
110
+ super().__init__(
111
+ n_outputs=n_outputs,
112
+ n_chans=n_chans,
113
+ chs_info=chs_info,
114
+ n_times=n_times,
115
+ input_window_seconds=input_window_seconds,
116
+ sfreq=sfreq,
117
+ )
118
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
119
+
120
+ # Checkers and creating variables
121
+ if activation is None:
122
+ activation = nn.ELU()
123
+
124
+ # Define low frequencies for the SincFilter
125
+ low_freqs = torch.arange(
126
+ first_freq,
127
+ first_freq + num_time_filters * freq_stride,
128
+ freq_stride,
129
+ dtype=torch.float32,
130
+ )
131
+ self.n_filters = len(low_freqs)
132
+
133
+ if padding.lower() == "valid":
134
+ n_times_after_sinc_filter = self.n_times - time_filter_len + 1
135
+ elif padding.lower() == "same":
136
+ n_times_after_sinc_filter = self.n_times
137
+ else:
138
+ raise ValueError("Padding must be 'valid' or 'same'.")
139
+
140
+ size_after_pooling = (
141
+ (n_times_after_sinc_filter - pool_size) // pool_stride
142
+ ) + 1
143
+ flattened_size = num_time_filters * depth_multiplier * size_after_pooling
144
+
145
+ # Layers
146
+ self.ensuredims = Rearrange("batch chans times -> batch chans times 1")
147
+
148
+ # Block 1: Sinc filter
149
+ self.sinc_filter_layer = _SincFilter(
150
+ low_freqs=low_freqs,
151
+ kernel_size=time_filter_len,
152
+ sfreq=self.sfreq,
153
+ padding=padding,
154
+ bandwidth=bandwidth,
155
+ min_freq=min_freq,
156
+ )
157
+
158
+ self.depthwiseconv = nn.Sequential(
159
+ # Matching dim to depth wise conv!
160
+ Rearrange("batch timefil time nfilter -> batch nfilter timefil time"),
161
+ nn.BatchNorm2d(
162
+ self.n_filters, momentum=0.99
163
+ ), # To match keras implementation
164
+ nn.Conv2d(
165
+ in_channels=self.n_filters,
166
+ out_channels=depth_multiplier * self.n_filters,
167
+ kernel_size=(self.n_chans, 1),
168
+ groups=self.n_filters,
169
+ bias=False,
170
+ ),
171
+ )
172
+
173
+ # Block 2: Batch norm, activation, pooling, dropout
174
+ self.temporal_aggregation = nn.Sequential(
175
+ nn.BatchNorm2d(depth_multiplier * self.n_filters, momentum=0.99),
176
+ activation(),
177
+ nn.AvgPool2d(kernel_size=(1, pool_size), stride=(1, pool_stride)),
178
+ nn.Dropout(p=drop_prob),
179
+ nn.Flatten(),
180
+ )
181
+
182
+ # Final classification layer
183
+ self.final_layer = nn.Linear(
184
+ flattened_size,
185
+ self.n_outputs,
186
+ )
187
+
188
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
189
+ """
190
+ Forward pass of the model.
191
+
192
+ Parameters
193
+ ----------
194
+ x : torch.Tensor
195
+ Input tensor of shape [batch_size, num_channels, num_samples].
196
+
197
+ Returns
198
+ -------
199
+ torch.Tensor
200
+ Output logits of shape [batch_size, num_classes].
201
+ """
202
+ x = self.ensuredims(x)
203
+ x = self.sinc_filter_layer(x)
204
+ x = self.depthwiseconv(x)
205
+ x = self.temporal_aggregation(x)
206
+
207
+ return self.final_layer(x)
208
+
209
+
210
+ class _SincFilter(nn.Module):
211
+ """Sinc-Based Convolutional Layer for Band-Pass Filtering from Ravanelli and Bengio (2018) [ravanelli]_.
212
+
213
+ The `SincFilter` layer implements a convolutional layer where each kernel is
214
+ defined using a parametrized sinc function.
215
+ This design enforces each kernel to represent a band-pass filter,
216
+ reducing the number of trainable parameters.
217
+
218
+ Parameters
219
+ ----------
220
+ low_freqs : torch.Tensor
221
+ Initial low cutoff frequencies for each filter.
222
+ kernel_size : int
223
+ Size of the convolutional kernels (filters). Must be odd.
224
+ sfreq : float
225
+ Sampling rate of the input signal.
226
+ bandwidth : float, optional
227
+ Initial bandwidth for each filter. Default is 4.0.
228
+ min_freq : float, optional
229
+ Minimum frequency allowed for low frequencies. Default is 1.0.
230
+ padding : str, optional
231
+ Padding mode, either 'same' or 'valid'. Default is 'same'.
232
+
233
+ References
234
+ ----------
235
+ .. [ravanelli] Ravanelli, M., & Bengio, Y. (2018, December). Speaker
236
+ recognition from raw waveform with sincnet. In 2018 IEEE spoken language
237
+ technology workshop (SLT) (pp. 1021-1028). IEEE.
238
+ """
239
+
240
+ def __init__(
241
+ self,
242
+ low_freqs: torch.Tensor,
243
+ kernel_size: int,
244
+ sfreq: float,
245
+ bandwidth: float = 4.0,
246
+ min_freq: float = 1.0,
247
+ padding: str = "same",
248
+ ):
249
+ super().__init__()
250
+ if kernel_size % 2 == 0:
251
+ raise ValueError("Kernel size must be odd.")
252
+
253
+ self.num_filters = low_freqs.numel()
254
+ self.kernel_size = kernel_size
255
+ self.sfreq = sfreq
256
+ self.min_freq = min_freq
257
+ self.padding = padding.lower()
258
+
259
+ # Precompute constants
260
+ window = torch.hamming_window(kernel_size, periodic=False)
261
+
262
+ self.register_buffer("window", window[: kernel_size // 2].unsqueeze(-1))
263
+
264
+ n_pi = (
265
+ torch.arange(-(kernel_size // 2), 0, dtype=torch.float32)
266
+ / sfreq
267
+ * 2
268
+ * math.pi
269
+ )
270
+ self.register_buffer("n_pi", n_pi.unsqueeze(-1))
271
+
272
+ # Initialize learnable parameters
273
+ bandwidths = torch.full((1, self.num_filters), bandwidth)
274
+ self.bandwidths = nn.Parameter(bandwidths)
275
+ self.low_freqs = nn.Parameter(low_freqs.unsqueeze(0))
276
+
277
+ # Constant tensor of ones for filter construction
278
+ self.register_buffer("ones", torch.ones(1, 1, 1, self.num_filters))
279
+
280
+ def build_sinc_filters(self) -> torch.Tensor:
281
+ """Builds the sinc filters based on current parameters."""
282
+ # Computing the low frequencies of the filters
283
+ low_freqs = self.min_freq + torch.abs(self.low_freqs)
284
+ # Setting a minimum band and minimum freq
285
+ high_freqs = torch.clamp(
286
+ low_freqs + torch.abs(self.bandwidths),
287
+ min=self.min_freq,
288
+ max=self.sfreq / 2.0,
289
+ )
290
+ bandwidths = high_freqs - low_freqs
291
+
292
+ # Passing from n_ to the corresponding f_times_t domain
293
+ low = self.n_pi * low_freqs # [kernel_size // 2, num_filters]
294
+ high = self.n_pi * high_freqs # [kernel_size // 2, num_filters]
295
+
296
+ filters_left = (torch.sin(high) - torch.sin(low)) / (self.n_pi / 2.0)
297
+ filters_left *= self.window
298
+ filters_left /= 2.0 * bandwidths
299
+
300
+ # [1, kernel_size // 2, 1, num_filters]
301
+ filters_left = filters_left.unsqueeze(0).unsqueeze(2)
302
+ filters_right = torch.flip(filters_left, dims=[1])
303
+
304
+ filters = torch.cat(
305
+ [filters_left, self.ones, filters_right], dim=1
306
+ ) # [1, kernel_size, 1, num_filters]
307
+ filters = filters / torch.std(filters)
308
+ return filters
309
+
310
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
311
+ """
312
+ Apply sinc filters to the input signal.
313
+
314
+ Parameters
315
+ ----------
316
+ inputs : torch.Tensor
317
+ Input tensor of shape [batch_size, num_channels, num_samples, 1].
318
+
319
+ Returns
320
+ -------
321
+ torch.Tensor
322
+ Filtered output tensor of shape [batch_size, num_channels, num_samples, num_filters].
323
+ """
324
+ filters = self.build_sinc_filters().to(
325
+ inputs.device
326
+ ) # [1, kernel_size, 1, num_filters]
327
+
328
+ # Convert from channels_last to channels_first format
329
+ inputs = inputs.permute(0, 3, 1, 2)
330
+ # Permuting to match conv:
331
+ filters = filters.permute(3, 2, 0, 1)
332
+ # Apply convolution
333
+ outputs = F.conv2d(inputs, filters, padding=self.padding)
334
+ # Changing the dimensional
335
+ outputs = outputs.permute(0, 2, 3, 1)
336
+
337
+ return outputs
@@ -5,11 +5,15 @@
5
5
  import torch
6
6
  from torch import nn
7
7
 
8
- from .base import EEGModuleMixin, deprecated_args
8
+ from braindecode.models.base import EEGModuleMixin
9
9
 
10
10
 
11
11
  class SleepStagerBlanco2020(EEGModuleMixin, nn.Module):
12
- """Sleep staging architecture from Blanco et al 2020.
12
+ """Sleep staging architecture from Blanco et al. (2020) from [Blanco2020]_
13
+
14
+ .. figure:: https://media.springernature.com/full/springer-static/image/art%3A10.1007%2Fs00500-019-04174-1/MediaObjects/500_2019_4174_Fig2_HTML.png
15
+ :align: center
16
+ :alt: SleepStagerBlanco2020 Architecture
13
17
 
14
18
  Convolutional neural network for sleep staging described in [Blanco2020]_.
15
19
  A series of seven convolutional layers with kernel sizes running down from 7 to 3,
@@ -24,7 +28,7 @@ class SleepStagerBlanco2020(EEGModuleMixin, nn.Module):
24
28
  Number of groups for the convolution. Set to 2 in [Blanco2020]_ for 2 Channel EEG.
25
29
  controls the connections between inputs and outputs. n_channels and n_conv_chans must be
26
30
  divisible by n_groups.
27
- dropout : float
31
+ drop_prob : float
28
32
  Dropout rate before the output dense layer.
29
33
  apply_batch_norm : bool
30
34
  If True, apply batch normalization after both temporal convolutional
@@ -39,6 +43,9 @@ class SleepStagerBlanco2020(EEGModuleMixin, nn.Module):
39
43
  Alias for `n_outputs`.
40
44
  input_size_s : float
41
45
  Alias for `input_window_seconds`.
46
+ activation: nn.Module, default=nn.ReLU
47
+ Activation function class to apply. Should be a PyTorch activation
48
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
42
49
 
43
50
  References
44
51
  ----------
@@ -48,30 +55,21 @@ class SleepStagerBlanco2020(EEGModuleMixin, nn.Module):
48
55
  """
49
56
 
50
57
  def __init__(
51
- self,
52
- n_chans=None,
53
- sfreq=None,
54
- n_conv_chans=20,
55
- input_window_seconds=30,
56
- n_outputs=5,
57
- n_groups=2,
58
- max_pool_size=2,
59
- dropout=0.5,
60
- apply_batch_norm=False,
61
- return_feats=False,
62
- chs_info=None,
63
- n_times=None,
64
- n_channels=None,
65
- n_classes=None,
66
- input_size_s=None,
67
- add_log_softmax=True,
58
+ self,
59
+ n_chans=None,
60
+ sfreq=None,
61
+ n_conv_chans=20,
62
+ input_window_seconds=None,
63
+ n_outputs=5,
64
+ n_groups=2,
65
+ max_pool_size=2,
66
+ drop_prob=0.5,
67
+ apply_batch_norm=False,
68
+ return_feats=False,
69
+ activation: nn.Module = nn.ReLU,
70
+ chs_info=None,
71
+ n_times=None,
68
72
  ):
69
- n_chans, n_outputs, input_window_seconds, = deprecated_args(
70
- self,
71
- ("n_channels", "n_chans", n_channels, n_chans),
72
- ("n_classes", "n_outputs", n_classes, n_outputs),
73
- ("input_size_s", "input_window_seconds", input_size_s, input_window_seconds),
74
- )
75
73
  super().__init__(
76
74
  n_outputs=n_outputs,
77
75
  n_chans=n_chans,
@@ -79,14 +77,12 @@ class SleepStagerBlanco2020(EEGModuleMixin, nn.Module):
79
77
  n_times=n_times,
80
78
  input_window_seconds=input_window_seconds,
81
79
  sfreq=sfreq,
82
- add_log_softmax=add_log_softmax,
83
80
  )
84
81
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
85
- del n_channels, n_classes, input_size_s
86
82
 
87
83
  self.mapping = {
88
84
  "fc.1.weight": "final_layer.1.weight",
89
- "fc.1.bias": "final_layer.1.bias"
85
+ "fc.1.bias": "final_layer.1.bias",
90
86
  }
91
87
 
92
88
  batch_norm = nn.BatchNorm2d if apply_batch_norm else nn.Identity
@@ -94,32 +90,44 @@ class SleepStagerBlanco2020(EEGModuleMixin, nn.Module):
94
90
  self.feature_extractor = nn.Sequential(
95
91
  nn.Conv2d(self.n_chans, n_conv_chans, (1, 7), groups=n_groups, padding=0),
96
92
  batch_norm(n_conv_chans),
97
- nn.ReLU(),
93
+ activation(),
98
94
  nn.MaxPool2d((1, max_pool_size)),
99
- nn.Conv2d(n_conv_chans, n_conv_chans, (1, 7), groups=n_conv_chans, padding=0),
95
+ nn.Conv2d(
96
+ n_conv_chans, n_conv_chans, (1, 7), groups=n_conv_chans, padding=0
97
+ ),
100
98
  batch_norm(n_conv_chans),
101
- nn.ReLU(),
99
+ activation(),
102
100
  nn.MaxPool2d((1, max_pool_size)),
103
- nn.Conv2d(n_conv_chans, n_conv_chans, (1, 5), groups=n_conv_chans, padding=0),
101
+ nn.Conv2d(
102
+ n_conv_chans, n_conv_chans, (1, 5), groups=n_conv_chans, padding=0
103
+ ),
104
104
  batch_norm(n_conv_chans),
105
- nn.ReLU(),
105
+ activation(),
106
106
  nn.MaxPool2d((1, max_pool_size)),
107
- nn.Conv2d(n_conv_chans, n_conv_chans, (1, 5), groups=n_conv_chans, padding=0),
107
+ nn.Conv2d(
108
+ n_conv_chans, n_conv_chans, (1, 5), groups=n_conv_chans, padding=0
109
+ ),
108
110
  batch_norm(n_conv_chans),
109
- nn.ReLU(),
111
+ activation(),
110
112
  nn.MaxPool2d((1, max_pool_size)),
111
- nn.Conv2d(n_conv_chans, n_conv_chans, (1, 5), groups=n_conv_chans, padding=0),
113
+ nn.Conv2d(
114
+ n_conv_chans, n_conv_chans, (1, 5), groups=n_conv_chans, padding=0
115
+ ),
112
116
  batch_norm(n_conv_chans),
113
- nn.ReLU(),
117
+ activation(),
114
118
  nn.MaxPool2d((1, max_pool_size)),
115
- nn.Conv2d(n_conv_chans, n_conv_chans, (1, 3), groups=n_conv_chans, padding=0),
119
+ nn.Conv2d(
120
+ n_conv_chans, n_conv_chans, (1, 3), groups=n_conv_chans, padding=0
121
+ ),
116
122
  batch_norm(n_conv_chans),
117
- nn.ReLU(),
123
+ activation(),
118
124
  nn.MaxPool2d((1, max_pool_size)),
119
- nn.Conv2d(n_conv_chans, n_conv_chans, (1, 3), groups=n_conv_chans, padding=0),
125
+ nn.Conv2d(
126
+ n_conv_chans, n_conv_chans, (1, 3), groups=n_conv_chans, padding=0
127
+ ),
120
128
  batch_norm(n_conv_chans),
121
- nn.ReLU(),
122
- nn.MaxPool2d((1, max_pool_size))
129
+ activation(),
130
+ nn.MaxPool2d((1, max_pool_size)),
123
131
  )
124
132
 
125
133
  self.len_last_layer = self._len_last_layer(self.n_chans, self.n_times)
@@ -128,16 +136,17 @@ class SleepStagerBlanco2020(EEGModuleMixin, nn.Module):
128
136
  # TODO: Add new way to handle return_features == True
129
137
  if not return_feats:
130
138
  self.final_layer = nn.Sequential(
131
- nn.Dropout(dropout),
139
+ nn.Dropout(drop_prob),
132
140
  nn.Linear(self.len_last_layer, self.n_outputs),
133
- nn.LogSoftmax(dim=1) if self.add_log_softmax else nn.Identity()
141
+ nn.Identity(),
134
142
  )
135
143
 
136
144
  def _len_last_layer(self, n_channels, input_size):
137
145
  self.feature_extractor.eval()
138
146
  with torch.no_grad():
139
147
  out = self.feature_extractor(
140
- torch.Tensor(1, n_channels, 1, input_size)) # batch_size,n_channels,height,width
148
+ torch.Tensor(1, n_channels, 1, input_size)
149
+ ) # batch_size,n_channels,height,width
141
150
  self.feature_extractor.train()
142
151
  return len(out.flatten())
143
152
 
@@ -2,14 +2,20 @@
2
2
  #
3
3
  # License: BSD (3-clause)
4
4
 
5
+ import math
6
+
5
7
  import torch
6
8
  from torch import nn
7
- import numpy as np
8
- from .base import EEGModuleMixin, deprecated_args
9
+
10
+ from braindecode.models.base import EEGModuleMixin
9
11
 
10
12
 
11
13
  class SleepStagerChambon2018(EEGModuleMixin, nn.Module):
12
- """Sleep staging architecture from Chambon et al 2018.
14
+ """Sleep staging architecture from Chambon et al. (2018) [Chambon2018]_.
15
+
16
+ .. figure:: https://braindecode.org/dev/_static/model/SleepStagerChambon2018.jpg
17
+ :align: center
18
+ :alt: SleepStagerChambon2018 Architecture
13
19
 
14
20
  Convolutional neural network for sleep staging described in [Chambon2018]_.
15
21
 
@@ -26,7 +32,7 @@ class SleepStagerChambon2018(EEGModuleMixin, nn.Module):
26
32
  pad_size_s : float
27
33
  Padding size, in seconds. Set to 0.25 in [Chambon2018]_ (half the
28
34
  temporal convolution kernel size).
29
- dropout : float
35
+ drop_prob : float
30
36
  Dropout rate before the output dense layer.
31
37
  apply_batch_norm : bool
32
38
  If True, apply batch normalization after both temporal convolutional
@@ -41,6 +47,9 @@ class SleepStagerChambon2018(EEGModuleMixin, nn.Module):
41
47
  Alias for `input_window_seconds`.
42
48
  n_classes:
43
49
  Alias for `n_outputs`.
50
+ activation: nn.Module, default=nn.ReLU
51
+ Activation function class to apply. Should be a PyTorch activation
52
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
44
53
 
45
54
  References
46
55
  ----------
@@ -52,30 +61,22 @@ class SleepStagerChambon2018(EEGModuleMixin, nn.Module):
52
61
  """
53
62
 
54
63
  def __init__(
55
- self,
56
- n_chans=None,
57
- sfreq=None,
58
- n_conv_chs=8,
59
- time_conv_size_s=0.5,
60
- max_pool_size_s=0.125,
61
- pad_size_s=0.25,
62
- input_window_seconds=30,
63
- n_outputs=5,
64
- dropout=0.25,
65
- apply_batch_norm=False,
66
- return_feats=False,
67
- chs_info=None,
68
- n_times=None,
69
- n_channels=None,
70
- input_size_s=None,
71
- n_classes=None,
64
+ self,
65
+ n_chans=None,
66
+ sfreq=None,
67
+ n_conv_chs=8,
68
+ time_conv_size_s=0.5,
69
+ max_pool_size_s=0.125,
70
+ pad_size_s=0.25,
71
+ activation: nn.Module = nn.ReLU,
72
+ input_window_seconds=None,
73
+ n_outputs=5,
74
+ drop_prob=0.25,
75
+ apply_batch_norm=False,
76
+ return_feats=False,
77
+ chs_info=None,
78
+ n_times=None,
72
79
  ):
73
- n_chans, n_outputs, input_window_seconds, = deprecated_args(
74
- self,
75
- ("n_channels", "n_chans", n_channels, n_chans),
76
- ("n_classes", "n_outputs", n_classes, n_outputs),
77
- ("input_size_s", "input_window_seconds", input_size_s, input_window_seconds),
78
- )
79
80
  super().__init__(
80
81
  n_outputs=n_outputs,
81
82
  n_chans=n_chans,
@@ -85,54 +86,54 @@ class SleepStagerChambon2018(EEGModuleMixin, nn.Module):
85
86
  sfreq=sfreq,
86
87
  )
87
88
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
88
- del n_channels, n_classes, input_size_s
89
89
 
90
90
  self.mapping = {
91
91
  "fc.1.weight": "final_layer.1.weight",
92
- "fc.1.bias": "final_layer.1.bias"
92
+ "fc.1.bias": "final_layer.1.bias",
93
93
  }
94
94
 
95
- time_conv_size = np.ceil(time_conv_size_s * self.sfreq).astype(int)
96
- max_pool_size = np.ceil(max_pool_size_s * self.sfreq).astype(int)
97
- pad_size = np.ceil(pad_size_s * self.sfreq).astype(int)
95
+ time_conv_size = math.ceil(time_conv_size_s * self.sfreq)
96
+ max_pool_size = math.ceil(max_pool_size_s * self.sfreq)
97
+ pad_size = math.ceil(pad_size_s * self.sfreq)
98
98
 
99
99
  if self.n_chans > 1:
100
100
  self.spatial_conv = nn.Conv2d(1, self.n_chans, (self.n_chans, 1))
101
+ else:
102
+ self.spatial_conv = nn.Identity()
101
103
 
102
104
  batch_norm = nn.BatchNorm2d if apply_batch_norm else nn.Identity
103
105
 
104
106
  self.feature_extractor = nn.Sequential(
105
- nn.Conv2d(
106
- 1, n_conv_chs, (1, time_conv_size), padding=(0, pad_size)),
107
+ nn.Conv2d(1, n_conv_chs, (1, time_conv_size), padding=(0, pad_size)),
107
108
  batch_norm(n_conv_chs),
108
- nn.ReLU(),
109
+ activation(),
109
110
  nn.MaxPool2d((1, max_pool_size)),
110
111
  nn.Conv2d(
111
- n_conv_chs, n_conv_chs, (1, time_conv_size),
112
- padding=(0, pad_size)),
112
+ n_conv_chs, n_conv_chs, (1, time_conv_size), padding=(0, pad_size)
113
+ ),
113
114
  batch_norm(n_conv_chs),
114
- nn.ReLU(),
115
- nn.MaxPool2d((1, max_pool_size))
115
+ activation(),
116
+ nn.MaxPool2d((1, max_pool_size)),
116
117
  )
117
- self.len_last_layer = self._len_last_layer(self.n_chans, self.n_times)
118
118
  self.return_feats = return_feats
119
119
 
120
+ dim_conv_1 = (
121
+ self.n_times + 2 * pad_size - (time_conv_size - 1)
122
+ ) // max_pool_size
123
+ dim_after_conv = (
124
+ dim_conv_1 + 2 * pad_size - (time_conv_size - 1)
125
+ ) // max_pool_size
126
+
127
+ self.len_last_layer = n_conv_chs * self.n_chans * dim_after_conv
128
+
120
129
  # TODO: Add new way to handle return_features == True
121
130
  if not return_feats:
122
131
  self.final_layer = nn.Sequential(
123
- nn.Dropout(dropout),
124
- nn.Linear(self.len_last_layer, self.n_outputs),
132
+ nn.Dropout(p=drop_prob),
133
+ nn.Linear(in_features=self.len_last_layer, out_features=self.n_outputs),
125
134
  )
126
135
 
127
- def _len_last_layer(self, n_channels, input_size):
128
- self.feature_extractor.eval()
129
- with torch.no_grad():
130
- out = self.feature_extractor(
131
- torch.Tensor(1, 1, n_channels, input_size))
132
- self.feature_extractor.train()
133
- return len(out.flatten())
134
-
135
- def forward(self, x):
136
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
136
137
  """
137
138
  Forward pass.
138
139
 
@@ -152,5 +153,5 @@ class SleepStagerChambon2018(EEGModuleMixin, nn.Module):
152
153
 
153
154
  if self.return_feats:
154
155
  return feats
155
- else:
156
- return self.final_layer(feats)
156
+
157
+ return self.final_layer(feats)