braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,830 @@
1
+ # Authors: Cedric Rommel <cedric.rommel@inria.fr>
2
+ #
3
+ # License: BSD (3-clause)
4
+ import math
5
+
6
+ import torch
7
+ from einops.layers.torch import Rearrange
8
+ from mne.utils import warn
9
+ from torch import nn
10
+
11
+ from braindecode.models.base import EEGModuleMixin
12
+ from braindecode.modules import CausalConv1d, Ensure4d, MaxNormLinear
13
+
14
+
15
+ class ATCNet(EEGModuleMixin, nn.Module):
16
+ r"""ATCNet from Altaheri et al (2022) [1]_.
17
+
18
+ :bdg-success:`Convolution` :bdg-secondary:`Recurrent` :bdg-info:`Attention/Transformer`
19
+
20
+ .. figure:: https://user-images.githubusercontent.com/25565236/185449791-e8539453-d4fa-41e1-865a-2cf7e91f60ef.png
21
+ :align: center
22
+ :alt: ATCNet Architecture
23
+ :width: 650px
24
+
25
+ .. rubric:: Architectural Overview
26
+
27
+ ATCNet is a *convolution-first* architecture augmented with a *lightweight attention–TCN*
28
+ sequence module. The end-to-end flow is:
29
+
30
+ - (i) :class:`_ConvBlock` learns temporal filter-banks and spatial projections (EEGNet-style),
31
+ downsampling time to a compact feature map;
32
+
33
+ - (ii) Sliding Windows carve overlapping temporal windows from this map;
34
+
35
+ - (iii) for each window, :class:`_AttentionBlock` applies small multi-head self-attention
36
+ over time, followed by a :class:`_TCNResidualBlock` stack (causal, dilated);
37
+
38
+ - (iv) window-level features are aggregated (mean of window logits or concatenation)
39
+ and mapped via a max-norm–constrained linear layer.
40
+
41
+ Relative to ViT, ATCNet replaces linear patch projection with learned *temporal–spatial*
42
+ convolutions; it processes *parallel* window encoders (attention→TCN) instead of a deep
43
+ stack; and swaps the MLP head for a TCN suited to 1-D EEG sequences.
44
+
45
+ .. rubric:: Macro Components
46
+
47
+ - :class:`_ConvBlock` **(Shallow conv stem → feature map)**
48
+
49
+ - *Operations.*
50
+ - **Temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(L_t, 1)`` builds a
51
+ FIR-like filter bank (``F1`` maps).
52
+ - **Depthwise spatial conv** (:class:`torch.nn.Conv2d`, ``groups=F1``) with kernel
53
+ ``(1, n_chans)`` learns per-filter spatial projections (akin to EEGNet's CSP-like step).
54
+ - **BN → ELU → AvgPool → Dropout** to stabilize and condense activations.
55
+ - **Refining temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(L_r, 1)`` +
56
+ **BN → ELU → AvgPool → Dropout**.
57
+
58
+ The output shape is ``(B, F2, T_c, 1)`` with ``F2 = F1·D`` and ``T_c = T/(P1·P2)``.
59
+ Temporal kernels behave as FIR filters; the depthwise-spatial conv yields frequency-specific
60
+ topographies. Pooling acts as a local integrator, reducing variance and imposing a
61
+ useful inductive bias on short EEG windows.
62
+
63
+ - **Sliding-Window Sequencer**
64
+
65
+ From the condensed time axis (length ``T_c``), ATCNet forms ``n`` overlapping windows
66
+ of width ``T_w = T_c - n + 1`` (one start per index). Each window produces a sequence
67
+ ``(B, F2, T_w)`` forwarded to its own attention-TCN branch. This creates *parallel*
68
+ encoders over shifted contexts and is key to robustness on nonstationary EEG.
69
+
70
+ - :class:`_AttentionBlock` **(small MHA on temporal positions)**
71
+
72
+ Attention here is *local to a window* and purely temporal.
73
+
74
+ - *Operations.*
75
+ - Rearrange to ``(B, T_w, F2)``,
76
+ - Normalization :class:`torch.nn.LayerNorm`
77
+ - Custom MultiHeadAttention :class:`_MHA` (``num_heads=H``, per-head dim ``d_h``) + residual add,
78
+ - Dropout :class:`torch.nn.Dropout`
79
+ - Rearrange back to ``(B, F2, T_w)``.
80
+
81
+ *Role.* Re-weights evidence across the window, letting the model emphasize informative
82
+ segments (onsets, bursts) before causal convolutions aggregate history.
83
+
84
+ - :class:`_TCNResidualBlock` **(causal dilated temporal CNN)**
85
+
86
+ *Operations:*
87
+
88
+ - Two :class:`braindecode.modules.CausalConv1d` layers per block with dilation ``1, 2, 4, …``
89
+ - Across blocks of `torch.nn.ELU` + `torch.nn.BatchNorm1d` + `torch.nn.Dropout`) +
90
+ a residual (identity or 1x1 mapping).
91
+ - The final feature used per window is the *last* causal step ``[..., -1]`` (forecast-style).
92
+
93
+ *Role.* Efficient long-range temporal integration with stable gradients; the dilated
94
+ receptive field complements attention's soft selection.
95
+
96
+ - **Aggregation & Classifier**
97
+
98
+ *Operations:*
99
+
100
+ - Either (a) map each window feature ``(B, F2)`` to logits via :class:`braindecode.modules.MaxNormLinear`
101
+ and **average** across windows (default, matching official code), or
102
+ - (b) **concatenate** all window features ``(B, n·F2)`` and apply a single :class:`MaxNormLinear`.
103
+
104
+ The max-norm constraint regularizes the readout.
105
+
106
+ .. rubric:: Convolutional Details
107
+
108
+ - **Temporal.** Temporal structure is learned in three places:
109
+ - (1) the stem's wide ``(L_t, 1)`` conv (learned filter bank),
110
+ - (2) the refining ``(L_r, 1)`` conv after pooling (short-term dynamics), and
111
+ - (3) the TCN's causal 1-D convolutions with exponentially increasing dilation
112
+ (long-range dependencies). The minimum sequence length required by the TCN stack is
113
+ ``(K_t - 1)·2^{L-1} + 1``; the implementation *auto-scales* kernels/pools/windows
114
+ when inputs are shorter to preserve feasibility.
115
+
116
+ - **Spatial.** A depthwise spatial conv spans the **full montage** (kernel ``(1, n_chans)``),
117
+ producing *per-temporal-filter* spatial projections (no cross-filter mixing at this step).
118
+ This mirrors EEGNet's interpretability: each temporal filter has its own spatial pattern.
119
+
120
+ .. rubric:: Attention / Sequential Modules
121
+
122
+ - **Type.** Multi-head self-attention with ``H`` heads and per-head dim ``d_h`` implemented
123
+ in :class:`_MHA`, allowing ``embed_dim = H·d_h`` independent of input and output dims.
124
+ - **Shapes.** ``(B, F2, T_w) → (B, T_w, F2) → (B, F2, T_w)``. Attention operates along
125
+ the **temporal** axis within a window; channels/features stay in the embedding dim ``F2``.
126
+ - **Role.** Highlights salient temporal positions prior to causal convolution; small attention
127
+ keeps compute modest while improving context modeling over pooled features.
128
+
129
+ .. rubric:: Additional Mechanisms
130
+
131
+ - **Parallel encoders over shifted windows.** Improves montage/phase robustness by
132
+ ensembling nearby contexts rather than committing to a single segmentation.
133
+ - **Max-norm classifier.** Enforces weight norm constraints at the readout, a common
134
+ stabilization trick in EEG decoding.
135
+ - **ViT vs. ATCNet (design choices).** Convolutional *nonlinear* projection rather than
136
+ linear patchification; attention followed by **TCN** (not MLP); *parallel* window
137
+ encoders rather than stacked encoders.
138
+
139
+ .. rubric:: Usage and Configuration
140
+
141
+ - ``conv_block_n_filters (F1)``, ``conv_block_depth_mult (D)`` → capacity of the stem
142
+ (with ``F2 = F1·D`` feeding attention/TCN), dimensions aligned to ``F2``, like :class:`EEGNet`.
143
+ - Pool sizes ``P1,P2`` trade temporal resolution for stability/compute; they set
144
+ ``T_c = T/(P1·P2)`` and thus window width ``T_w``.
145
+ - ``n_windows`` controls the ensemble over shifts (compute ∝ windows).
146
+ - ``num_heads``, ``head_dim`` set attention capacity; keep ``H·d_h ≈ F2``.
147
+ - ``tcn_depth``, ``tcn_kernel_size`` govern receptive field; larger values demand
148
+ longer inputs (see minimum length above). The implementation warns and *rescales*
149
+ kernels/pools/windows if inputs are too short.
150
+ - **Aggregation choice.** ``concat=False`` (default, average of per-window logits) matches
151
+ the official code; ``concat=True`` mirrors the paper's concatenation variant.
152
+
153
+ Parameters
154
+ ----------
155
+ input_window_seconds : float, optional
156
+ Time length of inputs, in seconds. Defaults to 4.5 s, as in BCI-IV 2a
157
+ dataset.
158
+ sfreq : int, optional
159
+ Sampling frequency of the inputs, in Hz. Default to 250 Hz, as in
160
+ BCI-IV 2a dataset.
161
+ conv_block_n_filters : int
162
+ Number temporal filters in the first convolutional layer of the
163
+ convolutional block, denoted F1 in figure 2 of the paper [1]_. Defaults
164
+ to 16 as in [1]_.
165
+ conv_block_kernel_length_1 : int
166
+ Length of temporal filters in the first convolutional layer of the
167
+ convolutional block, denoted Kc in table 1 of the paper [1]_. Defaults
168
+ to 64 as in [1]_.
169
+ conv_block_kernel_length_2 : int
170
+ Length of temporal filters in the last convolutional layer of the
171
+ convolutional block. Defaults to 16 as in [1]_.
172
+ conv_block_pool_size_1 : int
173
+ Length of first average pooling kernel in the convolutional block.
174
+ Defaults to 8 as in [1]_.
175
+ conv_block_pool_size_2 : int
176
+ Length of first average pooling kernel in the convolutional block,
177
+ denoted P2 in table 1 of the paper [1]_. Defaults to 7 as in [1]_.
178
+ conv_block_depth_mult : int
179
+ Depth multiplier of depthwise convolution in the convolutional block,
180
+ denoted D in table 1 of the paper [1]_. Defaults to 2 as in [1]_.
181
+ conv_block_dropout : float
182
+ Dropout probability used in the convolution block, denoted pc in
183
+ table 1 of the paper [1]_. Defaults to 0.3 as in [1]_.
184
+ n_windows : int
185
+ Number of sliding windows, denoted n in [1]_. Defaults to 5 as in [1]_.
186
+ head_dim : int
187
+ Embedding dimension used in each self-attention head, denoted dh in
188
+ table 1 of the paper [1]_. Defaults to 8 as in [1]_.
189
+ num_heads : int
190
+ Number of attention heads, denoted H in table 1 of the paper [1]_.
191
+ Defaults to 2 as in [1]_.
192
+ att_dropout : float
193
+ Dropout probability used in the attention block, denoted pa in table 1
194
+ of the paper [1]_. Defaults to 0.5 as in [1]_.
195
+ tcn_depth : int
196
+ Depth of Temporal Convolutional Network block (i.e. number of TCN
197
+ Residual blocks), denoted L in table 1 of the paper [1]_. Defaults to 2
198
+ as in [1]_.
199
+ tcn_kernel_size : int
200
+ Temporal kernel size used in TCN block, denoted Kt in table 1 of the
201
+ paper [1]_. Defaults to 4 as in [1]_.
202
+ tcn_dropout : float
203
+ Dropout probability used in the TCN block, denoted pt in table 1
204
+ of the paper [1]_. Defaults to 0.3 as in [1]_.
205
+ tcn_activation : torch.nn.Module
206
+ Nonlinear activation to use. Defaults to nn.ELU().
207
+ concat : bool
208
+ When ``True``, concatenates each slidding window embedding before
209
+ feeding it to a fully-connected layer, as done in [1]_. When ``False``,
210
+ maps each slidding window to `n_outputs` logits and average them.
211
+ Defaults to ``False`` contrary to what is reported in [1]_, but
212
+ matching what the official code does [2]_.
213
+ max_norm_const : float
214
+ Maximum L2-norm constraint imposed on weights of the last
215
+ fully-connected layer. Defaults to 0.25.
216
+
217
+ Notes
218
+ -----
219
+ - Inputs substantially shorter than the implied minimum length trigger **automatic
220
+ downscaling** of kernels, pools, windows, and TCN kernel size to maintain validity.
221
+ - The attention–TCN sequence operates **per window**; the last causal step is used as the
222
+ window feature, aligning the temporal semantics across windows.
223
+
224
+ .. versionadded:: 1.1
225
+
226
+ - More detailed documentation of the model.
227
+
228
+ References
229
+ ----------
230
+ .. [1] H. Altaheri, G. Muhammad, M. Alsulaiman (2022).
231
+ *Physics-informed attention temporal convolutional network for EEG-based motor imagery classification.*
232
+ IEEE Transactions on Industrial Informatics. doi:10.1109/TII.2022.3197419.
233
+ .. [2] Official EEG-ATCNet implementation (TensorFlow):
234
+ https://github.com/Altaheri/EEG-ATCNet/blob/main/models.py
235
+ """
236
+
237
+ def __init__(
238
+ self,
239
+ n_chans=None,
240
+ n_outputs=None,
241
+ input_window_seconds=None,
242
+ sfreq=250.0,
243
+ conv_block_n_filters=16,
244
+ conv_block_kernel_length_1=64,
245
+ conv_block_kernel_length_2=16,
246
+ conv_block_pool_size_1=8,
247
+ conv_block_pool_size_2=7,
248
+ conv_block_depth_mult=2,
249
+ conv_block_dropout=0.3,
250
+ n_windows=5,
251
+ head_dim=8,
252
+ num_heads=2,
253
+ att_drop_prob=0.5,
254
+ tcn_depth=2,
255
+ tcn_kernel_size=4,
256
+ tcn_drop_prob=0.3,
257
+ tcn_activation: type[nn.Module] = nn.ELU,
258
+ concat=False,
259
+ max_norm_const=0.25,
260
+ chs_info=None,
261
+ n_times=None,
262
+ ):
263
+ super().__init__(
264
+ n_outputs=n_outputs,
265
+ n_chans=n_chans,
266
+ chs_info=chs_info,
267
+ n_times=n_times,
268
+ input_window_seconds=input_window_seconds,
269
+ sfreq=sfreq,
270
+ )
271
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
272
+
273
+ # Validate and adjust parameters based on input size
274
+
275
+ min_len_tcn = (tcn_kernel_size - 1) * (2 ** (tcn_depth - 1)) + 1
276
+ # Minimum length required to get at least one sliding window
277
+ min_len_sliding = n_windows + min_len_tcn - 1
278
+ # Minimum input size that produces the required feature map length
279
+ min_n_times = min_len_sliding * conv_block_pool_size_1 * conv_block_pool_size_2
280
+
281
+ # 2. If the input is shorter, calculate a scaling factor
282
+ if self.n_times < min_n_times:
283
+ scaling_factor = self.n_times / min_n_times
284
+ warn(
285
+ f"n_times ({self.n_times}) is smaller than the minimum required "
286
+ f"({min_n_times}) for the current model parameters configuration. "
287
+ "Adjusting parameters to ensure compatibility."
288
+ "Reducing the kernel, pooling, and stride sizes accordingly."
289
+ "Scaling factor: {:.2f}".format(scaling_factor),
290
+ UserWarning,
291
+ )
292
+ conv_block_kernel_length_1 = max(
293
+ 1, int(conv_block_kernel_length_1 * scaling_factor)
294
+ )
295
+ conv_block_kernel_length_2 = max(
296
+ 1, int(conv_block_kernel_length_2 * scaling_factor)
297
+ )
298
+ conv_block_pool_size_1 = max(
299
+ 1, int(conv_block_pool_size_1 * scaling_factor)
300
+ )
301
+ conv_block_pool_size_2 = max(
302
+ 1, int(conv_block_pool_size_2 * scaling_factor)
303
+ )
304
+
305
+ # n_windows should be at least 1
306
+ n_windows = max(1, int(n_windows * scaling_factor))
307
+
308
+ # tcn_kernel_size must be at least 2 for dilation to work
309
+ tcn_kernel_size = max(2, int(tcn_kernel_size * scaling_factor))
310
+
311
+ self.conv_block_n_filters = conv_block_n_filters
312
+ self.conv_block_kernel_length_1 = conv_block_kernel_length_1
313
+ self.conv_block_kernel_length_2 = conv_block_kernel_length_2
314
+ self.conv_block_pool_size_1 = conv_block_pool_size_1
315
+ self.conv_block_pool_size_2 = conv_block_pool_size_2
316
+ self.conv_block_depth_mult = conv_block_depth_mult
317
+ self.conv_block_dropout = conv_block_dropout
318
+ self.n_windows = n_windows
319
+ self.head_dim = head_dim
320
+ self.num_heads = num_heads
321
+ self.att_dropout = att_drop_prob
322
+ self.tcn_depth = tcn_depth
323
+ self.tcn_kernel_size = tcn_kernel_size
324
+ self.tcn_dropout = tcn_drop_prob
325
+ self.tcn_activation = tcn_activation
326
+ self.concat = concat
327
+ self.max_norm_const = max_norm_const
328
+ self.tcn_n_filters = int(self.conv_block_depth_mult * self.conv_block_n_filters)
329
+ map = dict()
330
+ for w in range(self.n_windows):
331
+ map[f"max_norm_linears.[{w}].weight"] = f"final_layer.[{w}].weight"
332
+ map[f"max_norm_linears.[{w}].bias"] = f"final_layer.[{w}].bias"
333
+ self.mapping = map
334
+
335
+ # Check later if we want to keep the Ensure4d. Not sure if we can
336
+ # remove it or replace it with eipsum.
337
+ self.ensuredims = Ensure4d()
338
+ self.dimshuffle = Rearrange("batch C T 1 -> batch 1 T C")
339
+
340
+ self.conv_block = _ConvBlock(
341
+ n_channels=self.n_chans, # input shape: (batch_size, 1, T, C)
342
+ n_filters=conv_block_n_filters,
343
+ kernel_length_1=conv_block_kernel_length_1,
344
+ kernel_length_2=conv_block_kernel_length_2,
345
+ pool_size_1=conv_block_pool_size_1,
346
+ pool_size_2=conv_block_pool_size_2,
347
+ depth_mult=conv_block_depth_mult,
348
+ dropout=conv_block_dropout,
349
+ )
350
+
351
+ self.F2 = int(conv_block_depth_mult * conv_block_n_filters)
352
+ self.Tc = int(self.n_times / (conv_block_pool_size_1 * conv_block_pool_size_2))
353
+ self.Tw = self.Tc - self.n_windows + 1
354
+
355
+ self.attention_blocks = nn.ModuleList(
356
+ [
357
+ _AttentionBlock(
358
+ in_shape=self.F2,
359
+ head_dim=self.head_dim,
360
+ num_heads=num_heads,
361
+ dropout=att_drop_prob,
362
+ )
363
+ for _ in range(self.n_windows)
364
+ ]
365
+ )
366
+
367
+ self.temporal_conv_nets = nn.ModuleList(
368
+ [
369
+ nn.Sequential(
370
+ *[
371
+ _TCNResidualBlock(
372
+ in_channels=self.F2 if i == 0 else self.tcn_n_filters,
373
+ kernel_size=self.tcn_kernel_size,
374
+ n_filters=self.tcn_n_filters,
375
+ dropout=self.tcn_dropout,
376
+ activation=self.tcn_activation,
377
+ dilation=2**i,
378
+ )
379
+ for i in range(self.tcn_depth)
380
+ ]
381
+ )
382
+ for _ in range(self.n_windows)
383
+ ]
384
+ )
385
+
386
+ if self.concat:
387
+ self.final_layer = nn.ModuleList(
388
+ [
389
+ MaxNormLinear(
390
+ in_features=self.tcn_n_filters * self.n_windows,
391
+ out_features=self.n_outputs,
392
+ max_norm_val=self.max_norm_const,
393
+ )
394
+ ]
395
+ )
396
+ else:
397
+ self.final_layer = nn.ModuleList(
398
+ [
399
+ MaxNormLinear(
400
+ in_features=self.tcn_n_filters,
401
+ out_features=self.n_outputs,
402
+ max_norm_val=self.max_norm_const,
403
+ )
404
+ for _ in range(self.n_windows)
405
+ ]
406
+ )
407
+
408
+ self.out_fun = nn.Identity()
409
+
410
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
411
+ # Dimension: (batch_size, C, T)
412
+ X = self.ensuredims(X)
413
+ # Dimension: (batch_size, C, T, 1)
414
+ X = self.dimshuffle(X)
415
+ # Dimension: (batch_size, 1, T, C)
416
+
417
+ # ----- Sliding window -----
418
+ conv_feat = self.conv_block(X)
419
+ # Dimension: (batch_size, F2, Tc, 1)
420
+ conv_feat = conv_feat.view(-1, self.F2, self.Tc)
421
+ # Dimension: (batch_size, F2, Tc)
422
+
423
+ # ----- Sliding window -----
424
+ sw_concat: list[torch.Tensor] = [] # to store sliding window outputs
425
+ # for w in range(self.n_windows):
426
+ for idx, (attention, tcn_module, final_layer) in enumerate(
427
+ zip(self.attention_blocks, self.temporal_conv_nets, self.final_layer)
428
+ ):
429
+ conv_feat_w = conv_feat[..., idx : idx + self.Tw]
430
+ # Dimension: (batch_size, F2, Tw)
431
+
432
+ # ----- Attention block -----
433
+ att_feat = attention(conv_feat_w)
434
+ # Dimension: (batch_size, F2, Tw)
435
+
436
+ # ----- Temporal convolutional network (TCN) -----
437
+ tcn_feat = tcn_module(att_feat)[..., -1]
438
+ # Dimension: (batch_size, F2)
439
+
440
+ # Outputs of sliding window can be either averaged after being
441
+ # mapped by dense layer or concatenated then mapped by a dense
442
+ # layer
443
+ if not self.concat:
444
+ tcn_feat = final_layer(tcn_feat)
445
+
446
+ sw_concat.append(tcn_feat)
447
+
448
+ # ----- Aggregation and prediction -----
449
+ if self.concat:
450
+ sw_concat_agg = torch.cat(sw_concat, dim=1)
451
+ sw_concat_agg = self.final_layer[0](sw_concat_agg)
452
+ else:
453
+ if len(sw_concat) > 1: # more than one window
454
+ sw_concat_agg = torch.stack(sw_concat, dim=0)
455
+ sw_concat_agg = torch.mean(sw_concat_agg, dim=0)
456
+ else: # one window (# windows = 1)
457
+ sw_concat_agg = sw_concat[0]
458
+
459
+ return self.out_fun(sw_concat_agg)
460
+
461
+
462
+ class _ConvBlock(nn.Module):
463
+ r"""Convolutional block proposed in ATCNet [1]_, inspired by the EEGNet.
464
+
465
+ architecture [2]_.
466
+
467
+ References
468
+ ----------
469
+ .. [1] H. Altaheri, G. Muhammad and M. Alsulaiman, "Physics-informed
470
+ attention temporal convolutional network for EEG-based motor imagery
471
+ classification," in IEEE Transactions on Industrial Informatics,
472
+ 2022, doi: 10.1109/TII.2022.3197419.
473
+ .. [2] Lawhern, V. J., Solon, A. J., Waytowich, N. R., Gordon,
474
+ S. M., Hung, C. P., & Lance, B. J. (2018).
475
+ EEGNet: A Compact Convolutional Network for EEG-based
476
+ Brain-Computer Interfaces.
477
+ arXiv preprint arXiv:1611.08024.
478
+ """
479
+
480
+ def __init__(
481
+ self,
482
+ n_channels,
483
+ n_filters=16,
484
+ kernel_length_1=64,
485
+ kernel_length_2=16,
486
+ pool_size_1=8,
487
+ pool_size_2=7,
488
+ depth_mult=2,
489
+ dropout=0.3,
490
+ ):
491
+ super().__init__()
492
+
493
+ self.conv1 = nn.Conv2d(
494
+ in_channels=1,
495
+ out_channels=n_filters,
496
+ kernel_size=(kernel_length_1, 1),
497
+ padding="same",
498
+ bias=False,
499
+ )
500
+
501
+ self.bn1 = nn.BatchNorm2d(num_features=n_filters, eps=1e-4)
502
+
503
+ n_depth_kernels = n_filters * depth_mult
504
+ self.conv2 = nn.Conv2d(
505
+ in_channels=n_filters,
506
+ out_channels=n_depth_kernels,
507
+ groups=n_filters,
508
+ kernel_size=(1, n_channels),
509
+ padding="valid",
510
+ bias=False,
511
+ )
512
+
513
+ self.bn2 = nn.BatchNorm2d(num_features=n_depth_kernels, eps=1e-4)
514
+
515
+ self.activation2 = nn.ELU()
516
+
517
+ self.pool2 = nn.AvgPool2d(kernel_size=(pool_size_1, 1))
518
+
519
+ self.drop2 = nn.Dropout2d(dropout)
520
+
521
+ self.conv3 = nn.Conv2d(
522
+ in_channels=n_depth_kernels,
523
+ out_channels=n_depth_kernels,
524
+ kernel_size=(kernel_length_2, 1),
525
+ padding="same",
526
+ bias=False,
527
+ )
528
+
529
+ self.bn3 = nn.BatchNorm2d(num_features=n_depth_kernels, eps=1e-4)
530
+
531
+ self.activation3 = nn.ELU()
532
+
533
+ self.pool3 = nn.AvgPool2d(kernel_size=(pool_size_2, 1))
534
+
535
+ self.drop3 = nn.Dropout2d(dropout)
536
+
537
+ def forward(self, X):
538
+ # ----- Temporal convolution -----
539
+ # Dimension: (batch_size, 1, T, C)
540
+ X = self.conv1(X)
541
+ X = self.bn1(X)
542
+ # Dimension: (batch_size, F1, T, C)
543
+
544
+ # ----- Depthwise channels convolution -----
545
+ X = self.conv2(X)
546
+ X = self.bn2(X)
547
+ X = self.activation2(X)
548
+ # Dimension: (batch_size, F1*D, T, 1)
549
+ X = self.pool2(X)
550
+ X = self.drop2(X)
551
+ # Dimension: (batch_size, F1*D, T/P1, 1)
552
+
553
+ # ----- "Spatial" convolution -----
554
+ X = self.conv3(X)
555
+ X = self.bn3(X)
556
+ X = self.activation3(X)
557
+ # Dimension: (batch_size, F1*D, T/P1, 1)
558
+ X = self.pool3(X)
559
+ X = self.drop3(X)
560
+ # Dimension: (batch_size, F1*D, T/(P1*P2), 1)
561
+
562
+ return X
563
+
564
+
565
+ class _AttentionBlock(nn.Module):
566
+ r"""Multi Head self Attention (MHA) block used in ATCNet [1]_, inspired from.
567
+
568
+ [2]_.
569
+
570
+ References
571
+ ----------
572
+ .. [1] H. Altaheri, G. Muhammad and M. Alsulaiman, "Physics-informed
573
+ attention temporal convolutional network for EEG-based motor imagery
574
+ classification," in IEEE Transactions on Industrial Informatics,
575
+ 2022, doi: 10.1109/TII.2022.3197419.
576
+ .. [2] Vaswani, A. et al., "Attention is all you need",
577
+ in Advances in neural information processing systems, 2017.
578
+ """
579
+
580
+ def __init__(
581
+ self,
582
+ in_shape=32,
583
+ head_dim=8,
584
+ num_heads=2,
585
+ dropout=0.5,
586
+ ):
587
+ super().__init__()
588
+ self.in_shape = in_shape
589
+ self.head_dim = head_dim
590
+ self.num_heads = num_heads
591
+
592
+ # Puts time dimension at -2 and feature dim at -1
593
+ self.dimshuffle = Rearrange("batch C T -> batch T C")
594
+
595
+ # Layer normalization
596
+ self.ln = nn.LayerNorm(normalized_shape=in_shape, eps=1e-6)
597
+
598
+ # Multi-head self-attention layer
599
+ # (We had to reimplement it since the original code is in tensorflow,
600
+ # where it is possible to have an embedding dimension different than
601
+ # the input and output dimensions, which is not possible in pytorch.)
602
+ self.mha = _MHA(
603
+ input_dim=in_shape,
604
+ head_dim=head_dim,
605
+ output_dim=in_shape,
606
+ num_heads=num_heads,
607
+ dropout=dropout,
608
+ )
609
+
610
+ # XXX: This line in the official code is weird, as there is already
611
+ # dropout in the MultiheadAttention layer. They also don't mention
612
+ # any additional dropout between the attention block and TCN in the
613
+ # paper. We are adding it here however to follo so we are removing this
614
+ # for now.
615
+ self.drop = nn.Dropout(0.3)
616
+
617
+ def forward(self, X):
618
+ # Dimension: (batch_size, F2, Tw)
619
+ X = self.dimshuffle(X)
620
+ # Dimension: (batch_size, Tw, F2)
621
+
622
+ # ----- Layer norm -----
623
+ out = self.ln(X)
624
+
625
+ # ----- Self-Attention -----
626
+ out = self.mha(out, out, out)
627
+ # Dimension: (batch_size, Tw, F2)
628
+
629
+ # XXX In the paper fig. 1, it is drawn that layer normalization is
630
+ # performed before the skip connection, while it is done afterwards
631
+ # in the official code. Here we follow the code.
632
+
633
+ # ----- Skip connection -----
634
+ out = X + self.drop(out)
635
+
636
+ # Move back to shape (batch_size, F2, Tw) from the beginning
637
+ return self.dimshuffle(out)
638
+
639
+
640
+ class _TCNResidualBlock(nn.Module):
641
+ r"""Modified TCN Residual block as proposed in [1]_.
642
+
643
+ Inspired from
644
+ Temporal Convolutional Networks (TCN) [2]_.
645
+
646
+ References
647
+ ----------
648
+ .. [1] H. Altaheri, G. Muhammad and M. Alsulaiman, "Physics-informed
649
+ attention temporal convolutional network for EEG-based motor imagery
650
+ classification," in IEEE Transactions on Industrial Informatics,
651
+ 2022, doi: 10.1109/TII.2022.3197419.
652
+ .. [2] Bai, S., Kolter, J. Z., & Koltun, V.
653
+ "An empirical evaluation of generic convolutional and recurrent
654
+ networks for sequence modeling", 2018.
655
+ """
656
+
657
+ def __init__(
658
+ self,
659
+ in_channels,
660
+ kernel_size=4,
661
+ n_filters=32,
662
+ dropout=0.3,
663
+ activation: type[nn.Module] = nn.ELU,
664
+ dilation=1,
665
+ ):
666
+ super().__init__()
667
+ self.activation = activation()
668
+ self.dilation = dilation
669
+ self.dropout = dropout
670
+ self.n_filters = n_filters
671
+ self.kernel_size = kernel_size
672
+ self.in_channels = in_channels
673
+
674
+ self.conv1 = CausalConv1d(
675
+ in_channels=in_channels,
676
+ out_channels=n_filters,
677
+ kernel_size=kernel_size,
678
+ dilation=dilation,
679
+ )
680
+ nn.init.kaiming_uniform_(self.conv1.weight)
681
+
682
+ self.bn1 = nn.BatchNorm1d(n_filters)
683
+
684
+ self.drop1 = nn.Dropout(dropout)
685
+
686
+ self.conv2 = CausalConv1d(
687
+ in_channels=n_filters,
688
+ out_channels=n_filters,
689
+ kernel_size=kernel_size,
690
+ dilation=dilation,
691
+ )
692
+ nn.init.kaiming_uniform_(self.conv2.weight)
693
+
694
+ self.bn2 = nn.BatchNorm1d(n_filters)
695
+
696
+ self.drop2 = nn.Dropout(dropout)
697
+
698
+ # Reshape the input for the residual connection when necessary
699
+ if in_channels != n_filters:
700
+ self.reshaping_conv = nn.Conv1d(
701
+ in_channels=in_channels, # Specify input channels
702
+ out_channels=n_filters, # Specify output channels
703
+ kernel_size=1,
704
+ padding="same",
705
+ )
706
+ else:
707
+ self.reshaping_conv = nn.Identity()
708
+
709
+ def forward(self, X):
710
+ # Dimension: (batch_size, F2, Tw)
711
+ # ----- Double dilated convolutions -----
712
+ out = self.conv1(X)
713
+ out = self.bn1(out)
714
+ out = self.activation(out)
715
+ out = self.drop1(out)
716
+
717
+ out = self.conv2(out)
718
+ out = self.bn2(out)
719
+ out = self.activation(out)
720
+ out = self.drop2(out)
721
+
722
+ X = self.reshaping_conv(X)
723
+
724
+ # ----- Residual connection -----
725
+ out = X + out
726
+
727
+ return self.activation(out)
728
+
729
+
730
+ class _MHA(nn.Module):
731
+ def __init__(
732
+ self,
733
+ input_dim: int,
734
+ head_dim: int,
735
+ output_dim: int,
736
+ num_heads: int,
737
+ dropout: float = 0.0,
738
+ ):
739
+ """Multi-head Attention.
740
+
741
+ The difference between this module and torch.nn.MultiheadAttention is
742
+ that this module supports embedding dimensions different then input
743
+ and output ones. It also does not support sequences of different
744
+ length.
745
+
746
+ Parameters
747
+ ----------
748
+ input_dim : int
749
+ Dimension of query, key and value inputs.
750
+ head_dim : int
751
+ Dimension of embed query, key and value in each head,
752
+ before computing attention.
753
+ output_dim : int
754
+ Output dimension.
755
+ num_heads : int
756
+ Number of heads in the multi-head architecture.
757
+ dropout : float, optional
758
+ Dropout probability on output weights. Default: 0.0 (no dropout).
759
+ """
760
+
761
+ super(_MHA, self).__init__()
762
+
763
+ self.input_dim = input_dim
764
+ self.head_dim = head_dim
765
+ # typical choice for the split dimension of the heads
766
+ self.embed_dim = head_dim * num_heads
767
+
768
+ # embeddings for multi-head projections
769
+ self.fc_q = nn.Linear(input_dim, self.embed_dim)
770
+ self.fc_k = nn.Linear(input_dim, self.embed_dim)
771
+ self.fc_v = nn.Linear(input_dim, self.embed_dim)
772
+
773
+ # output mapping
774
+ self.fc_o = nn.Linear(self.embed_dim, output_dim)
775
+
776
+ # dropout
777
+ self.dropout = nn.Dropout(dropout)
778
+
779
+ def forward(
780
+ self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor
781
+ ) -> torch.Tensor:
782
+ """Compute MHA(Q, K, V).
783
+
784
+ Parameters
785
+ ----------
786
+ Q : torch.Tensor of size (batch_size, seq_len, input_dim)
787
+ Input query (Q) sequence.
788
+ K : torch.Tensor of size (batch_size, seq_len, input_dim)
789
+ Input key (K) sequence.
790
+ V : torch.Tensor of size (batch_size, seq_len, input_dim)
791
+ Input value (V) sequence.
792
+
793
+ Returns
794
+ -------
795
+ O : torch.Tensor of size (batch_size, seq_len, output_dim)
796
+ Output MHA(Q, K, V)
797
+ """
798
+ assert Q.shape[-1] == K.shape[-1] == V.shape[-1] == self.input_dim
799
+
800
+ batch_size, _, _ = Q.shape
801
+
802
+ # embedding for multi-head projections (masked or not)
803
+ Q = self.fc_q(Q) # (B, S, D)
804
+ K, V = self.fc_k(K), self.fc_v(V) # (B, S, D)
805
+
806
+ # Split into num_head vectors (num_heads * batch_size, n/m, head_dim)
807
+ Q_ = torch.cat(Q.split(self.head_dim, -1), 0) # (B', S, D')
808
+ K_ = torch.cat(K.split(self.head_dim, -1), 0) # (B', S, D')
809
+ V_ = torch.cat(V.split(self.head_dim, -1), 0) # (B', S, D')
810
+
811
+ # Attention weights of size (num_heads * batch_size, n, m):
812
+ # measures how similar each pair of Q and K is.
813
+ W = torch.softmax(
814
+ Q_.bmm(K_.transpose(-2, -1)) / math.sqrt(self.head_dim),
815
+ -1, # (B', D', S)
816
+ ) # (B', N, M)
817
+
818
+ # Multihead output (batch_size, seq_len, dim):
819
+ # weighted sum of V where a value gets more weight if its corresponding
820
+ # key has larger dot product with the query.
821
+ H = torch.cat(
822
+ (W.bmm(V_)).split( # (B', S, S) # (B', S, D')
823
+ batch_size, 0
824
+ ), # [(B, S, D')] * num_heads
825
+ -1,
826
+ ) # (B, S, D)
827
+
828
+ out = self.fc_o(H)
829
+
830
+ return self.dropout(out)