braindecode 1.3.0.dev176728557__py3-none-any.whl → 1.3.0.dev178811222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/augmentation/functional.py +154 -54
- braindecode/augmentation/transforms.py +2 -2
- braindecode/datasets/base.py +1 -1
- braindecode/datasets/sleep_physio_challe_18.py +2 -1
- braindecode/datautil/serialization.py +11 -6
- braindecode/eegneuralnet.py +2 -0
- braindecode/models/__init__.py +4 -0
- braindecode/models/atcnet.py +7 -7
- braindecode/models/attentionbasenet.py +2 -0
- braindecode/models/biot.py +1 -1
- braindecode/models/eegnet.py +4 -3
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +8 -6
- braindecode/models/util.py +2 -0
- braindecode/preprocessing/preprocess.py +11 -2
- braindecode/version.py +1 -1
- {braindecode-1.3.0.dev176728557.dist-info → braindecode-1.3.0.dev178811222.dist-info}/METADATA +1 -1
- {braindecode-1.3.0.dev176728557.dist-info → braindecode-1.3.0.dev178811222.dist-info}/RECORD +23 -21
- {braindecode-1.3.0.dev176728557.dist-info → braindecode-1.3.0.dev178811222.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev176728557.dist-info → braindecode-1.3.0.dev178811222.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev176728557.dist-info → braindecode-1.3.0.dev178811222.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev176728557.dist-info → braindecode-1.3.0.dev178811222.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,869 @@
|
|
|
1
|
+
# Authors: Can Han <hancan@sjtu.edu.cn> (original paper and code,
|
|
2
|
+
# first iteration of braindecode adaptation)
|
|
3
|
+
# Bruno Aristimunha <b.aristimunha@gmail.com> (braindecode adaptation)
|
|
4
|
+
#
|
|
5
|
+
# License: BSD (3-clause)
|
|
6
|
+
|
|
7
|
+
from typing import List, Optional, Tuple, Union
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
from einops import rearrange
|
|
13
|
+
|
|
14
|
+
from braindecode.models.base import EEGModuleMixin
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SSTDPN(EEGModuleMixin, nn.Module):
|
|
18
|
+
r"""SSTDPN from Can Han et al (2025) [Han2025]_.
|
|
19
|
+
|
|
20
|
+
:bdg-info:`Small Attention` :bdg-success:`Convolution`
|
|
21
|
+
|
|
22
|
+
.. figure:: https://raw.githubusercontent.com/hancan16/SST-DPN/refs/heads/main/figs/framework.png
|
|
23
|
+
:align: center
|
|
24
|
+
:alt: SSTDPN Architecture
|
|
25
|
+
:width: 1000px
|
|
26
|
+
|
|
27
|
+
The **Spatial–Spectral** and **Temporal - Dual Prototype Network** (SST-DPN)
|
|
28
|
+
is an end-to-end 1D convolutional architecture designed for motor imagery (MI) EEG decoding,
|
|
29
|
+
aiming to address challenges related to discriminative feature extraction and
|
|
30
|
+
small-sample sizes [Han2025]_.
|
|
31
|
+
|
|
32
|
+
The framework systematically addresses three key challenges: multi-channel spatial–spectral
|
|
33
|
+
features and long-term temporal features [Han2025]_.
|
|
34
|
+
|
|
35
|
+
.. rubric:: Architectural Overview
|
|
36
|
+
|
|
37
|
+
SST-DPN consists of a feature extractor (_SSTEncoder, comprising Adaptive Spatial-Spectral
|
|
38
|
+
Fusion and Multi-scale Variance Pooling) followed by Dual Prototype Learning classification [Han2025]_.
|
|
39
|
+
|
|
40
|
+
1. **Adaptive Spatial–Spectral Fusion (ASSF)**: Uses :class:`_DepthwiseTemporalConv1d` to generate a
|
|
41
|
+
multi-channel spatial–spectral representation, followed by :class:`_SpatSpectralAttn`
|
|
42
|
+
(Spatial-Spectral Attention) to model relationships and highlight key spatial–spectral
|
|
43
|
+
channels [Han2025]_.
|
|
44
|
+
|
|
45
|
+
2. **Multi-scale Variance Pooling (MVP)**: Applies :class:`_MultiScaleVarPooler` with variance pooling
|
|
46
|
+
at multiple temporal scales to capture long-range temporal dependencies, serving as an
|
|
47
|
+
efficient alternative to transformers [Han2025]_.
|
|
48
|
+
|
|
49
|
+
3. **Dual Prototype Learning (DPL)**: A training strategy that employs two sets of
|
|
50
|
+
prototypes—Inter-class Separation Prototypes (proto_sep) and Intra-class Compact
|
|
51
|
+
Prototypes (proto_cpt)—to optimize the feature space, enhancing generalization ability and
|
|
52
|
+
preventing overfitting on small datasets [Han2025]_. During inference (forward pass),
|
|
53
|
+
classification decisions are based on the distance (dot product) between the
|
|
54
|
+
feature vector and proto_sep for each class [Han2025]_.
|
|
55
|
+
|
|
56
|
+
.. rubric:: Macro Components
|
|
57
|
+
|
|
58
|
+
- `SSTDPN.encoder` **(Feature Extractor)**
|
|
59
|
+
|
|
60
|
+
- *Operations.* Combines Adaptive Spatial–Spectral Fusion and Multi-scale Variance Pooling
|
|
61
|
+
via an internal :class:`_SSTEncoder`.
|
|
62
|
+
- *Role.* Maps the raw MI-EEG trial :math:`X_i \in \mathbb{R}^{C \times T}` to the
|
|
63
|
+
feature space :math:`z_i \in \mathbb{R}^d`.
|
|
64
|
+
|
|
65
|
+
- `_SSTEncoder.temporal_conv` **(Depthwise Temporal Convolution for Spectral Extraction)**
|
|
66
|
+
|
|
67
|
+
- *Operations.* Internal :class:`_DepthwiseTemporalConv1d` applying separate temporal
|
|
68
|
+
convolution filters to each channel with kernel size `temporal_conv_kernel_size` and
|
|
69
|
+
depth multiplier `n_spectral_filters_temporal` (equivalent to :math:`F_1` in the paper).
|
|
70
|
+
- *Role.* Extracts multiple distinct spectral bands from each EEG channel independently.
|
|
71
|
+
|
|
72
|
+
- `_SSTEncoder.spt_attn` **(Spatial–Spectral Attention for Channel Gating)**
|
|
73
|
+
|
|
74
|
+
- *Operations.* Internal :class:`_SpatSpectralAttn` module using Global Context Embedding
|
|
75
|
+
via variance-based pooling, followed by adaptive channel normalization and gating.
|
|
76
|
+
- *Role.* Reweights channels in the spatial–spectral dimension to extract efficient and
|
|
77
|
+
discriminative features by emphasizing task-relevant regions and frequency bands.
|
|
78
|
+
|
|
79
|
+
- `_SSTEncoder.chan_conv` **(Pointwise Fusion across Channels)**
|
|
80
|
+
|
|
81
|
+
- *Operations.* A 1D pointwise convolution with `n_fused_filters` output channels
|
|
82
|
+
(equivalent to :math:`F_2` in the paper), followed by BatchNorm and the specified
|
|
83
|
+
`activation` function (default: ELU).
|
|
84
|
+
- *Role.* Fuses the weighted spatial–spectral features across all electrodes to produce
|
|
85
|
+
a fused representation :math:`X_{fused} \in \mathbb{R}^{F_2 \times T}`.
|
|
86
|
+
|
|
87
|
+
- `_SSTEncoder.mvp` **(Multi-scale Variance Pooling for Temporal Extraction)**
|
|
88
|
+
|
|
89
|
+
- *Operations.* Internal :class:`_MultiScaleVarPooler` using :class:`_VariancePool1D`
|
|
90
|
+
layers at multiple scales (`mvp_kernel_sizes`), followed by concatenation.
|
|
91
|
+
- *Role.* Captures long-range temporal features at multiple time scales. The variance
|
|
92
|
+
operation leverages the prior that variance represents EEG spectral power.
|
|
93
|
+
|
|
94
|
+
- `SSTDPN.proto_sep` / `SSTDPN.proto_cpt` **(Dual Prototypes)**
|
|
95
|
+
|
|
96
|
+
- *Operations.* Learnable vectors optimized during training using prototype learning losses.
|
|
97
|
+
The `proto_sep` (Inter-class Separation Prototype) is constrained via L2 weight-normalization
|
|
98
|
+
(:math:`\lVert s_i \rVert_2 \leq` `proto_sep_maxnorm`) during inference.
|
|
99
|
+
- *Role.* `proto_sep` achieves inter-class separation; `proto_cpt` enhances intra-class compactness.
|
|
100
|
+
|
|
101
|
+
.. rubric:: How the information is encoded temporally, spatially, and spectrally
|
|
102
|
+
|
|
103
|
+
* **Temporal.**
|
|
104
|
+
The initial :class:`_DepthwiseTemporalConv1d` uses a large kernel (e.g., 75). The MVP module employs pooling
|
|
105
|
+
kernels that are much larger (e.g., 50, 100, 200 samples) to capture long-term temporal
|
|
106
|
+
features effectively. Large kernel pooling layers are shown to be superior to transformer
|
|
107
|
+
modules for this task in EEG decoding according to [Han2025]_.
|
|
108
|
+
|
|
109
|
+
* **Spatial.**
|
|
110
|
+
The initial convolution at the classes :class:`_DepthwiseTemporalConv1d` groups parameter :math:`h=1`,
|
|
111
|
+
meaning :math:`F_1` temporal filters are shared across channels. The Spatial-Spectral Attention
|
|
112
|
+
mechanism explicitly models the relationships among these channels in the spatial–spectral
|
|
113
|
+
dimension, allowing for finer-grained spatial feature modeling compared to conventional
|
|
114
|
+
GCNs according to the authors [Han2025]_.
|
|
115
|
+
In other words, all electrode channels share :math:`F_1` temporal filters
|
|
116
|
+
independently to produce the spatial–spectral representation.
|
|
117
|
+
|
|
118
|
+
* **Spectral.**
|
|
119
|
+
Spectral information is implicitly extracted via the :math:`F_1` filters in :class:`_DepthwiseTemporalConv1d`.
|
|
120
|
+
Furthermore, the use of Variance Pooling (in MVP) explicitly leverages the neurophysiological
|
|
121
|
+
prior that the **variance of EEG signals represents their spectral power**, which is an
|
|
122
|
+
important feature for distinguishing different MI classes [Han2025]_.
|
|
123
|
+
|
|
124
|
+
.. rubric:: Additional Mechanisms
|
|
125
|
+
|
|
126
|
+
- **Attention.** A lightweight Spatial-Spectral Attention mechanism models spatial–spectral relationships
|
|
127
|
+
at the channel level, distinct from applying attention to deep feature dimensions,
|
|
128
|
+
which is common in comparison methods like :class:`ATCNet`.
|
|
129
|
+
- **Regularization.** Dual Prototype Learning acts as a regularization technique
|
|
130
|
+
by optimizing the feature space to be compact within classes and separated between
|
|
131
|
+
classes. This enhances model generalization and classification performance, particularly
|
|
132
|
+
useful for limited data typical of MI-EEG tasks, without requiring external transfer
|
|
133
|
+
learning data, according to [Han2025]_.
|
|
134
|
+
|
|
135
|
+
Notes
|
|
136
|
+
----------
|
|
137
|
+
* The implementation of the DPL loss functions (:math:`\mathcal{L}_S`, :math:`\mathcal{L}_C`, :math:`\mathcal{L}_{EF}`)
|
|
138
|
+
and the optimization of ICPs are typically handled outside the primary ``forward`` method, within the training strategy
|
|
139
|
+
(see Ref. 52 in [Han2025]_).
|
|
140
|
+
* The default parameters are configured based on the BCI Competition IV 2a dataset.
|
|
141
|
+
* The use of Prototype Learning (PL) methods is novel in the field of EEG-MI decoding.
|
|
142
|
+
* **Lowest FLOPs:** Achieves the lowest Floating Point Operations (FLOPs) (9.65 M) among competitive
|
|
143
|
+
SOTA methods, including braindecode models like :class:`ATCNet` (29.81 M) and
|
|
144
|
+
:class:`EEGConformer` (63.86 M), demonstrating computational efficiency [Han2025]_.
|
|
145
|
+
* **Transformer Alternative:** Multi-scale Variance Pooling (MVP) provides a accuracy
|
|
146
|
+
improvement over temporal attention transformer modules in ablation studies, offering a more
|
|
147
|
+
efficient alternative to transformer-based approaches like :class:`EEGConformer` [Han2025]_.
|
|
148
|
+
|
|
149
|
+
.. warning::
|
|
150
|
+
|
|
151
|
+
**Important:** To utilize the full potential of SSTDPN with Dual Prototype Learning (DPL),
|
|
152
|
+
users must implement the DPL optimization strategy outside the model's forward method.
|
|
153
|
+
For implementation details and training strategies, please consult the official code at
|
|
154
|
+
[Han2025Code]_:
|
|
155
|
+
https://github.com/hancan16/SST-DPN/blob/main/train.py
|
|
156
|
+
|
|
157
|
+
Parameters
|
|
158
|
+
----------
|
|
159
|
+
n_spectral_filters_temporal : int, optional
|
|
160
|
+
Number of spectral filters extracted per channel via temporal convolution.
|
|
161
|
+
These represent the temporal spectral bands (equivalent to :math:`F_1` in the paper).
|
|
162
|
+
Default is 9.
|
|
163
|
+
|
|
164
|
+
n_fused_filters : int, optional
|
|
165
|
+
Number of output filters after pointwise fusion convolution.
|
|
166
|
+
These fuse the spectral filters across all channels (equivalent to :math:`F_2` in the paper).
|
|
167
|
+
Default is 48.
|
|
168
|
+
|
|
169
|
+
temporal_conv_kernel_size : int, optional
|
|
170
|
+
Kernel size for the temporal convolution layer. Controls the receptive field for extracting
|
|
171
|
+
spectral information. Default is 75 samples.
|
|
172
|
+
|
|
173
|
+
mvp_kernel_sizes : list[int], optional
|
|
174
|
+
Kernel sizes for Multi-scale Variance Pooling (MVP) module.
|
|
175
|
+
Larger kernels capture long-term temporal dependencies .
|
|
176
|
+
|
|
177
|
+
return_features : bool, optional
|
|
178
|
+
If True, the forward pass returns (features, logits). If False, returns only logits.
|
|
179
|
+
Default is False.
|
|
180
|
+
|
|
181
|
+
proto_sep_maxnorm : float, optional
|
|
182
|
+
Maximum L2 norm constraint for Inter-class Separation Prototypes during forward pass.
|
|
183
|
+
This constraint acts as an implicit force to push features away from the origin. Default is 1.0.
|
|
184
|
+
|
|
185
|
+
proto_cpt_std : float, optional
|
|
186
|
+
Standard deviation for Intra-class Compactness Prototype initialization. Default is 0.01.
|
|
187
|
+
|
|
188
|
+
spt_attn_global_context_kernel : int, optional
|
|
189
|
+
Kernel size for global context embedding in Spatial-Spectral Attention module.
|
|
190
|
+
Default is 250 samples.
|
|
191
|
+
|
|
192
|
+
spt_attn_epsilon : float, optional
|
|
193
|
+
Small epsilon value for numerical stability in Spatial-Spectral Attention. Default is 1e-5.
|
|
194
|
+
|
|
195
|
+
spt_attn_mode : str, optional
|
|
196
|
+
Embedding computation mode for Spatial-Spectral Attention ('var', 'l2', or 'l1').
|
|
197
|
+
Default is 'var' (variance-based mean-var operation).
|
|
198
|
+
|
|
199
|
+
activation : nn.Module, optional
|
|
200
|
+
Activation function to apply after the pointwise fusion convolution in :class:`_SSTEncoder`.
|
|
201
|
+
Should be a PyTorch activation module class. Default is nn.ELU.
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
References
|
|
205
|
+
----------
|
|
206
|
+
.. [Han2025] Han, C., Liu, C., Wang, J., Wang, Y., Cai, C.,
|
|
207
|
+
& Qian, D. (2025). A spatial–spectral and temporal dual
|
|
208
|
+
prototype network for motor imagery brain–computer
|
|
209
|
+
interface. Knowledge-Based Systems, 315, 113315.
|
|
210
|
+
.. [Han2025Code] Han, C., Liu, C., Wang, J., Wang, Y.,
|
|
211
|
+
Cai, C., & Qian, D. (2025). A spatial–spectral and
|
|
212
|
+
temporal dual prototype network for motor imagery
|
|
213
|
+
brain–computer interface. Knowledge-Based Systems,
|
|
214
|
+
315, 113315. GitHub repository.
|
|
215
|
+
https://github.com/hancan16/SST-DPN.
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
def __init__(
|
|
219
|
+
self,
|
|
220
|
+
# Braindecode standard parameters
|
|
221
|
+
n_chans=None,
|
|
222
|
+
n_times=None,
|
|
223
|
+
n_outputs=None,
|
|
224
|
+
input_window_seconds=None,
|
|
225
|
+
sfreq=None,
|
|
226
|
+
chs_info=None,
|
|
227
|
+
# models parameters
|
|
228
|
+
n_spectral_filters_temporal: int = 9,
|
|
229
|
+
n_fused_filters: int = 48,
|
|
230
|
+
temporal_conv_kernel_size: int = 75,
|
|
231
|
+
mvp_kernel_sizes: Optional[List[int]] = None,
|
|
232
|
+
return_features: bool = False,
|
|
233
|
+
proto_sep_maxnorm: float = 1.0,
|
|
234
|
+
proto_cpt_std: float = 0.01,
|
|
235
|
+
spt_attn_global_context_kernel: int = 250,
|
|
236
|
+
spt_attn_epsilon: float = 1e-5,
|
|
237
|
+
spt_attn_mode: str = "var",
|
|
238
|
+
activation: Optional[nn.Module] = nn.ELU,
|
|
239
|
+
) -> None:
|
|
240
|
+
super().__init__(
|
|
241
|
+
n_chans=n_chans,
|
|
242
|
+
n_outputs=n_outputs,
|
|
243
|
+
chs_info=chs_info,
|
|
244
|
+
n_times=n_times,
|
|
245
|
+
input_window_seconds=input_window_seconds,
|
|
246
|
+
sfreq=sfreq,
|
|
247
|
+
)
|
|
248
|
+
del input_window_seconds, sfreq, chs_info, n_chans, n_outputs, n_times
|
|
249
|
+
|
|
250
|
+
# Set default activation if not provided
|
|
251
|
+
if activation is None:
|
|
252
|
+
activation = nn.ELU
|
|
253
|
+
|
|
254
|
+
# Store hyperparameters
|
|
255
|
+
self.n_spectral_filters_temporal = n_spectral_filters_temporal
|
|
256
|
+
self.n_fused_filters = n_fused_filters
|
|
257
|
+
self.temporal_conv_kernel_size = temporal_conv_kernel_size
|
|
258
|
+
self.mvp_kernel_sizes = (
|
|
259
|
+
mvp_kernel_sizes if mvp_kernel_sizes is not None else [50, 100, 200]
|
|
260
|
+
)
|
|
261
|
+
self.return_features = return_features
|
|
262
|
+
self.proto_sep_maxnorm = proto_sep_maxnorm
|
|
263
|
+
self.proto_cpt_std = proto_cpt_std
|
|
264
|
+
self.spt_attn_global_context_kernel = spt_attn_global_context_kernel
|
|
265
|
+
self.spt_attn_epsilon = spt_attn_epsilon
|
|
266
|
+
self.spt_attn_mode = spt_attn_mode
|
|
267
|
+
self.activation = activation
|
|
268
|
+
|
|
269
|
+
# Encoder accepts (batch, n_chans, n_times)
|
|
270
|
+
self.encoder = _SSTEncoder(
|
|
271
|
+
n_times=self.n_times,
|
|
272
|
+
n_chans=self.n_chans,
|
|
273
|
+
n_spectral_filters_temporal=self.n_spectral_filters_temporal,
|
|
274
|
+
n_fused_filters=self.n_fused_filters,
|
|
275
|
+
temporal_conv_kernel_size=self.temporal_conv_kernel_size,
|
|
276
|
+
mvp_kernel_sizes=self.mvp_kernel_sizes,
|
|
277
|
+
spt_attn_global_context_kernel=self.spt_attn_global_context_kernel,
|
|
278
|
+
spt_attn_epsilon=self.spt_attn_epsilon,
|
|
279
|
+
spt_attn_mode=self.spt_attn_mode,
|
|
280
|
+
activation=self.activation,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
# Infer feature dimension analytically
|
|
284
|
+
feat_dim = self._compute_feature_dim()
|
|
285
|
+
|
|
286
|
+
# Prototypes: Inter-class Separation (ISP) and Intra-class Compactness (ICP)
|
|
287
|
+
# ISP: provides inter-class separation via prototype learning
|
|
288
|
+
# ICP: enhances intra-class compactness
|
|
289
|
+
self.proto_sep = nn.Parameter(
|
|
290
|
+
torch.empty(self.n_outputs, feat_dim), requires_grad=True
|
|
291
|
+
)
|
|
292
|
+
# This parameters is not used in the forward pass, only during training for the
|
|
293
|
+
# prototype learning losses. You should implement the losses outside this class.
|
|
294
|
+
self.proto_cpt = nn.Parameter(
|
|
295
|
+
torch.empty(self.n_outputs, feat_dim), requires_grad=True
|
|
296
|
+
)
|
|
297
|
+
# just for braindecode compatibility
|
|
298
|
+
self.final_layer = nn.Identity()
|
|
299
|
+
|
|
300
|
+
self._reset_parameters()
|
|
301
|
+
|
|
302
|
+
def _reset_parameters(self) -> None:
|
|
303
|
+
"""Initialize prototype parameters."""
|
|
304
|
+
nn.init.kaiming_normal_(self.proto_sep)
|
|
305
|
+
nn.init.normal_(self.proto_cpt, mean=0.0, std=self.proto_cpt_std)
|
|
306
|
+
|
|
307
|
+
def forward(
|
|
308
|
+
self, x: torch.Tensor
|
|
309
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
310
|
+
"""
|
|
311
|
+
Classification is based on the dot product similarity with
|
|
312
|
+
Inter-class Separation Prototypes (:attr:`SSTDPN.proto_sep`).
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
Parameters
|
|
316
|
+
----------
|
|
317
|
+
x : torch.Tensor
|
|
318
|
+
Input tensor. Supported shapes:
|
|
319
|
+
- (batch, n_chans, n_times)
|
|
320
|
+
|
|
321
|
+
Returns
|
|
322
|
+
-------
|
|
323
|
+
logits : torch.Tensor
|
|
324
|
+
If input was 3D: (batch, n_outputs)
|
|
325
|
+
Or if self.return_features is True:
|
|
326
|
+
(features, logits) where features shape is (batch, feat_dim)
|
|
327
|
+
"""
|
|
328
|
+
|
|
329
|
+
features = self.encoder(x) # (b, feat_dim)
|
|
330
|
+
# Renormalize inter-class separation prototypes
|
|
331
|
+
self.proto_sep.data = torch.renorm(
|
|
332
|
+
self.proto_sep.data, p=2, dim=1, maxnorm=self.proto_sep_maxnorm
|
|
333
|
+
)
|
|
334
|
+
logits = torch.einsum("bd,cd->bc", features, self.proto_sep) # (b, n_outputs)
|
|
335
|
+
logits = self.final_layer(logits)
|
|
336
|
+
|
|
337
|
+
if self.return_features:
|
|
338
|
+
return features, logits
|
|
339
|
+
|
|
340
|
+
return logits
|
|
341
|
+
|
|
342
|
+
def _compute_feature_dim(self) -> int:
|
|
343
|
+
"""Compute encoder feature dimensionality without a forward pass."""
|
|
344
|
+
if not self.mvp_kernel_sizes:
|
|
345
|
+
raise ValueError(
|
|
346
|
+
"`mvp_kernel_sizes` must contain at least one kernel size."
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
num_scales = len(self.mvp_kernel_sizes)
|
|
350
|
+
channels_per_scale, rest = divmod(self.n_fused_filters, num_scales)
|
|
351
|
+
if rest:
|
|
352
|
+
raise ValueError(
|
|
353
|
+
"Number of fused filters must be divisible by the number of MVP scales. "
|
|
354
|
+
f"Got {self.n_fused_filters=} and {num_scales=}."
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
# Validate all kernel sizes at once (stride = k // 2 must be >= 1)
|
|
358
|
+
invalid = [k for k in self.mvp_kernel_sizes if k // 2 == 0]
|
|
359
|
+
if invalid:
|
|
360
|
+
raise ValueError(
|
|
361
|
+
"MVP kernel sizes too small to derive a valid stride (k//2 == 0): "
|
|
362
|
+
f"{invalid}"
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
pooled_total = sum(
|
|
366
|
+
self._pool1d_output_length(
|
|
367
|
+
length=self.n_times, kernel_size=k, stride=k // 2, padding=0, dilation=1
|
|
368
|
+
)
|
|
369
|
+
for k in self.mvp_kernel_sizes
|
|
370
|
+
)
|
|
371
|
+
return channels_per_scale * pooled_total
|
|
372
|
+
|
|
373
|
+
@staticmethod
|
|
374
|
+
def _pool1d_output_length(
|
|
375
|
+
length: int, kernel_size: int, stride: int, padding: int = 0, dilation: int = 1
|
|
376
|
+
) -> int:
|
|
377
|
+
"""Temporal length after 1D pooling (PyTorch-style formula)."""
|
|
378
|
+
return max(
|
|
379
|
+
0,
|
|
380
|
+
(length + 2 * padding - (dilation * (kernel_size - 1) + 1)) // stride + 1,
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
class _SSTEncoder(nn.Module):
|
|
385
|
+
"""Internal encoder combining Adaptive Spatial-Spectral Fusion and Multi-scale Variance Pooling.
|
|
386
|
+
|
|
387
|
+
This class should not be instantiated directly. It is an internal component
|
|
388
|
+
of :class:`SSTDPN`.
|
|
389
|
+
|
|
390
|
+
Parameters
|
|
391
|
+
----------
|
|
392
|
+
n_times : int
|
|
393
|
+
Number of time samples in the input window.
|
|
394
|
+
n_chans : int
|
|
395
|
+
Number of EEG channels.
|
|
396
|
+
n_spectral_filters_temporal : int
|
|
397
|
+
Number of spectral filters extracted via temporal convolution (:math:`F_1`).
|
|
398
|
+
n_fused_filters : int
|
|
399
|
+
Number of output filters after pointwise fusion (:math:`F_2`).
|
|
400
|
+
temporal_conv_kernel_size : int
|
|
401
|
+
Kernel size for temporal convolution.
|
|
402
|
+
mvp_kernel_sizes : list[int]
|
|
403
|
+
Kernel sizes for Multi-scale Variance Pooling.
|
|
404
|
+
spt_attn_global_context_kernel : int
|
|
405
|
+
Kernel size for global context in Spatial-Spectral Attention.
|
|
406
|
+
spt_attn_epsilon : float
|
|
407
|
+
Epsilon for numerical stability in Spatial-Spectral Attention.
|
|
408
|
+
spt_attn_mode : str
|
|
409
|
+
Mode for Spatial-Spectral Attention computation ('var', 'l2', or 'l1').
|
|
410
|
+
activation : nn.Module, optional
|
|
411
|
+
Activation function class to use after pointwise convolution. Default is nn.ELU.
|
|
412
|
+
"""
|
|
413
|
+
|
|
414
|
+
def __init__(
|
|
415
|
+
self,
|
|
416
|
+
n_times: int,
|
|
417
|
+
n_chans: int,
|
|
418
|
+
n_spectral_filters_temporal: int = 9,
|
|
419
|
+
n_fused_filters: int = 48,
|
|
420
|
+
temporal_conv_kernel_size: int = 75,
|
|
421
|
+
mvp_kernel_sizes: Optional[List[int]] = None,
|
|
422
|
+
spt_attn_global_context_kernel: int = 250,
|
|
423
|
+
spt_attn_epsilon: float = 1e-5,
|
|
424
|
+
spt_attn_mode: str = "var",
|
|
425
|
+
activation: Optional[nn.Module] = None,
|
|
426
|
+
) -> None:
|
|
427
|
+
super().__init__()
|
|
428
|
+
|
|
429
|
+
if mvp_kernel_sizes is None:
|
|
430
|
+
mvp_kernel_sizes = [50, 100, 200]
|
|
431
|
+
|
|
432
|
+
if activation is None:
|
|
433
|
+
activation = nn.ELU
|
|
434
|
+
|
|
435
|
+
# Adaptive Spatial-Spectral Fusion (ASSF): Temporal convolution for spectral filtering
|
|
436
|
+
self.temporal_conv = _DepthwiseTemporalConv1d(
|
|
437
|
+
in_channels=n_chans,
|
|
438
|
+
num_heads=1,
|
|
439
|
+
n_spectral_filters_temporal=n_spectral_filters_temporal,
|
|
440
|
+
kernel_size=temporal_conv_kernel_size,
|
|
441
|
+
stride=1,
|
|
442
|
+
padding="same",
|
|
443
|
+
bias=True,
|
|
444
|
+
weight_softmax=False,
|
|
445
|
+
)
|
|
446
|
+
# Spatial-Spectral Attention: Gate mechanism for channel weighting
|
|
447
|
+
self.spt_attn = _SpatSpectralAttn(
|
|
448
|
+
T=n_times,
|
|
449
|
+
num_channels=n_chans * n_spectral_filters_temporal,
|
|
450
|
+
epsilon=spt_attn_epsilon,
|
|
451
|
+
mode=spt_attn_mode,
|
|
452
|
+
global_context_kernel=spt_attn_global_context_kernel,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
# Pointwise convolution for fusing spectral filters across channels
|
|
456
|
+
self.chan_conv = nn.Sequential(
|
|
457
|
+
nn.Conv1d(
|
|
458
|
+
n_chans * n_spectral_filters_temporal,
|
|
459
|
+
n_fused_filters,
|
|
460
|
+
kernel_size=1,
|
|
461
|
+
stride=1,
|
|
462
|
+
padding=0,
|
|
463
|
+
bias=True,
|
|
464
|
+
),
|
|
465
|
+
nn.BatchNorm1d(n_fused_filters),
|
|
466
|
+
activation(),
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
# Multi-scale Variance Pooling (MVP): Temporal feature extraction at multiple scales
|
|
470
|
+
self.mvp = _MultiScaleVarPooler(kernel_sizes=mvp_kernel_sizes)
|
|
471
|
+
|
|
472
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
473
|
+
"""
|
|
474
|
+
Forward pass.
|
|
475
|
+
|
|
476
|
+
Parameters
|
|
477
|
+
----------
|
|
478
|
+
x : torch.Tensor
|
|
479
|
+
Input of shape (batch, n_chans, n_times).
|
|
480
|
+
|
|
481
|
+
Returns
|
|
482
|
+
-------
|
|
483
|
+
torch.Tensor
|
|
484
|
+
Feature vector of shape (batch, feat_dim).
|
|
485
|
+
"""
|
|
486
|
+
x = self.temporal_conv(x) # (b, n_chans*n_spectral_filters_temporal, T)
|
|
487
|
+
x, _ = self.spt_attn(x) # (b, n_chans*n_spectral_filters_temporal, T)
|
|
488
|
+
x_fused = self.chan_conv(x) # (b, n_fused_filters, T)
|
|
489
|
+
feature = self.mvp(x_fused) # (b, feat_dim)
|
|
490
|
+
return feature
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
class _DepthwiseTemporalConv1d(nn.Module):
|
|
494
|
+
"""Internal depthwise temporal convolution for spectral filtering.
|
|
495
|
+
|
|
496
|
+
Applies separate temporal convolution filters to each channel independently
|
|
497
|
+
to extract spectral information across multiple bands. This is used to generate
|
|
498
|
+
the spatial-spectral representation in SSTDPN.
|
|
499
|
+
|
|
500
|
+
Not intended for external use.
|
|
501
|
+
|
|
502
|
+
Parameters
|
|
503
|
+
----------
|
|
504
|
+
in_channels : int
|
|
505
|
+
Number of input channels.
|
|
506
|
+
num_heads : int, optional
|
|
507
|
+
Number of filter groups (typically 1). Default is 1.
|
|
508
|
+
n_spectral_filters_temporal : int, optional
|
|
509
|
+
Number of spectral filters per channel (depth multiplier). Default is 1.
|
|
510
|
+
kernel_size : int, optional
|
|
511
|
+
Temporal convolution kernel size. Default is 1.
|
|
512
|
+
stride : int, optional
|
|
513
|
+
Convolution stride. Default is 1.
|
|
514
|
+
padding : str or int, optional
|
|
515
|
+
Padding mode. Default is 0.
|
|
516
|
+
bias : bool, optional
|
|
517
|
+
Whether to use bias. Default is True.
|
|
518
|
+
weight_softmax : bool, optional
|
|
519
|
+
Whether to apply softmax to weights. Default is False.
|
|
520
|
+
"""
|
|
521
|
+
|
|
522
|
+
def __init__(
|
|
523
|
+
self,
|
|
524
|
+
in_channels: int,
|
|
525
|
+
num_heads: int = 1,
|
|
526
|
+
n_spectral_filters_temporal: int = 1,
|
|
527
|
+
kernel_size: int = 1,
|
|
528
|
+
stride: int = 1,
|
|
529
|
+
padding: Union[str, int] = 0,
|
|
530
|
+
bias: bool = True,
|
|
531
|
+
weight_softmax: bool = False,
|
|
532
|
+
) -> None:
|
|
533
|
+
super().__init__()
|
|
534
|
+
self.in_channels = in_channels
|
|
535
|
+
self.kernel_size = kernel_size
|
|
536
|
+
self.stride = stride
|
|
537
|
+
self.num_heads = num_heads
|
|
538
|
+
self.padding = padding
|
|
539
|
+
self.weight_softmax = weight_softmax
|
|
540
|
+
|
|
541
|
+
self.weight = nn.Parameter(
|
|
542
|
+
torch.Tensor(num_heads * n_spectral_filters_temporal, 1, kernel_size)
|
|
543
|
+
)
|
|
544
|
+
self.bias = (
|
|
545
|
+
nn.Parameter(torch.Tensor(num_heads * n_spectral_filters_temporal))
|
|
546
|
+
if bias
|
|
547
|
+
else None
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
self._init_parameters()
|
|
551
|
+
|
|
552
|
+
def _init_parameters(self) -> None:
|
|
553
|
+
"""Initialize parameters."""
|
|
554
|
+
nn.init.xavier_uniform_(self.weight)
|
|
555
|
+
if self.bias is not None:
|
|
556
|
+
nn.init.constant_(self.bias, 0.0)
|
|
557
|
+
|
|
558
|
+
def forward(self, inp: torch.Tensor) -> torch.Tensor:
|
|
559
|
+
"""
|
|
560
|
+
Forward pass.
|
|
561
|
+
|
|
562
|
+
Parameters
|
|
563
|
+
----------
|
|
564
|
+
inp : torch.Tensor
|
|
565
|
+
Input of shape (batch, in_channels, time).
|
|
566
|
+
|
|
567
|
+
Returns
|
|
568
|
+
-------
|
|
569
|
+
torch.Tensor
|
|
570
|
+
Output of shape (batch, num_heads * n_spectral_filters_temporal, time).
|
|
571
|
+
"""
|
|
572
|
+
B, _, _ = inp.size()
|
|
573
|
+
H = self.num_heads
|
|
574
|
+
weight = self.weight
|
|
575
|
+
if self.weight_softmax:
|
|
576
|
+
weight = F.softmax(weight, dim=-1)
|
|
577
|
+
|
|
578
|
+
inp = rearrange(inp, "b (h c) t -> (b c) h t", h=H)
|
|
579
|
+
if self.bias is None:
|
|
580
|
+
output = F.conv1d(
|
|
581
|
+
inp,
|
|
582
|
+
weight,
|
|
583
|
+
stride=self.stride,
|
|
584
|
+
padding=self.padding,
|
|
585
|
+
groups=self.num_heads,
|
|
586
|
+
)
|
|
587
|
+
else:
|
|
588
|
+
output = F.conv1d(
|
|
589
|
+
inp,
|
|
590
|
+
weight,
|
|
591
|
+
bias=self.bias,
|
|
592
|
+
stride=self.stride,
|
|
593
|
+
padding=self.padding,
|
|
594
|
+
groups=self.num_heads,
|
|
595
|
+
)
|
|
596
|
+
output = rearrange(output, "(b c) h t -> b (h c) t", b=B)
|
|
597
|
+
return output
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
class _GlobalContextVarPool1D(nn.Module):
|
|
601
|
+
"""Internal global context variance pooling module.
|
|
602
|
+
|
|
603
|
+
Computes variance-based global context embeddings using specified kernel size.
|
|
604
|
+
Used in the Spatial-Spectral Attention module.
|
|
605
|
+
|
|
606
|
+
Not intended for external use.
|
|
607
|
+
|
|
608
|
+
Parameters
|
|
609
|
+
----------
|
|
610
|
+
T : int
|
|
611
|
+
Sequence length.
|
|
612
|
+
kernel_size : int
|
|
613
|
+
Pooling kernel size.
|
|
614
|
+
stride : int or None, optional
|
|
615
|
+
Stride. If None, defaults to kernel_size. Default is None.
|
|
616
|
+
padding : int, optional
|
|
617
|
+
Padding. Default is 0.
|
|
618
|
+
"""
|
|
619
|
+
|
|
620
|
+
def __init__(
|
|
621
|
+
self, kernel_size: int, stride: Optional[int] = None, padding: int = 0
|
|
622
|
+
) -> None:
|
|
623
|
+
super().__init__()
|
|
624
|
+
self.kernel_size = kernel_size
|
|
625
|
+
self.stride = kernel_size if stride is None else stride
|
|
626
|
+
self.padding = padding
|
|
627
|
+
|
|
628
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
629
|
+
"""
|
|
630
|
+
Forward pass computing global context via variance pooling.
|
|
631
|
+
|
|
632
|
+
Parameters
|
|
633
|
+
----------
|
|
634
|
+
x : torch.Tensor
|
|
635
|
+
Input tensor.
|
|
636
|
+
|
|
637
|
+
Returns
|
|
638
|
+
-------
|
|
639
|
+
torch.Tensor
|
|
640
|
+
Global context (variance-pooled) output.
|
|
641
|
+
"""
|
|
642
|
+
mean_of_squares = F.avg_pool1d(
|
|
643
|
+
x**2, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding
|
|
644
|
+
)
|
|
645
|
+
square_of_mean = (
|
|
646
|
+
F.avg_pool1d(
|
|
647
|
+
x,
|
|
648
|
+
kernel_size=self.kernel_size,
|
|
649
|
+
stride=self.stride,
|
|
650
|
+
padding=self.padding,
|
|
651
|
+
)
|
|
652
|
+
) ** 2
|
|
653
|
+
variance = mean_of_squares - square_of_mean
|
|
654
|
+
out = F.avg_pool1d(variance, variance.shape[-1])
|
|
655
|
+
return out
|
|
656
|
+
|
|
657
|
+
|
|
658
|
+
class _VariancePool1D(nn.Module):
|
|
659
|
+
"""Internal variance pooling module for temporal feature extraction.
|
|
660
|
+
|
|
661
|
+
Applies variance pooling at a specified kernel size to capture temporal dynamics.
|
|
662
|
+
Used in the Multi-scale Variance Pooling (MVP) module.
|
|
663
|
+
|
|
664
|
+
Not intended for external use.
|
|
665
|
+
|
|
666
|
+
Parameters
|
|
667
|
+
----------
|
|
668
|
+
kernel_size : int
|
|
669
|
+
Pooling kernel size (receptive field width).
|
|
670
|
+
stride : int or None, optional
|
|
671
|
+
Stride. If None, defaults to kernel_size. Default is None.
|
|
672
|
+
padding : int, optional
|
|
673
|
+
Padding. Default is 0.
|
|
674
|
+
"""
|
|
675
|
+
|
|
676
|
+
def __init__(
|
|
677
|
+
self, kernel_size: int, stride: Optional[int] = None, padding: int = 0
|
|
678
|
+
) -> None:
|
|
679
|
+
super().__init__()
|
|
680
|
+
self.kernel_size = kernel_size
|
|
681
|
+
self.stride = kernel_size if stride is None else stride
|
|
682
|
+
self.padding = padding
|
|
683
|
+
|
|
684
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
685
|
+
"""
|
|
686
|
+
Forward pass computing variance pooling.
|
|
687
|
+
|
|
688
|
+
Parameters
|
|
689
|
+
----------
|
|
690
|
+
x : torch.Tensor
|
|
691
|
+
Input tensor of shape (batch, channels, time).
|
|
692
|
+
|
|
693
|
+
Returns
|
|
694
|
+
-------
|
|
695
|
+
torch.Tensor
|
|
696
|
+
Variance-pooled output.
|
|
697
|
+
"""
|
|
698
|
+
mean_of_squares = F.avg_pool1d(
|
|
699
|
+
x**2, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding
|
|
700
|
+
)
|
|
701
|
+
square_of_mean = (
|
|
702
|
+
F.avg_pool1d(
|
|
703
|
+
x,
|
|
704
|
+
kernel_size=self.kernel_size,
|
|
705
|
+
stride=self.stride,
|
|
706
|
+
padding=self.padding,
|
|
707
|
+
)
|
|
708
|
+
) ** 2
|
|
709
|
+
variance = mean_of_squares - square_of_mean
|
|
710
|
+
return variance
|
|
711
|
+
|
|
712
|
+
|
|
713
|
+
class _SpatSpectralAttn(nn.Module):
|
|
714
|
+
"""Internal Spatial-Spectral Attention module with global context gating.
|
|
715
|
+
|
|
716
|
+
This attention mechanism computes channel-wise gates based on global context
|
|
717
|
+
embedding and applies adaptive reweighting to highlight task-relevant
|
|
718
|
+
spatial-spectral features.
|
|
719
|
+
|
|
720
|
+
Not intended for external use. Used internally in :class:`_SSTEncoder`.
|
|
721
|
+
|
|
722
|
+
Parameters
|
|
723
|
+
----------
|
|
724
|
+
T : int
|
|
725
|
+
Sequence (temporal) length.
|
|
726
|
+
num_channels : int
|
|
727
|
+
Number of channels in the spatial-spectral dimension.
|
|
728
|
+
epsilon : float, optional
|
|
729
|
+
Small value for numerical stability. Default is 1e-5.
|
|
730
|
+
mode : str, optional
|
|
731
|
+
Embedding computation mode: 'var' (variance-based), 'l2' (L2-norm),
|
|
732
|
+
or 'l1' (L1-norm). Default is 'var'.
|
|
733
|
+
after_relu : bool, optional
|
|
734
|
+
Whether ReLU is applied before this module. Default is False.
|
|
735
|
+
global_context_kernel : int, optional
|
|
736
|
+
Kernel size for global context variance pooling. Default is 250.
|
|
737
|
+
"""
|
|
738
|
+
|
|
739
|
+
def __init__(
|
|
740
|
+
self,
|
|
741
|
+
T: int,
|
|
742
|
+
num_channels: int,
|
|
743
|
+
epsilon: float = 1e-5,
|
|
744
|
+
mode: str = "var",
|
|
745
|
+
after_relu: bool = False,
|
|
746
|
+
global_context_kernel: int = 250,
|
|
747
|
+
) -> None:
|
|
748
|
+
super().__init__()
|
|
749
|
+
# Learnable gating parameters: scale, normalize, and shift
|
|
750
|
+
self.alpha = nn.Parameter(torch.ones(1, num_channels, 1))
|
|
751
|
+
self.gamma = nn.Parameter(torch.zeros(1, num_channels, 1))
|
|
752
|
+
self.beta = nn.Parameter(torch.zeros(1, num_channels, 1))
|
|
753
|
+
self.epsilon = epsilon
|
|
754
|
+
self.mode = mode
|
|
755
|
+
self.after_relu = after_relu
|
|
756
|
+
# check mode validity
|
|
757
|
+
if self.mode not in ["var", "l2", "l1"]:
|
|
758
|
+
raise ValueError(
|
|
759
|
+
f"Unsupported Spatial-Spectral Attention mode: {self.mode}"
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
# Global context module using variance pooling
|
|
763
|
+
self.global_ctx = _GlobalContextVarPool1D(global_context_kernel)
|
|
764
|
+
|
|
765
|
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
766
|
+
"""
|
|
767
|
+
Forward pass computing adaptive channel-wise gating.
|
|
768
|
+
|
|
769
|
+
Parameters
|
|
770
|
+
----------
|
|
771
|
+
x : torch.Tensor
|
|
772
|
+
Input of shape (batch, channels, time).
|
|
773
|
+
|
|
774
|
+
Returns
|
|
775
|
+
-------
|
|
776
|
+
tuple of torch.Tensor
|
|
777
|
+
(gated_output, gate) where both have the same shape as input.
|
|
778
|
+
"""
|
|
779
|
+
|
|
780
|
+
if self.mode == "l2":
|
|
781
|
+
# L2-norm based embedding
|
|
782
|
+
embedding = (x.pow(2).sum(2, keepdim=True) + self.epsilon).pow(0.5)
|
|
783
|
+
norm = self.gamma / (
|
|
784
|
+
embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon
|
|
785
|
+
).pow(0.5)
|
|
786
|
+
elif self.mode == "l1":
|
|
787
|
+
# L1-norm based embedding
|
|
788
|
+
_x = torch.abs(x) if not self.after_relu else x
|
|
789
|
+
embedding = _x.sum(2, keepdim=True)
|
|
790
|
+
norm = self.gamma / (
|
|
791
|
+
torch.abs(embedding).mean(dim=1, keepdim=True) + self.epsilon
|
|
792
|
+
)
|
|
793
|
+
elif self.mode == "var":
|
|
794
|
+
# Variance-based embedding (global context)
|
|
795
|
+
embedding = (self.global_ctx(x) + self.epsilon).pow(0.5) * self.alpha
|
|
796
|
+
norm = (self.gamma) / (
|
|
797
|
+
embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon
|
|
798
|
+
).pow(0.5)
|
|
799
|
+
|
|
800
|
+
# Compute adaptive gate: 1 + tanh(...)
|
|
801
|
+
gate = 1 + torch.tanh(embedding * norm + self.beta)
|
|
802
|
+
return x * gate, gate
|
|
803
|
+
|
|
804
|
+
|
|
805
|
+
class _MultiScaleVarPooler(nn.Module):
|
|
806
|
+
"""Internal Multi-scale Variance Pooling (MVP) module for temporal feature extraction.
|
|
807
|
+
|
|
808
|
+
Applies variance pooling at multiple temporal scales in parallel, then concatenates
|
|
809
|
+
the results to capture long-range temporal dependencies. Each scale processes a subset
|
|
810
|
+
of channels independently, enabling efficient feature extraction.
|
|
811
|
+
|
|
812
|
+
Not intended for external use. Used internally in :class:`_SSTEncoder`.
|
|
813
|
+
|
|
814
|
+
Parameters
|
|
815
|
+
----------
|
|
816
|
+
kernel_sizes : list[int] or None, optional
|
|
817
|
+
Kernel sizes for variance pooling layers at each scale. If None,
|
|
818
|
+
defaults to [50, 100, 200] (suitable for 1000-sample windows).
|
|
819
|
+
"""
|
|
820
|
+
|
|
821
|
+
def __init__(self, kernel_sizes: Optional[List[int]] = None) -> None:
|
|
822
|
+
super().__init__()
|
|
823
|
+
|
|
824
|
+
if kernel_sizes is None:
|
|
825
|
+
kernel_sizes = [50, 100, 200]
|
|
826
|
+
|
|
827
|
+
self.var_layers = nn.ModuleList()
|
|
828
|
+
self.num_scales = len(kernel_sizes)
|
|
829
|
+
|
|
830
|
+
# Create variance pooling layer for each scale
|
|
831
|
+
for k in kernel_sizes:
|
|
832
|
+
self.var_layers.append(
|
|
833
|
+
nn.Sequential(
|
|
834
|
+
_VariancePool1D(kernel_size=k, stride=int(k / 2)),
|
|
835
|
+
nn.Flatten(start_dim=1),
|
|
836
|
+
)
|
|
837
|
+
)
|
|
838
|
+
|
|
839
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
840
|
+
"""
|
|
841
|
+
Forward pass applying multi-scale variance pooling in parallel.
|
|
842
|
+
|
|
843
|
+
Parameters
|
|
844
|
+
----------
|
|
845
|
+
x : torch.Tensor
|
|
846
|
+
Input of shape (batch, channels, time).
|
|
847
|
+
|
|
848
|
+
Returns
|
|
849
|
+
-------
|
|
850
|
+
torch.Tensor
|
|
851
|
+
Concatenated multi-scale features of shape (batch, total_features).
|
|
852
|
+
"""
|
|
853
|
+
_, num_channels, _ = x.shape
|
|
854
|
+
# Split channels equally across scales
|
|
855
|
+
assert num_channels % self.num_scales == 0, (
|
|
856
|
+
f"Channel dimension ({num_channels}) must be divisible by "
|
|
857
|
+
f"number of scales ({self.num_scales})"
|
|
858
|
+
)
|
|
859
|
+
channels_per_scale = num_channels // self.num_scales
|
|
860
|
+
x_split = torch.split(x, channels_per_scale, dim=1)
|
|
861
|
+
|
|
862
|
+
# Apply variance pooling at each scale
|
|
863
|
+
multi_scale_features = []
|
|
864
|
+
for scale_idx, x_scale in enumerate(x_split):
|
|
865
|
+
multi_scale_features.append(self.var_layers[scale_idx](x_scale))
|
|
866
|
+
|
|
867
|
+
# Concatenate features from all scales
|
|
868
|
+
y = torch.concat(multi_scale_features, dim=1)
|
|
869
|
+
return y
|