braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. braindecode/augmentation/__init__.py +3 -5
  2. braindecode/augmentation/base.py +5 -8
  3. braindecode/augmentation/functional.py +22 -25
  4. braindecode/augmentation/transforms.py +42 -51
  5. braindecode/classifier.py +16 -11
  6. braindecode/datasets/__init__.py +3 -5
  7. braindecode/datasets/base.py +13 -17
  8. braindecode/datasets/bbci.py +14 -13
  9. braindecode/datasets/bcicomp.py +5 -4
  10. braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
  11. braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
  12. braindecode/datasets/{bids/hub.py → hub.py} +350 -375
  13. braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
  14. braindecode/datasets/mne.py +19 -19
  15. braindecode/datasets/moabb.py +10 -10
  16. braindecode/datasets/nmt.py +56 -58
  17. braindecode/datasets/sleep_physio_challe_18.py +5 -3
  18. braindecode/datasets/sleep_physionet.py +5 -5
  19. braindecode/datasets/tuh.py +18 -21
  20. braindecode/datasets/xy.py +9 -10
  21. braindecode/datautil/__init__.py +3 -3
  22. braindecode/datautil/serialization.py +20 -22
  23. braindecode/datautil/util.py +7 -120
  24. braindecode/eegneuralnet.py +52 -22
  25. braindecode/functional/functions.py +10 -7
  26. braindecode/functional/initialization.py +2 -3
  27. braindecode/models/__init__.py +3 -5
  28. braindecode/models/atcnet.py +39 -43
  29. braindecode/models/attentionbasenet.py +41 -37
  30. braindecode/models/attn_sleep.py +24 -26
  31. braindecode/models/base.py +6 -6
  32. braindecode/models/bendr.py +26 -50
  33. braindecode/models/biot.py +30 -61
  34. braindecode/models/contrawr.py +5 -5
  35. braindecode/models/ctnet.py +35 -35
  36. braindecode/models/deep4.py +5 -5
  37. braindecode/models/deepsleepnet.py +7 -7
  38. braindecode/models/eegconformer.py +26 -31
  39. braindecode/models/eeginception_erp.py +2 -2
  40. braindecode/models/eeginception_mi.py +6 -6
  41. braindecode/models/eegitnet.py +5 -5
  42. braindecode/models/eegminer.py +1 -1
  43. braindecode/models/eegnet.py +3 -3
  44. braindecode/models/eegnex.py +2 -2
  45. braindecode/models/eegsimpleconv.py +2 -2
  46. braindecode/models/eegsym.py +7 -7
  47. braindecode/models/eegtcnet.py +6 -6
  48. braindecode/models/fbcnet.py +2 -2
  49. braindecode/models/fblightconvnet.py +3 -3
  50. braindecode/models/fbmsnet.py +3 -3
  51. braindecode/models/hybrid.py +2 -2
  52. braindecode/models/ifnet.py +5 -5
  53. braindecode/models/labram.py +46 -70
  54. braindecode/models/luna.py +5 -60
  55. braindecode/models/medformer.py +21 -23
  56. braindecode/models/msvtnet.py +15 -15
  57. braindecode/models/patchedtransformer.py +55 -55
  58. braindecode/models/sccnet.py +2 -2
  59. braindecode/models/shallow_fbcsp.py +3 -5
  60. braindecode/models/signal_jepa.py +12 -39
  61. braindecode/models/sinc_shallow.py +4 -3
  62. braindecode/models/sleep_stager_blanco_2020.py +2 -2
  63. braindecode/models/sleep_stager_chambon_2018.py +2 -2
  64. braindecode/models/sparcnet.py +8 -8
  65. braindecode/models/sstdpn.py +869 -869
  66. braindecode/models/summary.csv +17 -19
  67. braindecode/models/syncnet.py +2 -2
  68. braindecode/models/tcn.py +5 -5
  69. braindecode/models/tidnet.py +3 -3
  70. braindecode/models/tsinception.py +3 -3
  71. braindecode/models/usleep.py +7 -7
  72. braindecode/models/util.py +14 -165
  73. braindecode/modules/__init__.py +1 -9
  74. braindecode/modules/activation.py +3 -29
  75. braindecode/modules/attention.py +0 -123
  76. braindecode/modules/blocks.py +1 -53
  77. braindecode/modules/convolution.py +0 -53
  78. braindecode/modules/filter.py +0 -31
  79. braindecode/modules/layers.py +0 -84
  80. braindecode/modules/linear.py +1 -22
  81. braindecode/modules/stats.py +0 -10
  82. braindecode/modules/util.py +0 -9
  83. braindecode/modules/wrapper.py +0 -17
  84. braindecode/preprocessing/preprocess.py +0 -3
  85. braindecode/regressor.py +18 -15
  86. braindecode/samplers/ssl.py +1 -1
  87. braindecode/util.py +28 -38
  88. braindecode/version.py +1 -1
  89. braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
  90. braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
  91. braindecode/datasets/bids/__init__.py +0 -54
  92. braindecode/datasets/bids/format.py +0 -717
  93. braindecode/datasets/bids/hub_format.py +0 -717
  94. braindecode/datasets/bids/hub_io.py +0 -197
  95. braindecode/datasets/chb_mit.py +0 -163
  96. braindecode/datasets/siena.py +0 -162
  97. braindecode/datasets/utils.py +0 -67
  98. braindecode/models/brainmodule.py +0 -845
  99. braindecode/models/config.py +0 -233
  100. braindecode/models/reve.py +0 -843
  101. braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
  102. braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
  103. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
  104. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
  105. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
  106. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
@@ -2,19 +2,6 @@
2
2
  #
3
3
  # License: BSD (3-clause)
4
4
 
5
- import logging
6
- from typing import Any, Literal
7
-
8
- import mne
9
- import numpy as np
10
- from skorch.helper import SliceDataset
11
- from skorch.utils import is_dataset
12
-
13
- from braindecode.datasets.base import BaseConcatDataset, WindowsDataset
14
- from braindecode.models.util import SigArgName
15
-
16
- log = logging.getLogger(__name__)
17
-
18
5
 
19
6
  def ms_to_samples(ms, fs):
20
7
  """
@@ -22,15 +9,16 @@ def ms_to_samples(ms, fs):
22
9
 
23
10
  Parameters
24
11
  ----------
25
- ms : number
12
+ ms: number
26
13
  Milliseconds
27
- fs : number
14
+ fs: number
28
15
  Sampling rate
29
16
 
30
17
  Returns
31
18
  -------
32
- n_samples : int
19
+ n_samples: int
33
20
  Number of samples
21
+
34
22
  """
35
23
  return ms * fs / 1000.0
36
24
 
@@ -41,114 +29,13 @@ def samples_to_ms(n_samples, fs):
41
29
 
42
30
  Parameters
43
31
  ----------
44
- n_samples : number
32
+ n_samples: number
45
33
  Number of samples
46
- fs : number
34
+ fs: number
47
35
  Sampling rate
48
36
 
49
37
  Returns
50
38
  -------
51
- milliseconds : int
39
+ milliseconds: int
52
40
  """
53
41
  return n_samples * 1000.0 / fs
54
-
55
-
56
- def _get_n_outputs(y, classes, mode):
57
- if mode == "classification":
58
- classes_y = np.unique(y)
59
- if classes is not None:
60
- assert set(classes_y) <= set(classes)
61
- else:
62
- classes = classes_y
63
- return len(classes)
64
- elif mode == "regression":
65
- if y is None:
66
- return None
67
- if y.ndim == 1:
68
- return 1
69
- else:
70
- return y.shape[-1]
71
- else:
72
- raise ValueError(f"Unknown mode {mode}")
73
-
74
-
75
- def infer_signal_properties(
76
- X,
77
- y=None,
78
- mode: Literal["classification", "regression"] = "classification",
79
- classes: list | None = None,
80
- ) -> dict[SigArgName, Any]:
81
- """Infers signal properties from the data.
82
-
83
- The extracted signal properties are:
84
-
85
- + n_chans: number of channels
86
- + n_times: number of time points
87
- + n_outputs: number of outputs
88
- + chs_info: channel information
89
- + sfreq: sampling frequency
90
-
91
- The returned dictionary can serve as kwargs for model initialization.
92
-
93
- Depending on the type of input passed, not all properties can be inferred.
94
-
95
- Parameters
96
- ----------
97
- X : array-like or mne.BaseEpochs or Dataset
98
- Input data
99
- y : array-like or None
100
- Targets
101
- mode : "classification" or "regression"
102
- Mode of the task
103
- classes : list or None
104
- List of classes for classification
105
-
106
- Returns
107
- -------
108
- signal_kwargs : dict
109
- Dictionary with signal-properties. Can serve as kwargs for model
110
- initialization.
111
- """
112
- signal_kwargs: dict[SigArgName, Any] = {}
113
- # Using shape to work both with torch.tensor and numpy.array:
114
- if (
115
- isinstance(X, mne.BaseEpochs)
116
- or (hasattr(X, "shape") and len(X.shape) >= 2)
117
- or isinstance(X, SliceDataset)
118
- ):
119
- if y is None:
120
- raise ValueError("y must be specified if X is array-like.")
121
- signal_kwargs["n_outputs"] = _get_n_outputs(y, classes, mode)
122
- if isinstance(X, mne.BaseEpochs):
123
- log.info("Using mne.Epochs to find signal-related parameters.")
124
- signal_kwargs["n_times"] = len(X.times)
125
- signal_kwargs["sfreq"] = X.info["sfreq"]
126
- signal_kwargs["chs_info"] = X.info["chs"]
127
- elif isinstance(X, SliceDataset):
128
- log.info("Using SliceDataset to find signal-related parameters.")
129
- Xshape = X[0].shape
130
- signal_kwargs["n_times"] = Xshape[-1]
131
- signal_kwargs["n_chans"] = Xshape[-2]
132
- else:
133
- log.info("Using array-like to find signal-related parameters.")
134
- signal_kwargs["n_times"] = X.shape[-1]
135
- signal_kwargs["n_chans"] = X.shape[-2]
136
- elif is_dataset(X):
137
- log.info(f"Using Dataset {X!r} to find signal-related parameters.")
138
- X0 = X[0][0]
139
- Xshape = X0.shape
140
- signal_kwargs["n_times"] = Xshape[-1]
141
- signal_kwargs["n_chans"] = Xshape[-2]
142
- if isinstance(X, BaseConcatDataset) and all(
143
- ds.targets_from == "metadata" for ds in X.datasets
144
- ):
145
- y_target = X.get_metadata().target
146
- signal_kwargs["n_outputs"] = _get_n_outputs(y_target, classes, mode)
147
- elif isinstance(X, WindowsDataset) and X.targets_from == "metadata":
148
- y_target = X.windows.metadata.target
149
- signal_kwargs["n_outputs"] = _get_n_outputs(y_target, classes, mode)
150
- else:
151
- log.warning(
152
- f"Can only infer signal shape of array-like and Datasets, got {type(X)!r}."
153
- )
154
- return signal_kwargs
@@ -7,7 +7,6 @@
7
7
  import abc
8
8
  import inspect
9
9
  import logging
10
- from typing import Literal
11
10
 
12
11
  import mne
13
12
  import numpy as np
@@ -15,10 +14,10 @@ import torch
15
14
  from sklearn.metrics import get_scorer
16
15
  from skorch import NeuralNet
17
16
  from skorch.callbacks import BatchScoring, EpochScoring, EpochTimer, PrintLog
18
- from skorch.utils import noop, to_numpy, train_loss_score, valid_loss_score
19
-
20
- from braindecode.datautil import infer_signal_properties
17
+ from skorch.helper import SliceDataset
18
+ from skorch.utils import is_dataset, noop, to_numpy, train_loss_score, valid_loss_score
21
19
 
20
+ from .datasets.base import BaseConcatDataset, WindowsDataset
22
21
  from .models.util import models_dict
23
22
  from .training.scoring import (
24
23
  CroppedTimeSeriesEpochScoring,
@@ -53,6 +52,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
53
52
 
54
53
  If the module is already initialized and no parameter was changed, it
55
54
  will be left as is.
55
+
56
56
  """
57
57
  kwargs = self.get_params_for("module")
58
58
  module = _get_model(self.module)
@@ -174,9 +174,8 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
174
174
  ("print_log", PrintLog()),
175
175
  ]
176
176
 
177
- @property
178
177
  @abc.abstractmethod
179
- def mode(self) -> Literal["classification", "regression"]:
178
+ def _get_n_outputs(self, y, classes):
180
179
  pass
181
180
 
182
181
  def _set_signal_args(self, X, y, classes):
@@ -192,8 +191,50 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
192
191
  return
193
192
  if classes is None:
194
193
  classes = getattr(self, "classes", None)
195
- signal_kwargs = infer_signal_properties(X, y, mode=self.mode, classes=classes)
196
- if not signal_kwargs:
194
+ # get kwargs from signal:
195
+ signal_kwargs = dict()
196
+ # Using shape to work both with torch.tensor and numpy.array:
197
+ if (
198
+ isinstance(X, mne.BaseEpochs)
199
+ or (hasattr(X, "shape") and len(X.shape) >= 2)
200
+ or isinstance(X, SliceDataset)
201
+ ):
202
+ if y is None:
203
+ raise ValueError("y must be specified if X is array-like.")
204
+ signal_kwargs["n_outputs"] = self._get_n_outputs(y, classes)
205
+ if isinstance(X, mne.BaseEpochs):
206
+ self.log.info("Using mne.Epochs to find signal-related parameters.")
207
+ signal_kwargs["n_times"] = len(X.times)
208
+ signal_kwargs["sfreq"] = X.info["sfreq"]
209
+ signal_kwargs["chs_info"] = X.info["chs"]
210
+ elif isinstance(X, SliceDataset):
211
+ self.log.info("Using SliceDataset to find signal-related parameters.")
212
+ Xshape = X[0].shape
213
+ signal_kwargs["n_times"] = Xshape[-1]
214
+ signal_kwargs["n_chans"] = Xshape[-2]
215
+ else:
216
+ self.log.info("Using array-like to find signal-related parameters.")
217
+ signal_kwargs["n_times"] = X.shape[-1]
218
+ signal_kwargs["n_chans"] = X.shape[-2]
219
+ elif is_dataset(X):
220
+ self.log.info(f"Using Dataset {X!r} to find signal-related parameters.")
221
+ X0 = X[0][0]
222
+ Xshape = X0.shape
223
+ signal_kwargs["n_times"] = Xshape[-1]
224
+ signal_kwargs["n_chans"] = Xshape[-2]
225
+ if isinstance(X, BaseConcatDataset) and all(
226
+ ds.targets_from == "metadata" for ds in X.datasets
227
+ ):
228
+ y_target = X.get_metadata().target
229
+ signal_kwargs["n_outputs"] = self._get_n_outputs(y_target, classes)
230
+ elif isinstance(X, WindowsDataset) and X.targets_from == "metadata":
231
+ y_target = X.windows.metadata.target
232
+ signal_kwargs["n_outputs"] = self._get_n_outputs(y_target, classes)
233
+ else:
234
+ self.log.warning(
235
+ "Can only infer signal shape of array-like and Datasets, "
236
+ f"got {type(X)!r}."
237
+ )
197
238
  return
198
239
 
199
240
  # kick out missing kwargs:
@@ -208,18 +249,6 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
208
249
  else:
209
250
  self.log.warning(f"Module {self.module!r} is missing parameter {k!r}.")
210
251
 
211
- # kick out inferred signal kwargs if user specifies kwargs:
212
- user_specified_kwargs = self.get_params_for("module").items()
213
- if len(user_specified_kwargs) > 0:
214
- self.log.info(
215
- f"Overriding inferred parameters with user "
216
- f"specified parameters{user_specified_kwargs!r}."
217
- )
218
- for k, v in self.get_params_for("module").items():
219
- if k in module_kwargs:
220
- module_kwargs.pop(k)
221
- module_kwargs[k] = v
222
-
223
252
  # save kwargs to self:
224
253
  self.log.info(
225
254
  f"Passing additional parameters {module_kwargs!r} "
@@ -229,8 +258,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
229
258
  self.set_params(**module_kwargs)
230
259
 
231
260
  def get_dataset(self, X, y=None):
232
- """Get a dataset that contains the input data and is passed to.
233
-
261
+ """Get a dataset that contains the input data and is passed to
234
262
  the iterator.
235
263
 
236
264
  Override this if you want to initialize your dataset
@@ -262,6 +290,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
262
290
  -------
263
291
  dataset
264
292
  The initialized dataset.
293
+
265
294
  """
266
295
  if isinstance(X, mne.BaseEpochs):
267
296
  X = X.get_data(units="uV")
@@ -314,6 +343,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
314
343
  **fit_params : dict
315
344
  Additional parameters passed to the ``forward`` method of
316
345
  the module and to the ``self.train_split`` call.
346
+
317
347
  """
318
348
  # this needs to be executed before the net is initialized:
319
349
  if not self.signal_args_set_:
@@ -24,13 +24,14 @@ def drop_path(
24
24
  ):
25
25
  """Drop paths (Stochastic Depth) per sample.
26
26
 
27
+
27
28
  Notes: This implementation is taken from timm library.
28
29
 
29
30
  All credit goes to Ross Wightman.
30
31
 
31
32
  Parameters
32
33
  ----------
33
- x : torch.Tensor
34
+ x: torch.Tensor
34
35
  input tensor
35
36
  drop_prob : float, optional
36
37
  survival rate (i.e. probability of being kept), by default 0.0
@@ -50,10 +51,11 @@ def drop_path(
50
51
  etc. networks, however,
51
52
  the original name is misleading as 'Drop Connect' is a different form
52
53
  of dropout in a separate paper...
53
- See discussion : https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956
54
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956
54
55
  ... I've opted for changing the layer and argument names to 'drop path'
55
56
  rather than mix DropConnect as a layer name and use
56
57
  'survival rate' as the argument.
58
+
57
59
  """
58
60
  if drop_prob == 0.0 or not training:
59
61
  return x
@@ -69,8 +71,7 @@ def drop_path(
69
71
 
70
72
  def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> torch.Tensor:
71
73
  """
72
- Generates a 1-dimensional Gaussian kernel based on the specified kernel.
73
-
74
+ Generates a 1-dimensional Gaussian kernel based on the specified kernel
74
75
  size and standard deviation (sigma).
75
76
  This kernel is useful for Gaussian smoothing or filtering operations in
76
77
  image processing. The function calculates a range limit to ensure the kernel
@@ -79,14 +80,15 @@ def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> torch.Tensor:
79
80
  Gaussian curve, normalized using a softmax function
80
81
  to ensure all weights sum to 1.
81
82
 
83
+
82
84
  Parameters
83
85
  ----------
84
- kernel_size : int
85
- sigma : float
86
+ kernel_size: int
87
+ sigma: float
86
88
 
87
89
  Returns
88
90
  -------
89
- kernel1d : torch.Tensor
91
+ kernel1d: torch.Tensor
90
92
 
91
93
  Notes
92
94
  -----
@@ -95,6 +97,7 @@ def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> torch.Tensor:
95
97
  All rights reserved.
96
98
 
97
99
  LICENSE in https://github.com/pytorch/vision/blob/main/LICENSE
100
+
98
101
  """
99
102
  ksize_half = (kernel_size - 1) * 0.5
100
103
  x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
@@ -4,14 +4,13 @@ from torch import nn
4
4
 
5
5
 
6
6
  def glorot_weight_zero_bias(model):
7
- """Initialize parameters of all modules by initializing weights with.
8
-
7
+ """Initialize parameters of all modules by initializing weights with
9
8
  glorot uniform/xavier initialization, and setting biases to zero. Weights from
10
9
  batch norm layers are set to 1.
11
10
 
12
11
  Parameters
13
12
  ----------
14
- model : Module
13
+ model: Module
15
14
  """
16
15
  for module in model.modules():
17
16
  if hasattr(module, "weight"):
@@ -1,4 +1,6 @@
1
- """Some predefined network architectures for EEG decoding."""
1
+ """
2
+ Some predefined network architectures for EEG decoding.
3
+ """
2
4
 
3
5
  from .atcnet import ATCNet
4
6
  from .attentionbasenet import AttentionBaseNet
@@ -6,7 +8,6 @@ from .attn_sleep import AttnSleep
6
8
  from .base import EEGModuleMixin
7
9
  from .bendr import BENDR
8
10
  from .biot import BIOT
9
- from .brainmodule import BrainModule
10
11
  from .contrawr import ContraWR
11
12
  from .ctnet import CTNet
12
13
  from .deep4 import Deep4Net
@@ -31,7 +32,6 @@ from .luna import LUNA
31
32
  from .medformer import MEDFormer
32
33
  from .msvtnet import MSVTNet
33
34
  from .patchedtransformer import PBT
34
- from .reve import REVE
35
35
  from .sccnet import SCCNet
36
36
  from .shallow_fbcsp import ShallowFBCSPNet
37
37
  from .signal_jepa import (
@@ -71,7 +71,6 @@ __all__ = [
71
71
  "CTNet",
72
72
  "Deep4Net",
73
73
  "DeepSleepNet",
74
- "BrainModule",
75
74
  "EEGConformer",
76
75
  "EEGInceptionERP",
77
76
  "EEGInceptionMI",
@@ -94,7 +93,6 @@ __all__ = [
94
93
  "MEDFormer",
95
94
  "MSVTNet",
96
95
  "PBT",
97
- "REVE",
98
96
  "SCCNet",
99
97
  "ShallowFBCSPNet",
100
98
  "SignalJEPA",
@@ -13,9 +13,9 @@ from braindecode.modules import CausalConv1d, Ensure4d, MaxNormLinear
13
13
 
14
14
 
15
15
  class ATCNet(EEGModuleMixin, nn.Module):
16
- r"""ATCNet from Altaheri et al (2022) [1]_.
16
+ """ATCNet from Altaheri et al. (2022) [1]_.
17
17
 
18
- :bdg-success:`Convolution` :bdg-secondary:`Recurrent` :bdg-info:`Attention/Transformer`
18
+ :bdg-success:`Convolution` :bdg-secondary:`Recurrent` :bdg-info:`Small Attention`
19
19
 
20
20
  .. figure:: https://user-images.githubusercontent.com/25565236/185449791-e8539453-d4fa-41e1-865a-2cf7e91f60ef.png
21
21
  :align: center
@@ -83,8 +83,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
83
83
 
84
84
  - :class:`_TCNResidualBlock` **(causal dilated temporal CNN)**
85
85
 
86
- *Operations:*
87
-
86
+ - *Operations.*
88
87
  - Two :class:`braindecode.modules.CausalConv1d` layers per block with dilation ``1, 2, 4, …``
89
88
  - Across blocks of `torch.nn.ELU` + `torch.nn.BatchNorm1d` + `torch.nn.Dropout`) +
90
89
  a residual (identity or 1x1 mapping).
@@ -95,12 +94,10 @@ class ATCNet(EEGModuleMixin, nn.Module):
95
94
 
96
95
  - **Aggregation & Classifier**
97
96
 
98
- *Operations:*
99
-
97
+ - *Operations.*
100
98
  - Either (a) map each window feature ``(B, F2)`` to logits via :class:`braindecode.modules.MaxNormLinear`
101
- and **average** across windows (default, matching official code), or
99
+ and **average** across windows (default, matching official code), or
102
100
  - (b) **concatenate** all window features ``(B, n·F2)`` and apply a single :class:`MaxNormLinear`.
103
-
104
101
  The max-norm constraint regularizes the readout.
105
102
 
106
103
  .. rubric:: Convolutional Details
@@ -117,6 +114,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
117
114
  producing *per-temporal-filter* spatial projections (no cross-filter mixing at this step).
118
115
  This mirrors EEGNet's interpretability: each temporal filter has its own spatial pattern.
119
116
 
117
+
120
118
  .. rubric:: Attention / Sequential Modules
121
119
 
122
120
  - **Type.** Multi-head self-attention with ``H`` heads and per-head dim ``d_h`` implemented
@@ -143,13 +141,26 @@ class ATCNet(EEGModuleMixin, nn.Module):
143
141
  - Pool sizes ``P1,P2`` trade temporal resolution for stability/compute; they set
144
142
  ``T_c = T/(P1·P2)`` and thus window width ``T_w``.
145
143
  - ``n_windows`` controls the ensemble over shifts (compute ∝ windows).
146
- - ``num_heads``, ``head_dim`` set attention capacity; keep ``H·d_h ≈ F2``.
144
+ - ``att_num_heads``, ``att_head_dim`` set attention capacity; keep ``H·d_h ≈ F2``.
147
145
  - ``tcn_depth``, ``tcn_kernel_size`` govern receptive field; larger values demand
148
146
  longer inputs (see minimum length above). The implementation warns and *rescales*
149
147
  kernels/pools/windows if inputs are too short.
150
148
  - **Aggregation choice.** ``concat=False`` (default, average of per-window logits) matches
151
149
  the official code; ``concat=True`` mirrors the paper's concatenation variant.
152
150
 
151
+
152
+ Notes
153
+ -----
154
+ - Inputs substantially shorter than the implied minimum length trigger **automatic
155
+ downscaling** of kernels, pools, windows, and TCN kernel size to maintain validity.
156
+ - The attention–TCN sequence operates **per window**; the last causal step is used as the
157
+ window feature, aligning the temporal semantics across windows.
158
+
159
+ .. versionadded:: 1.1
160
+
161
+ - More detailed documentation of the model.
162
+
163
+
153
164
  Parameters
154
165
  ----------
155
166
  input_window_seconds : float, optional
@@ -183,10 +194,10 @@ class ATCNet(EEGModuleMixin, nn.Module):
183
194
  table 1 of the paper [1]_. Defaults to 0.3 as in [1]_.
184
195
  n_windows : int
185
196
  Number of sliding windows, denoted n in [1]_. Defaults to 5 as in [1]_.
186
- head_dim : int
197
+ att_head_dim : int
187
198
  Embedding dimension used in each self-attention head, denoted dh in
188
199
  table 1 of the paper [1]_. Defaults to 8 as in [1]_.
189
- num_heads : int
200
+ att_num_heads : int
190
201
  Number of attention heads, denoted H in table 1 of the paper [1]_.
191
202
  Defaults to 2 as in [1]_.
192
203
  att_dropout : float
@@ -214,17 +225,6 @@ class ATCNet(EEGModuleMixin, nn.Module):
214
225
  Maximum L2-norm constraint imposed on weights of the last
215
226
  fully-connected layer. Defaults to 0.25.
216
227
 
217
- Notes
218
- -----
219
- - Inputs substantially shorter than the implied minimum length trigger **automatic
220
- downscaling** of kernels, pools, windows, and TCN kernel size to maintain validity.
221
- - The attention–TCN sequence operates **per window**; the last causal step is used as the
222
- window feature, aligning the temporal semantics across windows.
223
-
224
- .. versionadded:: 1.1
225
-
226
- - More detailed documentation of the model.
227
-
228
228
  References
229
229
  ----------
230
230
  .. [1] H. Altaheri, G. Muhammad, M. Alsulaiman (2022).
@@ -248,13 +248,13 @@ class ATCNet(EEGModuleMixin, nn.Module):
248
248
  conv_block_depth_mult=2,
249
249
  conv_block_dropout=0.3,
250
250
  n_windows=5,
251
- head_dim=8,
252
- num_heads=2,
251
+ att_head_dim=8,
252
+ att_num_heads=2,
253
253
  att_drop_prob=0.5,
254
254
  tcn_depth=2,
255
255
  tcn_kernel_size=4,
256
256
  tcn_drop_prob=0.3,
257
- tcn_activation: type[nn.Module] = nn.ELU,
257
+ tcn_activation: nn.Module = nn.ELU,
258
258
  concat=False,
259
259
  max_norm_const=0.25,
260
260
  chs_info=None,
@@ -316,8 +316,8 @@ class ATCNet(EEGModuleMixin, nn.Module):
316
316
  self.conv_block_depth_mult = conv_block_depth_mult
317
317
  self.conv_block_dropout = conv_block_dropout
318
318
  self.n_windows = n_windows
319
- self.head_dim = head_dim
320
- self.num_heads = num_heads
319
+ self.att_head_dim = att_head_dim
320
+ self.att_num_heads = att_num_heads
321
321
  self.att_dropout = att_drop_prob
322
322
  self.tcn_depth = tcn_depth
323
323
  self.tcn_kernel_size = tcn_kernel_size
@@ -356,8 +356,8 @@ class ATCNet(EEGModuleMixin, nn.Module):
356
356
  [
357
357
  _AttentionBlock(
358
358
  in_shape=self.F2,
359
- head_dim=self.head_dim,
360
- num_heads=num_heads,
359
+ head_dim=self.att_head_dim,
360
+ num_heads=att_num_heads,
361
361
  dropout=att_drop_prob,
362
362
  )
363
363
  for _ in range(self.n_windows)
@@ -460,8 +460,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
460
460
 
461
461
 
462
462
  class _ConvBlock(nn.Module):
463
- r"""Convolutional block proposed in ATCNet [1]_, inspired by the EEGNet.
464
-
463
+ """Convolutional block proposed in ATCNet [1]_, inspired by the EEGNet
465
464
  architecture [2]_.
466
465
 
467
466
  References
@@ -563,8 +562,7 @@ class _ConvBlock(nn.Module):
563
562
 
564
563
 
565
564
  class _AttentionBlock(nn.Module):
566
- r"""Multi Head self Attention (MHA) block used in ATCNet [1]_, inspired from.
567
-
565
+ """Multi Head self Attention (MHA) block used in ATCNet [1]_, inspired from
568
566
  [2]_.
569
567
 
570
568
  References
@@ -638,9 +636,7 @@ class _AttentionBlock(nn.Module):
638
636
 
639
637
 
640
638
  class _TCNResidualBlock(nn.Module):
641
- r"""Modified TCN Residual block as proposed in [1]_.
642
-
643
- Inspired from
639
+ """Modified TCN Residual block as proposed in [1]_. Inspired from
644
640
  Temporal Convolutional Networks (TCN) [2]_.
645
641
 
646
642
  References
@@ -660,7 +656,7 @@ class _TCNResidualBlock(nn.Module):
660
656
  kernel_size=4,
661
657
  n_filters=32,
662
658
  dropout=0.3,
663
- activation: type[nn.Module] = nn.ELU,
659
+ activation: nn.Module = nn.ELU,
664
660
  dilation=1,
665
661
  ):
666
662
  super().__init__()
@@ -736,7 +732,7 @@ class _MHA(nn.Module):
736
732
  num_heads: int,
737
733
  dropout: float = 0.0,
738
734
  ):
739
- """Multi-head Attention.
735
+ """Multi-head Attention
740
736
 
741
737
  The difference between this module and torch.nn.MultiheadAttention is
742
738
  that this module supports embedding dimensions different then input
@@ -779,20 +775,20 @@ class _MHA(nn.Module):
779
775
  def forward(
780
776
  self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor
781
777
  ) -> torch.Tensor:
782
- """Compute MHA(Q, K, V).
778
+ """Compute MHA(Q, K, V)
783
779
 
784
780
  Parameters
785
781
  ----------
786
- Q : torch.Tensor of size (batch_size, seq_len, input_dim)
782
+ Q: torch.Tensor of size (batch_size, seq_len, input_dim)
787
783
  Input query (Q) sequence.
788
- K : torch.Tensor of size (batch_size, seq_len, input_dim)
784
+ K: torch.Tensor of size (batch_size, seq_len, input_dim)
789
785
  Input key (K) sequence.
790
- V : torch.Tensor of size (batch_size, seq_len, input_dim)
786
+ V: torch.Tensor of size (batch_size, seq_len, input_dim)
791
787
  Input value (V) sequence.
792
788
 
793
789
  Returns
794
790
  -------
795
- O : torch.Tensor of size (batch_size, seq_len, output_dim)
791
+ O: torch.Tensor of size (batch_size, seq_len, output_dim)
796
792
  Output MHA(Q, K, V)
797
793
  """
798
794
  assert Q.shape[-1] == K.shape[-1] == V.shape[-1] == self.input_dim