braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,339 @@
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops.layers.torch import Rearrange
7
+
8
+ from braindecode.models.base import EEGModuleMixin
9
+
10
+
11
+ class SincShallowNet(EEGModuleMixin, nn.Module):
12
+ r"""Sinc-ShallowNet from Borra, D et al (2020) [borra2020]_.
13
+
14
+ :bdg-success:`Convolution` :bdg-warning:`Interpretability`
15
+
16
+ .. figure:: https://ars.els-cdn.com/content/image/1-s2.0-S0893608020302021-gr2_lrg.jpg
17
+ :align: center
18
+ :alt: SincShallowNet Architecture
19
+
20
+ The Sinc-ShallowNet architecture has these fundamental blocks:
21
+
22
+ 1. **Block 1: Spectral and Spatial Feature Extraction**
23
+
24
+ - *Temporal Sinc-Convolutional Layer*: Uses parametrized sinc functions to learn band-pass filters,
25
+ significantly reducing the number of trainable parameters by only
26
+ learning the lower and upper cutoff frequencies for each filter.
27
+ - *Spatial Depthwise Convolutional Layer*: Applies depthwise convolutions to learn spatial filters for
28
+ each temporal feature map independently, further reducing
29
+ parameters and enhancing interpretability.
30
+ - *Batch Normalization*
31
+
32
+ 2. **Block 2: Temporal Aggregation**
33
+
34
+ - *Activation Function*: ELU
35
+ - *Average Pooling Layer*: Aggregation by averaging spatial dim
36
+ - *Dropout Layer*
37
+ - *Flatten Layer*
38
+
39
+ 3. **Block 3: Classification**
40
+
41
+ - *Fully Connected Layer*: Maps the feature vector to n_outputs.
42
+
43
+ **Implementation Notes:**
44
+
45
+ - The sinc-convolutional layer initializes cutoff frequencies uniformly
46
+ within the desired frequency range and updates them during training while
47
+ ensuring the lower cutoff is less than the upper cutoff.
48
+
49
+ Parameters
50
+ ----------
51
+ num_time_filters : int
52
+ Number of temporal filters in the SincFilter layer.
53
+ time_filter_len : int
54
+ Size of the temporal filters.
55
+ depth_multiplier : int
56
+ Depth multiplier for spatial filtering.
57
+ activation : nn.Module, optional
58
+ Activation function to use. Default is nn.ELU().
59
+ drop_prob : float, optional
60
+ Dropout probability. Default is 0.5.
61
+ first_freq : float, optional
62
+ The starting frequency for the first Sinc filter. Default is 5.0.
63
+ min_freq : float, optional
64
+ Minimum frequency allowed for the low frequencies of the filters. Default is 1.0.
65
+ freq_stride : float, optional
66
+ Frequency stride for the Sinc filters. Controls the spacing between the filter frequencies.
67
+ Default is 1.0.
68
+ padding : str, optional
69
+ Padding mode for convolution, either 'same' or 'valid'. Default is 'same'.
70
+ bandwidth : float, optional
71
+ Initial bandwidth for each Sinc filter. Default is 4.0.
72
+ pool_size : int, optional
73
+ Size of the pooling window for the average pooling layer. Default is 55.
74
+ pool_stride : int, optional
75
+ Stride of the pooling operation. Default is 12.
76
+
77
+ Notes
78
+ -----
79
+ This implementation is based on the implementation from [sincshallowcode]_.
80
+
81
+ References
82
+ ----------
83
+ .. [borra2020] Borra, D., Fantozzi, S., & Magosso, E. (2020). Interpretable
84
+ and lightweight convolutional neural network for EEG decoding: Application
85
+ to movement execution and imagination. Neural Networks, 129, 55-74.
86
+ .. [sincshallowcode] Sinc-ShallowNet re-implementation source code:
87
+ https://github.com/marcellosicbaldi/SincNet-Tensorflow
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ num_time_filters: int = 32,
93
+ time_filter_len: int = 33,
94
+ depth_multiplier: int = 2,
95
+ activation: type[nn.Module] | None = nn.ELU,
96
+ drop_prob: float = 0.5,
97
+ first_freq: float = 5.0,
98
+ min_freq: float = 1.0,
99
+ freq_stride: float = 1.0,
100
+ padding: str = "same",
101
+ bandwidth: float = 4.0,
102
+ pool_size: int = 55,
103
+ pool_stride: int = 12,
104
+ # braindecode parameters
105
+ n_chans=None,
106
+ n_outputs=None,
107
+ n_times=None,
108
+ input_window_seconds=None,
109
+ sfreq=None,
110
+ chs_info=None,
111
+ ):
112
+ super().__init__(
113
+ n_outputs=n_outputs,
114
+ n_chans=n_chans,
115
+ chs_info=chs_info,
116
+ n_times=n_times,
117
+ input_window_seconds=input_window_seconds,
118
+ sfreq=sfreq,
119
+ )
120
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
121
+
122
+ # Checkers and creating variables
123
+ if activation is None:
124
+ activation = nn.ELU()
125
+
126
+ # Define low frequencies for the SincFilter
127
+ low_freqs = torch.arange(
128
+ first_freq,
129
+ first_freq + num_time_filters * freq_stride,
130
+ freq_stride,
131
+ dtype=torch.float32,
132
+ )
133
+ self.n_filters = len(low_freqs)
134
+
135
+ if padding.lower() == "valid":
136
+ n_times_after_sinc_filter = self.n_times - time_filter_len + 1
137
+ elif padding.lower() == "same":
138
+ n_times_after_sinc_filter = self.n_times
139
+ else:
140
+ raise ValueError("Padding must be 'valid' or 'same'.")
141
+
142
+ size_after_pooling = (
143
+ (n_times_after_sinc_filter - pool_size) // pool_stride
144
+ ) + 1
145
+ flattened_size = num_time_filters * depth_multiplier * size_after_pooling
146
+
147
+ # Layers
148
+ self.ensuredims = Rearrange("batch chans times -> batch chans times 1")
149
+
150
+ # Block 1: Sinc filter
151
+ self.sinc_filter_layer = _SincFilter(
152
+ low_freqs=low_freqs,
153
+ kernel_size=time_filter_len,
154
+ sfreq=self.sfreq,
155
+ padding=padding,
156
+ bandwidth=bandwidth,
157
+ min_freq=min_freq,
158
+ )
159
+
160
+ self.depthwiseconv = nn.Sequential(
161
+ # Matching dim to depth wise conv!
162
+ Rearrange("batch timefil time nfilter -> batch nfilter timefil time"),
163
+ nn.BatchNorm2d(
164
+ self.n_filters, momentum=0.99
165
+ ), # To match keras implementation
166
+ nn.Conv2d(
167
+ in_channels=self.n_filters,
168
+ out_channels=depth_multiplier * self.n_filters,
169
+ kernel_size=(self.n_chans, 1),
170
+ groups=self.n_filters,
171
+ bias=False,
172
+ ),
173
+ )
174
+
175
+ # Block 2: Batch norm, activation, pooling, dropout
176
+ self.temporal_aggregation = nn.Sequential(
177
+ nn.BatchNorm2d(depth_multiplier * self.n_filters, momentum=0.99),
178
+ activation(),
179
+ nn.AvgPool2d(kernel_size=(1, pool_size), stride=(1, pool_stride)),
180
+ nn.Dropout(p=drop_prob),
181
+ nn.Flatten(),
182
+ )
183
+
184
+ # Final classification layer
185
+ self.final_layer = nn.Linear(
186
+ flattened_size,
187
+ self.n_outputs,
188
+ )
189
+
190
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
191
+ """
192
+ Forward pass of the model.
193
+
194
+ Parameters
195
+ ----------
196
+ x : torch.Tensor
197
+ Input tensor of shape [batch_size, num_channels, num_samples].
198
+
199
+ Returns
200
+ -------
201
+ torch.Tensor
202
+ Output logits of shape [batch_size, num_classes].
203
+ """
204
+ x = self.ensuredims(x)
205
+ x = self.sinc_filter_layer(x)
206
+ x = self.depthwiseconv(x)
207
+ x = self.temporal_aggregation(x)
208
+
209
+ return self.final_layer(x)
210
+
211
+
212
+ class _SincFilter(nn.Module):
213
+ r"""Sinc-Based Convolutional Layer for Band-Pass Filtering from Ravanelli and Bengio (2018) [ravanelli]_.
214
+
215
+ The `SincFilter` layer implements a convolutional layer where each kernel is
216
+ defined using a parametrized sinc function.
217
+ This design enforces each kernel to represent a band-pass filter,
218
+ reducing the number of trainable parameters.
219
+
220
+ Parameters
221
+ ----------
222
+ low_freqs : torch.Tensor
223
+ Initial low cutoff frequencies for each filter.
224
+ kernel_size : int
225
+ Size of the convolutional kernels (filters). Must be odd.
226
+ sfreq : float
227
+ Sampling rate of the input signal.
228
+ bandwidth : float, optional
229
+ Initial bandwidth for each filter. Default is 4.0.
230
+ min_freq : float, optional
231
+ Minimum frequency allowed for low frequencies. Default is 1.0.
232
+ padding : str, optional
233
+ Padding mode, either 'same' or 'valid'. Default is 'same'.
234
+
235
+ References
236
+ ----------
237
+ .. [ravanelli] Ravanelli, M., & Bengio, Y. (2018, December). Speaker
238
+ recognition from raw waveform with sincnet. In 2018 IEEE spoken language
239
+ technology workshop (SLT) (pp. 1021-1028). IEEE.
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ low_freqs: torch.Tensor,
245
+ kernel_size: int,
246
+ sfreq: float,
247
+ bandwidth: float = 4.0,
248
+ min_freq: float = 1.0,
249
+ padding: str = "same",
250
+ ):
251
+ super().__init__()
252
+ if kernel_size % 2 == 0:
253
+ raise ValueError("Kernel size must be odd.")
254
+
255
+ self.num_filters = low_freqs.numel()
256
+ self.kernel_size = kernel_size
257
+ self.sfreq = sfreq
258
+ self.min_freq = min_freq
259
+ self.padding = padding.lower()
260
+
261
+ # Precompute constants
262
+ window = torch.hamming_window(kernel_size, periodic=False)
263
+
264
+ self.register_buffer("window", window[: kernel_size // 2].unsqueeze(-1))
265
+
266
+ n_pi = (
267
+ torch.arange(-(kernel_size // 2), 0, dtype=torch.float32)
268
+ / sfreq
269
+ * 2
270
+ * math.pi
271
+ )
272
+ self.register_buffer("n_pi", n_pi.unsqueeze(-1))
273
+
274
+ # Initialize learnable parameters
275
+ bandwidths = torch.full((1, self.num_filters), bandwidth)
276
+ self.bandwidths = nn.Parameter(bandwidths)
277
+ self.low_freqs = nn.Parameter(low_freqs.unsqueeze(0))
278
+
279
+ # Constant tensor of ones for filter construction
280
+ self.register_buffer("ones", torch.ones(1, 1, 1, self.num_filters))
281
+
282
+ def build_sinc_filters(self) -> torch.Tensor:
283
+ """Builds the sinc filters based on current parameters."""
284
+ # Computing the low frequencies of the filters
285
+ low_freqs = self.min_freq + torch.abs(self.low_freqs)
286
+ # Setting a minimum band and minimum freq
287
+ high_freqs = torch.clamp(
288
+ low_freqs + torch.abs(self.bandwidths),
289
+ min=self.min_freq,
290
+ max=self.sfreq / 2.0,
291
+ )
292
+ bandwidths = high_freqs - low_freqs
293
+
294
+ # Passing from n_ to the corresponding f_times_t domain
295
+ low = self.n_pi * low_freqs # [kernel_size // 2, num_filters]
296
+ high = self.n_pi * high_freqs # [kernel_size // 2, num_filters]
297
+
298
+ filters_left = (torch.sin(high) - torch.sin(low)) / (self.n_pi / 2.0)
299
+ filters_left *= self.window
300
+ filters_left /= 2.0 * bandwidths
301
+
302
+ # [1, kernel_size // 2, 1, num_filters]
303
+ filters_left = filters_left.unsqueeze(0).unsqueeze(2)
304
+ filters_right = torch.flip(filters_left, dims=[1])
305
+
306
+ filters = torch.cat(
307
+ [filters_left, self.ones, filters_right], dim=1
308
+ ) # [1, kernel_size, 1, num_filters]
309
+ filters = filters / torch.std(filters)
310
+ return filters
311
+
312
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
313
+ """
314
+ Apply sinc filters to the input signal.
315
+
316
+ Parameters
317
+ ----------
318
+ inputs : torch.Tensor
319
+ Input tensor of shape [batch_size, num_channels, num_samples, 1].
320
+
321
+ Returns
322
+ -------
323
+ torch.Tensor
324
+ Filtered output tensor of shape [batch_size, num_channels, num_samples, num_filters].
325
+ """
326
+ filters = self.build_sinc_filters().to(
327
+ inputs.device
328
+ ) # [1, kernel_size, 1, num_filters]
329
+
330
+ # Convert from channels_last to channels_first format
331
+ inputs = inputs.permute(0, 3, 1, 2)
332
+ # Permuting to match conv:
333
+ filters = filters.permute(3, 2, 0, 1)
334
+ # Apply convolution
335
+ outputs = F.conv2d(inputs, filters, padding=self.padding)
336
+ # Changing the dimensional
337
+ outputs = outputs.permute(0, 2, 3, 1)
338
+
339
+ return outputs
@@ -0,0 +1,169 @@
1
+ # Authors: Divyesh Narayanan <divyesh.narayanan@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from braindecode.models.base import EEGModuleMixin
9
+
10
+
11
+ class SleepStagerBlanco2020(EEGModuleMixin, nn.Module):
12
+ r"""Sleep staging architecture from Blanco et al (2020) from [Blanco2020]_
13
+
14
+ :bdg-success:`Convolution`
15
+
16
+ .. figure:: https://media.springernature.com/full/springer-static/image/art%3A10.1007%2Fs00500-019-04174-1/MediaObjects/500_2019_4174_Fig2_HTML.png
17
+ :align: center
18
+ :alt: SleepStagerBlanco2020 Architecture
19
+
20
+ Convolutional neural network for sleep staging described in [Blanco2020]_.
21
+ A series of seven convolutional layers with kernel sizes running down from 7 to 3,
22
+ in an attempt to extract more general features at the beginning, while more specific
23
+ and complex features were extracted in the final stages.
24
+
25
+ Parameters
26
+ ----------
27
+ n_conv_chans : int
28
+ Number of convolutional channels. Set to 20 in [Blanco2020]_.
29
+ n_groups : int
30
+ Number of groups for the convolution. Set to 2 in [Blanco2020]_ for 2 Channel EEG.
31
+ controls the connections between inputs and outputs. n_channels and n_conv_chans must be
32
+ divisible by n_groups.
33
+ drop_prob : float
34
+ Dropout rate before the output dense layer.
35
+ apply_batch_norm : bool
36
+ If True, apply batch normalization after both temporal convolutional
37
+ layers.
38
+ return_feats : bool
39
+ If True, return the features, i.e. the output of the feature extractor
40
+ (before the final linear layer). If False, pass the features through
41
+ the final linear layer.
42
+ n_channels : int
43
+ Alias for `n_chans`.
44
+ n_classes : int
45
+ Alias for `n_outputs`.
46
+ input_size_s : float
47
+ Alias for `input_window_seconds`.
48
+ activation: nn.Module, default=nn.ReLU
49
+ Activation function class to apply. Should be a PyTorch activation
50
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
51
+
52
+ References
53
+ ----------
54
+ .. [Blanco2020] Fernandez-Blanco, E., Rivero, D. & Pazos, A. Convolutional
55
+ neural networks for sleep stage scoring on a two-channel EEG signal.
56
+ Soft Comput 24, 4067–4079 (2020). https://doi.org/10.1007/s00500-019-04174-1
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ n_chans=None,
62
+ sfreq=None,
63
+ n_conv_chans=20,
64
+ input_window_seconds=None,
65
+ n_outputs=5,
66
+ n_groups=2,
67
+ max_pool_size=2,
68
+ drop_prob=0.5,
69
+ apply_batch_norm=False,
70
+ return_feats=False,
71
+ activation: type[nn.Module] = nn.ReLU,
72
+ chs_info=None,
73
+ n_times=None,
74
+ ):
75
+ super().__init__(
76
+ n_outputs=n_outputs,
77
+ n_chans=n_chans,
78
+ chs_info=chs_info,
79
+ n_times=n_times,
80
+ input_window_seconds=input_window_seconds,
81
+ sfreq=sfreq,
82
+ )
83
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
84
+
85
+ self.mapping = {
86
+ "fc.1.weight": "final_layer.1.weight",
87
+ "fc.1.bias": "final_layer.1.bias",
88
+ }
89
+
90
+ batch_norm = nn.BatchNorm2d if apply_batch_norm else nn.Identity
91
+
92
+ self.feature_extractor = nn.Sequential(
93
+ nn.Conv2d(self.n_chans, n_conv_chans, (1, 7), groups=n_groups, padding=0),
94
+ batch_norm(n_conv_chans),
95
+ activation(),
96
+ nn.MaxPool2d((1, max_pool_size)),
97
+ nn.Conv2d(
98
+ n_conv_chans, n_conv_chans, (1, 7), groups=n_conv_chans, padding=0
99
+ ),
100
+ batch_norm(n_conv_chans),
101
+ activation(),
102
+ nn.MaxPool2d((1, max_pool_size)),
103
+ nn.Conv2d(
104
+ n_conv_chans, n_conv_chans, (1, 5), groups=n_conv_chans, padding=0
105
+ ),
106
+ batch_norm(n_conv_chans),
107
+ activation(),
108
+ nn.MaxPool2d((1, max_pool_size)),
109
+ nn.Conv2d(
110
+ n_conv_chans, n_conv_chans, (1, 5), groups=n_conv_chans, padding=0
111
+ ),
112
+ batch_norm(n_conv_chans),
113
+ activation(),
114
+ nn.MaxPool2d((1, max_pool_size)),
115
+ nn.Conv2d(
116
+ n_conv_chans, n_conv_chans, (1, 5), groups=n_conv_chans, padding=0
117
+ ),
118
+ batch_norm(n_conv_chans),
119
+ activation(),
120
+ nn.MaxPool2d((1, max_pool_size)),
121
+ nn.Conv2d(
122
+ n_conv_chans, n_conv_chans, (1, 3), groups=n_conv_chans, padding=0
123
+ ),
124
+ batch_norm(n_conv_chans),
125
+ activation(),
126
+ nn.MaxPool2d((1, max_pool_size)),
127
+ nn.Conv2d(
128
+ n_conv_chans, n_conv_chans, (1, 3), groups=n_conv_chans, padding=0
129
+ ),
130
+ batch_norm(n_conv_chans),
131
+ activation(),
132
+ nn.MaxPool2d((1, max_pool_size)),
133
+ )
134
+
135
+ self.len_last_layer = self._len_last_layer(self.n_chans, self.n_times)
136
+ self.return_feats = return_feats
137
+
138
+ # TODO: Add new way to handle return_features == True
139
+ if not return_feats:
140
+ self.final_layer = nn.Sequential(
141
+ nn.Dropout(drop_prob),
142
+ nn.Linear(self.len_last_layer, self.n_outputs),
143
+ nn.Identity(),
144
+ )
145
+
146
+ def _len_last_layer(self, n_channels, input_size):
147
+ self.feature_extractor.eval()
148
+ with torch.no_grad():
149
+ out = self.feature_extractor(
150
+ torch.Tensor(1, n_channels, 1, input_size)
151
+ ) # batch_size,n_channels,height,width
152
+ self.feature_extractor.train()
153
+ return len(out.flatten())
154
+
155
+ def forward(self, x):
156
+ """Forward pass.
157
+ Parameters
158
+ ----------
159
+ x: torch.Tensor
160
+ Batch of EEG windows of shape (batch_size, n_channels, n_times).
161
+ """
162
+ if x.ndim == 3:
163
+ x = x.unsqueeze(2)
164
+
165
+ feats = self.feature_extractor(x).flatten(start_dim=1)
166
+ if self.return_feats:
167
+ return feats
168
+ else:
169
+ return self.final_layer(feats)
@@ -0,0 +1,159 @@
1
+ # Authors: Hubert Banville <hubert.jbanville@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+ import math
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+ from braindecode.models.base import EEGModuleMixin
11
+
12
+
13
+ class SleepStagerChambon2018(EEGModuleMixin, nn.Module):
14
+ r"""Sleep staging architecture from Chambon et al. (2018) [Chambon2018]_.
15
+
16
+ :bdg-success:`Convolution`
17
+
18
+ .. figure:: https://braindecode.org/dev/_static/model/SleepStagerChambon2018.jpg
19
+ :align: center
20
+ :alt: SleepStagerChambon2018 Architecture
21
+
22
+ Convolutional neural network for sleep staging described in [Chambon2018]_.
23
+
24
+ Parameters
25
+ ----------
26
+ n_conv_chs : int
27
+ Number of convolutional channels. Set to 8 in [Chambon2018]_.
28
+ time_conv_size_s : float
29
+ Size of filters in temporal convolution layers, in seconds. Set to 0.5
30
+ in [Chambon2018]_ (64 samples at sfreq=128).
31
+ max_pool_size_s : float
32
+ Max pooling size, in seconds. Set to 0.125 in [Chambon2018]_ (16
33
+ samples at sfreq=128).
34
+ pad_size_s : float
35
+ Padding size, in seconds. Set to 0.25 in [Chambon2018]_ (half the
36
+ temporal convolution kernel size).
37
+ drop_prob : float
38
+ Dropout rate before the output dense layer.
39
+ apply_batch_norm : bool
40
+ If True, apply batch normalization after both temporal convolutional
41
+ layers.
42
+ return_feats : bool
43
+ If True, return the features, i.e. the output of the feature extractor
44
+ (before the final linear layer). If False, pass the features through
45
+ the final linear layer.
46
+ n_channels : int
47
+ Alias for `n_chans`.
48
+ input_size_s:
49
+ Alias for `input_window_seconds`.
50
+ n_classes:
51
+ Alias for `n_outputs`.
52
+ activation: nn.Module, default=nn.ReLU
53
+ Activation function class to apply. Should be a PyTorch activation
54
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
55
+
56
+ References
57
+ ----------
58
+ .. [Chambon2018] Chambon, S., Galtier, M. N., Arnal, P. J., Wainrib, G., &
59
+ Gramfort, A. (2018). A deep learning architecture for temporal sleep
60
+ stage classification using multivariate and multimodal time series.
61
+ IEEE Transactions on Neural Systems and Rehabilitation Engineering,
62
+ 26(4), 758-769.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ n_chans=None,
68
+ sfreq=None,
69
+ n_conv_chs=8,
70
+ time_conv_size_s=0.5,
71
+ max_pool_size_s=0.125,
72
+ pad_size_s=0.25,
73
+ activation: type[nn.Module] = nn.ReLU,
74
+ input_window_seconds=None,
75
+ n_outputs=5,
76
+ drop_prob=0.25,
77
+ apply_batch_norm=False,
78
+ return_feats=False,
79
+ chs_info=None,
80
+ n_times=None,
81
+ ):
82
+ super().__init__(
83
+ n_outputs=n_outputs,
84
+ n_chans=n_chans,
85
+ chs_info=chs_info,
86
+ n_times=n_times,
87
+ input_window_seconds=input_window_seconds,
88
+ sfreq=sfreq,
89
+ )
90
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
91
+
92
+ self.mapping = {
93
+ "fc.1.weight": "final_layer.1.weight",
94
+ "fc.1.bias": "final_layer.1.bias",
95
+ }
96
+
97
+ time_conv_size = math.ceil(time_conv_size_s * self.sfreq)
98
+ max_pool_size = math.ceil(max_pool_size_s * self.sfreq)
99
+ pad_size = math.ceil(pad_size_s * self.sfreq)
100
+
101
+ if self.n_chans > 1:
102
+ self.spatial_conv = nn.Conv2d(1, self.n_chans, (self.n_chans, 1))
103
+ else:
104
+ self.spatial_conv = nn.Identity()
105
+
106
+ batch_norm = nn.BatchNorm2d if apply_batch_norm else nn.Identity
107
+
108
+ self.feature_extractor = nn.Sequential(
109
+ nn.Conv2d(1, n_conv_chs, (1, time_conv_size), padding=(0, pad_size)),
110
+ batch_norm(n_conv_chs),
111
+ activation(),
112
+ nn.MaxPool2d((1, max_pool_size)),
113
+ nn.Conv2d(
114
+ n_conv_chs, n_conv_chs, (1, time_conv_size), padding=(0, pad_size)
115
+ ),
116
+ batch_norm(n_conv_chs),
117
+ activation(),
118
+ nn.MaxPool2d((1, max_pool_size)),
119
+ )
120
+ self.return_feats = return_feats
121
+
122
+ dim_conv_1 = (
123
+ self.n_times + 2 * pad_size - (time_conv_size - 1)
124
+ ) // max_pool_size
125
+ dim_after_conv = (
126
+ dim_conv_1 + 2 * pad_size - (time_conv_size - 1)
127
+ ) // max_pool_size
128
+
129
+ self.len_last_layer = n_conv_chs * self.n_chans * dim_after_conv
130
+
131
+ # TODO: Add new way to handle return_features == True
132
+ if not return_feats:
133
+ self.final_layer = nn.Sequential(
134
+ nn.Dropout(p=drop_prob),
135
+ nn.Linear(in_features=self.len_last_layer, out_features=self.n_outputs),
136
+ )
137
+
138
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
139
+ """
140
+ Forward pass.
141
+
142
+ Parameters
143
+ ----------
144
+ x: torch.Tensor
145
+ Batch of EEG windows of shape (batch_size, n_channels, n_times).
146
+ """
147
+ if x.ndim == 3:
148
+ x = x.unsqueeze(1)
149
+
150
+ if self.n_chans > 1:
151
+ x = self.spatial_conv(x)
152
+ x = x.transpose(1, 2)
153
+
154
+ feats = self.feature_extractor(x).flatten(start_dim=1)
155
+
156
+ if self.return_feats:
157
+ return feats
158
+
159
+ return self.final_layer(feats)