braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,337 @@
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops.layers.torch import Rearrange
8
+
9
+ from braindecode.models.base import EEGModuleMixin
10
+
11
+
12
+ class SincShallowNet(EEGModuleMixin, nn.Module):
13
+ """Sinc-ShallowNet from Borra, D et al (2020) [borra2020]_.
14
+
15
+ .. figure:: https://ars.els-cdn.com/content/image/1-s2.0-S0893608020302021-gr2_lrg.jpg
16
+ :align: center
17
+ :alt: SincShallowNet Architecture
18
+
19
+ The Sinc-ShallowNet architecture has these fundamental blocks:
20
+
21
+ 1. **Block 1: Spectral and Spatial Feature Extraction**
22
+ - *Temporal Sinc-Convolutional Layer*:
23
+ Uses parametrized sinc functions to learn band-pass filters,
24
+ significantly reducing the number of trainable parameters by only
25
+ learning the lower and upper cutoff frequencies for each filter.
26
+ - *Spatial Depthwise Convolutional Layer*:
27
+ Applies depthwise convolutions to learn spatial filters for
28
+ each temporal feature map independently, further reducing
29
+ parameters and enhancing interpretability.
30
+ - *Batch Normalization*
31
+
32
+ 2. **Block 2: Temporal Aggregation**
33
+ - *Activation Function*: ELU
34
+ - *Average Pooling Layer*: Aggregation by averaging spatial dim
35
+ - *Dropout Layer*
36
+ - *Flatten Layer*
37
+
38
+ 3. **Block 3: Classification**
39
+ - *Fully Connected Layer*: Maps the feature vector to n_outputs.
40
+
41
+ **Implementation Notes:**
42
+
43
+ - The sinc-convolutional layer initializes cutoff frequencies uniformly
44
+ within the desired frequency range and updates them during training while
45
+ ensuring the lower cutoff is less than the upper cutoff.
46
+
47
+ Parameters
48
+ ----------
49
+ num_time_filters : int
50
+ Number of temporal filters in the SincFilter layer.
51
+ time_filter_len : int
52
+ Size of the temporal filters.
53
+ depth_multiplier : int
54
+ Depth multiplier for spatial filtering.
55
+ activation : nn.Module, optional
56
+ Activation function to use. Default is nn.ELU().
57
+ drop_prob : float, optional
58
+ Dropout probability. Default is 0.5.
59
+ first_freq : float, optional
60
+ The starting frequency for the first Sinc filter. Default is 5.0.
61
+ min_freq : float, optional
62
+ Minimum frequency allowed for the low frequencies of the filters. Default is 1.0.
63
+ freq_stride : float, optional
64
+ Frequency stride for the Sinc filters. Controls the spacing between the filter frequencies.
65
+ Default is 1.0.
66
+ padding : str, optional
67
+ Padding mode for convolution, either 'same' or 'valid'. Default is 'same'.
68
+ bandwidth : float, optional
69
+ Initial bandwidth for each Sinc filter. Default is 4.0.
70
+ pool_size : int, optional
71
+ Size of the pooling window for the average pooling layer. Default is 55.
72
+ pool_stride : int, optional
73
+ Stride of the pooling operation. Default is 12.
74
+
75
+ Notes
76
+ -----
77
+ This implementation is based on the implementation from [sincshallowcode]_.
78
+
79
+ References
80
+ ----------
81
+ .. [borra2020] Borra, D., Fantozzi, S., & Magosso, E. (2020). Interpretable
82
+ and lightweight convolutional neural network for EEG decoding: Application
83
+ to movement execution and imagination. Neural Networks, 129, 55-74.
84
+ .. [sincshallowcode] Sinc-ShallowNet re-implementation source code:
85
+ https://github.com/marcellosicbaldi/SincNet-Tensorflow
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ num_time_filters: int = 32,
91
+ time_filter_len: int = 33,
92
+ depth_multiplier: int = 2,
93
+ activation: Optional[nn.Module] = nn.ELU,
94
+ drop_prob: float = 0.5,
95
+ first_freq: float = 5.0,
96
+ min_freq: float = 1.0,
97
+ freq_stride: float = 1.0,
98
+ padding: str = "same",
99
+ bandwidth: float = 4.0,
100
+ pool_size: int = 55,
101
+ pool_stride: int = 12,
102
+ # braindecode parameters
103
+ n_chans=None,
104
+ n_outputs=None,
105
+ n_times=None,
106
+ input_window_seconds=None,
107
+ sfreq=None,
108
+ chs_info=None,
109
+ ):
110
+ super().__init__(
111
+ n_outputs=n_outputs,
112
+ n_chans=n_chans,
113
+ chs_info=chs_info,
114
+ n_times=n_times,
115
+ input_window_seconds=input_window_seconds,
116
+ sfreq=sfreq,
117
+ )
118
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
119
+
120
+ # Checkers and creating variables
121
+ if activation is None:
122
+ activation = nn.ELU()
123
+
124
+ # Define low frequencies for the SincFilter
125
+ low_freqs = torch.arange(
126
+ first_freq,
127
+ first_freq + num_time_filters * freq_stride,
128
+ freq_stride,
129
+ dtype=torch.float32,
130
+ )
131
+ self.n_filters = len(low_freqs)
132
+
133
+ if padding.lower() == "valid":
134
+ n_times_after_sinc_filter = self.n_times - time_filter_len + 1
135
+ elif padding.lower() == "same":
136
+ n_times_after_sinc_filter = self.n_times
137
+ else:
138
+ raise ValueError("Padding must be 'valid' or 'same'.")
139
+
140
+ size_after_pooling = (
141
+ (n_times_after_sinc_filter - pool_size) // pool_stride
142
+ ) + 1
143
+ flattened_size = num_time_filters * depth_multiplier * size_after_pooling
144
+
145
+ # Layers
146
+ self.ensuredims = Rearrange("batch chans times -> batch chans times 1")
147
+
148
+ # Block 1: Sinc filter
149
+ self.sinc_filter_layer = _SincFilter(
150
+ low_freqs=low_freqs,
151
+ kernel_size=time_filter_len,
152
+ sfreq=self.sfreq,
153
+ padding=padding,
154
+ bandwidth=bandwidth,
155
+ min_freq=min_freq,
156
+ )
157
+
158
+ self.depthwiseconv = nn.Sequential(
159
+ # Matching dim to depth wise conv!
160
+ Rearrange("batch timefil time nfilter -> batch nfilter timefil time"),
161
+ nn.BatchNorm2d(
162
+ self.n_filters, momentum=0.99
163
+ ), # To match keras implementation
164
+ nn.Conv2d(
165
+ in_channels=self.n_filters,
166
+ out_channels=depth_multiplier * self.n_filters,
167
+ kernel_size=(self.n_chans, 1),
168
+ groups=self.n_filters,
169
+ bias=False,
170
+ ),
171
+ )
172
+
173
+ # Block 2: Batch norm, activation, pooling, dropout
174
+ self.temporal_aggregation = nn.Sequential(
175
+ nn.BatchNorm2d(depth_multiplier * self.n_filters, momentum=0.99),
176
+ activation(),
177
+ nn.AvgPool2d(kernel_size=(1, pool_size), stride=(1, pool_stride)),
178
+ nn.Dropout(p=drop_prob),
179
+ nn.Flatten(),
180
+ )
181
+
182
+ # Final classification layer
183
+ self.final_layer = nn.Linear(
184
+ flattened_size,
185
+ self.n_outputs,
186
+ )
187
+
188
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
189
+ """
190
+ Forward pass of the model.
191
+
192
+ Parameters
193
+ ----------
194
+ x : torch.Tensor
195
+ Input tensor of shape [batch_size, num_channels, num_samples].
196
+
197
+ Returns
198
+ -------
199
+ torch.Tensor
200
+ Output logits of shape [batch_size, num_classes].
201
+ """
202
+ x = self.ensuredims(x)
203
+ x = self.sinc_filter_layer(x)
204
+ x = self.depthwiseconv(x)
205
+ x = self.temporal_aggregation(x)
206
+
207
+ return self.final_layer(x)
208
+
209
+
210
+ class _SincFilter(nn.Module):
211
+ """Sinc-Based Convolutional Layer for Band-Pass Filtering from Ravanelli and Bengio (2018) [ravanelli]_.
212
+
213
+ The `SincFilter` layer implements a convolutional layer where each kernel is
214
+ defined using a parametrized sinc function.
215
+ This design enforces each kernel to represent a band-pass filter,
216
+ reducing the number of trainable parameters.
217
+
218
+ Parameters
219
+ ----------
220
+ low_freqs : torch.Tensor
221
+ Initial low cutoff frequencies for each filter.
222
+ kernel_size : int
223
+ Size of the convolutional kernels (filters). Must be odd.
224
+ sfreq : float
225
+ Sampling rate of the input signal.
226
+ bandwidth : float, optional
227
+ Initial bandwidth for each filter. Default is 4.0.
228
+ min_freq : float, optional
229
+ Minimum frequency allowed for low frequencies. Default is 1.0.
230
+ padding : str, optional
231
+ Padding mode, either 'same' or 'valid'. Default is 'same'.
232
+
233
+ References
234
+ ----------
235
+ .. [ravanelli] Ravanelli, M., & Bengio, Y. (2018, December). Speaker
236
+ recognition from raw waveform with sincnet. In 2018 IEEE spoken language
237
+ technology workshop (SLT) (pp. 1021-1028). IEEE.
238
+ """
239
+
240
+ def __init__(
241
+ self,
242
+ low_freqs: torch.Tensor,
243
+ kernel_size: int,
244
+ sfreq: float,
245
+ bandwidth: float = 4.0,
246
+ min_freq: float = 1.0,
247
+ padding: str = "same",
248
+ ):
249
+ super().__init__()
250
+ if kernel_size % 2 == 0:
251
+ raise ValueError("Kernel size must be odd.")
252
+
253
+ self.num_filters = low_freqs.numel()
254
+ self.kernel_size = kernel_size
255
+ self.sfreq = sfreq
256
+ self.min_freq = min_freq
257
+ self.padding = padding.lower()
258
+
259
+ # Precompute constants
260
+ window = torch.hamming_window(kernel_size, periodic=False)
261
+
262
+ self.register_buffer("window", window[: kernel_size // 2].unsqueeze(-1))
263
+
264
+ n_pi = (
265
+ torch.arange(-(kernel_size // 2), 0, dtype=torch.float32)
266
+ / sfreq
267
+ * 2
268
+ * math.pi
269
+ )
270
+ self.register_buffer("n_pi", n_pi.unsqueeze(-1))
271
+
272
+ # Initialize learnable parameters
273
+ bandwidths = torch.full((1, self.num_filters), bandwidth)
274
+ self.bandwidths = nn.Parameter(bandwidths)
275
+ self.low_freqs = nn.Parameter(low_freqs.unsqueeze(0))
276
+
277
+ # Constant tensor of ones for filter construction
278
+ self.register_buffer("ones", torch.ones(1, 1, 1, self.num_filters))
279
+
280
+ def build_sinc_filters(self) -> torch.Tensor:
281
+ """Builds the sinc filters based on current parameters."""
282
+ # Computing the low frequencies of the filters
283
+ low_freqs = self.min_freq + torch.abs(self.low_freqs)
284
+ # Setting a minimum band and minimum freq
285
+ high_freqs = torch.clamp(
286
+ low_freqs + torch.abs(self.bandwidths),
287
+ min=self.min_freq,
288
+ max=self.sfreq / 2.0,
289
+ )
290
+ bandwidths = high_freqs - low_freqs
291
+
292
+ # Passing from n_ to the corresponding f_times_t domain
293
+ low = self.n_pi * low_freqs # [kernel_size // 2, num_filters]
294
+ high = self.n_pi * high_freqs # [kernel_size // 2, num_filters]
295
+
296
+ filters_left = (torch.sin(high) - torch.sin(low)) / (self.n_pi / 2.0)
297
+ filters_left *= self.window
298
+ filters_left /= 2.0 * bandwidths
299
+
300
+ # [1, kernel_size // 2, 1, num_filters]
301
+ filters_left = filters_left.unsqueeze(0).unsqueeze(2)
302
+ filters_right = torch.flip(filters_left, dims=[1])
303
+
304
+ filters = torch.cat(
305
+ [filters_left, self.ones, filters_right], dim=1
306
+ ) # [1, kernel_size, 1, num_filters]
307
+ filters = filters / torch.std(filters)
308
+ return filters
309
+
310
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
311
+ """
312
+ Apply sinc filters to the input signal.
313
+
314
+ Parameters
315
+ ----------
316
+ inputs : torch.Tensor
317
+ Input tensor of shape [batch_size, num_channels, num_samples, 1].
318
+
319
+ Returns
320
+ -------
321
+ torch.Tensor
322
+ Filtered output tensor of shape [batch_size, num_channels, num_samples, num_filters].
323
+ """
324
+ filters = self.build_sinc_filters().to(
325
+ inputs.device
326
+ ) # [1, kernel_size, 1, num_filters]
327
+
328
+ # Convert from channels_last to channels_first format
329
+ inputs = inputs.permute(0, 3, 1, 2)
330
+ # Permuting to match conv:
331
+ filters = filters.permute(3, 2, 0, 1)
332
+ # Apply convolution
333
+ outputs = F.conv2d(inputs, filters, padding=self.padding)
334
+ # Changing the dimensional
335
+ outputs = outputs.permute(0, 2, 3, 1)
336
+
337
+ return outputs
@@ -0,0 +1,167 @@
1
+ # Authors: Divyesh Narayanan <divyesh.narayanan@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from braindecode.models.base import EEGModuleMixin
9
+
10
+
11
+ class SleepStagerBlanco2020(EEGModuleMixin, nn.Module):
12
+ """Sleep staging architecture from Blanco et al. (2020) from [Blanco2020]_
13
+
14
+ .. figure:: https://media.springernature.com/full/springer-static/image/art%3A10.1007%2Fs00500-019-04174-1/MediaObjects/500_2019_4174_Fig2_HTML.png
15
+ :align: center
16
+ :alt: SleepStagerBlanco2020 Architecture
17
+
18
+ Convolutional neural network for sleep staging described in [Blanco2020]_.
19
+ A series of seven convolutional layers with kernel sizes running down from 7 to 3,
20
+ in an attempt to extract more general features at the beginning, while more specific
21
+ and complex features were extracted in the final stages.
22
+
23
+ Parameters
24
+ ----------
25
+ n_conv_chans : int
26
+ Number of convolutional channels. Set to 20 in [Blanco2020]_.
27
+ n_groups : int
28
+ Number of groups for the convolution. Set to 2 in [Blanco2020]_ for 2 Channel EEG.
29
+ controls the connections between inputs and outputs. n_channels and n_conv_chans must be
30
+ divisible by n_groups.
31
+ drop_prob : float
32
+ Dropout rate before the output dense layer.
33
+ apply_batch_norm : bool
34
+ If True, apply batch normalization after both temporal convolutional
35
+ layers.
36
+ return_feats : bool
37
+ If True, return the features, i.e. the output of the feature extractor
38
+ (before the final linear layer). If False, pass the features through
39
+ the final linear layer.
40
+ n_channels : int
41
+ Alias for `n_chans`.
42
+ n_classes : int
43
+ Alias for `n_outputs`.
44
+ input_size_s : float
45
+ Alias for `input_window_seconds`.
46
+ activation: nn.Module, default=nn.ReLU
47
+ Activation function class to apply. Should be a PyTorch activation
48
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
49
+
50
+ References
51
+ ----------
52
+ .. [Blanco2020] Fernandez-Blanco, E., Rivero, D. & Pazos, A. Convolutional
53
+ neural networks for sleep stage scoring on a two-channel EEG signal.
54
+ Soft Comput 24, 4067–4079 (2020). https://doi.org/10.1007/s00500-019-04174-1
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ n_chans=None,
60
+ sfreq=None,
61
+ n_conv_chans=20,
62
+ input_window_seconds=None,
63
+ n_outputs=5,
64
+ n_groups=2,
65
+ max_pool_size=2,
66
+ drop_prob=0.5,
67
+ apply_batch_norm=False,
68
+ return_feats=False,
69
+ activation: nn.Module = nn.ReLU,
70
+ chs_info=None,
71
+ n_times=None,
72
+ ):
73
+ super().__init__(
74
+ n_outputs=n_outputs,
75
+ n_chans=n_chans,
76
+ chs_info=chs_info,
77
+ n_times=n_times,
78
+ input_window_seconds=input_window_seconds,
79
+ sfreq=sfreq,
80
+ )
81
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
82
+
83
+ self.mapping = {
84
+ "fc.1.weight": "final_layer.1.weight",
85
+ "fc.1.bias": "final_layer.1.bias",
86
+ }
87
+
88
+ batch_norm = nn.BatchNorm2d if apply_batch_norm else nn.Identity
89
+
90
+ self.feature_extractor = nn.Sequential(
91
+ nn.Conv2d(self.n_chans, n_conv_chans, (1, 7), groups=n_groups, padding=0),
92
+ batch_norm(n_conv_chans),
93
+ activation(),
94
+ nn.MaxPool2d((1, max_pool_size)),
95
+ nn.Conv2d(
96
+ n_conv_chans, n_conv_chans, (1, 7), groups=n_conv_chans, padding=0
97
+ ),
98
+ batch_norm(n_conv_chans),
99
+ activation(),
100
+ nn.MaxPool2d((1, max_pool_size)),
101
+ nn.Conv2d(
102
+ n_conv_chans, n_conv_chans, (1, 5), groups=n_conv_chans, padding=0
103
+ ),
104
+ batch_norm(n_conv_chans),
105
+ activation(),
106
+ nn.MaxPool2d((1, max_pool_size)),
107
+ nn.Conv2d(
108
+ n_conv_chans, n_conv_chans, (1, 5), groups=n_conv_chans, padding=0
109
+ ),
110
+ batch_norm(n_conv_chans),
111
+ activation(),
112
+ nn.MaxPool2d((1, max_pool_size)),
113
+ nn.Conv2d(
114
+ n_conv_chans, n_conv_chans, (1, 5), groups=n_conv_chans, padding=0
115
+ ),
116
+ batch_norm(n_conv_chans),
117
+ activation(),
118
+ nn.MaxPool2d((1, max_pool_size)),
119
+ nn.Conv2d(
120
+ n_conv_chans, n_conv_chans, (1, 3), groups=n_conv_chans, padding=0
121
+ ),
122
+ batch_norm(n_conv_chans),
123
+ activation(),
124
+ nn.MaxPool2d((1, max_pool_size)),
125
+ nn.Conv2d(
126
+ n_conv_chans, n_conv_chans, (1, 3), groups=n_conv_chans, padding=0
127
+ ),
128
+ batch_norm(n_conv_chans),
129
+ activation(),
130
+ nn.MaxPool2d((1, max_pool_size)),
131
+ )
132
+
133
+ self.len_last_layer = self._len_last_layer(self.n_chans, self.n_times)
134
+ self.return_feats = return_feats
135
+
136
+ # TODO: Add new way to handle return_features == True
137
+ if not return_feats:
138
+ self.final_layer = nn.Sequential(
139
+ nn.Dropout(drop_prob),
140
+ nn.Linear(self.len_last_layer, self.n_outputs),
141
+ nn.Identity(),
142
+ )
143
+
144
+ def _len_last_layer(self, n_channels, input_size):
145
+ self.feature_extractor.eval()
146
+ with torch.no_grad():
147
+ out = self.feature_extractor(
148
+ torch.Tensor(1, n_channels, 1, input_size)
149
+ ) # batch_size,n_channels,height,width
150
+ self.feature_extractor.train()
151
+ return len(out.flatten())
152
+
153
+ def forward(self, x):
154
+ """Forward pass.
155
+ Parameters
156
+ ----------
157
+ x: torch.Tensor
158
+ Batch of EEG windows of shape (batch_size, n_channels, n_times).
159
+ """
160
+ if x.ndim == 3:
161
+ x = x.unsqueeze(2)
162
+
163
+ feats = self.feature_extractor(x).flatten(start_dim=1)
164
+ if self.return_feats:
165
+ return feats
166
+ else:
167
+ return self.final_layer(feats)
@@ -0,0 +1,157 @@
1
+ # Authors: Hubert Banville <hubert.jbanville@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+ import math
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+ from braindecode.models.base import EEGModuleMixin
11
+
12
+
13
+ class SleepStagerChambon2018(EEGModuleMixin, nn.Module):
14
+ """Sleep staging architecture from Chambon et al. (2018) [Chambon2018]_.
15
+
16
+ .. figure:: https://braindecode.org/dev/_static/model/SleepStagerChambon2018.jpg
17
+ :align: center
18
+ :alt: SleepStagerChambon2018 Architecture
19
+
20
+ Convolutional neural network for sleep staging described in [Chambon2018]_.
21
+
22
+ Parameters
23
+ ----------
24
+ n_conv_chs : int
25
+ Number of convolutional channels. Set to 8 in [Chambon2018]_.
26
+ time_conv_size_s : float
27
+ Size of filters in temporal convolution layers, in seconds. Set to 0.5
28
+ in [Chambon2018]_ (64 samples at sfreq=128).
29
+ max_pool_size_s : float
30
+ Max pooling size, in seconds. Set to 0.125 in [Chambon2018]_ (16
31
+ samples at sfreq=128).
32
+ pad_size_s : float
33
+ Padding size, in seconds. Set to 0.25 in [Chambon2018]_ (half the
34
+ temporal convolution kernel size).
35
+ drop_prob : float
36
+ Dropout rate before the output dense layer.
37
+ apply_batch_norm : bool
38
+ If True, apply batch normalization after both temporal convolutional
39
+ layers.
40
+ return_feats : bool
41
+ If True, return the features, i.e. the output of the feature extractor
42
+ (before the final linear layer). If False, pass the features through
43
+ the final linear layer.
44
+ n_channels : int
45
+ Alias for `n_chans`.
46
+ input_size_s:
47
+ Alias for `input_window_seconds`.
48
+ n_classes:
49
+ Alias for `n_outputs`.
50
+ activation: nn.Module, default=nn.ReLU
51
+ Activation function class to apply. Should be a PyTorch activation
52
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
53
+
54
+ References
55
+ ----------
56
+ .. [Chambon2018] Chambon, S., Galtier, M. N., Arnal, P. J., Wainrib, G., &
57
+ Gramfort, A. (2018). A deep learning architecture for temporal sleep
58
+ stage classification using multivariate and multimodal time series.
59
+ IEEE Transactions on Neural Systems and Rehabilitation Engineering,
60
+ 26(4), 758-769.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ n_chans=None,
66
+ sfreq=None,
67
+ n_conv_chs=8,
68
+ time_conv_size_s=0.5,
69
+ max_pool_size_s=0.125,
70
+ pad_size_s=0.25,
71
+ activation: nn.Module = nn.ReLU,
72
+ input_window_seconds=None,
73
+ n_outputs=5,
74
+ drop_prob=0.25,
75
+ apply_batch_norm=False,
76
+ return_feats=False,
77
+ chs_info=None,
78
+ n_times=None,
79
+ ):
80
+ super().__init__(
81
+ n_outputs=n_outputs,
82
+ n_chans=n_chans,
83
+ chs_info=chs_info,
84
+ n_times=n_times,
85
+ input_window_seconds=input_window_seconds,
86
+ sfreq=sfreq,
87
+ )
88
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
89
+
90
+ self.mapping = {
91
+ "fc.1.weight": "final_layer.1.weight",
92
+ "fc.1.bias": "final_layer.1.bias",
93
+ }
94
+
95
+ time_conv_size = math.ceil(time_conv_size_s * self.sfreq)
96
+ max_pool_size = math.ceil(max_pool_size_s * self.sfreq)
97
+ pad_size = math.ceil(pad_size_s * self.sfreq)
98
+
99
+ if self.n_chans > 1:
100
+ self.spatial_conv = nn.Conv2d(1, self.n_chans, (self.n_chans, 1))
101
+ else:
102
+ self.spatial_conv = nn.Identity()
103
+
104
+ batch_norm = nn.BatchNorm2d if apply_batch_norm else nn.Identity
105
+
106
+ self.feature_extractor = nn.Sequential(
107
+ nn.Conv2d(1, n_conv_chs, (1, time_conv_size), padding=(0, pad_size)),
108
+ batch_norm(n_conv_chs),
109
+ activation(),
110
+ nn.MaxPool2d((1, max_pool_size)),
111
+ nn.Conv2d(
112
+ n_conv_chs, n_conv_chs, (1, time_conv_size), padding=(0, pad_size)
113
+ ),
114
+ batch_norm(n_conv_chs),
115
+ activation(),
116
+ nn.MaxPool2d((1, max_pool_size)),
117
+ )
118
+ self.return_feats = return_feats
119
+
120
+ dim_conv_1 = (
121
+ self.n_times + 2 * pad_size - (time_conv_size - 1)
122
+ ) // max_pool_size
123
+ dim_after_conv = (
124
+ dim_conv_1 + 2 * pad_size - (time_conv_size - 1)
125
+ ) // max_pool_size
126
+
127
+ self.len_last_layer = n_conv_chs * self.n_chans * dim_after_conv
128
+
129
+ # TODO: Add new way to handle return_features == True
130
+ if not return_feats:
131
+ self.final_layer = nn.Sequential(
132
+ nn.Dropout(p=drop_prob),
133
+ nn.Linear(in_features=self.len_last_layer, out_features=self.n_outputs),
134
+ )
135
+
136
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
137
+ """
138
+ Forward pass.
139
+
140
+ Parameters
141
+ ----------
142
+ x: torch.Tensor
143
+ Batch of EEG windows of shape (batch_size, n_channels, n_times).
144
+ """
145
+ if x.ndim == 3:
146
+ x = x.unsqueeze(1)
147
+
148
+ if self.n_chans > 1:
149
+ x = self.spatial_conv(x)
150
+ x = x.transpose(1, 2)
151
+
152
+ feats = self.feature_extractor(x).flatten(start_dim=1)
153
+
154
+ if self.return_feats:
155
+ return feats
156
+
157
+ return self.final_layer(feats)