braindecode 1.2.0.dev184328194__py3-none-any.whl → 1.3.0.dev171178473__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (70) hide show
  1. braindecode/augmentation/base.py +1 -1
  2. braindecode/augmentation/functional.py +154 -54
  3. braindecode/augmentation/transforms.py +2 -2
  4. braindecode/datasets/__init__.py +10 -2
  5. braindecode/datasets/base.py +116 -152
  6. braindecode/datasets/bcicomp.py +4 -4
  7. braindecode/datasets/bids.py +3 -3
  8. braindecode/datasets/experimental.py +218 -0
  9. braindecode/datasets/mne.py +3 -5
  10. braindecode/datasets/moabb.py +2 -2
  11. braindecode/datasets/nmt.py +2 -2
  12. braindecode/datasets/sleep_physio_challe_18.py +4 -3
  13. braindecode/datasets/sleep_physionet.py +2 -2
  14. braindecode/datasets/tuh.py +2 -2
  15. braindecode/datasets/xy.py +2 -2
  16. braindecode/datautil/serialization.py +18 -13
  17. braindecode/eegneuralnet.py +2 -0
  18. braindecode/functional/functions.py +6 -2
  19. braindecode/functional/initialization.py +2 -3
  20. braindecode/models/__init__.py +12 -8
  21. braindecode/models/atcnet.py +156 -17
  22. braindecode/models/attentionbasenet.py +148 -16
  23. braindecode/models/{sleep_stager_eldele_2021.py → attn_sleep.py} +12 -2
  24. braindecode/models/base.py +280 -2
  25. braindecode/models/bendr.py +469 -0
  26. braindecode/models/biot.py +3 -1
  27. braindecode/models/ctnet.py +7 -4
  28. braindecode/models/deep4.py +6 -2
  29. braindecode/models/deepsleepnet.py +127 -5
  30. braindecode/models/eegconformer.py +114 -15
  31. braindecode/models/eeginception_erp.py +82 -7
  32. braindecode/models/eeginception_mi.py +2 -0
  33. braindecode/models/eegnet.py +64 -177
  34. braindecode/models/eegnex.py +113 -6
  35. braindecode/models/eegsimpleconv.py +2 -0
  36. braindecode/models/eegtcnet.py +1 -1
  37. braindecode/models/labram.py +188 -84
  38. braindecode/models/patchedtransformer.py +640 -0
  39. braindecode/models/sccnet.py +81 -8
  40. braindecode/models/shallow_fbcsp.py +2 -0
  41. braindecode/models/signal_jepa.py +109 -27
  42. braindecode/models/sinc_shallow.py +10 -9
  43. braindecode/models/sleep_stager_blanco_2020.py +2 -0
  44. braindecode/models/sleep_stager_chambon_2018.py +2 -0
  45. braindecode/models/sparcnet.py +2 -0
  46. braindecode/models/sstdpn.py +869 -0
  47. braindecode/models/summary.csv +42 -41
  48. braindecode/models/tidnet.py +2 -0
  49. braindecode/models/tsinception.py +15 -3
  50. braindecode/models/usleep.py +108 -9
  51. braindecode/models/util.py +8 -5
  52. braindecode/modules/attention.py +10 -10
  53. braindecode/modules/blocks.py +3 -3
  54. braindecode/modules/filter.py +2 -3
  55. braindecode/modules/layers.py +18 -17
  56. braindecode/preprocessing/__init__.py +24 -0
  57. braindecode/preprocessing/eegprep_preprocess.py +1202 -0
  58. braindecode/preprocessing/preprocess.py +42 -39
  59. braindecode/preprocessing/util.py +166 -0
  60. braindecode/preprocessing/windowers.py +24 -19
  61. braindecode/samplers/base.py +8 -8
  62. braindecode/version.py +1 -1
  63. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/METADATA +12 -3
  64. braindecode-1.3.0.dev171178473.dist-info/RECORD +106 -0
  65. braindecode/models/eegresnet.py +0 -362
  66. braindecode-1.2.0.dev184328194.dist-info/RECORD +0 -101
  67. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/WHEEL +0 -0
  68. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/licenses/LICENSE.txt +0 -0
  69. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/licenses/NOTICE.txt +0 -0
  70. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,7 @@ import mne
9
9
  import numpy as np
10
10
  import pandas as pd
11
11
 
12
- from .base import BaseConcatDataset, BaseDataset, WindowsDataset
12
+ from .base import BaseConcatDataset, RawDataset, WindowsDataset
13
13
 
14
14
 
15
15
  def create_from_mne_raw(
@@ -75,11 +75,9 @@ def create_from_mne_raw(
75
75
  f"length of 'raws' ({len(raws)}) and 'description' "
76
76
  f"({len(descriptions)}) has to match"
77
77
  )
78
- base_datasets = [
79
- BaseDataset(raw, desc) for raw, desc in zip(raws, descriptions)
80
- ]
78
+ base_datasets = [RawDataset(raw, desc) for raw, desc in zip(raws, descriptions)]
81
79
  else:
82
- base_datasets = [BaseDataset(raw) for raw in raws]
80
+ base_datasets = [RawDataset(raw) for raw in raws]
83
81
 
84
82
  base_datasets = BaseConcatDataset(base_datasets)
85
83
  windows_datasets = create_windows_from_events(
@@ -18,7 +18,7 @@ import pandas as pd
18
18
 
19
19
  from braindecode.util import _update_moabb_docstring
20
20
 
21
- from .base import BaseConcatDataset, BaseDataset
21
+ from .base import BaseConcatDataset, RawDataset
22
22
 
23
23
 
24
24
  def _find_dataset_in_moabb(dataset_name, dataset_kwargs=None):
@@ -164,7 +164,7 @@ class MOABBDataset(BaseConcatDataset):
164
164
  dataset_load_kwargs=dataset_load_kwargs,
165
165
  )
166
166
  all_base_ds = [
167
- BaseDataset(raw, row) for raw, (_, row) in zip(raws, description.iterrows())
167
+ RawDataset(raw, row) for raw, (_, row) in zip(raws, description.iterrows())
168
168
  ]
169
169
  super().__init__(all_base_ds)
170
170
 
@@ -31,7 +31,7 @@ import pandas as pd
31
31
  from joblib import Parallel, delayed
32
32
  from mne.datasets import fetch_dataset
33
33
 
34
- from braindecode.datasets.base import BaseConcatDataset, BaseDataset
34
+ from braindecode.datasets.base import BaseConcatDataset, RawDataset
35
35
 
36
36
  NMT_URL = "https://zenodo.org/record/10909103/files/NMT.zip"
37
37
  NMT_archive_name = "NMT.zip"
@@ -172,7 +172,7 @@ class NMT(BaseConcatDataset):
172
172
  d["n_samples"] = raw.n_times
173
173
  d["sfreq"] = raw.info["sfreq"]
174
174
  d["train"] = "train" in d.path.split(os.sep)
175
- base_dataset = BaseDataset(raw, d, target_name)
175
+ base_dataset = RawDataset(raw, d, target_name)
176
176
  return base_dataset
177
177
 
178
178
 
@@ -21,8 +21,7 @@ from mne.datasets.sleep_physionet._utils import _fetch_one
21
21
  from mne.datasets.utils import _get_path
22
22
  from mne.utils import warn
23
23
 
24
- from braindecode.datasets import BaseConcatDataset, BaseDataset
25
- from braindecode.preprocessing.preprocess import _preprocess
24
+ from braindecode.datasets import BaseConcatDataset, RawDataset
26
25
 
27
26
  PC18_DIR = op.join(op.dirname(__file__), "data", "pc18")
28
27
  PC18_RECORDS = op.join(PC18_DIR, "sleep_records.csv")
@@ -404,9 +403,11 @@ class SleepPhysionetChallenge2018(BaseConcatDataset):
404
403
  },
405
404
  name="",
406
405
  )
407
- base_dataset = BaseDataset(raw_file, desc)
406
+ base_dataset = RawDataset(raw_file, desc)
408
407
 
409
408
  if preproc is not None:
409
+ from braindecode.preprocessing.preprocess import _preprocess
410
+
410
411
  _preprocess(base_dataset, None, preproc)
411
412
 
412
413
  return base_dataset
@@ -12,7 +12,7 @@ import numpy as np
12
12
  import pandas as pd
13
13
  from mne.datasets.sleep_physionet.age import fetch_data
14
14
 
15
- from .base import BaseConcatDataset, BaseDataset
15
+ from .base import BaseConcatDataset, RawDataset
16
16
 
17
17
 
18
18
  class SleepPhysionet(BaseConcatDataset):
@@ -71,7 +71,7 @@ class SleepPhysionet(BaseConcatDataset):
71
71
  crop_wake_mins=crop_wake_mins,
72
72
  crop=crop,
73
73
  )
74
- base_ds = BaseDataset(raw, desc)
74
+ base_ds = RawDataset(raw, desc)
75
75
  all_base_ds.append(base_ds)
76
76
  super().__init__(all_base_ds)
77
77
 
@@ -22,7 +22,7 @@ import numpy as np
22
22
  import pandas as pd
23
23
  from joblib import Parallel, delayed
24
24
 
25
- from .base import BaseConcatDataset, BaseDataset
25
+ from .base import BaseConcatDataset, RawDataset
26
26
 
27
27
 
28
28
  class TUH(BaseConcatDataset):
@@ -214,7 +214,7 @@ class TUH(BaseConcatDataset):
214
214
  d["report"] = physician_report
215
215
  additional_description = pd.Series(d)
216
216
  description = pd.concat([description, additional_description])
217
- base_dataset = BaseDataset(raw, description, target_name=target_name)
217
+ base_dataset = RawDataset(raw, description, target_name=target_name)
218
218
  return base_dataset
219
219
 
220
220
 
@@ -12,7 +12,7 @@ import numpy as np
12
12
  import pandas as pd
13
13
  from numpy.typing import ArrayLike, NDArray
14
14
 
15
- from .base import BaseConcatDataset, BaseDataset
15
+ from .base import BaseConcatDataset, RawDataset
16
16
 
17
17
  log = logging.getLogger(__name__)
18
18
 
@@ -69,7 +69,7 @@ def create_from_X_y(
69
69
  n_samples_per_x.append(x.shape[1])
70
70
  info = mne.create_info(ch_names=ch_names, sfreq=sfreq)
71
71
  raw = mne.io.RawArray(x, info)
72
- base_dataset = BaseDataset(
72
+ base_dataset = RawDataset(
73
73
  raw, pd.Series({"target": target}), target_name="target"
74
74
  )
75
75
  base_datasets.append(base_dataset)
@@ -19,8 +19,8 @@ from joblib import Parallel, delayed
19
19
 
20
20
  from ..datasets.base import (
21
21
  BaseConcatDataset,
22
- BaseDataset,
23
22
  EEGWindowsDataset,
23
+ RawDataset,
24
24
  WindowsDataset,
25
25
  )
26
26
 
@@ -35,7 +35,7 @@ def save_concat_dataset(path, concat_dataset, overwrite=False):
35
35
 
36
36
 
37
37
  def _outdated_load_concat_dataset(path, preload, ids_to_load=None, target_name=None):
38
- """Load a stored BaseConcatDataset of BaseDatasets or WindowsDatasets from
38
+ """Load a stored BaseConcatDataset from
39
39
  files.
40
40
 
41
41
  Parameters
@@ -52,7 +52,7 @@ def _outdated_load_concat_dataset(path, preload, ids_to_load=None, target_name=N
52
52
 
53
53
  Returns
54
54
  -------
55
- concat_dataset: BaseConcatDataset of BaseDatasets or WindowsDatasets
55
+ concat_dataset: BaseConcatDataset
56
56
  """
57
57
  # assume we have a single concat dataset to load
58
58
  is_raw = (path / "0-raw.fif").is_file()
@@ -87,7 +87,7 @@ def _outdated_load_concat_dataset(path, preload, ids_to_load=None, target_name=N
87
87
  for i_signal, signal in enumerate(all_signals):
88
88
  if is_raw:
89
89
  datasets.append(
90
- BaseDataset(
90
+ RawDataset(
91
91
  signal, description.iloc[i_signal], target_name=target_name
92
92
  )
93
93
  )
@@ -138,12 +138,17 @@ def _load_signals(fif_file, preload, is_raw):
138
138
  with open(pkl_file, "rb") as f:
139
139
  signals = pickle.load(f)
140
140
 
141
- # If the file has been moved together with the pickle file, make sure
142
- # the path links to correct fif file.
143
- signals._fname = str(fif_file)
144
- if preload:
145
- signals.load_data()
146
- return signals
141
+ if all(f.exists() for f in signals.filenames):
142
+ if preload:
143
+ signals.load_data()
144
+ return signals
145
+ else: # This may happen if the file has been moved together with the pickle file.
146
+ warnings.warn(
147
+ f"Pickle file {pkl_file} exists, but the referenced fif "
148
+ "file(s) do not exist. Will read the fif file(s) directly "
149
+ "and re-create the pickle file.",
150
+ UserWarning,
151
+ )
147
152
 
148
153
  # If pickle didn't exist read via mne (likely slower) and save pkl after
149
154
  if is_raw:
@@ -170,7 +175,7 @@ def _load_signals(fif_file, preload, is_raw):
170
175
 
171
176
 
172
177
  def load_concat_dataset(path, preload, ids_to_load=None, target_name=None, n_jobs=1):
173
- """Load a stored BaseConcatDataset of BaseDatasets or WindowsDatasets from
178
+ """Load a stored BaseConcatDataset from
174
179
  files.
175
180
 
176
181
  Parameters
@@ -189,7 +194,7 @@ def load_concat_dataset(path, preload, ids_to_load=None, target_name=None, n_job
189
194
 
190
195
  Returns
191
196
  -------
192
- concat_dataset: BaseConcatDataset of BaseDatasets or WindowsDatasets
197
+ concat_dataset: BaseConcatDataset
193
198
  """
194
199
  # Make sure we always work with a pathlib.Path
195
200
  path = Path(path)
@@ -261,7 +266,7 @@ def _load_parallel(path, i, preload, is_raw, has_stored_windows):
261
266
  target_name = json.load(open(target_file_path, "r"))["target_name"]
262
267
 
263
268
  if is_raw and (not has_stored_windows):
264
- dataset = BaseDataset(signals, description, target_name)
269
+ dataset = RawDataset(signals, description, target_name)
265
270
  else:
266
271
  window_kwargs = _load_kwargs_json("window_kwargs", sub_dir)
267
272
  windows_ds_kwargs = [
@@ -189,6 +189,8 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
189
189
  "Skipping setting signal-related parameters from data."
190
190
  )
191
191
  return
192
+ if classes is None:
193
+ classes = getattr(self, "classes", None)
192
194
  # get kwargs from signal:
193
195
  signal_kwargs = dict()
194
196
  # Using shape to work both with torch.tensor and numpy.array:
@@ -181,20 +181,24 @@ def plv_time(x, forward_fourier=True, epsilon: float = 1e-6):
181
181
  The Phase Locking Value (PLV) is a measure of the synchronization between
182
182
  different channels by evaluating the consistency of phase differences
183
183
  over time. It ranges from 0 (no synchronization) to 1 (perfect
184
- synchronization) [1]_.
184
+ synchronization) [Lachaux1999]_.
185
185
 
186
186
  Parameters
187
187
  ----------
188
188
  x : torch.Tensor
189
189
  Input tensor containing the signal data.
190
+
190
191
  - If `forward_fourier` is `True`, the shape should be `(..., channels, time)`.
191
192
  - If `forward_fourier` is `False`, the shape should be `(..., channels, freqs, 2)`,
192
193
  where the last dimension represents the real and imaginary parts.
194
+
193
195
  forward_fourier : bool, optional
194
196
  Specifies the format of the input tensor `x`.
197
+
195
198
  - If `True`, `x` is assumed to be in the time domain.
196
199
  - If `False`, `x` is assumed to be in the Fourier domain with separate real and
197
200
  imaginary components.
201
+
198
202
  Default is `True`.
199
203
  epsilon : float, default 1e-6
200
204
  Small numerical value to ensure positivity constraint on the complex part
@@ -207,7 +211,7 @@ def plv_time(x, forward_fourier=True, epsilon: float = 1e-6):
207
211
 
208
212
  References
209
213
  ----------
210
- [1] Lachaux, J. P., Rodriguez, E., Martinerie, J., & Varela, F. J. (1999).
214
+ .. [Lachaux1999] Lachaux, J. P., Rodriguez, E., Martinerie, J., & Varela, F. J. (1999).
211
215
  Measuring phase synchrony in brain signals. Human brain mapping,
212
216
  8(4), 194-208.
213
217
  """
@@ -5,9 +5,8 @@ from torch import nn
5
5
 
6
6
  def glorot_weight_zero_bias(model):
7
7
  """Initialize parameters of all modules by initializing weights with
8
- glorot
9
- uniform/xavier initialization, and setting biases to zero. Weights from
10
- batch norm layers are set to 1.
8
+ glorot uniform/xavier initialization, and setting biases to zero. Weights from
9
+ batch norm layers are set to 1.
11
10
 
12
11
  Parameters
13
12
  ----------
@@ -4,7 +4,9 @@ Some predefined network architectures for EEG decoding.
4
4
 
5
5
  from .atcnet import ATCNet
6
6
  from .attentionbasenet import AttentionBaseNet
7
+ from .attn_sleep import AttnSleep
7
8
  from .base import EEGModuleMixin
9
+ from .bendr import BENDR
8
10
  from .biot import BIOT
9
11
  from .contrawr import ContraWR
10
12
  from .ctnet import CTNet
@@ -15,9 +17,8 @@ from .eeginception_erp import EEGInceptionERP
15
17
  from .eeginception_mi import EEGInceptionMI
16
18
  from .eegitnet import EEGITNet
17
19
  from .eegminer import EEGMiner
18
- from .eegnet import EEGNetv1, EEGNetv4
20
+ from .eegnet import EEGNet, EEGNetv4
19
21
  from .eegnex import EEGNeX
20
- from .eegresnet import EEGResNet
21
22
  from .eegsimpleconv import EEGSimpleConv
22
23
  from .eegtcnet import EEGTCNet
23
24
  from .fbcnet import FBCNet
@@ -27,6 +28,7 @@ from .hybrid import HybridNet
27
28
  from .ifnet import IFNet
28
29
  from .labram import Labram
29
30
  from .msvtnet import MSVTNet
31
+ from .patchedtransformer import PBT
30
32
  from .sccnet import SCCNet
31
33
  from .shallow_fbcsp import ShallowFBCSPNet
32
34
  from .signal_jepa import (
@@ -38,12 +40,12 @@ from .signal_jepa import (
38
40
  from .sinc_shallow import SincShallowNet
39
41
  from .sleep_stager_blanco_2020 import SleepStagerBlanco2020
40
42
  from .sleep_stager_chambon_2018 import SleepStagerChambon2018
41
- from .sleep_stager_eldele_2021 import SleepStagerEldele2021
42
43
  from .sparcnet import SPARCNet
44
+ from .sstdpn import SSTDPN
43
45
  from .syncnet import SyncNet
44
46
  from .tcn import BDTCN, TCN
45
47
  from .tidnet import TIDNet
46
- from .tsinception import TSceptionV1
48
+ from .tsinception import TSception
47
49
  from .usleep import USleep
48
50
  from .util import _init_models_dict, models_mandatory_parameters
49
51
 
@@ -53,9 +55,11 @@ _init_models_dict()
53
55
 
54
56
  __all__ = [
55
57
  "ATCNet",
58
+ "AttnSleep",
56
59
  "AttentionBaseNet",
57
60
  "EEGModuleMixin",
58
61
  "BIOT",
62
+ "BENDR",
59
63
  "ContraWR",
60
64
  "CTNet",
61
65
  "Deep4Net",
@@ -65,10 +69,9 @@ __all__ = [
65
69
  "EEGInceptionMI",
66
70
  "EEGITNet",
67
71
  "EEGMiner",
68
- "EEGNetv1",
72
+ "EEGNet",
69
73
  "EEGNetv4",
70
74
  "EEGNeX",
71
- "EEGResNet",
72
75
  "EEGSimpleConv",
73
76
  "EEGTCNet",
74
77
  "FBCNet",
@@ -78,6 +81,7 @@ __all__ = [
78
81
  "IFNet",
79
82
  "Labram",
80
83
  "MSVTNet",
84
+ "PBT",
81
85
  "SCCNet",
82
86
  "ShallowFBCSPNet",
83
87
  "SignalJEPA",
@@ -85,15 +89,15 @@ __all__ = [
85
89
  "SignalJEPA_PostLocal",
86
90
  "SignalJEPA_PreLocal",
87
91
  "SincShallowNet",
92
+ "SSTDPN",
88
93
  "SleepStagerBlanco2020",
89
94
  "SleepStagerChambon2018",
90
- "SleepStagerEldele2021",
91
95
  "SPARCNet",
92
96
  "SyncNet",
93
97
  "BDTCN",
94
98
  "TCN",
95
99
  "TIDNet",
96
- "TSceptionV1",
100
+ "TSception",
97
101
  "USleep",
98
102
  "_init_models_dict",
99
103
  "models_mandatory_parameters",
@@ -13,13 +13,153 @@ from braindecode.modules import CausalConv1d, Ensure4d, MaxNormLinear
13
13
 
14
14
 
15
15
  class ATCNet(EEGModuleMixin, nn.Module):
16
- """ATCNet model from Altaheri et al. (2022) [1]_
16
+ """ATCNet from Altaheri et al. (2022) [1]_.
17
17
 
18
- Pytorch implementation based on official tensorflow code [2]_.
18
+ :bdg-success:`Convolution` :bdg-info:`Small Attention`
19
19
 
20
20
  .. figure:: https://user-images.githubusercontent.com/25565236/185449791-e8539453-d4fa-41e1-865a-2cf7e91f60ef.png
21
- :align: center
22
- :alt: ATCNet Architecture
21
+ :align: center
22
+ :alt: ATCNet Architecture
23
+ :width: 650px
24
+
25
+ .. rubric:: Architectural Overview
26
+
27
+ ATCNet is a *convolution-first* architecture augmented with a *lightweight attention–TCN*
28
+ sequence module. The end-to-end flow is:
29
+
30
+ - (i) :class:`_ConvBlock` learns temporal filter-banks and spatial projections (EEGNet-style),
31
+ downsampling time to a compact feature map;
32
+
33
+ - (ii) Sliding Windows carve overlapping temporal windows from this map;
34
+
35
+ - (iii) for each window, :class:`_AttentionBlock` applies small multi-head self-attention
36
+ over time, followed by a :class:`_TCNResidualBlock` stack (causal, dilated);
37
+
38
+ - (iv) window-level features are aggregated (mean of window logits or concatenation)
39
+ and mapped via a max-norm–constrained linear layer.
40
+
41
+ Relative to ViT, ATCNet replaces linear patch projection with learned *temporal–spatial*
42
+ convolutions; it processes *parallel* window encoders (attention→TCN) instead of a deep
43
+ stack; and swaps the MLP head for a TCN suited to 1-D EEG sequences.
44
+
45
+ .. rubric:: Macro Components
46
+
47
+ - :class:`_ConvBlock` **(Shallow conv stem → feature map)**
48
+
49
+ - *Operations.*
50
+ - **Temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(L_t, 1)`` builds a
51
+ FIR-like filter bank (``F1`` maps).
52
+ - **Depthwise spatial conv** (:class:`torch.nn.Conv2d`, ``groups=F1``) with kernel
53
+ ``(1, n_chans)`` learns per-filter spatial projections (akin to EEGNet's CSP-like step).
54
+ - **BN → ELU → AvgPool → Dropout** to stabilize and condense activations.
55
+ - **Refining temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(L_r, 1)`` +
56
+ **BN → ELU → AvgPool → Dropout**.
57
+
58
+ The output shape is ``(B, F2, T_c, 1)`` with ``F2 = F1·D`` and ``T_c = T/(P1·P2)``.
59
+ Temporal kernels behave as FIR filters; the depthwise-spatial conv yields frequency-specific
60
+ topographies. Pooling acts as a local integrator, reducing variance and imposing a
61
+ useful inductive bias on short EEG windows.
62
+
63
+ - **Sliding-Window Sequencer**
64
+
65
+ From the condensed time axis (length ``T_c``), ATCNet forms ``n`` overlapping windows
66
+ of width ``T_w = T_c - n + 1`` (one start per index). Each window produces a sequence
67
+ ``(B, F2, T_w)`` forwarded to its own attention-TCN branch. This creates *parallel*
68
+ encoders over shifted contexts and is key to robustness on nonstationary EEG.
69
+
70
+ - :class:`_AttentionBlock` **(small MHA on temporal positions)**
71
+
72
+ Attention here is *local to a window* and purely temporal.
73
+
74
+ - *Operations.*
75
+ - Rearrange to ``(B, T_w, F2)``,
76
+ - Normalization :class:`torch.nn.LayerNorm`
77
+ - Custom MultiHeadAttention :class:`_MHA` (``num_heads=H``, per-head dim ``d_h``) + residual add,
78
+ - Dropout :class:`torch.nn.Dropout`
79
+ - Rearrange back to ``(B, F2, T_w)``.
80
+
81
+ *Role.* Re-weights evidence across the window, letting the model emphasize informative
82
+ segments (onsets, bursts) before causal convolutions aggregate history.
83
+
84
+ - :class:`_TCNResidualBlock` **(causal dilated temporal CNN)**
85
+
86
+ - *Operations.*
87
+ - Two :class:`braindecode.modules.CausalConv1d` layers per block with dilation ``1, 2, 4, …``
88
+ - Across blocks of `torch.nn.ELU` + `torch.nn.BatchNorm1d` + `torch.nn.Dropout`) +
89
+ a residual (identity or 1x1 mapping).
90
+ - The final feature used per window is the *last* causal step ``[..., -1]`` (forecast-style).
91
+
92
+ *Role.* Efficient long-range temporal integration with stable gradients; the dilated
93
+ receptive field complements attention's soft selection.
94
+
95
+ - **Aggregation & Classifier**
96
+
97
+ - *Operations.*
98
+ - Either (a) map each window feature ``(B, F2)`` to logits via :class:`braindecode.modules.MaxNormLinear`
99
+ and **average** across windows (default, matching official code), or
100
+ - (b) **concatenate** all window features ``(B, n·F2)`` and apply a single :class:`MaxNormLinear`.
101
+ The max-norm constraint regularizes the readout.
102
+
103
+ .. rubric:: Convolutional Details
104
+
105
+ - **Temporal.** Temporal structure is learned in three places:
106
+ - (1) the stem's wide ``(L_t, 1)`` conv (learned filter bank),
107
+ - (2) the refining ``(L_r, 1)`` conv after pooling (short-term dynamics), and
108
+ - (3) the TCN's causal 1-D convolutions with exponentially increasing dilation
109
+ (long-range dependencies). The minimum sequence length required by the TCN stack is
110
+ ``(K_t - 1)·2^{L-1} + 1``; the implementation *auto-scales* kernels/pools/windows
111
+ when inputs are shorter to preserve feasibility.
112
+
113
+ - **Spatial.** A depthwise spatial conv spans the **full montage** (kernel ``(1, n_chans)``),
114
+ producing *per-temporal-filter* spatial projections (no cross-filter mixing at this step).
115
+ This mirrors EEGNet's interpretability: each temporal filter has its own spatial pattern.
116
+
117
+
118
+ .. rubric:: Attention / Sequential Modules
119
+
120
+ - **Type.** Multi-head self-attention with ``H`` heads and per-head dim ``d_h`` implemented
121
+ in :class:`_MHA`, allowing ``embed_dim = H·d_h`` independent of input and output dims.
122
+ - **Shapes.** ``(B, F2, T_w) → (B, T_w, F2) → (B, F2, T_w)``. Attention operates along
123
+ the **temporal** axis within a window; channels/features stay in the embedding dim ``F2``.
124
+ - **Role.** Highlights salient temporal positions prior to causal convolution; small attention
125
+ keeps compute modest while improving context modeling over pooled features.
126
+
127
+ .. rubric:: Additional Mechanisms
128
+
129
+ - **Parallel encoders over shifted windows.** Improves montage/phase robustness by
130
+ ensembling nearby contexts rather than committing to a single segmentation.
131
+ - **Max-norm classifier.** Enforces weight norm constraints at the readout, a common
132
+ stabilization trick in EEG decoding.
133
+ - **ViT vs. ATCNet (design choices).** Convolutional *nonlinear* projection rather than
134
+ linear patchification; attention followed by **TCN** (not MLP); *parallel* window
135
+ encoders rather than stacked encoders.
136
+
137
+ .. rubric:: Usage and Configuration
138
+
139
+ - ``conv_block_n_filters (F1)``, ``conv_block_depth_mult (D)`` → capacity of the stem
140
+ (with ``F2 = F1·D`` feeding attention/TCN), dimensions aligned to ``F2``, like :class:`EEGNet`.
141
+ - Pool sizes ``P1,P2`` trade temporal resolution for stability/compute; they set
142
+ ``T_c = T/(P1·P2)`` and thus window width ``T_w``.
143
+ - ``n_windows`` controls the ensemble over shifts (compute ∝ windows).
144
+ - ``att_num_heads``, ``att_head_dim`` set attention capacity; keep ``H·d_h ≈ F2``.
145
+ - ``tcn_depth``, ``tcn_kernel_size`` govern receptive field; larger values demand
146
+ longer inputs (see minimum length above). The implementation warns and *rescales*
147
+ kernels/pools/windows if inputs are too short.
148
+ - **Aggregation choice.** ``concat=False`` (default, average of per-window logits) matches
149
+ the official code; ``concat=True`` mirrors the paper's concatenation variant.
150
+
151
+
152
+ Notes
153
+ -----
154
+ - Inputs substantially shorter than the implied minimum length trigger **automatic
155
+ downscaling** of kernels, pools, windows, and TCN kernel size to maintain validity.
156
+ - The attention–TCN sequence operates **per window**; the last causal step is used as the
157
+ window feature, aligning the temporal semantics across windows.
158
+
159
+ .. versionadded:: 1.1
160
+
161
+ - More detailed documentation of the model.
162
+
23
163
 
24
164
  Parameters
25
165
  ----------
@@ -85,15 +225,13 @@ class ATCNet(EEGModuleMixin, nn.Module):
85
225
  Maximum L2-norm constraint imposed on weights of the last
86
226
  fully-connected layer. Defaults to 0.25.
87
227
 
88
-
89
228
  References
90
229
  ----------
91
- .. [1] H. Altaheri, G. Muhammad and M. Alsulaiman,
92
- Physics-informed attention temporal convolutional network for EEG-based
93
- motor imagery classification in IEEE Transactions on Industrial Informatics,
94
- 2022, doi: 10.1109/TII.2022.3197419.
95
- .. [2] EEE-ATCNet implementation.
96
- https://github.com/Altaheri/EEG-ATCNet/blob/main/models.py
230
+ .. [1] H. Altaheri, G. Muhammad, M. Alsulaiman (2022).
231
+ *Physics-informed attention temporal convolutional network for EEG-based motor imagery classification.*
232
+ IEEE Transactions on Industrial Informatics. doi:10.1109/TII.2022.3197419.
233
+ .. [2] Official EEG-ATCNet implementation (TensorFlow):
234
+ https://github.com/Altaheri/EEG-ATCNet/blob/main/models.py
97
235
  """
98
236
 
99
237
  def __init__(
@@ -231,7 +369,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
231
369
  nn.Sequential(
232
370
  *[
233
371
  _TCNResidualBlock(
234
- in_channels=self.F2,
372
+ in_channels=self.F2 if i == 0 else self.tcn_n_filters,
235
373
  kernel_size=self.tcn_kernel_size,
236
374
  n_filters=self.tcn_n_filters,
237
375
  dropout=self.tcn_dropout,
@@ -249,7 +387,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
249
387
  self.final_layer = nn.ModuleList(
250
388
  [
251
389
  MaxNormLinear(
252
- in_features=self.F2 * self.n_windows,
390
+ in_features=self.tcn_n_filters * self.n_windows,
253
391
  out_features=self.n_outputs,
254
392
  max_norm_val=self.max_norm_const,
255
393
  )
@@ -259,7 +397,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
259
397
  self.final_layer = nn.ModuleList(
260
398
  [
261
399
  MaxNormLinear(
262
- in_features=self.F2,
400
+ in_features=self.tcn_n_filters,
263
401
  out_features=self.n_outputs,
264
402
  max_norm_val=self.max_norm_const,
265
403
  )
@@ -269,7 +407,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
269
407
 
270
408
  self.out_fun = nn.Identity()
271
409
 
272
- def forward(self, X):
410
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
273
411
  # Dimension: (batch_size, C, T)
274
412
  X = self.ensuredims(X)
275
413
  # Dimension: (batch_size, C, T, 1)
@@ -556,7 +694,8 @@ class _TCNResidualBlock(nn.Module):
556
694
  # Reshape the input for the residual connection when necessary
557
695
  if in_channels != n_filters:
558
696
  self.reshaping_conv = nn.Conv1d(
559
- n_filters,
697
+ in_channels=in_channels, # Specify input channels
698
+ out_channels=n_filters, # Specify output channels
560
699
  kernel_size=1,
561
700
  padding="same",
562
701
  )
@@ -576,7 +715,7 @@ class _TCNResidualBlock(nn.Module):
576
715
  out = self.activation(out)
577
716
  out = self.drop2(out)
578
717
 
579
- out = self.reshaping_conv(out)
718
+ X = self.reshaping_conv(X)
580
719
 
581
720
  # ----- Residual connection -----
582
721
  out = X + out