braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,760 @@
1
+ # Authors: Yihe Wang <ywang145@charlotte.edu>
2
+ # Nan Huang <nhuang1@charlotte.edu>
3
+ # Taida Li <tli14@charlotte.edu>
4
+ #
5
+ # License: MIT
6
+
7
+ """Medformer: A Multi-Granularity Patching Transformer for Medical Time-Series Classification."""
8
+
9
+ import math
10
+ from math import sqrt
11
+ from typing import List, Optional, Tuple
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from braindecode.models.base import EEGModuleMixin
18
+
19
+
20
+ class MEDFormer(EEGModuleMixin, nn.Module):
21
+ r"""
22
+ Medformer from Wang et al (2024) [Medformer2024]_.
23
+
24
+ :bdg-success:`Convolution` :bdg-danger:`Foundation Model`
25
+
26
+ .. figure:: https://raw.githubusercontent.com/DL4mHealth/Medformer/refs/heads/main/figs/medformer_architecture.png
27
+ :align: center
28
+ :alt: MEDFormer Architecture.
29
+
30
+ a) Workflow. b) For the input sample :math:`{x}_{\text{in}}`, the authors apply :math:`n`
31
+ different patch lengths in parallel to create patched features :math:`{x}_p^{(i)}`, where :math:`i`
32
+ ranges from 1 to :math:`n`. Each patch length represents a different granularity. These patched
33
+ features are linearly transformed into :math:`{x}_e^{(i)}` and augmented into :math:`\widetilde{x}_e^{(i)}`.
34
+ c) The final patch embedding :math:`{x}^{(i)}` fuses augmented :math:`\widetilde{x}_e^{(i)}` with the
35
+ positional embedding :math:`{W}_{\text{pos}}` and the granularity embedding :math:`{W}_{\text{gr}}^{(i)}`.
36
+ Each granularity employs a router :math:`{u}^{(i)}` to capture aggregated information.
37
+ Intra-granularity attention focuses within individual granularities, and inter-granularity attention
38
+ leverages the routers to integrate information across granularities.
39
+
40
+ The **MedFormer** is a multi-granularity patching transformer tailored to medical
41
+ time-series (MedTS) classification, with an emphasis on EEG and ECG signals. It captures
42
+ local temporal dynamics, inter-channel correlations, and multi-scale temporal structure
43
+ through cross-channel patching, multi-granularity embeddings, and two-stage attention
44
+ [Medformer2024]_.
45
+
46
+ .. rubric:: Architecture Overview
47
+
48
+ MedFormer integrates three mechanisms to enhance representation learning [Medformer2024]_:
49
+
50
+ 1. **Cross-channel patching.** Leverages inter-channel correlations by forming patches
51
+ across multiple channels and timestamps, capturing multi-timestamp and cross-channel
52
+ patterns.
53
+ 2. **Multi-granularity embedding.** Extracts features at different temporal scales from
54
+ :attr:`patch_len_list`, emulating frequency-band behavior without hand-crafted filters.
55
+ 3. **Two-stage multi-granularity self-attention.** Learns intra- and inter-granularity
56
+ correlations to fuse information across temporal scales.
57
+
58
+ .. rubric:: Macro Components
59
+
60
+ ``MEDFormer.enc_embedding`` (Embedding Layer)
61
+ **Operations.** :class:`~braindecode.models.medformer._ListPatchEmbedding` implements
62
+ cross-channel multi-granularity patching. For each patch length :math:`L_i`, the input
63
+ :math:`\mathbf{x}_{\text{in}} \in \mathbb{R}^{T \times C}` is segmented into
64
+ :math:`N_i` cross-channel non-overlapping patches
65
+ :math:`\mathbf{x}_p^{(i)} \in \mathbb{R}^{N_i \times (L_i \cdot C)}`, where
66
+ :math:`N_i = \lceil T/L_i \rceil`. Each patch is linearly projected via
67
+ :class:`~braindecode.models.medformer._CrossChannelTokenEmbedding` to obtain
68
+ :math:`\mathbf{x}_e^{(i)} \in \mathbb{R}^{N_i \times D}`. Data augmentations
69
+ (masking, jittering) produce augmented embeddings :math:`\tilde{\mathbf{x}}_e^{(i)}`.
70
+ The final embedding combines augmented patches, fixed positional embeddings
71
+ (:class:`~braindecode.models.medformer._PositionalEmbedding`), and learnable
72
+ granularity embeddings :math:`\mathbf{W}_{\text{gr}}^{(i)}`:
73
+
74
+ .. math::
75
+ \mathbf{x}^{(i)} = \tilde{\mathbf{x}}_e^{(i)} + \mathbf{W}_{\text{pos}}[1:N_i] + \mathbf{W}_{\text{gr}}^{(i)}
76
+
77
+ Additionally, a router token is initialized for each granularity:
78
+
79
+ .. math::
80
+ \mathbf{u}^{(i)} = \mathbf{W}_{\text{pos}}[N_i+1] + \mathbf{W}_{\text{gr}}^{(i)}
81
+
82
+ **Role.** Converts raw input into granularity-specific patch embeddings
83
+ :math:`\{\mathbf{x}^{(1)}, \ldots, \mathbf{x}^{(n)}\}` and router embeddings
84
+ :math:`\{\mathbf{u}^{(1)}, \ldots, \mathbf{u}^{(n)}\}` for multi-scale processing.
85
+
86
+ ``MEDFormer.encoder`` (Transformer Encoder Stack)
87
+ **Operations.** A stack of :class:`~braindecode.models.medformer._EncoderLayer` modules,
88
+ each containing a :class:`~braindecode.models.medformer._MedformerLayer` that implements
89
+ two-stage self-attention. The two-stage mechanism splits self-attention into:
90
+
91
+ **(a) Intra-Granularity Self-Attention.** For granularity :math:`i`, the patch embedding
92
+ :math:`\mathbf{x}^{(i)} \in \mathbb{R}^{N_i \times D}` and router embedding
93
+ :math:`\mathbf{u}^{(i)} \in \mathbb{R}^{1 \times D}` are concatenated:
94
+
95
+ .. math::
96
+ \mathbf{z}^{(i)} = [\mathbf{x}^{(i)} \| \mathbf{u}^{(i)}] \in \mathbb{R}^{(N_i+1) \times D}
97
+
98
+ Self-attention is applied to update both embeddings:
99
+
100
+ .. math::
101
+ \mathbf{x}^{(i)} &\leftarrow \text{Attn}_{\text{intra}}(\mathbf{x}^{(i)}, \mathbf{z}^{(i)}, \mathbf{z}^{(i)})\\
102
+ \mathbf{u}^{(i)} &\leftarrow \text{Attn}_{\text{intra}}(\mathbf{u}^{(i)}, \mathbf{z}^{(i)}, \mathbf{z}^{(i)})
103
+
104
+ This captures temporal features within each granularity independently.
105
+
106
+ **(b) Inter-Granularity Self-Attention.** All router embeddings are concatenated:
107
+
108
+ .. math::
109
+ \mathbf{U} = [\mathbf{u}^{(1)} \| \mathbf{u}^{(2)} \| \cdots \| \mathbf{u}^{(n)}] \in \mathbb{R}^{n \times D}
110
+
111
+ Self-attention among routers exchanges information across granularities:
112
+
113
+ .. math::
114
+ \mathbf{u}^{(i)} \leftarrow \text{Attn}_{\text{inter}}(\mathbf{u}^{(i)}, \mathbf{U}, \mathbf{U})
115
+
116
+ **Role.** Learns representations and correlations within and across temporal scales while
117
+ reducing complexity from :math:`O((\sum_i N_i)^2)` to
118
+ :math:`O(\sum_i N_i^2 + n^2)` through the router mechanism.
119
+
120
+ .. rubric:: Temporal, Spatial, and Spectral Encoding
121
+
122
+ - **Temporal:** Multiple patch lengths in :attr:`patch_len_list` capture features at several
123
+ temporal granularities, while intra-granularity attention supports long-range temporal
124
+ dependencies.
125
+ - **Spatial:** Cross-channel patching embeds inter-channel dependencies by applying kernels
126
+ that span every input channel.
127
+ - **Spectral:** Differing patch lengths simulate multiple sampling frequencies analogous to
128
+ clinically relevant bands (e.g., alpha, beta, gamma).
129
+
130
+ .. rubric:: Additional Mechanisms
131
+
132
+ - **Granularity router:** Each granularity :math:`i` receives a dedicated router token
133
+ :math:`\mathbf{u}^{(i)}`. Intra-attention updates the token, and inter-attention exchanges
134
+ aggregated information across scales.
135
+ - **Complexity:** Router-mediated two-stage attention maintains :math:`O(T^2)` complexity for
136
+ suitable patch lengths (e.g., power series), preserving transformer-like efficiency while
137
+ modeling multiple granularities.
138
+
139
+ Parameters
140
+ ----------
141
+ patch_len_list : list of int, optional
142
+ Patch lengths for multi-granularity patching; each entry selects a temporal scale.
143
+ The default is ``[14, 44, 45]``.
144
+ embed_dim : int, optional
145
+ Embedding dimensionality. The default is ``128``.
146
+ num_heads : int, optional
147
+ Number of attention heads, which must divide :attr:`d_model`. The default is ``8``.
148
+ drop_prob : float, optional
149
+ Dropout probability. The default is ``0.1``.
150
+ no_inter_attn : bool, optional
151
+ If ``True``, disables inter-granularity attention. The default is ``False``.
152
+ num_layers : int, optional
153
+ Number of encoder layers. The default is ``6``.
154
+ dim_feedforward : int, optional
155
+ Feedforward dimensionality. The default is ``256``.
156
+ activation_trans : nn.Module, optional
157
+ Activation module used in transformer encoder layers. The default is :class:`nn.ReLU`.
158
+ single_channel : bool, optional
159
+ If ``True``, processes each channel independently, increasing capacity and cost. The default is ``False``.
160
+ output_attention : bool, optional
161
+ If ``True``, returns attention weights for interpretability. The default is ``True``.
162
+ activation_class : nn.Module, optional
163
+ Activation used in the final classification layer. The default is :class:`nn.GELU`.
164
+
165
+ Notes
166
+ -----
167
+ - MedFormer outperforms strong baselines across six metrics on five MedTS datasets in a
168
+ subject-independent evaluation [Medformer2024]_.
169
+ - Cross-channel patching provides the largest F1 improvement in ablation studies (average
170
+ +6.10%), highlighting its importance for MedTS tasks [Medformer2024]_.
171
+ - Setting :attr:`no_inter_attn` to ``True`` disables inter-granularity attention while retaining
172
+ intra-granularity attention.
173
+
174
+ References
175
+ ----------
176
+ .. [Medformer2024] Wang, Y., Huang, N., Li, T., Yan, Y., & Zhang, X. (2024).
177
+ Medformer: A Multi-Granularity Patching Transformer for Medical Time-Series Classification.
178
+ In A. Globerson, L. Mackey, D. Belgrave, A. Fan, U. Paquet, J. Tomczak, & C. Zhang (Eds.),
179
+ Advances in Neural Information Processing Systems (Vol. 37, pp. 36314-36341).
180
+ doi:10.52202/079017-1145.
181
+ """
182
+
183
+ def __init__(
184
+ self,
185
+ # Signal related parameters
186
+ n_chans=None,
187
+ n_outputs=None,
188
+ n_times=None,
189
+ chs_info=None,
190
+ input_window_seconds=None,
191
+ sfreq=None,
192
+ # Model parameters
193
+ patch_len_list: Optional[List[int]] = None,
194
+ embed_dim: int = 128,
195
+ num_heads: int = 8,
196
+ drop_prob: float = 0.1,
197
+ no_inter_attn: bool = False,
198
+ num_layers: int = 6,
199
+ dim_feedforward: int = 256,
200
+ activation_trans: type[nn.Module] | None = nn.ReLU,
201
+ single_channel: bool = False,
202
+ output_attention: bool = True,
203
+ activation_class: type[nn.Module] | None = nn.GELU,
204
+ ):
205
+ super().__init__(
206
+ n_outputs=n_outputs,
207
+ n_chans=n_chans,
208
+ chs_info=chs_info,
209
+ n_times=n_times,
210
+ input_window_seconds=input_window_seconds,
211
+ sfreq=sfreq,
212
+ )
213
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
214
+
215
+ # In the original Medformer paper:
216
+ # - seq_len refers to the number of channels
217
+ # - enc_in refers to the number of time points
218
+
219
+ # Save model parameters as instance variables
220
+ self.embed_dim = embed_dim
221
+ self.num_heads = num_heads
222
+ self.drop_prob = drop_prob
223
+ self.no_inter_attn = no_inter_attn
224
+ self.num_layers = num_layers
225
+ self.dim_feedforward = dim_feedforward
226
+ self.activation_trans = activation_trans
227
+ self.output_attention = output_attention
228
+ self.single_channel = single_channel
229
+ self.activation_class = activation_class
230
+
231
+ # Process the sequence and patch configurations.
232
+ if patch_len_list is None:
233
+ patch_len_list = [2, 8, 16]
234
+
235
+ self.patch_len_list = patch_len_list
236
+ stride_list = patch_len_list # Using the same values for strides.
237
+ self.stride_list = stride_list
238
+ patch_num_list = [
239
+ int((self.n_chans - patch_len) / stride + 2)
240
+ for patch_len, stride in zip(patch_len_list, stride_list)
241
+ ]
242
+ self.patch_num_list = patch_num_list
243
+
244
+ # Initialize the embedding layer.
245
+ self.enc_embedding = _ListPatchEmbedding(
246
+ enc_in=self.n_times,
247
+ d_model=self.embed_dim,
248
+ seq_len=self.n_chans,
249
+ patch_len_list=self.patch_len_list,
250
+ stride_list=self.stride_list,
251
+ dropout=self.drop_prob,
252
+ single_channel=self.single_channel,
253
+ n_chans=self.n_chans,
254
+ n_times=self.n_times,
255
+ )
256
+ # Build the encoder with multiple layers.
257
+ self.encoder = _Encoder(
258
+ [
259
+ _EncoderLayer(
260
+ attention=_MedformerLayer(
261
+ num_blocks=len(self.patch_len_list),
262
+ d_model=self.embed_dim,
263
+ num_heads=self.num_heads,
264
+ dropout=self.drop_prob,
265
+ output_attention=self.output_attention,
266
+ no_inter=self.no_inter_attn,
267
+ ),
268
+ d_model=self.embed_dim,
269
+ dim_feedforward=self.dim_feedforward,
270
+ dropout=self.drop_prob,
271
+ activation=self.activation_trans()
272
+ if self.activation_trans is not None
273
+ else nn.ReLU(),
274
+ )
275
+ for _ in range(self.num_layers)
276
+ ],
277
+ norm_layer=torch.nn.LayerNorm(self.embed_dim),
278
+ )
279
+
280
+ # For classification tasks, add additional layers.
281
+ self.activation_layer = (
282
+ self.activation_class() if self.activation_class is not None else nn.GELU()
283
+ )
284
+ self.dropout = nn.Dropout(self.drop_prob)
285
+ self.final_layer = nn.Linear(
286
+ self.embed_dim
287
+ * len(self.patch_num_list)
288
+ * (1 if not self.single_channel else self.n_chans),
289
+ self.n_outputs,
290
+ )
291
+
292
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
293
+ """Forward pass of the Medformer model.
294
+
295
+ Parameters
296
+ ----------
297
+ x : torch.Tensor
298
+ Input tensor of shape (batch_size, n_chans, n_times).
299
+
300
+ Returns
301
+ -------
302
+ torch.Tensor
303
+ Output tensor of shape (batch_size, n_outputs).
304
+ """
305
+ # Embedding
306
+ enc_out = self.enc_embedding(x)
307
+ enc_out, _ = self.encoder(enc_out, attn_mask=None)
308
+
309
+ if self.single_channel:
310
+ # Reshape back from (batch_size * n_chans, ...) to (batch_size, n_chans, ...)
311
+ # Explicitly construct the reshape dimensions to be TorchScript compatible
312
+ batch_size = enc_out.shape[0] // self.n_chans
313
+ seq_len = enc_out.shape[1]
314
+ d_model = enc_out.shape[2]
315
+ enc_out = torch.reshape(
316
+ enc_out, (batch_size, self.n_chans, seq_len, d_model)
317
+ )
318
+
319
+ # Output
320
+ output = self.activation_layer(enc_out)
321
+ output = self.dropout(output)
322
+ output = output.reshape(
323
+ output.shape[0], -1
324
+ ) # (batch_size, seq_length * d_model)
325
+ output = self.final_layer(output) # (batch_size, num_classes)
326
+ return output
327
+
328
+
329
+ class _PositionalEmbedding(nn.Module):
330
+ def __init__(self, d_model: int, max_len: int = 5000):
331
+ super().__init__()
332
+ # If d_model is odd, temporarily work with d_model + 1.
333
+ if d_model % 2 == 1:
334
+ d_model_adj = d_model + 1
335
+ else:
336
+ d_model_adj = d_model
337
+ self.d_model = d_model # store the original dimension
338
+
339
+ # Create a pe tensor of size (max_len, d_model_adj)
340
+ pe = torch.zeros(max_len, d_model_adj).float()
341
+ pe.requires_grad = False
342
+
343
+ # Compute the sinusoidal factors.
344
+ position = torch.arange(0, max_len).float().unsqueeze(1)
345
+ # Use d_model_adj in the denominator so that the frequencies are computed over an even number.
346
+ div_term = torch.exp(
347
+ torch.arange(0, d_model_adj, 2).float() * (-math.log(10000.0) / d_model_adj)
348
+ )
349
+ pe[:, 0::2] = torch.sin(position * div_term)
350
+ pe[:, 1::2] = torch.cos(position * div_term)
351
+
352
+ # Unsqueeze to shape (1, max_len, d_model_adj)
353
+ pe = pe.unsqueeze(0)
354
+ self.register_buffer("pe", pe)
355
+
356
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
357
+ # x is assumed to have shape (B, L, d_model_target)
358
+ # We return the first self.d_model columns from the computed pe.
359
+ return self.pe[:, : x.size(1), : self.d_model]
360
+
361
+
362
+ class _CrossChannelTokenEmbedding(nn.Module):
363
+ def __init__(
364
+ self, c_in: int, l_patch: int, d_model: int, stride: Optional[int] = None
365
+ ):
366
+ super().__init__()
367
+ if stride is None:
368
+ stride = l_patch
369
+ self.token_conv = nn.Conv2d(
370
+ in_channels=1,
371
+ out_channels=d_model,
372
+ kernel_size=(c_in, l_patch),
373
+ stride=(1, stride),
374
+ padding=0,
375
+ padding_mode="circular",
376
+ bias=False,
377
+ )
378
+ for m in self.modules():
379
+ if isinstance(m, nn.Conv2d):
380
+ nn.init.kaiming_normal_(
381
+ m.weight, mode="fan_in", nonlinearity="leaky_relu"
382
+ )
383
+
384
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
385
+ x = self.token_conv(x)
386
+ return x
387
+
388
+
389
+ class _ListPatchEmbedding(nn.Module):
390
+ def __init__(
391
+ self,
392
+ enc_in: int,
393
+ d_model: int,
394
+ seq_len: int,
395
+ patch_len_list: List[int],
396
+ stride_list: List[int],
397
+ dropout: float,
398
+ single_channel: bool = False,
399
+ n_chans: Optional[int] = None,
400
+ n_times: Optional[int] = None,
401
+ ):
402
+ super().__init__()
403
+ self.patch_len_list = patch_len_list
404
+ self.stride_list = stride_list
405
+ # Use ModuleList so TorchScript can statically infer module attribute types
406
+ self.paddings: nn.ModuleList = nn.ModuleList(
407
+ [nn.ReplicationPad1d((0, stride)) for stride in stride_list]
408
+ )
409
+ self.single_channel = single_channel
410
+ self.n_chans = n_chans
411
+ self.n_times = n_times
412
+
413
+ # Number of different patch/granularity blocks (used to make loops TorchScript-friendly)
414
+ self.num_patches = len(patch_len_list)
415
+
416
+ linear_layers = [
417
+ _CrossChannelTokenEmbedding(
418
+ c_in=enc_in if not single_channel else 1,
419
+ l_patch=patch_len,
420
+ d_model=d_model,
421
+ )
422
+ for patch_len in patch_len_list
423
+ ]
424
+ self.value_embeddings = nn.ModuleList(linear_layers)
425
+ self.position_embedding = _PositionalEmbedding(d_model=d_model)
426
+ self.channel_embedding = _PositionalEmbedding(d_model=seq_len)
427
+ self.dropout = nn.Dropout(dropout)
428
+
429
+ self.learnable_embeddings = nn.ParameterList(
430
+ [nn.Parameter(torch.randn(1, d_model)) for _ in patch_len_list]
431
+ )
432
+
433
+ def forward(
434
+ self, x: torch.Tensor
435
+ ) -> List[torch.Tensor]: # (batch_size, seq_len, enc_in)
436
+ x = x.permute(0, 2, 1) # (batch_size, enc_in, seq_len)
437
+ if self.single_channel:
438
+ # After permute: x.shape = (batch_size, n_times, n_chans)
439
+ # We want to process each channel independently
440
+ batch_size = x.shape[0]
441
+ # Permute to get channels in the middle: (batch_size, n_chans, n_times)
442
+ x = x.permute(0, 2, 1)
443
+ # Reshape to treat each channel independently: (batch_size * n_chans, 1, n_times)
444
+ x = torch.reshape(x, (batch_size * self.n_chans, 1, self.n_times))
445
+
446
+ x_list = []
447
+ for padding, value_embedding in zip(self.paddings, self.value_embeddings):
448
+ x_copy = x.clone()
449
+ # add positional embedding to tag each channel (only when not single_channel)
450
+ if not self.single_channel:
451
+ x_new = x_copy + self.channel_embedding(x_copy)
452
+ else:
453
+ x_new = x_copy
454
+ x_new = padding(x_new).unsqueeze(
455
+ 1
456
+ ) # (batch_size, 1, enc_in, seq_len+stride)
457
+ x_new = value_embedding(x_new) # (batch_size, d_model, 1, patch_num)
458
+ x_new = x_new.squeeze(2).transpose(1, 2) # (batch_size, patch_num, d_model)
459
+ x_list.append(x_new)
460
+
461
+ # Combine each patch embedding with its corresponding learnable granularity
462
+ # embedding and positional embedding. Use an explicit indexed loop so
463
+ # TorchScript can statically determine lengths instead of iterating over
464
+ # Python lists/ParameterList via zip.
465
+ out_list: List[torch.Tensor] = []
466
+ # Iterate over learnable_embeddings with enumerate (supported by TorchScript)
467
+ for idx, cxt in enumerate(self.learnable_embeddings):
468
+ xi = x_list[idx]
469
+ xi = xi + cxt + self.position_embedding(xi)
470
+ out_list.append(xi)
471
+
472
+ return out_list
473
+
474
+
475
+ class _AttentionLayer(nn.Module):
476
+ def __init__(
477
+ self,
478
+ attention: nn.Module,
479
+ d_model: int,
480
+ num_heads: int,
481
+ d_keys: Optional[int] = None,
482
+ d_values: Optional[int] = None,
483
+ ):
484
+ super().__init__()
485
+
486
+ d_keys = d_keys or (d_model // num_heads)
487
+ d_values = d_values or (d_model // num_heads)
488
+
489
+ self.inner_attention = attention
490
+ self.query_projection = nn.Linear(d_model, d_keys * num_heads)
491
+ self.key_projection = nn.Linear(d_model, d_keys * num_heads)
492
+ self.value_projection = nn.Linear(d_model, d_values * num_heads)
493
+ self.out_projection = nn.Linear(d_values * num_heads, d_model)
494
+ self.num_heads = num_heads
495
+
496
+ def forward(
497
+ self,
498
+ queries: torch.Tensor,
499
+ keys: torch.Tensor,
500
+ values: torch.Tensor,
501
+ attn_mask: Optional[torch.Tensor],
502
+ tau: Optional[torch.Tensor] = None,
503
+ delta: Optional[torch.Tensor] = None,
504
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
505
+ batch_size, query_len, _ = queries.shape
506
+ _, key_len, _ = keys.shape
507
+ num_heads = self.num_heads
508
+
509
+ queries = self.query_projection(queries).view(
510
+ batch_size, query_len, num_heads, -1
511
+ ) # multi-head
512
+ keys = self.key_projection(keys).view(batch_size, key_len, num_heads, -1)
513
+ values = self.value_projection(values).view(batch_size, key_len, num_heads, -1)
514
+
515
+ out, attn = self.inner_attention(
516
+ queries, keys, values, attn_mask, tau=tau, delta=delta
517
+ )
518
+ out = out.view(batch_size, query_len, -1)
519
+
520
+ return self.out_projection(out), attn
521
+
522
+
523
+ class _TriangularCausalMask:
524
+ def __init__(
525
+ self, batch_size: int, seq_len: int, device: Optional[torch.device] = None
526
+ ):
527
+ # Normalize device to a torch.device for .to(device)
528
+ if device is None:
529
+ device = torch.device("cpu")
530
+ mask_shape = [batch_size, 1, seq_len, seq_len]
531
+ with torch.no_grad():
532
+ self._mask = torch.triu(
533
+ torch.ones(mask_shape, dtype=torch.bool), diagonal=1
534
+ ).to(device)
535
+
536
+ @property
537
+ def mask(self) -> torch.Tensor:
538
+ return self._mask
539
+
540
+
541
+ class _FullAttention(nn.Module):
542
+ def __init__(
543
+ self,
544
+ mask_flag: bool = True,
545
+ scale: Optional[float] = None,
546
+ attention_dropout: float = 0.1,
547
+ output_attention: bool = False,
548
+ ):
549
+ super().__init__()
550
+ self.scale = scale
551
+ self.mask_flag = mask_flag
552
+ self.output_attention = output_attention
553
+ self.dropout = nn.Dropout(attention_dropout)
554
+ super().__init__()
555
+ self.scale = scale
556
+ self.mask_flag = mask_flag
557
+ self.output_attention = output_attention
558
+ self.dropout = nn.Dropout(attention_dropout)
559
+
560
+ def forward(
561
+ self,
562
+ queries: torch.Tensor,
563
+ keys: torch.Tensor,
564
+ values: torch.Tensor,
565
+ attn_mask: Optional[torch.Tensor],
566
+ tau: Optional[torch.Tensor] = None,
567
+ delta: Optional[torch.Tensor] = None,
568
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
569
+ batch_size, query_len, _, embed_dim = queries.shape
570
+ _, _, _, _ = values.shape
571
+ # Avoid using `or` because TorchScript may fail to cast None/float
572
+ if self.scale is None:
573
+ scale = 1.0 / sqrt(embed_dim)
574
+ else:
575
+ scale = self.scale
576
+
577
+ scores = torch.einsum("blhe,bshe->bhls", queries, keys)
578
+
579
+ if self.mask_flag:
580
+ if attn_mask is None:
581
+ # create a triangular causal mask tensor of shape [B,1,L,L]
582
+ attn_mask = torch.triu(
583
+ torch.ones([batch_size, 1, query_len, query_len], dtype=torch.bool),
584
+ diagonal=1,
585
+ ).to(queries.device)
586
+
587
+ # attn_mask is expected to be a boolean tensor with same shape
588
+ scores.masked_fill_(attn_mask, -np.inf)
589
+
590
+ attention_weights = self.dropout(
591
+ torch.softmax(scale * scores, dim=-1)
592
+ ) # Scaled Dot-Product Attention
593
+ output_values = torch.einsum("bhls,bshd->blhd", attention_weights, values)
594
+
595
+ if self.output_attention:
596
+ return output_values.contiguous(), attention_weights
597
+ else:
598
+ return output_values.contiguous(), None
599
+
600
+
601
+ class _MedformerLayer(nn.Module):
602
+ def __init__(
603
+ self,
604
+ num_blocks: int,
605
+ d_model: int,
606
+ num_heads: int,
607
+ dropout: float = 0.1,
608
+ output_attention: bool = False,
609
+ no_inter: bool = False,
610
+ ):
611
+ super().__init__()
612
+
613
+ self.intra_attentions = nn.ModuleList(
614
+ [
615
+ _AttentionLayer(
616
+ _FullAttention(
617
+ mask_flag=False,
618
+ attention_dropout=dropout,
619
+ output_attention=output_attention,
620
+ ),
621
+ d_model,
622
+ num_heads,
623
+ )
624
+ for _ in range(num_blocks)
625
+ ]
626
+ )
627
+ if no_inter or num_blocks <= 1:
628
+ # print("No inter attention for time")
629
+ self.inter_attention = None
630
+ else:
631
+ self.inter_attention = _AttentionLayer(
632
+ _FullAttention(
633
+ mask_flag=False,
634
+ attention_dropout=dropout,
635
+ output_attention=output_attention,
636
+ ),
637
+ d_model,
638
+ num_heads,
639
+ )
640
+
641
+ def forward(
642
+ self,
643
+ x: List[torch.Tensor],
644
+ attn_mask: Optional[List[Optional[torch.Tensor]]] = None,
645
+ tau: Optional[torch.Tensor] = None,
646
+ delta: Optional[torch.Tensor] = None,
647
+ ) -> Tuple[List[torch.Tensor], List[Optional[torch.Tensor]]]:
648
+ # Explicit None check because TorchScript cannot evaluate truthiness of lists
649
+ if attn_mask is None:
650
+ # Build a list of None with explicit typing for TorchScript
651
+ new_mask_list: List[Optional[torch.Tensor]] = []
652
+ for _ in range(len(x)):
653
+ new_mask_list.append(None)
654
+ attn_mask = new_mask_list
655
+ # Intra attention
656
+ x_intra: List[torch.Tensor] = []
657
+ attn_out: List[Optional[torch.Tensor]] = []
658
+ # Iterate over ModuleList with enumerate (TorchScript supports enumerate over modules)
659
+ for idx, layer in enumerate(self.intra_attentions):
660
+ x_in_i = x[idx]
661
+ mask_i = attn_mask[idx]
662
+ x_out_temp, attn_temp = layer(
663
+ x_in_i, x_in_i, x_in_i, attn_mask=mask_i, tau=tau, delta=delta
664
+ )
665
+ x_intra.append(x_out_temp) # (B, Li, D)
666
+ attn_out.append(attn_temp)
667
+ if self.inter_attention is not None:
668
+ # Inter attention
669
+ routers = torch.cat([xi[:, -1:] for xi in x_intra], dim=1) # (B, N, D)
670
+ x_inter, attn_inter = self.inter_attention(
671
+ routers, routers, routers, attn_mask=None, tau=tau, delta=delta
672
+ )
673
+ x_out = [
674
+ torch.cat([xi[:, :-1], x_inter[:, i : i + 1]], dim=1) # (B, Li, D)
675
+ for i, xi in enumerate(x_intra)
676
+ ]
677
+ attn_out += [attn_inter]
678
+ else:
679
+ x_out = x_intra
680
+ return x_out, attn_out
681
+
682
+
683
+ class _EncoderLayer(nn.Module):
684
+ def __init__(
685
+ self,
686
+ attention: nn.Module,
687
+ d_model: int,
688
+ dim_feedforward: Optional[int],
689
+ dropout: float,
690
+ activation: Optional[nn.Module] = None,
691
+ ):
692
+ super().__init__()
693
+ dim_feedforward = dim_feedforward or 4 * d_model
694
+ self.attention = attention
695
+ self.conv1 = nn.Conv1d(
696
+ in_channels=d_model, out_channels=dim_feedforward, kernel_size=1
697
+ )
698
+ self.conv2 = nn.Conv1d(
699
+ in_channels=dim_feedforward, out_channels=d_model, kernel_size=1
700
+ )
701
+ self.norm1 = nn.LayerNorm(d_model)
702
+ self.norm2 = nn.LayerNorm(d_model)
703
+ self.dropout = nn.Dropout(dropout)
704
+ self.activation = activation if activation is not None else nn.ReLU()
705
+
706
+ def forward(
707
+ self,
708
+ x: List[torch.Tensor],
709
+ attn_mask: Optional[List[Optional[torch.Tensor]]] = None,
710
+ tau: Optional[torch.Tensor] = None,
711
+ delta: Optional[torch.Tensor] = None,
712
+ ) -> Tuple[List[torch.Tensor], List[Optional[torch.Tensor]]]:
713
+ new_x, attn = self.attention(x, attn_mask=attn_mask, tau=tau, delta=delta)
714
+ x = [x_orig + self.dropout(x_new) for x_orig, x_new in zip(x, new_x)]
715
+
716
+ y = x = [self.norm1(x_val) for x_val in x]
717
+ y = [
718
+ self.dropout(self.activation(self.conv1(y_val.transpose(-1, 1))))
719
+ for y_val in y
720
+ ]
721
+ y = [self.dropout(self.conv2(y_val).transpose(-1, 1)) for y_val in y]
722
+
723
+ return [self.norm2(x_val + y_val) for x_val, y_val in zip(x, y)], attn
724
+
725
+
726
+ class _Encoder(nn.Module):
727
+ def __init__(
728
+ self, attn_layers: List[nn.Module], norm_layer: Optional[nn.Module] = None
729
+ ):
730
+ super().__init__()
731
+ self.attn_layers = nn.ModuleList(attn_layers)
732
+ self.norm = norm_layer
733
+
734
+ def forward(
735
+ self,
736
+ x: List[torch.Tensor],
737
+ attn_mask: Optional[List[Optional[torch.Tensor]]] = None,
738
+ tau: Optional[torch.Tensor] = None,
739
+ delta: Optional[torch.Tensor] = None,
740
+ ) -> Tuple[torch.Tensor, List[List[Optional[torch.Tensor]]]]:
741
+ # x [[B, L1, D], [B, L2, D], ...]
742
+ attns: List[List[Optional[torch.Tensor]]] = []
743
+ for attn_layer in self.attn_layers:
744
+ x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
745
+ attns.append(attn)
746
+
747
+ # concat all the outputs
748
+ """x = torch.cat(
749
+ x, dim=1
750
+ ) # (batch_size, patch_num_1 + patch_num_2 + ... , d_model)"""
751
+
752
+ # concat all the routers
753
+ x = torch.cat(
754
+ [xi[:, -1, :].unsqueeze(1) for xi in x], dim=1
755
+ ) # (batch_size, len(patch_len_list), d_model)
756
+
757
+ if self.norm is not None:
758
+ x = self.norm(x)
759
+
760
+ return x, attns