braindecode 1.3.0.dev180329405__py3-none-any.whl → 1.3.0.dev182330353__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. braindecode/augmentation/base.py +1 -1
  2. braindecode/datasets/__init__.py +12 -4
  3. braindecode/datasets/base.py +115 -151
  4. braindecode/datasets/bcicomp.py +4 -4
  5. braindecode/datasets/bids.py +3 -3
  6. braindecode/datasets/experimental.py +2 -2
  7. braindecode/datasets/mne.py +3 -5
  8. braindecode/datasets/moabb.py +17 -7
  9. braindecode/datasets/nmt.py +2 -2
  10. braindecode/datasets/sleep_physio_challe_18.py +2 -2
  11. braindecode/datasets/sleep_physionet.py +2 -2
  12. braindecode/datasets/tuh.py +2 -2
  13. braindecode/datasets/xy.py +2 -2
  14. braindecode/datautil/__init__.py +11 -1
  15. braindecode/datautil/channel_utils.py +114 -0
  16. braindecode/datautil/serialization.py +7 -7
  17. braindecode/functional/functions.py +6 -2
  18. braindecode/functional/initialization.py +2 -3
  19. braindecode/models/__init__.py +6 -0
  20. braindecode/models/atcnet.py +26 -27
  21. braindecode/models/attentionbasenet.py +37 -32
  22. braindecode/models/attn_sleep.py +2 -0
  23. braindecode/models/base.py +280 -2
  24. braindecode/models/bendr.py +469 -0
  25. braindecode/models/biot.py +2 -0
  26. braindecode/models/contrawr.py +2 -0
  27. braindecode/models/ctnet.py +8 -3
  28. braindecode/models/deepsleepnet.py +28 -19
  29. braindecode/models/eegconformer.py +2 -2
  30. braindecode/models/eeginception_erp.py +31 -25
  31. braindecode/models/eegitnet.py +2 -0
  32. braindecode/models/eegminer.py +2 -0
  33. braindecode/models/eegnet.py +1 -1
  34. braindecode/models/eegsym.py +917 -0
  35. braindecode/models/eegtcnet.py +2 -0
  36. braindecode/models/fbcnet.py +5 -1
  37. braindecode/models/fblightconvnet.py +2 -0
  38. braindecode/models/fbmsnet.py +20 -6
  39. braindecode/models/ifnet.py +2 -0
  40. braindecode/models/labram.py +33 -26
  41. braindecode/models/medformer.py +758 -0
  42. braindecode/models/msvtnet.py +2 -0
  43. braindecode/models/patchedtransformer.py +1 -1
  44. braindecode/models/signal_jepa.py +111 -27
  45. braindecode/models/sinc_shallow.py +12 -9
  46. braindecode/models/sstdpn.py +11 -11
  47. braindecode/models/summary.csv +3 -0
  48. braindecode/models/syncnet.py +2 -0
  49. braindecode/models/tcn.py +2 -0
  50. braindecode/models/usleep.py +26 -21
  51. braindecode/models/util.py +3 -0
  52. braindecode/modules/attention.py +10 -10
  53. braindecode/modules/blocks.py +3 -3
  54. braindecode/modules/filter.py +2 -9
  55. braindecode/modules/layers.py +18 -17
  56. braindecode/preprocessing/__init__.py +232 -3
  57. braindecode/preprocessing/eegprep_preprocess.py +1202 -0
  58. braindecode/preprocessing/mne_preprocess.py +142 -10
  59. braindecode/preprocessing/preprocess.py +28 -18
  60. braindecode/preprocessing/util.py +166 -0
  61. braindecode/preprocessing/windowers.py +26 -20
  62. braindecode/samplers/base.py +8 -8
  63. braindecode/version.py +1 -1
  64. {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/METADATA +6 -2
  65. braindecode-1.3.0.dev182330353.dist-info/RECORD +109 -0
  66. braindecode-1.3.0.dev180329405.dist-info/RECORD +0 -103
  67. {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/WHEEL +0 -0
  68. {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/licenses/LICENSE.txt +0 -0
  69. {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/licenses/NOTICE.txt +0 -0
  70. {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,7 @@ import numpy as np
12
12
  import pandas as pd
13
13
  from numpy.typing import ArrayLike, NDArray
14
14
 
15
- from .base import BaseConcatDataset, BaseDataset
15
+ from .base import BaseConcatDataset, RawDataset
16
16
 
17
17
  log = logging.getLogger(__name__)
18
18
 
@@ -69,7 +69,7 @@ def create_from_X_y(
69
69
  n_samples_per_x.append(x.shape[1])
70
70
  info = mne.create_info(ch_names=ch_names, sfreq=sfreq)
71
71
  raw = mne.io.RawArray(x, info)
72
- base_dataset = BaseDataset(
72
+ base_dataset = RawDataset(
73
73
  raw, pd.Series({"target": target}), target_name="target"
74
74
  )
75
75
  base_datasets.append(base_dataset)
@@ -2,6 +2,10 @@
2
2
  Utilities for data manipulation.
3
3
  """
4
4
 
5
+ from .channel_utils import (
6
+ division_channels_idx,
7
+ match_hemisphere_chans,
8
+ )
5
9
  from .serialization import (
6
10
  _check_save_dir_empty,
7
11
  load_concat_dataset,
@@ -49,4 +53,10 @@ def __getattr__(name):
49
53
  raise AttributeError("No possible import named " + name)
50
54
 
51
55
 
52
- __all__ = ["load_concat_dataset", "save_concat_dataset", "_check_save_dir_empty"]
56
+ __all__ = [
57
+ "load_concat_dataset",
58
+ "save_concat_dataset",
59
+ "_check_save_dir_empty",
60
+ "match_hemisphere_chans",
61
+ "division_channels_idx",
62
+ ]
@@ -0,0 +1,114 @@
1
+ """
2
+ Utilities for EEG channel manipulation and selection.
3
+
4
+ This module provides functions for dividing and matching EEG channels,
5
+ particularly for hemisphere-aware processing.
6
+ """
7
+
8
+ import re
9
+ from re import search
10
+
11
+
12
+ def match_hemisphere_chans(left_chs, right_chs):
13
+ """
14
+ Match channels of the left and right hemispheres based on their names.
15
+
16
+ This function pairs channels from the left and right hemispheres by matching
17
+ their numeric identifiers. For a left channel with number N, it finds the
18
+ corresponding right channel with number N+1.
19
+
20
+ Parameters
21
+ ----------
22
+ left_chs : list of str
23
+ A list of channel names from the left hemisphere.
24
+ right_chs : list of str
25
+ A list of channel names from the right hemisphere.
26
+
27
+ Returns
28
+ -------
29
+ list of tuples
30
+ List of tuples with matched channel names from the left and right hemispheres.
31
+ Each tuple contains (left_channel, right_channel).
32
+
33
+ Raises
34
+ ------
35
+ ValueError
36
+ If the left and right channels do not match in length.
37
+ ValueError
38
+ If a channel name does not contain a number.
39
+ ValueError
40
+ If no matching right hemisphere channel is found for a left channel.
41
+
42
+ Examples
43
+ --------
44
+ >>> left = ['C3', 'F3']
45
+ >>> right = ['C4', 'F4']
46
+ >>> match_hemisphere_chans(left, right)
47
+ [('C3', 'C4'), ('F3', 'F4')]
48
+ """
49
+ if len(left_chs) != len(right_chs):
50
+ raise ValueError("Left and right channels do not match.")
51
+ right_chs = list(right_chs)
52
+ regexp = r"\d+"
53
+ out = []
54
+ for left in left_chs:
55
+ match = re.search(regexp, left)
56
+ if match is None:
57
+ raise ValueError(f"Channel '{left}' does not contain a number.")
58
+ chan_idx = 1 + int(match.group())
59
+ target_r = re.sub(regexp, str(chan_idx), left)
60
+ for right in right_chs:
61
+ if right == target_r:
62
+ out.append((left, right))
63
+ right_chs.remove(right)
64
+ break
65
+ else:
66
+ raise ValueError(
67
+ f"Found no right hemisphere matching channel for '{left}'."
68
+ )
69
+ return out
70
+
71
+
72
+ def division_channels_idx(ch_names):
73
+ """
74
+ Divide EEG channel names into left, right, and middle based on numbering.
75
+
76
+ This function categorizes channels by their numeric suffix:
77
+ - Odd-numbered channels → left hemisphere
78
+ - Even-numbered channels → right hemisphere
79
+ - Channels without numbers → middle/midline
80
+
81
+ Parameters
82
+ ----------
83
+ ch_names : list of str
84
+ A list of EEG channel names to be divided based on their numbering.
85
+
86
+ Returns
87
+ -------
88
+ tuple of lists
89
+ Three lists containing the channel names:
90
+ - left: Odd-numbered channels (e.g., C3, F3, P3)
91
+ - right: Even-numbered channels (e.g., C4, F4, P4)
92
+ - middle: Channels without numbers (e.g., Cz, Fz, Pz)
93
+
94
+ Notes
95
+ -----
96
+ The function identifies channel numbers by searching for numeric characters
97
+ in the channel names. Standard 10-20 system EEG channel naming conventions
98
+ use odd numbers for left hemisphere and even numbers for right hemisphere.
99
+
100
+ Examples
101
+ --------
102
+ >>> channels = ['FP1', 'FP2', 'O1', 'O2', 'FZ']
103
+ >>> division_channels_idx(channels)
104
+ (['FP1', 'O1'], ['FP2', 'O2'], ['FZ'])
105
+ """
106
+ left, right, middle = [], [], []
107
+ for ch in ch_names:
108
+ number = search(r"\d+", ch)
109
+ if number is not None:
110
+ (left if int(number[0]) % 2 else right).append(ch)
111
+ else:
112
+ middle.append(ch)
113
+
114
+ return left, right, middle
@@ -19,8 +19,8 @@ from joblib import Parallel, delayed
19
19
 
20
20
  from ..datasets.base import (
21
21
  BaseConcatDataset,
22
- BaseDataset,
23
22
  EEGWindowsDataset,
23
+ RawDataset,
24
24
  WindowsDataset,
25
25
  )
26
26
 
@@ -35,7 +35,7 @@ def save_concat_dataset(path, concat_dataset, overwrite=False):
35
35
 
36
36
 
37
37
  def _outdated_load_concat_dataset(path, preload, ids_to_load=None, target_name=None):
38
- """Load a stored BaseConcatDataset of BaseDatasets or WindowsDatasets from
38
+ """Load a stored BaseConcatDataset from
39
39
  files.
40
40
 
41
41
  Parameters
@@ -52,7 +52,7 @@ def _outdated_load_concat_dataset(path, preload, ids_to_load=None, target_name=N
52
52
 
53
53
  Returns
54
54
  -------
55
- concat_dataset: BaseConcatDataset of BaseDatasets or WindowsDatasets
55
+ concat_dataset: BaseConcatDataset
56
56
  """
57
57
  # assume we have a single concat dataset to load
58
58
  is_raw = (path / "0-raw.fif").is_file()
@@ -87,7 +87,7 @@ def _outdated_load_concat_dataset(path, preload, ids_to_load=None, target_name=N
87
87
  for i_signal, signal in enumerate(all_signals):
88
88
  if is_raw:
89
89
  datasets.append(
90
- BaseDataset(
90
+ RawDataset(
91
91
  signal, description.iloc[i_signal], target_name=target_name
92
92
  )
93
93
  )
@@ -175,7 +175,7 @@ def _load_signals(fif_file, preload, is_raw):
175
175
 
176
176
 
177
177
  def load_concat_dataset(path, preload, ids_to_load=None, target_name=None, n_jobs=1):
178
- """Load a stored BaseConcatDataset of BaseDatasets or WindowsDatasets from
178
+ """Load a stored BaseConcatDataset from
179
179
  files.
180
180
 
181
181
  Parameters
@@ -194,7 +194,7 @@ def load_concat_dataset(path, preload, ids_to_load=None, target_name=None, n_job
194
194
 
195
195
  Returns
196
196
  -------
197
- concat_dataset: BaseConcatDataset of BaseDatasets or WindowsDatasets
197
+ concat_dataset: BaseConcatDataset
198
198
  """
199
199
  # Make sure we always work with a pathlib.Path
200
200
  path = Path(path)
@@ -266,7 +266,7 @@ def _load_parallel(path, i, preload, is_raw, has_stored_windows):
266
266
  target_name = json.load(open(target_file_path, "r"))["target_name"]
267
267
 
268
268
  if is_raw and (not has_stored_windows):
269
- dataset = BaseDataset(signals, description, target_name)
269
+ dataset = RawDataset(signals, description, target_name)
270
270
  else:
271
271
  window_kwargs = _load_kwargs_json("window_kwargs", sub_dir)
272
272
  windows_ds_kwargs = [
@@ -181,20 +181,24 @@ def plv_time(x, forward_fourier=True, epsilon: float = 1e-6):
181
181
  The Phase Locking Value (PLV) is a measure of the synchronization between
182
182
  different channels by evaluating the consistency of phase differences
183
183
  over time. It ranges from 0 (no synchronization) to 1 (perfect
184
- synchronization) [1]_.
184
+ synchronization) [Lachaux1999]_.
185
185
 
186
186
  Parameters
187
187
  ----------
188
188
  x : torch.Tensor
189
189
  Input tensor containing the signal data.
190
+
190
191
  - If `forward_fourier` is `True`, the shape should be `(..., channels, time)`.
191
192
  - If `forward_fourier` is `False`, the shape should be `(..., channels, freqs, 2)`,
192
193
  where the last dimension represents the real and imaginary parts.
194
+
193
195
  forward_fourier : bool, optional
194
196
  Specifies the format of the input tensor `x`.
197
+
195
198
  - If `True`, `x` is assumed to be in the time domain.
196
199
  - If `False`, `x` is assumed to be in the Fourier domain with separate real and
197
200
  imaginary components.
201
+
198
202
  Default is `True`.
199
203
  epsilon : float, default 1e-6
200
204
  Small numerical value to ensure positivity constraint on the complex part
@@ -207,7 +211,7 @@ def plv_time(x, forward_fourier=True, epsilon: float = 1e-6):
207
211
 
208
212
  References
209
213
  ----------
210
- [1] Lachaux, J. P., Rodriguez, E., Martinerie, J., & Varela, F. J. (1999).
214
+ .. [Lachaux1999] Lachaux, J. P., Rodriguez, E., Martinerie, J., & Varela, F. J. (1999).
211
215
  Measuring phase synchrony in brain signals. Human brain mapping,
212
216
  8(4), 194-208.
213
217
  """
@@ -5,9 +5,8 @@ from torch import nn
5
5
 
6
6
  def glorot_weight_zero_bias(model):
7
7
  """Initialize parameters of all modules by initializing weights with
8
- glorot
9
- uniform/xavier initialization, and setting biases to zero. Weights from
10
- batch norm layers are set to 1.
8
+ glorot uniform/xavier initialization, and setting biases to zero. Weights from
9
+ batch norm layers are set to 1.
11
10
 
12
11
  Parameters
13
12
  ----------
@@ -6,6 +6,7 @@ from .atcnet import ATCNet
6
6
  from .attentionbasenet import AttentionBaseNet
7
7
  from .attn_sleep import AttnSleep
8
8
  from .base import EEGModuleMixin
9
+ from .bendr import BENDR
9
10
  from .biot import BIOT
10
11
  from .contrawr import ContraWR
11
12
  from .ctnet import CTNet
@@ -19,6 +20,7 @@ from .eegminer import EEGMiner
19
20
  from .eegnet import EEGNet, EEGNetv4
20
21
  from .eegnex import EEGNeX
21
22
  from .eegsimpleconv import EEGSimpleConv
23
+ from .eegsym import EEGSym
22
24
  from .eegtcnet import EEGTCNet
23
25
  from .fbcnet import FBCNet
24
26
  from .fblightconvnet import FBLightConvNet
@@ -26,6 +28,7 @@ from .fbmsnet import FBMSNet
26
28
  from .hybrid import HybridNet
27
29
  from .ifnet import IFNet
28
30
  from .labram import Labram
31
+ from .medformer import MEDFormer
29
32
  from .msvtnet import MSVTNet
30
33
  from .patchedtransformer import PBT
31
34
  from .sccnet import SCCNet
@@ -58,6 +61,7 @@ __all__ = [
58
61
  "AttentionBaseNet",
59
62
  "EEGModuleMixin",
60
63
  "BIOT",
64
+ "BENDR",
61
65
  "ContraWR",
62
66
  "CTNet",
63
67
  "Deep4Net",
@@ -70,6 +74,7 @@ __all__ = [
70
74
  "EEGNet",
71
75
  "EEGNetv4",
72
76
  "EEGNeX",
77
+ "EEGSym",
73
78
  "EEGSimpleConv",
74
79
  "EEGTCNet",
75
80
  "FBCNet",
@@ -78,6 +83,7 @@ __all__ = [
78
83
  "HybridNet",
79
84
  "IFNet",
80
85
  "Labram",
86
+ "MEDFormer",
81
87
  "MSVTNet",
82
88
  "PBT",
83
89
  "SCCNet",
@@ -15,7 +15,7 @@ from braindecode.modules import CausalConv1d, Ensure4d, MaxNormLinear
15
15
  class ATCNet(EEGModuleMixin, nn.Module):
16
16
  """ATCNet from Altaheri et al. (2022) [1]_.
17
17
 
18
- :bdg-success:`Convolution` :bdg-info:`Small Attention`
18
+ :bdg-success:`Convolution` :bdg-secondary:`Recurrent` :bdg-info:`Small Attention`
19
19
 
20
20
  .. figure:: https://user-images.githubusercontent.com/25565236/185449791-e8539453-d4fa-41e1-865a-2cf7e91f60ef.png
21
21
  :align: center
@@ -50,7 +50,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
50
50
  - **Temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(L_t, 1)`` builds a
51
51
  FIR-like filter bank (``F1`` maps).
52
52
  - **Depthwise spatial conv** (:class:`torch.nn.Conv2d`, ``groups=F1``) with kernel
53
- ``(1, n_chans)`` learns per-filter spatial projections (akin to EEGNets CSP-like step).
53
+ ``(1, n_chans)`` learns per-filter spatial projections (akin to EEGNet's CSP-like step).
54
54
  - **BN → ELU → AvgPool → Dropout** to stabilize and condense activations.
55
55
  - **Refining temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(L_r, 1)`` +
56
56
  **BN → ELU → AvgPool → Dropout**.
@@ -62,13 +62,15 @@ class ATCNet(EEGModuleMixin, nn.Module):
62
62
 
63
63
  - **Sliding-Window Sequencer**
64
64
 
65
- From the condensed time axis (length ``T_c``), ATCNet forms ``n`` overlapping windows
66
- of width ``T_w = T_c - n + 1`` (one start per index). Each window produces a sequence
67
- ``(B, F2, T_w)`` forwarded to its own attentionTCN branch. This creates *parallel*
68
- encoders over shifted contexts and is key to robustness on nonstationary EEG.
65
+ From the condensed time axis (length ``T_c``), ATCNet forms ``n`` overlapping windows
66
+ of width ``T_w = T_c - n + 1`` (one start per index). Each window produces a sequence
67
+ ``(B, F2, T_w)`` forwarded to its own attention-TCN branch. This creates *parallel*
68
+ encoders over shifted contexts and is key to robustness on nonstationary EEG.
69
69
 
70
70
  - :class:`_AttentionBlock` **(small MHA on temporal positions)**
71
71
 
72
+ Attention here is *local to a window* and purely temporal.
73
+
72
74
  - *Operations.*
73
75
  - Rearrange to ``(B, T_w, F2)``,
74
76
  - Normalization :class:`torch.nn.LayerNorm`
@@ -76,11 +78,8 @@ class ATCNet(EEGModuleMixin, nn.Module):
76
78
  - Dropout :class:`torch.nn.Dropout`
77
79
  - Rearrange back to ``(B, F2, T_w)``.
78
80
 
79
-
80
- **Note**: Attention is *local to a window* and purely temporal.
81
-
82
- *Role.* Re-weights evidence across the window, letting the model emphasize informative
83
- segments (onsets, bursts) before causal convolutions aggregate history.
81
+ *Role.* Re-weights evidence across the window, letting the model emphasize informative
82
+ segments (onsets, bursts) before causal convolutions aggregate history.
84
83
 
85
84
  - :class:`_TCNResidualBlock` **(causal dilated temporal CNN)**
86
85
 
@@ -90,8 +89,8 @@ class ATCNet(EEGModuleMixin, nn.Module):
90
89
  a residual (identity or 1x1 mapping).
91
90
  - The final feature used per window is the *last* causal step ``[..., -1]`` (forecast-style).
92
91
 
93
- *Role.* Efficient long-range temporal integration with stable gradients; the dilated
94
- receptive field complements attentions soft selection.
92
+ *Role.* Efficient long-range temporal integration with stable gradients; the dilated
93
+ receptive field complements attention's soft selection.
95
94
 
96
95
  - **Aggregation & Classifier**
97
96
 
@@ -104,16 +103,16 @@ class ATCNet(EEGModuleMixin, nn.Module):
104
103
  .. rubric:: Convolutional Details
105
104
 
106
105
  - **Temporal.** Temporal structure is learned in three places:
107
- - (1) the stems wide ``(L_t, 1)`` conv (learned filter bank),
106
+ - (1) the stem's wide ``(L_t, 1)`` conv (learned filter bank),
108
107
  - (2) the refining ``(L_r, 1)`` conv after pooling (short-term dynamics), and
109
- - (3) the TCNs causal 1-D convolutions with exponentially increasing dilation
108
+ - (3) the TCN's causal 1-D convolutions with exponentially increasing dilation
110
109
  (long-range dependencies). The minimum sequence length required by the TCN stack is
111
110
  ``(K_t - 1)·2^{L-1} + 1``; the implementation *auto-scales* kernels/pools/windows
112
111
  when inputs are shorter to preserve feasibility.
113
112
 
114
113
  - **Spatial.** A depthwise spatial conv spans the **full montage** (kernel ``(1, n_chans)``),
115
114
  producing *per-temporal-filter* spatial projections (no cross-filter mixing at this step).
116
- This mirrors EEGNets interpretability: each temporal filter has its own spatial pattern.
115
+ This mirrors EEGNet's interpretability: each temporal filter has its own spatial pattern.
117
116
 
118
117
 
119
118
  .. rubric:: Attention / Sequential Modules
@@ -137,17 +136,17 @@ class ATCNet(EEGModuleMixin, nn.Module):
137
136
 
138
137
  .. rubric:: Usage and Configuration
139
138
 
140
- - ``conv_block_n_filters (F1)``, ``conv_block_depth_mult (D)`` → capacity of the stem
141
- (with ``F2 = F1·D`` feeding attention/TCN), dimensions aligned to ``F2``, like :class:`EEGNet`.
142
- - Pool sizes ``P1,P2`` trade temporal resolution for stability/compute; they set
143
- ``T_c = T/(P1·P2)`` and thus window width ``T_w``.
144
- - ``n_windows`` controls the ensemble over shifts (compute ∝ windows).
145
- - ``att_num_heads``, ``att_head_dim`` set attention capacity; keep ``H·d_h ≈ F2``.
146
- - ``tcn_depth``, ``tcn_kernel_size`` govern receptive field; larger values demand
147
- longer inputs (see minimum length above). The implementation warns and *rescales*
148
- kernels/pools/windows if inputs are too short.
149
- - **Aggregation choice.** ``concat=False`` (default, average of per-window logits) matches
150
- the official code; ``concat=True`` mirrors the papers concatenation variant.
139
+ - ``conv_block_n_filters (F1)``, ``conv_block_depth_mult (D)`` → capacity of the stem
140
+ (with ``F2 = F1·D`` feeding attention/TCN), dimensions aligned to ``F2``, like :class:`EEGNet`.
141
+ - Pool sizes ``P1,P2`` trade temporal resolution for stability/compute; they set
142
+ ``T_c = T/(P1·P2)`` and thus window width ``T_w``.
143
+ - ``n_windows`` controls the ensemble over shifts (compute ∝ windows).
144
+ - ``att_num_heads``, ``att_head_dim`` set attention capacity; keep ``H·d_h ≈ F2``.
145
+ - ``tcn_depth``, ``tcn_kernel_size`` govern receptive field; larger values demand
146
+ longer inputs (see minimum length above). The implementation warns and *rescales*
147
+ kernels/pools/windows if inputs are too short.
148
+ - **Aggregation choice.** ``concat=False`` (default, average of per-window logits) matches
149
+ the official code; ``concat=True`` mirrors the paper's concatenation variant.
151
150
 
152
151
 
153
152
  Notes
@@ -97,7 +97,7 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
97
97
 
98
98
  - **Temporal (where time-domain patterns are learned).**
99
99
  Wide kernels in the stem (``(1, L_t)``) act as a learned filter bank for oscillatory
100
- bands/transients; the attention blocks depthwise temporal conv (``(1, L_a)``) sharpens
100
+ bands/transients; the attention block's depthwise temporal conv (``(1, L_a)``) sharpens
101
101
  short-term dynamics after downsampling. Pool sizes/strides (``P₁,S₁`` then ``P₂,S₂``)
102
102
  set the token rate and effective temporal resolution.
103
103
 
@@ -127,23 +127,24 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
127
127
 
128
128
  .. rubric:: Additional Mechanisms
129
129
 
130
- - **Attention variants at a glance.**
131
- - ``"se"``: Squeeze-and-Excitation (global pooling → bottleneck → gates).
132
- - ``"gsop"``: Global second-order pooling (covariance-aware channel weights).
133
- - ``"fca"``: Frequency Channel Attention (DCT summary; uses ``seq_len`` and ``freq_idx``).
134
- - ``"encnet"``: EncNet with learned codewords (uses ``n_codewords``).
135
- - ``"eca"``: Efficient Channel Attention (local 1-D conv over channel descriptor; uses ``kernel_size``).
136
- - ``"ge"``: Gather–Excite (context pooling with optional MLP; can use ``extra_params``).
137
- - ``"gct"``: Gated Channel Transformation (global context normalization + gating).
138
- - ``"srm"``: Style-based recalibration (mean–std descriptors; optional MLP).
139
- - ``"cbam"``: Channel then temporal attention (uses ``kernel_size``).
140
- - ``"cat"`` / ``"catlite"``: Collaborative (channel ± temporal) attention; *lite* omits temporal.
141
- - **Auto-compatibility on short inputs.**
130
+ **Attention variants at a glance:**
142
131
 
143
- If the input duration is too short for the configured kernels/pools, the implementation
144
- **automatically rescales** temporal lengths/strides downward (with a warning) to keep
145
- shapes valid and preserve the pipeline semantics.
132
+ - ``"se"``: Squeeze-and-Excitation (global pooling bottleneck gates).
133
+ - ``"gsop"``: Global second-order pooling (covariance-aware channel weights).
134
+ - ``"fca"``: Frequency Channel Attention (DCT summary; uses ``seq_len`` and ``freq_idx``).
135
+ - ``"encnet"``: EncNet with learned codewords (uses ``n_codewords``).
136
+ - ``"eca"``: Efficient Channel Attention (local 1-D conv over channel descriptor; uses ``kernel_size``).
137
+ - ``"ge"``: Gather–Excite (context pooling with optional MLP; can use ``extra_params``).
138
+ - ``"gct"``: Gated Channel Transformation (global context normalization + gating).
139
+ - ``"srm"``: Style-based recalibration (mean–std descriptors; optional MLP).
140
+ - ``"cbam"``: Channel then temporal attention (uses ``kernel_size``).
141
+ - ``"cat"`` / ``"catlite"``: Collaborative (channel ± temporal) attention; *lite* omits temporal.
146
142
 
143
+ **Auto-compatibility on short inputs:**
144
+
145
+ If the input duration is too short for the configured kernels/pools, the implementation
146
+ **automatically rescales** temporal lengths/strides downward (with a warning) to keep
147
+ shapes valid and preserve the pipeline semantics.
147
148
 
148
149
  .. rubric:: Usage and Configuration
149
150
 
@@ -158,9 +159,9 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
158
159
  - ``drop_prob_inp`` and ``drop_prob_attn``: regularize stem and attention stages.
159
160
  - **Training tips.**
160
161
 
161
- Start with moderate pooling (e.g., ``P₁=75,S₁=15``) and ELU activations; enable attention
162
- only after the stem learns stable filters. For small datasets, prefer simpler modes
163
- (``"se"``, ``"eca"``) before heavier ones (``"gsop"``, ``"encnet"``).
162
+ Start with moderate pooling (e.g., ``P₁=75,S₁=15``) and ELU activations; enable attention
163
+ only after the stem learns stable filters. For small datasets, prefer simpler modes
164
+ (``"se"``, ``"eca"``) before heavier ones (``"gsop"``, ``"encnet"``).
164
165
 
165
166
  Notes
166
167
  -----
@@ -170,6 +171,7 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
170
171
  specific variants (CBAM/CAT).
171
172
  - The paper and original code with more details about the methodological
172
173
  choices are available at the [Martin2023]_ and [MartinCode]_.
174
+
173
175
  .. versionadded:: 0.9
174
176
 
175
177
  Parameters
@@ -198,19 +200,21 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
198
200
  the depth of the network after the initial layer. Default is 16.
199
201
  attention_mode : str, optional
200
202
  The type of attention mechanism to apply. If `None`, no attention is applied.
201
- - "se" for Squeeze-and-excitation network
202
- - "gsop" for Global Second-Order Pooling
203
- - "fca" for Frequency Channel Attention Network
204
- - "encnet" for context encoding module
205
- - "eca" for Efficient channel attention for deep convolutional neural networks
206
- - "ge" for Gather-Excite
207
- - "gct" for Gated Channel Transformation
208
- - "srm" for Style-based Recalibration Module
209
- - "cbam" for Convolutional Block Attention Module
210
- - "cat" for Learning to collaborate channel and temporal attention
211
- from multi-information fusion
212
- - "catlite" for Learning to collaborate channel attention
213
- from multi-information fusion (lite version, cat w/o temporal attention)
203
+
204
+ - "se" for Squeeze-and-excitation network
205
+ - "gsop" for Global Second-Order Pooling
206
+ - "fca" for Frequency Channel Attention Network
207
+ - "encnet" for context encoding module
208
+ - "eca" for Efficient channel attention for deep convolutional neural networks
209
+ - "ge" for Gather-Excite
210
+ - "gct" for Gated Channel Transformation
211
+ - "srm" for Style-based Recalibration Module
212
+ - "cbam" for Convolutional Block Attention Module
213
+ - "cat" for Learning to collaborate channel and temporal attention
214
+ from multi-information fusion
215
+ - "catlite" for Learning to collaborate channel attention
216
+ from multi-information fusion (lite version, cat w/o temporal attention)
217
+
214
218
  pool_length : int, default=8
215
219
  The length of the window for the average pooling operation.
216
220
  pool_stride : int, default=8
@@ -499,6 +503,7 @@ class _ChannelAttentionBlock(nn.Module):
499
503
  ----------
500
504
  attention_mode : str, optional
501
505
  The type of attention mechanism to apply. If `None`, no attention is applied.
506
+
502
507
  - "se" for Squeeze-and-excitation network
503
508
  - "gsop" for Global Second-Order Pooling
504
509
  - "fca" for Frequency Channel Attention Network
@@ -18,6 +18,8 @@ from braindecode.modules import CausalConv1d
18
18
  class AttnSleep(EEGModuleMixin, nn.Module):
19
19
  """Sleep Staging Architecture from Eldele et al. (2021) [Eldele2021]_.
20
20
 
21
+ :bdg-success:`Convolution` :bdg-info:`Small Attention`
22
+
21
23
  .. figure:: https://raw.githubusercontent.com/emadeldeen24/AttnSleep/refs/heads/main/imgs/AttnSleep.png
22
24
  :align: center
23
25
  :alt: AttnSleep Architecture