braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +50 -0
- braindecode/augmentation/base.py +222 -0
- braindecode/augmentation/functional.py +1096 -0
- braindecode/augmentation/transforms.py +1274 -0
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +34 -0
- braindecode/datasets/base.py +840 -0
- braindecode/datasets/bbci.py +694 -0
- braindecode/datasets/bcicomp.py +194 -0
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +172 -0
- braindecode/datasets/moabb.py +209 -0
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +588 -0
- braindecode/datasets/xy.py +95 -0
- braindecode/datautil/__init__.py +49 -0
- braindecode/datautil/serialization.py +342 -0
- braindecode/datautil/util.py +41 -0
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +52 -0
- braindecode/models/atcnet.py +652 -0
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +296 -0
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +322 -0
- braindecode/models/deepsleepnet.py +295 -0
- braindecode/models/eegconformer.py +372 -0
- braindecode/models/eeginception_erp.py +304 -0
- braindecode/models/eeginception_mi.py +371 -0
- braindecode/models/eegitnet.py +301 -0
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +473 -0
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +362 -0
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +208 -0
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +167 -0
- braindecode/models/sleep_stager_chambon_2018.py +157 -0
- braindecode/models/sleep_stager_eldele_2021.py +536 -0
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +273 -0
- braindecode/models/tidnet.py +395 -0
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +340 -0
- braindecode/models/util.py +133 -0
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +37 -0
- braindecode/preprocessing/mne_preprocess.py +77 -0
- braindecode/preprocessing/preprocess.py +478 -0
- braindecode/preprocessing/windowers.py +1031 -0
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +401 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +483 -0
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +57 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-0.8.dist-info/RECORD +0 -11
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from einops.layers.torch import Rearrange
|
|
7
|
+
from mne.utils import warn
|
|
8
|
+
from torch import Tensor, nn
|
|
9
|
+
|
|
10
|
+
from braindecode.models.base import EEGModuleMixin
|
|
11
|
+
from braindecode.modules import (
|
|
12
|
+
Conv2dWithConstraint,
|
|
13
|
+
FilterBankLayer,
|
|
14
|
+
LinearWithConstraint,
|
|
15
|
+
LogVarLayer,
|
|
16
|
+
MaxLayer,
|
|
17
|
+
MeanLayer,
|
|
18
|
+
StdLayer,
|
|
19
|
+
VarLayer,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
_valid_layers = {
|
|
23
|
+
"VarLayer": VarLayer,
|
|
24
|
+
"StdLayer": StdLayer,
|
|
25
|
+
"LogVarLayer": LogVarLayer,
|
|
26
|
+
"MeanLayer": MeanLayer,
|
|
27
|
+
"MaxLayer": MaxLayer,
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class FBCNet(EEGModuleMixin, nn.Module):
|
|
32
|
+
"""FBCNet from Mane, R et al (2021) [fbcnet2021]_.
|
|
33
|
+
|
|
34
|
+
.. figure:: https://raw.githubusercontent.com/ravikiran-mane/FBCNet/refs/heads/master/FBCNet-V2.png
|
|
35
|
+
:align: center
|
|
36
|
+
:alt: FBCNet Architecture
|
|
37
|
+
|
|
38
|
+
The FBCNet model applies spatial convolution and variance calculation along
|
|
39
|
+
the time axis, inspired by the Filter Bank Common Spatial Pattern (FBCSP)
|
|
40
|
+
algorithm.
|
|
41
|
+
|
|
42
|
+
Notes
|
|
43
|
+
-----
|
|
44
|
+
This implementation is not guaranteed to be correct and has not been checked
|
|
45
|
+
by the original authors; it has only been reimplemented from the paper
|
|
46
|
+
description and source code [fbcnetcode2021]_. There is a difference in the
|
|
47
|
+
activation function; in the paper, the ELU is used as the activation function,
|
|
48
|
+
but in the original code, SiLU is used. We followed the code.
|
|
49
|
+
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
n_bands : int or None or list[tuple[int, int]]], default=9
|
|
53
|
+
Number of frequency bands. Could
|
|
54
|
+
n_filters_spat : int, default=32
|
|
55
|
+
Number of spatial filters for the first convolution.
|
|
56
|
+
n_dim: int, default=3
|
|
57
|
+
Number of dimensions for the temporal reductor
|
|
58
|
+
temporal_layer : str, default='LogVarLayer'
|
|
59
|
+
Type of temporal aggregator layer. Options: 'VarLayer', 'StdLayer',
|
|
60
|
+
'LogVarLayer', 'MeanLayer', 'MaxLayer'.
|
|
61
|
+
stride_factor : int, default=4
|
|
62
|
+
Stride factor for reshaping.
|
|
63
|
+
activation : nn.Module, default=nn.SiLU
|
|
64
|
+
Activation function class to apply in Spatial Convolution Block.
|
|
65
|
+
cnn_max_norm : float, default=2.0
|
|
66
|
+
Maximum norm for the spatial convolution layer.
|
|
67
|
+
linear_max_norm : float, default=0.5
|
|
68
|
+
Maximum norm for the final linear layer.
|
|
69
|
+
filter_parameters: dict, default None
|
|
70
|
+
Parameters for the FilterBankLayer
|
|
71
|
+
|
|
72
|
+
References
|
|
73
|
+
----------
|
|
74
|
+
.. [fbcnet2021] Mane, R., Chew, E., Chua, K., Ang, K. K., Robinson, N.,
|
|
75
|
+
Vinod, A. P., ... & Guan, C. (2021). FBCNet: A multi-view convolutional
|
|
76
|
+
neural network for brain-computer interface. preprint arXiv:2104.01233.
|
|
77
|
+
.. [fbcnetcode2021] Link to source-code:
|
|
78
|
+
https://github.com/ravikiran-mane/FBCNet
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
# Braindecode parameters
|
|
84
|
+
n_chans=None,
|
|
85
|
+
n_outputs=None,
|
|
86
|
+
chs_info=None,
|
|
87
|
+
n_times=None,
|
|
88
|
+
input_window_seconds=None,
|
|
89
|
+
sfreq=None,
|
|
90
|
+
# models parameters
|
|
91
|
+
n_bands=9,
|
|
92
|
+
n_filters_spat: int = 32,
|
|
93
|
+
temporal_layer: str = "LogVarLayer",
|
|
94
|
+
n_dim: int = 3,
|
|
95
|
+
stride_factor: int = 4,
|
|
96
|
+
activation: nn.Module = nn.SiLU,
|
|
97
|
+
linear_max_norm: float = 0.5,
|
|
98
|
+
cnn_max_norm: float = 2.0,
|
|
99
|
+
filter_parameters: dict[Any, Any] | None = None,
|
|
100
|
+
):
|
|
101
|
+
super().__init__(
|
|
102
|
+
n_chans=n_chans,
|
|
103
|
+
n_outputs=n_outputs,
|
|
104
|
+
chs_info=chs_info,
|
|
105
|
+
n_times=n_times,
|
|
106
|
+
input_window_seconds=input_window_seconds,
|
|
107
|
+
sfreq=sfreq,
|
|
108
|
+
)
|
|
109
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
110
|
+
|
|
111
|
+
# Parameters
|
|
112
|
+
self.n_bands = n_bands
|
|
113
|
+
self.n_filters_spat = n_filters_spat
|
|
114
|
+
self.n_dim = n_dim
|
|
115
|
+
self.stride_factor = stride_factor
|
|
116
|
+
self.activation = activation
|
|
117
|
+
self.filter_parameters = filter_parameters or {}
|
|
118
|
+
|
|
119
|
+
# Checkers
|
|
120
|
+
if temporal_layer not in _valid_layers:
|
|
121
|
+
raise NotImplementedError(
|
|
122
|
+
f"Temporal layer '{temporal_layer}' is not implemented."
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
if self.n_times % self.stride_factor != 0:
|
|
126
|
+
warn(
|
|
127
|
+
f"Time dimension ({self.n_times}) is not divisible by"
|
|
128
|
+
f" stride_factor ({self.stride_factor}). Input will be padded.",
|
|
129
|
+
UserWarning,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Layers
|
|
133
|
+
# Following paper nomenclature
|
|
134
|
+
self.spectral_filtering = FilterBankLayer(
|
|
135
|
+
n_chans=self.n_chans,
|
|
136
|
+
sfreq=self.sfreq,
|
|
137
|
+
band_filters=self.n_bands,
|
|
138
|
+
verbose=False,
|
|
139
|
+
**self.filter_parameters,
|
|
140
|
+
)
|
|
141
|
+
# As we have an internal process to create the bands,
|
|
142
|
+
# we get the values from the filterbank
|
|
143
|
+
self.n_bands = self.spectral_filtering.n_bands
|
|
144
|
+
|
|
145
|
+
# Spatial Convolution Block (SCB)
|
|
146
|
+
self.spatial_conv = nn.Sequential(
|
|
147
|
+
Conv2dWithConstraint(
|
|
148
|
+
in_channels=self.n_bands,
|
|
149
|
+
out_channels=self.n_filters_spat * self.n_bands,
|
|
150
|
+
kernel_size=(self.n_chans, 1),
|
|
151
|
+
groups=self.n_bands,
|
|
152
|
+
max_norm=cnn_max_norm,
|
|
153
|
+
padding=0,
|
|
154
|
+
),
|
|
155
|
+
nn.BatchNorm2d(self.n_filters_spat * self.n_bands),
|
|
156
|
+
self.activation(),
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# Padding layer
|
|
160
|
+
if self.n_times % self.stride_factor != 0:
|
|
161
|
+
self.padding_size = stride_factor - (self.n_times % stride_factor)
|
|
162
|
+
self.n_times_padded = self.n_times + self.padding_size
|
|
163
|
+
self.padding_layer = nn.ConstantPad1d((0, self.padding_size), 0.0)
|
|
164
|
+
else:
|
|
165
|
+
self.padding_layer = nn.Identity()
|
|
166
|
+
self.n_times_padded = self.n_times
|
|
167
|
+
|
|
168
|
+
# Temporal aggregator
|
|
169
|
+
self.temporal_layer = _valid_layers[temporal_layer](dim=self.n_dim) # type: ignore
|
|
170
|
+
|
|
171
|
+
# Flatten layer
|
|
172
|
+
self.flatten_layer = Rearrange("batch ... -> batch (...)")
|
|
173
|
+
|
|
174
|
+
# Final fully connected layer
|
|
175
|
+
self.final_layer = LinearWithConstraint(
|
|
176
|
+
in_features=self.n_filters_spat * self.n_bands * self.stride_factor,
|
|
177
|
+
out_features=self.n_outputs,
|
|
178
|
+
max_norm=linear_max_norm,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
182
|
+
"""
|
|
183
|
+
Forward pass of the FBCNet model.
|
|
184
|
+
|
|
185
|
+
Parameters
|
|
186
|
+
----------
|
|
187
|
+
x : torch.Tensor
|
|
188
|
+
Input tensor with shape (batch_size, n_chans, n_times).
|
|
189
|
+
|
|
190
|
+
Returns
|
|
191
|
+
-------
|
|
192
|
+
torch.Tensor
|
|
193
|
+
Output tensor with shape (batch_size, n_outputs).
|
|
194
|
+
"""
|
|
195
|
+
# output: (batch_size, n_chans, n_times)
|
|
196
|
+
x = self.spectral_filtering(x)
|
|
197
|
+
|
|
198
|
+
# output: (batch_size, n_bands, n_chans, n_times)
|
|
199
|
+
x = self.spatial_conv(x)
|
|
200
|
+
batch_size, channels, _, _ = x.shape
|
|
201
|
+
|
|
202
|
+
# shape: (batch_size, n_filters_spat * n_bands, 1, n_times)
|
|
203
|
+
x = self.padding_layer(x)
|
|
204
|
+
|
|
205
|
+
# shape: (batch_size, n_filters_spat * n_bands, 1, n_times_padded)
|
|
206
|
+
x = x.view(
|
|
207
|
+
batch_size,
|
|
208
|
+
channels,
|
|
209
|
+
self.stride_factor,
|
|
210
|
+
self.n_times_padded // self.stride_factor,
|
|
211
|
+
)
|
|
212
|
+
# shape: batch_size, n_filters_spat * n_bands, stride, n_times_padded/stride
|
|
213
|
+
x = self.temporal_layer(x) # type: ignore[operator]
|
|
214
|
+
|
|
215
|
+
# shape: batch_size, n_filters_spat * n_bands, stride, 1
|
|
216
|
+
x = self.flatten_layer(x)
|
|
217
|
+
|
|
218
|
+
# shape: batch_size, n_filters_spat * n_bands * stride
|
|
219
|
+
x = self.final_layer(x)
|
|
220
|
+
# shape: batch_size, n_outputs
|
|
221
|
+
return x
|
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from einops.layers.torch import Rearrange
|
|
8
|
+
from mne.utils import warn
|
|
9
|
+
from torch import nn
|
|
10
|
+
|
|
11
|
+
from braindecode.models.base import EEGModuleMixin
|
|
12
|
+
from braindecode.modules import (
|
|
13
|
+
FilterBankLayer,
|
|
14
|
+
LogVarLayer,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class FBLightConvNet(EEGModuleMixin, nn.Module):
|
|
19
|
+
"""LightConvNet from Ma, X et al (2023) [lightconvnet]_.
|
|
20
|
+
|
|
21
|
+
.. figure:: https://raw.githubusercontent.com/Ma-Xinzhi/LightConvNet/refs/heads/main/network_architecture.png
|
|
22
|
+
:align: center
|
|
23
|
+
:alt: LightConvNet Neural Network
|
|
24
|
+
|
|
25
|
+
A lightweight convolutional neural network incorporating temporal
|
|
26
|
+
dependency learning and attention mechanisms. The architecture is
|
|
27
|
+
designed to efficiently capture spatial and temporal features through
|
|
28
|
+
specialized convolutional layers and **multi-head attention**.
|
|
29
|
+
|
|
30
|
+
The network architecture consists of four main modules:
|
|
31
|
+
|
|
32
|
+
1. **Spatial and Spectral Information Learning**:
|
|
33
|
+
Applies filterbank and spatial convolutions.
|
|
34
|
+
This module is followed by batch normalization and
|
|
35
|
+
an activation function to enhance feature representation.
|
|
36
|
+
|
|
37
|
+
2. **Temporal Segmentation and Feature Extraction**:
|
|
38
|
+
Divides the processed data into non-overlapping temporal windows.
|
|
39
|
+
Within each window, a variance-based layer extracts discriminative features,
|
|
40
|
+
which are then log-transformed to stabilize variance before being
|
|
41
|
+
passed to the attention module.
|
|
42
|
+
|
|
43
|
+
3. **Temporal Attention Module**: Utilizes a multi-head attention
|
|
44
|
+
mechanism with depthwise separable convolutions to capture dependencies
|
|
45
|
+
across different temporal segments. The attention weights are normalized
|
|
46
|
+
using softmax and aggregated to form a comprehensive temporal
|
|
47
|
+
representation.
|
|
48
|
+
|
|
49
|
+
4. **Final Layer**: Flattens the aggregated features and passes them
|
|
50
|
+
through a linear layer to with kernel sizes matching the input
|
|
51
|
+
dimensions to integrate features across different channels generate the
|
|
52
|
+
final output predictions.
|
|
53
|
+
|
|
54
|
+
Notes
|
|
55
|
+
-----
|
|
56
|
+
This implementation is not guaranteed to be correct and has not been checked
|
|
57
|
+
by the original authors; it is a braindecode adaptation from the Pytorch
|
|
58
|
+
source-code [lightconvnetcode]_.
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
n_bands : int or None or list of tuple of int, default=8
|
|
63
|
+
Number of frequency bands or a list of frequency band tuples. If a list of tuples is provided,
|
|
64
|
+
each tuple defines the lower and upper bounds of a frequency band.
|
|
65
|
+
n_filters_spat : int, default=32
|
|
66
|
+
Number of spatial filters in the depthwise convolutional layer.
|
|
67
|
+
n_dim : int, default=3
|
|
68
|
+
Number of dimensions for the temporal reduction layer.
|
|
69
|
+
stride_factor : int, default=4
|
|
70
|
+
Stride factor used for reshaping the temporal dimension.
|
|
71
|
+
activation : nn.Module, default=nn.ELU
|
|
72
|
+
Activation function class to apply after convolutional layers.
|
|
73
|
+
verbose : bool, default=False
|
|
74
|
+
If True, enables verbose output during filter creation using mne.
|
|
75
|
+
filter_parameters : dict, default={}
|
|
76
|
+
Additional parameters for the FilterBankLayer.
|
|
77
|
+
heads : int, default=8
|
|
78
|
+
Number of attention heads in the multi-head attention mechanism.
|
|
79
|
+
weight_softmax : bool, default=True
|
|
80
|
+
If True, applies softmax to the attention weights.
|
|
81
|
+
bias : bool, default=False
|
|
82
|
+
If True, includes a bias term in the convolutional layers.
|
|
83
|
+
|
|
84
|
+
References
|
|
85
|
+
----------
|
|
86
|
+
.. [lightconvnet] Ma, X., Chen, W., Pei, Z., Liu, J., Huang, B., & Chen, J.
|
|
87
|
+
(2023). A temporal dependency learning CNN with attention mechanism
|
|
88
|
+
for MI-EEG decoding. IEEE Transactions on Neural Systems and
|
|
89
|
+
Rehabilitation Engineering.
|
|
90
|
+
.. [lightconvnetcode] Link to source-code:
|
|
91
|
+
https://github.com/Ma-Xinzhi/LightConvNet
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
# Braindecode parameters
|
|
97
|
+
n_chans=None,
|
|
98
|
+
n_outputs=None,
|
|
99
|
+
chs_info=None,
|
|
100
|
+
n_times=None,
|
|
101
|
+
input_window_seconds=None,
|
|
102
|
+
sfreq=None,
|
|
103
|
+
# models parameters
|
|
104
|
+
n_bands=9,
|
|
105
|
+
n_filters_spat: int = 32,
|
|
106
|
+
n_dim: int = 3,
|
|
107
|
+
stride_factor: int = 4,
|
|
108
|
+
win_len: int = 250,
|
|
109
|
+
heads: int = 8,
|
|
110
|
+
weight_softmax: bool = True,
|
|
111
|
+
bias: bool = False,
|
|
112
|
+
activation: nn.Module = nn.ELU,
|
|
113
|
+
verbose: bool = False,
|
|
114
|
+
filter_parameters: Optional[dict] = None,
|
|
115
|
+
):
|
|
116
|
+
super().__init__(
|
|
117
|
+
n_chans=n_chans,
|
|
118
|
+
n_outputs=n_outputs,
|
|
119
|
+
chs_info=chs_info,
|
|
120
|
+
n_times=n_times,
|
|
121
|
+
input_window_seconds=input_window_seconds,
|
|
122
|
+
sfreq=sfreq,
|
|
123
|
+
)
|
|
124
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
125
|
+
|
|
126
|
+
# Parameters
|
|
127
|
+
self.n_bands = n_bands
|
|
128
|
+
self.n_filters_spat = n_filters_spat
|
|
129
|
+
self.n_dim = n_dim
|
|
130
|
+
self.stride_factor = stride_factor
|
|
131
|
+
self.win_len = win_len
|
|
132
|
+
self.activation = activation
|
|
133
|
+
self.heads = heads
|
|
134
|
+
self.weight_softmax = weight_softmax
|
|
135
|
+
self.bias = bias
|
|
136
|
+
self.filter_parameters = filter_parameters or {}
|
|
137
|
+
|
|
138
|
+
# Checkers
|
|
139
|
+
self.n_times_truncated = self.n_times
|
|
140
|
+
if self.n_times % self.win_len != 0:
|
|
141
|
+
warn(
|
|
142
|
+
f"Time dimension ({self.n_times}) is not divisible by"
|
|
143
|
+
f" win_len ({self.win_len}). Input will be "
|
|
144
|
+
f"truncated in {self.n_times % self.win_len} temporal points ",
|
|
145
|
+
UserWarning,
|
|
146
|
+
)
|
|
147
|
+
self.n_times_truncated = self.n_times - (self.n_times % self.win_len)
|
|
148
|
+
|
|
149
|
+
# Layers
|
|
150
|
+
# Following paper nomeclature
|
|
151
|
+
self.spectral_filtering = FilterBankLayer(
|
|
152
|
+
n_chans=self.n_chans,
|
|
153
|
+
sfreq=self.sfreq,
|
|
154
|
+
band_filters=self.n_bands,
|
|
155
|
+
verbose=verbose,
|
|
156
|
+
**self.filter_parameters,
|
|
157
|
+
)
|
|
158
|
+
# As we have an internal process to create the bands,
|
|
159
|
+
# we get the values from the filterbank
|
|
160
|
+
self.n_bands = self.spectral_filtering.n_bands
|
|
161
|
+
|
|
162
|
+
# The convolution here is different.
|
|
163
|
+
self.spatial_conv = nn.Sequential(
|
|
164
|
+
nn.Conv2d(
|
|
165
|
+
in_channels=self.n_bands,
|
|
166
|
+
out_channels=self.n_filters_spat,
|
|
167
|
+
kernel_size=(self.n_chans, 1),
|
|
168
|
+
),
|
|
169
|
+
nn.BatchNorm2d(self.n_filters_spat),
|
|
170
|
+
self.activation(),
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Temporal aggregator
|
|
174
|
+
self.temporal_layer = LogVarLayer(self.n_dim, False)
|
|
175
|
+
|
|
176
|
+
self.flatten_layer = Rearrange("batch ... -> batch (...)")
|
|
177
|
+
|
|
178
|
+
# LightWeightConv1D
|
|
179
|
+
self.attn_conv = _LightweightConv1d(
|
|
180
|
+
self.n_filters_spat,
|
|
181
|
+
(self.n_times // self.win_len),
|
|
182
|
+
heads=self.heads,
|
|
183
|
+
weight_softmax=weight_softmax,
|
|
184
|
+
bias=bias,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
self.final_layer = nn.Linear(
|
|
188
|
+
in_features=self.n_filters_spat,
|
|
189
|
+
out_features=self.n_outputs,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
193
|
+
"""
|
|
194
|
+
Forward pass of the FBLightConvNet model.
|
|
195
|
+
Parameters
|
|
196
|
+
----------
|
|
197
|
+
x : torch.Tensor
|
|
198
|
+
Input tensor with shape (batch_size, n_chans, n_times).
|
|
199
|
+
Returns
|
|
200
|
+
-------
|
|
201
|
+
torch.Tensor
|
|
202
|
+
Output tensor with shape (batch_size, n_outputs).
|
|
203
|
+
"""
|
|
204
|
+
batch_size, _, _ = x.shape
|
|
205
|
+
# x.shape: batch, n_chans, n_times
|
|
206
|
+
|
|
207
|
+
x = self.spectral_filtering(x)
|
|
208
|
+
# x.shape: batch, nbands, n_chans, n_times
|
|
209
|
+
|
|
210
|
+
x = self.spatial_conv(x)
|
|
211
|
+
# x.shape: batch, n_filters_spat, n_times
|
|
212
|
+
|
|
213
|
+
x = x[:, :, :, : self.n_times_truncated]
|
|
214
|
+
# batch, n_filters_spat, n_times_trucated
|
|
215
|
+
|
|
216
|
+
x = x.reshape([batch_size, self.n_filters_spat, -1, self.win_len])
|
|
217
|
+
# batch, n_filters_spat, n_windows, win_len
|
|
218
|
+
# where the n_windows = n_times_truncated / win_len
|
|
219
|
+
# and win_len = 250 by default
|
|
220
|
+
|
|
221
|
+
x = self.temporal_layer(x)
|
|
222
|
+
# x.shape : batch, n_filters_spat, n_windows
|
|
223
|
+
|
|
224
|
+
x = self.attn_conv(x)
|
|
225
|
+
# x.shape : batch, n_filters_spat, 1
|
|
226
|
+
|
|
227
|
+
x = self.flatten_layer(x)
|
|
228
|
+
# x.shape : batch, n_filters_spat
|
|
229
|
+
|
|
230
|
+
x = self.final_layer(x)
|
|
231
|
+
# x.shape : batch, n_outputs
|
|
232
|
+
return x
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class _LightweightConv1d(nn.Module):
|
|
236
|
+
"""Lightweight 1D Convolution Module.
|
|
237
|
+
|
|
238
|
+
Applies a convolution operation with multiple heads, allowing for
|
|
239
|
+
parallel filter applications. Optionally applies a softmax normalization
|
|
240
|
+
to the convolution weights.
|
|
241
|
+
|
|
242
|
+
Parameters
|
|
243
|
+
----------
|
|
244
|
+
input_size : int
|
|
245
|
+
Number of channels of the input and output.
|
|
246
|
+
kernel_size : int, optional
|
|
247
|
+
Size of the convolution kernel. Default is `1`.
|
|
248
|
+
padding : int, optional
|
|
249
|
+
Amount of zero-padding added to both sides of the input. Default is `0`.
|
|
250
|
+
heads : int, optional
|
|
251
|
+
Number of attention heads used. The weight has shape `(heads, 1, kernel_size)`.
|
|
252
|
+
Default is `1`.
|
|
253
|
+
weight_softmax : bool, optional
|
|
254
|
+
If `True`, normalizes the convolution weights with softmax before applying the convolution.
|
|
255
|
+
Default is `False`.
|
|
256
|
+
bias : bool, optional
|
|
257
|
+
If `True`, adds a learnable bias to the output. Default is `False`.
|
|
258
|
+
"""
|
|
259
|
+
|
|
260
|
+
def __init__(
|
|
261
|
+
self,
|
|
262
|
+
input_size: int,
|
|
263
|
+
kernel_size: int = 1,
|
|
264
|
+
padding: int = 0,
|
|
265
|
+
heads: int = 1,
|
|
266
|
+
weight_softmax: bool = False,
|
|
267
|
+
bias: bool = False,
|
|
268
|
+
):
|
|
269
|
+
super().__init__()
|
|
270
|
+
self.input_size = input_size
|
|
271
|
+
self.kernel_size = kernel_size
|
|
272
|
+
self.heads = heads
|
|
273
|
+
self.padding = padding
|
|
274
|
+
self.weight_softmax = weight_softmax
|
|
275
|
+
self.weight = nn.Parameter(torch.Tensor(heads, 1, kernel_size))
|
|
276
|
+
|
|
277
|
+
if bias:
|
|
278
|
+
self.bias = nn.Parameter(torch.Tensor(input_size))
|
|
279
|
+
else:
|
|
280
|
+
self.bias = None
|
|
281
|
+
|
|
282
|
+
self._init_parameters()
|
|
283
|
+
|
|
284
|
+
def _init_parameters(self):
|
|
285
|
+
nn.init.xavier_uniform_(self.weight)
|
|
286
|
+
if self.bias is not None:
|
|
287
|
+
nn.init.constant_(self.bias, 0.0)
|
|
288
|
+
|
|
289
|
+
def forward(self, input):
|
|
290
|
+
# batch, n_filters_spat, n_windows
|
|
291
|
+
B, C, T = input.size()
|
|
292
|
+
|
|
293
|
+
H = self.heads
|
|
294
|
+
|
|
295
|
+
weight = self.weight
|
|
296
|
+
if self.weight_softmax:
|
|
297
|
+
weight = F.softmax(weight, dim=-1)
|
|
298
|
+
# shape: (heads, 1, kernel_size)
|
|
299
|
+
|
|
300
|
+
# reshape input so each head is its own “batch”
|
|
301
|
+
# original C = H * (C/H), so view to (B * (C/H), H, T) then transpose
|
|
302
|
+
# but since C/H == 1 here per head-channel grouping, .view(-1, H, T) works
|
|
303
|
+
# new shape: (B * channels_per_head, H, T)
|
|
304
|
+
input = input.view(-1, H, T)
|
|
305
|
+
output = F.conv1d(input, weight, padding=self.padding, groups=self.heads)
|
|
306
|
+
# 4, 8, 1
|
|
307
|
+
output = output.view(B, C, -1)
|
|
308
|
+
# 1, 32, 1
|
|
309
|
+
if self.bias is not None:
|
|
310
|
+
# Add bias if it exists
|
|
311
|
+
output = output + self.bias.view(1, -1, 1)
|
|
312
|
+
# final shape: batch, n_filters_spat
|
|
313
|
+
return output
|