braindecode 1.3.0.dev176728557__py3-none-any.whl → 1.3.0.dev178811222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

@@ -0,0 +1,869 @@
1
+ # Authors: Can Han <hancan@sjtu.edu.cn> (original paper and code,
2
+ # first iteration of braindecode adaptation)
3
+ # Bruno Aristimunha <b.aristimunha@gmail.com> (braindecode adaptation)
4
+ #
5
+ # License: BSD (3-clause)
6
+
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange
13
+
14
+ from braindecode.models.base import EEGModuleMixin
15
+
16
+
17
+ class SSTDPN(EEGModuleMixin, nn.Module):
18
+ r"""SSTDPN from Can Han et al (2025) [Han2025]_.
19
+
20
+ :bdg-info:`Small Attention` :bdg-success:`Convolution`
21
+
22
+ .. figure:: https://raw.githubusercontent.com/hancan16/SST-DPN/refs/heads/main/figs/framework.png
23
+ :align: center
24
+ :alt: SSTDPN Architecture
25
+ :width: 1000px
26
+
27
+ The **Spatial–Spectral** and **Temporal - Dual Prototype Network** (SST-DPN)
28
+ is an end-to-end 1D convolutional architecture designed for motor imagery (MI) EEG decoding,
29
+ aiming to address challenges related to discriminative feature extraction and
30
+ small-sample sizes [Han2025]_.
31
+
32
+ The framework systematically addresses three key challenges: multi-channel spatial–spectral
33
+ features and long-term temporal features [Han2025]_.
34
+
35
+ .. rubric:: Architectural Overview
36
+
37
+ SST-DPN consists of a feature extractor (_SSTEncoder, comprising Adaptive Spatial-Spectral
38
+ Fusion and Multi-scale Variance Pooling) followed by Dual Prototype Learning classification [Han2025]_.
39
+
40
+ 1. **Adaptive Spatial–Spectral Fusion (ASSF)**: Uses :class:`_DepthwiseTemporalConv1d` to generate a
41
+ multi-channel spatial–spectral representation, followed by :class:`_SpatSpectralAttn`
42
+ (Spatial-Spectral Attention) to model relationships and highlight key spatial–spectral
43
+ channels [Han2025]_.
44
+
45
+ 2. **Multi-scale Variance Pooling (MVP)**: Applies :class:`_MultiScaleVarPooler` with variance pooling
46
+ at multiple temporal scales to capture long-range temporal dependencies, serving as an
47
+ efficient alternative to transformers [Han2025]_.
48
+
49
+ 3. **Dual Prototype Learning (DPL)**: A training strategy that employs two sets of
50
+ prototypes—Inter-class Separation Prototypes (proto_sep) and Intra-class Compact
51
+ Prototypes (proto_cpt)—to optimize the feature space, enhancing generalization ability and
52
+ preventing overfitting on small datasets [Han2025]_. During inference (forward pass),
53
+ classification decisions are based on the distance (dot product) between the
54
+ feature vector and proto_sep for each class [Han2025]_.
55
+
56
+ .. rubric:: Macro Components
57
+
58
+ - `SSTDPN.encoder` **(Feature Extractor)**
59
+
60
+ - *Operations.* Combines Adaptive Spatial–Spectral Fusion and Multi-scale Variance Pooling
61
+ via an internal :class:`_SSTEncoder`.
62
+ - *Role.* Maps the raw MI-EEG trial :math:`X_i \in \mathbb{R}^{C \times T}` to the
63
+ feature space :math:`z_i \in \mathbb{R}^d`.
64
+
65
+ - `_SSTEncoder.temporal_conv` **(Depthwise Temporal Convolution for Spectral Extraction)**
66
+
67
+ - *Operations.* Internal :class:`_DepthwiseTemporalConv1d` applying separate temporal
68
+ convolution filters to each channel with kernel size `temporal_conv_kernel_size` and
69
+ depth multiplier `n_spectral_filters_temporal` (equivalent to :math:`F_1` in the paper).
70
+ - *Role.* Extracts multiple distinct spectral bands from each EEG channel independently.
71
+
72
+ - `_SSTEncoder.spt_attn` **(Spatial–Spectral Attention for Channel Gating)**
73
+
74
+ - *Operations.* Internal :class:`_SpatSpectralAttn` module using Global Context Embedding
75
+ via variance-based pooling, followed by adaptive channel normalization and gating.
76
+ - *Role.* Reweights channels in the spatial–spectral dimension to extract efficient and
77
+ discriminative features by emphasizing task-relevant regions and frequency bands.
78
+
79
+ - `_SSTEncoder.chan_conv` **(Pointwise Fusion across Channels)**
80
+
81
+ - *Operations.* A 1D pointwise convolution with `n_fused_filters` output channels
82
+ (equivalent to :math:`F_2` in the paper), followed by BatchNorm and the specified
83
+ `activation` function (default: ELU).
84
+ - *Role.* Fuses the weighted spatial–spectral features across all electrodes to produce
85
+ a fused representation :math:`X_{fused} \in \mathbb{R}^{F_2 \times T}`.
86
+
87
+ - `_SSTEncoder.mvp` **(Multi-scale Variance Pooling for Temporal Extraction)**
88
+
89
+ - *Operations.* Internal :class:`_MultiScaleVarPooler` using :class:`_VariancePool1D`
90
+ layers at multiple scales (`mvp_kernel_sizes`), followed by concatenation.
91
+ - *Role.* Captures long-range temporal features at multiple time scales. The variance
92
+ operation leverages the prior that variance represents EEG spectral power.
93
+
94
+ - `SSTDPN.proto_sep` / `SSTDPN.proto_cpt` **(Dual Prototypes)**
95
+
96
+ - *Operations.* Learnable vectors optimized during training using prototype learning losses.
97
+ The `proto_sep` (Inter-class Separation Prototype) is constrained via L2 weight-normalization
98
+ (:math:`\lVert s_i \rVert_2 \leq` `proto_sep_maxnorm`) during inference.
99
+ - *Role.* `proto_sep` achieves inter-class separation; `proto_cpt` enhances intra-class compactness.
100
+
101
+ .. rubric:: How the information is encoded temporally, spatially, and spectrally
102
+
103
+ * **Temporal.**
104
+ The initial :class:`_DepthwiseTemporalConv1d` uses a large kernel (e.g., 75). The MVP module employs pooling
105
+ kernels that are much larger (e.g., 50, 100, 200 samples) to capture long-term temporal
106
+ features effectively. Large kernel pooling layers are shown to be superior to transformer
107
+ modules for this task in EEG decoding according to [Han2025]_.
108
+
109
+ * **Spatial.**
110
+ The initial convolution at the classes :class:`_DepthwiseTemporalConv1d` groups parameter :math:`h=1`,
111
+ meaning :math:`F_1` temporal filters are shared across channels. The Spatial-Spectral Attention
112
+ mechanism explicitly models the relationships among these channels in the spatial–spectral
113
+ dimension, allowing for finer-grained spatial feature modeling compared to conventional
114
+ GCNs according to the authors [Han2025]_.
115
+ In other words, all electrode channels share :math:`F_1` temporal filters
116
+ independently to produce the spatial–spectral representation.
117
+
118
+ * **Spectral.**
119
+ Spectral information is implicitly extracted via the :math:`F_1` filters in :class:`_DepthwiseTemporalConv1d`.
120
+ Furthermore, the use of Variance Pooling (in MVP) explicitly leverages the neurophysiological
121
+ prior that the **variance of EEG signals represents their spectral power**, which is an
122
+ important feature for distinguishing different MI classes [Han2025]_.
123
+
124
+ .. rubric:: Additional Mechanisms
125
+
126
+ - **Attention.** A lightweight Spatial-Spectral Attention mechanism models spatial–spectral relationships
127
+ at the channel level, distinct from applying attention to deep feature dimensions,
128
+ which is common in comparison methods like :class:`ATCNet`.
129
+ - **Regularization.** Dual Prototype Learning acts as a regularization technique
130
+ by optimizing the feature space to be compact within classes and separated between
131
+ classes. This enhances model generalization and classification performance, particularly
132
+ useful for limited data typical of MI-EEG tasks, without requiring external transfer
133
+ learning data, according to [Han2025]_.
134
+
135
+ Notes
136
+ ----------
137
+ * The implementation of the DPL loss functions (:math:`\mathcal{L}_S`, :math:`\mathcal{L}_C`, :math:`\mathcal{L}_{EF}`)
138
+ and the optimization of ICPs are typically handled outside the primary ``forward`` method, within the training strategy
139
+ (see Ref. 52 in [Han2025]_).
140
+ * The default parameters are configured based on the BCI Competition IV 2a dataset.
141
+ * The use of Prototype Learning (PL) methods is novel in the field of EEG-MI decoding.
142
+ * **Lowest FLOPs:** Achieves the lowest Floating Point Operations (FLOPs) (9.65 M) among competitive
143
+ SOTA methods, including braindecode models like :class:`ATCNet` (29.81 M) and
144
+ :class:`EEGConformer` (63.86 M), demonstrating computational efficiency [Han2025]_.
145
+ * **Transformer Alternative:** Multi-scale Variance Pooling (MVP) provides a accuracy
146
+ improvement over temporal attention transformer modules in ablation studies, offering a more
147
+ efficient alternative to transformer-based approaches like :class:`EEGConformer` [Han2025]_.
148
+
149
+ .. warning::
150
+
151
+ **Important:** To utilize the full potential of SSTDPN with Dual Prototype Learning (DPL),
152
+ users must implement the DPL optimization strategy outside the model's forward method.
153
+ For implementation details and training strategies, please consult the official code at
154
+ [Han2025Code]_:
155
+ https://github.com/hancan16/SST-DPN/blob/main/train.py
156
+
157
+ Parameters
158
+ ----------
159
+ n_spectral_filters_temporal : int, optional
160
+ Number of spectral filters extracted per channel via temporal convolution.
161
+ These represent the temporal spectral bands (equivalent to :math:`F_1` in the paper).
162
+ Default is 9.
163
+
164
+ n_fused_filters : int, optional
165
+ Number of output filters after pointwise fusion convolution.
166
+ These fuse the spectral filters across all channels (equivalent to :math:`F_2` in the paper).
167
+ Default is 48.
168
+
169
+ temporal_conv_kernel_size : int, optional
170
+ Kernel size for the temporal convolution layer. Controls the receptive field for extracting
171
+ spectral information. Default is 75 samples.
172
+
173
+ mvp_kernel_sizes : list[int], optional
174
+ Kernel sizes for Multi-scale Variance Pooling (MVP) module.
175
+ Larger kernels capture long-term temporal dependencies .
176
+
177
+ return_features : bool, optional
178
+ If True, the forward pass returns (features, logits). If False, returns only logits.
179
+ Default is False.
180
+
181
+ proto_sep_maxnorm : float, optional
182
+ Maximum L2 norm constraint for Inter-class Separation Prototypes during forward pass.
183
+ This constraint acts as an implicit force to push features away from the origin. Default is 1.0.
184
+
185
+ proto_cpt_std : float, optional
186
+ Standard deviation for Intra-class Compactness Prototype initialization. Default is 0.01.
187
+
188
+ spt_attn_global_context_kernel : int, optional
189
+ Kernel size for global context embedding in Spatial-Spectral Attention module.
190
+ Default is 250 samples.
191
+
192
+ spt_attn_epsilon : float, optional
193
+ Small epsilon value for numerical stability in Spatial-Spectral Attention. Default is 1e-5.
194
+
195
+ spt_attn_mode : str, optional
196
+ Embedding computation mode for Spatial-Spectral Attention ('var', 'l2', or 'l1').
197
+ Default is 'var' (variance-based mean-var operation).
198
+
199
+ activation : nn.Module, optional
200
+ Activation function to apply after the pointwise fusion convolution in :class:`_SSTEncoder`.
201
+ Should be a PyTorch activation module class. Default is nn.ELU.
202
+
203
+
204
+ References
205
+ ----------
206
+ .. [Han2025] Han, C., Liu, C., Wang, J., Wang, Y., Cai, C.,
207
+ & Qian, D. (2025). A spatial–spectral and temporal dual
208
+ prototype network for motor imagery brain–computer
209
+ interface. Knowledge-Based Systems, 315, 113315.
210
+ .. [Han2025Code] Han, C., Liu, C., Wang, J., Wang, Y.,
211
+ Cai, C., & Qian, D. (2025). A spatial–spectral and
212
+ temporal dual prototype network for motor imagery
213
+ brain–computer interface. Knowledge-Based Systems,
214
+ 315, 113315. GitHub repository.
215
+ https://github.com/hancan16/SST-DPN.
216
+ """
217
+
218
+ def __init__(
219
+ self,
220
+ # Braindecode standard parameters
221
+ n_chans=None,
222
+ n_times=None,
223
+ n_outputs=None,
224
+ input_window_seconds=None,
225
+ sfreq=None,
226
+ chs_info=None,
227
+ # models parameters
228
+ n_spectral_filters_temporal: int = 9,
229
+ n_fused_filters: int = 48,
230
+ temporal_conv_kernel_size: int = 75,
231
+ mvp_kernel_sizes: Optional[List[int]] = None,
232
+ return_features: bool = False,
233
+ proto_sep_maxnorm: float = 1.0,
234
+ proto_cpt_std: float = 0.01,
235
+ spt_attn_global_context_kernel: int = 250,
236
+ spt_attn_epsilon: float = 1e-5,
237
+ spt_attn_mode: str = "var",
238
+ activation: Optional[nn.Module] = nn.ELU,
239
+ ) -> None:
240
+ super().__init__(
241
+ n_chans=n_chans,
242
+ n_outputs=n_outputs,
243
+ chs_info=chs_info,
244
+ n_times=n_times,
245
+ input_window_seconds=input_window_seconds,
246
+ sfreq=sfreq,
247
+ )
248
+ del input_window_seconds, sfreq, chs_info, n_chans, n_outputs, n_times
249
+
250
+ # Set default activation if not provided
251
+ if activation is None:
252
+ activation = nn.ELU
253
+
254
+ # Store hyperparameters
255
+ self.n_spectral_filters_temporal = n_spectral_filters_temporal
256
+ self.n_fused_filters = n_fused_filters
257
+ self.temporal_conv_kernel_size = temporal_conv_kernel_size
258
+ self.mvp_kernel_sizes = (
259
+ mvp_kernel_sizes if mvp_kernel_sizes is not None else [50, 100, 200]
260
+ )
261
+ self.return_features = return_features
262
+ self.proto_sep_maxnorm = proto_sep_maxnorm
263
+ self.proto_cpt_std = proto_cpt_std
264
+ self.spt_attn_global_context_kernel = spt_attn_global_context_kernel
265
+ self.spt_attn_epsilon = spt_attn_epsilon
266
+ self.spt_attn_mode = spt_attn_mode
267
+ self.activation = activation
268
+
269
+ # Encoder accepts (batch, n_chans, n_times)
270
+ self.encoder = _SSTEncoder(
271
+ n_times=self.n_times,
272
+ n_chans=self.n_chans,
273
+ n_spectral_filters_temporal=self.n_spectral_filters_temporal,
274
+ n_fused_filters=self.n_fused_filters,
275
+ temporal_conv_kernel_size=self.temporal_conv_kernel_size,
276
+ mvp_kernel_sizes=self.mvp_kernel_sizes,
277
+ spt_attn_global_context_kernel=self.spt_attn_global_context_kernel,
278
+ spt_attn_epsilon=self.spt_attn_epsilon,
279
+ spt_attn_mode=self.spt_attn_mode,
280
+ activation=self.activation,
281
+ )
282
+
283
+ # Infer feature dimension analytically
284
+ feat_dim = self._compute_feature_dim()
285
+
286
+ # Prototypes: Inter-class Separation (ISP) and Intra-class Compactness (ICP)
287
+ # ISP: provides inter-class separation via prototype learning
288
+ # ICP: enhances intra-class compactness
289
+ self.proto_sep = nn.Parameter(
290
+ torch.empty(self.n_outputs, feat_dim), requires_grad=True
291
+ )
292
+ # This parameters is not used in the forward pass, only during training for the
293
+ # prototype learning losses. You should implement the losses outside this class.
294
+ self.proto_cpt = nn.Parameter(
295
+ torch.empty(self.n_outputs, feat_dim), requires_grad=True
296
+ )
297
+ # just for braindecode compatibility
298
+ self.final_layer = nn.Identity()
299
+
300
+ self._reset_parameters()
301
+
302
+ def _reset_parameters(self) -> None:
303
+ """Initialize prototype parameters."""
304
+ nn.init.kaiming_normal_(self.proto_sep)
305
+ nn.init.normal_(self.proto_cpt, mean=0.0, std=self.proto_cpt_std)
306
+
307
+ def forward(
308
+ self, x: torch.Tensor
309
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
310
+ """
311
+ Classification is based on the dot product similarity with
312
+ Inter-class Separation Prototypes (:attr:`SSTDPN.proto_sep`).
313
+
314
+
315
+ Parameters
316
+ ----------
317
+ x : torch.Tensor
318
+ Input tensor. Supported shapes:
319
+ - (batch, n_chans, n_times)
320
+
321
+ Returns
322
+ -------
323
+ logits : torch.Tensor
324
+ If input was 3D: (batch, n_outputs)
325
+ Or if self.return_features is True:
326
+ (features, logits) where features shape is (batch, feat_dim)
327
+ """
328
+
329
+ features = self.encoder(x) # (b, feat_dim)
330
+ # Renormalize inter-class separation prototypes
331
+ self.proto_sep.data = torch.renorm(
332
+ self.proto_sep.data, p=2, dim=1, maxnorm=self.proto_sep_maxnorm
333
+ )
334
+ logits = torch.einsum("bd,cd->bc", features, self.proto_sep) # (b, n_outputs)
335
+ logits = self.final_layer(logits)
336
+
337
+ if self.return_features:
338
+ return features, logits
339
+
340
+ return logits
341
+
342
+ def _compute_feature_dim(self) -> int:
343
+ """Compute encoder feature dimensionality without a forward pass."""
344
+ if not self.mvp_kernel_sizes:
345
+ raise ValueError(
346
+ "`mvp_kernel_sizes` must contain at least one kernel size."
347
+ )
348
+
349
+ num_scales = len(self.mvp_kernel_sizes)
350
+ channels_per_scale, rest = divmod(self.n_fused_filters, num_scales)
351
+ if rest:
352
+ raise ValueError(
353
+ "Number of fused filters must be divisible by the number of MVP scales. "
354
+ f"Got {self.n_fused_filters=} and {num_scales=}."
355
+ )
356
+
357
+ # Validate all kernel sizes at once (stride = k // 2 must be >= 1)
358
+ invalid = [k for k in self.mvp_kernel_sizes if k // 2 == 0]
359
+ if invalid:
360
+ raise ValueError(
361
+ "MVP kernel sizes too small to derive a valid stride (k//2 == 0): "
362
+ f"{invalid}"
363
+ )
364
+
365
+ pooled_total = sum(
366
+ self._pool1d_output_length(
367
+ length=self.n_times, kernel_size=k, stride=k // 2, padding=0, dilation=1
368
+ )
369
+ for k in self.mvp_kernel_sizes
370
+ )
371
+ return channels_per_scale * pooled_total
372
+
373
+ @staticmethod
374
+ def _pool1d_output_length(
375
+ length: int, kernel_size: int, stride: int, padding: int = 0, dilation: int = 1
376
+ ) -> int:
377
+ """Temporal length after 1D pooling (PyTorch-style formula)."""
378
+ return max(
379
+ 0,
380
+ (length + 2 * padding - (dilation * (kernel_size - 1) + 1)) // stride + 1,
381
+ )
382
+
383
+
384
+ class _SSTEncoder(nn.Module):
385
+ """Internal encoder combining Adaptive Spatial-Spectral Fusion and Multi-scale Variance Pooling.
386
+
387
+ This class should not be instantiated directly. It is an internal component
388
+ of :class:`SSTDPN`.
389
+
390
+ Parameters
391
+ ----------
392
+ n_times : int
393
+ Number of time samples in the input window.
394
+ n_chans : int
395
+ Number of EEG channels.
396
+ n_spectral_filters_temporal : int
397
+ Number of spectral filters extracted via temporal convolution (:math:`F_1`).
398
+ n_fused_filters : int
399
+ Number of output filters after pointwise fusion (:math:`F_2`).
400
+ temporal_conv_kernel_size : int
401
+ Kernel size for temporal convolution.
402
+ mvp_kernel_sizes : list[int]
403
+ Kernel sizes for Multi-scale Variance Pooling.
404
+ spt_attn_global_context_kernel : int
405
+ Kernel size for global context in Spatial-Spectral Attention.
406
+ spt_attn_epsilon : float
407
+ Epsilon for numerical stability in Spatial-Spectral Attention.
408
+ spt_attn_mode : str
409
+ Mode for Spatial-Spectral Attention computation ('var', 'l2', or 'l1').
410
+ activation : nn.Module, optional
411
+ Activation function class to use after pointwise convolution. Default is nn.ELU.
412
+ """
413
+
414
+ def __init__(
415
+ self,
416
+ n_times: int,
417
+ n_chans: int,
418
+ n_spectral_filters_temporal: int = 9,
419
+ n_fused_filters: int = 48,
420
+ temporal_conv_kernel_size: int = 75,
421
+ mvp_kernel_sizes: Optional[List[int]] = None,
422
+ spt_attn_global_context_kernel: int = 250,
423
+ spt_attn_epsilon: float = 1e-5,
424
+ spt_attn_mode: str = "var",
425
+ activation: Optional[nn.Module] = None,
426
+ ) -> None:
427
+ super().__init__()
428
+
429
+ if mvp_kernel_sizes is None:
430
+ mvp_kernel_sizes = [50, 100, 200]
431
+
432
+ if activation is None:
433
+ activation = nn.ELU
434
+
435
+ # Adaptive Spatial-Spectral Fusion (ASSF): Temporal convolution for spectral filtering
436
+ self.temporal_conv = _DepthwiseTemporalConv1d(
437
+ in_channels=n_chans,
438
+ num_heads=1,
439
+ n_spectral_filters_temporal=n_spectral_filters_temporal,
440
+ kernel_size=temporal_conv_kernel_size,
441
+ stride=1,
442
+ padding="same",
443
+ bias=True,
444
+ weight_softmax=False,
445
+ )
446
+ # Spatial-Spectral Attention: Gate mechanism for channel weighting
447
+ self.spt_attn = _SpatSpectralAttn(
448
+ T=n_times,
449
+ num_channels=n_chans * n_spectral_filters_temporal,
450
+ epsilon=spt_attn_epsilon,
451
+ mode=spt_attn_mode,
452
+ global_context_kernel=spt_attn_global_context_kernel,
453
+ )
454
+
455
+ # Pointwise convolution for fusing spectral filters across channels
456
+ self.chan_conv = nn.Sequential(
457
+ nn.Conv1d(
458
+ n_chans * n_spectral_filters_temporal,
459
+ n_fused_filters,
460
+ kernel_size=1,
461
+ stride=1,
462
+ padding=0,
463
+ bias=True,
464
+ ),
465
+ nn.BatchNorm1d(n_fused_filters),
466
+ activation(),
467
+ )
468
+
469
+ # Multi-scale Variance Pooling (MVP): Temporal feature extraction at multiple scales
470
+ self.mvp = _MultiScaleVarPooler(kernel_sizes=mvp_kernel_sizes)
471
+
472
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
473
+ """
474
+ Forward pass.
475
+
476
+ Parameters
477
+ ----------
478
+ x : torch.Tensor
479
+ Input of shape (batch, n_chans, n_times).
480
+
481
+ Returns
482
+ -------
483
+ torch.Tensor
484
+ Feature vector of shape (batch, feat_dim).
485
+ """
486
+ x = self.temporal_conv(x) # (b, n_chans*n_spectral_filters_temporal, T)
487
+ x, _ = self.spt_attn(x) # (b, n_chans*n_spectral_filters_temporal, T)
488
+ x_fused = self.chan_conv(x) # (b, n_fused_filters, T)
489
+ feature = self.mvp(x_fused) # (b, feat_dim)
490
+ return feature
491
+
492
+
493
+ class _DepthwiseTemporalConv1d(nn.Module):
494
+ """Internal depthwise temporal convolution for spectral filtering.
495
+
496
+ Applies separate temporal convolution filters to each channel independently
497
+ to extract spectral information across multiple bands. This is used to generate
498
+ the spatial-spectral representation in SSTDPN.
499
+
500
+ Not intended for external use.
501
+
502
+ Parameters
503
+ ----------
504
+ in_channels : int
505
+ Number of input channels.
506
+ num_heads : int, optional
507
+ Number of filter groups (typically 1). Default is 1.
508
+ n_spectral_filters_temporal : int, optional
509
+ Number of spectral filters per channel (depth multiplier). Default is 1.
510
+ kernel_size : int, optional
511
+ Temporal convolution kernel size. Default is 1.
512
+ stride : int, optional
513
+ Convolution stride. Default is 1.
514
+ padding : str or int, optional
515
+ Padding mode. Default is 0.
516
+ bias : bool, optional
517
+ Whether to use bias. Default is True.
518
+ weight_softmax : bool, optional
519
+ Whether to apply softmax to weights. Default is False.
520
+ """
521
+
522
+ def __init__(
523
+ self,
524
+ in_channels: int,
525
+ num_heads: int = 1,
526
+ n_spectral_filters_temporal: int = 1,
527
+ kernel_size: int = 1,
528
+ stride: int = 1,
529
+ padding: Union[str, int] = 0,
530
+ bias: bool = True,
531
+ weight_softmax: bool = False,
532
+ ) -> None:
533
+ super().__init__()
534
+ self.in_channels = in_channels
535
+ self.kernel_size = kernel_size
536
+ self.stride = stride
537
+ self.num_heads = num_heads
538
+ self.padding = padding
539
+ self.weight_softmax = weight_softmax
540
+
541
+ self.weight = nn.Parameter(
542
+ torch.Tensor(num_heads * n_spectral_filters_temporal, 1, kernel_size)
543
+ )
544
+ self.bias = (
545
+ nn.Parameter(torch.Tensor(num_heads * n_spectral_filters_temporal))
546
+ if bias
547
+ else None
548
+ )
549
+
550
+ self._init_parameters()
551
+
552
+ def _init_parameters(self) -> None:
553
+ """Initialize parameters."""
554
+ nn.init.xavier_uniform_(self.weight)
555
+ if self.bias is not None:
556
+ nn.init.constant_(self.bias, 0.0)
557
+
558
+ def forward(self, inp: torch.Tensor) -> torch.Tensor:
559
+ """
560
+ Forward pass.
561
+
562
+ Parameters
563
+ ----------
564
+ inp : torch.Tensor
565
+ Input of shape (batch, in_channels, time).
566
+
567
+ Returns
568
+ -------
569
+ torch.Tensor
570
+ Output of shape (batch, num_heads * n_spectral_filters_temporal, time).
571
+ """
572
+ B, _, _ = inp.size()
573
+ H = self.num_heads
574
+ weight = self.weight
575
+ if self.weight_softmax:
576
+ weight = F.softmax(weight, dim=-1)
577
+
578
+ inp = rearrange(inp, "b (h c) t -> (b c) h t", h=H)
579
+ if self.bias is None:
580
+ output = F.conv1d(
581
+ inp,
582
+ weight,
583
+ stride=self.stride,
584
+ padding=self.padding,
585
+ groups=self.num_heads,
586
+ )
587
+ else:
588
+ output = F.conv1d(
589
+ inp,
590
+ weight,
591
+ bias=self.bias,
592
+ stride=self.stride,
593
+ padding=self.padding,
594
+ groups=self.num_heads,
595
+ )
596
+ output = rearrange(output, "(b c) h t -> b (h c) t", b=B)
597
+ return output
598
+
599
+
600
+ class _GlobalContextVarPool1D(nn.Module):
601
+ """Internal global context variance pooling module.
602
+
603
+ Computes variance-based global context embeddings using specified kernel size.
604
+ Used in the Spatial-Spectral Attention module.
605
+
606
+ Not intended for external use.
607
+
608
+ Parameters
609
+ ----------
610
+ T : int
611
+ Sequence length.
612
+ kernel_size : int
613
+ Pooling kernel size.
614
+ stride : int or None, optional
615
+ Stride. If None, defaults to kernel_size. Default is None.
616
+ padding : int, optional
617
+ Padding. Default is 0.
618
+ """
619
+
620
+ def __init__(
621
+ self, kernel_size: int, stride: Optional[int] = None, padding: int = 0
622
+ ) -> None:
623
+ super().__init__()
624
+ self.kernel_size = kernel_size
625
+ self.stride = kernel_size if stride is None else stride
626
+ self.padding = padding
627
+
628
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
629
+ """
630
+ Forward pass computing global context via variance pooling.
631
+
632
+ Parameters
633
+ ----------
634
+ x : torch.Tensor
635
+ Input tensor.
636
+
637
+ Returns
638
+ -------
639
+ torch.Tensor
640
+ Global context (variance-pooled) output.
641
+ """
642
+ mean_of_squares = F.avg_pool1d(
643
+ x**2, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding
644
+ )
645
+ square_of_mean = (
646
+ F.avg_pool1d(
647
+ x,
648
+ kernel_size=self.kernel_size,
649
+ stride=self.stride,
650
+ padding=self.padding,
651
+ )
652
+ ) ** 2
653
+ variance = mean_of_squares - square_of_mean
654
+ out = F.avg_pool1d(variance, variance.shape[-1])
655
+ return out
656
+
657
+
658
+ class _VariancePool1D(nn.Module):
659
+ """Internal variance pooling module for temporal feature extraction.
660
+
661
+ Applies variance pooling at a specified kernel size to capture temporal dynamics.
662
+ Used in the Multi-scale Variance Pooling (MVP) module.
663
+
664
+ Not intended for external use.
665
+
666
+ Parameters
667
+ ----------
668
+ kernel_size : int
669
+ Pooling kernel size (receptive field width).
670
+ stride : int or None, optional
671
+ Stride. If None, defaults to kernel_size. Default is None.
672
+ padding : int, optional
673
+ Padding. Default is 0.
674
+ """
675
+
676
+ def __init__(
677
+ self, kernel_size: int, stride: Optional[int] = None, padding: int = 0
678
+ ) -> None:
679
+ super().__init__()
680
+ self.kernel_size = kernel_size
681
+ self.stride = kernel_size if stride is None else stride
682
+ self.padding = padding
683
+
684
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
685
+ """
686
+ Forward pass computing variance pooling.
687
+
688
+ Parameters
689
+ ----------
690
+ x : torch.Tensor
691
+ Input tensor of shape (batch, channels, time).
692
+
693
+ Returns
694
+ -------
695
+ torch.Tensor
696
+ Variance-pooled output.
697
+ """
698
+ mean_of_squares = F.avg_pool1d(
699
+ x**2, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding
700
+ )
701
+ square_of_mean = (
702
+ F.avg_pool1d(
703
+ x,
704
+ kernel_size=self.kernel_size,
705
+ stride=self.stride,
706
+ padding=self.padding,
707
+ )
708
+ ) ** 2
709
+ variance = mean_of_squares - square_of_mean
710
+ return variance
711
+
712
+
713
+ class _SpatSpectralAttn(nn.Module):
714
+ """Internal Spatial-Spectral Attention module with global context gating.
715
+
716
+ This attention mechanism computes channel-wise gates based on global context
717
+ embedding and applies adaptive reweighting to highlight task-relevant
718
+ spatial-spectral features.
719
+
720
+ Not intended for external use. Used internally in :class:`_SSTEncoder`.
721
+
722
+ Parameters
723
+ ----------
724
+ T : int
725
+ Sequence (temporal) length.
726
+ num_channels : int
727
+ Number of channels in the spatial-spectral dimension.
728
+ epsilon : float, optional
729
+ Small value for numerical stability. Default is 1e-5.
730
+ mode : str, optional
731
+ Embedding computation mode: 'var' (variance-based), 'l2' (L2-norm),
732
+ or 'l1' (L1-norm). Default is 'var'.
733
+ after_relu : bool, optional
734
+ Whether ReLU is applied before this module. Default is False.
735
+ global_context_kernel : int, optional
736
+ Kernel size for global context variance pooling. Default is 250.
737
+ """
738
+
739
+ def __init__(
740
+ self,
741
+ T: int,
742
+ num_channels: int,
743
+ epsilon: float = 1e-5,
744
+ mode: str = "var",
745
+ after_relu: bool = False,
746
+ global_context_kernel: int = 250,
747
+ ) -> None:
748
+ super().__init__()
749
+ # Learnable gating parameters: scale, normalize, and shift
750
+ self.alpha = nn.Parameter(torch.ones(1, num_channels, 1))
751
+ self.gamma = nn.Parameter(torch.zeros(1, num_channels, 1))
752
+ self.beta = nn.Parameter(torch.zeros(1, num_channels, 1))
753
+ self.epsilon = epsilon
754
+ self.mode = mode
755
+ self.after_relu = after_relu
756
+ # check mode validity
757
+ if self.mode not in ["var", "l2", "l1"]:
758
+ raise ValueError(
759
+ f"Unsupported Spatial-Spectral Attention mode: {self.mode}"
760
+ )
761
+
762
+ # Global context module using variance pooling
763
+ self.global_ctx = _GlobalContextVarPool1D(global_context_kernel)
764
+
765
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
766
+ """
767
+ Forward pass computing adaptive channel-wise gating.
768
+
769
+ Parameters
770
+ ----------
771
+ x : torch.Tensor
772
+ Input of shape (batch, channels, time).
773
+
774
+ Returns
775
+ -------
776
+ tuple of torch.Tensor
777
+ (gated_output, gate) where both have the same shape as input.
778
+ """
779
+
780
+ if self.mode == "l2":
781
+ # L2-norm based embedding
782
+ embedding = (x.pow(2).sum(2, keepdim=True) + self.epsilon).pow(0.5)
783
+ norm = self.gamma / (
784
+ embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon
785
+ ).pow(0.5)
786
+ elif self.mode == "l1":
787
+ # L1-norm based embedding
788
+ _x = torch.abs(x) if not self.after_relu else x
789
+ embedding = _x.sum(2, keepdim=True)
790
+ norm = self.gamma / (
791
+ torch.abs(embedding).mean(dim=1, keepdim=True) + self.epsilon
792
+ )
793
+ elif self.mode == "var":
794
+ # Variance-based embedding (global context)
795
+ embedding = (self.global_ctx(x) + self.epsilon).pow(0.5) * self.alpha
796
+ norm = (self.gamma) / (
797
+ embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon
798
+ ).pow(0.5)
799
+
800
+ # Compute adaptive gate: 1 + tanh(...)
801
+ gate = 1 + torch.tanh(embedding * norm + self.beta)
802
+ return x * gate, gate
803
+
804
+
805
+ class _MultiScaleVarPooler(nn.Module):
806
+ """Internal Multi-scale Variance Pooling (MVP) module for temporal feature extraction.
807
+
808
+ Applies variance pooling at multiple temporal scales in parallel, then concatenates
809
+ the results to capture long-range temporal dependencies. Each scale processes a subset
810
+ of channels independently, enabling efficient feature extraction.
811
+
812
+ Not intended for external use. Used internally in :class:`_SSTEncoder`.
813
+
814
+ Parameters
815
+ ----------
816
+ kernel_sizes : list[int] or None, optional
817
+ Kernel sizes for variance pooling layers at each scale. If None,
818
+ defaults to [50, 100, 200] (suitable for 1000-sample windows).
819
+ """
820
+
821
+ def __init__(self, kernel_sizes: Optional[List[int]] = None) -> None:
822
+ super().__init__()
823
+
824
+ if kernel_sizes is None:
825
+ kernel_sizes = [50, 100, 200]
826
+
827
+ self.var_layers = nn.ModuleList()
828
+ self.num_scales = len(kernel_sizes)
829
+
830
+ # Create variance pooling layer for each scale
831
+ for k in kernel_sizes:
832
+ self.var_layers.append(
833
+ nn.Sequential(
834
+ _VariancePool1D(kernel_size=k, stride=int(k / 2)),
835
+ nn.Flatten(start_dim=1),
836
+ )
837
+ )
838
+
839
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
840
+ """
841
+ Forward pass applying multi-scale variance pooling in parallel.
842
+
843
+ Parameters
844
+ ----------
845
+ x : torch.Tensor
846
+ Input of shape (batch, channels, time).
847
+
848
+ Returns
849
+ -------
850
+ torch.Tensor
851
+ Concatenated multi-scale features of shape (batch, total_features).
852
+ """
853
+ _, num_channels, _ = x.shape
854
+ # Split channels equally across scales
855
+ assert num_channels % self.num_scales == 0, (
856
+ f"Channel dimension ({num_channels}) must be divisible by "
857
+ f"number of scales ({self.num_scales})"
858
+ )
859
+ channels_per_scale = num_channels // self.num_scales
860
+ x_split = torch.split(x, channels_per_scale, dim=1)
861
+
862
+ # Apply variance pooling at each scale
863
+ multi_scale_features = []
864
+ for scale_idx, x_scale in enumerate(x_split):
865
+ multi_scale_features.append(self.var_layers[scale_idx](x_scale))
866
+
867
+ # Concatenate features from all scales
868
+ y = torch.concat(multi_scale_features, dim=1)
869
+ return y