braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,830 @@
|
|
|
1
|
+
# Authors: Cedric Rommel <cedric.rommel@inria.fr>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from einops.layers.torch import Rearrange
|
|
8
|
+
from mne.utils import warn
|
|
9
|
+
from torch import nn
|
|
10
|
+
|
|
11
|
+
from braindecode.models.base import EEGModuleMixin
|
|
12
|
+
from braindecode.modules import CausalConv1d, Ensure4d, MaxNormLinear
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ATCNet(EEGModuleMixin, nn.Module):
|
|
16
|
+
r"""ATCNet from Altaheri et al (2022) [1]_.
|
|
17
|
+
|
|
18
|
+
:bdg-success:`Convolution` :bdg-secondary:`Recurrent` :bdg-info:`Attention/Transformer`
|
|
19
|
+
|
|
20
|
+
.. figure:: https://user-images.githubusercontent.com/25565236/185449791-e8539453-d4fa-41e1-865a-2cf7e91f60ef.png
|
|
21
|
+
:align: center
|
|
22
|
+
:alt: ATCNet Architecture
|
|
23
|
+
:width: 650px
|
|
24
|
+
|
|
25
|
+
.. rubric:: Architectural Overview
|
|
26
|
+
|
|
27
|
+
ATCNet is a *convolution-first* architecture augmented with a *lightweight attention–TCN*
|
|
28
|
+
sequence module. The end-to-end flow is:
|
|
29
|
+
|
|
30
|
+
- (i) :class:`_ConvBlock` learns temporal filter-banks and spatial projections (EEGNet-style),
|
|
31
|
+
downsampling time to a compact feature map;
|
|
32
|
+
|
|
33
|
+
- (ii) Sliding Windows carve overlapping temporal windows from this map;
|
|
34
|
+
|
|
35
|
+
- (iii) for each window, :class:`_AttentionBlock` applies small multi-head self-attention
|
|
36
|
+
over time, followed by a :class:`_TCNResidualBlock` stack (causal, dilated);
|
|
37
|
+
|
|
38
|
+
- (iv) window-level features are aggregated (mean of window logits or concatenation)
|
|
39
|
+
and mapped via a max-norm–constrained linear layer.
|
|
40
|
+
|
|
41
|
+
Relative to ViT, ATCNet replaces linear patch projection with learned *temporal–spatial*
|
|
42
|
+
convolutions; it processes *parallel* window encoders (attention→TCN) instead of a deep
|
|
43
|
+
stack; and swaps the MLP head for a TCN suited to 1-D EEG sequences.
|
|
44
|
+
|
|
45
|
+
.. rubric:: Macro Components
|
|
46
|
+
|
|
47
|
+
- :class:`_ConvBlock` **(Shallow conv stem → feature map)**
|
|
48
|
+
|
|
49
|
+
- *Operations.*
|
|
50
|
+
- **Temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(L_t, 1)`` builds a
|
|
51
|
+
FIR-like filter bank (``F1`` maps).
|
|
52
|
+
- **Depthwise spatial conv** (:class:`torch.nn.Conv2d`, ``groups=F1``) with kernel
|
|
53
|
+
``(1, n_chans)`` learns per-filter spatial projections (akin to EEGNet's CSP-like step).
|
|
54
|
+
- **BN → ELU → AvgPool → Dropout** to stabilize and condense activations.
|
|
55
|
+
- **Refining temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(L_r, 1)`` +
|
|
56
|
+
**BN → ELU → AvgPool → Dropout**.
|
|
57
|
+
|
|
58
|
+
The output shape is ``(B, F2, T_c, 1)`` with ``F2 = F1·D`` and ``T_c = T/(P1·P2)``.
|
|
59
|
+
Temporal kernels behave as FIR filters; the depthwise-spatial conv yields frequency-specific
|
|
60
|
+
topographies. Pooling acts as a local integrator, reducing variance and imposing a
|
|
61
|
+
useful inductive bias on short EEG windows.
|
|
62
|
+
|
|
63
|
+
- **Sliding-Window Sequencer**
|
|
64
|
+
|
|
65
|
+
From the condensed time axis (length ``T_c``), ATCNet forms ``n`` overlapping windows
|
|
66
|
+
of width ``T_w = T_c - n + 1`` (one start per index). Each window produces a sequence
|
|
67
|
+
``(B, F2, T_w)`` forwarded to its own attention-TCN branch. This creates *parallel*
|
|
68
|
+
encoders over shifted contexts and is key to robustness on nonstationary EEG.
|
|
69
|
+
|
|
70
|
+
- :class:`_AttentionBlock` **(small MHA on temporal positions)**
|
|
71
|
+
|
|
72
|
+
Attention here is *local to a window* and purely temporal.
|
|
73
|
+
|
|
74
|
+
- *Operations.*
|
|
75
|
+
- Rearrange to ``(B, T_w, F2)``,
|
|
76
|
+
- Normalization :class:`torch.nn.LayerNorm`
|
|
77
|
+
- Custom MultiHeadAttention :class:`_MHA` (``num_heads=H``, per-head dim ``d_h``) + residual add,
|
|
78
|
+
- Dropout :class:`torch.nn.Dropout`
|
|
79
|
+
- Rearrange back to ``(B, F2, T_w)``.
|
|
80
|
+
|
|
81
|
+
*Role.* Re-weights evidence across the window, letting the model emphasize informative
|
|
82
|
+
segments (onsets, bursts) before causal convolutions aggregate history.
|
|
83
|
+
|
|
84
|
+
- :class:`_TCNResidualBlock` **(causal dilated temporal CNN)**
|
|
85
|
+
|
|
86
|
+
*Operations:*
|
|
87
|
+
|
|
88
|
+
- Two :class:`braindecode.modules.CausalConv1d` layers per block with dilation ``1, 2, 4, …``
|
|
89
|
+
- Across blocks of `torch.nn.ELU` + `torch.nn.BatchNorm1d` + `torch.nn.Dropout`) +
|
|
90
|
+
a residual (identity or 1x1 mapping).
|
|
91
|
+
- The final feature used per window is the *last* causal step ``[..., -1]`` (forecast-style).
|
|
92
|
+
|
|
93
|
+
*Role.* Efficient long-range temporal integration with stable gradients; the dilated
|
|
94
|
+
receptive field complements attention's soft selection.
|
|
95
|
+
|
|
96
|
+
- **Aggregation & Classifier**
|
|
97
|
+
|
|
98
|
+
*Operations:*
|
|
99
|
+
|
|
100
|
+
- Either (a) map each window feature ``(B, F2)`` to logits via :class:`braindecode.modules.MaxNormLinear`
|
|
101
|
+
and **average** across windows (default, matching official code), or
|
|
102
|
+
- (b) **concatenate** all window features ``(B, n·F2)`` and apply a single :class:`MaxNormLinear`.
|
|
103
|
+
|
|
104
|
+
The max-norm constraint regularizes the readout.
|
|
105
|
+
|
|
106
|
+
.. rubric:: Convolutional Details
|
|
107
|
+
|
|
108
|
+
- **Temporal.** Temporal structure is learned in three places:
|
|
109
|
+
- (1) the stem's wide ``(L_t, 1)`` conv (learned filter bank),
|
|
110
|
+
- (2) the refining ``(L_r, 1)`` conv after pooling (short-term dynamics), and
|
|
111
|
+
- (3) the TCN's causal 1-D convolutions with exponentially increasing dilation
|
|
112
|
+
(long-range dependencies). The minimum sequence length required by the TCN stack is
|
|
113
|
+
``(K_t - 1)·2^{L-1} + 1``; the implementation *auto-scales* kernels/pools/windows
|
|
114
|
+
when inputs are shorter to preserve feasibility.
|
|
115
|
+
|
|
116
|
+
- **Spatial.** A depthwise spatial conv spans the **full montage** (kernel ``(1, n_chans)``),
|
|
117
|
+
producing *per-temporal-filter* spatial projections (no cross-filter mixing at this step).
|
|
118
|
+
This mirrors EEGNet's interpretability: each temporal filter has its own spatial pattern.
|
|
119
|
+
|
|
120
|
+
.. rubric:: Attention / Sequential Modules
|
|
121
|
+
|
|
122
|
+
- **Type.** Multi-head self-attention with ``H`` heads and per-head dim ``d_h`` implemented
|
|
123
|
+
in :class:`_MHA`, allowing ``embed_dim = H·d_h`` independent of input and output dims.
|
|
124
|
+
- **Shapes.** ``(B, F2, T_w) → (B, T_w, F2) → (B, F2, T_w)``. Attention operates along
|
|
125
|
+
the **temporal** axis within a window; channels/features stay in the embedding dim ``F2``.
|
|
126
|
+
- **Role.** Highlights salient temporal positions prior to causal convolution; small attention
|
|
127
|
+
keeps compute modest while improving context modeling over pooled features.
|
|
128
|
+
|
|
129
|
+
.. rubric:: Additional Mechanisms
|
|
130
|
+
|
|
131
|
+
- **Parallel encoders over shifted windows.** Improves montage/phase robustness by
|
|
132
|
+
ensembling nearby contexts rather than committing to a single segmentation.
|
|
133
|
+
- **Max-norm classifier.** Enforces weight norm constraints at the readout, a common
|
|
134
|
+
stabilization trick in EEG decoding.
|
|
135
|
+
- **ViT vs. ATCNet (design choices).** Convolutional *nonlinear* projection rather than
|
|
136
|
+
linear patchification; attention followed by **TCN** (not MLP); *parallel* window
|
|
137
|
+
encoders rather than stacked encoders.
|
|
138
|
+
|
|
139
|
+
.. rubric:: Usage and Configuration
|
|
140
|
+
|
|
141
|
+
- ``conv_block_n_filters (F1)``, ``conv_block_depth_mult (D)`` → capacity of the stem
|
|
142
|
+
(with ``F2 = F1·D`` feeding attention/TCN), dimensions aligned to ``F2``, like :class:`EEGNet`.
|
|
143
|
+
- Pool sizes ``P1,P2`` trade temporal resolution for stability/compute; they set
|
|
144
|
+
``T_c = T/(P1·P2)`` and thus window width ``T_w``.
|
|
145
|
+
- ``n_windows`` controls the ensemble over shifts (compute ∝ windows).
|
|
146
|
+
- ``num_heads``, ``head_dim`` set attention capacity; keep ``H·d_h ≈ F2``.
|
|
147
|
+
- ``tcn_depth``, ``tcn_kernel_size`` govern receptive field; larger values demand
|
|
148
|
+
longer inputs (see minimum length above). The implementation warns and *rescales*
|
|
149
|
+
kernels/pools/windows if inputs are too short.
|
|
150
|
+
- **Aggregation choice.** ``concat=False`` (default, average of per-window logits) matches
|
|
151
|
+
the official code; ``concat=True`` mirrors the paper's concatenation variant.
|
|
152
|
+
|
|
153
|
+
Parameters
|
|
154
|
+
----------
|
|
155
|
+
input_window_seconds : float, optional
|
|
156
|
+
Time length of inputs, in seconds. Defaults to 4.5 s, as in BCI-IV 2a
|
|
157
|
+
dataset.
|
|
158
|
+
sfreq : int, optional
|
|
159
|
+
Sampling frequency of the inputs, in Hz. Default to 250 Hz, as in
|
|
160
|
+
BCI-IV 2a dataset.
|
|
161
|
+
conv_block_n_filters : int
|
|
162
|
+
Number temporal filters in the first convolutional layer of the
|
|
163
|
+
convolutional block, denoted F1 in figure 2 of the paper [1]_. Defaults
|
|
164
|
+
to 16 as in [1]_.
|
|
165
|
+
conv_block_kernel_length_1 : int
|
|
166
|
+
Length of temporal filters in the first convolutional layer of the
|
|
167
|
+
convolutional block, denoted Kc in table 1 of the paper [1]_. Defaults
|
|
168
|
+
to 64 as in [1]_.
|
|
169
|
+
conv_block_kernel_length_2 : int
|
|
170
|
+
Length of temporal filters in the last convolutional layer of the
|
|
171
|
+
convolutional block. Defaults to 16 as in [1]_.
|
|
172
|
+
conv_block_pool_size_1 : int
|
|
173
|
+
Length of first average pooling kernel in the convolutional block.
|
|
174
|
+
Defaults to 8 as in [1]_.
|
|
175
|
+
conv_block_pool_size_2 : int
|
|
176
|
+
Length of first average pooling kernel in the convolutional block,
|
|
177
|
+
denoted P2 in table 1 of the paper [1]_. Defaults to 7 as in [1]_.
|
|
178
|
+
conv_block_depth_mult : int
|
|
179
|
+
Depth multiplier of depthwise convolution in the convolutional block,
|
|
180
|
+
denoted D in table 1 of the paper [1]_. Defaults to 2 as in [1]_.
|
|
181
|
+
conv_block_dropout : float
|
|
182
|
+
Dropout probability used in the convolution block, denoted pc in
|
|
183
|
+
table 1 of the paper [1]_. Defaults to 0.3 as in [1]_.
|
|
184
|
+
n_windows : int
|
|
185
|
+
Number of sliding windows, denoted n in [1]_. Defaults to 5 as in [1]_.
|
|
186
|
+
head_dim : int
|
|
187
|
+
Embedding dimension used in each self-attention head, denoted dh in
|
|
188
|
+
table 1 of the paper [1]_. Defaults to 8 as in [1]_.
|
|
189
|
+
num_heads : int
|
|
190
|
+
Number of attention heads, denoted H in table 1 of the paper [1]_.
|
|
191
|
+
Defaults to 2 as in [1]_.
|
|
192
|
+
att_dropout : float
|
|
193
|
+
Dropout probability used in the attention block, denoted pa in table 1
|
|
194
|
+
of the paper [1]_. Defaults to 0.5 as in [1]_.
|
|
195
|
+
tcn_depth : int
|
|
196
|
+
Depth of Temporal Convolutional Network block (i.e. number of TCN
|
|
197
|
+
Residual blocks), denoted L in table 1 of the paper [1]_. Defaults to 2
|
|
198
|
+
as in [1]_.
|
|
199
|
+
tcn_kernel_size : int
|
|
200
|
+
Temporal kernel size used in TCN block, denoted Kt in table 1 of the
|
|
201
|
+
paper [1]_. Defaults to 4 as in [1]_.
|
|
202
|
+
tcn_dropout : float
|
|
203
|
+
Dropout probability used in the TCN block, denoted pt in table 1
|
|
204
|
+
of the paper [1]_. Defaults to 0.3 as in [1]_.
|
|
205
|
+
tcn_activation : torch.nn.Module
|
|
206
|
+
Nonlinear activation to use. Defaults to nn.ELU().
|
|
207
|
+
concat : bool
|
|
208
|
+
When ``True``, concatenates each slidding window embedding before
|
|
209
|
+
feeding it to a fully-connected layer, as done in [1]_. When ``False``,
|
|
210
|
+
maps each slidding window to `n_outputs` logits and average them.
|
|
211
|
+
Defaults to ``False`` contrary to what is reported in [1]_, but
|
|
212
|
+
matching what the official code does [2]_.
|
|
213
|
+
max_norm_const : float
|
|
214
|
+
Maximum L2-norm constraint imposed on weights of the last
|
|
215
|
+
fully-connected layer. Defaults to 0.25.
|
|
216
|
+
|
|
217
|
+
Notes
|
|
218
|
+
-----
|
|
219
|
+
- Inputs substantially shorter than the implied minimum length trigger **automatic
|
|
220
|
+
downscaling** of kernels, pools, windows, and TCN kernel size to maintain validity.
|
|
221
|
+
- The attention–TCN sequence operates **per window**; the last causal step is used as the
|
|
222
|
+
window feature, aligning the temporal semantics across windows.
|
|
223
|
+
|
|
224
|
+
.. versionadded:: 1.1
|
|
225
|
+
|
|
226
|
+
- More detailed documentation of the model.
|
|
227
|
+
|
|
228
|
+
References
|
|
229
|
+
----------
|
|
230
|
+
.. [1] H. Altaheri, G. Muhammad, M. Alsulaiman (2022).
|
|
231
|
+
*Physics-informed attention temporal convolutional network for EEG-based motor imagery classification.*
|
|
232
|
+
IEEE Transactions on Industrial Informatics. doi:10.1109/TII.2022.3197419.
|
|
233
|
+
.. [2] Official EEG-ATCNet implementation (TensorFlow):
|
|
234
|
+
https://github.com/Altaheri/EEG-ATCNet/blob/main/models.py
|
|
235
|
+
"""
|
|
236
|
+
|
|
237
|
+
def __init__(
|
|
238
|
+
self,
|
|
239
|
+
n_chans=None,
|
|
240
|
+
n_outputs=None,
|
|
241
|
+
input_window_seconds=None,
|
|
242
|
+
sfreq=250.0,
|
|
243
|
+
conv_block_n_filters=16,
|
|
244
|
+
conv_block_kernel_length_1=64,
|
|
245
|
+
conv_block_kernel_length_2=16,
|
|
246
|
+
conv_block_pool_size_1=8,
|
|
247
|
+
conv_block_pool_size_2=7,
|
|
248
|
+
conv_block_depth_mult=2,
|
|
249
|
+
conv_block_dropout=0.3,
|
|
250
|
+
n_windows=5,
|
|
251
|
+
head_dim=8,
|
|
252
|
+
num_heads=2,
|
|
253
|
+
att_drop_prob=0.5,
|
|
254
|
+
tcn_depth=2,
|
|
255
|
+
tcn_kernel_size=4,
|
|
256
|
+
tcn_drop_prob=0.3,
|
|
257
|
+
tcn_activation: type[nn.Module] = nn.ELU,
|
|
258
|
+
concat=False,
|
|
259
|
+
max_norm_const=0.25,
|
|
260
|
+
chs_info=None,
|
|
261
|
+
n_times=None,
|
|
262
|
+
):
|
|
263
|
+
super().__init__(
|
|
264
|
+
n_outputs=n_outputs,
|
|
265
|
+
n_chans=n_chans,
|
|
266
|
+
chs_info=chs_info,
|
|
267
|
+
n_times=n_times,
|
|
268
|
+
input_window_seconds=input_window_seconds,
|
|
269
|
+
sfreq=sfreq,
|
|
270
|
+
)
|
|
271
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
272
|
+
|
|
273
|
+
# Validate and adjust parameters based on input size
|
|
274
|
+
|
|
275
|
+
min_len_tcn = (tcn_kernel_size - 1) * (2 ** (tcn_depth - 1)) + 1
|
|
276
|
+
# Minimum length required to get at least one sliding window
|
|
277
|
+
min_len_sliding = n_windows + min_len_tcn - 1
|
|
278
|
+
# Minimum input size that produces the required feature map length
|
|
279
|
+
min_n_times = min_len_sliding * conv_block_pool_size_1 * conv_block_pool_size_2
|
|
280
|
+
|
|
281
|
+
# 2. If the input is shorter, calculate a scaling factor
|
|
282
|
+
if self.n_times < min_n_times:
|
|
283
|
+
scaling_factor = self.n_times / min_n_times
|
|
284
|
+
warn(
|
|
285
|
+
f"n_times ({self.n_times}) is smaller than the minimum required "
|
|
286
|
+
f"({min_n_times}) for the current model parameters configuration. "
|
|
287
|
+
"Adjusting parameters to ensure compatibility."
|
|
288
|
+
"Reducing the kernel, pooling, and stride sizes accordingly."
|
|
289
|
+
"Scaling factor: {:.2f}".format(scaling_factor),
|
|
290
|
+
UserWarning,
|
|
291
|
+
)
|
|
292
|
+
conv_block_kernel_length_1 = max(
|
|
293
|
+
1, int(conv_block_kernel_length_1 * scaling_factor)
|
|
294
|
+
)
|
|
295
|
+
conv_block_kernel_length_2 = max(
|
|
296
|
+
1, int(conv_block_kernel_length_2 * scaling_factor)
|
|
297
|
+
)
|
|
298
|
+
conv_block_pool_size_1 = max(
|
|
299
|
+
1, int(conv_block_pool_size_1 * scaling_factor)
|
|
300
|
+
)
|
|
301
|
+
conv_block_pool_size_2 = max(
|
|
302
|
+
1, int(conv_block_pool_size_2 * scaling_factor)
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
# n_windows should be at least 1
|
|
306
|
+
n_windows = max(1, int(n_windows * scaling_factor))
|
|
307
|
+
|
|
308
|
+
# tcn_kernel_size must be at least 2 for dilation to work
|
|
309
|
+
tcn_kernel_size = max(2, int(tcn_kernel_size * scaling_factor))
|
|
310
|
+
|
|
311
|
+
self.conv_block_n_filters = conv_block_n_filters
|
|
312
|
+
self.conv_block_kernel_length_1 = conv_block_kernel_length_1
|
|
313
|
+
self.conv_block_kernel_length_2 = conv_block_kernel_length_2
|
|
314
|
+
self.conv_block_pool_size_1 = conv_block_pool_size_1
|
|
315
|
+
self.conv_block_pool_size_2 = conv_block_pool_size_2
|
|
316
|
+
self.conv_block_depth_mult = conv_block_depth_mult
|
|
317
|
+
self.conv_block_dropout = conv_block_dropout
|
|
318
|
+
self.n_windows = n_windows
|
|
319
|
+
self.head_dim = head_dim
|
|
320
|
+
self.num_heads = num_heads
|
|
321
|
+
self.att_dropout = att_drop_prob
|
|
322
|
+
self.tcn_depth = tcn_depth
|
|
323
|
+
self.tcn_kernel_size = tcn_kernel_size
|
|
324
|
+
self.tcn_dropout = tcn_drop_prob
|
|
325
|
+
self.tcn_activation = tcn_activation
|
|
326
|
+
self.concat = concat
|
|
327
|
+
self.max_norm_const = max_norm_const
|
|
328
|
+
self.tcn_n_filters = int(self.conv_block_depth_mult * self.conv_block_n_filters)
|
|
329
|
+
map = dict()
|
|
330
|
+
for w in range(self.n_windows):
|
|
331
|
+
map[f"max_norm_linears.[{w}].weight"] = f"final_layer.[{w}].weight"
|
|
332
|
+
map[f"max_norm_linears.[{w}].bias"] = f"final_layer.[{w}].bias"
|
|
333
|
+
self.mapping = map
|
|
334
|
+
|
|
335
|
+
# Check later if we want to keep the Ensure4d. Not sure if we can
|
|
336
|
+
# remove it or replace it with eipsum.
|
|
337
|
+
self.ensuredims = Ensure4d()
|
|
338
|
+
self.dimshuffle = Rearrange("batch C T 1 -> batch 1 T C")
|
|
339
|
+
|
|
340
|
+
self.conv_block = _ConvBlock(
|
|
341
|
+
n_channels=self.n_chans, # input shape: (batch_size, 1, T, C)
|
|
342
|
+
n_filters=conv_block_n_filters,
|
|
343
|
+
kernel_length_1=conv_block_kernel_length_1,
|
|
344
|
+
kernel_length_2=conv_block_kernel_length_2,
|
|
345
|
+
pool_size_1=conv_block_pool_size_1,
|
|
346
|
+
pool_size_2=conv_block_pool_size_2,
|
|
347
|
+
depth_mult=conv_block_depth_mult,
|
|
348
|
+
dropout=conv_block_dropout,
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
self.F2 = int(conv_block_depth_mult * conv_block_n_filters)
|
|
352
|
+
self.Tc = int(self.n_times / (conv_block_pool_size_1 * conv_block_pool_size_2))
|
|
353
|
+
self.Tw = self.Tc - self.n_windows + 1
|
|
354
|
+
|
|
355
|
+
self.attention_blocks = nn.ModuleList(
|
|
356
|
+
[
|
|
357
|
+
_AttentionBlock(
|
|
358
|
+
in_shape=self.F2,
|
|
359
|
+
head_dim=self.head_dim,
|
|
360
|
+
num_heads=num_heads,
|
|
361
|
+
dropout=att_drop_prob,
|
|
362
|
+
)
|
|
363
|
+
for _ in range(self.n_windows)
|
|
364
|
+
]
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
self.temporal_conv_nets = nn.ModuleList(
|
|
368
|
+
[
|
|
369
|
+
nn.Sequential(
|
|
370
|
+
*[
|
|
371
|
+
_TCNResidualBlock(
|
|
372
|
+
in_channels=self.F2 if i == 0 else self.tcn_n_filters,
|
|
373
|
+
kernel_size=self.tcn_kernel_size,
|
|
374
|
+
n_filters=self.tcn_n_filters,
|
|
375
|
+
dropout=self.tcn_dropout,
|
|
376
|
+
activation=self.tcn_activation,
|
|
377
|
+
dilation=2**i,
|
|
378
|
+
)
|
|
379
|
+
for i in range(self.tcn_depth)
|
|
380
|
+
]
|
|
381
|
+
)
|
|
382
|
+
for _ in range(self.n_windows)
|
|
383
|
+
]
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
if self.concat:
|
|
387
|
+
self.final_layer = nn.ModuleList(
|
|
388
|
+
[
|
|
389
|
+
MaxNormLinear(
|
|
390
|
+
in_features=self.tcn_n_filters * self.n_windows,
|
|
391
|
+
out_features=self.n_outputs,
|
|
392
|
+
max_norm_val=self.max_norm_const,
|
|
393
|
+
)
|
|
394
|
+
]
|
|
395
|
+
)
|
|
396
|
+
else:
|
|
397
|
+
self.final_layer = nn.ModuleList(
|
|
398
|
+
[
|
|
399
|
+
MaxNormLinear(
|
|
400
|
+
in_features=self.tcn_n_filters,
|
|
401
|
+
out_features=self.n_outputs,
|
|
402
|
+
max_norm_val=self.max_norm_const,
|
|
403
|
+
)
|
|
404
|
+
for _ in range(self.n_windows)
|
|
405
|
+
]
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
self.out_fun = nn.Identity()
|
|
409
|
+
|
|
410
|
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
|
411
|
+
# Dimension: (batch_size, C, T)
|
|
412
|
+
X = self.ensuredims(X)
|
|
413
|
+
# Dimension: (batch_size, C, T, 1)
|
|
414
|
+
X = self.dimshuffle(X)
|
|
415
|
+
# Dimension: (batch_size, 1, T, C)
|
|
416
|
+
|
|
417
|
+
# ----- Sliding window -----
|
|
418
|
+
conv_feat = self.conv_block(X)
|
|
419
|
+
# Dimension: (batch_size, F2, Tc, 1)
|
|
420
|
+
conv_feat = conv_feat.view(-1, self.F2, self.Tc)
|
|
421
|
+
# Dimension: (batch_size, F2, Tc)
|
|
422
|
+
|
|
423
|
+
# ----- Sliding window -----
|
|
424
|
+
sw_concat: list[torch.Tensor] = [] # to store sliding window outputs
|
|
425
|
+
# for w in range(self.n_windows):
|
|
426
|
+
for idx, (attention, tcn_module, final_layer) in enumerate(
|
|
427
|
+
zip(self.attention_blocks, self.temporal_conv_nets, self.final_layer)
|
|
428
|
+
):
|
|
429
|
+
conv_feat_w = conv_feat[..., idx : idx + self.Tw]
|
|
430
|
+
# Dimension: (batch_size, F2, Tw)
|
|
431
|
+
|
|
432
|
+
# ----- Attention block -----
|
|
433
|
+
att_feat = attention(conv_feat_w)
|
|
434
|
+
# Dimension: (batch_size, F2, Tw)
|
|
435
|
+
|
|
436
|
+
# ----- Temporal convolutional network (TCN) -----
|
|
437
|
+
tcn_feat = tcn_module(att_feat)[..., -1]
|
|
438
|
+
# Dimension: (batch_size, F2)
|
|
439
|
+
|
|
440
|
+
# Outputs of sliding window can be either averaged after being
|
|
441
|
+
# mapped by dense layer or concatenated then mapped by a dense
|
|
442
|
+
# layer
|
|
443
|
+
if not self.concat:
|
|
444
|
+
tcn_feat = final_layer(tcn_feat)
|
|
445
|
+
|
|
446
|
+
sw_concat.append(tcn_feat)
|
|
447
|
+
|
|
448
|
+
# ----- Aggregation and prediction -----
|
|
449
|
+
if self.concat:
|
|
450
|
+
sw_concat_agg = torch.cat(sw_concat, dim=1)
|
|
451
|
+
sw_concat_agg = self.final_layer[0](sw_concat_agg)
|
|
452
|
+
else:
|
|
453
|
+
if len(sw_concat) > 1: # more than one window
|
|
454
|
+
sw_concat_agg = torch.stack(sw_concat, dim=0)
|
|
455
|
+
sw_concat_agg = torch.mean(sw_concat_agg, dim=0)
|
|
456
|
+
else: # one window (# windows = 1)
|
|
457
|
+
sw_concat_agg = sw_concat[0]
|
|
458
|
+
|
|
459
|
+
return self.out_fun(sw_concat_agg)
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
class _ConvBlock(nn.Module):
|
|
463
|
+
r"""Convolutional block proposed in ATCNet [1]_, inspired by the EEGNet.
|
|
464
|
+
|
|
465
|
+
architecture [2]_.
|
|
466
|
+
|
|
467
|
+
References
|
|
468
|
+
----------
|
|
469
|
+
.. [1] H. Altaheri, G. Muhammad and M. Alsulaiman, "Physics-informed
|
|
470
|
+
attention temporal convolutional network for EEG-based motor imagery
|
|
471
|
+
classification," in IEEE Transactions on Industrial Informatics,
|
|
472
|
+
2022, doi: 10.1109/TII.2022.3197419.
|
|
473
|
+
.. [2] Lawhern, V. J., Solon, A. J., Waytowich, N. R., Gordon,
|
|
474
|
+
S. M., Hung, C. P., & Lance, B. J. (2018).
|
|
475
|
+
EEGNet: A Compact Convolutional Network for EEG-based
|
|
476
|
+
Brain-Computer Interfaces.
|
|
477
|
+
arXiv preprint arXiv:1611.08024.
|
|
478
|
+
"""
|
|
479
|
+
|
|
480
|
+
def __init__(
|
|
481
|
+
self,
|
|
482
|
+
n_channels,
|
|
483
|
+
n_filters=16,
|
|
484
|
+
kernel_length_1=64,
|
|
485
|
+
kernel_length_2=16,
|
|
486
|
+
pool_size_1=8,
|
|
487
|
+
pool_size_2=7,
|
|
488
|
+
depth_mult=2,
|
|
489
|
+
dropout=0.3,
|
|
490
|
+
):
|
|
491
|
+
super().__init__()
|
|
492
|
+
|
|
493
|
+
self.conv1 = nn.Conv2d(
|
|
494
|
+
in_channels=1,
|
|
495
|
+
out_channels=n_filters,
|
|
496
|
+
kernel_size=(kernel_length_1, 1),
|
|
497
|
+
padding="same",
|
|
498
|
+
bias=False,
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
self.bn1 = nn.BatchNorm2d(num_features=n_filters, eps=1e-4)
|
|
502
|
+
|
|
503
|
+
n_depth_kernels = n_filters * depth_mult
|
|
504
|
+
self.conv2 = nn.Conv2d(
|
|
505
|
+
in_channels=n_filters,
|
|
506
|
+
out_channels=n_depth_kernels,
|
|
507
|
+
groups=n_filters,
|
|
508
|
+
kernel_size=(1, n_channels),
|
|
509
|
+
padding="valid",
|
|
510
|
+
bias=False,
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
self.bn2 = nn.BatchNorm2d(num_features=n_depth_kernels, eps=1e-4)
|
|
514
|
+
|
|
515
|
+
self.activation2 = nn.ELU()
|
|
516
|
+
|
|
517
|
+
self.pool2 = nn.AvgPool2d(kernel_size=(pool_size_1, 1))
|
|
518
|
+
|
|
519
|
+
self.drop2 = nn.Dropout2d(dropout)
|
|
520
|
+
|
|
521
|
+
self.conv3 = nn.Conv2d(
|
|
522
|
+
in_channels=n_depth_kernels,
|
|
523
|
+
out_channels=n_depth_kernels,
|
|
524
|
+
kernel_size=(kernel_length_2, 1),
|
|
525
|
+
padding="same",
|
|
526
|
+
bias=False,
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
self.bn3 = nn.BatchNorm2d(num_features=n_depth_kernels, eps=1e-4)
|
|
530
|
+
|
|
531
|
+
self.activation3 = nn.ELU()
|
|
532
|
+
|
|
533
|
+
self.pool3 = nn.AvgPool2d(kernel_size=(pool_size_2, 1))
|
|
534
|
+
|
|
535
|
+
self.drop3 = nn.Dropout2d(dropout)
|
|
536
|
+
|
|
537
|
+
def forward(self, X):
|
|
538
|
+
# ----- Temporal convolution -----
|
|
539
|
+
# Dimension: (batch_size, 1, T, C)
|
|
540
|
+
X = self.conv1(X)
|
|
541
|
+
X = self.bn1(X)
|
|
542
|
+
# Dimension: (batch_size, F1, T, C)
|
|
543
|
+
|
|
544
|
+
# ----- Depthwise channels convolution -----
|
|
545
|
+
X = self.conv2(X)
|
|
546
|
+
X = self.bn2(X)
|
|
547
|
+
X = self.activation2(X)
|
|
548
|
+
# Dimension: (batch_size, F1*D, T, 1)
|
|
549
|
+
X = self.pool2(X)
|
|
550
|
+
X = self.drop2(X)
|
|
551
|
+
# Dimension: (batch_size, F1*D, T/P1, 1)
|
|
552
|
+
|
|
553
|
+
# ----- "Spatial" convolution -----
|
|
554
|
+
X = self.conv3(X)
|
|
555
|
+
X = self.bn3(X)
|
|
556
|
+
X = self.activation3(X)
|
|
557
|
+
# Dimension: (batch_size, F1*D, T/P1, 1)
|
|
558
|
+
X = self.pool3(X)
|
|
559
|
+
X = self.drop3(X)
|
|
560
|
+
# Dimension: (batch_size, F1*D, T/(P1*P2), 1)
|
|
561
|
+
|
|
562
|
+
return X
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
class _AttentionBlock(nn.Module):
|
|
566
|
+
r"""Multi Head self Attention (MHA) block used in ATCNet [1]_, inspired from.
|
|
567
|
+
|
|
568
|
+
[2]_.
|
|
569
|
+
|
|
570
|
+
References
|
|
571
|
+
----------
|
|
572
|
+
.. [1] H. Altaheri, G. Muhammad and M. Alsulaiman, "Physics-informed
|
|
573
|
+
attention temporal convolutional network for EEG-based motor imagery
|
|
574
|
+
classification," in IEEE Transactions on Industrial Informatics,
|
|
575
|
+
2022, doi: 10.1109/TII.2022.3197419.
|
|
576
|
+
.. [2] Vaswani, A. et al., "Attention is all you need",
|
|
577
|
+
in Advances in neural information processing systems, 2017.
|
|
578
|
+
"""
|
|
579
|
+
|
|
580
|
+
def __init__(
|
|
581
|
+
self,
|
|
582
|
+
in_shape=32,
|
|
583
|
+
head_dim=8,
|
|
584
|
+
num_heads=2,
|
|
585
|
+
dropout=0.5,
|
|
586
|
+
):
|
|
587
|
+
super().__init__()
|
|
588
|
+
self.in_shape = in_shape
|
|
589
|
+
self.head_dim = head_dim
|
|
590
|
+
self.num_heads = num_heads
|
|
591
|
+
|
|
592
|
+
# Puts time dimension at -2 and feature dim at -1
|
|
593
|
+
self.dimshuffle = Rearrange("batch C T -> batch T C")
|
|
594
|
+
|
|
595
|
+
# Layer normalization
|
|
596
|
+
self.ln = nn.LayerNorm(normalized_shape=in_shape, eps=1e-6)
|
|
597
|
+
|
|
598
|
+
# Multi-head self-attention layer
|
|
599
|
+
# (We had to reimplement it since the original code is in tensorflow,
|
|
600
|
+
# where it is possible to have an embedding dimension different than
|
|
601
|
+
# the input and output dimensions, which is not possible in pytorch.)
|
|
602
|
+
self.mha = _MHA(
|
|
603
|
+
input_dim=in_shape,
|
|
604
|
+
head_dim=head_dim,
|
|
605
|
+
output_dim=in_shape,
|
|
606
|
+
num_heads=num_heads,
|
|
607
|
+
dropout=dropout,
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
# XXX: This line in the official code is weird, as there is already
|
|
611
|
+
# dropout in the MultiheadAttention layer. They also don't mention
|
|
612
|
+
# any additional dropout between the attention block and TCN in the
|
|
613
|
+
# paper. We are adding it here however to follo so we are removing this
|
|
614
|
+
# for now.
|
|
615
|
+
self.drop = nn.Dropout(0.3)
|
|
616
|
+
|
|
617
|
+
def forward(self, X):
|
|
618
|
+
# Dimension: (batch_size, F2, Tw)
|
|
619
|
+
X = self.dimshuffle(X)
|
|
620
|
+
# Dimension: (batch_size, Tw, F2)
|
|
621
|
+
|
|
622
|
+
# ----- Layer norm -----
|
|
623
|
+
out = self.ln(X)
|
|
624
|
+
|
|
625
|
+
# ----- Self-Attention -----
|
|
626
|
+
out = self.mha(out, out, out)
|
|
627
|
+
# Dimension: (batch_size, Tw, F2)
|
|
628
|
+
|
|
629
|
+
# XXX In the paper fig. 1, it is drawn that layer normalization is
|
|
630
|
+
# performed before the skip connection, while it is done afterwards
|
|
631
|
+
# in the official code. Here we follow the code.
|
|
632
|
+
|
|
633
|
+
# ----- Skip connection -----
|
|
634
|
+
out = X + self.drop(out)
|
|
635
|
+
|
|
636
|
+
# Move back to shape (batch_size, F2, Tw) from the beginning
|
|
637
|
+
return self.dimshuffle(out)
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
class _TCNResidualBlock(nn.Module):
|
|
641
|
+
r"""Modified TCN Residual block as proposed in [1]_.
|
|
642
|
+
|
|
643
|
+
Inspired from
|
|
644
|
+
Temporal Convolutional Networks (TCN) [2]_.
|
|
645
|
+
|
|
646
|
+
References
|
|
647
|
+
----------
|
|
648
|
+
.. [1] H. Altaheri, G. Muhammad and M. Alsulaiman, "Physics-informed
|
|
649
|
+
attention temporal convolutional network for EEG-based motor imagery
|
|
650
|
+
classification," in IEEE Transactions on Industrial Informatics,
|
|
651
|
+
2022, doi: 10.1109/TII.2022.3197419.
|
|
652
|
+
.. [2] Bai, S., Kolter, J. Z., & Koltun, V.
|
|
653
|
+
"An empirical evaluation of generic convolutional and recurrent
|
|
654
|
+
networks for sequence modeling", 2018.
|
|
655
|
+
"""
|
|
656
|
+
|
|
657
|
+
def __init__(
|
|
658
|
+
self,
|
|
659
|
+
in_channels,
|
|
660
|
+
kernel_size=4,
|
|
661
|
+
n_filters=32,
|
|
662
|
+
dropout=0.3,
|
|
663
|
+
activation: type[nn.Module] = nn.ELU,
|
|
664
|
+
dilation=1,
|
|
665
|
+
):
|
|
666
|
+
super().__init__()
|
|
667
|
+
self.activation = activation()
|
|
668
|
+
self.dilation = dilation
|
|
669
|
+
self.dropout = dropout
|
|
670
|
+
self.n_filters = n_filters
|
|
671
|
+
self.kernel_size = kernel_size
|
|
672
|
+
self.in_channels = in_channels
|
|
673
|
+
|
|
674
|
+
self.conv1 = CausalConv1d(
|
|
675
|
+
in_channels=in_channels,
|
|
676
|
+
out_channels=n_filters,
|
|
677
|
+
kernel_size=kernel_size,
|
|
678
|
+
dilation=dilation,
|
|
679
|
+
)
|
|
680
|
+
nn.init.kaiming_uniform_(self.conv1.weight)
|
|
681
|
+
|
|
682
|
+
self.bn1 = nn.BatchNorm1d(n_filters)
|
|
683
|
+
|
|
684
|
+
self.drop1 = nn.Dropout(dropout)
|
|
685
|
+
|
|
686
|
+
self.conv2 = CausalConv1d(
|
|
687
|
+
in_channels=n_filters,
|
|
688
|
+
out_channels=n_filters,
|
|
689
|
+
kernel_size=kernel_size,
|
|
690
|
+
dilation=dilation,
|
|
691
|
+
)
|
|
692
|
+
nn.init.kaiming_uniform_(self.conv2.weight)
|
|
693
|
+
|
|
694
|
+
self.bn2 = nn.BatchNorm1d(n_filters)
|
|
695
|
+
|
|
696
|
+
self.drop2 = nn.Dropout(dropout)
|
|
697
|
+
|
|
698
|
+
# Reshape the input for the residual connection when necessary
|
|
699
|
+
if in_channels != n_filters:
|
|
700
|
+
self.reshaping_conv = nn.Conv1d(
|
|
701
|
+
in_channels=in_channels, # Specify input channels
|
|
702
|
+
out_channels=n_filters, # Specify output channels
|
|
703
|
+
kernel_size=1,
|
|
704
|
+
padding="same",
|
|
705
|
+
)
|
|
706
|
+
else:
|
|
707
|
+
self.reshaping_conv = nn.Identity()
|
|
708
|
+
|
|
709
|
+
def forward(self, X):
|
|
710
|
+
# Dimension: (batch_size, F2, Tw)
|
|
711
|
+
# ----- Double dilated convolutions -----
|
|
712
|
+
out = self.conv1(X)
|
|
713
|
+
out = self.bn1(out)
|
|
714
|
+
out = self.activation(out)
|
|
715
|
+
out = self.drop1(out)
|
|
716
|
+
|
|
717
|
+
out = self.conv2(out)
|
|
718
|
+
out = self.bn2(out)
|
|
719
|
+
out = self.activation(out)
|
|
720
|
+
out = self.drop2(out)
|
|
721
|
+
|
|
722
|
+
X = self.reshaping_conv(X)
|
|
723
|
+
|
|
724
|
+
# ----- Residual connection -----
|
|
725
|
+
out = X + out
|
|
726
|
+
|
|
727
|
+
return self.activation(out)
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
class _MHA(nn.Module):
|
|
731
|
+
def __init__(
|
|
732
|
+
self,
|
|
733
|
+
input_dim: int,
|
|
734
|
+
head_dim: int,
|
|
735
|
+
output_dim: int,
|
|
736
|
+
num_heads: int,
|
|
737
|
+
dropout: float = 0.0,
|
|
738
|
+
):
|
|
739
|
+
"""Multi-head Attention.
|
|
740
|
+
|
|
741
|
+
The difference between this module and torch.nn.MultiheadAttention is
|
|
742
|
+
that this module supports embedding dimensions different then input
|
|
743
|
+
and output ones. It also does not support sequences of different
|
|
744
|
+
length.
|
|
745
|
+
|
|
746
|
+
Parameters
|
|
747
|
+
----------
|
|
748
|
+
input_dim : int
|
|
749
|
+
Dimension of query, key and value inputs.
|
|
750
|
+
head_dim : int
|
|
751
|
+
Dimension of embed query, key and value in each head,
|
|
752
|
+
before computing attention.
|
|
753
|
+
output_dim : int
|
|
754
|
+
Output dimension.
|
|
755
|
+
num_heads : int
|
|
756
|
+
Number of heads in the multi-head architecture.
|
|
757
|
+
dropout : float, optional
|
|
758
|
+
Dropout probability on output weights. Default: 0.0 (no dropout).
|
|
759
|
+
"""
|
|
760
|
+
|
|
761
|
+
super(_MHA, self).__init__()
|
|
762
|
+
|
|
763
|
+
self.input_dim = input_dim
|
|
764
|
+
self.head_dim = head_dim
|
|
765
|
+
# typical choice for the split dimension of the heads
|
|
766
|
+
self.embed_dim = head_dim * num_heads
|
|
767
|
+
|
|
768
|
+
# embeddings for multi-head projections
|
|
769
|
+
self.fc_q = nn.Linear(input_dim, self.embed_dim)
|
|
770
|
+
self.fc_k = nn.Linear(input_dim, self.embed_dim)
|
|
771
|
+
self.fc_v = nn.Linear(input_dim, self.embed_dim)
|
|
772
|
+
|
|
773
|
+
# output mapping
|
|
774
|
+
self.fc_o = nn.Linear(self.embed_dim, output_dim)
|
|
775
|
+
|
|
776
|
+
# dropout
|
|
777
|
+
self.dropout = nn.Dropout(dropout)
|
|
778
|
+
|
|
779
|
+
def forward(
|
|
780
|
+
self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor
|
|
781
|
+
) -> torch.Tensor:
|
|
782
|
+
"""Compute MHA(Q, K, V).
|
|
783
|
+
|
|
784
|
+
Parameters
|
|
785
|
+
----------
|
|
786
|
+
Q : torch.Tensor of size (batch_size, seq_len, input_dim)
|
|
787
|
+
Input query (Q) sequence.
|
|
788
|
+
K : torch.Tensor of size (batch_size, seq_len, input_dim)
|
|
789
|
+
Input key (K) sequence.
|
|
790
|
+
V : torch.Tensor of size (batch_size, seq_len, input_dim)
|
|
791
|
+
Input value (V) sequence.
|
|
792
|
+
|
|
793
|
+
Returns
|
|
794
|
+
-------
|
|
795
|
+
O : torch.Tensor of size (batch_size, seq_len, output_dim)
|
|
796
|
+
Output MHA(Q, K, V)
|
|
797
|
+
"""
|
|
798
|
+
assert Q.shape[-1] == K.shape[-1] == V.shape[-1] == self.input_dim
|
|
799
|
+
|
|
800
|
+
batch_size, _, _ = Q.shape
|
|
801
|
+
|
|
802
|
+
# embedding for multi-head projections (masked or not)
|
|
803
|
+
Q = self.fc_q(Q) # (B, S, D)
|
|
804
|
+
K, V = self.fc_k(K), self.fc_v(V) # (B, S, D)
|
|
805
|
+
|
|
806
|
+
# Split into num_head vectors (num_heads * batch_size, n/m, head_dim)
|
|
807
|
+
Q_ = torch.cat(Q.split(self.head_dim, -1), 0) # (B', S, D')
|
|
808
|
+
K_ = torch.cat(K.split(self.head_dim, -1), 0) # (B', S, D')
|
|
809
|
+
V_ = torch.cat(V.split(self.head_dim, -1), 0) # (B', S, D')
|
|
810
|
+
|
|
811
|
+
# Attention weights of size (num_heads * batch_size, n, m):
|
|
812
|
+
# measures how similar each pair of Q and K is.
|
|
813
|
+
W = torch.softmax(
|
|
814
|
+
Q_.bmm(K_.transpose(-2, -1)) / math.sqrt(self.head_dim),
|
|
815
|
+
-1, # (B', D', S)
|
|
816
|
+
) # (B', N, M)
|
|
817
|
+
|
|
818
|
+
# Multihead output (batch_size, seq_len, dim):
|
|
819
|
+
# weighted sum of V where a value gets more weight if its corresponding
|
|
820
|
+
# key has larger dot product with the query.
|
|
821
|
+
H = torch.cat(
|
|
822
|
+
(W.bmm(V_)).split( # (B', S, S) # (B', S, D')
|
|
823
|
+
batch_size, 0
|
|
824
|
+
), # [(B, S, D')] * num_heads
|
|
825
|
+
-1,
|
|
826
|
+
) # (B, S, D)
|
|
827
|
+
|
|
828
|
+
out = self.fc_o(H)
|
|
829
|
+
|
|
830
|
+
return self.dropout(out)
|