cortexflowx 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortexflow/__init__.py +78 -0
- cortexflow/_types.py +115 -0
- cortexflow/brain2audio.py +298 -0
- cortexflow/brain2img.py +234 -0
- cortexflow/brain2text.py +278 -0
- cortexflow/brain_encoder.py +228 -0
- cortexflow/dit.py +397 -0
- cortexflow/flow_matching.py +236 -0
- cortexflow/training.py +283 -0
- cortexflow/vae.py +232 -0
- cortexflowx-0.1.0.dist-info/METADATA +218 -0
- cortexflowx-0.1.0.dist-info/RECORD +14 -0
- cortexflowx-0.1.0.dist-info/WHEEL +4 -0
- cortexflowx-0.1.0.dist-info/licenses/LICENSE +190 -0
cortexflow/brain2img.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
"""Brain → Image reconstruction pipeline.
|
|
2
|
+
|
|
3
|
+
End-to-end pipeline: fMRI voxels → BrainEncoder → DiT (flow matching)
|
|
4
|
+
→ VAE decoder → RGB image.
|
|
5
|
+
|
|
6
|
+
This is the core brain decoding pipeline. Given measured fMRI activity
|
|
7
|
+
while a subject views an image, reconstruct what they saw.
|
|
8
|
+
|
|
9
|
+
Reference architecture inspired by:
|
|
10
|
+
- MindEye (Scotti et al. 2023) for brain-to-image conditioning
|
|
11
|
+
- SD3 / FLUX for the DiT + flow matching backbone
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
import torch.nn as nn
|
|
18
|
+
|
|
19
|
+
from cortexflow._types import (
|
|
20
|
+
BrainData,
|
|
21
|
+
DiTConfig,
|
|
22
|
+
FlowConfig,
|
|
23
|
+
Modality,
|
|
24
|
+
ReconstructionResult,
|
|
25
|
+
VAEConfig,
|
|
26
|
+
)
|
|
27
|
+
from cortexflow.brain_encoder import BrainEncoder
|
|
28
|
+
from cortexflow.dit import DiffusionTransformer
|
|
29
|
+
from cortexflow.flow_matching import RectifiedFlowMatcher
|
|
30
|
+
from cortexflow.vae import LatentVAE
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Brain2Image(nn.Module):
|
|
34
|
+
"""Reconstruct images from brain activity using DiT + Flow Matching.
|
|
35
|
+
|
|
36
|
+
Full pipeline::
|
|
37
|
+
|
|
38
|
+
fMRI voxels
|
|
39
|
+
→ BrainEncoder (global + tokens)
|
|
40
|
+
→ DiffusionTransformer (flow matching ODE)
|
|
41
|
+
→ LatentVAE.decode()
|
|
42
|
+
→ RGB image
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
n_voxels: Number of fMRI voxels.
|
|
46
|
+
img_size: Output image spatial resolution.
|
|
47
|
+
dit_config: DiT architecture configuration.
|
|
48
|
+
vae_config: VAE configuration.
|
|
49
|
+
flow_config: Flow matching configuration.
|
|
50
|
+
n_brain_tokens: Number of brain conditioning tokens.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
n_voxels: int = 1024,
|
|
56
|
+
img_size: int = 64,
|
|
57
|
+
dit_config: DiTConfig | None = None,
|
|
58
|
+
vae_config: VAEConfig | None = None,
|
|
59
|
+
flow_config: FlowConfig | None = None,
|
|
60
|
+
n_brain_tokens: int = 16,
|
|
61
|
+
) -> None:
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.img_size = img_size
|
|
64
|
+
|
|
65
|
+
# Configs
|
|
66
|
+
dit_cfg = dit_config or DiTConfig()
|
|
67
|
+
vae_cfg = vae_config or VAEConfig()
|
|
68
|
+
flow_cfg = flow_config or FlowConfig()
|
|
69
|
+
|
|
70
|
+
# VAE: determines latent spatial size
|
|
71
|
+
self.vae = LatentVAE(vae_cfg)
|
|
72
|
+
n_downsample = len(vae_cfg.hidden_dims)
|
|
73
|
+
latent_size = img_size // (2 ** n_downsample)
|
|
74
|
+
|
|
75
|
+
# Brain encoder
|
|
76
|
+
self.brain_encoder = BrainEncoder(
|
|
77
|
+
n_voxels=n_voxels,
|
|
78
|
+
cond_dim=dit_cfg.cond_dim,
|
|
79
|
+
n_tokens=n_brain_tokens,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Unconditional embeddings for classifier-free guidance
|
|
83
|
+
self.uncond_global = nn.Parameter(torch.zeros(1, dit_cfg.cond_dim))
|
|
84
|
+
self.uncond_tokens = nn.Parameter(torch.zeros(1, n_brain_tokens, dit_cfg.cond_dim))
|
|
85
|
+
|
|
86
|
+
# DiT operates on VAE latent space
|
|
87
|
+
dit_cfg_updated = DiTConfig(
|
|
88
|
+
in_channels=vae_cfg.latent_channels,
|
|
89
|
+
hidden_dim=dit_cfg.hidden_dim,
|
|
90
|
+
depth=dit_cfg.depth,
|
|
91
|
+
num_heads=dit_cfg.num_heads,
|
|
92
|
+
patch_size=dit_cfg.patch_size,
|
|
93
|
+
cond_dim=dit_cfg.cond_dim,
|
|
94
|
+
mlp_ratio=dit_cfg.mlp_ratio,
|
|
95
|
+
qk_norm=dit_cfg.qk_norm,
|
|
96
|
+
use_cross_attn=dit_cfg.use_cross_attn,
|
|
97
|
+
)
|
|
98
|
+
self.dit = DiffusionTransformer(dit_cfg_updated, img_size=latent_size)
|
|
99
|
+
|
|
100
|
+
# Flow matcher
|
|
101
|
+
self.flow_matcher = RectifiedFlowMatcher(flow_cfg)
|
|
102
|
+
|
|
103
|
+
self._latent_size = latent_size
|
|
104
|
+
self._latent_channels = vae_cfg.latent_channels
|
|
105
|
+
|
|
106
|
+
def encode_brain(
|
|
107
|
+
self, brain_data: BrainData
|
|
108
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
109
|
+
"""Encode fMRI to conditioning signals."""
|
|
110
|
+
return self.brain_encoder(brain_data.voxels)
|
|
111
|
+
|
|
112
|
+
def training_loss(
|
|
113
|
+
self,
|
|
114
|
+
images: torch.Tensor,
|
|
115
|
+
brain_data: BrainData,
|
|
116
|
+
cfg_dropout: float = 0.1,
|
|
117
|
+
) -> torch.Tensor:
|
|
118
|
+
"""Compute training loss for brain-conditioned image generation.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
images: Target images ``(B, 3, H, W)``.
|
|
122
|
+
brain_data: Corresponding fMRI data.
|
|
123
|
+
cfg_dropout: Probability of dropping brain condition for CFG training.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Scalar loss.
|
|
127
|
+
"""
|
|
128
|
+
B = images.shape[0]
|
|
129
|
+
|
|
130
|
+
# Encode images to latent space (detach VAE — train separately or frozen)
|
|
131
|
+
with torch.no_grad():
|
|
132
|
+
z, _, _ = self.vae.encode(images)
|
|
133
|
+
|
|
134
|
+
# Encode brain
|
|
135
|
+
brain_global, brain_tokens = self.encode_brain(brain_data)
|
|
136
|
+
|
|
137
|
+
# CFG training: randomly drop conditioning
|
|
138
|
+
if cfg_dropout > 0 and self.training:
|
|
139
|
+
mask = torch.rand(B, device=images.device) < cfg_dropout
|
|
140
|
+
if mask.any():
|
|
141
|
+
brain_global = brain_global.clone()
|
|
142
|
+
brain_tokens = brain_tokens.clone()
|
|
143
|
+
brain_global[mask] = self.uncond_global.expand(mask.sum(), -1)
|
|
144
|
+
brain_tokens[mask] = self.uncond_tokens.expand(mask.sum(), -1, -1)
|
|
145
|
+
|
|
146
|
+
# Flow matching loss on latent space
|
|
147
|
+
return self.flow_matcher.compute_loss(
|
|
148
|
+
self.dit, z, brain_global, brain_tokens
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
@torch.no_grad()
|
|
152
|
+
def reconstruct(
|
|
153
|
+
self,
|
|
154
|
+
brain_data: BrainData,
|
|
155
|
+
num_steps: int = 50,
|
|
156
|
+
cfg_scale: float = 4.0,
|
|
157
|
+
) -> ReconstructionResult:
|
|
158
|
+
"""Reconstruct an image from brain activity.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
brain_data: fMRI data to decode.
|
|
162
|
+
num_steps: Number of ODE solver steps.
|
|
163
|
+
cfg_scale: Classifier-free guidance scale.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
ReconstructionResult with the decoded image.
|
|
167
|
+
"""
|
|
168
|
+
B = brain_data.batch_size
|
|
169
|
+
device = brain_data.voxels.device
|
|
170
|
+
|
|
171
|
+
# Encode brain
|
|
172
|
+
brain_global, brain_tokens = self.encode_brain(brain_data)
|
|
173
|
+
|
|
174
|
+
# Unconditional embeddings for CFG
|
|
175
|
+
uncond_global = self.uncond_global.expand(B, -1)
|
|
176
|
+
uncond_tokens = self.uncond_tokens.expand(B, -1, -1)
|
|
177
|
+
|
|
178
|
+
# Sample latents via flow matching
|
|
179
|
+
latent_shape = (B, self._latent_channels, self._latent_size, self._latent_size)
|
|
180
|
+
z = self.flow_matcher.sample(
|
|
181
|
+
self.dit,
|
|
182
|
+
shape=latent_shape,
|
|
183
|
+
brain_global=brain_global,
|
|
184
|
+
brain_tokens=brain_tokens,
|
|
185
|
+
num_steps=num_steps,
|
|
186
|
+
cfg_scale=cfg_scale,
|
|
187
|
+
brain_global_uncond=uncond_global,
|
|
188
|
+
brain_tokens_uncond=uncond_tokens,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# Decode latents to images
|
|
192
|
+
images = self.vae.decode(z)
|
|
193
|
+
images = images.clamp(0, 1)
|
|
194
|
+
|
|
195
|
+
return ReconstructionResult(
|
|
196
|
+
modality=Modality.IMAGE,
|
|
197
|
+
output=images,
|
|
198
|
+
brain_condition=brain_global,
|
|
199
|
+
n_steps=num_steps,
|
|
200
|
+
cfg_scale=cfg_scale,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def build_brain2img(
|
|
205
|
+
n_voxels: int = 1024,
|
|
206
|
+
img_size: int = 64,
|
|
207
|
+
hidden_dim: int = 256,
|
|
208
|
+
depth: int = 6,
|
|
209
|
+
num_heads: int = 8,
|
|
210
|
+
) -> Brain2Image:
|
|
211
|
+
"""Build a Brain2Image model with sensible defaults.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
n_voxels: Number of fMRI voxels.
|
|
215
|
+
img_size: Target image size (square).
|
|
216
|
+
hidden_dim: DiT hidden dimension.
|
|
217
|
+
depth: Number of DiT blocks.
|
|
218
|
+
num_heads: Attention heads.
|
|
219
|
+
"""
|
|
220
|
+
dit_cfg = DiTConfig(
|
|
221
|
+
hidden_dim=hidden_dim,
|
|
222
|
+
depth=depth,
|
|
223
|
+
num_heads=num_heads,
|
|
224
|
+
cond_dim=hidden_dim,
|
|
225
|
+
)
|
|
226
|
+
vae_cfg = VAEConfig(hidden_dims=[32, 64]) # lightweight for small images
|
|
227
|
+
flow_cfg = FlowConfig()
|
|
228
|
+
return Brain2Image(
|
|
229
|
+
n_voxels=n_voxels,
|
|
230
|
+
img_size=img_size,
|
|
231
|
+
dit_config=dit_cfg,
|
|
232
|
+
vae_config=vae_cfg,
|
|
233
|
+
flow_config=flow_cfg,
|
|
234
|
+
)
|
cortexflow/brain2text.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
"""Brain → Text reconstruction pipeline.
|
|
2
|
+
|
|
3
|
+
Reconstruct what someone read or thought in linguistic form from fMRI.
|
|
4
|
+
|
|
5
|
+
Architecture: fMRI → BrainEncoder → Transformer Decoder with
|
|
6
|
+
cross-attention to brain tokens → autoregressive text generation.
|
|
7
|
+
|
|
8
|
+
Unlike image/audio which use flow matching on continuous latents,
|
|
9
|
+
text uses autoregressive decoding since language is inherently discrete.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import math
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
import torch.nn as nn
|
|
18
|
+
import torch.nn.functional as F
|
|
19
|
+
|
|
20
|
+
from cortexflow._types import BrainData, Modality, ReconstructionResult
|
|
21
|
+
from cortexflow.brain_encoder import BrainEncoder
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TextDecoderBlock(nn.Module):
|
|
25
|
+
"""Transformer decoder block with causal self-attention + brain cross-attention."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, hidden_dim: int, num_heads: int, cond_dim: int, mlp_ratio: float = 4.0):
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.norm1 = nn.LayerNorm(hidden_dim, eps=1e-6)
|
|
30
|
+
self.self_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
|
|
31
|
+
self.norm2 = nn.LayerNorm(hidden_dim, eps=1e-6)
|
|
32
|
+
self.cross_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
|
|
33
|
+
self.cond_proj = nn.Linear(cond_dim, hidden_dim) if cond_dim != hidden_dim else nn.Identity()
|
|
34
|
+
self.norm3 = nn.LayerNorm(hidden_dim, eps=1e-6)
|
|
35
|
+
mlp_h = int(hidden_dim * mlp_ratio)
|
|
36
|
+
self.mlp = nn.Sequential(
|
|
37
|
+
nn.Linear(hidden_dim, mlp_h), nn.SiLU(),
|
|
38
|
+
nn.Linear(mlp_h, hidden_dim),
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def forward(
|
|
42
|
+
self,
|
|
43
|
+
x: torch.Tensor,
|
|
44
|
+
brain_tokens: torch.Tensor,
|
|
45
|
+
attn_mask: torch.Tensor | None = None,
|
|
46
|
+
) -> torch.Tensor:
|
|
47
|
+
# Causal self-attention
|
|
48
|
+
h = self.norm1(x)
|
|
49
|
+
h, _ = self.self_attn(h, h, h, attn_mask=attn_mask, need_weights=False)
|
|
50
|
+
x = x + h
|
|
51
|
+
|
|
52
|
+
# Cross-attention to brain tokens
|
|
53
|
+
h = self.norm2(x)
|
|
54
|
+
kv = self.cond_proj(brain_tokens)
|
|
55
|
+
h, _ = self.cross_attn(h, kv, kv, need_weights=False)
|
|
56
|
+
x = x + h
|
|
57
|
+
|
|
58
|
+
# Feedforward
|
|
59
|
+
h = self.norm3(x)
|
|
60
|
+
x = x + self.mlp(h)
|
|
61
|
+
return x
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class BrainTextDecoder(nn.Module):
|
|
65
|
+
"""Transformer decoder for brain → text.
|
|
66
|
+
|
|
67
|
+
Uses a simple character/word-piece vocabulary for self-contained
|
|
68
|
+
operation without external tokenizers.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
vocab_size: int = 256,
|
|
74
|
+
max_len: int = 128,
|
|
75
|
+
hidden_dim: int = 256,
|
|
76
|
+
depth: int = 6,
|
|
77
|
+
num_heads: int = 8,
|
|
78
|
+
cond_dim: int = 256,
|
|
79
|
+
):
|
|
80
|
+
super().__init__()
|
|
81
|
+
self.vocab_size = vocab_size
|
|
82
|
+
self.max_len = max_len
|
|
83
|
+
self.hidden_dim = hidden_dim
|
|
84
|
+
|
|
85
|
+
self.token_embed = nn.Embedding(vocab_size, hidden_dim)
|
|
86
|
+
self.pos_embed = nn.Embedding(max_len, hidden_dim)
|
|
87
|
+
|
|
88
|
+
self.blocks = nn.ModuleList([
|
|
89
|
+
TextDecoderBlock(hidden_dim, num_heads, cond_dim) for _ in range(depth)
|
|
90
|
+
])
|
|
91
|
+
self.final_norm = nn.LayerNorm(hidden_dim, eps=1e-6)
|
|
92
|
+
self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)
|
|
93
|
+
|
|
94
|
+
# Tie weights
|
|
95
|
+
self.lm_head.weight = self.token_embed.weight
|
|
96
|
+
|
|
97
|
+
def _causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
|
98
|
+
"""Generate causal attention mask."""
|
|
99
|
+
mask = torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=device), diagonal=1)
|
|
100
|
+
return mask
|
|
101
|
+
|
|
102
|
+
def forward(
|
|
103
|
+
self,
|
|
104
|
+
token_ids: torch.Tensor,
|
|
105
|
+
brain_tokens: torch.Tensor,
|
|
106
|
+
) -> torch.Tensor:
|
|
107
|
+
"""Forward pass for training (teacher forcing).
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
token_ids: ``(B, T)`` input token IDs.
|
|
111
|
+
brain_tokens: ``(B, K, cond_dim)`` brain conditioning tokens.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
``(B, T, vocab_size)`` logits.
|
|
115
|
+
"""
|
|
116
|
+
B, T = token_ids.shape
|
|
117
|
+
positions = torch.arange(T, device=token_ids.device).unsqueeze(0)
|
|
118
|
+
x = self.token_embed(token_ids) + self.pos_embed(positions)
|
|
119
|
+
|
|
120
|
+
mask = self._causal_mask(T, token_ids.device)
|
|
121
|
+
for block in self.blocks:
|
|
122
|
+
x = block(x, brain_tokens, attn_mask=mask)
|
|
123
|
+
|
|
124
|
+
x = self.final_norm(x)
|
|
125
|
+
return self.lm_head(x)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class Brain2Text(nn.Module):
|
|
129
|
+
"""Reconstruct text from brain activity.
|
|
130
|
+
|
|
131
|
+
Pipeline::
|
|
132
|
+
|
|
133
|
+
fMRI → BrainEncoder → BrainTextDecoder (autoregressive) → tokens → text
|
|
134
|
+
|
|
135
|
+
Uses byte-level encoding (vocab_size=256) for simplicity. Each byte
|
|
136
|
+
is one token, so this handles any text without a tokenizer.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
n_voxels: Number of fMRI voxels.
|
|
140
|
+
max_len: Maximum output text length.
|
|
141
|
+
hidden_dim: Model dimension.
|
|
142
|
+
depth: Number of transformer layers.
|
|
143
|
+
num_heads: Attention heads.
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
def __init__(
|
|
147
|
+
self,
|
|
148
|
+
n_voxels: int = 1024,
|
|
149
|
+
max_len: int = 128,
|
|
150
|
+
hidden_dim: int = 256,
|
|
151
|
+
depth: int = 6,
|
|
152
|
+
num_heads: int = 8,
|
|
153
|
+
):
|
|
154
|
+
super().__init__()
|
|
155
|
+
cond_dim = hidden_dim
|
|
156
|
+
self.brain_encoder = BrainEncoder(
|
|
157
|
+
n_voxels=n_voxels, cond_dim=cond_dim, n_tokens=16,
|
|
158
|
+
)
|
|
159
|
+
self.decoder = BrainTextDecoder(
|
|
160
|
+
vocab_size=256, # byte-level
|
|
161
|
+
max_len=max_len,
|
|
162
|
+
hidden_dim=hidden_dim,
|
|
163
|
+
depth=depth,
|
|
164
|
+
num_heads=num_heads,
|
|
165
|
+
cond_dim=cond_dim,
|
|
166
|
+
)
|
|
167
|
+
self.max_len = max_len
|
|
168
|
+
self.bos_token = 0 # null byte as BOS
|
|
169
|
+
|
|
170
|
+
@staticmethod
|
|
171
|
+
def text_to_tokens(text: str) -> torch.Tensor:
|
|
172
|
+
"""Encode text to byte-level token IDs."""
|
|
173
|
+
return torch.tensor(list(text.encode("utf-8")), dtype=torch.long)
|
|
174
|
+
|
|
175
|
+
@staticmethod
|
|
176
|
+
def tokens_to_text(tokens: torch.Tensor) -> str:
|
|
177
|
+
"""Decode byte-level token IDs to text."""
|
|
178
|
+
byte_list = tokens.cpu().tolist()
|
|
179
|
+
# Stop at first null byte
|
|
180
|
+
if 0 in byte_list:
|
|
181
|
+
byte_list = byte_list[: byte_list.index(0)]
|
|
182
|
+
return bytes(byte_list).decode("utf-8", errors="replace")
|
|
183
|
+
|
|
184
|
+
def training_loss(
|
|
185
|
+
self,
|
|
186
|
+
text_tokens: torch.Tensor,
|
|
187
|
+
brain_data: BrainData,
|
|
188
|
+
) -> torch.Tensor:
|
|
189
|
+
"""Compute cross-entropy loss for text generation.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
text_tokens: ``(B, T)`` target token IDs (byte-level).
|
|
193
|
+
brain_data: Corresponding fMRI data.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
Scalar cross-entropy loss.
|
|
197
|
+
"""
|
|
198
|
+
_, brain_tokens = self.brain_encoder(brain_data.voxels)
|
|
199
|
+
|
|
200
|
+
# Input: [BOS, t1, t2, ..., t_{n-1}], Target: [t1, t2, ..., t_n]
|
|
201
|
+
input_tokens = text_tokens[:, :-1]
|
|
202
|
+
target_tokens = text_tokens[:, 1:]
|
|
203
|
+
|
|
204
|
+
logits = self.decoder(input_tokens, brain_tokens)
|
|
205
|
+
return F.cross_entropy(
|
|
206
|
+
logits.reshape(-1, logits.shape[-1]),
|
|
207
|
+
target_tokens.reshape(-1),
|
|
208
|
+
ignore_index=0,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
@torch.no_grad()
|
|
212
|
+
def reconstruct(
|
|
213
|
+
self,
|
|
214
|
+
brain_data: BrainData,
|
|
215
|
+
max_len: int | None = None,
|
|
216
|
+
temperature: float = 0.8,
|
|
217
|
+
top_k: int = 50,
|
|
218
|
+
) -> ReconstructionResult:
|
|
219
|
+
"""Reconstruct text from brain activity via autoregressive decoding.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
brain_data: fMRI data to decode.
|
|
223
|
+
max_len: Maximum generation length.
|
|
224
|
+
temperature: Sampling temperature.
|
|
225
|
+
top_k: Top-k filtering.
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
ReconstructionResult with generated text as metadata.
|
|
229
|
+
"""
|
|
230
|
+
B = brain_data.batch_size
|
|
231
|
+
device = brain_data.voxels.device
|
|
232
|
+
gen_len = max_len or self.max_len
|
|
233
|
+
|
|
234
|
+
_, brain_tokens = self.brain_encoder(brain_data.voxels)
|
|
235
|
+
|
|
236
|
+
# Start with BOS token
|
|
237
|
+
generated = torch.full((B, 1), self.bos_token, dtype=torch.long, device=device)
|
|
238
|
+
|
|
239
|
+
for _ in range(gen_len - 1):
|
|
240
|
+
logits = self.decoder(generated, brain_tokens)
|
|
241
|
+
next_logits = logits[:, -1, :] / temperature
|
|
242
|
+
|
|
243
|
+
# Top-k filtering
|
|
244
|
+
if top_k > 0:
|
|
245
|
+
topk_vals, _ = next_logits.topk(top_k, dim=-1)
|
|
246
|
+
threshold = topk_vals[:, -1].unsqueeze(-1)
|
|
247
|
+
next_logits = next_logits.masked_fill(next_logits < threshold, float("-inf"))
|
|
248
|
+
|
|
249
|
+
probs = F.softmax(next_logits, dim=-1)
|
|
250
|
+
next_token = torch.multinomial(probs, num_samples=1)
|
|
251
|
+
generated = torch.cat([generated, next_token], dim=1)
|
|
252
|
+
|
|
253
|
+
# Stop if all sequences produced a null byte
|
|
254
|
+
if (next_token == 0).all():
|
|
255
|
+
break
|
|
256
|
+
|
|
257
|
+
# Decode to text
|
|
258
|
+
texts = [self.tokens_to_text(generated[i]) for i in range(B)]
|
|
259
|
+
|
|
260
|
+
return ReconstructionResult(
|
|
261
|
+
modality=Modality.TEXT,
|
|
262
|
+
output=generated,
|
|
263
|
+
brain_condition=brain_tokens.mean(dim=1),
|
|
264
|
+
metadata={"texts": texts},
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def build_brain2text(
|
|
269
|
+
n_voxels: int = 1024,
|
|
270
|
+
max_len: int = 128,
|
|
271
|
+
hidden_dim: int = 256,
|
|
272
|
+
depth: int = 6,
|
|
273
|
+
) -> Brain2Text:
|
|
274
|
+
"""Build a Brain2Text model with sensible defaults."""
|
|
275
|
+
return Brain2Text(
|
|
276
|
+
n_voxels=n_voxels, max_len=max_len,
|
|
277
|
+
hidden_dim=hidden_dim, depth=depth,
|
|
278
|
+
)
|