braindecode 1.3.0.dev180329405__py3-none-any.whl → 1.3.0.dev182330353__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. braindecode/augmentation/base.py +1 -1
  2. braindecode/datasets/__init__.py +12 -4
  3. braindecode/datasets/base.py +115 -151
  4. braindecode/datasets/bcicomp.py +4 -4
  5. braindecode/datasets/bids.py +3 -3
  6. braindecode/datasets/experimental.py +2 -2
  7. braindecode/datasets/mne.py +3 -5
  8. braindecode/datasets/moabb.py +17 -7
  9. braindecode/datasets/nmt.py +2 -2
  10. braindecode/datasets/sleep_physio_challe_18.py +2 -2
  11. braindecode/datasets/sleep_physionet.py +2 -2
  12. braindecode/datasets/tuh.py +2 -2
  13. braindecode/datasets/xy.py +2 -2
  14. braindecode/datautil/__init__.py +11 -1
  15. braindecode/datautil/channel_utils.py +114 -0
  16. braindecode/datautil/serialization.py +7 -7
  17. braindecode/functional/functions.py +6 -2
  18. braindecode/functional/initialization.py +2 -3
  19. braindecode/models/__init__.py +6 -0
  20. braindecode/models/atcnet.py +26 -27
  21. braindecode/models/attentionbasenet.py +37 -32
  22. braindecode/models/attn_sleep.py +2 -0
  23. braindecode/models/base.py +280 -2
  24. braindecode/models/bendr.py +469 -0
  25. braindecode/models/biot.py +2 -0
  26. braindecode/models/contrawr.py +2 -0
  27. braindecode/models/ctnet.py +8 -3
  28. braindecode/models/deepsleepnet.py +28 -19
  29. braindecode/models/eegconformer.py +2 -2
  30. braindecode/models/eeginception_erp.py +31 -25
  31. braindecode/models/eegitnet.py +2 -0
  32. braindecode/models/eegminer.py +2 -0
  33. braindecode/models/eegnet.py +1 -1
  34. braindecode/models/eegsym.py +917 -0
  35. braindecode/models/eegtcnet.py +2 -0
  36. braindecode/models/fbcnet.py +5 -1
  37. braindecode/models/fblightconvnet.py +2 -0
  38. braindecode/models/fbmsnet.py +20 -6
  39. braindecode/models/ifnet.py +2 -0
  40. braindecode/models/labram.py +33 -26
  41. braindecode/models/medformer.py +758 -0
  42. braindecode/models/msvtnet.py +2 -0
  43. braindecode/models/patchedtransformer.py +1 -1
  44. braindecode/models/signal_jepa.py +111 -27
  45. braindecode/models/sinc_shallow.py +12 -9
  46. braindecode/models/sstdpn.py +11 -11
  47. braindecode/models/summary.csv +3 -0
  48. braindecode/models/syncnet.py +2 -0
  49. braindecode/models/tcn.py +2 -0
  50. braindecode/models/usleep.py +26 -21
  51. braindecode/models/util.py +3 -0
  52. braindecode/modules/attention.py +10 -10
  53. braindecode/modules/blocks.py +3 -3
  54. braindecode/modules/filter.py +2 -9
  55. braindecode/modules/layers.py +18 -17
  56. braindecode/preprocessing/__init__.py +232 -3
  57. braindecode/preprocessing/eegprep_preprocess.py +1202 -0
  58. braindecode/preprocessing/mne_preprocess.py +142 -10
  59. braindecode/preprocessing/preprocess.py +28 -18
  60. braindecode/preprocessing/util.py +166 -0
  61. braindecode/preprocessing/windowers.py +26 -20
  62. braindecode/samplers/base.py +8 -8
  63. braindecode/version.py +1 -1
  64. {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/METADATA +6 -2
  65. braindecode-1.3.0.dev182330353.dist-info/RECORD +109 -0
  66. braindecode-1.3.0.dev180329405.dist-info/RECORD +0 -103
  67. {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/WHEEL +0 -0
  68. {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/licenses/LICENSE.txt +0 -0
  69. {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/licenses/NOTICE.txt +0 -0
  70. {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,758 @@
1
+ # Authors: Yihe Wang <ywang145@charlotte.edu>
2
+ # Nan Huang <nhuang1@charlotte.edu>
3
+ # Taida Li <tli14@charlotte.edu>
4
+ #
5
+ # License: MIT
6
+
7
+ """Medformer: A Multi-Granularity Patching Transformer for Medical Time-Series Classification."""
8
+
9
+ import math
10
+ from math import sqrt
11
+ from typing import List, Optional, Tuple
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from braindecode.models.base import EEGModuleMixin
18
+
19
+
20
+ class MEDFormer(EEGModuleMixin, nn.Module):
21
+ r"""Medformer from Wang et al. (2024) [Medformer2024]_.
22
+
23
+ :bdg-success:`Convolution` :bdg-danger:`Large Brain Model`
24
+
25
+ .. figure:: https://raw.githubusercontent.com/DL4mHealth/Medformer/refs/heads/main/figs/medformer_architecture.png
26
+ :align: center
27
+ :alt: MEDFormer Architecture.
28
+
29
+ a) Workflow. b) For the input sample :math:`{x}_{\\textrm{in}}`, the authors apply :math:`n`
30
+ different patch lengths in parallel to create patched features :math:`{x}_p^{(i)}`, where :math:`i`
31
+ ranges from 1 to :math:`n`. Each patch length represents a different granularity. These patched
32
+ features are linearly transformed into :math:`{x}_e^{(i)}` and augmented into :math:`\\widetilde{x}_e^{(i)}`.
33
+ c) The final patch embedding :math:`{x}^{(i)}` fuses augmented :math:`\\widetilde{{x}}_e^{(i)}` with the
34
+ positional embedding :math:`{W}_{\\text{pos}}` and the granularity embedding :math:`{W}_{\\text{gr}}^{(i)}`.
35
+ Each granularity employs a router :math:`{u}^{(i)}` to capture aggregated information.
36
+ Intra-granularity attention focuses within individual granularities, and inter-granularity attention
37
+ leverages the routers to integrate information across granularities.
38
+
39
+ The **MedFormer** is a multi-granularity patching transformer tailored to medical
40
+ time-series (MedTS) classification, with an emphasis on EEG and ECG signals. It captures
41
+ local temporal dynamics, inter-channel correlations, and multi-scale temporal structure
42
+ through cross-channel patching, multi-granularity embeddings, and two-stage attention
43
+ [Medformer2024]_.
44
+
45
+ .. rubric:: Architecture Overview
46
+
47
+ MedFormer integrates three mechanisms to enhance representation learning [Medformer2024]_:
48
+
49
+ 1. **Cross-channel patching.** Leverages inter-channel correlations by forming patches
50
+ across multiple channels and timestamps, capturing multi-timestamp and cross-channel
51
+ patterns.
52
+ 2. **Multi-granularity embedding.** Extracts features at different temporal scales from
53
+ :attr:`patch_len_list`, emulating frequency-band behavior without hand-crafted filters.
54
+ 3. **Two-stage multi-granularity self-attention.** Learns intra- and inter-granularity
55
+ correlations to fuse information across temporal scales.
56
+
57
+ .. rubric:: Macro Components
58
+
59
+ ``MEDFormer.enc_embedding`` (Embedding Layer)
60
+ **Operations.** :class:`~braindecode.models.medformer._ListPatchEmbedding` implements
61
+ cross-channel multi-granularity patching. For each patch length :math:`L_i`, the input
62
+ :math:`\mathbf{x}_{\text{in}} \in \mathbb{R}^{T \times C}` is segmented into
63
+ :math:`N_i` cross-channel non-overlapping patches
64
+ :math:`\mathbf{x}_p^{(i)} \in \mathbb{R}^{N_i \times (L_i \cdot C)}`, where
65
+ :math:`N_i = \lceil T/L_i \rceil`. Each patch is linearly projected via
66
+ :class:`~braindecode.models.medformer._CrossChannelTokenEmbedding` to obtain
67
+ :math:`\mathbf{x}_e^{(i)} \in \mathbb{R}^{N_i \times D}`. Data augmentations
68
+ (masking, jittering) produce augmented embeddings :math:`\tilde{\mathbf{x}}_e^{(i)}`.
69
+ The final embedding combines augmented patches, fixed positional embeddings
70
+ (:class:`~braindecode.models.medformer._PositionalEmbedding`), and learnable
71
+ granularity embeddings :math:`\mathbf{W}_{\text{gr}}^{(i)}`:
72
+
73
+ .. math::
74
+ \mathbf{x}^{(i)} = \tilde{\mathbf{x}}_e^{(i)} + \mathbf{W}_{\text{pos}}[1:N_i] + \mathbf{W}_{\text{gr}}^{(i)}
75
+
76
+ Additionally, a router token is initialized for each granularity:
77
+
78
+ .. math::
79
+ \mathbf{u}^{(i)} = \mathbf{W}_{\text{pos}}[N_i+1] + \mathbf{W}_{\text{gr}}^{(i)}
80
+
81
+ **Role.** Converts raw input into granularity-specific patch embeddings
82
+ :math:`\{\mathbf{x}^{(1)}, \ldots, \mathbf{x}^{(n)}\}` and router embeddings
83
+ :math:`\{\mathbf{u}^{(1)}, \ldots, \mathbf{u}^{(n)}\}` for multi-scale processing.
84
+
85
+ ``MEDFormer.encoder`` (Transformer Encoder Stack)
86
+ **Operations.** A stack of :class:`~braindecode.models.medformer._EncoderLayer` modules,
87
+ each containing a :class:`~braindecode.models.medformer._MedformerLayer` that implements
88
+ two-stage self-attention. The two-stage mechanism splits self-attention into:
89
+
90
+ **(a) Intra-Granularity Self-Attention.** For granularity :math:`i`, the patch embedding
91
+ :math:`\mathbf{x}^{(i)} \in \mathbb{R}^{N_i \times D}` and router embedding
92
+ :math:`\mathbf{u}^{(i)} \in \mathbb{R}^{1 \times D}` are concatenated:
93
+
94
+ .. math::
95
+ \mathbf{z}^{(i)} = [\mathbf{x}^{(i)} \| \mathbf{u}^{(i)}] \in \mathbb{R}^{(N_i+1) \times D}
96
+
97
+ Self-attention is applied to update both embeddings:
98
+
99
+ .. math::
100
+ \mathbf{x}^{(i)} &\leftarrow \text{Attn}_{\text{intra}}(\mathbf{x}^{(i)}, \mathbf{z}^{(i)}, \mathbf{z}^{(i)})\\
101
+ \mathbf{u}^{(i)} &\leftarrow \text{Attn}_{\text{intra}}(\mathbf{u}^{(i)}, \mathbf{z}^{(i)}, \mathbf{z}^{(i)})
102
+
103
+ This captures temporal features within each granularity independently.
104
+
105
+ **(b) Inter-Granularity Self-Attention.** All router embeddings are concatenated:
106
+
107
+ .. math::
108
+ \mathbf{U} = [\mathbf{u}^{(1)} \| \mathbf{u}^{(2)} \| \cdots \| \mathbf{u}^{(n)}] \in \mathbb{R}^{n \times D}
109
+
110
+ Self-attention among routers exchanges information across granularities:
111
+
112
+ .. math::
113
+ \mathbf{u}^{(i)} \leftarrow \text{Attn}_{\text{inter}}(\mathbf{u}^{(i)}, \mathbf{U}, \mathbf{U})
114
+
115
+ **Role.** Learns representations and correlations within and across temporal scales while
116
+ reducing complexity from :math:`O((\sum_i N_i)^2)` to
117
+ :math:`O(\sum_i N_i^2 + n^2)` through the router mechanism.
118
+ .. rubric:: Temporal, Spatial, and Spectral Encoding
119
+
120
+ - **Temporal:** Multiple patch lengths in :attr:`patch_len_list` capture features at several
121
+ temporal granularities, while intra-granularity attention supports long-range temporal
122
+ dependencies.
123
+ - **Spatial:** Cross-channel patching embeds inter-channel dependencies by applying kernels
124
+ that span every input channel.
125
+ - **Spectral:** Differing patch lengths simulate multiple sampling frequencies analogous to
126
+ clinically relevant bands (e.g., alpha, beta, gamma).
127
+
128
+ .. rubric:: Additional Mechanisms
129
+
130
+ - **Granularity router:** Each granularity :math:`i` receives a dedicated router token
131
+ :math:`\\mathbf{u}^{(i)}`. Intra-attention updates the token, and inter-attention exchanges
132
+ aggregated information across scales.
133
+ - **Complexity:** Router-mediated two-stage attention maintains :math:`O(T^2)` complexity for
134
+ suitable patch lengths (e.g., power series), preserving transformer-like efficiency while
135
+ modeling multiple granularities.
136
+
137
+ Parameters
138
+ ----------
139
+ patch_len_list : list of int, optional
140
+ Patch lengths for multi-granularity patching; each entry selects a temporal scale.
141
+ The default is ``[14, 44, 45]``.
142
+ d_model : int, optional
143
+ Embedding dimensionality. The default is ``128``.
144
+ num_heads : int, optional
145
+ Number of attention heads, which must divide :attr:`d_model`. The default is ``8``.
146
+ drop_prob : float, optional
147
+ Dropout probability. The default is ``0.1``.
148
+ no_inter_attn : bool, optional
149
+ If ``True``, disables inter-granularity attention. The default is ``False``.
150
+ n_layers : int, optional
151
+ Number of encoder layers. The default is ``6``.
152
+ dim_feedforward : int, optional
153
+ Feedforward dimensionality. The default is ``256``.
154
+ activation_trans : nn.Module, optional
155
+ Activation module used in transformer encoder layers. The default is :class:`nn.ReLU`.
156
+ single_channel : bool, optional
157
+ If ``True``, processes each channel independently, increasing capacity and cost. The default is ``False``.
158
+ output_attention : bool, optional
159
+ If ``True``, returns attention weights for interpretability. The default is ``True``.
160
+ activation_class : nn.Module, optional
161
+ Activation used in the final classification layer. The default is :class:`nn.GELU`.
162
+
163
+ Notes
164
+ -----
165
+ - MedFormer outperforms strong baselines across six metrics on five MedTS datasets in a
166
+ subject-independent evaluation [Medformer2024]_.
167
+ - Cross-channel patching provides the largest F1 improvement in ablation studies (average
168
+ +6.10%), highlighting its importance for MedTS tasks [Medformer2024]_.
169
+ - Setting :attr:`no_inter_attn` to ``True`` disables inter-granularity attention while retaining
170
+ intra-granularity attention.
171
+
172
+ References
173
+ ----------
174
+ .. [Medformer2024] Wang, Y., Huang, N., Li, T., Yan, Y., & Zhang, X. (2024).
175
+ Medformer: A Multi-Granularity Patching Transformer for Medical Time-Series Classification.
176
+ In A. Globerson, L. Mackey, D. Belgrave, A. Fan, U. Paquet, J. Tomczak, & C. Zhang (Eds.),
177
+ Advances in Neural Information Processing Systems (Vol. 37, pp. 36314-36341).
178
+ doi:10.52202/079017-1145.
179
+ """
180
+
181
+ def __init__(
182
+ self,
183
+ # Signal related parameters
184
+ n_chans=None,
185
+ n_outputs=None,
186
+ n_times=None,
187
+ chs_info=None,
188
+ input_window_seconds=None,
189
+ sfreq=None,
190
+ # Model parameters
191
+ patch_len_list: Optional[List[int]] = None,
192
+ d_model: int = 128,
193
+ num_heads: int = 8,
194
+ drop_prob: float = 0.1,
195
+ no_inter_attn: bool = False,
196
+ n_layers: int = 6,
197
+ dim_feedforward: int = 256,
198
+ activation_trans: Optional[nn.Module] = nn.ReLU,
199
+ single_channel: bool = False,
200
+ output_attention: bool = True,
201
+ activation_class: Optional[nn.Module] = nn.GELU,
202
+ ):
203
+ super().__init__(
204
+ n_outputs=n_outputs,
205
+ n_chans=n_chans,
206
+ chs_info=chs_info,
207
+ n_times=n_times,
208
+ input_window_seconds=input_window_seconds,
209
+ sfreq=sfreq,
210
+ )
211
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
212
+
213
+ # In the original Medformer paper:
214
+ # - seq_len refers to the number of channels
215
+ # - enc_in refers to the number of time points
216
+
217
+ # Save model parameters as instance variables
218
+ self.d_model = d_model
219
+ self.num_heads = num_heads
220
+ self.drop_prob = drop_prob
221
+ self.no_inter_attn = no_inter_attn
222
+ self.n_layers = n_layers
223
+ self.dim_feedforward = dim_feedforward
224
+ self.activation_trans = activation_trans
225
+ self.output_attention = output_attention
226
+ self.single_channel = single_channel
227
+ self.activation_class = activation_class
228
+
229
+ # Process the sequence and patch configurations.
230
+ if patch_len_list is None:
231
+ patch_len_list = [2, 8, 16]
232
+
233
+ self.patch_len_list = patch_len_list
234
+ stride_list = patch_len_list # Using the same values for strides.
235
+ self.stride_list = stride_list
236
+ patch_num_list = [
237
+ int((self.n_chans - patch_len) / stride + 2)
238
+ for patch_len, stride in zip(patch_len_list, stride_list)
239
+ ]
240
+ self.patch_num_list = patch_num_list
241
+
242
+ # Initialize the embedding layer.
243
+ self.enc_embedding = _ListPatchEmbedding(
244
+ enc_in=self.n_times,
245
+ d_model=self.d_model,
246
+ seq_len=self.n_chans,
247
+ patch_len_list=self.patch_len_list,
248
+ stride_list=self.stride_list,
249
+ dropout=self.drop_prob,
250
+ single_channel=self.single_channel,
251
+ n_chans=self.n_chans,
252
+ n_times=self.n_times,
253
+ )
254
+ # Build the encoder with multiple layers.
255
+ self.encoder = _Encoder(
256
+ [
257
+ _EncoderLayer(
258
+ attention=_MedformerLayer(
259
+ num_blocks=len(self.patch_len_list),
260
+ d_model=self.d_model,
261
+ num_heads=self.num_heads,
262
+ dropout=self.drop_prob,
263
+ output_attention=self.output_attention,
264
+ no_inter=self.no_inter_attn,
265
+ ),
266
+ d_model=self.d_model,
267
+ dim_feedforward=self.dim_feedforward,
268
+ dropout=self.drop_prob,
269
+ activation=self.activation_trans()
270
+ if self.activation_trans is not None
271
+ else nn.ReLU(),
272
+ )
273
+ for _ in range(self.n_layers)
274
+ ],
275
+ norm_layer=torch.nn.LayerNorm(self.d_model),
276
+ )
277
+
278
+ # For classification tasks, add additional layers.
279
+ self.activation_layer = (
280
+ self.activation_class() if self.activation_class is not None else nn.GELU()
281
+ )
282
+ self.dropout = nn.Dropout(self.drop_prob)
283
+ self.final_layer = nn.Linear(
284
+ self.d_model
285
+ * len(self.patch_num_list)
286
+ * (1 if not self.single_channel else self.n_chans),
287
+ self.n_outputs,
288
+ )
289
+
290
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
291
+ """Forward pass of the Medformer model.
292
+
293
+ Parameters
294
+ ----------
295
+ x : torch.Tensor
296
+ Input tensor of shape (batch_size, n_chans, n_times).
297
+
298
+ Returns
299
+ -------
300
+ torch.Tensor
301
+ Output tensor of shape (batch_size, n_outputs).
302
+ """
303
+ # Embedding
304
+ enc_out = self.enc_embedding(x)
305
+ enc_out, _ = self.encoder(enc_out, attn_mask=None)
306
+
307
+ if self.single_channel:
308
+ # Reshape back from (batch_size * n_chans, ...) to (batch_size, n_chans, ...)
309
+ # Explicitly construct the reshape dimensions to be TorchScript compatible
310
+ batch_size = enc_out.shape[0] // self.n_chans
311
+ seq_len = enc_out.shape[1]
312
+ d_model = enc_out.shape[2]
313
+ enc_out = torch.reshape(
314
+ enc_out, (batch_size, self.n_chans, seq_len, d_model)
315
+ )
316
+
317
+ # Output
318
+ output = self.activation_layer(enc_out)
319
+ output = self.dropout(output)
320
+ output = output.reshape(
321
+ output.shape[0], -1
322
+ ) # (batch_size, seq_length * d_model)
323
+ output = self.final_layer(output) # (batch_size, num_classes)
324
+ return output
325
+
326
+
327
+ class _PositionalEmbedding(nn.Module):
328
+ def __init__(self, d_model: int, max_len: int = 5000):
329
+ super().__init__()
330
+ # If d_model is odd, temporarily work with d_model + 1.
331
+ if d_model % 2 == 1:
332
+ d_model_adj = d_model + 1
333
+ else:
334
+ d_model_adj = d_model
335
+ self.d_model = d_model # store the original dimension
336
+
337
+ # Create a pe tensor of size (max_len, d_model_adj)
338
+ pe = torch.zeros(max_len, d_model_adj).float()
339
+ pe.requires_grad = False
340
+
341
+ # Compute the sinusoidal factors.
342
+ position = torch.arange(0, max_len).float().unsqueeze(1)
343
+ # Use d_model_adj in the denominator so that the frequencies are computed over an even number.
344
+ div_term = torch.exp(
345
+ torch.arange(0, d_model_adj, 2).float() * (-math.log(10000.0) / d_model_adj)
346
+ )
347
+ pe[:, 0::2] = torch.sin(position * div_term)
348
+ pe[:, 1::2] = torch.cos(position * div_term)
349
+
350
+ # Unsqueeze to shape (1, max_len, d_model_adj)
351
+ pe = pe.unsqueeze(0)
352
+ self.register_buffer("pe", pe)
353
+
354
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
355
+ # x is assumed to have shape (B, L, d_model_target)
356
+ # We return the first self.d_model columns from the computed pe.
357
+ return self.pe[:, : x.size(1), : self.d_model]
358
+
359
+
360
+ class _CrossChannelTokenEmbedding(nn.Module):
361
+ def __init__(
362
+ self, c_in: int, l_patch: int, d_model: int, stride: Optional[int] = None
363
+ ):
364
+ super().__init__()
365
+ if stride is None:
366
+ stride = l_patch
367
+ self.token_conv = nn.Conv2d(
368
+ in_channels=1,
369
+ out_channels=d_model,
370
+ kernel_size=(c_in, l_patch),
371
+ stride=(1, stride),
372
+ padding=0,
373
+ padding_mode="circular",
374
+ bias=False,
375
+ )
376
+ for m in self.modules():
377
+ if isinstance(m, nn.Conv2d):
378
+ nn.init.kaiming_normal_(
379
+ m.weight, mode="fan_in", nonlinearity="leaky_relu"
380
+ )
381
+
382
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
383
+ x = self.token_conv(x)
384
+ return x
385
+
386
+
387
+ class _ListPatchEmbedding(nn.Module):
388
+ def __init__(
389
+ self,
390
+ enc_in: int,
391
+ d_model: int,
392
+ seq_len: int,
393
+ patch_len_list: List[int],
394
+ stride_list: List[int],
395
+ dropout: float,
396
+ single_channel: bool = False,
397
+ n_chans: Optional[int] = None,
398
+ n_times: Optional[int] = None,
399
+ ):
400
+ super().__init__()
401
+ self.patch_len_list = patch_len_list
402
+ self.stride_list = stride_list
403
+ # Use ModuleList so TorchScript can statically infer module attribute types
404
+ self.paddings: nn.ModuleList = nn.ModuleList(
405
+ [nn.ReplicationPad1d((0, stride)) for stride in stride_list]
406
+ )
407
+ self.single_channel = single_channel
408
+ self.n_chans = n_chans
409
+ self.n_times = n_times
410
+
411
+ # Number of different patch/granularity blocks (used to make loops TorchScript-friendly)
412
+ self.num_patches = len(patch_len_list)
413
+
414
+ linear_layers = [
415
+ _CrossChannelTokenEmbedding(
416
+ c_in=enc_in if not single_channel else 1,
417
+ l_patch=patch_len,
418
+ d_model=d_model,
419
+ )
420
+ for patch_len in patch_len_list
421
+ ]
422
+ self.value_embeddings = nn.ModuleList(linear_layers)
423
+ self.position_embedding = _PositionalEmbedding(d_model=d_model)
424
+ self.channel_embedding = _PositionalEmbedding(d_model=seq_len)
425
+ self.dropout = nn.Dropout(dropout)
426
+
427
+ self.learnable_embeddings = nn.ParameterList(
428
+ [nn.Parameter(torch.randn(1, d_model)) for _ in patch_len_list]
429
+ )
430
+
431
+ def forward(
432
+ self, x: torch.Tensor
433
+ ) -> List[torch.Tensor]: # (batch_size, seq_len, enc_in)
434
+ x = x.permute(0, 2, 1) # (batch_size, enc_in, seq_len)
435
+ if self.single_channel:
436
+ # After permute: x.shape = (batch_size, n_times, n_chans)
437
+ # We want to process each channel independently
438
+ batch_size = x.shape[0]
439
+ # Permute to get channels in the middle: (batch_size, n_chans, n_times)
440
+ x = x.permute(0, 2, 1)
441
+ # Reshape to treat each channel independently: (batch_size * n_chans, 1, n_times)
442
+ x = torch.reshape(x, (batch_size * self.n_chans, 1, self.n_times))
443
+
444
+ x_list = []
445
+ for padding, value_embedding in zip(self.paddings, self.value_embeddings):
446
+ x_copy = x.clone()
447
+ # add positional embedding to tag each channel (only when not single_channel)
448
+ if not self.single_channel:
449
+ x_new = x_copy + self.channel_embedding(x_copy)
450
+ else:
451
+ x_new = x_copy
452
+ x_new = padding(x_new).unsqueeze(
453
+ 1
454
+ ) # (batch_size, 1, enc_in, seq_len+stride)
455
+ x_new = value_embedding(x_new) # (batch_size, d_model, 1, patch_num)
456
+ x_new = x_new.squeeze(2).transpose(1, 2) # (batch_size, patch_num, d_model)
457
+ x_list.append(x_new)
458
+
459
+ # Combine each patch embedding with its corresponding learnable granularity
460
+ # embedding and positional embedding. Use an explicit indexed loop so
461
+ # TorchScript can statically determine lengths instead of iterating over
462
+ # Python lists/ParameterList via zip.
463
+ out_list: List[torch.Tensor] = []
464
+ # Iterate over learnable_embeddings with enumerate (supported by TorchScript)
465
+ for idx, cxt in enumerate(self.learnable_embeddings):
466
+ xi = x_list[idx]
467
+ xi = xi + cxt + self.position_embedding(xi)
468
+ out_list.append(xi)
469
+
470
+ return out_list
471
+
472
+
473
+ class _AttentionLayer(nn.Module):
474
+ def __init__(
475
+ self,
476
+ attention: nn.Module,
477
+ d_model: int,
478
+ num_heads: int,
479
+ d_keys: Optional[int] = None,
480
+ d_values: Optional[int] = None,
481
+ ):
482
+ super().__init__()
483
+
484
+ d_keys = d_keys or (d_model // num_heads)
485
+ d_values = d_values or (d_model // num_heads)
486
+
487
+ self.inner_attention = attention
488
+ self.query_projection = nn.Linear(d_model, d_keys * num_heads)
489
+ self.key_projection = nn.Linear(d_model, d_keys * num_heads)
490
+ self.value_projection = nn.Linear(d_model, d_values * num_heads)
491
+ self.out_projection = nn.Linear(d_values * num_heads, d_model)
492
+ self.num_heads = num_heads
493
+
494
+ def forward(
495
+ self,
496
+ queries: torch.Tensor,
497
+ keys: torch.Tensor,
498
+ values: torch.Tensor,
499
+ attn_mask: Optional[torch.Tensor],
500
+ tau: Optional[torch.Tensor] = None,
501
+ delta: Optional[torch.Tensor] = None,
502
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
503
+ batch_size, query_len, _ = queries.shape
504
+ _, key_len, _ = keys.shape
505
+ num_heads = self.num_heads
506
+
507
+ queries = self.query_projection(queries).view(
508
+ batch_size, query_len, num_heads, -1
509
+ ) # multi-head
510
+ keys = self.key_projection(keys).view(batch_size, key_len, num_heads, -1)
511
+ values = self.value_projection(values).view(batch_size, key_len, num_heads, -1)
512
+
513
+ out, attn = self.inner_attention(
514
+ queries, keys, values, attn_mask, tau=tau, delta=delta
515
+ )
516
+ out = out.view(batch_size, query_len, -1)
517
+
518
+ return self.out_projection(out), attn
519
+
520
+
521
+ class _TriangularCausalMask:
522
+ def __init__(
523
+ self, batch_size: int, seq_len: int, device: Optional[torch.device] = None
524
+ ):
525
+ # Normalize device to a torch.device for .to(device)
526
+ if device is None:
527
+ device = torch.device("cpu")
528
+ mask_shape = [batch_size, 1, seq_len, seq_len]
529
+ with torch.no_grad():
530
+ self._mask = torch.triu(
531
+ torch.ones(mask_shape, dtype=torch.bool), diagonal=1
532
+ ).to(device)
533
+
534
+ @property
535
+ def mask(self) -> torch.Tensor:
536
+ return self._mask
537
+
538
+
539
+ class _FullAttention(nn.Module):
540
+ def __init__(
541
+ self,
542
+ mask_flag: bool = True,
543
+ scale: Optional[float] = None,
544
+ attention_dropout: float = 0.1,
545
+ output_attention: bool = False,
546
+ ):
547
+ super().__init__()
548
+ self.scale = scale
549
+ self.mask_flag = mask_flag
550
+ self.output_attention = output_attention
551
+ self.dropout = nn.Dropout(attention_dropout)
552
+ super().__init__()
553
+ self.scale = scale
554
+ self.mask_flag = mask_flag
555
+ self.output_attention = output_attention
556
+ self.dropout = nn.Dropout(attention_dropout)
557
+
558
+ def forward(
559
+ self,
560
+ queries: torch.Tensor,
561
+ keys: torch.Tensor,
562
+ values: torch.Tensor,
563
+ attn_mask: Optional[torch.Tensor],
564
+ tau: Optional[torch.Tensor] = None,
565
+ delta: Optional[torch.Tensor] = None,
566
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
567
+ batch_size, query_len, _, embed_dim = queries.shape
568
+ _, _, _, _ = values.shape
569
+ # Avoid using `or` because TorchScript may fail to cast None/float
570
+ if self.scale is None:
571
+ scale = 1.0 / sqrt(embed_dim)
572
+ else:
573
+ scale = self.scale
574
+
575
+ scores = torch.einsum("blhe,bshe->bhls", queries, keys)
576
+
577
+ if self.mask_flag:
578
+ if attn_mask is None:
579
+ # create a triangular causal mask tensor of shape [B,1,L,L]
580
+ attn_mask = torch.triu(
581
+ torch.ones([batch_size, 1, query_len, query_len], dtype=torch.bool),
582
+ diagonal=1,
583
+ ).to(queries.device)
584
+
585
+ # attn_mask is expected to be a boolean tensor with same shape
586
+ scores.masked_fill_(attn_mask, -np.inf)
587
+
588
+ attention_weights = self.dropout(
589
+ torch.softmax(scale * scores, dim=-1)
590
+ ) # Scaled Dot-Product Attention
591
+ output_values = torch.einsum("bhls,bshd->blhd", attention_weights, values)
592
+
593
+ if self.output_attention:
594
+ return output_values.contiguous(), attention_weights
595
+ else:
596
+ return output_values.contiguous(), None
597
+
598
+
599
+ class _MedformerLayer(nn.Module):
600
+ def __init__(
601
+ self,
602
+ num_blocks: int,
603
+ d_model: int,
604
+ num_heads: int,
605
+ dropout: float = 0.1,
606
+ output_attention: bool = False,
607
+ no_inter: bool = False,
608
+ ):
609
+ super().__init__()
610
+
611
+ self.intra_attentions = nn.ModuleList(
612
+ [
613
+ _AttentionLayer(
614
+ _FullAttention(
615
+ mask_flag=False,
616
+ attention_dropout=dropout,
617
+ output_attention=output_attention,
618
+ ),
619
+ d_model,
620
+ num_heads,
621
+ )
622
+ for _ in range(num_blocks)
623
+ ]
624
+ )
625
+ if no_inter or num_blocks <= 1:
626
+ # print("No inter attention for time")
627
+ self.inter_attention = None
628
+ else:
629
+ self.inter_attention = _AttentionLayer(
630
+ _FullAttention(
631
+ mask_flag=False,
632
+ attention_dropout=dropout,
633
+ output_attention=output_attention,
634
+ ),
635
+ d_model,
636
+ num_heads,
637
+ )
638
+
639
+ def forward(
640
+ self,
641
+ x: List[torch.Tensor],
642
+ attn_mask: Optional[List[Optional[torch.Tensor]]] = None,
643
+ tau: Optional[torch.Tensor] = None,
644
+ delta: Optional[torch.Tensor] = None,
645
+ ) -> Tuple[List[torch.Tensor], List[Optional[torch.Tensor]]]:
646
+ # Explicit None check because TorchScript cannot evaluate truthiness of lists
647
+ if attn_mask is None:
648
+ # Build a list of None with explicit typing for TorchScript
649
+ new_mask_list: List[Optional[torch.Tensor]] = []
650
+ for _ in range(len(x)):
651
+ new_mask_list.append(None)
652
+ attn_mask = new_mask_list
653
+ # Intra attention
654
+ x_intra: List[torch.Tensor] = []
655
+ attn_out: List[Optional[torch.Tensor]] = []
656
+ # Iterate over ModuleList with enumerate (TorchScript supports enumerate over modules)
657
+ for idx, layer in enumerate(self.intra_attentions):
658
+ x_in_i = x[idx]
659
+ mask_i = attn_mask[idx]
660
+ x_out_temp, attn_temp = layer(
661
+ x_in_i, x_in_i, x_in_i, attn_mask=mask_i, tau=tau, delta=delta
662
+ )
663
+ x_intra.append(x_out_temp) # (B, Li, D)
664
+ attn_out.append(attn_temp)
665
+ if self.inter_attention is not None:
666
+ # Inter attention
667
+ routers = torch.cat([xi[:, -1:] for xi in x_intra], dim=1) # (B, N, D)
668
+ x_inter, attn_inter = self.inter_attention(
669
+ routers, routers, routers, attn_mask=None, tau=tau, delta=delta
670
+ )
671
+ x_out = [
672
+ torch.cat([xi[:, :-1], x_inter[:, i : i + 1]], dim=1) # (B, Li, D)
673
+ for i, xi in enumerate(x_intra)
674
+ ]
675
+ attn_out += [attn_inter]
676
+ else:
677
+ x_out = x_intra
678
+ return x_out, attn_out
679
+
680
+
681
+ class _EncoderLayer(nn.Module):
682
+ def __init__(
683
+ self,
684
+ attention: nn.Module,
685
+ d_model: int,
686
+ dim_feedforward: Optional[int],
687
+ dropout: float,
688
+ activation: Optional[nn.Module] = None,
689
+ ):
690
+ super().__init__()
691
+ dim_feedforward = dim_feedforward or 4 * d_model
692
+ self.attention = attention
693
+ self.conv1 = nn.Conv1d(
694
+ in_channels=d_model, out_channels=dim_feedforward, kernel_size=1
695
+ )
696
+ self.conv2 = nn.Conv1d(
697
+ in_channels=dim_feedforward, out_channels=d_model, kernel_size=1
698
+ )
699
+ self.norm1 = nn.LayerNorm(d_model)
700
+ self.norm2 = nn.LayerNorm(d_model)
701
+ self.dropout = nn.Dropout(dropout)
702
+ self.activation = activation if activation is not None else nn.ReLU()
703
+
704
+ def forward(
705
+ self,
706
+ x: List[torch.Tensor],
707
+ attn_mask: Optional[List[Optional[torch.Tensor]]] = None,
708
+ tau: Optional[torch.Tensor] = None,
709
+ delta: Optional[torch.Tensor] = None,
710
+ ) -> Tuple[List[torch.Tensor], List[Optional[torch.Tensor]]]:
711
+ new_x, attn = self.attention(x, attn_mask=attn_mask, tau=tau, delta=delta)
712
+ x = [x_orig + self.dropout(x_new) for x_orig, x_new in zip(x, new_x)]
713
+
714
+ y = x = [self.norm1(x_val) for x_val in x]
715
+ y = [
716
+ self.dropout(self.activation(self.conv1(y_val.transpose(-1, 1))))
717
+ for y_val in y
718
+ ]
719
+ y = [self.dropout(self.conv2(y_val).transpose(-1, 1)) for y_val in y]
720
+
721
+ return [self.norm2(x_val + y_val) for x_val, y_val in zip(x, y)], attn
722
+
723
+
724
+ class _Encoder(nn.Module):
725
+ def __init__(
726
+ self, attn_layers: List[nn.Module], norm_layer: Optional[nn.Module] = None
727
+ ):
728
+ super().__init__()
729
+ self.attn_layers = nn.ModuleList(attn_layers)
730
+ self.norm = norm_layer
731
+
732
+ def forward(
733
+ self,
734
+ x: List[torch.Tensor],
735
+ attn_mask: Optional[List[Optional[torch.Tensor]]] = None,
736
+ tau: Optional[torch.Tensor] = None,
737
+ delta: Optional[torch.Tensor] = None,
738
+ ) -> Tuple[torch.Tensor, List[List[Optional[torch.Tensor]]]]:
739
+ # x [[B, L1, D], [B, L2, D], ...]
740
+ attns: List[List[Optional[torch.Tensor]]] = []
741
+ for attn_layer in self.attn_layers:
742
+ x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
743
+ attns.append(attn)
744
+
745
+ # concat all the outputs
746
+ """x = torch.cat(
747
+ x, dim=1
748
+ ) # (batch_size, patch_num_1 + patch_num_2 + ... , d_model)"""
749
+
750
+ # concat all the routers
751
+ x = torch.cat(
752
+ [xi[:, -1, :].unsqueeze(1) for xi in x], dim=1
753
+ ) # (batch_size, len(patch_len_list), d_model)
754
+
755
+ if self.norm is not None:
756
+ x = self.norm(x)
757
+
758
+ return x, attns