braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/augmentation/__init__.py +3 -5
- braindecode/augmentation/base.py +5 -8
- braindecode/augmentation/functional.py +22 -25
- braindecode/augmentation/transforms.py +42 -51
- braindecode/classifier.py +16 -11
- braindecode/datasets/__init__.py +3 -5
- braindecode/datasets/base.py +13 -17
- braindecode/datasets/bbci.py +14 -13
- braindecode/datasets/bcicomp.py +5 -4
- braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
- braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
- braindecode/datasets/{bids/hub.py → hub.py} +350 -375
- braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
- braindecode/datasets/mne.py +19 -19
- braindecode/datasets/moabb.py +10 -10
- braindecode/datasets/nmt.py +56 -58
- braindecode/datasets/sleep_physio_challe_18.py +5 -3
- braindecode/datasets/sleep_physionet.py +5 -5
- braindecode/datasets/tuh.py +18 -21
- braindecode/datasets/xy.py +9 -10
- braindecode/datautil/__init__.py +3 -3
- braindecode/datautil/serialization.py +20 -22
- braindecode/datautil/util.py +7 -120
- braindecode/eegneuralnet.py +52 -22
- braindecode/functional/functions.py +10 -7
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +3 -5
- braindecode/models/atcnet.py +39 -43
- braindecode/models/attentionbasenet.py +41 -37
- braindecode/models/attn_sleep.py +24 -26
- braindecode/models/base.py +6 -6
- braindecode/models/bendr.py +26 -50
- braindecode/models/biot.py +30 -61
- braindecode/models/contrawr.py +5 -5
- braindecode/models/ctnet.py +35 -35
- braindecode/models/deep4.py +5 -5
- braindecode/models/deepsleepnet.py +7 -7
- braindecode/models/eegconformer.py +26 -31
- braindecode/models/eeginception_erp.py +2 -2
- braindecode/models/eeginception_mi.py +6 -6
- braindecode/models/eegitnet.py +5 -5
- braindecode/models/eegminer.py +1 -1
- braindecode/models/eegnet.py +3 -3
- braindecode/models/eegnex.py +2 -2
- braindecode/models/eegsimpleconv.py +2 -2
- braindecode/models/eegsym.py +7 -7
- braindecode/models/eegtcnet.py +6 -6
- braindecode/models/fbcnet.py +2 -2
- braindecode/models/fblightconvnet.py +3 -3
- braindecode/models/fbmsnet.py +3 -3
- braindecode/models/hybrid.py +2 -2
- braindecode/models/ifnet.py +5 -5
- braindecode/models/labram.py +46 -70
- braindecode/models/luna.py +5 -60
- braindecode/models/medformer.py +21 -23
- braindecode/models/msvtnet.py +15 -15
- braindecode/models/patchedtransformer.py +55 -55
- braindecode/models/sccnet.py +2 -2
- braindecode/models/shallow_fbcsp.py +3 -5
- braindecode/models/signal_jepa.py +12 -39
- braindecode/models/sinc_shallow.py +4 -3
- braindecode/models/sleep_stager_blanco_2020.py +2 -2
- braindecode/models/sleep_stager_chambon_2018.py +2 -2
- braindecode/models/sparcnet.py +8 -8
- braindecode/models/sstdpn.py +869 -869
- braindecode/models/summary.csv +17 -19
- braindecode/models/syncnet.py +2 -2
- braindecode/models/tcn.py +5 -5
- braindecode/models/tidnet.py +3 -3
- braindecode/models/tsinception.py +3 -3
- braindecode/models/usleep.py +7 -7
- braindecode/models/util.py +14 -165
- braindecode/modules/__init__.py +1 -9
- braindecode/modules/activation.py +3 -29
- braindecode/modules/attention.py +0 -123
- braindecode/modules/blocks.py +1 -53
- braindecode/modules/convolution.py +0 -53
- braindecode/modules/filter.py +0 -31
- braindecode/modules/layers.py +0 -84
- braindecode/modules/linear.py +1 -22
- braindecode/modules/stats.py +0 -10
- braindecode/modules/util.py +0 -9
- braindecode/modules/wrapper.py +0 -17
- braindecode/preprocessing/preprocess.py +0 -3
- braindecode/regressor.py +18 -15
- braindecode/samplers/ssl.py +1 -1
- braindecode/util.py +28 -38
- braindecode/version.py +1 -1
- braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
- braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
- braindecode/datasets/bids/__init__.py +0 -54
- braindecode/datasets/bids/format.py +0 -717
- braindecode/datasets/bids/hub_format.py +0 -717
- braindecode/datasets/bids/hub_io.py +0 -197
- braindecode/datasets/chb_mit.py +0 -163
- braindecode/datasets/siena.py +0 -162
- braindecode/datasets/utils.py +0 -67
- braindecode/models/brainmodule.py +0 -845
- braindecode/models/config.py +0 -233
- braindecode/models/reve.py +0 -843
- braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
- braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
|
@@ -1,845 +0,0 @@
|
|
|
1
|
-
# Authors: Meta Platforms, Inc. and affiliates (original)
|
|
2
|
-
# Bruno Aristimunha <b.aristimunha@gmail.com> (Braindecode adaptation)
|
|
3
|
-
# Hubert Banville <hubertjb@meta.com> (Braindecode adaptation and Review)
|
|
4
|
-
#
|
|
5
|
-
# License: Attribution-NonCommercial 4.0 International
|
|
6
|
-
|
|
7
|
-
"""BrainModule: Dilated Convolutional Encoder for EEG decoding."""
|
|
8
|
-
|
|
9
|
-
from __future__ import annotations
|
|
10
|
-
|
|
11
|
-
import math
|
|
12
|
-
import typing as tp
|
|
13
|
-
|
|
14
|
-
import torch
|
|
15
|
-
import torchaudio as ta
|
|
16
|
-
from torch import nn
|
|
17
|
-
from torch.nn import functional as F
|
|
18
|
-
|
|
19
|
-
from braindecode.models.base import EEGModuleMixin
|
|
20
|
-
from braindecode.modules.layers import SubjectLayers
|
|
21
|
-
|
|
22
|
-
__all__ = ["BrainModule"]
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
class BrainModule(EEGModuleMixin, nn.Module):
|
|
26
|
-
r"""BrainModule from [brainmagick]_, also known as SimpleConv.
|
|
27
|
-
|
|
28
|
-
A dilated convolutional encoder for EEG decoding, using residual
|
|
29
|
-
connections and optional GLU gating for improved expressivity.
|
|
30
|
-
|
|
31
|
-
:bdg-success:`Convolution`
|
|
32
|
-
|
|
33
|
-
.. figure:: ../_static/model/simpleconv.png
|
|
34
|
-
:align: center
|
|
35
|
-
:alt: BrainModule Architecture
|
|
36
|
-
:width: 500px
|
|
37
|
-
|
|
38
|
-
Figure adapted Extended Data Fig. 4 from [brainmagick]_ to highlight only the model part.
|
|
39
|
-
Architecture of the brain module. Architecture used to process the brain recordings.
|
|
40
|
-
For each layer, the authors note first the number of output channels, while the number of time steps
|
|
41
|
-
is constant throughout the layers. The model is composed of a spatial attention layer,
|
|
42
|
-
then a 1x1 convolution without activation. A 'Subject Layer' is selected based on the subject index s,
|
|
43
|
-
which consists in a 1x1 convolution learnt only for that subject with no activation. Then,
|
|
44
|
-
the authors apply five convolutional blocks made of three convolutions. The first
|
|
45
|
-
two use residual skip connection and increasing dilation, followed by a BatchNorm layer and a
|
|
46
|
-
GELU activation. The third convolution is not residual, and uses a GLU activation
|
|
47
|
-
(which halves the number of channels) and no normalization.
|
|
48
|
-
Finally, the authors apply two 1x1 convolutions with a GELU in between.
|
|
49
|
-
|
|
50
|
-
The BrainModule (also referred to as SimpleConv) is a deep dilated
|
|
51
|
-
convolutional encoder specifically designed to decode perceived speech from
|
|
52
|
-
non-invasive brain recordings like EEG and MEG. It is engineered to address
|
|
53
|
-
the high noise levels and inter-individual variability inherent in
|
|
54
|
-
non-invasive neuroimaging by using a single architecture trained across
|
|
55
|
-
large cohorts while accommodating participant-specific differences.
|
|
56
|
-
|
|
57
|
-
.. rubric:: Architecture Overview
|
|
58
|
-
|
|
59
|
-
The BrainModule integrates three primary mechanisms to align brain activity
|
|
60
|
-
with deep speech representations:
|
|
61
|
-
|
|
62
|
-
1. **Spatial-temporal feature extraction.** The model uses a dedicated
|
|
63
|
-
spatial attention layer to remap sensor data based on physical
|
|
64
|
-
locations, followed by temporal processing through dilated convolutions.
|
|
65
|
-
2. **Subject-specific adaptation.** To leverage inter-subject variability,
|
|
66
|
-
the architecture includes a "Subject Layer" or participant-specific
|
|
67
|
-
1x1 convolution that allows the model to share core weights across a
|
|
68
|
-
cohort while learning individual-specific neural patterns.
|
|
69
|
-
3. **Dilated residual blocks with gating.** The core encoder employs a
|
|
70
|
-
stack of convolutional blocks featuring skip connections and increasing
|
|
71
|
-
dilation to expand the receptive field without losing temporal
|
|
72
|
-
resolution, supplemented by optional Gated Linear Units (GLU) for
|
|
73
|
-
increased expressivity.
|
|
74
|
-
|
|
75
|
-
.. rubric:: Macro Components
|
|
76
|
-
|
|
77
|
-
``BrainModule.input_projection`` (Initial Processing)
|
|
78
|
-
**Operations.** Raw M/EEG input
|
|
79
|
-
:math:`\mathbf{X} \in \mathbb{R}^{C \times T}` is first processed
|
|
80
|
-
through a spatial attention layer that projects sensor locations onto a
|
|
81
|
-
2D plane using Fourier-parameterized functions. This is followed by a
|
|
82
|
-
subject-specific 1x1 convolution
|
|
83
|
-
:math:`\mathbf{M}_s \in \mathbb{R}^{D_1 \times D_1}` if subject
|
|
84
|
-
features are enabled. The resulting features are projected to the
|
|
85
|
-
``hidden_dim`` (default 320) to ensure compatibility with subsequent
|
|
86
|
-
residual connections.
|
|
87
|
-
|
|
88
|
-
**Role.** Converts high-dimensional, subject-dependent sensor data into
|
|
89
|
-
a standardized latent space while preserving spatial and temporal
|
|
90
|
-
relationships.
|
|
91
|
-
|
|
92
|
-
``BrainModule.encoder`` (Convolutional Sequence)
|
|
93
|
-
**Operations.** Implemented via
|
|
94
|
-
:class:`~braindecode.models.brainmodule._ConvSequence`, this component
|
|
95
|
-
consists of a stack of ``k`` convolutional blocks. Each block typically
|
|
96
|
-
contains: (a) **Residual dilated convolutions.** Two layers with kernel
|
|
97
|
-
size 3, residual skip connections, and dilation factors that grow
|
|
98
|
-
exponentially (e.g., powers of two with periodic resets) to capture
|
|
99
|
-
multi-scale temporal context. (b) **GLU gating.** Every ``N`` layers
|
|
100
|
-
(defined by ``glu``), a Gated Linear Unit is applied, which halves the
|
|
101
|
-
channel dimension and introduces non-linear gating to filter
|
|
102
|
-
intermediate representations.
|
|
103
|
-
|
|
104
|
-
**Role.** Extracts deep hierarchical temporal features from the brain
|
|
105
|
-
signal, significantly expanding the model's receptive field to align
|
|
106
|
-
with the contextual windows of speech modules like wav2vec 2.0.
|
|
107
|
-
|
|
108
|
-
.. rubric:: Temporal, Spatial, and Spectral Encoding
|
|
109
|
-
|
|
110
|
-
- **Temporal:** Increasing dilation factors across layers allow the model to
|
|
111
|
-
integrate information over large time windows without the computational
|
|
112
|
-
cost of standard large kernels, while a 150 ms input shift facilitates
|
|
113
|
-
alignment between stimulus and brain response.
|
|
114
|
-
- **Spatial:** The spatial attention layer learns a softmax weighting over
|
|
115
|
-
input sensors based on their 3D coordinates, allowing the model to focus
|
|
116
|
-
on regions typically activated during auditory stimulation (e.g., the
|
|
117
|
-
temporal cortex).
|
|
118
|
-
- **Spectral:** Through the optional ``n_fft`` parameter, the model can
|
|
119
|
-
apply an STFT transformation, converting time-domain signals into a
|
|
120
|
-
spectrogram representation before encoding.
|
|
121
|
-
|
|
122
|
-
.. rubric:: Additional Mechanisms
|
|
123
|
-
|
|
124
|
-
- **Clamping and scaling:** The model relies on clamping input values
|
|
125
|
-
(e.g., at 20 standard deviations) to prevent outliers and large
|
|
126
|
-
electromagnetic artifacts from destabilizing the BatchNorm estimates and
|
|
127
|
-
optimization process.
|
|
128
|
-
- **Scaled subject embeddings:** When ``subject_dim`` is used, the
|
|
129
|
-
:class:`~braindecode.models.brainmodule._ScaledEmbedding` layer scales up
|
|
130
|
-
the learning rate for subject-specific features to prevent slow
|
|
131
|
-
convergence in multi-participant training.
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
- **_ConvSequence and residual logic:** This class handles the actual
|
|
135
|
-
stacking of layers. It is designed to be flexible with the ``growth``
|
|
136
|
-
parameter; if the channel size changes between layers (``growth != 1.0``),
|
|
137
|
-
it automatically applies a 1x1 ``skip_projection`` convolution to the
|
|
138
|
-
residual path so dimensions match for addition.
|
|
139
|
-
- **_ChannelDropout:** Unlike standard dropout which zeroes individual
|
|
140
|
-
neurons, this zeroes entire channels. It includes a rescale feature that
|
|
141
|
-
multiplies the remaining channels by a factor
|
|
142
|
-
``total_channels / active_channels`` to maintain the expected value of the
|
|
143
|
-
signal during training.
|
|
144
|
-
- **_ScaledEmbedding:** This is a clever optimization for multi-subject
|
|
145
|
-
learning. By dividing the initial weights by a scale and then multiplying
|
|
146
|
-
the output by the same scale, it effectively increases the gradient
|
|
147
|
-
magnitude for the embedding weights, allowing subject-specific features to
|
|
148
|
-
learn faster than the shared backbone.
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
Parameters
|
|
152
|
-
----------
|
|
153
|
-
hidden_dim : int, default=320
|
|
154
|
-
Hidden dimension for convolutional layers. Input is projected to this
|
|
155
|
-
dimension before the convolutional blocks.
|
|
156
|
-
depth : int, default=10
|
|
157
|
-
Number of convolutional blocks. Each block contains a dilated convolution
|
|
158
|
-
with batch normalization and activation, followed by a residual connection.
|
|
159
|
-
kernel_size : int, default=3
|
|
160
|
-
Convolutional kernel size. Must be odd for proper padding with dilation.
|
|
161
|
-
growth : float, default=1.0
|
|
162
|
-
Channel size multiplier: hidden_dim * (growth ** layer_index).
|
|
163
|
-
Values > 1.0 grow channels deeper; < 1.0 shrink them.
|
|
164
|
-
Note: growth != 1.0 disables residual connections between layers
|
|
165
|
-
with different channel sizes.
|
|
166
|
-
dilation_growth : int, default=2
|
|
167
|
-
Dilation multiplier per layer (e.g., 2 means dilation doubles each layer).
|
|
168
|
-
Improves receptive field exponentially. Requires odd kernel_size.
|
|
169
|
-
dilation_period : int, default=5
|
|
170
|
-
Reset dilation to 1 every N layers. Prevents dilation from growing
|
|
171
|
-
too large and maintains local connectivity.
|
|
172
|
-
conv_drop_prob : float, default=0.0
|
|
173
|
-
Dropout probability for convolutional layers.
|
|
174
|
-
dropout_input : float, default=0.0
|
|
175
|
-
Dropout probability applied to model input only.
|
|
176
|
-
batch_norm : bool, default=True
|
|
177
|
-
If True, apply batch normalization after each convolution.
|
|
178
|
-
activation : type[nn.Module], default=nn.GELU
|
|
179
|
-
Activation function class to use (e.g., nn.GELU, nn.ReLU, nn.ELU).
|
|
180
|
-
n_subjects : int, default=200
|
|
181
|
-
Number of unique subjects (for subject-specific pathways).
|
|
182
|
-
Only used if subject_dim > 0.
|
|
183
|
-
subject_dim : int, default=0
|
|
184
|
-
Dimension of subject embeddings. If 0, no subject-specific features.
|
|
185
|
-
If > 0, adds subject embeddings to the input before encoding.
|
|
186
|
-
subject_layers : bool, default=False
|
|
187
|
-
If True, apply subject-specific linear transformations to input channels.
|
|
188
|
-
Each subject has its own weight matrix. Requires subject_dim > 0.
|
|
189
|
-
subject_layers_dim : str, default="input"
|
|
190
|
-
Where to apply subject layers: "input" or "hidden".
|
|
191
|
-
subject_layers_id : bool, default=False
|
|
192
|
-
If True, initialize subject layers as identity matrices.
|
|
193
|
-
embedding_scale : float, default=1.0
|
|
194
|
-
Scaling factor for subject embeddings learning rate.
|
|
195
|
-
n_fft : int, optional
|
|
196
|
-
FFT size for STFT processing. If None, no STFT is applied.
|
|
197
|
-
If specified, applies spectrogram transform before encoding.
|
|
198
|
-
fft_complex : bool, default=True
|
|
199
|
-
If True, keep complex spectrogram. If False, use power spectrogram.
|
|
200
|
-
Only used when n_fft is not None.
|
|
201
|
-
channel_dropout_prob : float, default=0.0
|
|
202
|
-
Probability of dropping each channel during training (0.0 to 1.0).
|
|
203
|
-
If 0.0, no channel dropout is applied.
|
|
204
|
-
channel_dropout_type : str, optional
|
|
205
|
-
If specified with chs_info, only drop channels of this type
|
|
206
|
-
(e.g., 'eeg', 'ref', 'eog'). If None with dropout_prob > 0, drops any channel.
|
|
207
|
-
glu : int, default=2
|
|
208
|
-
If > 0, applies Gated Linear Units (GLU) every N convolutional layers.
|
|
209
|
-
GLUs gate intermediate representations for more expressivity.
|
|
210
|
-
If 0, no GLU is applied.
|
|
211
|
-
glu_context : int, default=1
|
|
212
|
-
Context window size for GLU gates. If > 0, uses contextual information
|
|
213
|
-
from neighboring time steps for gating. Requires glu > 0.
|
|
214
|
-
|
|
215
|
-
References
|
|
216
|
-
----------
|
|
217
|
-
.. [brainmagick] Défossez, A., Caucheteux, C., Rapin, J., Kabeli, O., & King, J. R.
|
|
218
|
-
(2023). Decoding speech perception from non-invasive brain recordings. Nature
|
|
219
|
-
Machine Intelligence, 5(10), 1097-1107.
|
|
220
|
-
|
|
221
|
-
Notes
|
|
222
|
-
-----
|
|
223
|
-
- Input shape: (batch, n_chans, n_times)
|
|
224
|
-
- Output shape: (batch, n_outputs)
|
|
225
|
-
- The model uses dilated convolutions with stride=1 to maintain temporal
|
|
226
|
-
resolution while achieving large receptive fields.
|
|
227
|
-
- Residual connections are applied at every layer where input and output
|
|
228
|
-
channels match.
|
|
229
|
-
- Subject-specific features (subject_dim > 0, subject_layers) require passing
|
|
230
|
-
subject indices in the forward pass as an optional parameter or via batch.
|
|
231
|
-
- STFT processing (n_fft > 0) automatically transforms input to spectrogram domain.
|
|
232
|
-
|
|
233
|
-
.. versionadded:: 1.2
|
|
234
|
-
|
|
235
|
-
"""
|
|
236
|
-
|
|
237
|
-
def __init__(
|
|
238
|
-
self,
|
|
239
|
-
# braindecode EEGModuleMixin parameters
|
|
240
|
-
n_chans: int | None = None,
|
|
241
|
-
n_outputs: int | None = None,
|
|
242
|
-
n_times: int | None = None,
|
|
243
|
-
sfreq: float | None = None,
|
|
244
|
-
chs_info: list[dict] | None = None,
|
|
245
|
-
input_window_seconds: float | None = None,
|
|
246
|
-
########
|
|
247
|
-
# Model related parameters
|
|
248
|
-
# Architecture
|
|
249
|
-
hidden_dim: int = 320,
|
|
250
|
-
depth: int = 10,
|
|
251
|
-
kernel_size: int = 3,
|
|
252
|
-
growth: float = 1.0,
|
|
253
|
-
dilation_growth: int = 2,
|
|
254
|
-
dilation_period: int = 5,
|
|
255
|
-
# Regularization
|
|
256
|
-
conv_drop_prob: float = 0.0,
|
|
257
|
-
dropout_input: float = 0.0,
|
|
258
|
-
batch_norm: bool = True,
|
|
259
|
-
activation: type[nn.Module] = nn.GELU,
|
|
260
|
-
# Subject-specific features (optional)
|
|
261
|
-
n_subjects: int = 200,
|
|
262
|
-
subject_dim: int = 0,
|
|
263
|
-
subject_layers: bool = False,
|
|
264
|
-
subject_layers_dim: str = "input",
|
|
265
|
-
subject_layers_id: bool = False,
|
|
266
|
-
embedding_scale: float = 1.0,
|
|
267
|
-
# STFT/Spectrogram (optional)
|
|
268
|
-
n_fft: int | None = None,
|
|
269
|
-
fft_complex: bool = True,
|
|
270
|
-
# Channel dropout (optional)
|
|
271
|
-
channel_dropout_prob: float = 0.0,
|
|
272
|
-
channel_dropout_type: str | None = None,
|
|
273
|
-
# GLU gates (optional)
|
|
274
|
-
glu: int = 2,
|
|
275
|
-
glu_context: int = 1,
|
|
276
|
-
):
|
|
277
|
-
# Initialize EEGModuleMixin
|
|
278
|
-
super().__init__(
|
|
279
|
-
n_outputs=n_outputs,
|
|
280
|
-
n_chans=n_chans,
|
|
281
|
-
chs_info=chs_info,
|
|
282
|
-
n_times=n_times,
|
|
283
|
-
input_window_seconds=input_window_seconds,
|
|
284
|
-
sfreq=sfreq,
|
|
285
|
-
)
|
|
286
|
-
|
|
287
|
-
# Store parameters for later use
|
|
288
|
-
self.subject_dim = subject_dim
|
|
289
|
-
self.n_subjects = n_subjects
|
|
290
|
-
self.n_fft = n_fft
|
|
291
|
-
self.fft_complex = fft_complex
|
|
292
|
-
self.hidden_dim = hidden_dim
|
|
293
|
-
|
|
294
|
-
# Validate inputs
|
|
295
|
-
_validate_brainmodule_params(
|
|
296
|
-
subject_layers=subject_layers,
|
|
297
|
-
subject_dim=subject_dim,
|
|
298
|
-
depth=depth,
|
|
299
|
-
kernel_size=kernel_size,
|
|
300
|
-
growth=growth,
|
|
301
|
-
dilation_growth=dilation_growth,
|
|
302
|
-
channel_dropout_prob=channel_dropout_prob,
|
|
303
|
-
channel_dropout_type=channel_dropout_type,
|
|
304
|
-
glu=glu,
|
|
305
|
-
glu_context=glu_context,
|
|
306
|
-
)
|
|
307
|
-
|
|
308
|
-
# Initialize channel dropout (optional)
|
|
309
|
-
self.channel_dropout = None
|
|
310
|
-
if channel_dropout_prob > 0:
|
|
311
|
-
self.channel_dropout = _ChannelDropout(
|
|
312
|
-
dropout_prob=channel_dropout_prob,
|
|
313
|
-
ch_info=chs_info,
|
|
314
|
-
channel_type=channel_dropout_type,
|
|
315
|
-
)
|
|
316
|
-
|
|
317
|
-
# Initialize subject-specific modules (optional)
|
|
318
|
-
self.subject_embedding = None
|
|
319
|
-
self.subject_layers_module = None
|
|
320
|
-
input_channels = self.n_chans
|
|
321
|
-
|
|
322
|
-
if subject_dim > 0:
|
|
323
|
-
self.subject_embedding = _ScaledEmbedding(
|
|
324
|
-
n_subjects, subject_dim, embedding_scale
|
|
325
|
-
)
|
|
326
|
-
input_channels += subject_dim
|
|
327
|
-
|
|
328
|
-
if subject_layers:
|
|
329
|
-
assert subject_dim > 0, "subject_layers requires subject_dim > 0"
|
|
330
|
-
# Use n_chans for input dim since subject_layers is applied before
|
|
331
|
-
# subject embeddings are concatenated in forward()
|
|
332
|
-
meg_dim = self.n_chans
|
|
333
|
-
dim = hidden_dim if subject_layers_dim == "hidden" else meg_dim
|
|
334
|
-
self.subject_layers_module = SubjectLayers(
|
|
335
|
-
meg_dim, dim, n_subjects, subject_layers_id
|
|
336
|
-
)
|
|
337
|
-
# After subject_layers, we have 'dim' channels, then add subject_dim
|
|
338
|
-
input_channels = dim + subject_dim
|
|
339
|
-
|
|
340
|
-
# Initialize STFT module (optional)
|
|
341
|
-
self.stft = None
|
|
342
|
-
if n_fft is not None:
|
|
343
|
-
self.stft = ta.transforms.Spectrogram(
|
|
344
|
-
n_fft=n_fft,
|
|
345
|
-
hop_length=n_fft // 2,
|
|
346
|
-
normalized=True,
|
|
347
|
-
power=None if fft_complex else 1,
|
|
348
|
-
return_complex=True,
|
|
349
|
-
)
|
|
350
|
-
# Update input channels for spectrogram
|
|
351
|
-
freq_bins = n_fft // 2 + 1
|
|
352
|
-
if fft_complex:
|
|
353
|
-
input_channels *= 2 * freq_bins
|
|
354
|
-
else:
|
|
355
|
-
input_channels *= freq_bins
|
|
356
|
-
|
|
357
|
-
# Initial projection layer: project input channels to hidden_dim
|
|
358
|
-
# This is crucial for residual connections to work properly
|
|
359
|
-
self.input_projection = nn.Conv1d(input_channels, hidden_dim, 1)
|
|
360
|
-
|
|
361
|
-
# Build channel dimensions for encoder (all same size for residuals)
|
|
362
|
-
# With growth=1.0, all layers have hidden_dim channels (residuals work)
|
|
363
|
-
# With growth!=1.0, channels vary (residuals only where dims match)
|
|
364
|
-
encoder_dims = [hidden_dim] + [
|
|
365
|
-
int(round(hidden_dim * growth**k)) for k in range(depth)
|
|
366
|
-
]
|
|
367
|
-
|
|
368
|
-
# Build encoder (stride=1, no downsampling)
|
|
369
|
-
self.encoder = _ConvSequence(
|
|
370
|
-
channels=encoder_dims,
|
|
371
|
-
kernel_size=kernel_size,
|
|
372
|
-
dilation_growth=dilation_growth,
|
|
373
|
-
dilation_period=dilation_period,
|
|
374
|
-
dropout=conv_drop_prob,
|
|
375
|
-
dropout_input=dropout_input,
|
|
376
|
-
batch_norm=batch_norm,
|
|
377
|
-
glu=glu,
|
|
378
|
-
glu_context=glu_context,
|
|
379
|
-
activation=activation,
|
|
380
|
-
)
|
|
381
|
-
|
|
382
|
-
# Final layer: temporal aggregation + output projection
|
|
383
|
-
# Use the last encoder dimension (may differ from hidden_dim if growth != 1)
|
|
384
|
-
final_hidden_dim = encoder_dims[-1]
|
|
385
|
-
self.final_layer = nn.Sequential(
|
|
386
|
-
nn.AdaptiveAvgPool1d(1),
|
|
387
|
-
nn.Flatten(start_dim=1),
|
|
388
|
-
nn.Linear(final_hidden_dim, self.n_outputs),
|
|
389
|
-
)
|
|
390
|
-
|
|
391
|
-
def forward(
|
|
392
|
-
self, x: torch.Tensor, subject_index: torch.Tensor | None = None
|
|
393
|
-
) -> torch.Tensor:
|
|
394
|
-
"""
|
|
395
|
-
Forward pass.
|
|
396
|
-
|
|
397
|
-
Parameters
|
|
398
|
-
----------
|
|
399
|
-
x : torch.Tensor
|
|
400
|
-
Input EEG data of shape (batch, n_chans, n_times).
|
|
401
|
-
subject_index : torch.Tensor, optional
|
|
402
|
-
Subject indices of shape (batch,). Required if subject_dim > 0.
|
|
403
|
-
|
|
404
|
-
Returns
|
|
405
|
-
-------
|
|
406
|
-
torch.Tensor
|
|
407
|
-
Output logits/predictions of shape (batch, n_outputs).
|
|
408
|
-
"""
|
|
409
|
-
# Validate input shape
|
|
410
|
-
if x.dim() != 3:
|
|
411
|
-
raise ValueError(
|
|
412
|
-
f"Expected 3D input (batch, channels, time), got shape {x.shape}"
|
|
413
|
-
)
|
|
414
|
-
if x.shape[1] != self.n_chans:
|
|
415
|
-
raise ValueError(f"Expected {self.n_chans} channels, got {x.shape[1]}")
|
|
416
|
-
|
|
417
|
-
# Apply STFT if enabled
|
|
418
|
-
if self.stft is not None:
|
|
419
|
-
# Pad for STFT window
|
|
420
|
-
assert self.n_fft is not None, "n_fft must be set if stft is not None"
|
|
421
|
-
pad_size = self.n_fft // 4
|
|
422
|
-
x = F.pad(
|
|
423
|
-
_pad_multiple(x, self.n_fft // 2), (pad_size, pad_size), mode="reflect"
|
|
424
|
-
)
|
|
425
|
-
|
|
426
|
-
# Apply STFT
|
|
427
|
-
spec = self.stft(x) # (batch, channels, freq, time)
|
|
428
|
-
B, C, Fr, T = spec.shape
|
|
429
|
-
|
|
430
|
-
if self.fft_complex:
|
|
431
|
-
# Convert complex to real/imag channels
|
|
432
|
-
spec = torch.view_as_real(spec).permute(0, 1, 2, 4, 3)
|
|
433
|
-
x = spec.reshape(B, C * 2 * Fr, T)
|
|
434
|
-
else:
|
|
435
|
-
x = spec.reshape(B, C * Fr, T)
|
|
436
|
-
|
|
437
|
-
# Apply channel dropout if enabled
|
|
438
|
-
if self.channel_dropout is not None:
|
|
439
|
-
x = self.channel_dropout(x)
|
|
440
|
-
|
|
441
|
-
# Apply subject layers if enabled
|
|
442
|
-
if self.subject_layers_module is not None:
|
|
443
|
-
if subject_index is None:
|
|
444
|
-
raise ValueError(
|
|
445
|
-
"subject_index is required when subject_layers is enabled"
|
|
446
|
-
)
|
|
447
|
-
x = self.subject_layers_module(x, subject_index)
|
|
448
|
-
|
|
449
|
-
# Apply subject embedding if enabled
|
|
450
|
-
if self.subject_embedding is not None:
|
|
451
|
-
if subject_index is None:
|
|
452
|
-
raise ValueError("subject_index is required when subject_dim > 0")
|
|
453
|
-
emb = self.subject_embedding(subject_index) # (batch, subject_dim)
|
|
454
|
-
emb = emb[:, :, None].expand(
|
|
455
|
-
-1, -1, x.shape[-1]
|
|
456
|
-
) # (batch, subject_dim, time)
|
|
457
|
-
x = torch.cat([x, emb], dim=1) # Concatenate along channel dimension
|
|
458
|
-
|
|
459
|
-
# Project input to hidden dimension
|
|
460
|
-
x = self.input_projection(x)
|
|
461
|
-
|
|
462
|
-
# Encode with residual dilated convolutions
|
|
463
|
-
x = self.encoder(x)
|
|
464
|
-
|
|
465
|
-
# Apply final layer (pool + linear)
|
|
466
|
-
x = self.final_layer(x)
|
|
467
|
-
|
|
468
|
-
return x
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
class _ConvSequence(nn.Module):
|
|
472
|
-
"""Sequence of residual dilated convolutional layers with GLU activation.
|
|
473
|
-
|
|
474
|
-
This is a simplified encoder-only architecture that maintains temporal
|
|
475
|
-
resolution (stride=1) and applies residual connections at every layer
|
|
476
|
-
where input and output channels match.
|
|
477
|
-
|
|
478
|
-
Parameters
|
|
479
|
-
----------
|
|
480
|
-
channels : Sequence[int]
|
|
481
|
-
Channel dimensions for each layer. E.g., [320, 320, 320, 320] for
|
|
482
|
-
a 3-layer network with 320 hidden dims.
|
|
483
|
-
kernel_size : int, default=3
|
|
484
|
-
Convolutional kernel size. Must be odd for proper padding with dilation.
|
|
485
|
-
dilation_growth : int, default=2
|
|
486
|
-
Dilation multiplier per layer. Improves receptive field exponentially.
|
|
487
|
-
dilation_period : int, default=5
|
|
488
|
-
Reset dilation to 1 every N layers.
|
|
489
|
-
dropout : float, default=0.0
|
|
490
|
-
Dropout probability after activation.
|
|
491
|
-
dropout_input : float, default=0.0
|
|
492
|
-
Dropout probability applied to input only.
|
|
493
|
-
batch_norm : bool, default=True
|
|
494
|
-
Whether to apply batch normalization.
|
|
495
|
-
glu : int, default=2
|
|
496
|
-
Apply GLU gating every N layers. If 0, no GLU.
|
|
497
|
-
glu_context : int, default=1
|
|
498
|
-
Context window for GLU convolution.
|
|
499
|
-
activation : type, default=nn.GELU
|
|
500
|
-
Activation function class.
|
|
501
|
-
"""
|
|
502
|
-
|
|
503
|
-
def __init__(
|
|
504
|
-
self,
|
|
505
|
-
channels: tp.Sequence[int],
|
|
506
|
-
kernel_size: int = 3,
|
|
507
|
-
dilation_growth: int = 2,
|
|
508
|
-
dilation_period: int = 5,
|
|
509
|
-
dropout: float = 0.0,
|
|
510
|
-
dropout_input: float = 0.0,
|
|
511
|
-
batch_norm: bool = True,
|
|
512
|
-
glu: int = 2,
|
|
513
|
-
glu_context: int = 1,
|
|
514
|
-
activation: tp.Any = None,
|
|
515
|
-
) -> None:
|
|
516
|
-
super().__init__()
|
|
517
|
-
|
|
518
|
-
if dilation_growth > 1:
|
|
519
|
-
assert kernel_size % 2 != 0, (
|
|
520
|
-
"Supports only odd kernel with dilation for now"
|
|
521
|
-
)
|
|
522
|
-
|
|
523
|
-
if activation is None:
|
|
524
|
-
activation = nn.GELU
|
|
525
|
-
|
|
526
|
-
self.sequence = nn.ModuleList()
|
|
527
|
-
self.glus = nn.ModuleList()
|
|
528
|
-
self.skip_projections = nn.ModuleList() # For when chin != chout
|
|
529
|
-
|
|
530
|
-
dilation = 1
|
|
531
|
-
channels = tuple(channels)
|
|
532
|
-
|
|
533
|
-
for k, (chin, chout) in enumerate(zip(channels[:-1], channels[1:])):
|
|
534
|
-
layers: tp.List[nn.Module] = []
|
|
535
|
-
|
|
536
|
-
# Input dropout (only on first layer)
|
|
537
|
-
if k == 0 and dropout_input > 0:
|
|
538
|
-
layers.append(nn.Dropout(dropout_input))
|
|
539
|
-
|
|
540
|
-
# Reset dilation periodically
|
|
541
|
-
if dilation_period and (k % dilation_period) == 0:
|
|
542
|
-
dilation = 1
|
|
543
|
-
|
|
544
|
-
# Dilated convolution with proper padding to maintain temporal size
|
|
545
|
-
pad = kernel_size // 2 * dilation
|
|
546
|
-
layers.extend(
|
|
547
|
-
[
|
|
548
|
-
nn.Conv1d(
|
|
549
|
-
chin,
|
|
550
|
-
chout,
|
|
551
|
-
kernel_size=kernel_size,
|
|
552
|
-
stride=1, # Always stride=1 for residual connections
|
|
553
|
-
padding=pad,
|
|
554
|
-
dilation=dilation,
|
|
555
|
-
),
|
|
556
|
-
]
|
|
557
|
-
)
|
|
558
|
-
|
|
559
|
-
# Batch norm + activation + dropout
|
|
560
|
-
if batch_norm:
|
|
561
|
-
layers.append(nn.BatchNorm1d(num_features=chout))
|
|
562
|
-
layers.append(activation())
|
|
563
|
-
if dropout > 0:
|
|
564
|
-
layers.append(nn.Dropout(dropout))
|
|
565
|
-
|
|
566
|
-
dilation *= dilation_growth
|
|
567
|
-
|
|
568
|
-
self.sequence.append(nn.Sequential(*layers))
|
|
569
|
-
|
|
570
|
-
# Add skip projection if channels don't match (for growth != 1.0)
|
|
571
|
-
if chin != chout:
|
|
572
|
-
self.skip_projections.append(nn.Conv1d(chin, chout, 1))
|
|
573
|
-
else:
|
|
574
|
-
self.skip_projections.append(None)
|
|
575
|
-
|
|
576
|
-
# GLU gating every N layers
|
|
577
|
-
if glu > 0 and (k + 1) % glu == 0:
|
|
578
|
-
self.glus.append(
|
|
579
|
-
nn.Sequential(
|
|
580
|
-
nn.Conv1d(
|
|
581
|
-
chout, 2 * chout, 1 + 2 * glu_context, padding=glu_context
|
|
582
|
-
),
|
|
583
|
-
nn.GLU(dim=1),
|
|
584
|
-
)
|
|
585
|
-
)
|
|
586
|
-
else:
|
|
587
|
-
self.glus.append(None)
|
|
588
|
-
|
|
589
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
590
|
-
for module, glu, skip_proj in zip(
|
|
591
|
-
self.sequence, self.glus, self.skip_projections
|
|
592
|
-
):
|
|
593
|
-
# Apply residual connection
|
|
594
|
-
# If channels match, add directly; otherwise use projection
|
|
595
|
-
if skip_proj is not None:
|
|
596
|
-
x = skip_proj(x) + module(x)
|
|
597
|
-
else:
|
|
598
|
-
x = x + module(x)
|
|
599
|
-
if glu is not None:
|
|
600
|
-
x = glu(x)
|
|
601
|
-
return x
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
def _pad_multiple(x: torch.Tensor, base: int) -> torch.Tensor:
|
|
605
|
-
"""Pad tensor to be a multiple of base."""
|
|
606
|
-
length = x.shape[-1]
|
|
607
|
-
target = math.ceil(length / base) * base
|
|
608
|
-
return F.pad(x, (0, target - length))
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
class _ScaledEmbedding(nn.Module):
|
|
612
|
-
"""Scaled embedding layer for subjects.
|
|
613
|
-
|
|
614
|
-
Scales up the learning rate for the embedding to prevent slow convergence.
|
|
615
|
-
Used for subject-specific representations.
|
|
616
|
-
"""
|
|
617
|
-
|
|
618
|
-
def __init__(self, num_embeddings: int, embedding_dim: int, scale: float = 10.0):
|
|
619
|
-
super().__init__()
|
|
620
|
-
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
|
621
|
-
self.embedding.weight.data /= scale
|
|
622
|
-
self.scale = scale
|
|
623
|
-
|
|
624
|
-
@property
|
|
625
|
-
def weight(self) -> torch.Tensor:
|
|
626
|
-
"""Get scaled embedding weights."""
|
|
627
|
-
return self.embedding.weight * self.scale
|
|
628
|
-
|
|
629
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
630
|
-
"""Forward pass.
|
|
631
|
-
|
|
632
|
-
Parameters
|
|
633
|
-
----------
|
|
634
|
-
x : torch.Tensor
|
|
635
|
-
Subject indices of shape (batch,).
|
|
636
|
-
|
|
637
|
-
Returns
|
|
638
|
-
-------
|
|
639
|
-
torch.Tensor
|
|
640
|
-
Scaled embeddings of shape (batch, embedding_dim).
|
|
641
|
-
"""
|
|
642
|
-
return self.embedding(x) * self.scale
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
class _ChannelDropout(nn.Module):
|
|
646
|
-
"""Channel dropout with rescaling and optional ch_info support.
|
|
647
|
-
|
|
648
|
-
Randomly drops channels during training and rescales output to maintain
|
|
649
|
-
expected value. Optionally supports selective channel dropout based on
|
|
650
|
-
channel type (EEG, reference, EOG, etc.) using ch_info metadata.
|
|
651
|
-
|
|
652
|
-
Parameters
|
|
653
|
-
----------
|
|
654
|
-
dropout_prob : float, default=0.0
|
|
655
|
-
Probability of dropping each channel (0.0 to 1.0).
|
|
656
|
-
If 0.0, no dropout is applied.
|
|
657
|
-
ch_info : list of dict, optional
|
|
658
|
-
Channel information from MNE (e.g., from raw.info['chs']).
|
|
659
|
-
Each dict should have 'ch_name' and 'ch_type' keys.
|
|
660
|
-
If provided, enables selective channel dropout by type.
|
|
661
|
-
channel_type : str, optional
|
|
662
|
-
If specified with ch_info, only drop channels of this type
|
|
663
|
-
(e.g., 'eeg', 'ref', 'eog'). If None, drop from all available channels.
|
|
664
|
-
rescale : bool, default=True
|
|
665
|
-
If True, rescale output to maintain expected value.
|
|
666
|
-
scale_factor = n_channels / (n_channels - n_dropped)
|
|
667
|
-
|
|
668
|
-
Examples
|
|
669
|
-
--------
|
|
670
|
-
>>> # Random channel dropout
|
|
671
|
-
>>> dropout = _ChannelDropout(dropout_prob=0.1)
|
|
672
|
-
>>> x = torch.randn(4, 32, 1000)
|
|
673
|
-
>>> x_dropped = dropout(x)
|
|
674
|
-
|
|
675
|
-
>>> # Selective EEG dropout using ch_info
|
|
676
|
-
>>> ch_info = [{'ch_name': 'Fp1', 'ch_type': 'eeg'}, ...]
|
|
677
|
-
>>> dropout_eeg = _ChannelDropout(
|
|
678
|
-
... dropout_prob=0.1,
|
|
679
|
-
... ch_info=ch_info,
|
|
680
|
-
... channel_type='eeg' # Only drop EEG channels
|
|
681
|
-
... )
|
|
682
|
-
>>> x_dropped = dropout_eeg(x) # Reference channels never dropped
|
|
683
|
-
"""
|
|
684
|
-
|
|
685
|
-
def __init__(
|
|
686
|
-
self,
|
|
687
|
-
dropout_prob: float = 0.0,
|
|
688
|
-
ch_info: list[dict] | None = None,
|
|
689
|
-
channel_type: str | None = None,
|
|
690
|
-
rescale: bool = True,
|
|
691
|
-
):
|
|
692
|
-
super().__init__()
|
|
693
|
-
|
|
694
|
-
if not 0.0 <= dropout_prob <= 1.0:
|
|
695
|
-
raise ValueError(f"dropout_prob must be in [0.0, 1.0], got {dropout_prob}")
|
|
696
|
-
if channel_type is not None and ch_info is None:
|
|
697
|
-
raise ValueError("channel_type requires ch_info to be provided")
|
|
698
|
-
|
|
699
|
-
self.dropout_prob = dropout_prob
|
|
700
|
-
self.rescale = rescale
|
|
701
|
-
self.ch_info = ch_info
|
|
702
|
-
self.channel_type = channel_type
|
|
703
|
-
|
|
704
|
-
# Compute droppable channel indices
|
|
705
|
-
self.droppable_indices: list[int] | None = None
|
|
706
|
-
if ch_info is not None:
|
|
707
|
-
if channel_type is not None:
|
|
708
|
-
# Drop only specific type
|
|
709
|
-
self.droppable_indices = [
|
|
710
|
-
i
|
|
711
|
-
for i, ch in enumerate(ch_info)
|
|
712
|
-
if ch.get("ch_type", ch.get("kind")) == channel_type
|
|
713
|
-
]
|
|
714
|
-
else:
|
|
715
|
-
# Drop any channel
|
|
716
|
-
self.droppable_indices = list(range(len(ch_info)))
|
|
717
|
-
|
|
718
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
719
|
-
"""Forward pass with channel dropout.
|
|
720
|
-
|
|
721
|
-
Parameters
|
|
722
|
-
----------
|
|
723
|
-
x : torch.Tensor
|
|
724
|
-
Input of shape (batch, channels, time).
|
|
725
|
-
|
|
726
|
-
Returns
|
|
727
|
-
-------
|
|
728
|
-
torch.Tensor
|
|
729
|
-
Output of same shape as input, with selected channels randomly zeroed.
|
|
730
|
-
"""
|
|
731
|
-
if not self.training or self.dropout_prob == 0:
|
|
732
|
-
return x
|
|
733
|
-
|
|
734
|
-
_, channels, _ = x.shape
|
|
735
|
-
|
|
736
|
-
# Determine which channels to drop
|
|
737
|
-
if self.droppable_indices is not None:
|
|
738
|
-
# Only drop from specified indices
|
|
739
|
-
n_droppable = len(self.droppable_indices)
|
|
740
|
-
n_to_drop = max(1, int(n_droppable * self.dropout_prob))
|
|
741
|
-
if n_to_drop > 0:
|
|
742
|
-
drop_indices = torch.tensor(
|
|
743
|
-
self.droppable_indices, device=x.device, dtype=torch.long
|
|
744
|
-
)
|
|
745
|
-
# Randomly select which droppable indices to actually drop
|
|
746
|
-
selected = torch.randperm(n_droppable, device=x.device)[:n_to_drop]
|
|
747
|
-
drop_indices = drop_indices[selected]
|
|
748
|
-
else:
|
|
749
|
-
drop_indices = torch.tensor([], device=x.device, dtype=torch.long)
|
|
750
|
-
else:
|
|
751
|
-
# Drop from any channel
|
|
752
|
-
n_to_drop = max(1, int(channels * self.dropout_prob))
|
|
753
|
-
drop_indices = torch.randperm(channels, device=x.device)[:n_to_drop]
|
|
754
|
-
|
|
755
|
-
# Clone and apply dropout
|
|
756
|
-
if len(drop_indices) > 0:
|
|
757
|
-
x_out = x.clone()
|
|
758
|
-
x_out[:, drop_indices, :] = 0
|
|
759
|
-
|
|
760
|
-
# Rescale to maintain expected value
|
|
761
|
-
if self.rescale:
|
|
762
|
-
scale_factor = channels / (channels - len(drop_indices))
|
|
763
|
-
x_out = x_out * scale_factor
|
|
764
|
-
else:
|
|
765
|
-
x_out = x
|
|
766
|
-
|
|
767
|
-
return x_out
|
|
768
|
-
|
|
769
|
-
def __repr__(self) -> str:
|
|
770
|
-
return (
|
|
771
|
-
f"ChannelDropout(dropout_prob={self.dropout_prob}, "
|
|
772
|
-
f"rescale={self.rescale}, "
|
|
773
|
-
f"channel_type={self.channel_type})"
|
|
774
|
-
)
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
def _validate_brainmodule_params(
|
|
778
|
-
subject_layers: bool,
|
|
779
|
-
subject_dim: int,
|
|
780
|
-
depth: int,
|
|
781
|
-
kernel_size: int,
|
|
782
|
-
growth: float,
|
|
783
|
-
dilation_growth: int,
|
|
784
|
-
channel_dropout_prob: float,
|
|
785
|
-
channel_dropout_type: str | None,
|
|
786
|
-
glu: int,
|
|
787
|
-
glu_context: int,
|
|
788
|
-
) -> None:
|
|
789
|
-
"""Validate BrainModule parameters.
|
|
790
|
-
|
|
791
|
-
Parameters
|
|
792
|
-
----------
|
|
793
|
-
subject_layers : bool
|
|
794
|
-
Whether to use subject-specific layer transformations.
|
|
795
|
-
subject_dim : int
|
|
796
|
-
Dimension of subject embeddings.
|
|
797
|
-
depth : int
|
|
798
|
-
Number of convolutional blocks.
|
|
799
|
-
kernel_size : int
|
|
800
|
-
Convolutional kernel size.
|
|
801
|
-
growth : float
|
|
802
|
-
Channel size multiplier per layer.
|
|
803
|
-
dilation_growth : int
|
|
804
|
-
Dilation multiplier per layer.
|
|
805
|
-
channel_dropout_prob : float
|
|
806
|
-
Channel dropout probability.
|
|
807
|
-
channel_dropout_type : str or None
|
|
808
|
-
Channel type to selectively drop.
|
|
809
|
-
glu : int
|
|
810
|
-
GLU gating interval.
|
|
811
|
-
glu_context : int
|
|
812
|
-
GLU context window size.
|
|
813
|
-
|
|
814
|
-
Raises
|
|
815
|
-
------
|
|
816
|
-
ValueError
|
|
817
|
-
If any parameter combination is invalid.
|
|
818
|
-
"""
|
|
819
|
-
validations = [
|
|
820
|
-
(
|
|
821
|
-
subject_layers and subject_dim == 0,
|
|
822
|
-
"subject_layers=True requires subject_dim > 0",
|
|
823
|
-
),
|
|
824
|
-
(depth < 1, "depth must be >= 1"),
|
|
825
|
-
(kernel_size <= 0, "kernel_size must be > 0"),
|
|
826
|
-
(kernel_size % 2 == 0, "kernel_size must be odd for proper padding"),
|
|
827
|
-
(growth <= 0, "growth must be > 0"),
|
|
828
|
-
(dilation_growth < 1, "dilation_growth must be >= 1"),
|
|
829
|
-
(
|
|
830
|
-
not 0.0 <= channel_dropout_prob <= 1.0,
|
|
831
|
-
"channel_dropout_prob must be in [0.0, 1.0]",
|
|
832
|
-
),
|
|
833
|
-
(
|
|
834
|
-
channel_dropout_type is not None and channel_dropout_prob == 0.0,
|
|
835
|
-
"channel_dropout_type requires channel_dropout_prob > 0",
|
|
836
|
-
),
|
|
837
|
-
(glu < 0, "glu must be >= 0"),
|
|
838
|
-
(glu_context < 0, "glu_context must be >= 0"),
|
|
839
|
-
(glu_context > 0 and glu == 0, "glu_context > 0 requires glu > 0"),
|
|
840
|
-
(glu_context >= kernel_size, "glu_context must be < kernel_size"),
|
|
841
|
-
]
|
|
842
|
-
|
|
843
|
-
for condition, message in validations:
|
|
844
|
-
if condition:
|
|
845
|
-
raise ValueError(message)
|