braindecode 1.3.0.dev180851780__py3-none-any.whl → 1.3.0.dev183667303__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/augmentation/base.py +1 -1
- braindecode/datasets/__init__.py +10 -2
- braindecode/datasets/base.py +115 -151
- braindecode/datasets/bcicomp.py +4 -4
- braindecode/datasets/bids.py +3 -3
- braindecode/datasets/experimental.py +2 -2
- braindecode/datasets/mne.py +3 -5
- braindecode/datasets/moabb.py +2 -2
- braindecode/datasets/nmt.py +2 -2
- braindecode/datasets/sleep_physio_challe_18.py +2 -2
- braindecode/datasets/sleep_physionet.py +2 -2
- braindecode/datasets/tuh.py +2 -2
- braindecode/datasets/xy.py +2 -2
- braindecode/datautil/serialization.py +7 -7
- braindecode/eegneuralnet.py +2 -0
- braindecode/functional/functions.py +6 -2
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +2 -0
- braindecode/models/atcnet.py +25 -26
- braindecode/models/attentionbasenet.py +39 -32
- braindecode/models/base.py +280 -2
- braindecode/models/bendr.py +469 -0
- braindecode/models/biot.py +2 -0
- braindecode/models/ctnet.py +6 -3
- braindecode/models/deepsleepnet.py +27 -18
- braindecode/models/eegconformer.py +2 -2
- braindecode/models/eeginception_erp.py +31 -25
- braindecode/models/eegnet.py +1 -1
- braindecode/models/labram.py +193 -87
- braindecode/models/signal_jepa.py +109 -27
- braindecode/models/sinc_shallow.py +10 -9
- braindecode/models/sstdpn.py +11 -11
- braindecode/models/summary.csv +1 -0
- braindecode/models/usleep.py +26 -21
- braindecode/models/util.py +1 -0
- braindecode/modules/attention.py +10 -10
- braindecode/modules/blocks.py +3 -3
- braindecode/modules/filter.py +2 -3
- braindecode/modules/layers.py +18 -17
- braindecode/preprocessing/__init__.py +24 -0
- braindecode/preprocessing/eegprep_preprocess.py +1202 -0
- braindecode/preprocessing/preprocess.py +12 -12
- braindecode/preprocessing/util.py +166 -0
- braindecode/preprocessing/windowers.py +24 -19
- braindecode/samplers/base.py +8 -8
- braindecode/version.py +1 -1
- {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev183667303.dist-info}/METADATA +6 -2
- {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev183667303.dist-info}/RECORD +52 -49
- {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev183667303.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev183667303.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev183667303.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev183667303.dist-info}/top_level.txt +0 -0
|
@@ -19,8 +19,8 @@ from joblib import Parallel, delayed
|
|
|
19
19
|
|
|
20
20
|
from ..datasets.base import (
|
|
21
21
|
BaseConcatDataset,
|
|
22
|
-
BaseDataset,
|
|
23
22
|
EEGWindowsDataset,
|
|
23
|
+
RawDataset,
|
|
24
24
|
WindowsDataset,
|
|
25
25
|
)
|
|
26
26
|
|
|
@@ -35,7 +35,7 @@ def save_concat_dataset(path, concat_dataset, overwrite=False):
|
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
def _outdated_load_concat_dataset(path, preload, ids_to_load=None, target_name=None):
|
|
38
|
-
"""Load a stored BaseConcatDataset
|
|
38
|
+
"""Load a stored BaseConcatDataset from
|
|
39
39
|
files.
|
|
40
40
|
|
|
41
41
|
Parameters
|
|
@@ -52,7 +52,7 @@ def _outdated_load_concat_dataset(path, preload, ids_to_load=None, target_name=N
|
|
|
52
52
|
|
|
53
53
|
Returns
|
|
54
54
|
-------
|
|
55
|
-
concat_dataset: BaseConcatDataset
|
|
55
|
+
concat_dataset: BaseConcatDataset
|
|
56
56
|
"""
|
|
57
57
|
# assume we have a single concat dataset to load
|
|
58
58
|
is_raw = (path / "0-raw.fif").is_file()
|
|
@@ -87,7 +87,7 @@ def _outdated_load_concat_dataset(path, preload, ids_to_load=None, target_name=N
|
|
|
87
87
|
for i_signal, signal in enumerate(all_signals):
|
|
88
88
|
if is_raw:
|
|
89
89
|
datasets.append(
|
|
90
|
-
|
|
90
|
+
RawDataset(
|
|
91
91
|
signal, description.iloc[i_signal], target_name=target_name
|
|
92
92
|
)
|
|
93
93
|
)
|
|
@@ -175,7 +175,7 @@ def _load_signals(fif_file, preload, is_raw):
|
|
|
175
175
|
|
|
176
176
|
|
|
177
177
|
def load_concat_dataset(path, preload, ids_to_load=None, target_name=None, n_jobs=1):
|
|
178
|
-
"""Load a stored BaseConcatDataset
|
|
178
|
+
"""Load a stored BaseConcatDataset from
|
|
179
179
|
files.
|
|
180
180
|
|
|
181
181
|
Parameters
|
|
@@ -194,7 +194,7 @@ def load_concat_dataset(path, preload, ids_to_load=None, target_name=None, n_job
|
|
|
194
194
|
|
|
195
195
|
Returns
|
|
196
196
|
-------
|
|
197
|
-
concat_dataset: BaseConcatDataset
|
|
197
|
+
concat_dataset: BaseConcatDataset
|
|
198
198
|
"""
|
|
199
199
|
# Make sure we always work with a pathlib.Path
|
|
200
200
|
path = Path(path)
|
|
@@ -266,7 +266,7 @@ def _load_parallel(path, i, preload, is_raw, has_stored_windows):
|
|
|
266
266
|
target_name = json.load(open(target_file_path, "r"))["target_name"]
|
|
267
267
|
|
|
268
268
|
if is_raw and (not has_stored_windows):
|
|
269
|
-
dataset =
|
|
269
|
+
dataset = RawDataset(signals, description, target_name)
|
|
270
270
|
else:
|
|
271
271
|
window_kwargs = _load_kwargs_json("window_kwargs", sub_dir)
|
|
272
272
|
windows_ds_kwargs = [
|
braindecode/eegneuralnet.py
CHANGED
|
@@ -189,6 +189,8 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
189
189
|
"Skipping setting signal-related parameters from data."
|
|
190
190
|
)
|
|
191
191
|
return
|
|
192
|
+
if classes is None:
|
|
193
|
+
classes = getattr(self, "classes", None)
|
|
192
194
|
# get kwargs from signal:
|
|
193
195
|
signal_kwargs = dict()
|
|
194
196
|
# Using shape to work both with torch.tensor and numpy.array:
|
|
@@ -181,20 +181,24 @@ def plv_time(x, forward_fourier=True, epsilon: float = 1e-6):
|
|
|
181
181
|
The Phase Locking Value (PLV) is a measure of the synchronization between
|
|
182
182
|
different channels by evaluating the consistency of phase differences
|
|
183
183
|
over time. It ranges from 0 (no synchronization) to 1 (perfect
|
|
184
|
-
synchronization) [
|
|
184
|
+
synchronization) [Lachaux1999]_.
|
|
185
185
|
|
|
186
186
|
Parameters
|
|
187
187
|
----------
|
|
188
188
|
x : torch.Tensor
|
|
189
189
|
Input tensor containing the signal data.
|
|
190
|
+
|
|
190
191
|
- If `forward_fourier` is `True`, the shape should be `(..., channels, time)`.
|
|
191
192
|
- If `forward_fourier` is `False`, the shape should be `(..., channels, freqs, 2)`,
|
|
192
193
|
where the last dimension represents the real and imaginary parts.
|
|
194
|
+
|
|
193
195
|
forward_fourier : bool, optional
|
|
194
196
|
Specifies the format of the input tensor `x`.
|
|
197
|
+
|
|
195
198
|
- If `True`, `x` is assumed to be in the time domain.
|
|
196
199
|
- If `False`, `x` is assumed to be in the Fourier domain with separate real and
|
|
197
200
|
imaginary components.
|
|
201
|
+
|
|
198
202
|
Default is `True`.
|
|
199
203
|
epsilon : float, default 1e-6
|
|
200
204
|
Small numerical value to ensure positivity constraint on the complex part
|
|
@@ -207,7 +211,7 @@ def plv_time(x, forward_fourier=True, epsilon: float = 1e-6):
|
|
|
207
211
|
|
|
208
212
|
References
|
|
209
213
|
----------
|
|
210
|
-
[
|
|
214
|
+
.. [Lachaux1999] Lachaux, J. P., Rodriguez, E., Martinerie, J., & Varela, F. J. (1999).
|
|
211
215
|
Measuring phase synchrony in brain signals. Human brain mapping,
|
|
212
216
|
8(4), 194-208.
|
|
213
217
|
"""
|
|
@@ -5,9 +5,8 @@ from torch import nn
|
|
|
5
5
|
|
|
6
6
|
def glorot_weight_zero_bias(model):
|
|
7
7
|
"""Initialize parameters of all modules by initializing weights with
|
|
8
|
-
glorot
|
|
9
|
-
|
|
10
|
-
batch norm layers are set to 1.
|
|
8
|
+
glorot uniform/xavier initialization, and setting biases to zero. Weights from
|
|
9
|
+
batch norm layers are set to 1.
|
|
11
10
|
|
|
12
11
|
Parameters
|
|
13
12
|
----------
|
braindecode/models/__init__.py
CHANGED
|
@@ -6,6 +6,7 @@ from .atcnet import ATCNet
|
|
|
6
6
|
from .attentionbasenet import AttentionBaseNet
|
|
7
7
|
from .attn_sleep import AttnSleep
|
|
8
8
|
from .base import EEGModuleMixin
|
|
9
|
+
from .bendr import BENDR
|
|
9
10
|
from .biot import BIOT
|
|
10
11
|
from .contrawr import ContraWR
|
|
11
12
|
from .ctnet import CTNet
|
|
@@ -58,6 +59,7 @@ __all__ = [
|
|
|
58
59
|
"AttentionBaseNet",
|
|
59
60
|
"EEGModuleMixin",
|
|
60
61
|
"BIOT",
|
|
62
|
+
"BENDR",
|
|
61
63
|
"ContraWR",
|
|
62
64
|
"CTNet",
|
|
63
65
|
"Deep4Net",
|
braindecode/models/atcnet.py
CHANGED
|
@@ -50,7 +50,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
50
50
|
- **Temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(L_t, 1)`` builds a
|
|
51
51
|
FIR-like filter bank (``F1`` maps).
|
|
52
52
|
- **Depthwise spatial conv** (:class:`torch.nn.Conv2d`, ``groups=F1``) with kernel
|
|
53
|
-
``(1, n_chans)`` learns per-filter spatial projections (akin to EEGNet
|
|
53
|
+
``(1, n_chans)`` learns per-filter spatial projections (akin to EEGNet's CSP-like step).
|
|
54
54
|
- **BN → ELU → AvgPool → Dropout** to stabilize and condense activations.
|
|
55
55
|
- **Refining temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(L_r, 1)`` +
|
|
56
56
|
**BN → ELU → AvgPool → Dropout**.
|
|
@@ -62,13 +62,15 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
62
62
|
|
|
63
63
|
- **Sliding-Window Sequencer**
|
|
64
64
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
65
|
+
From the condensed time axis (length ``T_c``), ATCNet forms ``n`` overlapping windows
|
|
66
|
+
of width ``T_w = T_c - n + 1`` (one start per index). Each window produces a sequence
|
|
67
|
+
``(B, F2, T_w)`` forwarded to its own attention-TCN branch. This creates *parallel*
|
|
68
|
+
encoders over shifted contexts and is key to robustness on nonstationary EEG.
|
|
69
69
|
|
|
70
70
|
- :class:`_AttentionBlock` **(small MHA on temporal positions)**
|
|
71
71
|
|
|
72
|
+
Attention here is *local to a window* and purely temporal.
|
|
73
|
+
|
|
72
74
|
- *Operations.*
|
|
73
75
|
- Rearrange to ``(B, T_w, F2)``,
|
|
74
76
|
- Normalization :class:`torch.nn.LayerNorm`
|
|
@@ -76,11 +78,8 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
76
78
|
- Dropout :class:`torch.nn.Dropout`
|
|
77
79
|
- Rearrange back to ``(B, F2, T_w)``.
|
|
78
80
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
*Role.* Re-weights evidence across the window, letting the model emphasize informative
|
|
83
|
-
segments (onsets, bursts) before causal convolutions aggregate history.
|
|
81
|
+
*Role.* Re-weights evidence across the window, letting the model emphasize informative
|
|
82
|
+
segments (onsets, bursts) before causal convolutions aggregate history.
|
|
84
83
|
|
|
85
84
|
- :class:`_TCNResidualBlock` **(causal dilated temporal CNN)**
|
|
86
85
|
|
|
@@ -90,8 +89,8 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
90
89
|
a residual (identity or 1x1 mapping).
|
|
91
90
|
- The final feature used per window is the *last* causal step ``[..., -1]`` (forecast-style).
|
|
92
91
|
|
|
93
|
-
|
|
94
|
-
|
|
92
|
+
*Role.* Efficient long-range temporal integration with stable gradients; the dilated
|
|
93
|
+
receptive field complements attention's soft selection.
|
|
95
94
|
|
|
96
95
|
- **Aggregation & Classifier**
|
|
97
96
|
|
|
@@ -104,16 +103,16 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
104
103
|
.. rubric:: Convolutional Details
|
|
105
104
|
|
|
106
105
|
- **Temporal.** Temporal structure is learned in three places:
|
|
107
|
-
- (1) the stem
|
|
106
|
+
- (1) the stem's wide ``(L_t, 1)`` conv (learned filter bank),
|
|
108
107
|
- (2) the refining ``(L_r, 1)`` conv after pooling (short-term dynamics), and
|
|
109
|
-
- (3) the TCN
|
|
108
|
+
- (3) the TCN's causal 1-D convolutions with exponentially increasing dilation
|
|
110
109
|
(long-range dependencies). The minimum sequence length required by the TCN stack is
|
|
111
110
|
``(K_t - 1)·2^{L-1} + 1``; the implementation *auto-scales* kernels/pools/windows
|
|
112
111
|
when inputs are shorter to preserve feasibility.
|
|
113
112
|
|
|
114
113
|
- **Spatial.** A depthwise spatial conv spans the **full montage** (kernel ``(1, n_chans)``),
|
|
115
114
|
producing *per-temporal-filter* spatial projections (no cross-filter mixing at this step).
|
|
116
|
-
This mirrors EEGNet
|
|
115
|
+
This mirrors EEGNet's interpretability: each temporal filter has its own spatial pattern.
|
|
117
116
|
|
|
118
117
|
|
|
119
118
|
.. rubric:: Attention / Sequential Modules
|
|
@@ -137,17 +136,17 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
137
136
|
|
|
138
137
|
.. rubric:: Usage and Configuration
|
|
139
138
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
139
|
+
- ``conv_block_n_filters (F1)``, ``conv_block_depth_mult (D)`` → capacity of the stem
|
|
140
|
+
(with ``F2 = F1·D`` feeding attention/TCN), dimensions aligned to ``F2``, like :class:`EEGNet`.
|
|
141
|
+
- Pool sizes ``P1,P2`` trade temporal resolution for stability/compute; they set
|
|
142
|
+
``T_c = T/(P1·P2)`` and thus window width ``T_w``.
|
|
143
|
+
- ``n_windows`` controls the ensemble over shifts (compute ∝ windows).
|
|
144
|
+
- ``att_num_heads``, ``att_head_dim`` set attention capacity; keep ``H·d_h ≈ F2``.
|
|
145
|
+
- ``tcn_depth``, ``tcn_kernel_size`` govern receptive field; larger values demand
|
|
146
|
+
longer inputs (see minimum length above). The implementation warns and *rescales*
|
|
147
|
+
kernels/pools/windows if inputs are too short.
|
|
148
|
+
- **Aggregation choice.** ``concat=False`` (default, average of per-window logits) matches
|
|
149
|
+
the official code; ``concat=True`` mirrors the paper's concatenation variant.
|
|
151
150
|
|
|
152
151
|
|
|
153
152
|
Notes
|
|
@@ -97,7 +97,7 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
97
97
|
|
|
98
98
|
- **Temporal (where time-domain patterns are learned).**
|
|
99
99
|
Wide kernels in the stem (``(1, L_t)``) act as a learned filter bank for oscillatory
|
|
100
|
-
bands/transients; the attention block
|
|
100
|
+
bands/transients; the attention block's depthwise temporal conv (``(1, L_a)``) sharpens
|
|
101
101
|
short-term dynamics after downsampling. Pool sizes/strides (``P₁,S₁`` then ``P₂,S₂``)
|
|
102
102
|
set the token rate and effective temporal resolution.
|
|
103
103
|
|
|
@@ -127,23 +127,24 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
127
127
|
|
|
128
128
|
.. rubric:: Additional Mechanisms
|
|
129
129
|
|
|
130
|
-
|
|
131
|
-
- ``"se"``: Squeeze-and-Excitation (global pooling → bottleneck → gates).
|
|
132
|
-
- ``"gsop"``: Global second-order pooling (covariance-aware channel weights).
|
|
133
|
-
- ``"fca"``: Frequency Channel Attention (DCT summary; uses ``seq_len`` and ``freq_idx``).
|
|
134
|
-
- ``"encnet"``: EncNet with learned codewords (uses ``n_codewords``).
|
|
135
|
-
- ``"eca"``: Efficient Channel Attention (local 1-D conv over channel descriptor; uses ``kernel_size``).
|
|
136
|
-
- ``"ge"``: Gather–Excite (context pooling with optional MLP; can use ``extra_params``).
|
|
137
|
-
- ``"gct"``: Gated Channel Transformation (global context normalization + gating).
|
|
138
|
-
- ``"srm"``: Style-based recalibration (mean–std descriptors; optional MLP).
|
|
139
|
-
- ``"cbam"``: Channel then temporal attention (uses ``kernel_size``).
|
|
140
|
-
- ``"cat"`` / ``"catlite"``: Collaborative (channel ± temporal) attention; *lite* omits temporal.
|
|
141
|
-
- **Auto-compatibility on short inputs.**
|
|
130
|
+
**Attention variants at a glance:**
|
|
142
131
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
132
|
+
- ``"se"``: Squeeze-and-Excitation (global pooling → bottleneck → gates).
|
|
133
|
+
- ``"gsop"``: Global second-order pooling (covariance-aware channel weights).
|
|
134
|
+
- ``"fca"``: Frequency Channel Attention (DCT summary; uses ``seq_len`` and ``freq_idx``).
|
|
135
|
+
- ``"encnet"``: EncNet with learned codewords (uses ``n_codewords``).
|
|
136
|
+
- ``"eca"``: Efficient Channel Attention (local 1-D conv over channel descriptor; uses ``kernel_size``).
|
|
137
|
+
- ``"ge"``: Gather–Excite (context pooling with optional MLP; can use ``extra_params``).
|
|
138
|
+
- ``"gct"``: Gated Channel Transformation (global context normalization + gating).
|
|
139
|
+
- ``"srm"``: Style-based recalibration (mean–std descriptors; optional MLP).
|
|
140
|
+
- ``"cbam"``: Channel then temporal attention (uses ``kernel_size``).
|
|
141
|
+
- ``"cat"`` / ``"catlite"``: Collaborative (channel ± temporal) attention; *lite* omits temporal.
|
|
146
142
|
|
|
143
|
+
**Auto-compatibility on short inputs:**
|
|
144
|
+
|
|
145
|
+
If the input duration is too short for the configured kernels/pools, the implementation
|
|
146
|
+
**automatically rescales** temporal lengths/strides downward (with a warning) to keep
|
|
147
|
+
shapes valid and preserve the pipeline semantics.
|
|
147
148
|
|
|
148
149
|
.. rubric:: Usage and Configuration
|
|
149
150
|
|
|
@@ -158,9 +159,9 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
158
159
|
- ``drop_prob_inp`` and ``drop_prob_attn``: regularize stem and attention stages.
|
|
159
160
|
- **Training tips.**
|
|
160
161
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
162
|
+
Start with moderate pooling (e.g., ``P₁=75,S₁=15``) and ELU activations; enable attention
|
|
163
|
+
only after the stem learns stable filters. For small datasets, prefer simpler modes
|
|
164
|
+
(``"se"``, ``"eca"``) before heavier ones (``"gsop"``, ``"encnet"``).
|
|
164
165
|
|
|
165
166
|
Notes
|
|
166
167
|
-----
|
|
@@ -170,6 +171,7 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
170
171
|
specific variants (CBAM/CAT).
|
|
171
172
|
- The paper and original code with more details about the methodological
|
|
172
173
|
choices are available at the [Martin2023]_ and [MartinCode]_.
|
|
174
|
+
|
|
173
175
|
.. versionadded:: 0.9
|
|
174
176
|
|
|
175
177
|
Parameters
|
|
@@ -198,19 +200,21 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
198
200
|
the depth of the network after the initial layer. Default is 16.
|
|
199
201
|
attention_mode : str, optional
|
|
200
202
|
The type of attention mechanism to apply. If `None`, no attention is applied.
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
203
|
+
|
|
204
|
+
- "se" for Squeeze-and-excitation network
|
|
205
|
+
- "gsop" for Global Second-Order Pooling
|
|
206
|
+
- "fca" for Frequency Channel Attention Network
|
|
207
|
+
- "encnet" for context encoding module
|
|
208
|
+
- "eca" for Efficient channel attention for deep convolutional neural networks
|
|
209
|
+
- "ge" for Gather-Excite
|
|
210
|
+
- "gct" for Gated Channel Transformation
|
|
211
|
+
- "srm" for Style-based Recalibration Module
|
|
212
|
+
- "cbam" for Convolutional Block Attention Module
|
|
213
|
+
- "cat" for Learning to collaborate channel and temporal attention
|
|
214
|
+
from multi-information fusion
|
|
215
|
+
- "catlite" for Learning to collaborate channel attention
|
|
216
|
+
from multi-information fusion (lite version, cat w/o temporal attention)
|
|
217
|
+
|
|
214
218
|
pool_length : int, default=8
|
|
215
219
|
The length of the window for the average pooling operation.
|
|
216
220
|
pool_stride : int, default=8
|
|
@@ -381,6 +385,8 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
381
385
|
for k, pl, ps in zip(kernel_lengths, pool_lengths, pool_strides):
|
|
382
386
|
out = math.floor(out + 2 * (k // 2) - k + 1)
|
|
383
387
|
out = math.floor((out - pl) / ps + 1)
|
|
388
|
+
# Ensure output is at least 1 to avoid zero-sized tensors
|
|
389
|
+
out = max(1, out)
|
|
384
390
|
seq_lengths.append(int(out))
|
|
385
391
|
return seq_lengths
|
|
386
392
|
|
|
@@ -497,6 +503,7 @@ class _ChannelAttentionBlock(nn.Module):
|
|
|
497
503
|
----------
|
|
498
504
|
attention_mode : str, optional
|
|
499
505
|
The type of attention mechanism to apply. If `None`, no attention is applied.
|
|
506
|
+
|
|
500
507
|
- "se" for Squeeze-and-excitation network
|
|
501
508
|
- "gsop" for Global Second-Order Pooling
|
|
502
509
|
- "fca" for Frequency Channel Attention Network
|