braindecode 1.3.0.dev180329405__py3-none-any.whl → 1.3.0.dev182330353__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/augmentation/base.py +1 -1
- braindecode/datasets/__init__.py +12 -4
- braindecode/datasets/base.py +115 -151
- braindecode/datasets/bcicomp.py +4 -4
- braindecode/datasets/bids.py +3 -3
- braindecode/datasets/experimental.py +2 -2
- braindecode/datasets/mne.py +3 -5
- braindecode/datasets/moabb.py +17 -7
- braindecode/datasets/nmt.py +2 -2
- braindecode/datasets/sleep_physio_challe_18.py +2 -2
- braindecode/datasets/sleep_physionet.py +2 -2
- braindecode/datasets/tuh.py +2 -2
- braindecode/datasets/xy.py +2 -2
- braindecode/datautil/__init__.py +11 -1
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/serialization.py +7 -7
- braindecode/functional/functions.py +6 -2
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +6 -0
- braindecode/models/atcnet.py +26 -27
- braindecode/models/attentionbasenet.py +37 -32
- braindecode/models/attn_sleep.py +2 -0
- braindecode/models/base.py +280 -2
- braindecode/models/bendr.py +469 -0
- braindecode/models/biot.py +2 -0
- braindecode/models/contrawr.py +2 -0
- braindecode/models/ctnet.py +8 -3
- braindecode/models/deepsleepnet.py +28 -19
- braindecode/models/eegconformer.py +2 -2
- braindecode/models/eeginception_erp.py +31 -25
- braindecode/models/eegitnet.py +2 -0
- braindecode/models/eegminer.py +2 -0
- braindecode/models/eegnet.py +1 -1
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +2 -0
- braindecode/models/fbcnet.py +5 -1
- braindecode/models/fblightconvnet.py +2 -0
- braindecode/models/fbmsnet.py +20 -6
- braindecode/models/ifnet.py +2 -0
- braindecode/models/labram.py +33 -26
- braindecode/models/medformer.py +758 -0
- braindecode/models/msvtnet.py +2 -0
- braindecode/models/patchedtransformer.py +1 -1
- braindecode/models/signal_jepa.py +111 -27
- braindecode/models/sinc_shallow.py +12 -9
- braindecode/models/sstdpn.py +11 -11
- braindecode/models/summary.csv +3 -0
- braindecode/models/syncnet.py +2 -0
- braindecode/models/tcn.py +2 -0
- braindecode/models/usleep.py +26 -21
- braindecode/models/util.py +3 -0
- braindecode/modules/attention.py +10 -10
- braindecode/modules/blocks.py +3 -3
- braindecode/modules/filter.py +2 -9
- braindecode/modules/layers.py +18 -17
- braindecode/preprocessing/__init__.py +232 -3
- braindecode/preprocessing/eegprep_preprocess.py +1202 -0
- braindecode/preprocessing/mne_preprocess.py +142 -10
- braindecode/preprocessing/preprocess.py +28 -18
- braindecode/preprocessing/util.py +166 -0
- braindecode/preprocessing/windowers.py +26 -20
- braindecode/samplers/base.py +8 -8
- braindecode/version.py +1 -1
- {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/METADATA +6 -2
- braindecode-1.3.0.dev182330353.dist-info/RECORD +109 -0
- braindecode-1.3.0.dev180329405.dist-info/RECORD +0 -103
- {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,917 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, List, Tuple, cast
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from einops.layers.torch import Rearrange
|
|
7
|
+
from torch import nn
|
|
8
|
+
|
|
9
|
+
from braindecode.datautil.channel_utils import (
|
|
10
|
+
division_channels_idx,
|
|
11
|
+
match_hemisphere_chans,
|
|
12
|
+
)
|
|
13
|
+
from braindecode.models.base import EEGModuleMixin
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class EEGSym(EEGModuleMixin, nn.Module):
|
|
17
|
+
"""EEGSym from Pérez-Velasco et al (2022) [eegsym2022]_.
|
|
18
|
+
|
|
19
|
+
:bdg-success:`Convolution` :bdg-dark-line:`Channel`
|
|
20
|
+
|
|
21
|
+
.. figure:: ../../docs/_static/model/eegsym.png
|
|
22
|
+
:align: center
|
|
23
|
+
:alt: EEGSym Architecture
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
The **EEGSym** is a novel Convolutional Neural Network (CNN) architecture designed for
|
|
27
|
+
Motor Imagery (MI) based Brain-Computer Interfaces (BCIs), primarily aimed at
|
|
28
|
+
**overcoming inter-subject variability** and significantly **reducing BCI inefficiency**
|
|
29
|
+
[eegsym2022]_.
|
|
30
|
+
|
|
31
|
+
The architecture integrates advances from Deep Learning (DL), complemented by
|
|
32
|
+
Transfer Learning (TL) techniques and Data Augmentation (DA), to achieve strong
|
|
33
|
+
performance in inter-subject MI classification [eegsym2022]_.
|
|
34
|
+
|
|
35
|
+
.. rubric:: Architectural Overview
|
|
36
|
+
|
|
37
|
+
EEGSym systematically incorporates three core features:
|
|
38
|
+
|
|
39
|
+
#. **Inception Modules** for multi-scale temporal analysis [eegsym2022]_.
|
|
40
|
+
#. **Residual Connections** maintain spatio-temporal signal structure and
|
|
41
|
+
enable deeper feature extraction [eegsym2022]_.
|
|
42
|
+
#. A **Siamese-network design** exploits the inherent symmetry of the brain
|
|
43
|
+
across the mid-sagittal plane [eegsym2022]_.
|
|
44
|
+
|
|
45
|
+
.. rubric:: Macro Components
|
|
46
|
+
|
|
47
|
+
- `EEGSym.symmetric_division` **(Input Processing)**
|
|
48
|
+
- *Operations.* The input is virtually split into left, right, and middle channels.
|
|
49
|
+
Middle (central) channels are duplicated and concatenated to both left
|
|
50
|
+
and right lateralized electrodes to form the two hemisphere inputs [eegsym2022]_.
|
|
51
|
+
- *Role.* Prepares the data for the siamese-network approach,
|
|
52
|
+
reducing the number of parameters in the spatial filters
|
|
53
|
+
for the tempospatial analysis stage [eegsym2022]_.
|
|
54
|
+
|
|
55
|
+
- `EEGSym.inception_block` **(Tempospatial Analysis - Temporal Feature Extraction)**
|
|
56
|
+
- *Operations.* Uses :class:`_InceptionBlock` modules, which apply parallel
|
|
57
|
+
temporal convolutions with different kernel sizes (scales) [eegsym2022]_.
|
|
58
|
+
This is followed by concatenation, residual connections, and average
|
|
59
|
+
pooling for temporal dimensionality reduction [eegsym2022]_.
|
|
60
|
+
- *Role.* Captures detailed temporal relationships in the architecture,
|
|
61
|
+
similarly to :class:`~braindecode.models.eeginception_mi.EEGInceptionMI`
|
|
62
|
+
[eeginception2020]_. The first block uses large temporal kernels
|
|
63
|
+
(e.g., 500 ms, 250 ms, 125 ms) [eegsym2022]_.
|
|
64
|
+
|
|
65
|
+
- `EEGSym.residual_blocks` **(Tempospatial Analysis - Spatial Feature Extraction)**
|
|
66
|
+
- *Operations.* Composed of multiple :class:`_ResidualBlock` modules (typically three instances)
|
|
67
|
+
[eegsym2022]_. Each block applies temporal convolution, pooling, and a spatial analysis layer
|
|
68
|
+
(convolution or grouped convolution) [eegsym2022]_.
|
|
69
|
+
- *Role.* Enhances spatial feature extraction by incorporating residual
|
|
70
|
+
connections across all CNN stages, which helps maintain the spatio-temporal
|
|
71
|
+
structure of the signal through deeper layers [eegsym2022]_.
|
|
72
|
+
|
|
73
|
+
- `EEGSym.channel_merging` **(Hemisphere Merging)**
|
|
74
|
+
- *Operations.* The :class:`_ChannelMergingBlock` reduces the spatial dimensionality
|
|
75
|
+
(Z and C) to 1, performing two residual convolutions followed by a final grouped
|
|
76
|
+
convolution that merges the feature information from the two hemispheres [eegsym2022]_.
|
|
77
|
+
- *Role.* Extracts complex relationships between channels of both hemispheres as part of the
|
|
78
|
+
symmetry exploitation [eegsym2022]_.
|
|
79
|
+
|
|
80
|
+
- `EEGSym.temporal_merging` **(Temporal Collapse)**
|
|
81
|
+
- *Operations.* The :class:`_TemporalMergingBlock` uses residual convolution
|
|
82
|
+
followed by grouped convolution to reduce the temporal dimension (S) to 1 [eegsym2022]_.
|
|
83
|
+
- *Role.* Final step of temporal aggregation before the output module [eegsym2022]_.
|
|
84
|
+
|
|
85
|
+
- `EEGSym.output_blocks` **(Output Processing)**
|
|
86
|
+
- *Operations.* The :class:`_OutputBlock` applies four residual convolution iterations
|
|
87
|
+
(1x1x1 convolutions) followed by flattening [eegsym2022]_.
|
|
88
|
+
- *Role.* Final feature refinement through residual connections before the
|
|
89
|
+
fully connected classification layer [eegsym2022]_.
|
|
90
|
+
|
|
91
|
+
.. rubric:: How the information is encoded temporally, spatially, and spectrally
|
|
92
|
+
|
|
93
|
+
* **Temporal.**
|
|
94
|
+
Temporal features are extracted across multiple scales in the inception modules
|
|
95
|
+
using different temporal convolution kernel sizes (e.g., corresponding to
|
|
96
|
+
500 ms, 250 ms, and 125 ms windows for a 128 Hz sampling rate), very similar to [eeginception2020]_.
|
|
97
|
+
Subsequent pooling operations and residual blocks continue to reduce the temporal dimension
|
|
98
|
+
[eegsym2022]_.
|
|
99
|
+
|
|
100
|
+
* **Spatial.**
|
|
101
|
+
|
|
102
|
+
Spatial features are extracted via two main mechanisms:
|
|
103
|
+
|
|
104
|
+
- (1) The **siamese-network design** implicitly introduces brain symmetry by treating the two hemispheres
|
|
105
|
+
equally during feature extraction [eegsym2022]_.
|
|
106
|
+
- (2) **Residual connections** are utilized in the Tempospatial Analysis stage to enhance the extraction of
|
|
107
|
+
spatial correlations between electrodes [eegsym2022]_.
|
|
108
|
+
|
|
109
|
+
* **Spectral.**
|
|
110
|
+
Spectral information is implicitly captured by the varying kernel sizes of the temporal convolutions
|
|
111
|
+
in the inception modules [eegsym2022]_. These kernels filter the signal across different temporal windows,
|
|
112
|
+
corresponding to different frequency characteristics.
|
|
113
|
+
|
|
114
|
+
Notes
|
|
115
|
+
----------
|
|
116
|
+
* EEGSym achieved competitive accuracies across five large MI datasets [eegsym2022]_.
|
|
117
|
+
* The model maintained high accuracy using a reduced set of electrodes (8 or 16 channels)
|
|
118
|
+
[eegsym2022]_.
|
|
119
|
+
* This is PyTorch implementation of the EEGSym model of the TensorFlow original [eegsym2022code]_.
|
|
120
|
+
|
|
121
|
+
Parameters
|
|
122
|
+
----------
|
|
123
|
+
filters_per_branch : int, optional
|
|
124
|
+
Number of filters in each inception branch. Should be a multiple of 8.
|
|
125
|
+
Default is 12 [eegsym2022]_.
|
|
126
|
+
scales_time : tuple of int, optional
|
|
127
|
+
Temporal scales (in milliseconds) for the temporal convolutions in the first
|
|
128
|
+
inception module. Default is (500, 250, 125) [eegsym2022]_.
|
|
129
|
+
drop_prob : float, optional
|
|
130
|
+
Dropout probability. Default is 0.25 [eegsym2022]_.
|
|
131
|
+
activation : type[nn.Module], optional
|
|
132
|
+
Activation function class to use. Default is :class:`nn.ELU` [eegsym2022]_.
|
|
133
|
+
spatial_resnet_repetitions : int, optional
|
|
134
|
+
Number of repetitions of the spatial analysis operations at each step.
|
|
135
|
+
Default is 5 [eegsym2022]_.
|
|
136
|
+
left_right_chs : list of tuple of str, optional
|
|
137
|
+
List of tuples pairing left and right hemisphere channel names,
|
|
138
|
+
e.g., ``[('C3', 'C4'), ('FC5', 'FC6')]``. If not provided, channels
|
|
139
|
+
are automatically split into left/right hemispheres using
|
|
140
|
+
:func:`~braindecode.datautil.channel_utils.division_channels_idx` and
|
|
141
|
+
:func:`~braindecode.datautil.channel_utils.match_hemisphere_chans`.
|
|
142
|
+
Must be provided together with ``middle_chs`` [eegsym2022]_.
|
|
143
|
+
middle_chs : list of str, optional
|
|
144
|
+
List of midline (central) channel names that lie on the mid-sagittal plane,
|
|
145
|
+
e.g., ``['FZ', 'CZ', 'PZ']``. These channels are duplicated and concatenated
|
|
146
|
+
to both hemispheres. If not provided, channels are automatically identified
|
|
147
|
+
using :func:`~braindecode.datautil.channel_utils.division_channels_idx`.
|
|
148
|
+
Must be provided together with ``left_right_chs`` [eegsym2022]_.
|
|
149
|
+
|
|
150
|
+
References
|
|
151
|
+
----------
|
|
152
|
+
.. [eegsym2022] Pérez-Velasco, S., Santamaría-Vázquez, E., Martínez-Cagigal, V.,
|
|
153
|
+
Marcos-Martínez, D., & Hornero, R. (2022). EEGSym: Overcoming inter-subject
|
|
154
|
+
variability in motor imagery based BCIs with deep learning. IEEE Transactions
|
|
155
|
+
on Neural Systems and Rehabilitation Engineering, 30, 1766-1775.
|
|
156
|
+
.. [eegsym2022code] Pérez-Velasco, S., EEGSym source code.
|
|
157
|
+
https://github.com/Serpeve/EEGSym
|
|
158
|
+
.. [eeginception2020] Santamaría-Vázquez, E., Martínez-Cagigal, V.,
|
|
159
|
+
Vaquerizo-Villar, F., & Hornero, R. (2020). EEG-Inception: A novel deep
|
|
160
|
+
convolutional neural network for assistive ERP-based brain-computer interfaces.
|
|
161
|
+
IEEE Transactions on Neural Systems and Rehabilitation Engineering, 28, 2773-2782.
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
def __init__(
|
|
165
|
+
self,
|
|
166
|
+
# braidecode parameters
|
|
167
|
+
n_chans=None,
|
|
168
|
+
n_outputs=None,
|
|
169
|
+
n_times=None,
|
|
170
|
+
chs_info=None,
|
|
171
|
+
input_window_seconds=None,
|
|
172
|
+
sfreq=None,
|
|
173
|
+
# Model parameters
|
|
174
|
+
filters_per_branch: int = 12,
|
|
175
|
+
scales_time: Tuple[int, int, int] = (500, 250, 125),
|
|
176
|
+
drop_prob: float = 0.25,
|
|
177
|
+
activation: type[nn.Module] = nn.ELU,
|
|
178
|
+
spatial_resnet_repetitions: int = 5,
|
|
179
|
+
left_right_chs: list[tuple[str, str]] | None = None,
|
|
180
|
+
middle_chs: list[str] | None = None,
|
|
181
|
+
):
|
|
182
|
+
if (left_right_chs is None) != (middle_chs is None):
|
|
183
|
+
raise ValueError(
|
|
184
|
+
"Either both or none of left_right_chs and middle_chs must be provided."
|
|
185
|
+
)
|
|
186
|
+
super().__init__(
|
|
187
|
+
n_outputs=n_outputs,
|
|
188
|
+
n_chans=n_chans,
|
|
189
|
+
chs_info=chs_info,
|
|
190
|
+
n_times=n_times,
|
|
191
|
+
input_window_seconds=input_window_seconds,
|
|
192
|
+
sfreq=sfreq,
|
|
193
|
+
)
|
|
194
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
195
|
+
|
|
196
|
+
self.filters_per_branch = filters_per_branch
|
|
197
|
+
self.scales_time = scales_time
|
|
198
|
+
self.drop_prob = drop_prob
|
|
199
|
+
self.activation = activation()
|
|
200
|
+
self.spatial_resnet_repetitions = spatial_resnet_repetitions
|
|
201
|
+
|
|
202
|
+
# Calculate scales in samples
|
|
203
|
+
self.scales_samples = [int(s * self.sfreq / 2000) * 2 + 1 for s in scales_time]
|
|
204
|
+
|
|
205
|
+
# Note: chs_info is actually list[dict] despite base class type hint
|
|
206
|
+
# saying list[str]
|
|
207
|
+
ch_names = [cast(dict[str, Any], ch)["ch_name"] for ch in self.chs_info]
|
|
208
|
+
if left_right_chs is None:
|
|
209
|
+
left_chs, right_chs, middle_chs = division_channels_idx(ch_names)
|
|
210
|
+
try:
|
|
211
|
+
# Try to match hemispheres based on channel naming
|
|
212
|
+
left_chs, right_chs = zip(*match_hemisphere_chans(left_chs, right_chs))
|
|
213
|
+
except (ValueError, IndexError):
|
|
214
|
+
# Fallback: if matching fails, treat all channels as one hemisphere
|
|
215
|
+
# This allows the model to work with arbitrary channel configurations
|
|
216
|
+
left_chs = ch_names
|
|
217
|
+
right_chs = ch_names
|
|
218
|
+
middle_chs = []
|
|
219
|
+
else:
|
|
220
|
+
left_chs, right_chs = zip(*left_right_chs)
|
|
221
|
+
# middle_chs is guaranteed to be not None when left_right_chs is not None
|
|
222
|
+
# (checked in __init__ validation)
|
|
223
|
+
assert middle_chs is not None, (
|
|
224
|
+
"middle_chs must be provided with left_right_chs"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Convert to indices and store as tensors for TorchScript compatibility
|
|
228
|
+
left_idx = [ch_names.index(ch) for ch in left_chs]
|
|
229
|
+
right_idx = [ch_names.index(ch) for ch in right_chs]
|
|
230
|
+
middle_idx = [ch_names.index(ch) for ch in middle_chs]
|
|
231
|
+
|
|
232
|
+
# Register as buffers (non-trainable tensors) for TorchScript compatibility
|
|
233
|
+
self.register_buffer("left_idx", torch.tensor(left_idx, dtype=torch.long))
|
|
234
|
+
self.register_buffer("right_idx", torch.tensor(right_idx, dtype=torch.long))
|
|
235
|
+
self.register_buffer("middle_idx", torch.tensor(middle_idx, dtype=torch.long))
|
|
236
|
+
|
|
237
|
+
self.n_channels_per_hemi = len(left_idx) + len(middle_idx)
|
|
238
|
+
##################
|
|
239
|
+
# Build the model
|
|
240
|
+
##################
|
|
241
|
+
self.include_extra_dim = Rearrange("batch channel time -> batch 1 channel time")
|
|
242
|
+
|
|
243
|
+
self.permute_layer = Rearrange(
|
|
244
|
+
"batch features z time space -> batch features z space time"
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# Build the model
|
|
248
|
+
self.inception_block1 = _InceptionBlock(
|
|
249
|
+
in_channels=1,
|
|
250
|
+
scales_samples=self.scales_samples,
|
|
251
|
+
filters_per_branch=self.filters_per_branch,
|
|
252
|
+
ncha=self.n_channels_per_hemi,
|
|
253
|
+
activation=self.activation,
|
|
254
|
+
drop_prob=self.drop_prob,
|
|
255
|
+
average_pool=2,
|
|
256
|
+
spatial_resnet_repetitions=self.spatial_resnet_repetitions,
|
|
257
|
+
init=True,
|
|
258
|
+
)
|
|
259
|
+
self.inception_block2 = _InceptionBlock(
|
|
260
|
+
in_channels=self.filters_per_branch * len(self.scales_samples),
|
|
261
|
+
scales_samples=[max(1, s // 4) for s in self.scales_samples],
|
|
262
|
+
filters_per_branch=self.filters_per_branch,
|
|
263
|
+
ncha=self.n_channels_per_hemi,
|
|
264
|
+
activation=self.activation,
|
|
265
|
+
drop_prob=self.drop_prob,
|
|
266
|
+
average_pool=2,
|
|
267
|
+
spatial_resnet_repetitions=self.spatial_resnet_repetitions,
|
|
268
|
+
init=False,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# Residual blocks (spatial dim is still n_channels_per_hemi through the network)
|
|
272
|
+
self.residual_blocks = nn.Sequential(
|
|
273
|
+
_ResidualBlock(
|
|
274
|
+
in_channels=self.filters_per_branch * len(self.scales_samples),
|
|
275
|
+
filters=self.filters_per_branch
|
|
276
|
+
* len(self.scales_samples), # No reduction
|
|
277
|
+
kernel_size=16,
|
|
278
|
+
ncha=self.n_channels_per_hemi,
|
|
279
|
+
activation=self.activation,
|
|
280
|
+
drop_prob=self.drop_prob,
|
|
281
|
+
average_pool=2,
|
|
282
|
+
spatial_resnet_repetitions=self.spatial_resnet_repetitions,
|
|
283
|
+
),
|
|
284
|
+
_ResidualBlock(
|
|
285
|
+
in_channels=self.filters_per_branch * len(self.scales_samples),
|
|
286
|
+
filters=int(
|
|
287
|
+
self.filters_per_branch * len(self.scales_samples) / 2
|
|
288
|
+
), # Reduce by /2
|
|
289
|
+
kernel_size=8,
|
|
290
|
+
ncha=self.n_channels_per_hemi,
|
|
291
|
+
activation=self.activation,
|
|
292
|
+
drop_prob=self.drop_prob,
|
|
293
|
+
average_pool=2,
|
|
294
|
+
spatial_resnet_repetitions=self.spatial_resnet_repetitions,
|
|
295
|
+
),
|
|
296
|
+
_ResidualBlock(
|
|
297
|
+
in_channels=int(self.filters_per_branch * len(self.scales_samples) / 2),
|
|
298
|
+
filters=int(
|
|
299
|
+
self.filters_per_branch * len(self.scales_samples) / 4
|
|
300
|
+
), # Reduce by /2
|
|
301
|
+
kernel_size=4,
|
|
302
|
+
ncha=self.n_channels_per_hemi,
|
|
303
|
+
activation=self.activation,
|
|
304
|
+
drop_prob=self.drop_prob,
|
|
305
|
+
average_pool=2,
|
|
306
|
+
spatial_resnet_repetitions=self.spatial_resnet_repetitions,
|
|
307
|
+
),
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
# Temporal reduction
|
|
311
|
+
self.temporal_reduction = nn.Sequential(
|
|
312
|
+
_TemporalBlock(
|
|
313
|
+
in_channels=int(self.filters_per_branch * len(self.scales_samples) / 4),
|
|
314
|
+
filters=int(self.filters_per_branch * len(self.scales_samples) / 4),
|
|
315
|
+
kernel_size=4,
|
|
316
|
+
activation=self.activation,
|
|
317
|
+
drop_prob=self.drop_prob,
|
|
318
|
+
),
|
|
319
|
+
nn.AvgPool3d(kernel_size=(1, 2, 1)),
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# Channel merging
|
|
323
|
+
self.channel_merging = _ChannelMergingBlock(
|
|
324
|
+
in_channels=int(self.filters_per_branch * len(self.scales_samples) / 4),
|
|
325
|
+
filters=int(self.filters_per_branch * len(self.scales_samples) / 4),
|
|
326
|
+
groups=int(
|
|
327
|
+
self.filters_per_branch * len(self.scales_samples) / 12
|
|
328
|
+
), # 36/12=3 groups
|
|
329
|
+
ncha=self.n_channels_per_hemi,
|
|
330
|
+
division=2,
|
|
331
|
+
activation=self.activation,
|
|
332
|
+
drop_prob=self.drop_prob,
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
# Temporal merging
|
|
336
|
+
# Calculate temporal dimension at this point
|
|
337
|
+
# After: Inc1 (pool/2), Inc2 (pool/2), Res1-3 (pool/2 each), TempRed (pool/2)
|
|
338
|
+
# Total reduction: 2^6 = 64
|
|
339
|
+
temporal_dim_at_merging = self.n_times // 64
|
|
340
|
+
|
|
341
|
+
self.temporal_merging = _TemporalMergingBlock(
|
|
342
|
+
in_channels=int(self.filters_per_branch * len(self.scales_samples) / 4),
|
|
343
|
+
filters=int(self.filters_per_branch * len(self.scales_samples) / 2),
|
|
344
|
+
groups=int(self.filters_per_branch * len(self.scales_samples) / 4),
|
|
345
|
+
n_times=temporal_dim_at_merging,
|
|
346
|
+
activation=self.activation,
|
|
347
|
+
drop_prob=self.drop_prob,
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
# Output layers
|
|
351
|
+
self.output_blocks = nn.Sequential(
|
|
352
|
+
_OutputBlock(
|
|
353
|
+
in_channels=int(self.filters_per_branch * len(self.scales_samples) / 2),
|
|
354
|
+
activation=self.activation,
|
|
355
|
+
drop_prob=self.drop_prob,
|
|
356
|
+
),
|
|
357
|
+
nn.Flatten(),
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
# Final fully connected layer
|
|
361
|
+
self.final_layer = nn.Linear(
|
|
362
|
+
in_features=int(self.filters_per_branch * len(self.scales_samples) / 2),
|
|
363
|
+
out_features=self.n_outputs,
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
def forward(self, x):
|
|
367
|
+
"""Forward pass.
|
|
368
|
+
|
|
369
|
+
Parameters
|
|
370
|
+
----------
|
|
371
|
+
x : torch.Tensor
|
|
372
|
+
Input tensor of shape (batch_size, n_channels, n_times).
|
|
373
|
+
|
|
374
|
+
Returns
|
|
375
|
+
-------
|
|
376
|
+
torch.Tensor
|
|
377
|
+
Output tensor of shape (batch_size, n_classes).
|
|
378
|
+
"""
|
|
379
|
+
# Input: (B, C, T) = (batch, channels, time)
|
|
380
|
+
# Step 1: Add feature dimension
|
|
381
|
+
x = self.include_extra_dim(x) # (B, 1, C, T)
|
|
382
|
+
|
|
383
|
+
# Step 2: Split into left, right, and middle channels
|
|
384
|
+
# Use index_select for TorchScript compatibility
|
|
385
|
+
left_data = torch.index_select(x, 2, self.left_idx) # (B, 1, n_left, T)
|
|
386
|
+
right_data = torch.index_select(x, 2, self.right_idx) # (B, 1, n_right, T)
|
|
387
|
+
middle_data = torch.index_select(x, 2, self.middle_idx) # (B, 1, n_middle, T)
|
|
388
|
+
|
|
389
|
+
# Step 3: Concatenate middle channels to both hemispheres
|
|
390
|
+
left_hemi = torch.cat(
|
|
391
|
+
[left_data, middle_data], dim=2
|
|
392
|
+
) # (B, 1, n_left+n_middle, T)
|
|
393
|
+
right_hemi = torch.cat(
|
|
394
|
+
[right_data, middle_data], dim=2
|
|
395
|
+
) # (B, 1, n_right+n_middle, T)
|
|
396
|
+
|
|
397
|
+
# Step 4: Stack along Z dimension
|
|
398
|
+
x = torch.stack([left_hemi, right_hemi], dim=2) # (B, 1, 2, n_ch_per_hemi, T)
|
|
399
|
+
|
|
400
|
+
# Step 5:
|
|
401
|
+
# From: (B, F, Z, Space, Time)
|
|
402
|
+
# To: (B, F, Z, Time, Space)
|
|
403
|
+
x = self.permute_layer(x)
|
|
404
|
+
|
|
405
|
+
# Now x is in correct format: (Batch, Features, Z, Time, Space)
|
|
406
|
+
|
|
407
|
+
# Initial inception modules
|
|
408
|
+
x = self.inception_block1([x])[0] # Returns list, take first element
|
|
409
|
+
x = self.inception_block2([x])[0] # Returns list, take first element
|
|
410
|
+
|
|
411
|
+
# Residual blocks
|
|
412
|
+
x = self.residual_blocks(x)
|
|
413
|
+
|
|
414
|
+
# Temporal reduction
|
|
415
|
+
x = self.temporal_reduction(x)
|
|
416
|
+
|
|
417
|
+
# Channel merging
|
|
418
|
+
x = self.channel_merging(x)
|
|
419
|
+
|
|
420
|
+
# Temporal merging
|
|
421
|
+
x = self.temporal_merging(x)
|
|
422
|
+
|
|
423
|
+
# Output blocks
|
|
424
|
+
x = self.output_blocks(x)
|
|
425
|
+
|
|
426
|
+
# Final fully connected layer
|
|
427
|
+
x = self.final_layer(x)
|
|
428
|
+
|
|
429
|
+
return x
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
class _InceptionBlock(nn.Module):
|
|
433
|
+
"""Inception module used in EEGSym architecture.
|
|
434
|
+
|
|
435
|
+
Parameters
|
|
436
|
+
----------
|
|
437
|
+
in_channels : int
|
|
438
|
+
Number of input channels.
|
|
439
|
+
scales_samples : list of int
|
|
440
|
+
List of sample sizes for the temporal convolution kernels.
|
|
441
|
+
filters_per_branch : int
|
|
442
|
+
Number of filters in each inception branch.
|
|
443
|
+
ncha : int
|
|
444
|
+
Number of input channels.
|
|
445
|
+
activation : nn.Module
|
|
446
|
+
Activation function to use.
|
|
447
|
+
drop_prob : float
|
|
448
|
+
Dropout probability.
|
|
449
|
+
average_pool : int
|
|
450
|
+
Kernel size for average pooling.
|
|
451
|
+
spatial_resnet_repetitions : int
|
|
452
|
+
Number of repetitions of the spatial analysis operations.
|
|
453
|
+
residual : bool
|
|
454
|
+
If True, includes residual connections.
|
|
455
|
+
init : bool
|
|
456
|
+
If True, applies channel merging operation if residual is False.
|
|
457
|
+
"""
|
|
458
|
+
|
|
459
|
+
def __init__(
|
|
460
|
+
self,
|
|
461
|
+
in_channels: int,
|
|
462
|
+
scales_samples: List[int],
|
|
463
|
+
filters_per_branch: int,
|
|
464
|
+
ncha: int,
|
|
465
|
+
activation: nn.Module,
|
|
466
|
+
drop_prob: float,
|
|
467
|
+
average_pool: int,
|
|
468
|
+
spatial_resnet_repetitions: int,
|
|
469
|
+
init: bool = False,
|
|
470
|
+
):
|
|
471
|
+
super().__init__()
|
|
472
|
+
self.activation = activation
|
|
473
|
+
self.drop_prob = drop_prob
|
|
474
|
+
self.average_pool = average_pool
|
|
475
|
+
self.init = init
|
|
476
|
+
|
|
477
|
+
# Temporal convolutions
|
|
478
|
+
self.temporal_convs = nn.ModuleList()
|
|
479
|
+
for scale in scales_samples:
|
|
480
|
+
self.temporal_convs.append(
|
|
481
|
+
nn.Sequential(
|
|
482
|
+
nn.Conv3d(
|
|
483
|
+
in_channels=in_channels,
|
|
484
|
+
out_channels=filters_per_branch,
|
|
485
|
+
kernel_size=(1, scale, 1),
|
|
486
|
+
padding=(0, scale // 2, 0),
|
|
487
|
+
),
|
|
488
|
+
nn.BatchNorm3d(filters_per_branch),
|
|
489
|
+
activation,
|
|
490
|
+
nn.Dropout(drop_prob),
|
|
491
|
+
)
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
# Spatial convolutions
|
|
495
|
+
if ncha != 1:
|
|
496
|
+
self.spatial_convs = nn.ModuleList()
|
|
497
|
+
for _ in range(spatial_resnet_repetitions):
|
|
498
|
+
self.spatial_convs.append(
|
|
499
|
+
nn.Sequential(
|
|
500
|
+
nn.Conv3d(
|
|
501
|
+
in_channels=filters_per_branch * len(scales_samples),
|
|
502
|
+
out_channels=filters_per_branch * len(scales_samples),
|
|
503
|
+
kernel_size=(1, 1, ncha),
|
|
504
|
+
padding=(0, 0, 0),
|
|
505
|
+
),
|
|
506
|
+
nn.BatchNorm3d(filters_per_branch * len(scales_samples)),
|
|
507
|
+
activation,
|
|
508
|
+
nn.Dropout(drop_prob),
|
|
509
|
+
)
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
self.pool = (
|
|
513
|
+
nn.AvgPool3d(kernel_size=(1, average_pool, 1))
|
|
514
|
+
if average_pool != 1
|
|
515
|
+
else nn.Identity()
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
def forward(self, x_list: list[torch.Tensor]) -> list[torch.Tensor]:
|
|
519
|
+
outputs: list[torch.Tensor] = []
|
|
520
|
+
for x in x_list:
|
|
521
|
+
# Apply temporal convolutions
|
|
522
|
+
temp_outputs = [conv(x) for conv in self.temporal_convs]
|
|
523
|
+
x_out = torch.cat(temp_outputs, dim=1)
|
|
524
|
+
|
|
525
|
+
# Trim temporal dimension if needed (due to even kernel sizes with padding)
|
|
526
|
+
if x_out.shape[3] > x.shape[3]:
|
|
527
|
+
x_out = x_out[:, :, :, : x.shape[3], :]
|
|
528
|
+
|
|
529
|
+
# Residual connection
|
|
530
|
+
x_out = x_out + x
|
|
531
|
+
|
|
532
|
+
# Average pooling
|
|
533
|
+
x_out = self.pool(x_out)
|
|
534
|
+
|
|
535
|
+
# Apply spatial convolutions
|
|
536
|
+
if hasattr(self, "spatial_convs"):
|
|
537
|
+
for spatial_conv in self.spatial_convs:
|
|
538
|
+
x_spatial = spatial_conv(x_out)
|
|
539
|
+
x_out = x_out + x_spatial # Always use residual connection
|
|
540
|
+
|
|
541
|
+
outputs.append(x_out)
|
|
542
|
+
return outputs
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
class _ResidualBlock(nn.Module):
|
|
546
|
+
"""Residual block used in EEGSym architecture.
|
|
547
|
+
|
|
548
|
+
Parameters
|
|
549
|
+
----------
|
|
550
|
+
in_channels : int
|
|
551
|
+
Number of input channels.
|
|
552
|
+
filters : int
|
|
553
|
+
Number of filters for the convolutional layers.
|
|
554
|
+
kernel_size : int
|
|
555
|
+
Kernel size for the temporal convolution.
|
|
556
|
+
activation : nn.Module
|
|
557
|
+
Activation function to use.
|
|
558
|
+
drop_prob : float
|
|
559
|
+
Dropout probability.
|
|
560
|
+
average_pool : int
|
|
561
|
+
Kernel size for average pooling.
|
|
562
|
+
spatial_resnet_repetitions : int
|
|
563
|
+
Number of repetitions of the spatial analysis operations.
|
|
564
|
+
residual : bool
|
|
565
|
+
If True, includes residual connections.
|
|
566
|
+
"""
|
|
567
|
+
|
|
568
|
+
def __init__(
|
|
569
|
+
self,
|
|
570
|
+
in_channels: int,
|
|
571
|
+
filters: int,
|
|
572
|
+
kernel_size: int,
|
|
573
|
+
ncha: int,
|
|
574
|
+
activation: nn.Module,
|
|
575
|
+
drop_prob: float,
|
|
576
|
+
average_pool: int,
|
|
577
|
+
spatial_resnet_repetitions: int = 5,
|
|
578
|
+
):
|
|
579
|
+
super().__init__()
|
|
580
|
+
self.activation = activation
|
|
581
|
+
self.drop_prob = drop_prob
|
|
582
|
+
|
|
583
|
+
# Temporal convolution
|
|
584
|
+
self.temporal_conv = nn.Sequential(
|
|
585
|
+
nn.Conv3d(
|
|
586
|
+
in_channels=in_channels,
|
|
587
|
+
out_channels=filters,
|
|
588
|
+
kernel_size=(1, kernel_size, 1),
|
|
589
|
+
padding=(0, kernel_size // 2, 0),
|
|
590
|
+
),
|
|
591
|
+
nn.BatchNorm3d(filters),
|
|
592
|
+
activation,
|
|
593
|
+
nn.Dropout(drop_prob),
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
# Projection layer for dimension matching if needed
|
|
597
|
+
if in_channels != filters:
|
|
598
|
+
self.projection = nn.Conv3d(
|
|
599
|
+
in_channels=in_channels,
|
|
600
|
+
out_channels=filters,
|
|
601
|
+
kernel_size=(1, 1, 1),
|
|
602
|
+
)
|
|
603
|
+
else:
|
|
604
|
+
self.projection = None
|
|
605
|
+
|
|
606
|
+
# Average pooling
|
|
607
|
+
self.avg_pool = nn.AvgPool3d(
|
|
608
|
+
kernel_size=(1, average_pool, 1)
|
|
609
|
+
) # FIXED: pool Time
|
|
610
|
+
|
|
611
|
+
# Spatial convolutions (multiple repetitions like in InceptionBlock)
|
|
612
|
+
if ncha != 1:
|
|
613
|
+
self.spatial_convs = nn.ModuleList()
|
|
614
|
+
for _ in range(spatial_resnet_repetitions):
|
|
615
|
+
self.spatial_convs.append(
|
|
616
|
+
nn.Sequential(
|
|
617
|
+
nn.Conv3d(
|
|
618
|
+
in_channels=filters,
|
|
619
|
+
out_channels=filters,
|
|
620
|
+
kernel_size=(1, 1, ncha), # Spatial convolution
|
|
621
|
+
padding=(0, 0, 0),
|
|
622
|
+
),
|
|
623
|
+
nn.BatchNorm3d(filters),
|
|
624
|
+
activation,
|
|
625
|
+
nn.Dropout(drop_prob),
|
|
626
|
+
)
|
|
627
|
+
)
|
|
628
|
+
else:
|
|
629
|
+
self.spatial_convs = None
|
|
630
|
+
|
|
631
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
632
|
+
x_res = self.temporal_conv(x)
|
|
633
|
+
|
|
634
|
+
# Trim temporal dimension if needed (due to even kernel sizes with padding)
|
|
635
|
+
if x_res.shape[3] > x.shape[3]:
|
|
636
|
+
x_res = x_res[:, :, :, : x.shape[3], :]
|
|
637
|
+
|
|
638
|
+
# Handle channel dimension mismatch if needed
|
|
639
|
+
if self.projection is not None:
|
|
640
|
+
x = self.projection(x)
|
|
641
|
+
|
|
642
|
+
x_out = x_res + x # Residual connection
|
|
643
|
+
x_out = self.avg_pool(x_out)
|
|
644
|
+
|
|
645
|
+
# Apply spatial convolutions if present (multiple repetitions)
|
|
646
|
+
if self.spatial_convs is not None:
|
|
647
|
+
for spatial_conv in self.spatial_convs:
|
|
648
|
+
x_spatial = spatial_conv(x_out)
|
|
649
|
+
x_out = x_out + x_spatial # Residual connection with broadcasting
|
|
650
|
+
|
|
651
|
+
return x_out
|
|
652
|
+
|
|
653
|
+
|
|
654
|
+
class _TemporalBlock(nn.Module):
|
|
655
|
+
"""Temporal reduction block used in EEGSym architecture.
|
|
656
|
+
|
|
657
|
+
Parameters
|
|
658
|
+
----------
|
|
659
|
+
in_channels : int
|
|
660
|
+
Number of input channels.
|
|
661
|
+
filters : int
|
|
662
|
+
Number of filters for the convolutional layers.
|
|
663
|
+
kernel_size : int
|
|
664
|
+
Kernel size for the temporal convolution.
|
|
665
|
+
activation : nn.Module
|
|
666
|
+
Activation function to use.
|
|
667
|
+
drop_prob : float
|
|
668
|
+
Dropout probability.
|
|
669
|
+
residual : bool
|
|
670
|
+
If True, includes residual connections.
|
|
671
|
+
"""
|
|
672
|
+
|
|
673
|
+
def __init__(
|
|
674
|
+
self,
|
|
675
|
+
in_channels: int,
|
|
676
|
+
filters: int,
|
|
677
|
+
kernel_size: int,
|
|
678
|
+
activation: nn.Module,
|
|
679
|
+
drop_prob: float,
|
|
680
|
+
):
|
|
681
|
+
super().__init__()
|
|
682
|
+
self.activation = activation
|
|
683
|
+
self.drop_prob = drop_prob
|
|
684
|
+
|
|
685
|
+
self.conv = nn.Sequential(
|
|
686
|
+
nn.Conv3d(
|
|
687
|
+
in_channels=in_channels,
|
|
688
|
+
out_channels=filters,
|
|
689
|
+
kernel_size=(1, kernel_size, 1),
|
|
690
|
+
padding=(0, kernel_size // 2, 0),
|
|
691
|
+
),
|
|
692
|
+
nn.BatchNorm3d(filters),
|
|
693
|
+
activation,
|
|
694
|
+
nn.Dropout(drop_prob),
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
698
|
+
x_res = self.conv(x)
|
|
699
|
+
|
|
700
|
+
# Trim temporal dimension if needed (due to even kernel sizes with padding)
|
|
701
|
+
if x_res.shape[3] > x.shape[3]:
|
|
702
|
+
x_res = x_res[:, :, :, : x.shape[3], :]
|
|
703
|
+
|
|
704
|
+
x_res = x_res + x
|
|
705
|
+
return x_res
|
|
706
|
+
|
|
707
|
+
|
|
708
|
+
class _ChannelMergingBlock(nn.Module):
|
|
709
|
+
"""Channel merging block used in EEGSym architecture.
|
|
710
|
+
|
|
711
|
+
This block performs hemisphere merging through:
|
|
712
|
+
1. Two residual convolution iterations (with full spatial kernel)
|
|
713
|
+
2. One grouped convolution (merges Z dimension from 2 to 1)
|
|
714
|
+
|
|
715
|
+
Parameters
|
|
716
|
+
----------
|
|
717
|
+
in_channels : int
|
|
718
|
+
Number of input channels.
|
|
719
|
+
filters : int
|
|
720
|
+
Number of filters for the convolutional layers.
|
|
721
|
+
groups : int
|
|
722
|
+
Number of groups for the final grouped convolution.
|
|
723
|
+
ncha : int
|
|
724
|
+
Number of spatial channels to merge.
|
|
725
|
+
division : int
|
|
726
|
+
Z dimension size to merge (typically 2 for two hemispheres).
|
|
727
|
+
activation : nn.Module
|
|
728
|
+
Activation function to use.
|
|
729
|
+
drop_prob : float
|
|
730
|
+
Dropout probability.
|
|
731
|
+
"""
|
|
732
|
+
|
|
733
|
+
def __init__(
|
|
734
|
+
self,
|
|
735
|
+
in_channels: int,
|
|
736
|
+
filters: int,
|
|
737
|
+
groups: int,
|
|
738
|
+
ncha: int,
|
|
739
|
+
division: int,
|
|
740
|
+
activation: nn.Module,
|
|
741
|
+
drop_prob: float,
|
|
742
|
+
):
|
|
743
|
+
super().__init__()
|
|
744
|
+
self.activation = activation
|
|
745
|
+
self.drop_prob = drop_prob
|
|
746
|
+
|
|
747
|
+
# TWO residual convolution iterations
|
|
748
|
+
# Each reduces spatial dimension: ncha → 1
|
|
749
|
+
self.residual_convs = nn.ModuleList()
|
|
750
|
+
for _ in range(2):
|
|
751
|
+
self.residual_convs.append(
|
|
752
|
+
nn.Sequential(
|
|
753
|
+
nn.Conv3d(
|
|
754
|
+
in_channels=in_channels,
|
|
755
|
+
out_channels=filters,
|
|
756
|
+
kernel_size=(division, 1, ncha), # (Z, Time, Space)
|
|
757
|
+
padding=(0, 0, 0), # Valid padding
|
|
758
|
+
),
|
|
759
|
+
nn.BatchNorm3d(filters),
|
|
760
|
+
activation,
|
|
761
|
+
nn.Dropout(drop_prob),
|
|
762
|
+
)
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
# Final grouped convolution
|
|
766
|
+
# Merges Z dimension: 2 → 1
|
|
767
|
+
self.grouped_conv = nn.Sequential(
|
|
768
|
+
nn.Conv3d(
|
|
769
|
+
in_channels=in_channels,
|
|
770
|
+
out_channels=filters,
|
|
771
|
+
kernel_size=(division, 1, ncha), # (Z, Time, Space)
|
|
772
|
+
groups=groups,
|
|
773
|
+
padding=(0, 0, 0),
|
|
774
|
+
),
|
|
775
|
+
nn.BatchNorm3d(filters),
|
|
776
|
+
activation,
|
|
777
|
+
nn.Dropout(drop_prob),
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
781
|
+
# Apply 2 residual iterations
|
|
782
|
+
# Each iteration: conv reduces dims, then Add broadcasts back
|
|
783
|
+
for residual_conv in self.residual_convs:
|
|
784
|
+
x_res = residual_conv(x)
|
|
785
|
+
x = x + x_res # Broadcasts x_res (1,T,1) to match x (2,T,5)
|
|
786
|
+
|
|
787
|
+
# Apply final grouped conv (permanently reduces dimensions)
|
|
788
|
+
x = self.grouped_conv(x)
|
|
789
|
+
|
|
790
|
+
return x
|
|
791
|
+
|
|
792
|
+
|
|
793
|
+
class _TemporalMergingBlock(nn.Module):
|
|
794
|
+
"""Temporal merging block used in EEGSym architecture.
|
|
795
|
+
|
|
796
|
+
This block performs temporal dimension collapse through:
|
|
797
|
+
1. One residual convolution (temporal collapse with residual connection)
|
|
798
|
+
2. One grouped convolution (temporal collapse + double filters)
|
|
799
|
+
|
|
800
|
+
Parameters
|
|
801
|
+
----------
|
|
802
|
+
in_channels : int
|
|
803
|
+
Number of input channels.
|
|
804
|
+
filters : int
|
|
805
|
+
Number of output filters (should be 2x input channels).
|
|
806
|
+
groups : int
|
|
807
|
+
Number of groups for the grouped convolution.
|
|
808
|
+
n_times : int
|
|
809
|
+
Current temporal dimension size.
|
|
810
|
+
activation : nn.Module
|
|
811
|
+
Activation function to use.
|
|
812
|
+
drop_prob : float
|
|
813
|
+
Dropout probability.
|
|
814
|
+
"""
|
|
815
|
+
|
|
816
|
+
def __init__(
|
|
817
|
+
self,
|
|
818
|
+
in_channels: int,
|
|
819
|
+
filters: int,
|
|
820
|
+
groups: int,
|
|
821
|
+
n_times: int,
|
|
822
|
+
activation: nn.Module,
|
|
823
|
+
drop_prob: float,
|
|
824
|
+
):
|
|
825
|
+
super().__init__()
|
|
826
|
+
self.activation = activation
|
|
827
|
+
self.drop_prob = drop_prob
|
|
828
|
+
|
|
829
|
+
# Calculate temporal kernel size
|
|
830
|
+
# At this point in network, temporal dim has been reduced by pooling
|
|
831
|
+
self.temporal_kernel = n_times # Should be 6 for 384 input samples
|
|
832
|
+
|
|
833
|
+
# Residual convolution (collapses time dimension)
|
|
834
|
+
self.residual_conv = nn.Sequential(
|
|
835
|
+
nn.Conv3d(
|
|
836
|
+
in_channels=in_channels,
|
|
837
|
+
out_channels=in_channels, # Same channels for residual
|
|
838
|
+
kernel_size=(1, self.temporal_kernel, 1), # (Z, Time, Space)
|
|
839
|
+
padding=(0, 0, 0), # Valid padding - reduces time to 1
|
|
840
|
+
),
|
|
841
|
+
nn.BatchNorm3d(in_channels),
|
|
842
|
+
activation,
|
|
843
|
+
nn.Dropout(drop_prob),
|
|
844
|
+
)
|
|
845
|
+
|
|
846
|
+
# Grouped convolution (collapses time dimension, doubles filters)
|
|
847
|
+
self.grouped_conv = nn.Sequential(
|
|
848
|
+
nn.Conv3d(
|
|
849
|
+
in_channels=in_channels,
|
|
850
|
+
out_channels=filters, # Double the channels
|
|
851
|
+
kernel_size=(1, self.temporal_kernel, 1), # (Z, Time, Space)
|
|
852
|
+
groups=groups,
|
|
853
|
+
padding=(0, 0, 0),
|
|
854
|
+
),
|
|
855
|
+
nn.BatchNorm3d(filters),
|
|
856
|
+
activation,
|
|
857
|
+
nn.Dropout(drop_prob),
|
|
858
|
+
)
|
|
859
|
+
|
|
860
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
861
|
+
# Residual convolution with broadcasting
|
|
862
|
+
x_res = self.residual_conv(x)
|
|
863
|
+
x = x + x_res # Broadcasts x_res (1,1,1) back to x shape (1,6,1)
|
|
864
|
+
|
|
865
|
+
# Grouped convolution (reduces time to 1, doubles channels)
|
|
866
|
+
x = self.grouped_conv(x)
|
|
867
|
+
|
|
868
|
+
return x
|
|
869
|
+
|
|
870
|
+
|
|
871
|
+
class _OutputBlock(nn.Module):
|
|
872
|
+
"""Output block used in EEGSym architecture.
|
|
873
|
+
|
|
874
|
+
Parameters
|
|
875
|
+
----------
|
|
876
|
+
in_channels : int
|
|
877
|
+
Number of input channels.
|
|
878
|
+
activation : nn.Module
|
|
879
|
+
Activation function to use.
|
|
880
|
+
drop_prob : float
|
|
881
|
+
Dropout probability.
|
|
882
|
+
residual : bool
|
|
883
|
+
If True, includes residual connections.
|
|
884
|
+
"""
|
|
885
|
+
|
|
886
|
+
def __init__(
|
|
887
|
+
self,
|
|
888
|
+
in_channels: int,
|
|
889
|
+
activation: nn.Module,
|
|
890
|
+
drop_prob: float,
|
|
891
|
+
n_residual: int = 4,
|
|
892
|
+
):
|
|
893
|
+
super().__init__()
|
|
894
|
+
self.activation = activation
|
|
895
|
+
self.drop_prob = drop_prob
|
|
896
|
+
|
|
897
|
+
self.conv_blocks = nn.ModuleList()
|
|
898
|
+
for _ in range(n_residual):
|
|
899
|
+
self.conv_blocks.append(
|
|
900
|
+
nn.Sequential(
|
|
901
|
+
nn.Conv3d(
|
|
902
|
+
in_channels=in_channels,
|
|
903
|
+
out_channels=in_channels,
|
|
904
|
+
kernel_size=(1, 1, 1),
|
|
905
|
+
padding=(0, 0, 0),
|
|
906
|
+
),
|
|
907
|
+
nn.BatchNorm3d(in_channels),
|
|
908
|
+
activation,
|
|
909
|
+
nn.Dropout(drop_prob),
|
|
910
|
+
)
|
|
911
|
+
)
|
|
912
|
+
|
|
913
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
914
|
+
for conv_block in self.conv_blocks:
|
|
915
|
+
x_res = conv_block(x)
|
|
916
|
+
x = x + x_res
|
|
917
|
+
return x
|