braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +325 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +37 -18
- braindecode/datautil/serialization.py +110 -72
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +250 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +84 -14
- braindecode/models/atcnet.py +193 -164
- braindecode/models/attentionbasenet.py +599 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +504 -0
- braindecode/models/contrawr.py +317 -0
- braindecode/models/ctnet.py +536 -0
- braindecode/models/deep4.py +116 -77
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +112 -173
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +161 -97
- braindecode/models/eegitnet.py +215 -152
- braindecode/models/eegminer.py +254 -0
- braindecode/models/eegnet.py +228 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +324 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1186 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +207 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1011 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +247 -141
- braindecode/models/sparcnet.py +424 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +283 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -145
- braindecode/modules/__init__.py +84 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +628 -0
- braindecode/modules/layers.py +131 -0
- braindecode/modules/linear.py +49 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +76 -0
- braindecode/modules/wrapper.py +73 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +146 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
- braindecode-1.1.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Optional, Sequence
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from einops.layers.torch import Rearrange
|
|
7
|
+
from mne.utils import warn
|
|
8
|
+
from torch import nn
|
|
9
|
+
|
|
10
|
+
from braindecode.models.base import EEGModuleMixin
|
|
11
|
+
from braindecode.models.fbcnet import _valid_layers
|
|
12
|
+
from braindecode.modules import (
|
|
13
|
+
Conv2dWithConstraint,
|
|
14
|
+
FilterBankLayer,
|
|
15
|
+
LinearWithConstraint,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class FBMSNet(EEGModuleMixin, nn.Module):
|
|
20
|
+
"""FBMSNet from Liu et al (2022) [fbmsnet]_.
|
|
21
|
+
|
|
22
|
+
.. figure:: https://raw.githubusercontent.com/Want2Vanish/FBMSNet/refs/heads/main/FBMSNet.png
|
|
23
|
+
:align: center
|
|
24
|
+
:alt: FBMSNet Architecture
|
|
25
|
+
|
|
26
|
+
0. **FilterBank Layer**: Applying filterbank to transform the input.
|
|
27
|
+
|
|
28
|
+
1. **Temporal Convolution Block**: Utilizes mixed depthwise convolution
|
|
29
|
+
(MixConv) to extract multiscale temporal features from multiview EEG
|
|
30
|
+
representations. The input is split into groups corresponding to different
|
|
31
|
+
views each convolved with kernels of varying sizes.
|
|
32
|
+
Kernel sizes are set relative to the EEG
|
|
33
|
+
sampling rate, with ratio coefficients [0.5, 0.25, 0.125, 0.0625],
|
|
34
|
+
dividing the input into four groups.
|
|
35
|
+
|
|
36
|
+
2. **Spatial Convolution Block**: Applies depthwise convolution with a kernel
|
|
37
|
+
size of (n_chans, 1) to span all EEG channels, effectively learning spatial
|
|
38
|
+
filters. This is followed by batch normalization and the Swish activation
|
|
39
|
+
function. A maximum norm constraint of 2 is imposed on the convolution
|
|
40
|
+
weights to regularize the model.
|
|
41
|
+
|
|
42
|
+
3. **Temporal Log-Variance Block**: Computes the log-variance.
|
|
43
|
+
|
|
44
|
+
4. **Classification Layer**: A fully connected with weight constraint.
|
|
45
|
+
|
|
46
|
+
Notes
|
|
47
|
+
-----
|
|
48
|
+
This implementation is not guaranteed to be correct and has not been checked
|
|
49
|
+
by the original authors; it has only been reimplemented from the paper
|
|
50
|
+
description and source code [fbmsnetcode]_. There is an extra layer here to
|
|
51
|
+
compute the filterbank during bash time and not on data time. This avoids
|
|
52
|
+
data-leak, and allows the model to follow the braindecode convention.
|
|
53
|
+
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
n_bands : int, default=9
|
|
57
|
+
Number of input channels (e.g., number of frequency bands).
|
|
58
|
+
stride_factor : int, default=4
|
|
59
|
+
Stride factor for temporal segmentation.
|
|
60
|
+
temporal_layer : str, default='LogVarLayer'
|
|
61
|
+
Temporal aggregation layer to use.
|
|
62
|
+
n_filters_spat : int, default=36
|
|
63
|
+
Number of output channels from the MixedConv2d layer.
|
|
64
|
+
dilatability : int, default=8
|
|
65
|
+
Expansion factor for the spatial convolution block.
|
|
66
|
+
activation : nn.Module, default=nn.SiLU
|
|
67
|
+
Activation function class to apply.
|
|
68
|
+
verbose: bool, default False
|
|
69
|
+
Verbose parameter to create the filter using mne.
|
|
70
|
+
|
|
71
|
+
References
|
|
72
|
+
----------
|
|
73
|
+
.. [fbmsnet] Liu, K., Yang, M., Yu, Z., Wang, G., & Wu, W. (2022).
|
|
74
|
+
FBMSNet: A filter-bank multi-scale convolutional neural network for
|
|
75
|
+
EEG-based motor imagery decoding. IEEE Transactions on Biomedical
|
|
76
|
+
Engineering, 70(2), 436-445.
|
|
77
|
+
.. [fbmsnetcode] Liu, K., Yang, M., Yu, Z., Wang, G., & Wu, W. (2022).
|
|
78
|
+
FBMSNet: A filter-bank multi-scale convolutional neural network for
|
|
79
|
+
EEG-based motor imagery decoding.
|
|
80
|
+
https://github.com/Want2Vanish/FBMSNet
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(
|
|
84
|
+
self,
|
|
85
|
+
# Braindecode parameters
|
|
86
|
+
n_chans=None,
|
|
87
|
+
n_outputs=None,
|
|
88
|
+
chs_info=None,
|
|
89
|
+
n_times=None,
|
|
90
|
+
input_window_seconds=None,
|
|
91
|
+
sfreq=None,
|
|
92
|
+
# models parameters
|
|
93
|
+
n_bands: int = 9,
|
|
94
|
+
n_filters_spat: int = 36,
|
|
95
|
+
temporal_layer: str = "LogVarLayer",
|
|
96
|
+
n_dim: int = 3,
|
|
97
|
+
stride_factor: int = 4,
|
|
98
|
+
dilatability: int = 8,
|
|
99
|
+
activation: nn.Module = nn.SiLU,
|
|
100
|
+
kernels_weights: Sequence[int] = (15, 31, 63, 125),
|
|
101
|
+
cnn_max_norm: float = 2,
|
|
102
|
+
linear_max_norm: float = 0.5,
|
|
103
|
+
verbose: bool = False,
|
|
104
|
+
filter_parameters: Optional[dict] = None,
|
|
105
|
+
):
|
|
106
|
+
super().__init__(
|
|
107
|
+
n_chans=n_chans,
|
|
108
|
+
n_outputs=n_outputs,
|
|
109
|
+
chs_info=chs_info,
|
|
110
|
+
n_times=n_times,
|
|
111
|
+
input_window_seconds=input_window_seconds,
|
|
112
|
+
sfreq=sfreq,
|
|
113
|
+
)
|
|
114
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
115
|
+
|
|
116
|
+
# Parameters
|
|
117
|
+
self.n_bands = n_bands
|
|
118
|
+
self.n_filters_spat = n_filters_spat
|
|
119
|
+
self.n_dim = n_dim
|
|
120
|
+
self.stride_factor = stride_factor
|
|
121
|
+
self.activation = activation
|
|
122
|
+
self.dilatability = dilatability
|
|
123
|
+
self.kernels_weights = kernels_weights
|
|
124
|
+
self.filter_parameters = filter_parameters or {}
|
|
125
|
+
self.out_channels_spatial = self.n_filters_spat * self.dilatability
|
|
126
|
+
|
|
127
|
+
# Checkers
|
|
128
|
+
if temporal_layer not in _valid_layers:
|
|
129
|
+
raise NotImplementedError(
|
|
130
|
+
f"Temporal layer '{temporal_layer}' is not implemented."
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
if self.n_times % self.stride_factor != 0:
|
|
134
|
+
warn(
|
|
135
|
+
f"Time dimension ({self.n_times}) is not divisible by"
|
|
136
|
+
f" stride_factor ({self.stride_factor}). Input will be padded.",
|
|
137
|
+
UserWarning,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Layers
|
|
141
|
+
# Following paper nomeclature
|
|
142
|
+
self.spectral_filtering = FilterBankLayer(
|
|
143
|
+
n_chans=self.n_chans,
|
|
144
|
+
sfreq=self.sfreq,
|
|
145
|
+
band_filters=self.n_bands,
|
|
146
|
+
verbose=verbose,
|
|
147
|
+
**self.filter_parameters,
|
|
148
|
+
)
|
|
149
|
+
# As we have an internal process to create the bands,
|
|
150
|
+
# we get the values from the filterbank
|
|
151
|
+
self.n_bands = self.spectral_filtering.n_bands
|
|
152
|
+
|
|
153
|
+
# MixedConv2d Layer
|
|
154
|
+
self.mix_conv = nn.Sequential(
|
|
155
|
+
_MixedConv2d(
|
|
156
|
+
in_channels=self.n_bands,
|
|
157
|
+
out_channels=self.n_filters_spat,
|
|
158
|
+
stride=1,
|
|
159
|
+
dilation=1,
|
|
160
|
+
depthwise=False,
|
|
161
|
+
kernels_weights=kernels_weights,
|
|
162
|
+
),
|
|
163
|
+
nn.BatchNorm2d(self.n_filters_spat),
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# Spatial Convolution Block (SCB)
|
|
167
|
+
self.spatial_conv = nn.Sequential(
|
|
168
|
+
Conv2dWithConstraint(
|
|
169
|
+
in_channels=self.n_filters_spat,
|
|
170
|
+
out_channels=self.out_channels_spatial,
|
|
171
|
+
kernel_size=(self.n_chans, 1),
|
|
172
|
+
groups=self.n_filters_spat,
|
|
173
|
+
max_norm=cnn_max_norm,
|
|
174
|
+
padding=0,
|
|
175
|
+
),
|
|
176
|
+
nn.BatchNorm2d(self.out_channels_spatial),
|
|
177
|
+
self.activation(),
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Padding layer
|
|
181
|
+
if self.n_times % self.stride_factor != 0:
|
|
182
|
+
self.padding_size = stride_factor - (self.n_times % stride_factor)
|
|
183
|
+
self.n_times_padded = self.n_times + self.padding_size
|
|
184
|
+
self.padding_layer = nn.ConstantPad1d((0, self.padding_size), 0.0)
|
|
185
|
+
else:
|
|
186
|
+
self.padding_layer = nn.Identity()
|
|
187
|
+
self.n_times_padded = self.n_times
|
|
188
|
+
|
|
189
|
+
# Temporal Aggregation Layer
|
|
190
|
+
self.temporal_layer = _valid_layers[temporal_layer](dim=self.n_dim) # type: ignore
|
|
191
|
+
|
|
192
|
+
self.flatten_layer = Rearrange("batch ... -> batch (...)")
|
|
193
|
+
|
|
194
|
+
# Final fully connected layer
|
|
195
|
+
self.final_layer = LinearWithConstraint(
|
|
196
|
+
in_features=self.out_channels_spatial * self.stride_factor,
|
|
197
|
+
out_features=self.n_outputs,
|
|
198
|
+
max_norm=linear_max_norm,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
def forward(self, x):
|
|
202
|
+
"""
|
|
203
|
+
Forward pass of the FBMSNet model.
|
|
204
|
+
|
|
205
|
+
Parameters
|
|
206
|
+
----------
|
|
207
|
+
x : torch.Tensor
|
|
208
|
+
Input tensor with shape (batch_size, n_chans, n_times).
|
|
209
|
+
|
|
210
|
+
Returns
|
|
211
|
+
-------
|
|
212
|
+
torch.Tensor
|
|
213
|
+
Output tensor with shape (batch_size, n_outputs).
|
|
214
|
+
"""
|
|
215
|
+
batch, _, _ = x.shape
|
|
216
|
+
|
|
217
|
+
# shape: (batch, n_chans, n_times)
|
|
218
|
+
x = self.spectral_filtering(x)
|
|
219
|
+
# shape: (batch, n_bands, n_chans, n_times)
|
|
220
|
+
|
|
221
|
+
# Mixed convolution
|
|
222
|
+
x = self.mix_conv(x)
|
|
223
|
+
# shape: (batch, self.n_filters_spat, n_chans, n_times)
|
|
224
|
+
|
|
225
|
+
# Spatial convolution block
|
|
226
|
+
x = self.spatial_conv(x)
|
|
227
|
+
# shape: (batch, self.out_channels_spatial, 1, n_times)
|
|
228
|
+
|
|
229
|
+
# Apply some padding to the input to make it divisible by the stride factor
|
|
230
|
+
x = self.padding_layer(x)
|
|
231
|
+
# shape: (batch, self.out_channels_spatial, 1, n_times_padded)
|
|
232
|
+
|
|
233
|
+
# Reshape for temporal layer
|
|
234
|
+
x = x.view(batch, self.out_channels_spatial, self.stride_factor, -1)
|
|
235
|
+
# shape: (batch, self.out_channels_spatial, self.stride_factor, n_times/self.stride_factor)
|
|
236
|
+
|
|
237
|
+
# Temporal aggregation
|
|
238
|
+
x = self.temporal_layer(x)
|
|
239
|
+
# shape: (batch, self.out_channels_spatial, self.stride_factor, 1)
|
|
240
|
+
|
|
241
|
+
# Flatten and classify
|
|
242
|
+
x = self.flatten_layer(x)
|
|
243
|
+
# shape: (batch, self.out_channels_spatial*self.stride_factor)
|
|
244
|
+
|
|
245
|
+
x = self.final_layer(x)
|
|
246
|
+
# shape: (batch, n_outputs)
|
|
247
|
+
return x
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class _MixedConv2d(nn.Module):
|
|
251
|
+
"""Mixed Grouped Convolution for multiscale feature extraction."""
|
|
252
|
+
|
|
253
|
+
def __init__(
|
|
254
|
+
self,
|
|
255
|
+
in_channels,
|
|
256
|
+
out_channels,
|
|
257
|
+
kernels_weights=(15, 31, 63, 125),
|
|
258
|
+
stride=1,
|
|
259
|
+
dilation=1,
|
|
260
|
+
depthwise=False,
|
|
261
|
+
):
|
|
262
|
+
super().__init__()
|
|
263
|
+
|
|
264
|
+
num_groups = len(kernels_weights)
|
|
265
|
+
in_splits = self._split_channels(in_channels, num_groups)
|
|
266
|
+
out_splits = self._split_channels(out_channels, num_groups)
|
|
267
|
+
self.splits = in_splits
|
|
268
|
+
|
|
269
|
+
self.convs = nn.ModuleList()
|
|
270
|
+
# Create a convolutional layer for each kernel size
|
|
271
|
+
for k, in_ch, out_ch in zip(kernels_weights, in_splits, out_splits):
|
|
272
|
+
conv_groups = out_ch if depthwise else 1
|
|
273
|
+
conv = nn.Conv2d(
|
|
274
|
+
in_channels=in_ch,
|
|
275
|
+
out_channels=out_ch,
|
|
276
|
+
kernel_size=(1, k),
|
|
277
|
+
stride=stride,
|
|
278
|
+
padding="same",
|
|
279
|
+
dilation=dilation,
|
|
280
|
+
groups=conv_groups,
|
|
281
|
+
bias=False,
|
|
282
|
+
)
|
|
283
|
+
self.convs.append(conv)
|
|
284
|
+
|
|
285
|
+
@staticmethod
|
|
286
|
+
def _split_channels(num_chan, num_groups):
|
|
287
|
+
"""
|
|
288
|
+
Splits the total number of channels into a specified
|
|
289
|
+
number of groups as evenly as possible.
|
|
290
|
+
|
|
291
|
+
Parameters
|
|
292
|
+
----------
|
|
293
|
+
num_chan : int
|
|
294
|
+
The total number of channels to split.
|
|
295
|
+
num_groups : int
|
|
296
|
+
The number of groups to split the channels into.
|
|
297
|
+
|
|
298
|
+
Returns
|
|
299
|
+
-------
|
|
300
|
+
list of int
|
|
301
|
+
A list containing the number of channels in each group.
|
|
302
|
+
The first group may have more channels if the division is not even.
|
|
303
|
+
"""
|
|
304
|
+
split = [num_chan // num_groups for _ in range(num_groups)]
|
|
305
|
+
split[0] += num_chan - sum(split)
|
|
306
|
+
return split
|
|
307
|
+
|
|
308
|
+
def forward(self, x):
|
|
309
|
+
# Split the input tensor `x` along the channel dimension (dim=1) into groups.
|
|
310
|
+
# The size of each group is defined by `self.splits`, which is calculated
|
|
311
|
+
# based on the number of input channels and the number of kernel sizes.
|
|
312
|
+
x_split = torch.split(x, self.splits, 1)
|
|
313
|
+
|
|
314
|
+
# For each split group, apply the corresponding convolutional layer.
|
|
315
|
+
# `self.values()` returns the convolutional layers in the order they were added.
|
|
316
|
+
# The result is a list of output tensors, one for each group.
|
|
317
|
+
x_out = [conv(x_split[i]) for i, conv in enumerate(self.convs)]
|
|
318
|
+
|
|
319
|
+
# Concatenate the outputs from all groups along the channel dimension (dim=1)
|
|
320
|
+
# to form a single output tensor.
|
|
321
|
+
x = torch.cat(x_out, 1)
|
|
322
|
+
|
|
323
|
+
# Return the concatenated tensor as the output of the mixed convolution.
|
|
324
|
+
return x
|
braindecode/models/hybrid.py
CHANGED
|
@@ -6,14 +6,12 @@ import torch
|
|
|
6
6
|
from torch import nn
|
|
7
7
|
from torch.nn import ConstantPad2d
|
|
8
8
|
|
|
9
|
-
from .deep4 import Deep4Net
|
|
10
|
-
from .
|
|
11
|
-
from .shallow_fbcsp import ShallowFBCSPNet
|
|
12
|
-
from .base import EEGModuleMixin, deprecated_args
|
|
9
|
+
from braindecode.models.deep4 import Deep4Net
|
|
10
|
+
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
|
|
13
11
|
|
|
14
12
|
|
|
15
|
-
class HybridNet(
|
|
16
|
-
"""Hybrid ConvNet model from Schirrmeister et al 2017.
|
|
13
|
+
class HybridNet(nn.Module):
|
|
14
|
+
"""Hybrid ConvNet model from Schirrmeister, R T et al (2017) [Schirrmeister2017]_.
|
|
17
15
|
|
|
18
16
|
See [Schirrmeister2017]_ for details.
|
|
19
17
|
|
|
@@ -28,25 +26,21 @@ class HybridNet(EEGModuleMixin, nn.Module):
|
|
|
28
26
|
Online: http://dx.doi.org/10.1002/hbm.23730
|
|
29
27
|
"""
|
|
30
28
|
|
|
31
|
-
def __init__(
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
n_chans=n_chans,
|
|
44
|
-
n_times=n_times,
|
|
45
|
-
add_log_softmax=add_log_softmax,
|
|
46
|
-
)
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
n_chans=None,
|
|
32
|
+
n_outputs=None,
|
|
33
|
+
n_times=None,
|
|
34
|
+
input_window_seconds=None,
|
|
35
|
+
sfreq=None,
|
|
36
|
+
chs_info=None,
|
|
37
|
+
activation: nn.Module = nn.ELU,
|
|
38
|
+
drop_prob: float = 0.5,
|
|
39
|
+
):
|
|
40
|
+
super().__init__()
|
|
47
41
|
self.mapping = {
|
|
48
|
-
|
|
49
|
-
|
|
42
|
+
"final_conv.weight": "final_layer.weight",
|
|
43
|
+
"final_conv.bias": "final_layer.bias",
|
|
50
44
|
}
|
|
51
45
|
|
|
52
46
|
deep_model = Deep4Net(
|
|
@@ -58,61 +52,52 @@ class HybridNet(EEGModuleMixin, nn.Module):
|
|
|
58
52
|
n_filters_3=50,
|
|
59
53
|
n_filters_4=60,
|
|
60
54
|
n_times=n_times,
|
|
55
|
+
input_window_seconds=input_window_seconds,
|
|
56
|
+
sfreq=sfreq,
|
|
57
|
+
chs_info=chs_info,
|
|
61
58
|
final_conv_length=2,
|
|
59
|
+
activation_first_conv_nonlin=activation,
|
|
60
|
+
activation_later_conv_nonlin=activation,
|
|
61
|
+
drop_prob=drop_prob,
|
|
62
62
|
)
|
|
63
63
|
shallow_model = ShallowFBCSPNet(
|
|
64
64
|
n_chans=n_chans,
|
|
65
65
|
n_outputs=n_outputs,
|
|
66
66
|
n_times=n_times,
|
|
67
|
+
input_window_seconds=input_window_seconds,
|
|
68
|
+
sfreq=sfreq,
|
|
69
|
+
chs_info=chs_info,
|
|
67
70
|
n_filters_time=30,
|
|
68
71
|
n_filters_spat=40,
|
|
69
72
|
filter_time_length=28,
|
|
70
73
|
final_conv_length=29,
|
|
74
|
+
drop_prob=drop_prob,
|
|
75
|
+
)
|
|
76
|
+
new_conv_layer = nn.Conv2d(
|
|
77
|
+
deep_model.final_layer.conv_classifier.in_channels,
|
|
78
|
+
60,
|
|
79
|
+
kernel_size=deep_model.final_layer.conv_classifier.kernel_size,
|
|
80
|
+
stride=deep_model.final_layer.conv_classifier.stride,
|
|
71
81
|
)
|
|
82
|
+
deep_model.final_layer = new_conv_layer
|
|
72
83
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
reduced_deep_model.add_module("deep_final_conv", new_conv_layer)
|
|
86
|
-
break
|
|
87
|
-
reduced_deep_model.add_module(name, module)
|
|
88
|
-
|
|
89
|
-
reduced_shallow_model = nn.Sequential()
|
|
90
|
-
for name, module in shallow_model.named_children():
|
|
91
|
-
if name == "final_layer":
|
|
92
|
-
new_conv_layer = nn.Conv2d(
|
|
93
|
-
module.conv_classifier.in_channels,
|
|
94
|
-
40,
|
|
95
|
-
kernel_size=module.conv_classifier.kernel_size,
|
|
96
|
-
stride=module.conv_classifier.stride,
|
|
97
|
-
)
|
|
98
|
-
reduced_shallow_model.add_module(
|
|
99
|
-
"shallow_final_conv", new_conv_layer
|
|
100
|
-
)
|
|
101
|
-
break
|
|
102
|
-
reduced_shallow_model.add_module(name, module)
|
|
103
|
-
|
|
104
|
-
to_dense_prediction_model(reduced_deep_model)
|
|
105
|
-
to_dense_prediction_model(reduced_shallow_model)
|
|
106
|
-
self.reduced_deep_model = reduced_deep_model
|
|
107
|
-
self.reduced_shallow_model = reduced_shallow_model
|
|
84
|
+
new_conv_layer = nn.Conv2d(
|
|
85
|
+
shallow_model.final_layer.conv_classifier.in_channels,
|
|
86
|
+
40,
|
|
87
|
+
kernel_size=shallow_model.final_layer.conv_classifier.kernel_size,
|
|
88
|
+
stride=shallow_model.final_layer.conv_classifier.stride,
|
|
89
|
+
)
|
|
90
|
+
shallow_model.final_layer = new_conv_layer
|
|
91
|
+
|
|
92
|
+
deep_model.to_dense_prediction_model()
|
|
93
|
+
shallow_model.to_dense_prediction_model()
|
|
94
|
+
self.reduced_deep_model = deep_model
|
|
95
|
+
self.reduced_shallow_model = shallow_model
|
|
108
96
|
|
|
109
97
|
self.final_layer = nn.Sequential(
|
|
110
|
-
nn.Conv2d(
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
kernel_size=(1, 1),
|
|
114
|
-
stride=1),
|
|
115
|
-
nn.LogSoftmax(dim=1) if self.add_log_softmax else nn.Identity())
|
|
98
|
+
nn.Conv2d(100, n_outputs, kernel_size=(1, 1), stride=1),
|
|
99
|
+
nn.Identity(),
|
|
100
|
+
)
|
|
116
101
|
|
|
117
102
|
def forward(self, x):
|
|
118
103
|
"""Forward pass.
|
|
@@ -128,13 +113,9 @@ class HybridNet(EEGModuleMixin, nn.Module):
|
|
|
128
113
|
n_diff_deep_shallow = deep_out.size()[2] - shallow_out.size()[2]
|
|
129
114
|
|
|
130
115
|
if n_diff_deep_shallow < 0:
|
|
131
|
-
deep_out = ConstantPad2d((0, 0, -n_diff_deep_shallow, 0), 0)(
|
|
132
|
-
deep_out
|
|
133
|
-
)
|
|
116
|
+
deep_out = ConstantPad2d((0, 0, -n_diff_deep_shallow, 0), 0)(deep_out)
|
|
134
117
|
elif n_diff_deep_shallow > 0:
|
|
135
|
-
shallow_out = ConstantPad2d((0, 0, n_diff_deep_shallow, 0), 0)(
|
|
136
|
-
shallow_out
|
|
137
|
-
)
|
|
118
|
+
shallow_out = ConstantPad2d((0, 0, n_diff_deep_shallow, 0), 0)(shallow_out)
|
|
138
119
|
|
|
139
120
|
merged_out = torch.cat((deep_out, shallow_out), dim=1)
|
|
140
121
|
|