braindecode 1.3.0.dev180329405__py3-none-any.whl → 1.3.0.dev182330353__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/augmentation/base.py +1 -1
- braindecode/datasets/__init__.py +12 -4
- braindecode/datasets/base.py +115 -151
- braindecode/datasets/bcicomp.py +4 -4
- braindecode/datasets/bids.py +3 -3
- braindecode/datasets/experimental.py +2 -2
- braindecode/datasets/mne.py +3 -5
- braindecode/datasets/moabb.py +17 -7
- braindecode/datasets/nmt.py +2 -2
- braindecode/datasets/sleep_physio_challe_18.py +2 -2
- braindecode/datasets/sleep_physionet.py +2 -2
- braindecode/datasets/tuh.py +2 -2
- braindecode/datasets/xy.py +2 -2
- braindecode/datautil/__init__.py +11 -1
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/serialization.py +7 -7
- braindecode/functional/functions.py +6 -2
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +6 -0
- braindecode/models/atcnet.py +26 -27
- braindecode/models/attentionbasenet.py +37 -32
- braindecode/models/attn_sleep.py +2 -0
- braindecode/models/base.py +280 -2
- braindecode/models/bendr.py +469 -0
- braindecode/models/biot.py +2 -0
- braindecode/models/contrawr.py +2 -0
- braindecode/models/ctnet.py +8 -3
- braindecode/models/deepsleepnet.py +28 -19
- braindecode/models/eegconformer.py +2 -2
- braindecode/models/eeginception_erp.py +31 -25
- braindecode/models/eegitnet.py +2 -0
- braindecode/models/eegminer.py +2 -0
- braindecode/models/eegnet.py +1 -1
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +2 -0
- braindecode/models/fbcnet.py +5 -1
- braindecode/models/fblightconvnet.py +2 -0
- braindecode/models/fbmsnet.py +20 -6
- braindecode/models/ifnet.py +2 -0
- braindecode/models/labram.py +33 -26
- braindecode/models/medformer.py +758 -0
- braindecode/models/msvtnet.py +2 -0
- braindecode/models/patchedtransformer.py +1 -1
- braindecode/models/signal_jepa.py +111 -27
- braindecode/models/sinc_shallow.py +12 -9
- braindecode/models/sstdpn.py +11 -11
- braindecode/models/summary.csv +3 -0
- braindecode/models/syncnet.py +2 -0
- braindecode/models/tcn.py +2 -0
- braindecode/models/usleep.py +26 -21
- braindecode/models/util.py +3 -0
- braindecode/modules/attention.py +10 -10
- braindecode/modules/blocks.py +3 -3
- braindecode/modules/filter.py +2 -9
- braindecode/modules/layers.py +18 -17
- braindecode/preprocessing/__init__.py +232 -3
- braindecode/preprocessing/eegprep_preprocess.py +1202 -0
- braindecode/preprocessing/mne_preprocess.py +142 -10
- braindecode/preprocessing/preprocess.py +28 -18
- braindecode/preprocessing/util.py +166 -0
- braindecode/preprocessing/windowers.py +26 -20
- braindecode/samplers/base.py +8 -8
- braindecode/version.py +1 -1
- {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/METADATA +6 -2
- braindecode-1.3.0.dev182330353.dist-info/RECORD +109 -0
- braindecode-1.3.0.dev180329405.dist-info/RECORD +0 -103
- {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,758 @@
|
|
|
1
|
+
# Authors: Yihe Wang <ywang145@charlotte.edu>
|
|
2
|
+
# Nan Huang <nhuang1@charlotte.edu>
|
|
3
|
+
# Taida Li <tli14@charlotte.edu>
|
|
4
|
+
#
|
|
5
|
+
# License: MIT
|
|
6
|
+
|
|
7
|
+
"""Medformer: A Multi-Granularity Patching Transformer for Medical Time-Series Classification."""
|
|
8
|
+
|
|
9
|
+
import math
|
|
10
|
+
from math import sqrt
|
|
11
|
+
from typing import List, Optional, Tuple
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import torch
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
|
|
17
|
+
from braindecode.models.base import EEGModuleMixin
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class MEDFormer(EEGModuleMixin, nn.Module):
|
|
21
|
+
r"""Medformer from Wang et al. (2024) [Medformer2024]_.
|
|
22
|
+
|
|
23
|
+
:bdg-success:`Convolution` :bdg-danger:`Large Brain Model`
|
|
24
|
+
|
|
25
|
+
.. figure:: https://raw.githubusercontent.com/DL4mHealth/Medformer/refs/heads/main/figs/medformer_architecture.png
|
|
26
|
+
:align: center
|
|
27
|
+
:alt: MEDFormer Architecture.
|
|
28
|
+
|
|
29
|
+
a) Workflow. b) For the input sample :math:`{x}_{\\textrm{in}}`, the authors apply :math:`n`
|
|
30
|
+
different patch lengths in parallel to create patched features :math:`{x}_p^{(i)}`, where :math:`i`
|
|
31
|
+
ranges from 1 to :math:`n`. Each patch length represents a different granularity. These patched
|
|
32
|
+
features are linearly transformed into :math:`{x}_e^{(i)}` and augmented into :math:`\\widetilde{x}_e^{(i)}`.
|
|
33
|
+
c) The final patch embedding :math:`{x}^{(i)}` fuses augmented :math:`\\widetilde{{x}}_e^{(i)}` with the
|
|
34
|
+
positional embedding :math:`{W}_{\\text{pos}}` and the granularity embedding :math:`{W}_{\\text{gr}}^{(i)}`.
|
|
35
|
+
Each granularity employs a router :math:`{u}^{(i)}` to capture aggregated information.
|
|
36
|
+
Intra-granularity attention focuses within individual granularities, and inter-granularity attention
|
|
37
|
+
leverages the routers to integrate information across granularities.
|
|
38
|
+
|
|
39
|
+
The **MedFormer** is a multi-granularity patching transformer tailored to medical
|
|
40
|
+
time-series (MedTS) classification, with an emphasis on EEG and ECG signals. It captures
|
|
41
|
+
local temporal dynamics, inter-channel correlations, and multi-scale temporal structure
|
|
42
|
+
through cross-channel patching, multi-granularity embeddings, and two-stage attention
|
|
43
|
+
[Medformer2024]_.
|
|
44
|
+
|
|
45
|
+
.. rubric:: Architecture Overview
|
|
46
|
+
|
|
47
|
+
MedFormer integrates three mechanisms to enhance representation learning [Medformer2024]_:
|
|
48
|
+
|
|
49
|
+
1. **Cross-channel patching.** Leverages inter-channel correlations by forming patches
|
|
50
|
+
across multiple channels and timestamps, capturing multi-timestamp and cross-channel
|
|
51
|
+
patterns.
|
|
52
|
+
2. **Multi-granularity embedding.** Extracts features at different temporal scales from
|
|
53
|
+
:attr:`patch_len_list`, emulating frequency-band behavior without hand-crafted filters.
|
|
54
|
+
3. **Two-stage multi-granularity self-attention.** Learns intra- and inter-granularity
|
|
55
|
+
correlations to fuse information across temporal scales.
|
|
56
|
+
|
|
57
|
+
.. rubric:: Macro Components
|
|
58
|
+
|
|
59
|
+
``MEDFormer.enc_embedding`` (Embedding Layer)
|
|
60
|
+
**Operations.** :class:`~braindecode.models.medformer._ListPatchEmbedding` implements
|
|
61
|
+
cross-channel multi-granularity patching. For each patch length :math:`L_i`, the input
|
|
62
|
+
:math:`\mathbf{x}_{\text{in}} \in \mathbb{R}^{T \times C}` is segmented into
|
|
63
|
+
:math:`N_i` cross-channel non-overlapping patches
|
|
64
|
+
:math:`\mathbf{x}_p^{(i)} \in \mathbb{R}^{N_i \times (L_i \cdot C)}`, where
|
|
65
|
+
:math:`N_i = \lceil T/L_i \rceil`. Each patch is linearly projected via
|
|
66
|
+
:class:`~braindecode.models.medformer._CrossChannelTokenEmbedding` to obtain
|
|
67
|
+
:math:`\mathbf{x}_e^{(i)} \in \mathbb{R}^{N_i \times D}`. Data augmentations
|
|
68
|
+
(masking, jittering) produce augmented embeddings :math:`\tilde{\mathbf{x}}_e^{(i)}`.
|
|
69
|
+
The final embedding combines augmented patches, fixed positional embeddings
|
|
70
|
+
(:class:`~braindecode.models.medformer._PositionalEmbedding`), and learnable
|
|
71
|
+
granularity embeddings :math:`\mathbf{W}_{\text{gr}}^{(i)}`:
|
|
72
|
+
|
|
73
|
+
.. math::
|
|
74
|
+
\mathbf{x}^{(i)} = \tilde{\mathbf{x}}_e^{(i)} + \mathbf{W}_{\text{pos}}[1:N_i] + \mathbf{W}_{\text{gr}}^{(i)}
|
|
75
|
+
|
|
76
|
+
Additionally, a router token is initialized for each granularity:
|
|
77
|
+
|
|
78
|
+
.. math::
|
|
79
|
+
\mathbf{u}^{(i)} = \mathbf{W}_{\text{pos}}[N_i+1] + \mathbf{W}_{\text{gr}}^{(i)}
|
|
80
|
+
|
|
81
|
+
**Role.** Converts raw input into granularity-specific patch embeddings
|
|
82
|
+
:math:`\{\mathbf{x}^{(1)}, \ldots, \mathbf{x}^{(n)}\}` and router embeddings
|
|
83
|
+
:math:`\{\mathbf{u}^{(1)}, \ldots, \mathbf{u}^{(n)}\}` for multi-scale processing.
|
|
84
|
+
|
|
85
|
+
``MEDFormer.encoder`` (Transformer Encoder Stack)
|
|
86
|
+
**Operations.** A stack of :class:`~braindecode.models.medformer._EncoderLayer` modules,
|
|
87
|
+
each containing a :class:`~braindecode.models.medformer._MedformerLayer` that implements
|
|
88
|
+
two-stage self-attention. The two-stage mechanism splits self-attention into:
|
|
89
|
+
|
|
90
|
+
**(a) Intra-Granularity Self-Attention.** For granularity :math:`i`, the patch embedding
|
|
91
|
+
:math:`\mathbf{x}^{(i)} \in \mathbb{R}^{N_i \times D}` and router embedding
|
|
92
|
+
:math:`\mathbf{u}^{(i)} \in \mathbb{R}^{1 \times D}` are concatenated:
|
|
93
|
+
|
|
94
|
+
.. math::
|
|
95
|
+
\mathbf{z}^{(i)} = [\mathbf{x}^{(i)} \| \mathbf{u}^{(i)}] \in \mathbb{R}^{(N_i+1) \times D}
|
|
96
|
+
|
|
97
|
+
Self-attention is applied to update both embeddings:
|
|
98
|
+
|
|
99
|
+
.. math::
|
|
100
|
+
\mathbf{x}^{(i)} &\leftarrow \text{Attn}_{\text{intra}}(\mathbf{x}^{(i)}, \mathbf{z}^{(i)}, \mathbf{z}^{(i)})\\
|
|
101
|
+
\mathbf{u}^{(i)} &\leftarrow \text{Attn}_{\text{intra}}(\mathbf{u}^{(i)}, \mathbf{z}^{(i)}, \mathbf{z}^{(i)})
|
|
102
|
+
|
|
103
|
+
This captures temporal features within each granularity independently.
|
|
104
|
+
|
|
105
|
+
**(b) Inter-Granularity Self-Attention.** All router embeddings are concatenated:
|
|
106
|
+
|
|
107
|
+
.. math::
|
|
108
|
+
\mathbf{U} = [\mathbf{u}^{(1)} \| \mathbf{u}^{(2)} \| \cdots \| \mathbf{u}^{(n)}] \in \mathbb{R}^{n \times D}
|
|
109
|
+
|
|
110
|
+
Self-attention among routers exchanges information across granularities:
|
|
111
|
+
|
|
112
|
+
.. math::
|
|
113
|
+
\mathbf{u}^{(i)} \leftarrow \text{Attn}_{\text{inter}}(\mathbf{u}^{(i)}, \mathbf{U}, \mathbf{U})
|
|
114
|
+
|
|
115
|
+
**Role.** Learns representations and correlations within and across temporal scales while
|
|
116
|
+
reducing complexity from :math:`O((\sum_i N_i)^2)` to
|
|
117
|
+
:math:`O(\sum_i N_i^2 + n^2)` through the router mechanism.
|
|
118
|
+
.. rubric:: Temporal, Spatial, and Spectral Encoding
|
|
119
|
+
|
|
120
|
+
- **Temporal:** Multiple patch lengths in :attr:`patch_len_list` capture features at several
|
|
121
|
+
temporal granularities, while intra-granularity attention supports long-range temporal
|
|
122
|
+
dependencies.
|
|
123
|
+
- **Spatial:** Cross-channel patching embeds inter-channel dependencies by applying kernels
|
|
124
|
+
that span every input channel.
|
|
125
|
+
- **Spectral:** Differing patch lengths simulate multiple sampling frequencies analogous to
|
|
126
|
+
clinically relevant bands (e.g., alpha, beta, gamma).
|
|
127
|
+
|
|
128
|
+
.. rubric:: Additional Mechanisms
|
|
129
|
+
|
|
130
|
+
- **Granularity router:** Each granularity :math:`i` receives a dedicated router token
|
|
131
|
+
:math:`\\mathbf{u}^{(i)}`. Intra-attention updates the token, and inter-attention exchanges
|
|
132
|
+
aggregated information across scales.
|
|
133
|
+
- **Complexity:** Router-mediated two-stage attention maintains :math:`O(T^2)` complexity for
|
|
134
|
+
suitable patch lengths (e.g., power series), preserving transformer-like efficiency while
|
|
135
|
+
modeling multiple granularities.
|
|
136
|
+
|
|
137
|
+
Parameters
|
|
138
|
+
----------
|
|
139
|
+
patch_len_list : list of int, optional
|
|
140
|
+
Patch lengths for multi-granularity patching; each entry selects a temporal scale.
|
|
141
|
+
The default is ``[14, 44, 45]``.
|
|
142
|
+
d_model : int, optional
|
|
143
|
+
Embedding dimensionality. The default is ``128``.
|
|
144
|
+
num_heads : int, optional
|
|
145
|
+
Number of attention heads, which must divide :attr:`d_model`. The default is ``8``.
|
|
146
|
+
drop_prob : float, optional
|
|
147
|
+
Dropout probability. The default is ``0.1``.
|
|
148
|
+
no_inter_attn : bool, optional
|
|
149
|
+
If ``True``, disables inter-granularity attention. The default is ``False``.
|
|
150
|
+
n_layers : int, optional
|
|
151
|
+
Number of encoder layers. The default is ``6``.
|
|
152
|
+
dim_feedforward : int, optional
|
|
153
|
+
Feedforward dimensionality. The default is ``256``.
|
|
154
|
+
activation_trans : nn.Module, optional
|
|
155
|
+
Activation module used in transformer encoder layers. The default is :class:`nn.ReLU`.
|
|
156
|
+
single_channel : bool, optional
|
|
157
|
+
If ``True``, processes each channel independently, increasing capacity and cost. The default is ``False``.
|
|
158
|
+
output_attention : bool, optional
|
|
159
|
+
If ``True``, returns attention weights for interpretability. The default is ``True``.
|
|
160
|
+
activation_class : nn.Module, optional
|
|
161
|
+
Activation used in the final classification layer. The default is :class:`nn.GELU`.
|
|
162
|
+
|
|
163
|
+
Notes
|
|
164
|
+
-----
|
|
165
|
+
- MedFormer outperforms strong baselines across six metrics on five MedTS datasets in a
|
|
166
|
+
subject-independent evaluation [Medformer2024]_.
|
|
167
|
+
- Cross-channel patching provides the largest F1 improvement in ablation studies (average
|
|
168
|
+
+6.10%), highlighting its importance for MedTS tasks [Medformer2024]_.
|
|
169
|
+
- Setting :attr:`no_inter_attn` to ``True`` disables inter-granularity attention while retaining
|
|
170
|
+
intra-granularity attention.
|
|
171
|
+
|
|
172
|
+
References
|
|
173
|
+
----------
|
|
174
|
+
.. [Medformer2024] Wang, Y., Huang, N., Li, T., Yan, Y., & Zhang, X. (2024).
|
|
175
|
+
Medformer: A Multi-Granularity Patching Transformer for Medical Time-Series Classification.
|
|
176
|
+
In A. Globerson, L. Mackey, D. Belgrave, A. Fan, U. Paquet, J. Tomczak, & C. Zhang (Eds.),
|
|
177
|
+
Advances in Neural Information Processing Systems (Vol. 37, pp. 36314-36341).
|
|
178
|
+
doi:10.52202/079017-1145.
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
def __init__(
|
|
182
|
+
self,
|
|
183
|
+
# Signal related parameters
|
|
184
|
+
n_chans=None,
|
|
185
|
+
n_outputs=None,
|
|
186
|
+
n_times=None,
|
|
187
|
+
chs_info=None,
|
|
188
|
+
input_window_seconds=None,
|
|
189
|
+
sfreq=None,
|
|
190
|
+
# Model parameters
|
|
191
|
+
patch_len_list: Optional[List[int]] = None,
|
|
192
|
+
d_model: int = 128,
|
|
193
|
+
num_heads: int = 8,
|
|
194
|
+
drop_prob: float = 0.1,
|
|
195
|
+
no_inter_attn: bool = False,
|
|
196
|
+
n_layers: int = 6,
|
|
197
|
+
dim_feedforward: int = 256,
|
|
198
|
+
activation_trans: Optional[nn.Module] = nn.ReLU,
|
|
199
|
+
single_channel: bool = False,
|
|
200
|
+
output_attention: bool = True,
|
|
201
|
+
activation_class: Optional[nn.Module] = nn.GELU,
|
|
202
|
+
):
|
|
203
|
+
super().__init__(
|
|
204
|
+
n_outputs=n_outputs,
|
|
205
|
+
n_chans=n_chans,
|
|
206
|
+
chs_info=chs_info,
|
|
207
|
+
n_times=n_times,
|
|
208
|
+
input_window_seconds=input_window_seconds,
|
|
209
|
+
sfreq=sfreq,
|
|
210
|
+
)
|
|
211
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
212
|
+
|
|
213
|
+
# In the original Medformer paper:
|
|
214
|
+
# - seq_len refers to the number of channels
|
|
215
|
+
# - enc_in refers to the number of time points
|
|
216
|
+
|
|
217
|
+
# Save model parameters as instance variables
|
|
218
|
+
self.d_model = d_model
|
|
219
|
+
self.num_heads = num_heads
|
|
220
|
+
self.drop_prob = drop_prob
|
|
221
|
+
self.no_inter_attn = no_inter_attn
|
|
222
|
+
self.n_layers = n_layers
|
|
223
|
+
self.dim_feedforward = dim_feedforward
|
|
224
|
+
self.activation_trans = activation_trans
|
|
225
|
+
self.output_attention = output_attention
|
|
226
|
+
self.single_channel = single_channel
|
|
227
|
+
self.activation_class = activation_class
|
|
228
|
+
|
|
229
|
+
# Process the sequence and patch configurations.
|
|
230
|
+
if patch_len_list is None:
|
|
231
|
+
patch_len_list = [2, 8, 16]
|
|
232
|
+
|
|
233
|
+
self.patch_len_list = patch_len_list
|
|
234
|
+
stride_list = patch_len_list # Using the same values for strides.
|
|
235
|
+
self.stride_list = stride_list
|
|
236
|
+
patch_num_list = [
|
|
237
|
+
int((self.n_chans - patch_len) / stride + 2)
|
|
238
|
+
for patch_len, stride in zip(patch_len_list, stride_list)
|
|
239
|
+
]
|
|
240
|
+
self.patch_num_list = patch_num_list
|
|
241
|
+
|
|
242
|
+
# Initialize the embedding layer.
|
|
243
|
+
self.enc_embedding = _ListPatchEmbedding(
|
|
244
|
+
enc_in=self.n_times,
|
|
245
|
+
d_model=self.d_model,
|
|
246
|
+
seq_len=self.n_chans,
|
|
247
|
+
patch_len_list=self.patch_len_list,
|
|
248
|
+
stride_list=self.stride_list,
|
|
249
|
+
dropout=self.drop_prob,
|
|
250
|
+
single_channel=self.single_channel,
|
|
251
|
+
n_chans=self.n_chans,
|
|
252
|
+
n_times=self.n_times,
|
|
253
|
+
)
|
|
254
|
+
# Build the encoder with multiple layers.
|
|
255
|
+
self.encoder = _Encoder(
|
|
256
|
+
[
|
|
257
|
+
_EncoderLayer(
|
|
258
|
+
attention=_MedformerLayer(
|
|
259
|
+
num_blocks=len(self.patch_len_list),
|
|
260
|
+
d_model=self.d_model,
|
|
261
|
+
num_heads=self.num_heads,
|
|
262
|
+
dropout=self.drop_prob,
|
|
263
|
+
output_attention=self.output_attention,
|
|
264
|
+
no_inter=self.no_inter_attn,
|
|
265
|
+
),
|
|
266
|
+
d_model=self.d_model,
|
|
267
|
+
dim_feedforward=self.dim_feedforward,
|
|
268
|
+
dropout=self.drop_prob,
|
|
269
|
+
activation=self.activation_trans()
|
|
270
|
+
if self.activation_trans is not None
|
|
271
|
+
else nn.ReLU(),
|
|
272
|
+
)
|
|
273
|
+
for _ in range(self.n_layers)
|
|
274
|
+
],
|
|
275
|
+
norm_layer=torch.nn.LayerNorm(self.d_model),
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# For classification tasks, add additional layers.
|
|
279
|
+
self.activation_layer = (
|
|
280
|
+
self.activation_class() if self.activation_class is not None else nn.GELU()
|
|
281
|
+
)
|
|
282
|
+
self.dropout = nn.Dropout(self.drop_prob)
|
|
283
|
+
self.final_layer = nn.Linear(
|
|
284
|
+
self.d_model
|
|
285
|
+
* len(self.patch_num_list)
|
|
286
|
+
* (1 if not self.single_channel else self.n_chans),
|
|
287
|
+
self.n_outputs,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
291
|
+
"""Forward pass of the Medformer model.
|
|
292
|
+
|
|
293
|
+
Parameters
|
|
294
|
+
----------
|
|
295
|
+
x : torch.Tensor
|
|
296
|
+
Input tensor of shape (batch_size, n_chans, n_times).
|
|
297
|
+
|
|
298
|
+
Returns
|
|
299
|
+
-------
|
|
300
|
+
torch.Tensor
|
|
301
|
+
Output tensor of shape (batch_size, n_outputs).
|
|
302
|
+
"""
|
|
303
|
+
# Embedding
|
|
304
|
+
enc_out = self.enc_embedding(x)
|
|
305
|
+
enc_out, _ = self.encoder(enc_out, attn_mask=None)
|
|
306
|
+
|
|
307
|
+
if self.single_channel:
|
|
308
|
+
# Reshape back from (batch_size * n_chans, ...) to (batch_size, n_chans, ...)
|
|
309
|
+
# Explicitly construct the reshape dimensions to be TorchScript compatible
|
|
310
|
+
batch_size = enc_out.shape[0] // self.n_chans
|
|
311
|
+
seq_len = enc_out.shape[1]
|
|
312
|
+
d_model = enc_out.shape[2]
|
|
313
|
+
enc_out = torch.reshape(
|
|
314
|
+
enc_out, (batch_size, self.n_chans, seq_len, d_model)
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# Output
|
|
318
|
+
output = self.activation_layer(enc_out)
|
|
319
|
+
output = self.dropout(output)
|
|
320
|
+
output = output.reshape(
|
|
321
|
+
output.shape[0], -1
|
|
322
|
+
) # (batch_size, seq_length * d_model)
|
|
323
|
+
output = self.final_layer(output) # (batch_size, num_classes)
|
|
324
|
+
return output
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
class _PositionalEmbedding(nn.Module):
|
|
328
|
+
def __init__(self, d_model: int, max_len: int = 5000):
|
|
329
|
+
super().__init__()
|
|
330
|
+
# If d_model is odd, temporarily work with d_model + 1.
|
|
331
|
+
if d_model % 2 == 1:
|
|
332
|
+
d_model_adj = d_model + 1
|
|
333
|
+
else:
|
|
334
|
+
d_model_adj = d_model
|
|
335
|
+
self.d_model = d_model # store the original dimension
|
|
336
|
+
|
|
337
|
+
# Create a pe tensor of size (max_len, d_model_adj)
|
|
338
|
+
pe = torch.zeros(max_len, d_model_adj).float()
|
|
339
|
+
pe.requires_grad = False
|
|
340
|
+
|
|
341
|
+
# Compute the sinusoidal factors.
|
|
342
|
+
position = torch.arange(0, max_len).float().unsqueeze(1)
|
|
343
|
+
# Use d_model_adj in the denominator so that the frequencies are computed over an even number.
|
|
344
|
+
div_term = torch.exp(
|
|
345
|
+
torch.arange(0, d_model_adj, 2).float() * (-math.log(10000.0) / d_model_adj)
|
|
346
|
+
)
|
|
347
|
+
pe[:, 0::2] = torch.sin(position * div_term)
|
|
348
|
+
pe[:, 1::2] = torch.cos(position * div_term)
|
|
349
|
+
|
|
350
|
+
# Unsqueeze to shape (1, max_len, d_model_adj)
|
|
351
|
+
pe = pe.unsqueeze(0)
|
|
352
|
+
self.register_buffer("pe", pe)
|
|
353
|
+
|
|
354
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
355
|
+
# x is assumed to have shape (B, L, d_model_target)
|
|
356
|
+
# We return the first self.d_model columns from the computed pe.
|
|
357
|
+
return self.pe[:, : x.size(1), : self.d_model]
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
class _CrossChannelTokenEmbedding(nn.Module):
|
|
361
|
+
def __init__(
|
|
362
|
+
self, c_in: int, l_patch: int, d_model: int, stride: Optional[int] = None
|
|
363
|
+
):
|
|
364
|
+
super().__init__()
|
|
365
|
+
if stride is None:
|
|
366
|
+
stride = l_patch
|
|
367
|
+
self.token_conv = nn.Conv2d(
|
|
368
|
+
in_channels=1,
|
|
369
|
+
out_channels=d_model,
|
|
370
|
+
kernel_size=(c_in, l_patch),
|
|
371
|
+
stride=(1, stride),
|
|
372
|
+
padding=0,
|
|
373
|
+
padding_mode="circular",
|
|
374
|
+
bias=False,
|
|
375
|
+
)
|
|
376
|
+
for m in self.modules():
|
|
377
|
+
if isinstance(m, nn.Conv2d):
|
|
378
|
+
nn.init.kaiming_normal_(
|
|
379
|
+
m.weight, mode="fan_in", nonlinearity="leaky_relu"
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
383
|
+
x = self.token_conv(x)
|
|
384
|
+
return x
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
class _ListPatchEmbedding(nn.Module):
|
|
388
|
+
def __init__(
|
|
389
|
+
self,
|
|
390
|
+
enc_in: int,
|
|
391
|
+
d_model: int,
|
|
392
|
+
seq_len: int,
|
|
393
|
+
patch_len_list: List[int],
|
|
394
|
+
stride_list: List[int],
|
|
395
|
+
dropout: float,
|
|
396
|
+
single_channel: bool = False,
|
|
397
|
+
n_chans: Optional[int] = None,
|
|
398
|
+
n_times: Optional[int] = None,
|
|
399
|
+
):
|
|
400
|
+
super().__init__()
|
|
401
|
+
self.patch_len_list = patch_len_list
|
|
402
|
+
self.stride_list = stride_list
|
|
403
|
+
# Use ModuleList so TorchScript can statically infer module attribute types
|
|
404
|
+
self.paddings: nn.ModuleList = nn.ModuleList(
|
|
405
|
+
[nn.ReplicationPad1d((0, stride)) for stride in stride_list]
|
|
406
|
+
)
|
|
407
|
+
self.single_channel = single_channel
|
|
408
|
+
self.n_chans = n_chans
|
|
409
|
+
self.n_times = n_times
|
|
410
|
+
|
|
411
|
+
# Number of different patch/granularity blocks (used to make loops TorchScript-friendly)
|
|
412
|
+
self.num_patches = len(patch_len_list)
|
|
413
|
+
|
|
414
|
+
linear_layers = [
|
|
415
|
+
_CrossChannelTokenEmbedding(
|
|
416
|
+
c_in=enc_in if not single_channel else 1,
|
|
417
|
+
l_patch=patch_len,
|
|
418
|
+
d_model=d_model,
|
|
419
|
+
)
|
|
420
|
+
for patch_len in patch_len_list
|
|
421
|
+
]
|
|
422
|
+
self.value_embeddings = nn.ModuleList(linear_layers)
|
|
423
|
+
self.position_embedding = _PositionalEmbedding(d_model=d_model)
|
|
424
|
+
self.channel_embedding = _PositionalEmbedding(d_model=seq_len)
|
|
425
|
+
self.dropout = nn.Dropout(dropout)
|
|
426
|
+
|
|
427
|
+
self.learnable_embeddings = nn.ParameterList(
|
|
428
|
+
[nn.Parameter(torch.randn(1, d_model)) for _ in patch_len_list]
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
def forward(
|
|
432
|
+
self, x: torch.Tensor
|
|
433
|
+
) -> List[torch.Tensor]: # (batch_size, seq_len, enc_in)
|
|
434
|
+
x = x.permute(0, 2, 1) # (batch_size, enc_in, seq_len)
|
|
435
|
+
if self.single_channel:
|
|
436
|
+
# After permute: x.shape = (batch_size, n_times, n_chans)
|
|
437
|
+
# We want to process each channel independently
|
|
438
|
+
batch_size = x.shape[0]
|
|
439
|
+
# Permute to get channels in the middle: (batch_size, n_chans, n_times)
|
|
440
|
+
x = x.permute(0, 2, 1)
|
|
441
|
+
# Reshape to treat each channel independently: (batch_size * n_chans, 1, n_times)
|
|
442
|
+
x = torch.reshape(x, (batch_size * self.n_chans, 1, self.n_times))
|
|
443
|
+
|
|
444
|
+
x_list = []
|
|
445
|
+
for padding, value_embedding in zip(self.paddings, self.value_embeddings):
|
|
446
|
+
x_copy = x.clone()
|
|
447
|
+
# add positional embedding to tag each channel (only when not single_channel)
|
|
448
|
+
if not self.single_channel:
|
|
449
|
+
x_new = x_copy + self.channel_embedding(x_copy)
|
|
450
|
+
else:
|
|
451
|
+
x_new = x_copy
|
|
452
|
+
x_new = padding(x_new).unsqueeze(
|
|
453
|
+
1
|
|
454
|
+
) # (batch_size, 1, enc_in, seq_len+stride)
|
|
455
|
+
x_new = value_embedding(x_new) # (batch_size, d_model, 1, patch_num)
|
|
456
|
+
x_new = x_new.squeeze(2).transpose(1, 2) # (batch_size, patch_num, d_model)
|
|
457
|
+
x_list.append(x_new)
|
|
458
|
+
|
|
459
|
+
# Combine each patch embedding with its corresponding learnable granularity
|
|
460
|
+
# embedding and positional embedding. Use an explicit indexed loop so
|
|
461
|
+
# TorchScript can statically determine lengths instead of iterating over
|
|
462
|
+
# Python lists/ParameterList via zip.
|
|
463
|
+
out_list: List[torch.Tensor] = []
|
|
464
|
+
# Iterate over learnable_embeddings with enumerate (supported by TorchScript)
|
|
465
|
+
for idx, cxt in enumerate(self.learnable_embeddings):
|
|
466
|
+
xi = x_list[idx]
|
|
467
|
+
xi = xi + cxt + self.position_embedding(xi)
|
|
468
|
+
out_list.append(xi)
|
|
469
|
+
|
|
470
|
+
return out_list
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
class _AttentionLayer(nn.Module):
|
|
474
|
+
def __init__(
|
|
475
|
+
self,
|
|
476
|
+
attention: nn.Module,
|
|
477
|
+
d_model: int,
|
|
478
|
+
num_heads: int,
|
|
479
|
+
d_keys: Optional[int] = None,
|
|
480
|
+
d_values: Optional[int] = None,
|
|
481
|
+
):
|
|
482
|
+
super().__init__()
|
|
483
|
+
|
|
484
|
+
d_keys = d_keys or (d_model // num_heads)
|
|
485
|
+
d_values = d_values or (d_model // num_heads)
|
|
486
|
+
|
|
487
|
+
self.inner_attention = attention
|
|
488
|
+
self.query_projection = nn.Linear(d_model, d_keys * num_heads)
|
|
489
|
+
self.key_projection = nn.Linear(d_model, d_keys * num_heads)
|
|
490
|
+
self.value_projection = nn.Linear(d_model, d_values * num_heads)
|
|
491
|
+
self.out_projection = nn.Linear(d_values * num_heads, d_model)
|
|
492
|
+
self.num_heads = num_heads
|
|
493
|
+
|
|
494
|
+
def forward(
|
|
495
|
+
self,
|
|
496
|
+
queries: torch.Tensor,
|
|
497
|
+
keys: torch.Tensor,
|
|
498
|
+
values: torch.Tensor,
|
|
499
|
+
attn_mask: Optional[torch.Tensor],
|
|
500
|
+
tau: Optional[torch.Tensor] = None,
|
|
501
|
+
delta: Optional[torch.Tensor] = None,
|
|
502
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
503
|
+
batch_size, query_len, _ = queries.shape
|
|
504
|
+
_, key_len, _ = keys.shape
|
|
505
|
+
num_heads = self.num_heads
|
|
506
|
+
|
|
507
|
+
queries = self.query_projection(queries).view(
|
|
508
|
+
batch_size, query_len, num_heads, -1
|
|
509
|
+
) # multi-head
|
|
510
|
+
keys = self.key_projection(keys).view(batch_size, key_len, num_heads, -1)
|
|
511
|
+
values = self.value_projection(values).view(batch_size, key_len, num_heads, -1)
|
|
512
|
+
|
|
513
|
+
out, attn = self.inner_attention(
|
|
514
|
+
queries, keys, values, attn_mask, tau=tau, delta=delta
|
|
515
|
+
)
|
|
516
|
+
out = out.view(batch_size, query_len, -1)
|
|
517
|
+
|
|
518
|
+
return self.out_projection(out), attn
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
class _TriangularCausalMask:
|
|
522
|
+
def __init__(
|
|
523
|
+
self, batch_size: int, seq_len: int, device: Optional[torch.device] = None
|
|
524
|
+
):
|
|
525
|
+
# Normalize device to a torch.device for .to(device)
|
|
526
|
+
if device is None:
|
|
527
|
+
device = torch.device("cpu")
|
|
528
|
+
mask_shape = [batch_size, 1, seq_len, seq_len]
|
|
529
|
+
with torch.no_grad():
|
|
530
|
+
self._mask = torch.triu(
|
|
531
|
+
torch.ones(mask_shape, dtype=torch.bool), diagonal=1
|
|
532
|
+
).to(device)
|
|
533
|
+
|
|
534
|
+
@property
|
|
535
|
+
def mask(self) -> torch.Tensor:
|
|
536
|
+
return self._mask
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
class _FullAttention(nn.Module):
|
|
540
|
+
def __init__(
|
|
541
|
+
self,
|
|
542
|
+
mask_flag: bool = True,
|
|
543
|
+
scale: Optional[float] = None,
|
|
544
|
+
attention_dropout: float = 0.1,
|
|
545
|
+
output_attention: bool = False,
|
|
546
|
+
):
|
|
547
|
+
super().__init__()
|
|
548
|
+
self.scale = scale
|
|
549
|
+
self.mask_flag = mask_flag
|
|
550
|
+
self.output_attention = output_attention
|
|
551
|
+
self.dropout = nn.Dropout(attention_dropout)
|
|
552
|
+
super().__init__()
|
|
553
|
+
self.scale = scale
|
|
554
|
+
self.mask_flag = mask_flag
|
|
555
|
+
self.output_attention = output_attention
|
|
556
|
+
self.dropout = nn.Dropout(attention_dropout)
|
|
557
|
+
|
|
558
|
+
def forward(
|
|
559
|
+
self,
|
|
560
|
+
queries: torch.Tensor,
|
|
561
|
+
keys: torch.Tensor,
|
|
562
|
+
values: torch.Tensor,
|
|
563
|
+
attn_mask: Optional[torch.Tensor],
|
|
564
|
+
tau: Optional[torch.Tensor] = None,
|
|
565
|
+
delta: Optional[torch.Tensor] = None,
|
|
566
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
567
|
+
batch_size, query_len, _, embed_dim = queries.shape
|
|
568
|
+
_, _, _, _ = values.shape
|
|
569
|
+
# Avoid using `or` because TorchScript may fail to cast None/float
|
|
570
|
+
if self.scale is None:
|
|
571
|
+
scale = 1.0 / sqrt(embed_dim)
|
|
572
|
+
else:
|
|
573
|
+
scale = self.scale
|
|
574
|
+
|
|
575
|
+
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
|
|
576
|
+
|
|
577
|
+
if self.mask_flag:
|
|
578
|
+
if attn_mask is None:
|
|
579
|
+
# create a triangular causal mask tensor of shape [B,1,L,L]
|
|
580
|
+
attn_mask = torch.triu(
|
|
581
|
+
torch.ones([batch_size, 1, query_len, query_len], dtype=torch.bool),
|
|
582
|
+
diagonal=1,
|
|
583
|
+
).to(queries.device)
|
|
584
|
+
|
|
585
|
+
# attn_mask is expected to be a boolean tensor with same shape
|
|
586
|
+
scores.masked_fill_(attn_mask, -np.inf)
|
|
587
|
+
|
|
588
|
+
attention_weights = self.dropout(
|
|
589
|
+
torch.softmax(scale * scores, dim=-1)
|
|
590
|
+
) # Scaled Dot-Product Attention
|
|
591
|
+
output_values = torch.einsum("bhls,bshd->blhd", attention_weights, values)
|
|
592
|
+
|
|
593
|
+
if self.output_attention:
|
|
594
|
+
return output_values.contiguous(), attention_weights
|
|
595
|
+
else:
|
|
596
|
+
return output_values.contiguous(), None
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
class _MedformerLayer(nn.Module):
|
|
600
|
+
def __init__(
|
|
601
|
+
self,
|
|
602
|
+
num_blocks: int,
|
|
603
|
+
d_model: int,
|
|
604
|
+
num_heads: int,
|
|
605
|
+
dropout: float = 0.1,
|
|
606
|
+
output_attention: bool = False,
|
|
607
|
+
no_inter: bool = False,
|
|
608
|
+
):
|
|
609
|
+
super().__init__()
|
|
610
|
+
|
|
611
|
+
self.intra_attentions = nn.ModuleList(
|
|
612
|
+
[
|
|
613
|
+
_AttentionLayer(
|
|
614
|
+
_FullAttention(
|
|
615
|
+
mask_flag=False,
|
|
616
|
+
attention_dropout=dropout,
|
|
617
|
+
output_attention=output_attention,
|
|
618
|
+
),
|
|
619
|
+
d_model,
|
|
620
|
+
num_heads,
|
|
621
|
+
)
|
|
622
|
+
for _ in range(num_blocks)
|
|
623
|
+
]
|
|
624
|
+
)
|
|
625
|
+
if no_inter or num_blocks <= 1:
|
|
626
|
+
# print("No inter attention for time")
|
|
627
|
+
self.inter_attention = None
|
|
628
|
+
else:
|
|
629
|
+
self.inter_attention = _AttentionLayer(
|
|
630
|
+
_FullAttention(
|
|
631
|
+
mask_flag=False,
|
|
632
|
+
attention_dropout=dropout,
|
|
633
|
+
output_attention=output_attention,
|
|
634
|
+
),
|
|
635
|
+
d_model,
|
|
636
|
+
num_heads,
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
def forward(
|
|
640
|
+
self,
|
|
641
|
+
x: List[torch.Tensor],
|
|
642
|
+
attn_mask: Optional[List[Optional[torch.Tensor]]] = None,
|
|
643
|
+
tau: Optional[torch.Tensor] = None,
|
|
644
|
+
delta: Optional[torch.Tensor] = None,
|
|
645
|
+
) -> Tuple[List[torch.Tensor], List[Optional[torch.Tensor]]]:
|
|
646
|
+
# Explicit None check because TorchScript cannot evaluate truthiness of lists
|
|
647
|
+
if attn_mask is None:
|
|
648
|
+
# Build a list of None with explicit typing for TorchScript
|
|
649
|
+
new_mask_list: List[Optional[torch.Tensor]] = []
|
|
650
|
+
for _ in range(len(x)):
|
|
651
|
+
new_mask_list.append(None)
|
|
652
|
+
attn_mask = new_mask_list
|
|
653
|
+
# Intra attention
|
|
654
|
+
x_intra: List[torch.Tensor] = []
|
|
655
|
+
attn_out: List[Optional[torch.Tensor]] = []
|
|
656
|
+
# Iterate over ModuleList with enumerate (TorchScript supports enumerate over modules)
|
|
657
|
+
for idx, layer in enumerate(self.intra_attentions):
|
|
658
|
+
x_in_i = x[idx]
|
|
659
|
+
mask_i = attn_mask[idx]
|
|
660
|
+
x_out_temp, attn_temp = layer(
|
|
661
|
+
x_in_i, x_in_i, x_in_i, attn_mask=mask_i, tau=tau, delta=delta
|
|
662
|
+
)
|
|
663
|
+
x_intra.append(x_out_temp) # (B, Li, D)
|
|
664
|
+
attn_out.append(attn_temp)
|
|
665
|
+
if self.inter_attention is not None:
|
|
666
|
+
# Inter attention
|
|
667
|
+
routers = torch.cat([xi[:, -1:] for xi in x_intra], dim=1) # (B, N, D)
|
|
668
|
+
x_inter, attn_inter = self.inter_attention(
|
|
669
|
+
routers, routers, routers, attn_mask=None, tau=tau, delta=delta
|
|
670
|
+
)
|
|
671
|
+
x_out = [
|
|
672
|
+
torch.cat([xi[:, :-1], x_inter[:, i : i + 1]], dim=1) # (B, Li, D)
|
|
673
|
+
for i, xi in enumerate(x_intra)
|
|
674
|
+
]
|
|
675
|
+
attn_out += [attn_inter]
|
|
676
|
+
else:
|
|
677
|
+
x_out = x_intra
|
|
678
|
+
return x_out, attn_out
|
|
679
|
+
|
|
680
|
+
|
|
681
|
+
class _EncoderLayer(nn.Module):
|
|
682
|
+
def __init__(
|
|
683
|
+
self,
|
|
684
|
+
attention: nn.Module,
|
|
685
|
+
d_model: int,
|
|
686
|
+
dim_feedforward: Optional[int],
|
|
687
|
+
dropout: float,
|
|
688
|
+
activation: Optional[nn.Module] = None,
|
|
689
|
+
):
|
|
690
|
+
super().__init__()
|
|
691
|
+
dim_feedforward = dim_feedforward or 4 * d_model
|
|
692
|
+
self.attention = attention
|
|
693
|
+
self.conv1 = nn.Conv1d(
|
|
694
|
+
in_channels=d_model, out_channels=dim_feedforward, kernel_size=1
|
|
695
|
+
)
|
|
696
|
+
self.conv2 = nn.Conv1d(
|
|
697
|
+
in_channels=dim_feedforward, out_channels=d_model, kernel_size=1
|
|
698
|
+
)
|
|
699
|
+
self.norm1 = nn.LayerNorm(d_model)
|
|
700
|
+
self.norm2 = nn.LayerNorm(d_model)
|
|
701
|
+
self.dropout = nn.Dropout(dropout)
|
|
702
|
+
self.activation = activation if activation is not None else nn.ReLU()
|
|
703
|
+
|
|
704
|
+
def forward(
|
|
705
|
+
self,
|
|
706
|
+
x: List[torch.Tensor],
|
|
707
|
+
attn_mask: Optional[List[Optional[torch.Tensor]]] = None,
|
|
708
|
+
tau: Optional[torch.Tensor] = None,
|
|
709
|
+
delta: Optional[torch.Tensor] = None,
|
|
710
|
+
) -> Tuple[List[torch.Tensor], List[Optional[torch.Tensor]]]:
|
|
711
|
+
new_x, attn = self.attention(x, attn_mask=attn_mask, tau=tau, delta=delta)
|
|
712
|
+
x = [x_orig + self.dropout(x_new) for x_orig, x_new in zip(x, new_x)]
|
|
713
|
+
|
|
714
|
+
y = x = [self.norm1(x_val) for x_val in x]
|
|
715
|
+
y = [
|
|
716
|
+
self.dropout(self.activation(self.conv1(y_val.transpose(-1, 1))))
|
|
717
|
+
for y_val in y
|
|
718
|
+
]
|
|
719
|
+
y = [self.dropout(self.conv2(y_val).transpose(-1, 1)) for y_val in y]
|
|
720
|
+
|
|
721
|
+
return [self.norm2(x_val + y_val) for x_val, y_val in zip(x, y)], attn
|
|
722
|
+
|
|
723
|
+
|
|
724
|
+
class _Encoder(nn.Module):
|
|
725
|
+
def __init__(
|
|
726
|
+
self, attn_layers: List[nn.Module], norm_layer: Optional[nn.Module] = None
|
|
727
|
+
):
|
|
728
|
+
super().__init__()
|
|
729
|
+
self.attn_layers = nn.ModuleList(attn_layers)
|
|
730
|
+
self.norm = norm_layer
|
|
731
|
+
|
|
732
|
+
def forward(
|
|
733
|
+
self,
|
|
734
|
+
x: List[torch.Tensor],
|
|
735
|
+
attn_mask: Optional[List[Optional[torch.Tensor]]] = None,
|
|
736
|
+
tau: Optional[torch.Tensor] = None,
|
|
737
|
+
delta: Optional[torch.Tensor] = None,
|
|
738
|
+
) -> Tuple[torch.Tensor, List[List[Optional[torch.Tensor]]]]:
|
|
739
|
+
# x [[B, L1, D], [B, L2, D], ...]
|
|
740
|
+
attns: List[List[Optional[torch.Tensor]]] = []
|
|
741
|
+
for attn_layer in self.attn_layers:
|
|
742
|
+
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
|
|
743
|
+
attns.append(attn)
|
|
744
|
+
|
|
745
|
+
# concat all the outputs
|
|
746
|
+
"""x = torch.cat(
|
|
747
|
+
x, dim=1
|
|
748
|
+
) # (batch_size, patch_num_1 + patch_num_2 + ... , d_model)"""
|
|
749
|
+
|
|
750
|
+
# concat all the routers
|
|
751
|
+
x = torch.cat(
|
|
752
|
+
[xi[:, -1, :].unsqueeze(1) for xi in x], dim=1
|
|
753
|
+
) # (batch_size, len(patch_len_list), d_model)
|
|
754
|
+
|
|
755
|
+
if self.norm is not None:
|
|
756
|
+
x = self.norm(x)
|
|
757
|
+
|
|
758
|
+
return x, attns
|