braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/augmentation/__init__.py +3 -5
- braindecode/augmentation/base.py +5 -8
- braindecode/augmentation/functional.py +22 -25
- braindecode/augmentation/transforms.py +42 -51
- braindecode/classifier.py +16 -11
- braindecode/datasets/__init__.py +3 -5
- braindecode/datasets/base.py +13 -17
- braindecode/datasets/bbci.py +14 -13
- braindecode/datasets/bcicomp.py +5 -4
- braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
- braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
- braindecode/datasets/{bids/hub.py → hub.py} +350 -375
- braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
- braindecode/datasets/mne.py +19 -19
- braindecode/datasets/moabb.py +10 -10
- braindecode/datasets/nmt.py +56 -58
- braindecode/datasets/sleep_physio_challe_18.py +5 -3
- braindecode/datasets/sleep_physionet.py +5 -5
- braindecode/datasets/tuh.py +18 -21
- braindecode/datasets/xy.py +9 -10
- braindecode/datautil/__init__.py +3 -3
- braindecode/datautil/serialization.py +20 -22
- braindecode/datautil/util.py +7 -120
- braindecode/eegneuralnet.py +52 -22
- braindecode/functional/functions.py +10 -7
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +3 -5
- braindecode/models/atcnet.py +39 -43
- braindecode/models/attentionbasenet.py +41 -37
- braindecode/models/attn_sleep.py +24 -26
- braindecode/models/base.py +6 -6
- braindecode/models/bendr.py +26 -50
- braindecode/models/biot.py +30 -61
- braindecode/models/contrawr.py +5 -5
- braindecode/models/ctnet.py +35 -35
- braindecode/models/deep4.py +5 -5
- braindecode/models/deepsleepnet.py +7 -7
- braindecode/models/eegconformer.py +26 -31
- braindecode/models/eeginception_erp.py +2 -2
- braindecode/models/eeginception_mi.py +6 -6
- braindecode/models/eegitnet.py +5 -5
- braindecode/models/eegminer.py +1 -1
- braindecode/models/eegnet.py +3 -3
- braindecode/models/eegnex.py +2 -2
- braindecode/models/eegsimpleconv.py +2 -2
- braindecode/models/eegsym.py +7 -7
- braindecode/models/eegtcnet.py +6 -6
- braindecode/models/fbcnet.py +2 -2
- braindecode/models/fblightconvnet.py +3 -3
- braindecode/models/fbmsnet.py +3 -3
- braindecode/models/hybrid.py +2 -2
- braindecode/models/ifnet.py +5 -5
- braindecode/models/labram.py +46 -70
- braindecode/models/luna.py +5 -60
- braindecode/models/medformer.py +21 -23
- braindecode/models/msvtnet.py +15 -15
- braindecode/models/patchedtransformer.py +55 -55
- braindecode/models/sccnet.py +2 -2
- braindecode/models/shallow_fbcsp.py +3 -5
- braindecode/models/signal_jepa.py +12 -39
- braindecode/models/sinc_shallow.py +4 -3
- braindecode/models/sleep_stager_blanco_2020.py +2 -2
- braindecode/models/sleep_stager_chambon_2018.py +2 -2
- braindecode/models/sparcnet.py +8 -8
- braindecode/models/sstdpn.py +869 -869
- braindecode/models/summary.csv +17 -19
- braindecode/models/syncnet.py +2 -2
- braindecode/models/tcn.py +5 -5
- braindecode/models/tidnet.py +3 -3
- braindecode/models/tsinception.py +3 -3
- braindecode/models/usleep.py +7 -7
- braindecode/models/util.py +14 -165
- braindecode/modules/__init__.py +1 -9
- braindecode/modules/activation.py +3 -29
- braindecode/modules/attention.py +0 -123
- braindecode/modules/blocks.py +1 -53
- braindecode/modules/convolution.py +0 -53
- braindecode/modules/filter.py +0 -31
- braindecode/modules/layers.py +0 -84
- braindecode/modules/linear.py +1 -22
- braindecode/modules/stats.py +0 -10
- braindecode/modules/util.py +0 -9
- braindecode/modules/wrapper.py +0 -17
- braindecode/preprocessing/preprocess.py +0 -3
- braindecode/regressor.py +18 -15
- braindecode/samplers/ssl.py +1 -1
- braindecode/util.py +28 -38
- braindecode/version.py +1 -1
- braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
- braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
- braindecode/datasets/bids/__init__.py +0 -54
- braindecode/datasets/bids/format.py +0 -717
- braindecode/datasets/bids/hub_format.py +0 -717
- braindecode/datasets/bids/hub_io.py +0 -197
- braindecode/datasets/chb_mit.py +0 -163
- braindecode/datasets/siena.py +0 -162
- braindecode/datasets/utils.py +0 -67
- braindecode/models/brainmodule.py +0 -845
- braindecode/models/config.py +0 -233
- braindecode/models/reve.py +0 -843
- braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
- braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
braindecode/models/reve.py
DELETED
|
@@ -1,843 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
REVE (Representation for EEG with versatile embeddings) model.
|
|
3
|
-
Authors: Jonathan Lys (jonathan.lys@imt-atlantique.org)
|
|
4
|
-
License: BSD 3 clause
|
|
5
|
-
"""
|
|
6
|
-
|
|
7
|
-
import json
|
|
8
|
-
import logging
|
|
9
|
-
import math
|
|
10
|
-
import os
|
|
11
|
-
from pathlib import Path
|
|
12
|
-
from typing import Optional, Union
|
|
13
|
-
|
|
14
|
-
import requests
|
|
15
|
-
import torch
|
|
16
|
-
import torch.nn.functional as F
|
|
17
|
-
from einops import rearrange
|
|
18
|
-
from torch import nn
|
|
19
|
-
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
20
|
-
|
|
21
|
-
from braindecode.models.base import EEGModuleMixin
|
|
22
|
-
|
|
23
|
-
logger = logging.getLogger(__name__)
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
class REVE(EEGModuleMixin, nn.Module):
|
|
27
|
-
r"""
|
|
28
|
-
**R**\ epresentation for **E**\ EG with **V**\ ersatile **E**\ mbeddings (REVE) from El Ouahidi et al. (2025) [reve]_.
|
|
29
|
-
|
|
30
|
-
:bdg-danger:`Foundation Model` :bdg-info:`Attention/Transformer`
|
|
31
|
-
|
|
32
|
-
.. figure:: https://brain-bzh.github.io/reve/static/images/architecture.png
|
|
33
|
-
:align: center
|
|
34
|
-
:alt: REVE Training pipeline overview
|
|
35
|
-
:width: 1000px
|
|
36
|
-
|
|
37
|
-
Foundation models have transformed machine learning by reducing reliance on
|
|
38
|
-
task-specific data and induced biases through large-scale pretraining. While
|
|
39
|
-
successful in language and vision, their adoption in EEG has lagged due to the
|
|
40
|
-
heterogeneity of public datasets, which are collected under varying protocols,
|
|
41
|
-
devices, and electrode configurations. Existing EEG foundation models struggle
|
|
42
|
-
to generalize across these variations, often restricting pretraining to a single
|
|
43
|
-
setup and resulting in suboptimal performance, particularly under linear probing.
|
|
44
|
-
|
|
45
|
-
REVE is a pretrained model explicitly designed to generalize across diverse EEG signals. It introduces
|
|
46
|
-
a **4D positional encoding** scheme that enables processing signals of arbitrary length and electrode
|
|
47
|
-
arrangement. Using a masked autoencoding objective, REVE was pretrained on over **60,000 hours** of EEG
|
|
48
|
-
data from **92 datasets** spanning **25,000 subjects**, the largest EEG pretraining effort to date.
|
|
49
|
-
|
|
50
|
-
.. rubric:: Channels Invariant Positional Encoding
|
|
51
|
-
|
|
52
|
-
Prior EEG foundation models (:class:`~braindecode.models.Labram`, :class:`~braindecode.models.BIOT`) rely on
|
|
53
|
-
fixed positional embeddings, making direct transfer to unseen electrode layouts infeasible. CBraMod uses
|
|
54
|
-
convolution-based positional encoding that requires fine-tuning when adapting to new configurations.
|
|
55
|
-
As noted in the CBraMod paper: *"fixing the pre-trained parameters during training on downstream
|
|
56
|
-
datasets will lead to a very large performance decline."*
|
|
57
|
-
|
|
58
|
-
REVE's 4D positional encoding jointly encodes spatial :math:`(x, y, z)` and temporal :math:`(t)` positions
|
|
59
|
-
using Fourier embeddings, enabling true cross-configuration transfer without retraining. The fourier embedding
|
|
60
|
-
have inspiration on brainmodule [brainmodule]_, generalized to 4D for EEG with the channel spatial coordinates
|
|
61
|
-
and temporal patch index.
|
|
62
|
-
|
|
63
|
-
.. rubric:: Linear Probing Performance
|
|
64
|
-
|
|
65
|
-
A key advantage of REVE is producing useful latent representation without heavy fine-tuning. Under linear
|
|
66
|
-
probing (frozen encoder), REVE achieves state-of-the-art results on downstream EEG tasks.
|
|
67
|
-
This enables practical deployment in low-data scenarios where extensive fine-tuning is not feasible.
|
|
68
|
-
|
|
69
|
-
.. rubric:: Architecture
|
|
70
|
-
|
|
71
|
-
The model adopts modern Transformer components validated through ablation studies:
|
|
72
|
-
|
|
73
|
-
- **Normalization**: RMSNorm outperforms LayerNorm;
|
|
74
|
-
- **Activation**: GEGLU outperforms GELU;
|
|
75
|
-
- **Attention**: Flash Attention via PyTorch's SDPA;
|
|
76
|
-
- **Masking ratio**: 55% optimal for spatio-temporal block masking
|
|
77
|
-
|
|
78
|
-
These choices align with best practices from large language models and were empirically validated
|
|
79
|
-
on EEG data.
|
|
80
|
-
|
|
81
|
-
.. rubric:: Secondary Loss
|
|
82
|
-
|
|
83
|
-
A secondary reconstruction objective using attention pooling across layers prevents over-specialization
|
|
84
|
-
in the final layer. This pooling acts as an information bottleneck, forcing the model to distill key
|
|
85
|
-
information from the entire sequence. Ablations show this loss is crucial for linear probing quality:
|
|
86
|
-
removing it drops average performance in 10% under the frozen evaluation.
|
|
87
|
-
|
|
88
|
-
.. rubric:: Macro Components
|
|
89
|
-
|
|
90
|
-
- ``REVE.to_patch_embedding`` **Patch Tokenization**
|
|
91
|
-
|
|
92
|
-
The EEG signal is split into overlapping patches along the time dimension, generating
|
|
93
|
-
:math:`p = \left\lceil \frac{T - w}{w - o} \right\rceil + \mathbf{1}[(T - w) \bmod (w - o) \neq 0]`
|
|
94
|
-
patches of size :math:`w` with overlap :math:`o`, where :math:`T` is the signal length.
|
|
95
|
-
Each patch is linearly projected to the embedding dimension.
|
|
96
|
-
|
|
97
|
-
- ``REVE.fourier4d`` + ``REVE.mlp4d`` **4D Positional Embedding (4DPE)**
|
|
98
|
-
|
|
99
|
-
The 4DPE encodes each token's 4D coordinates :math:`(x, y, z, t)` where :math:`(x, y, z)` are the
|
|
100
|
-
3D spatial coordinates from a standardized electrode position bank, and :math:`t` is the temporal
|
|
101
|
-
patch index. The encoding combines:
|
|
102
|
-
|
|
103
|
-
1. **Fourier embedding**: Sinusoidal encoding across multiple frequencies for smooth interpolation
|
|
104
|
-
to unseen positions
|
|
105
|
-
2. **MLP embedding**: :class:`~torch.nn.Linear` (4 → embed_dim) → :class:`~torch.nn.GELU` → :class:`~torch.nn.LayerNorm` for learnable refinement
|
|
106
|
-
|
|
107
|
-
Both components are summed and normalized. The 4DPE adds negligible computational overhead,
|
|
108
|
-
scaling linearly with the number of tokens.
|
|
109
|
-
|
|
110
|
-
- ``REVE.transformer`` **Transformer Encoder**
|
|
111
|
-
|
|
112
|
-
Pre-LayerNorm Transformer with multi-head self-attention (:class:`~torch.nn.RMSNorm`), feed-forward networks (GEGLU
|
|
113
|
-
activation), and residual connections. Default configuration: 22 layers, 8 heads, 512 embedding
|
|
114
|
-
dimension (~72M parameters).
|
|
115
|
-
|
|
116
|
-
- ``REVE.final_layer`` **Classification Head**
|
|
117
|
-
|
|
118
|
-
Two modes (controlled by the ``attention_pooling`` parameter):
|
|
119
|
-
|
|
120
|
-
- When ``attention_pooling`` is disabled (e.g., ``None`` or ``False``): flatten all tokens
|
|
121
|
-
→ :class:`~torch.nn.LayerNorm` → :class:`~torch.nn.Linear`
|
|
122
|
-
- When ``attention_pooling`` is enabled: attention pooling with a learnable query token
|
|
123
|
-
attending to all encoder outputs
|
|
124
|
-
|
|
125
|
-
.. rubric:: Known Limitations
|
|
126
|
-
|
|
127
|
-
- **Sparse electrode setups**: Performance degrades with very few channels. On motor imagery,
|
|
128
|
-
accuracy drops from 0.824 (64 channels) to 0.660 (1 channel). For tasks requiring broad
|
|
129
|
-
spatial coverage (e.g., imagined speech), performance with <4 channels approaches chance level.
|
|
130
|
-
- **Demographic bias**: The pretraining corpus aggregates publicly available datasets, most
|
|
131
|
-
originating from North America and Europe, resulting in limited demographic diversity,
|
|
132
|
-
more details about the datasets used for pretraining can be found in the REVE paper [reve]_.
|
|
133
|
-
|
|
134
|
-
.. rubric:: Pretrained Weights
|
|
135
|
-
|
|
136
|
-
Weights are available on `HuggingFace <https://huggingface.co/collections/brain-bzh/reve>`_,
|
|
137
|
-
but you must agree to the data usage terms before downloading:
|
|
138
|
-
|
|
139
|
-
- ``brain-bzh/reve-base``: 72M parameters, 512 embedding dim, 22 layers (~260 A100 GPU hours)
|
|
140
|
-
- ``brain-bzh/reve-large``: ~400M parameters, 1250 embedding dim
|
|
141
|
-
|
|
142
|
-
.. important::
|
|
143
|
-
**Pre-trained Weights Available (Registration Required)**
|
|
144
|
-
|
|
145
|
-
This model has pre-trained weights available on the Hugging Face Hub.
|
|
146
|
-
**You must first register and agree to the data usage terms on the authors'
|
|
147
|
-
HuggingFace repository before you can access the weights.**
|
|
148
|
-
`Link here <https://huggingface.co/collections/brain-bzh/reve>`_.
|
|
149
|
-
|
|
150
|
-
You can load them using:
|
|
151
|
-
|
|
152
|
-
.. code-block:: python
|
|
153
|
-
|
|
154
|
-
from braindecode.models import REVE
|
|
155
|
-
|
|
156
|
-
# Load pre-trained model from Hugging Face Hub
|
|
157
|
-
model = REVE.from_pretrained("brain-bzh/reve-base")
|
|
158
|
-
|
|
159
|
-
To push your own trained model to the Hub:
|
|
160
|
-
|
|
161
|
-
.. code-block:: python
|
|
162
|
-
|
|
163
|
-
# After training your model
|
|
164
|
-
model.push_to_hub(
|
|
165
|
-
repo_id="username/my-reve-model", commit_message="Upload trained REVE model"
|
|
166
|
-
)
|
|
167
|
-
|
|
168
|
-
Requires installing ``braindecode[hug]`` for Hub integration.
|
|
169
|
-
|
|
170
|
-
.. rubric:: Usage
|
|
171
|
-
|
|
172
|
-
.. code-block:: python
|
|
173
|
-
|
|
174
|
-
from braindecode.models import REVE
|
|
175
|
-
|
|
176
|
-
model = REVE(
|
|
177
|
-
n_outputs=4, # e.g., 4-class motor imagery
|
|
178
|
-
n_chans=22,
|
|
179
|
-
n_times=1000, # 5 seconds at 200 Hz
|
|
180
|
-
sfreq=200,
|
|
181
|
-
chs_info=[{"ch_name": "C3"}, {"ch_name": "C4"}, ...],
|
|
182
|
-
)
|
|
183
|
-
|
|
184
|
-
# Forward pass: (batch, n_chans, n_times) -> (batch, n_outputs)
|
|
185
|
-
output = model(eeg_data, pos=channel_positions)
|
|
186
|
-
|
|
187
|
-
.. warning::
|
|
188
|
-
|
|
189
|
-
Input data must be sampled at **200 Hz** to match pretraining. The model applies
|
|
190
|
-
z-score normalization followed by clipping at 15 standard deviations internally
|
|
191
|
-
during pretraining-users should apply similar preprocessing.
|
|
192
|
-
|
|
193
|
-
Parameters
|
|
194
|
-
----------
|
|
195
|
-
embed_dim : int, default=512
|
|
196
|
-
Embedding dimension. Use 512 for REVE-Base, 1250 for REVE-Large.
|
|
197
|
-
depth : int, default=22
|
|
198
|
-
Number of Transformer layers.
|
|
199
|
-
heads : int, default=8
|
|
200
|
-
Number of attention heads.
|
|
201
|
-
head_dim : int, default=64
|
|
202
|
-
Dimension per attention head.
|
|
203
|
-
mlp_dim_ratio : float, default=2.66
|
|
204
|
-
FFN hidden dimension ratio: ``mlp_dim = embed_dim × mlp_dim_ratio``.
|
|
205
|
-
use_geglu : bool, default=True
|
|
206
|
-
Use GEGLU activation (recommended) or standard GELU.
|
|
207
|
-
freqs : int, default=4
|
|
208
|
-
Number of frequencies for Fourier positional embedding.
|
|
209
|
-
patch_size : int, default=200
|
|
210
|
-
Temporal patch size in samples (200 samples = 1 second at 200 Hz).
|
|
211
|
-
patch_overlap : int, default=20
|
|
212
|
-
Overlap between patches in samples.
|
|
213
|
-
attention_pooling : bool, default=False
|
|
214
|
-
Pooling strategy for aggregating transformer outputs before classification.
|
|
215
|
-
If ``False`` (default), all tokens are flattened into a single vector of size
|
|
216
|
-
``(n_chans x n_patches x embed_dim)``, which is then passed through LayerNorm
|
|
217
|
-
and a linear classifier. If ``True``, uses attention-based pooling with a
|
|
218
|
-
learnable query token that attends to all encoder outputs, producing a single
|
|
219
|
-
embedding of size ``embed_dim``. Attention pooling is more parameter-efficient
|
|
220
|
-
for long sequences and variable-length inputs.
|
|
221
|
-
|
|
222
|
-
References
|
|
223
|
-
----------
|
|
224
|
-
.. [reve] El Ouahidi, Y., Lys, J., Thölke, P., Farrugia, N., Pasdeloup, B.,
|
|
225
|
-
Gripon, V., Jerbi, K. & Lioi, G. (2025). REVE: A Foundation Model for EEG -
|
|
226
|
-
Adapting to Any Setup with Large-Scale Pretraining on 25,000 Subjects.
|
|
227
|
-
The Thirty-Ninth Annual Conference on Neural Information Processing Systems.
|
|
228
|
-
https://openreview.net/forum?id=ZeFMtRBy4Z
|
|
229
|
-
.. [brainmodule] Défossez, A., Caucheteux, C., Rapin, J., Kabeli, O., & King, J. R.
|
|
230
|
-
(2023). Decoding speech perception from non-invasive brain recordings. Nature
|
|
231
|
-
Machine Intelligence, 5(10), 1097-1107.
|
|
232
|
-
|
|
233
|
-
Notes
|
|
234
|
-
-----
|
|
235
|
-
The position bank is downloaded from HuggingFace on first initialization, mapping
|
|
236
|
-
standard 10-20/10-10/10-05 electrode names to 3D coordinates. This enables the
|
|
237
|
-
4D positional encoding to generalize across electrode configurations without
|
|
238
|
-
requiring matched layouts between pretraining and downstream tasks.
|
|
239
|
-
"""
|
|
240
|
-
|
|
241
|
-
def __init__(
|
|
242
|
-
self,
|
|
243
|
-
n_outputs=None,
|
|
244
|
-
n_chans=None,
|
|
245
|
-
chs_info=None,
|
|
246
|
-
n_times=None,
|
|
247
|
-
input_window_seconds=None,
|
|
248
|
-
sfreq=None,
|
|
249
|
-
# REVE specific parameters
|
|
250
|
-
embed_dim: int = 512,
|
|
251
|
-
depth: int = 22,
|
|
252
|
-
heads: int = 8,
|
|
253
|
-
head_dim: int = 64,
|
|
254
|
-
mlp_dim_ratio: float = 2.66,
|
|
255
|
-
use_geglu: bool = True,
|
|
256
|
-
freqs: int = 4,
|
|
257
|
-
patch_size: int = 200,
|
|
258
|
-
patch_overlap: int = 20,
|
|
259
|
-
attention_pooling: bool = False,
|
|
260
|
-
):
|
|
261
|
-
super().__init__(
|
|
262
|
-
n_outputs=n_outputs,
|
|
263
|
-
n_chans=n_chans,
|
|
264
|
-
chs_info=chs_info,
|
|
265
|
-
n_times=n_times,
|
|
266
|
-
input_window_seconds=input_window_seconds,
|
|
267
|
-
sfreq=sfreq,
|
|
268
|
-
)
|
|
269
|
-
|
|
270
|
-
self.embed_dim = embed_dim
|
|
271
|
-
self.freqs = freqs
|
|
272
|
-
self.patch_size = patch_size
|
|
273
|
-
self.patch_overlap = patch_overlap
|
|
274
|
-
self.depth = depth
|
|
275
|
-
self.heads = heads
|
|
276
|
-
self.head_dim = head_dim
|
|
277
|
-
self.mlp_dim_ratio = mlp_dim_ratio
|
|
278
|
-
self.use_geglu = use_geglu
|
|
279
|
-
|
|
280
|
-
self.use_attention_pooling = attention_pooling
|
|
281
|
-
|
|
282
|
-
self.to_patch_embedding = nn.Sequential(
|
|
283
|
-
nn.Linear(in_features=self.patch_size, out_features=self.embed_dim)
|
|
284
|
-
)
|
|
285
|
-
|
|
286
|
-
self.fourier4d = FourierEmb4D(self.embed_dim, freqs=self.freqs)
|
|
287
|
-
|
|
288
|
-
self.mlp4d = nn.Sequential(
|
|
289
|
-
nn.Linear(4, self.embed_dim, bias=False),
|
|
290
|
-
nn.GELU(),
|
|
291
|
-
nn.LayerNorm(self.embed_dim),
|
|
292
|
-
)
|
|
293
|
-
|
|
294
|
-
self.ln = nn.LayerNorm(self.embed_dim) # 4DPE module layernorm
|
|
295
|
-
|
|
296
|
-
self.transformer = TransformerBackbone(
|
|
297
|
-
dim=self.embed_dim,
|
|
298
|
-
depth=self.depth,
|
|
299
|
-
heads=self.heads,
|
|
300
|
-
head_dim=self.head_dim,
|
|
301
|
-
mlp_dim=int(self.embed_dim * self.mlp_dim_ratio),
|
|
302
|
-
geglu=self.use_geglu,
|
|
303
|
-
)
|
|
304
|
-
|
|
305
|
-
final_dim = self._get_flattened_output_dim()
|
|
306
|
-
|
|
307
|
-
if self.use_attention_pooling:
|
|
308
|
-
self.final_layer = nn.Sequential(
|
|
309
|
-
nn.LayerNorm(self.embed_dim),
|
|
310
|
-
nn.Linear(self.embed_dim, self.n_outputs),
|
|
311
|
-
)
|
|
312
|
-
else:
|
|
313
|
-
self.final_layer = nn.Sequential(
|
|
314
|
-
nn.Flatten(),
|
|
315
|
-
nn.LayerNorm(final_dim),
|
|
316
|
-
nn.Linear(final_dim, self.n_outputs),
|
|
317
|
-
)
|
|
318
|
-
|
|
319
|
-
if self.use_attention_pooling:
|
|
320
|
-
self.cls_query_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
|
321
|
-
|
|
322
|
-
self._position_bank = RevePositionBank()
|
|
323
|
-
|
|
324
|
-
self.default_pos = None
|
|
325
|
-
if chs_info is not None:
|
|
326
|
-
self.default_pos = self.get_positions([ch["ch_name"] for ch in chs_info])
|
|
327
|
-
|
|
328
|
-
def _get_flattened_output_dim(self) -> int:
|
|
329
|
-
"""Helper function to compute the flattened output dimension after the transformer."""
|
|
330
|
-
|
|
331
|
-
if self.use_attention_pooling:
|
|
332
|
-
return self.embed_dim
|
|
333
|
-
|
|
334
|
-
n_patches = math.ceil(
|
|
335
|
-
(self.n_times - self.patch_size) / (self.patch_size - self.patch_overlap)
|
|
336
|
-
)
|
|
337
|
-
|
|
338
|
-
if (self.n_times - self.patch_size) % (
|
|
339
|
-
self.patch_size - self.patch_overlap
|
|
340
|
-
) == 0:
|
|
341
|
-
n_patches += 1
|
|
342
|
-
|
|
343
|
-
flat_dim = self.n_chans * n_patches * self.embed_dim
|
|
344
|
-
return flat_dim
|
|
345
|
-
|
|
346
|
-
def get_positions(self, channel_names: list[str]) -> torch.Tensor:
|
|
347
|
-
"""Fetch channel positions from the position bank. The position bank is downloaded when the model is instantiated.
|
|
348
|
-
|
|
349
|
-
Parameters
|
|
350
|
-
----------
|
|
351
|
-
channel_names : list[str]
|
|
352
|
-
List of channel names for which to fetch positions.
|
|
353
|
-
|
|
354
|
-
Returns
|
|
355
|
-
-------
|
|
356
|
-
torch.Tensor
|
|
357
|
-
Tensor of shape (num_channels, 3) containing the (x, y, z) positions of the channels.
|
|
358
|
-
"""
|
|
359
|
-
|
|
360
|
-
return self._position_bank.forward(channel_names)
|
|
361
|
-
|
|
362
|
-
def forward(
|
|
363
|
-
self,
|
|
364
|
-
eeg: torch.Tensor,
|
|
365
|
-
pos: Optional[torch.Tensor] = None,
|
|
366
|
-
return_output: bool = False,
|
|
367
|
-
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
|
368
|
-
"""
|
|
369
|
-
Forward pass of the model.
|
|
370
|
-
|
|
371
|
-
Goes through the following steps:
|
|
372
|
-
1. Patch extraction from the EEG signal.
|
|
373
|
-
2. 4D positional embedding computation.
|
|
374
|
-
3. Transformer encoding.
|
|
375
|
-
4. Final layer processing (if `return_output` is False).
|
|
376
|
-
|
|
377
|
-
Parameters
|
|
378
|
-
----------
|
|
379
|
-
eeg : torch.Tensor
|
|
380
|
-
Input EEG tensor of shape (batch_size, channels, sequence_length).
|
|
381
|
-
pos : torch.Tensor
|
|
382
|
-
Position tensor of shape (batch_size, channels, 3) representing (x, y, z) coordinates.
|
|
383
|
-
return_output : bool, optional
|
|
384
|
-
If True, returns the output from the transformer directly.
|
|
385
|
-
If False, applies the final layer and returns the processed output. Default is False.
|
|
386
|
-
|
|
387
|
-
Returns
|
|
388
|
-
-------
|
|
389
|
-
Union[torch.Tensor, list[torch.Tensor]]
|
|
390
|
-
- If `return_output` is False: Returns a single `torch.Tensor` (output after final layer).
|
|
391
|
-
- If `return_output` is True: Returns a `list[torch.Tensor]` (outputs from transformer layers).
|
|
392
|
-
|
|
393
|
-
The output tensor(s) from the model. If `return_output` is True,
|
|
394
|
-
returns the transformer output; otherwise, returns the output after the final layer.
|
|
395
|
-
"""
|
|
396
|
-
|
|
397
|
-
patches = eeg.unfold(
|
|
398
|
-
dimension=2,
|
|
399
|
-
size=self.patch_size,
|
|
400
|
-
step=self.patch_size - self.patch_overlap,
|
|
401
|
-
)
|
|
402
|
-
batch_size, channel, n_patches, _ = patches.shape
|
|
403
|
-
|
|
404
|
-
if pos is None:
|
|
405
|
-
if self.default_pos is None:
|
|
406
|
-
raise ValueError(
|
|
407
|
-
"No positions provided and no default positions available. Please provide channel positions."
|
|
408
|
-
)
|
|
409
|
-
pos = self.default_pos.expand(batch_size, -1, -1).to(eeg.device)
|
|
410
|
-
|
|
411
|
-
pos = FourierEmb4D.add_time_patch(pos, n_patches)
|
|
412
|
-
pos_embed = self.ln(self.fourier4d(pos) + self.mlp4d(pos))
|
|
413
|
-
|
|
414
|
-
# Patch embedding: (batch, channels, n_patches, patch_size) -> (batch, channels, n_patches, embed_dim)
|
|
415
|
-
patch_embeddings = self.to_patch_embedding(patches)
|
|
416
|
-
|
|
417
|
-
# Flatten spatial and temporal dimensions into a single sequence dimension
|
|
418
|
-
# (batch, channels, n_patches, embed_dim) -> (batch, channels * n_patches, embed_dim)
|
|
419
|
-
# This creates a sequence of tokens where each token represents one patch from one channel
|
|
420
|
-
x = (
|
|
421
|
-
rearrange(
|
|
422
|
-
patch_embeddings,
|
|
423
|
-
"batch chan patch emb -> batch (chan patch) emb",
|
|
424
|
-
chan=channel,
|
|
425
|
-
patch=n_patches,
|
|
426
|
-
emb=self.embed_dim,
|
|
427
|
-
)
|
|
428
|
-
+ pos_embed
|
|
429
|
-
)
|
|
430
|
-
x = self.transformer(x, return_output)
|
|
431
|
-
|
|
432
|
-
if return_output:
|
|
433
|
-
return x
|
|
434
|
-
|
|
435
|
-
# Reshape back from flattened sequence to separate channel and temporal dimensions
|
|
436
|
-
# (batch, channels * n_patches, embed_dim) -> (batch, channels, n_patches, embed_dim)
|
|
437
|
-
# This recovers the spatio-temporal structure for downstream processing
|
|
438
|
-
x = rearrange(
|
|
439
|
-
x,
|
|
440
|
-
"batch (chan patch) emb -> batch chan patch emb",
|
|
441
|
-
batch=batch_size,
|
|
442
|
-
chan=channel,
|
|
443
|
-
patch=n_patches,
|
|
444
|
-
emb=self.embed_dim,
|
|
445
|
-
)
|
|
446
|
-
|
|
447
|
-
if self.use_attention_pooling:
|
|
448
|
-
x = self._attention_pooling(x)
|
|
449
|
-
|
|
450
|
-
x = self.final_layer(x)
|
|
451
|
-
return x
|
|
452
|
-
|
|
453
|
-
def _attention_pooling(self, x: torch.Tensor) -> torch.Tensor:
|
|
454
|
-
"""Apply attention pooling on the sequence dimension of x.
|
|
455
|
-
|
|
456
|
-
Parameters
|
|
457
|
-
----------
|
|
458
|
-
x : torch.Tensor
|
|
459
|
-
Input tensor of shape (B, C, S, E), where B is the batch size,
|
|
460
|
-
C is the number of channels, S is the sequence length,
|
|
461
|
-
and E is the embedding dimension. Typically the output from the transformer.
|
|
462
|
-
|
|
463
|
-
Returns
|
|
464
|
-
-------
|
|
465
|
-
torch.Tensor
|
|
466
|
-
Output tensor of shape (B, E) after attention pooling.
|
|
467
|
-
"""
|
|
468
|
-
|
|
469
|
-
batch_size, n_channels, seq_len, embed_dim = x.shape
|
|
470
|
-
# Flatten channel and sequence dimensions for attention pooling
|
|
471
|
-
# (batch, channels, seq_len, embed_dim) -> (batch, channels * seq_len, embed_dim)
|
|
472
|
-
x = rearrange(
|
|
473
|
-
x,
|
|
474
|
-
"batch chan seq emb -> batch (chan seq) emb",
|
|
475
|
-
chan=n_channels,
|
|
476
|
-
seq=seq_len,
|
|
477
|
-
emb=embed_dim,
|
|
478
|
-
)
|
|
479
|
-
query_output = self.cls_query_token.expand(batch_size, -1, -1) # (B, 1, E)
|
|
480
|
-
attention_scores = torch.matmul(query_output, x.transpose(-1, -2)) / (
|
|
481
|
-
self.embed_dim**0.5
|
|
482
|
-
) # (B, 1, C*S)
|
|
483
|
-
attention_weights = torch.softmax(attention_scores, dim=-1) # (B, 1, C*S)
|
|
484
|
-
out = torch.matmul(attention_weights, x).squeeze(1) # (B, E)
|
|
485
|
-
return out
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
#################################################################################
|
|
489
|
-
# Layers #
|
|
490
|
-
#################################################################################
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
class GEGLU(nn.Module):
|
|
494
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
495
|
-
x, gates = x.chunk(2, dim=-1)
|
|
496
|
-
return F.gelu(gates) * x
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
class RMSNorm(torch.nn.Module):
|
|
500
|
-
def __init__(self, dim: int, eps: float = 1e-6):
|
|
501
|
-
super().__init__()
|
|
502
|
-
self.eps = eps
|
|
503
|
-
self.weight = nn.Parameter(torch.ones(dim))
|
|
504
|
-
|
|
505
|
-
def _norm(self, x: torch.Tensor) -> torch.Tensor:
|
|
506
|
-
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
507
|
-
|
|
508
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
509
|
-
output = self._norm(x.float()).type_as(x)
|
|
510
|
-
return output * self.weight
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
class FeedForward(nn.Module):
|
|
514
|
-
def __init__(self, dim: int, hidden_dim: int, geglu: bool):
|
|
515
|
-
super().__init__()
|
|
516
|
-
self.net = nn.Sequential(
|
|
517
|
-
RMSNorm(dim),
|
|
518
|
-
nn.Linear(dim, hidden_dim * 2 if geglu else hidden_dim, bias=False),
|
|
519
|
-
GEGLU() if geglu else nn.GELU(),
|
|
520
|
-
nn.Linear(hidden_dim, dim, bias=False),
|
|
521
|
-
)
|
|
522
|
-
|
|
523
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
524
|
-
return self.net(x)
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
#################################################################################
|
|
528
|
-
# Attention #
|
|
529
|
-
#################################################################################
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
class ClassicalAttention(nn.Module):
|
|
533
|
-
def __init__(self, heads: int, use_sdpa: bool = True):
|
|
534
|
-
super().__init__()
|
|
535
|
-
self.use_sdpa = use_sdpa
|
|
536
|
-
self.heads = heads
|
|
537
|
-
|
|
538
|
-
def forward(self, qkv: torch.Tensor) -> torch.Tensor:
|
|
539
|
-
# Split concatenated QKV into separate tensors
|
|
540
|
-
# qkv shape: (batch, seq_len, 3 * heads * head_dim)
|
|
541
|
-
q, k, v = qkv.chunk(3, dim=-1)
|
|
542
|
-
|
|
543
|
-
# Reshape for multi-head attention: split last dim into (heads, head_dim)
|
|
544
|
-
# (batch, seq_len, heads * head_dim) -> (batch, heads, seq_len, head_dim)
|
|
545
|
-
q, k, v = (
|
|
546
|
-
rearrange(
|
|
547
|
-
t,
|
|
548
|
-
"batch seq (heads dim) -> batch heads seq dim",
|
|
549
|
-
heads=self.heads,
|
|
550
|
-
)
|
|
551
|
-
for t in (q, k, v)
|
|
552
|
-
)
|
|
553
|
-
|
|
554
|
-
if self.use_sdpa: # SDPA Implementation
|
|
555
|
-
with sdpa_kernel(
|
|
556
|
-
[
|
|
557
|
-
SDPBackend.FLASH_ATTENTION,
|
|
558
|
-
SDPBackend.EFFICIENT_ATTENTION,
|
|
559
|
-
SDPBackend.MATH,
|
|
560
|
-
]
|
|
561
|
-
):
|
|
562
|
-
out = F.scaled_dot_product_attention(q, k, v)
|
|
563
|
-
else: # Naive Implementation
|
|
564
|
-
_, _, scale = q.shape[-2], q.device, q.shape[-1] ** -0.5
|
|
565
|
-
dots = torch.matmul(q, k.transpose(-1, -2)) * scale
|
|
566
|
-
attn = nn.Softmax(dim=-1)(dots)
|
|
567
|
-
out = torch.matmul(attn, v)
|
|
568
|
-
|
|
569
|
-
# Merge heads back: (batch, heads, seq_len, head_dim) -> (batch, seq_len, heads * head_dim)
|
|
570
|
-
out = rearrange(out, "batch heads seq dim -> batch seq (heads dim)")
|
|
571
|
-
return out
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
class Attention(nn.Module):
|
|
575
|
-
"""
|
|
576
|
-
Multi-head self-attention layer with RMSNorm.
|
|
577
|
-
"""
|
|
578
|
-
|
|
579
|
-
def __init__(self, dim: int, heads: int = 8, head_dim: int = 64):
|
|
580
|
-
super().__init__()
|
|
581
|
-
inner_dim = head_dim * heads
|
|
582
|
-
self.heads = heads
|
|
583
|
-
self.norm = RMSNorm(dim)
|
|
584
|
-
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
|
585
|
-
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
|
586
|
-
|
|
587
|
-
self.attend = ClassicalAttention(self.heads, use_sdpa=True)
|
|
588
|
-
|
|
589
|
-
def forward(self, x):
|
|
590
|
-
x = self.norm(x)
|
|
591
|
-
qkv = self.to_qkv(x)
|
|
592
|
-
out = self.attend(qkv)
|
|
593
|
-
return self.to_out(out)
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
#################################################################################
|
|
597
|
-
# Transformer #
|
|
598
|
-
#################################################################################
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
class TransformerBackbone(nn.Module):
|
|
602
|
-
"""
|
|
603
|
-
Transformer backbone consisting of multiple layers of attention and feed-forward networks.
|
|
604
|
-
"""
|
|
605
|
-
|
|
606
|
-
def __init__(self, dim, depth, heads, head_dim, mlp_dim, geglu):
|
|
607
|
-
super().__init__()
|
|
608
|
-
self.dim = dim
|
|
609
|
-
self.layers = nn.ModuleList([])
|
|
610
|
-
for _ in range(depth):
|
|
611
|
-
self.layers.append(
|
|
612
|
-
nn.ModuleList(
|
|
613
|
-
[
|
|
614
|
-
Attention(
|
|
615
|
-
self.dim,
|
|
616
|
-
heads=heads,
|
|
617
|
-
head_dim=head_dim,
|
|
618
|
-
),
|
|
619
|
-
FeedForward(self.dim, mlp_dim, geglu),
|
|
620
|
-
]
|
|
621
|
-
)
|
|
622
|
-
)
|
|
623
|
-
|
|
624
|
-
def forward(
|
|
625
|
-
self, x, return_out_layers=False
|
|
626
|
-
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
|
627
|
-
out_layers = [x] if return_out_layers else []
|
|
628
|
-
for attn, ff in self.layers:
|
|
629
|
-
x = attn(x) + x
|
|
630
|
-
x = ff(x) + x
|
|
631
|
-
if return_out_layers:
|
|
632
|
-
out_layers.append(x)
|
|
633
|
-
return out_layers if return_out_layers else x
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
##################################################################################
|
|
637
|
-
# 4D PE #
|
|
638
|
-
##################################################################################
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
class FourierEmb4D(nn.Module):
|
|
642
|
-
"""
|
|
643
|
-
Fourier positional embedding for 4D positions (x, y, z, t).
|
|
644
|
-
This version allows for a reduced number of frequencies (n_freqs),
|
|
645
|
-
and ensures the output embedding has the specified dimension.
|
|
646
|
-
|
|
647
|
-
Parameters
|
|
648
|
-
----------
|
|
649
|
-
dimension : int
|
|
650
|
-
The dimension of the output embedding. Must be an even number.
|
|
651
|
-
freqs : int
|
|
652
|
-
The number of frequencies to use for the Fourier embedding.
|
|
653
|
-
increment_time : float, optional
|
|
654
|
-
The time increment to scale the time dimension. Default is 0.1.
|
|
655
|
-
margin : float, optional
|
|
656
|
-
The margin to add to the position coordinates to avoid boundary issues. Default is 0.4.
|
|
657
|
-
|
|
658
|
-
"""
|
|
659
|
-
|
|
660
|
-
def __init__(
|
|
661
|
-
self, dimension: int, freqs: int, increment_time=0.1, margin: float = 0.4
|
|
662
|
-
):
|
|
663
|
-
super().__init__()
|
|
664
|
-
self.dimension = dimension
|
|
665
|
-
self.freqs = freqs
|
|
666
|
-
self.increment_time = increment_time
|
|
667
|
-
self.margin = margin
|
|
668
|
-
|
|
669
|
-
def forward(self, positions_: torch.Tensor) -> torch.Tensor:
|
|
670
|
-
positions = positions_.clone()
|
|
671
|
-
positions[:, :, -1] *= self.increment_time
|
|
672
|
-
input_shape = positions.shape
|
|
673
|
-
batch_dims = list(input_shape[:-1])
|
|
674
|
-
|
|
675
|
-
freqs_w = torch.arange(self.freqs).to(positions)
|
|
676
|
-
freqs_z = freqs_w[:, None]
|
|
677
|
-
freqs_y = freqs_z[:, None]
|
|
678
|
-
freqs_x = freqs_y[:, None]
|
|
679
|
-
width = 1 + 2 * self.margin
|
|
680
|
-
positions = positions + self.margin
|
|
681
|
-
p_x = 2 * math.pi * freqs_x / width
|
|
682
|
-
p_y = 2 * math.pi * freqs_y / width
|
|
683
|
-
p_z = 2 * math.pi * freqs_z / width
|
|
684
|
-
p_w = 2 * math.pi * freqs_w / width
|
|
685
|
-
positions = positions[..., None, None, None, None, :]
|
|
686
|
-
loc = (
|
|
687
|
-
positions[..., 0] * p_x
|
|
688
|
-
+ positions[..., 1] * p_y
|
|
689
|
-
+ positions[..., 2] * p_z
|
|
690
|
-
+ positions[..., 3] * p_w
|
|
691
|
-
)
|
|
692
|
-
batch_dims.append(-1)
|
|
693
|
-
loc = loc.view(batch_dims)
|
|
694
|
-
|
|
695
|
-
half_dim = self.dimension // 2
|
|
696
|
-
current_dim = loc.shape[-1]
|
|
697
|
-
if current_dim != half_dim:
|
|
698
|
-
if current_dim > half_dim:
|
|
699
|
-
loc = loc[..., :half_dim]
|
|
700
|
-
else:
|
|
701
|
-
raise ValueError(
|
|
702
|
-
f"Input dimension ({current_dim}) is too small for target "
|
|
703
|
-
f"embedding dimension ({self.dimension}). Expected at least {half_dim}."
|
|
704
|
-
)
|
|
705
|
-
|
|
706
|
-
emb = torch.cat([torch.cos(loc), torch.sin(loc)], dim=-1)
|
|
707
|
-
return emb
|
|
708
|
-
|
|
709
|
-
@classmethod
|
|
710
|
-
def add_time_patch(cls, pos: torch.Tensor, num_patches: int) -> torch.Tensor:
|
|
711
|
-
"""
|
|
712
|
-
Expand the position tensor by adding a time dimension, handling batched data.
|
|
713
|
-
|
|
714
|
-
Parameters
|
|
715
|
-
----------
|
|
716
|
-
pos : torch.Tensor
|
|
717
|
-
Input tensor of shape (B, C, 3), where B is the batch size,
|
|
718
|
-
C is the number of channels, and 3 represents x, y, z.
|
|
719
|
-
num_patches : int
|
|
720
|
-
The number of time patches.
|
|
721
|
-
|
|
722
|
-
Returns
|
|
723
|
-
-------
|
|
724
|
-
torch.Tensor
|
|
725
|
-
Output tensor of shape (B, C * num_patches, 4), where each position is repeated with each time value.
|
|
726
|
-
"""
|
|
727
|
-
batch, nchans, _ = pos.shape
|
|
728
|
-
# Repeat each position for each time step
|
|
729
|
-
pos_repeated = pos.unsqueeze(2).repeat(
|
|
730
|
-
1, 1, num_patches, 1
|
|
731
|
-
) # Shape: (batch, nchans, num_patches, 3)
|
|
732
|
-
# Generate time values with the specified increment
|
|
733
|
-
time_values = torch.arange(
|
|
734
|
-
0, num_patches, 1, device=pos.device
|
|
735
|
-
).float() # Shape: (num_patches,)
|
|
736
|
-
time_values = time_values.view(1, 1, num_patches, 1).expand(
|
|
737
|
-
batch, nchans, num_patches, 1
|
|
738
|
-
) # (batch, nchans, num_patches, 1)
|
|
739
|
-
# Concatenate the repeated positions with the time values along the last dimension
|
|
740
|
-
pos_with_time = torch.cat(
|
|
741
|
-
(pos_repeated, time_values), dim=-1
|
|
742
|
-
) # Shape: (batch, nchans, num_patches, 4)
|
|
743
|
-
# Reshape to (batch, nchans * num_patches, 4)
|
|
744
|
-
pos_with_time = pos_with_time.view(batch, nchans * num_patches, 4)
|
|
745
|
-
|
|
746
|
-
return pos_with_time
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
class RevePositionBank(torch.nn.Module):
|
|
750
|
-
"""Position bank for REVE model that maps standard EEG channel names to 3D coordinates.
|
|
751
|
-
|
|
752
|
-
The position bank is cached locally in the library root to avoid repeated downloads.
|
|
753
|
-
|
|
754
|
-
The coordinates come from the 92 datasets used during REVE pretraining.
|
|
755
|
-
|
|
756
|
-
Parameters
|
|
757
|
-
----------
|
|
758
|
-
url : str, optional
|
|
759
|
-
URL to download the position bank JSON file. Default is the HuggingFace URL.
|
|
760
|
-
timeout : int, optional
|
|
761
|
-
Timeout in seconds for the HTTP request. Default is 5 seconds.
|
|
762
|
-
cache_dir : str, optional
|
|
763
|
-
Directory to cache the position bank. Default is the models folder within the library.
|
|
764
|
-
"""
|
|
765
|
-
|
|
766
|
-
def __init__(
|
|
767
|
-
self,
|
|
768
|
-
url: str = "https://huggingface.co/brain-bzh/reve-positions/resolve/main/positions.json",
|
|
769
|
-
timeout: int = 5,
|
|
770
|
-
cache_dir: Optional[str] = None,
|
|
771
|
-
):
|
|
772
|
-
super().__init__()
|
|
773
|
-
|
|
774
|
-
if cache_dir is None:
|
|
775
|
-
# Use the model root directory
|
|
776
|
-
cache_dir = str(Path(__file__).parent)
|
|
777
|
-
|
|
778
|
-
cache_file = os.path.join(cache_dir, ".cache", "reve_positions.json")
|
|
779
|
-
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
|
|
780
|
-
|
|
781
|
-
config = None
|
|
782
|
-
|
|
783
|
-
# Try to load from cache first
|
|
784
|
-
if os.path.exists(cache_file):
|
|
785
|
-
try:
|
|
786
|
-
with open(cache_file, "r") as f:
|
|
787
|
-
config = json.load(f)
|
|
788
|
-
logger.info(f"Loaded position bank from cache: {cache_file}")
|
|
789
|
-
except (json.JSONDecodeError, IOError) as e:
|
|
790
|
-
logger.warning(f"Failed to load cache, downloading: {e}")
|
|
791
|
-
|
|
792
|
-
# Download if cache miss or failed to load
|
|
793
|
-
if config is None:
|
|
794
|
-
try:
|
|
795
|
-
response = requests.get(url, timeout=timeout)
|
|
796
|
-
response.raise_for_status()
|
|
797
|
-
config = json.loads(response.text)
|
|
798
|
-
|
|
799
|
-
# Save to cache
|
|
800
|
-
with open(cache_file, "w") as f:
|
|
801
|
-
json.dump(config, f)
|
|
802
|
-
logger.info(f"Downloaded and cached position bank to: {cache_file}")
|
|
803
|
-
except (requests.RequestException, json.JSONDecodeError) as e:
|
|
804
|
-
raise RuntimeError(
|
|
805
|
-
f"Failed to download or parse the position bank from {url}: {e}"
|
|
806
|
-
) from e
|
|
807
|
-
|
|
808
|
-
try:
|
|
809
|
-
self.position_names = list(config.keys())
|
|
810
|
-
self.mapping = {name: i for i, name in enumerate(self.position_names)}
|
|
811
|
-
positions_data = list(config.values())
|
|
812
|
-
positions = torch.tensor(positions_data, dtype=torch.float32)
|
|
813
|
-
self.register_buffer("embedding", positions)
|
|
814
|
-
|
|
815
|
-
# Validate shape (N, 3)
|
|
816
|
-
if self.embedding.dim() != 2 or self.embedding.shape[1] != 3:
|
|
817
|
-
raise ValueError(
|
|
818
|
-
f"Position data must have shape (N, 3), but got {self.embedding.shape}"
|
|
819
|
-
)
|
|
820
|
-
|
|
821
|
-
assert self.embedding.shape == (len(self.mapping), 3), (
|
|
822
|
-
f"Expected embedding shape ({len(self.mapping)}, 3), but got {self.embedding.shape}"
|
|
823
|
-
)
|
|
824
|
-
|
|
825
|
-
except (ValueError, TypeError, AssertionError) as e:
|
|
826
|
-
raise RuntimeError(
|
|
827
|
-
f"Invalid position data format in the downloaded config: {e}"
|
|
828
|
-
) from e
|
|
829
|
-
|
|
830
|
-
def forward(self, channel_names: list[str]):
|
|
831
|
-
indices = [self.mapping[q] for q in channel_names if q in self.mapping]
|
|
832
|
-
|
|
833
|
-
if len(indices) < len(channel_names):
|
|
834
|
-
logger.warning(
|
|
835
|
-
f"Found {len(indices)} positions out of {len(channel_names)} channels"
|
|
836
|
-
)
|
|
837
|
-
|
|
838
|
-
indices = torch.tensor(indices, device=self.embedding.device)
|
|
839
|
-
|
|
840
|
-
return self.embedding[indices]
|
|
841
|
-
|
|
842
|
-
def get_all_positions(self):
|
|
843
|
-
return self.position_names
|