braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,280 @@
|
|
|
1
|
+
# Authors: Chun-Shu Wei
|
|
2
|
+
# Bruno Aristimunha <b.aristimunha@gmail.com> (braindecode adaptation)
|
|
3
|
+
#
|
|
4
|
+
# License: BSD (3-clause)
|
|
5
|
+
|
|
6
|
+
import math
|
|
7
|
+
from warnings import warn
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from einops.layers.torch import Rearrange
|
|
11
|
+
from torch import nn
|
|
12
|
+
|
|
13
|
+
from braindecode.models.base import EEGModuleMixin
|
|
14
|
+
from braindecode.modules import LogActivation
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SCCNet(EEGModuleMixin, nn.Module):
|
|
18
|
+
r"""SCCNet from Wei, C S (2019) [sccnet]_.
|
|
19
|
+
|
|
20
|
+
:bdg-success:`Convolution`
|
|
21
|
+
|
|
22
|
+
Spatial component-wise convolutional network (SCCNet) for motor-imagery EEG
|
|
23
|
+
classification.
|
|
24
|
+
|
|
25
|
+
.. figure:: https://dt5vp8kor0orz.cloudfront.net/6e3ec5d729cd51fe8acc5a978db27d02a5df9e05/2-Figure1-1.png
|
|
26
|
+
:align: center
|
|
27
|
+
:alt: Spatial component-wise convolutional network
|
|
28
|
+
:width: 680px
|
|
29
|
+
|
|
30
|
+
.. rubric:: Architectural Overview
|
|
31
|
+
|
|
32
|
+
SCCNet is a spatial-first convolutional layer that fixes temporal kernels in seconds
|
|
33
|
+
to make its filters correspond to neurophysiologically aligned windows. The model
|
|
34
|
+
comprises four stages:
|
|
35
|
+
|
|
36
|
+
1. **Spatial Component Analysis**: Performs convolution spatial filtering
|
|
37
|
+
across all EEG channels to extract spatial components, effectively
|
|
38
|
+
reducing the channel dimension.
|
|
39
|
+
2. **Spatio-Temporal Filtering**: Applies convolution across the spatial
|
|
40
|
+
components and temporal domain to capture spatio-temporal patterns.
|
|
41
|
+
3. **Temporal Smoothing (Pooling)**: Uses average pooling over time to smooth the
|
|
42
|
+
features and reduce the temporal dimension, focusing on longer-term patterns.
|
|
43
|
+
4. **Classification**: Flattens the features and applies a fully connected
|
|
44
|
+
layer.
|
|
45
|
+
|
|
46
|
+
.. rubric:: Macro Components
|
|
47
|
+
|
|
48
|
+
- `SCCNet.spatial_conv` **(spatial component analysis)**
|
|
49
|
+
|
|
50
|
+
- *Operations.*
|
|
51
|
+
- :class:`~torch.nn.Conv2d` with kernel `(n_chans, N_t)` and stride `(1, 1)` on an input reshaped to `(B, 1, n_chans, T)`; typical choice `N_t=1` yields a pure across-channel projection (montage-wide linear spatial filter).
|
|
52
|
+
- Zero padding to preserve time, :class:`~torch.nn.BatchNorm2d`; output has `N_u` component signals shaped `(B, 1, N_u, T)` after a permute step.
|
|
53
|
+
|
|
54
|
+
*Interpretability/robustness.* Mimics CSP-like spatial filtering: each learned filter is a channel-weighted component, easing inspection and reducing channel noise.
|
|
55
|
+
|
|
56
|
+
- `SCCNet.spatial_filt_conv` **(spatio-temporal filtering)**
|
|
57
|
+
|
|
58
|
+
- *Operations.*
|
|
59
|
+
- :class:`~torch.nn.Conv2d` with kernel `(N_u, 12)` over components and time (12 samples ~ 0.1 s at 125 Hz),
|
|
60
|
+
- :class:`~torch.nn.BatchNorm2d`;
|
|
61
|
+
- Nonlinearity is **power-like**: the original paper uses **square** like :class:`~braindecode.models.ShallowFBCSPNet` with the class :class:`~braindecode.modules.LogActivation` as default.
|
|
62
|
+
- :class:`~torch.nn.Dropout` with rate `p=0.5`.
|
|
63
|
+
|
|
64
|
+
- *Role.* Learns frequency-selective energy features and inter-component interactions within a 0.1 s context (beta/alpha cycle scale).
|
|
65
|
+
|
|
66
|
+
- `SCCNet.temporal_smoothing` **(aggregation + readout)**
|
|
67
|
+
|
|
68
|
+
- *Operations.*
|
|
69
|
+
- :class:`~torch.nn.AvgPool2d` with size `(1, 62)` (~ 0.5 s) for temporal smoothing and downsampling
|
|
70
|
+
- :class:`~torch.nn.Flatten`
|
|
71
|
+
- :class:`~torch.nn.Linear` to `n_outputs`.
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
.. rubric:: Convolutional Details
|
|
75
|
+
|
|
76
|
+
* **Temporal (where time-domain patterns are learned).**
|
|
77
|
+
The second block's kernel length is fixed to 12 samples (≈ 100 ms) and slides with
|
|
78
|
+
stride 1; average pooling `(1, 62)` (≈ 500 ms) integrates power over longer spans.
|
|
79
|
+
These choices bake in short-cycle detection followed by half-second trend smoothing.
|
|
80
|
+
|
|
81
|
+
* **Spatial (how electrodes are processed).**
|
|
82
|
+
The first block's kernel spans **all electrodes** `(n_chans, N_t)`. With `N_t=1`,
|
|
83
|
+
it reduces to a montage-wide linear projection, mapping channels → `N_u` components.
|
|
84
|
+
The second block mixes **across components** via kernel height `N_u`.
|
|
85
|
+
|
|
86
|
+
* **Spectral (how frequency information is captured).**
|
|
87
|
+
No explicit transform is used; learned **temporal kernels** serve as bandpass-like
|
|
88
|
+
filters, and the **square/log power** nonlinearity plus 0.5 s averaging approximate
|
|
89
|
+
band-power estimation (ERD/ERS-style features).
|
|
90
|
+
|
|
91
|
+
.. rubric:: Attention / Sequential Modules
|
|
92
|
+
|
|
93
|
+
This model contains **no attention** and **no recurrent units**.
|
|
94
|
+
|
|
95
|
+
.. rubric:: Additional Mechanisms
|
|
96
|
+
|
|
97
|
+
- :class:`~torch.nn.BatchNorm2d` and zero-padding are applied to both convolutions;
|
|
98
|
+
L2 weight decay was used in the original paper; dropout `p=0.5` combats overfitting.
|
|
99
|
+
- Contrasting with other compact neural network, in EEGNet performs a temporal depthwise conv
|
|
100
|
+
followed by a **depthwise spatial** conv (separable), learning temporal filters first.
|
|
101
|
+
SCCNet inverts this order: it performs a **full spatial projection first** (CSP-like),
|
|
102
|
+
then a short **spatio-temporal** conv with an explicit 0.1 s kernel, followed by
|
|
103
|
+
**power-like** nonlinearity and longer temporal averaging. EEGNet's ELU and
|
|
104
|
+
separable design favor parameter efficiency; SCCNet's second-scale kernels and
|
|
105
|
+
square/log emphasize interpretable **band-power** features.
|
|
106
|
+
|
|
107
|
+
- Reference implementation: see [sccnetcode]_.
|
|
108
|
+
|
|
109
|
+
.. rubric:: Usage and Configuration
|
|
110
|
+
|
|
111
|
+
* **Training from the original authors.**
|
|
112
|
+
|
|
113
|
+
* Match window length so that `T` is comfortably larger than pooling length
|
|
114
|
+
(e.g., > 1.5-2 s for MI).
|
|
115
|
+
* Start with standard MI augmentations (channel dropout/shuffle, time reverse)
|
|
116
|
+
and tune `n_spatial_filters` before deeper changes.
|
|
117
|
+
|
|
118
|
+
Parameters
|
|
119
|
+
----------
|
|
120
|
+
n_spatial_filters : int, optional
|
|
121
|
+
Number of spatial filters in the first convolutional layer, variable `N_u` from the
|
|
122
|
+
original paper. Default is 22.
|
|
123
|
+
n_spatial_filters_smooth : int, optional
|
|
124
|
+
Number of spatial filters used as filter in the second convolutional
|
|
125
|
+
layer. Default is 20.
|
|
126
|
+
drop_prob : float, optional
|
|
127
|
+
Dropout probability. Default is 0.5.
|
|
128
|
+
activation : nn.Module, optional
|
|
129
|
+
Activation function after the second convolutional layer. Default is
|
|
130
|
+
logarithm activation.
|
|
131
|
+
|
|
132
|
+
References
|
|
133
|
+
----------
|
|
134
|
+
.. [sccnet] Wei, C. S., Koike-Akino, T., & Wang, Y. (2019, March). Spatial
|
|
135
|
+
component-wise convolutional network (SCCNet) for motor-imagery EEG
|
|
136
|
+
classification. In 2019 9th International IEEE/EMBS Conference on
|
|
137
|
+
Neural Engineering (NER) (pp. 328-331). IEEE.
|
|
138
|
+
.. [sccnetcode] Hsieh, C. Y., Chou, J. L., Chang, Y. H., & Wei, C. S.
|
|
139
|
+
XBrainLab: An Open-Source Software for Explainable Artificial
|
|
140
|
+
Intelligence-Based EEG Analysis. In NeurIPS 2023 AI for
|
|
141
|
+
Science Workshop.
|
|
142
|
+
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
def __init__(
|
|
146
|
+
self,
|
|
147
|
+
# Signal related parameters
|
|
148
|
+
n_chans=None,
|
|
149
|
+
n_outputs=None,
|
|
150
|
+
n_times=None,
|
|
151
|
+
chs_info=None,
|
|
152
|
+
input_window_seconds=None,
|
|
153
|
+
sfreq=None,
|
|
154
|
+
# Model related parameters
|
|
155
|
+
n_spatial_filters: int = 22,
|
|
156
|
+
n_spatial_filters_smooth: int = 20,
|
|
157
|
+
drop_prob: float = 0.5,
|
|
158
|
+
activation: type[nn.Module] = LogActivation,
|
|
159
|
+
batch_norm_momentum: float = 0.1,
|
|
160
|
+
):
|
|
161
|
+
super().__init__(
|
|
162
|
+
n_outputs=n_outputs,
|
|
163
|
+
n_chans=n_chans,
|
|
164
|
+
chs_info=chs_info,
|
|
165
|
+
n_times=n_times,
|
|
166
|
+
input_window_seconds=input_window_seconds,
|
|
167
|
+
sfreq=sfreq,
|
|
168
|
+
)
|
|
169
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
170
|
+
# Parameters
|
|
171
|
+
self.n_spatial_filters = n_spatial_filters
|
|
172
|
+
self.n_spatial_filters_smooth = n_spatial_filters_smooth
|
|
173
|
+
self.drop_prob = drop_prob
|
|
174
|
+
|
|
175
|
+
# Original logical for SCCNet
|
|
176
|
+
conv_kernel_time = 0.1 # 100ms
|
|
177
|
+
pool_kernel_time = 0.5 # 500ms
|
|
178
|
+
|
|
179
|
+
# Calculate sample-based sizes from time durations
|
|
180
|
+
conv_kernel_samples = int(math.floor(self.sfreq * conv_kernel_time))
|
|
181
|
+
pool_kernel_samples = int(math.floor(self.sfreq * pool_kernel_time))
|
|
182
|
+
|
|
183
|
+
# If the input window is too short for the default kernel sizes,
|
|
184
|
+
# scale them down proportionally.
|
|
185
|
+
total_kernel_samples = conv_kernel_samples + pool_kernel_samples
|
|
186
|
+
|
|
187
|
+
if self.n_times < total_kernel_samples:
|
|
188
|
+
warning_msg = (
|
|
189
|
+
f"Input window seconds ({self.input_window_seconds:.2f}s) is smaller than the "
|
|
190
|
+
f"model's combined kernel sizes ({(total_kernel_samples / self.sfreq):.2f}s). "
|
|
191
|
+
"Scaling temporal parameters down proportionally."
|
|
192
|
+
)
|
|
193
|
+
warn(warning_msg, UserWarning, stacklevel=2)
|
|
194
|
+
|
|
195
|
+
scaling_factor = self.n_times / total_kernel_samples
|
|
196
|
+
conv_kernel_samples = int(math.floor(conv_kernel_samples * scaling_factor))
|
|
197
|
+
pool_kernel_samples = int(math.floor(pool_kernel_samples * scaling_factor))
|
|
198
|
+
|
|
199
|
+
# Ensure kernels are at least 1 sample wide
|
|
200
|
+
self.samples_100ms = max(1, conv_kernel_samples)
|
|
201
|
+
self.kernel_size_pool = max(1, pool_kernel_samples)
|
|
202
|
+
|
|
203
|
+
num_features = self._calc_num_features()
|
|
204
|
+
|
|
205
|
+
# Layers
|
|
206
|
+
self.ensure_dim = Rearrange("batch nchan times -> batch 1 nchan times")
|
|
207
|
+
|
|
208
|
+
self.activation = LogActivation() if activation is None else activation()
|
|
209
|
+
|
|
210
|
+
self.spatial_conv = nn.Conv2d(
|
|
211
|
+
in_channels=1,
|
|
212
|
+
out_channels=self.n_spatial_filters,
|
|
213
|
+
kernel_size=(self.n_chans, 1),
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
self.spatial_batch_norm = nn.BatchNorm2d(
|
|
217
|
+
self.n_spatial_filters, momentum=batch_norm_momentum
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
self.permute = Rearrange(
|
|
221
|
+
"batch filspat nchans time -> batch nchans filspat time"
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
self.spatial_filt_conv = nn.Conv2d(
|
|
225
|
+
in_channels=1,
|
|
226
|
+
out_channels=self.n_spatial_filters_smooth,
|
|
227
|
+
kernel_size=(self.n_spatial_filters, self.samples_100ms),
|
|
228
|
+
bias=False,
|
|
229
|
+
)
|
|
230
|
+
self.batch_norm = nn.BatchNorm2d(
|
|
231
|
+
self.n_spatial_filters_smooth, momentum=batch_norm_momentum
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
self.dropout = nn.Dropout(self.drop_prob)
|
|
235
|
+
self.temporal_smoothing = nn.AvgPool2d(
|
|
236
|
+
kernel_size=(1, self.kernel_size_pool),
|
|
237
|
+
stride=(1, self.samples_100ms),
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
self.final_layer = nn.Linear(num_features, self.n_outputs)
|
|
241
|
+
|
|
242
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
243
|
+
# Shape: (batch_size, n_chans, n_times)
|
|
244
|
+
x = self.ensure_dim(x)
|
|
245
|
+
# Shape: (batch_size, 1, n_chans, n_times)
|
|
246
|
+
x = self.spatial_conv(x)
|
|
247
|
+
# Shape: (batch_size, n_filters, 1, n_times)
|
|
248
|
+
x = self.spatial_batch_norm(x)
|
|
249
|
+
# Shape: (batch_size, n_filters, 1, n_times)
|
|
250
|
+
x = self.permute(x)
|
|
251
|
+
# Shape: (batch_size, 1, n_filters, n_times)
|
|
252
|
+
x = self.spatial_filt_conv(x)
|
|
253
|
+
# Shape: (batch_size, n_filters_filt, 1, n_times_reduced)
|
|
254
|
+
x = self.batch_norm(x)
|
|
255
|
+
# Shape: (batch_size, n_filters_filt, 1, n_times_reduced)
|
|
256
|
+
x = torch.pow(x, 2)
|
|
257
|
+
# Shape: (batch_size, n_filters_filt, 1, n_times_reduced)
|
|
258
|
+
x = self.dropout(x)
|
|
259
|
+
# Shape: (batch_size, n_filters_filt, 1, n_times_reduced)
|
|
260
|
+
x = self.temporal_smoothing(x)
|
|
261
|
+
# Shape: (batch_size, n_filters_filt, 1, n_times_reduced_avg_pool)
|
|
262
|
+
x = self.activation(x)
|
|
263
|
+
# Shape: (batch_size, n_filters_filt, 1, n_times_reduced_avg_pool)
|
|
264
|
+
x = x.view(x.size(0), -1)
|
|
265
|
+
# Shape: (batch_size, n_filters_filt*n_times_reduced_avg_pool)
|
|
266
|
+
x = self.final_layer(x)
|
|
267
|
+
# Shape: (batch_size, n_outputs)
|
|
268
|
+
return x
|
|
269
|
+
|
|
270
|
+
def _calc_num_features(self) -> int:
|
|
271
|
+
# Compute the number of features for the final linear layer
|
|
272
|
+
w_out_conv2 = (
|
|
273
|
+
self.n_times - self.samples_100ms + 1 # After second conv layer
|
|
274
|
+
)
|
|
275
|
+
w_out_pool = (
|
|
276
|
+
(w_out_conv2 - self.kernel_size_pool) // self.samples_100ms + 1
|
|
277
|
+
# After pooling layer
|
|
278
|
+
)
|
|
279
|
+
num_features = self.n_spatial_filters_smooth * w_out_pool
|
|
280
|
+
return num_features
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
|
|
5
|
+
from typing import Callable
|
|
6
|
+
|
|
7
|
+
from einops.layers.torch import Rearrange
|
|
8
|
+
from torch import nn
|
|
9
|
+
from torch.nn import init
|
|
10
|
+
|
|
11
|
+
from braindecode.functional import square
|
|
12
|
+
from braindecode.models.base import EEGModuleMixin
|
|
13
|
+
from braindecode.modules import (
|
|
14
|
+
CombinedConv,
|
|
15
|
+
Ensure4d,
|
|
16
|
+
Expression,
|
|
17
|
+
SafeLog,
|
|
18
|
+
SqueezeFinalOutput,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
|
|
23
|
+
r"""Shallow ConvNet model from Schirrmeister et al (2017) [Schirrmeister2017]_.
|
|
24
|
+
|
|
25
|
+
:bdg-success:`Convolution`
|
|
26
|
+
|
|
27
|
+
.. figure:: https://onlinelibrary.wiley.com/cms/asset/221ea375-6701-40d3-ab3f-e411aad62d9e/hbm23730-fig-0002-m.jpg
|
|
28
|
+
:align: center
|
|
29
|
+
:alt: ShallowNet Architecture
|
|
30
|
+
|
|
31
|
+
Model described in [Schirrmeister2017]_.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
n_filters_time: int
|
|
36
|
+
Number of temporal filters.
|
|
37
|
+
filter_time_length: int
|
|
38
|
+
Length of the temporal filter.
|
|
39
|
+
n_filters_spat: int
|
|
40
|
+
Number of spatial filters.
|
|
41
|
+
pool_time_length: int
|
|
42
|
+
Length of temporal pooling filter.
|
|
43
|
+
pool_time_stride: int
|
|
44
|
+
Length of stride between temporal pooling filters.
|
|
45
|
+
final_conv_length: int | str
|
|
46
|
+
Length of the final convolution layer.
|
|
47
|
+
If set to "auto", length of the input signal must be specified.
|
|
48
|
+
conv_nonlin: callable
|
|
49
|
+
Non-linear function to be used after convolution layers.
|
|
50
|
+
pool_mode: str
|
|
51
|
+
Method to use on pooling layers. "max" or "mean".
|
|
52
|
+
activation_pool_nonlin: callable
|
|
53
|
+
Non-linear function to be used after pooling layers.
|
|
54
|
+
split_first_layer: bool
|
|
55
|
+
Split first layer into temporal and spatial layers (True) or just use temporal (False).
|
|
56
|
+
There would be no non-linearity between the split layers.
|
|
57
|
+
batch_norm: bool
|
|
58
|
+
Whether to use batch normalisation.
|
|
59
|
+
batch_norm_alpha: float
|
|
60
|
+
Momentum for BatchNorm2d.
|
|
61
|
+
drop_prob: float
|
|
62
|
+
Dropout probability.
|
|
63
|
+
|
|
64
|
+
References
|
|
65
|
+
----------
|
|
66
|
+
.. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer,
|
|
67
|
+
L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F.
|
|
68
|
+
& Ball, T. (2017).
|
|
69
|
+
Deep learning with convolutional neural networks for EEG decoding and
|
|
70
|
+
visualization.
|
|
71
|
+
Human Brain Mapping , Aug. 2017.
|
|
72
|
+
Online: http://dx.doi.org/10.1002/hbm.23730
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
n_chans=None,
|
|
78
|
+
n_outputs=None,
|
|
79
|
+
n_times=None,
|
|
80
|
+
n_filters_time=40,
|
|
81
|
+
filter_time_length=25,
|
|
82
|
+
n_filters_spat=40,
|
|
83
|
+
pool_time_length=75,
|
|
84
|
+
pool_time_stride=15,
|
|
85
|
+
final_conv_length="auto",
|
|
86
|
+
conv_nonlin: Callable = square,
|
|
87
|
+
pool_mode="mean",
|
|
88
|
+
activation_pool_nonlin: type[nn.Module] = SafeLog,
|
|
89
|
+
split_first_layer=True,
|
|
90
|
+
batch_norm=True,
|
|
91
|
+
batch_norm_alpha=0.1,
|
|
92
|
+
drop_prob=0.5,
|
|
93
|
+
chs_info=None,
|
|
94
|
+
input_window_seconds=None,
|
|
95
|
+
sfreq=None,
|
|
96
|
+
):
|
|
97
|
+
super().__init__(
|
|
98
|
+
n_outputs=n_outputs,
|
|
99
|
+
n_chans=n_chans,
|
|
100
|
+
chs_info=chs_info,
|
|
101
|
+
n_times=n_times,
|
|
102
|
+
input_window_seconds=input_window_seconds,
|
|
103
|
+
sfreq=sfreq,
|
|
104
|
+
)
|
|
105
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
106
|
+
if final_conv_length == "auto":
|
|
107
|
+
assert self.n_times is not None
|
|
108
|
+
self.n_filters_time = n_filters_time
|
|
109
|
+
self.filter_time_length = filter_time_length
|
|
110
|
+
self.n_filters_spat = n_filters_spat
|
|
111
|
+
self.pool_time_length = pool_time_length
|
|
112
|
+
self.pool_time_stride = pool_time_stride
|
|
113
|
+
self.final_conv_length = final_conv_length
|
|
114
|
+
self.conv_nonlin = conv_nonlin
|
|
115
|
+
self.pool_mode = pool_mode
|
|
116
|
+
self.pool_nonlin = activation_pool_nonlin
|
|
117
|
+
self.split_first_layer = split_first_layer
|
|
118
|
+
self.batch_norm = batch_norm
|
|
119
|
+
self.batch_norm_alpha = batch_norm_alpha
|
|
120
|
+
self.drop_prob = drop_prob
|
|
121
|
+
|
|
122
|
+
self.mapping = {
|
|
123
|
+
"conv_time.weight": "conv_time_spat.conv_time.weight",
|
|
124
|
+
"conv_spat.weight": "conv_time_spat.conv_spat.weight",
|
|
125
|
+
"conv_time.bias": "conv_time_spat.conv_time.bias",
|
|
126
|
+
"conv_spat.bias": "conv_time_spat.conv_spat.bias",
|
|
127
|
+
"conv_classifier.weight": "final_layer.conv_classifier.weight",
|
|
128
|
+
"conv_classifier.bias": "final_layer.conv_classifier.bias",
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
self.add_module("ensuredims", Ensure4d())
|
|
132
|
+
pool_class = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)[self.pool_mode]
|
|
133
|
+
if self.split_first_layer:
|
|
134
|
+
self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 T C"))
|
|
135
|
+
self.add_module(
|
|
136
|
+
"conv_time_spat",
|
|
137
|
+
CombinedConv(
|
|
138
|
+
in_chans=self.n_chans,
|
|
139
|
+
n_filters_time=self.n_filters_time,
|
|
140
|
+
n_filters_spat=self.n_filters_spat,
|
|
141
|
+
filter_time_length=filter_time_length,
|
|
142
|
+
bias_time=True,
|
|
143
|
+
bias_spat=not self.batch_norm,
|
|
144
|
+
),
|
|
145
|
+
)
|
|
146
|
+
n_filters_conv = self.n_filters_spat
|
|
147
|
+
else:
|
|
148
|
+
self.add_module(
|
|
149
|
+
"conv_time",
|
|
150
|
+
nn.Conv2d(
|
|
151
|
+
self.n_chans,
|
|
152
|
+
self.n_filters_time,
|
|
153
|
+
(self.filter_time_length, 1),
|
|
154
|
+
stride=1,
|
|
155
|
+
bias=not self.batch_norm,
|
|
156
|
+
),
|
|
157
|
+
)
|
|
158
|
+
n_filters_conv = self.n_filters_time
|
|
159
|
+
if self.batch_norm:
|
|
160
|
+
self.add_module(
|
|
161
|
+
"bnorm",
|
|
162
|
+
nn.BatchNorm2d(
|
|
163
|
+
n_filters_conv, momentum=self.batch_norm_alpha, affine=True
|
|
164
|
+
),
|
|
165
|
+
)
|
|
166
|
+
self.add_module("conv_nonlin_exp", Expression(self.conv_nonlin))
|
|
167
|
+
self.add_module(
|
|
168
|
+
"pool",
|
|
169
|
+
pool_class(
|
|
170
|
+
kernel_size=(self.pool_time_length, 1),
|
|
171
|
+
stride=(self.pool_time_stride, 1),
|
|
172
|
+
),
|
|
173
|
+
)
|
|
174
|
+
self.add_module("pool_nonlin_exp", self.pool_nonlin())
|
|
175
|
+
self.add_module("drop", nn.Dropout(p=self.drop_prob))
|
|
176
|
+
self.eval()
|
|
177
|
+
if self.final_conv_length == "auto":
|
|
178
|
+
self.final_conv_length = self.get_output_shape()[2]
|
|
179
|
+
|
|
180
|
+
# Incorporating classification module and subsequent ones in one final layer
|
|
181
|
+
module = nn.Sequential()
|
|
182
|
+
|
|
183
|
+
module.add_module(
|
|
184
|
+
"conv_classifier",
|
|
185
|
+
nn.Conv2d(
|
|
186
|
+
n_filters_conv,
|
|
187
|
+
self.n_outputs,
|
|
188
|
+
(self.final_conv_length, 1),
|
|
189
|
+
bias=True,
|
|
190
|
+
),
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
module.add_module("squeeze", SqueezeFinalOutput())
|
|
194
|
+
|
|
195
|
+
self.add_module("final_layer", module)
|
|
196
|
+
|
|
197
|
+
# Initialization, xavier is same as in paper...
|
|
198
|
+
init.xavier_uniform_(self.conv_time_spat.conv_time.weight, gain=1)
|
|
199
|
+
# maybe no bias in case of no split layer and batch norm
|
|
200
|
+
if self.split_first_layer or (not self.batch_norm):
|
|
201
|
+
init.constant_(self.conv_time_spat.conv_time.bias, 0)
|
|
202
|
+
if self.split_first_layer:
|
|
203
|
+
init.xavier_uniform_(self.conv_time_spat.conv_spat.weight, gain=1)
|
|
204
|
+
if not self.batch_norm:
|
|
205
|
+
init.constant_(self.conv_time_spat.conv_spat.bias, 0)
|
|
206
|
+
if self.batch_norm:
|
|
207
|
+
init.constant_(self.bnorm.weight, 1)
|
|
208
|
+
init.constant_(self.bnorm.bias, 0)
|
|
209
|
+
init.xavier_uniform_(self.final_layer.conv_classifier.weight, gain=1)
|
|
210
|
+
init.constant_(self.final_layer.conv_classifier.bias, 0)
|
|
211
|
+
|
|
212
|
+
self.train()
|