cortexflowx 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,234 @@
1
+ """Brain → Image reconstruction pipeline.
2
+
3
+ End-to-end pipeline: fMRI voxels → BrainEncoder → DiT (flow matching)
4
+ → VAE decoder → RGB image.
5
+
6
+ This is the core brain decoding pipeline. Given measured fMRI activity
7
+ while a subject views an image, reconstruct what they saw.
8
+
9
+ Reference architecture inspired by:
10
+ - MindEye (Scotti et al. 2023) for brain-to-image conditioning
11
+ - SD3 / FLUX for the DiT + flow matching backbone
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from cortexflow._types import (
20
+ BrainData,
21
+ DiTConfig,
22
+ FlowConfig,
23
+ Modality,
24
+ ReconstructionResult,
25
+ VAEConfig,
26
+ )
27
+ from cortexflow.brain_encoder import BrainEncoder
28
+ from cortexflow.dit import DiffusionTransformer
29
+ from cortexflow.flow_matching import RectifiedFlowMatcher
30
+ from cortexflow.vae import LatentVAE
31
+
32
+
33
+ class Brain2Image(nn.Module):
34
+ """Reconstruct images from brain activity using DiT + Flow Matching.
35
+
36
+ Full pipeline::
37
+
38
+ fMRI voxels
39
+ → BrainEncoder (global + tokens)
40
+ → DiffusionTransformer (flow matching ODE)
41
+ → LatentVAE.decode()
42
+ → RGB image
43
+
44
+ Args:
45
+ n_voxels: Number of fMRI voxels.
46
+ img_size: Output image spatial resolution.
47
+ dit_config: DiT architecture configuration.
48
+ vae_config: VAE configuration.
49
+ flow_config: Flow matching configuration.
50
+ n_brain_tokens: Number of brain conditioning tokens.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ n_voxels: int = 1024,
56
+ img_size: int = 64,
57
+ dit_config: DiTConfig | None = None,
58
+ vae_config: VAEConfig | None = None,
59
+ flow_config: FlowConfig | None = None,
60
+ n_brain_tokens: int = 16,
61
+ ) -> None:
62
+ super().__init__()
63
+ self.img_size = img_size
64
+
65
+ # Configs
66
+ dit_cfg = dit_config or DiTConfig()
67
+ vae_cfg = vae_config or VAEConfig()
68
+ flow_cfg = flow_config or FlowConfig()
69
+
70
+ # VAE: determines latent spatial size
71
+ self.vae = LatentVAE(vae_cfg)
72
+ n_downsample = len(vae_cfg.hidden_dims)
73
+ latent_size = img_size // (2 ** n_downsample)
74
+
75
+ # Brain encoder
76
+ self.brain_encoder = BrainEncoder(
77
+ n_voxels=n_voxels,
78
+ cond_dim=dit_cfg.cond_dim,
79
+ n_tokens=n_brain_tokens,
80
+ )
81
+
82
+ # Unconditional embeddings for classifier-free guidance
83
+ self.uncond_global = nn.Parameter(torch.zeros(1, dit_cfg.cond_dim))
84
+ self.uncond_tokens = nn.Parameter(torch.zeros(1, n_brain_tokens, dit_cfg.cond_dim))
85
+
86
+ # DiT operates on VAE latent space
87
+ dit_cfg_updated = DiTConfig(
88
+ in_channels=vae_cfg.latent_channels,
89
+ hidden_dim=dit_cfg.hidden_dim,
90
+ depth=dit_cfg.depth,
91
+ num_heads=dit_cfg.num_heads,
92
+ patch_size=dit_cfg.patch_size,
93
+ cond_dim=dit_cfg.cond_dim,
94
+ mlp_ratio=dit_cfg.mlp_ratio,
95
+ qk_norm=dit_cfg.qk_norm,
96
+ use_cross_attn=dit_cfg.use_cross_attn,
97
+ )
98
+ self.dit = DiffusionTransformer(dit_cfg_updated, img_size=latent_size)
99
+
100
+ # Flow matcher
101
+ self.flow_matcher = RectifiedFlowMatcher(flow_cfg)
102
+
103
+ self._latent_size = latent_size
104
+ self._latent_channels = vae_cfg.latent_channels
105
+
106
+ def encode_brain(
107
+ self, brain_data: BrainData
108
+ ) -> tuple[torch.Tensor, torch.Tensor]:
109
+ """Encode fMRI to conditioning signals."""
110
+ return self.brain_encoder(brain_data.voxels)
111
+
112
+ def training_loss(
113
+ self,
114
+ images: torch.Tensor,
115
+ brain_data: BrainData,
116
+ cfg_dropout: float = 0.1,
117
+ ) -> torch.Tensor:
118
+ """Compute training loss for brain-conditioned image generation.
119
+
120
+ Args:
121
+ images: Target images ``(B, 3, H, W)``.
122
+ brain_data: Corresponding fMRI data.
123
+ cfg_dropout: Probability of dropping brain condition for CFG training.
124
+
125
+ Returns:
126
+ Scalar loss.
127
+ """
128
+ B = images.shape[0]
129
+
130
+ # Encode images to latent space (detach VAE — train separately or frozen)
131
+ with torch.no_grad():
132
+ z, _, _ = self.vae.encode(images)
133
+
134
+ # Encode brain
135
+ brain_global, brain_tokens = self.encode_brain(brain_data)
136
+
137
+ # CFG training: randomly drop conditioning
138
+ if cfg_dropout > 0 and self.training:
139
+ mask = torch.rand(B, device=images.device) < cfg_dropout
140
+ if mask.any():
141
+ brain_global = brain_global.clone()
142
+ brain_tokens = brain_tokens.clone()
143
+ brain_global[mask] = self.uncond_global.expand(mask.sum(), -1)
144
+ brain_tokens[mask] = self.uncond_tokens.expand(mask.sum(), -1, -1)
145
+
146
+ # Flow matching loss on latent space
147
+ return self.flow_matcher.compute_loss(
148
+ self.dit, z, brain_global, brain_tokens
149
+ )
150
+
151
+ @torch.no_grad()
152
+ def reconstruct(
153
+ self,
154
+ brain_data: BrainData,
155
+ num_steps: int = 50,
156
+ cfg_scale: float = 4.0,
157
+ ) -> ReconstructionResult:
158
+ """Reconstruct an image from brain activity.
159
+
160
+ Args:
161
+ brain_data: fMRI data to decode.
162
+ num_steps: Number of ODE solver steps.
163
+ cfg_scale: Classifier-free guidance scale.
164
+
165
+ Returns:
166
+ ReconstructionResult with the decoded image.
167
+ """
168
+ B = brain_data.batch_size
169
+ device = brain_data.voxels.device
170
+
171
+ # Encode brain
172
+ brain_global, brain_tokens = self.encode_brain(brain_data)
173
+
174
+ # Unconditional embeddings for CFG
175
+ uncond_global = self.uncond_global.expand(B, -1)
176
+ uncond_tokens = self.uncond_tokens.expand(B, -1, -1)
177
+
178
+ # Sample latents via flow matching
179
+ latent_shape = (B, self._latent_channels, self._latent_size, self._latent_size)
180
+ z = self.flow_matcher.sample(
181
+ self.dit,
182
+ shape=latent_shape,
183
+ brain_global=brain_global,
184
+ brain_tokens=brain_tokens,
185
+ num_steps=num_steps,
186
+ cfg_scale=cfg_scale,
187
+ brain_global_uncond=uncond_global,
188
+ brain_tokens_uncond=uncond_tokens,
189
+ )
190
+
191
+ # Decode latents to images
192
+ images = self.vae.decode(z)
193
+ images = images.clamp(0, 1)
194
+
195
+ return ReconstructionResult(
196
+ modality=Modality.IMAGE,
197
+ output=images,
198
+ brain_condition=brain_global,
199
+ n_steps=num_steps,
200
+ cfg_scale=cfg_scale,
201
+ )
202
+
203
+
204
+ def build_brain2img(
205
+ n_voxels: int = 1024,
206
+ img_size: int = 64,
207
+ hidden_dim: int = 256,
208
+ depth: int = 6,
209
+ num_heads: int = 8,
210
+ ) -> Brain2Image:
211
+ """Build a Brain2Image model with sensible defaults.
212
+
213
+ Args:
214
+ n_voxels: Number of fMRI voxels.
215
+ img_size: Target image size (square).
216
+ hidden_dim: DiT hidden dimension.
217
+ depth: Number of DiT blocks.
218
+ num_heads: Attention heads.
219
+ """
220
+ dit_cfg = DiTConfig(
221
+ hidden_dim=hidden_dim,
222
+ depth=depth,
223
+ num_heads=num_heads,
224
+ cond_dim=hidden_dim,
225
+ )
226
+ vae_cfg = VAEConfig(hidden_dims=[32, 64]) # lightweight for small images
227
+ flow_cfg = FlowConfig()
228
+ return Brain2Image(
229
+ n_voxels=n_voxels,
230
+ img_size=img_size,
231
+ dit_config=dit_cfg,
232
+ vae_config=vae_cfg,
233
+ flow_config=flow_cfg,
234
+ )
@@ -0,0 +1,278 @@
1
+ """Brain → Text reconstruction pipeline.
2
+
3
+ Reconstruct what someone read or thought in linguistic form from fMRI.
4
+
5
+ Architecture: fMRI → BrainEncoder → Transformer Decoder with
6
+ cross-attention to brain tokens → autoregressive text generation.
7
+
8
+ Unlike image/audio which use flow matching on continuous latents,
9
+ text uses autoregressive decoding since language is inherently discrete.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import math
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+
20
+ from cortexflow._types import BrainData, Modality, ReconstructionResult
21
+ from cortexflow.brain_encoder import BrainEncoder
22
+
23
+
24
+ class TextDecoderBlock(nn.Module):
25
+ """Transformer decoder block with causal self-attention + brain cross-attention."""
26
+
27
+ def __init__(self, hidden_dim: int, num_heads: int, cond_dim: int, mlp_ratio: float = 4.0):
28
+ super().__init__()
29
+ self.norm1 = nn.LayerNorm(hidden_dim, eps=1e-6)
30
+ self.self_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
31
+ self.norm2 = nn.LayerNorm(hidden_dim, eps=1e-6)
32
+ self.cross_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
33
+ self.cond_proj = nn.Linear(cond_dim, hidden_dim) if cond_dim != hidden_dim else nn.Identity()
34
+ self.norm3 = nn.LayerNorm(hidden_dim, eps=1e-6)
35
+ mlp_h = int(hidden_dim * mlp_ratio)
36
+ self.mlp = nn.Sequential(
37
+ nn.Linear(hidden_dim, mlp_h), nn.SiLU(),
38
+ nn.Linear(mlp_h, hidden_dim),
39
+ )
40
+
41
+ def forward(
42
+ self,
43
+ x: torch.Tensor,
44
+ brain_tokens: torch.Tensor,
45
+ attn_mask: torch.Tensor | None = None,
46
+ ) -> torch.Tensor:
47
+ # Causal self-attention
48
+ h = self.norm1(x)
49
+ h, _ = self.self_attn(h, h, h, attn_mask=attn_mask, need_weights=False)
50
+ x = x + h
51
+
52
+ # Cross-attention to brain tokens
53
+ h = self.norm2(x)
54
+ kv = self.cond_proj(brain_tokens)
55
+ h, _ = self.cross_attn(h, kv, kv, need_weights=False)
56
+ x = x + h
57
+
58
+ # Feedforward
59
+ h = self.norm3(x)
60
+ x = x + self.mlp(h)
61
+ return x
62
+
63
+
64
+ class BrainTextDecoder(nn.Module):
65
+ """Transformer decoder for brain → text.
66
+
67
+ Uses a simple character/word-piece vocabulary for self-contained
68
+ operation without external tokenizers.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ vocab_size: int = 256,
74
+ max_len: int = 128,
75
+ hidden_dim: int = 256,
76
+ depth: int = 6,
77
+ num_heads: int = 8,
78
+ cond_dim: int = 256,
79
+ ):
80
+ super().__init__()
81
+ self.vocab_size = vocab_size
82
+ self.max_len = max_len
83
+ self.hidden_dim = hidden_dim
84
+
85
+ self.token_embed = nn.Embedding(vocab_size, hidden_dim)
86
+ self.pos_embed = nn.Embedding(max_len, hidden_dim)
87
+
88
+ self.blocks = nn.ModuleList([
89
+ TextDecoderBlock(hidden_dim, num_heads, cond_dim) for _ in range(depth)
90
+ ])
91
+ self.final_norm = nn.LayerNorm(hidden_dim, eps=1e-6)
92
+ self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)
93
+
94
+ # Tie weights
95
+ self.lm_head.weight = self.token_embed.weight
96
+
97
+ def _causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
98
+ """Generate causal attention mask."""
99
+ mask = torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=device), diagonal=1)
100
+ return mask
101
+
102
+ def forward(
103
+ self,
104
+ token_ids: torch.Tensor,
105
+ brain_tokens: torch.Tensor,
106
+ ) -> torch.Tensor:
107
+ """Forward pass for training (teacher forcing).
108
+
109
+ Args:
110
+ token_ids: ``(B, T)`` input token IDs.
111
+ brain_tokens: ``(B, K, cond_dim)`` brain conditioning tokens.
112
+
113
+ Returns:
114
+ ``(B, T, vocab_size)`` logits.
115
+ """
116
+ B, T = token_ids.shape
117
+ positions = torch.arange(T, device=token_ids.device).unsqueeze(0)
118
+ x = self.token_embed(token_ids) + self.pos_embed(positions)
119
+
120
+ mask = self._causal_mask(T, token_ids.device)
121
+ for block in self.blocks:
122
+ x = block(x, brain_tokens, attn_mask=mask)
123
+
124
+ x = self.final_norm(x)
125
+ return self.lm_head(x)
126
+
127
+
128
+ class Brain2Text(nn.Module):
129
+ """Reconstruct text from brain activity.
130
+
131
+ Pipeline::
132
+
133
+ fMRI → BrainEncoder → BrainTextDecoder (autoregressive) → tokens → text
134
+
135
+ Uses byte-level encoding (vocab_size=256) for simplicity. Each byte
136
+ is one token, so this handles any text without a tokenizer.
137
+
138
+ Args:
139
+ n_voxels: Number of fMRI voxels.
140
+ max_len: Maximum output text length.
141
+ hidden_dim: Model dimension.
142
+ depth: Number of transformer layers.
143
+ num_heads: Attention heads.
144
+ """
145
+
146
+ def __init__(
147
+ self,
148
+ n_voxels: int = 1024,
149
+ max_len: int = 128,
150
+ hidden_dim: int = 256,
151
+ depth: int = 6,
152
+ num_heads: int = 8,
153
+ ):
154
+ super().__init__()
155
+ cond_dim = hidden_dim
156
+ self.brain_encoder = BrainEncoder(
157
+ n_voxels=n_voxels, cond_dim=cond_dim, n_tokens=16,
158
+ )
159
+ self.decoder = BrainTextDecoder(
160
+ vocab_size=256, # byte-level
161
+ max_len=max_len,
162
+ hidden_dim=hidden_dim,
163
+ depth=depth,
164
+ num_heads=num_heads,
165
+ cond_dim=cond_dim,
166
+ )
167
+ self.max_len = max_len
168
+ self.bos_token = 0 # null byte as BOS
169
+
170
+ @staticmethod
171
+ def text_to_tokens(text: str) -> torch.Tensor:
172
+ """Encode text to byte-level token IDs."""
173
+ return torch.tensor(list(text.encode("utf-8")), dtype=torch.long)
174
+
175
+ @staticmethod
176
+ def tokens_to_text(tokens: torch.Tensor) -> str:
177
+ """Decode byte-level token IDs to text."""
178
+ byte_list = tokens.cpu().tolist()
179
+ # Stop at first null byte
180
+ if 0 in byte_list:
181
+ byte_list = byte_list[: byte_list.index(0)]
182
+ return bytes(byte_list).decode("utf-8", errors="replace")
183
+
184
+ def training_loss(
185
+ self,
186
+ text_tokens: torch.Tensor,
187
+ brain_data: BrainData,
188
+ ) -> torch.Tensor:
189
+ """Compute cross-entropy loss for text generation.
190
+
191
+ Args:
192
+ text_tokens: ``(B, T)`` target token IDs (byte-level).
193
+ brain_data: Corresponding fMRI data.
194
+
195
+ Returns:
196
+ Scalar cross-entropy loss.
197
+ """
198
+ _, brain_tokens = self.brain_encoder(brain_data.voxels)
199
+
200
+ # Input: [BOS, t1, t2, ..., t_{n-1}], Target: [t1, t2, ..., t_n]
201
+ input_tokens = text_tokens[:, :-1]
202
+ target_tokens = text_tokens[:, 1:]
203
+
204
+ logits = self.decoder(input_tokens, brain_tokens)
205
+ return F.cross_entropy(
206
+ logits.reshape(-1, logits.shape[-1]),
207
+ target_tokens.reshape(-1),
208
+ ignore_index=0,
209
+ )
210
+
211
+ @torch.no_grad()
212
+ def reconstruct(
213
+ self,
214
+ brain_data: BrainData,
215
+ max_len: int | None = None,
216
+ temperature: float = 0.8,
217
+ top_k: int = 50,
218
+ ) -> ReconstructionResult:
219
+ """Reconstruct text from brain activity via autoregressive decoding.
220
+
221
+ Args:
222
+ brain_data: fMRI data to decode.
223
+ max_len: Maximum generation length.
224
+ temperature: Sampling temperature.
225
+ top_k: Top-k filtering.
226
+
227
+ Returns:
228
+ ReconstructionResult with generated text as metadata.
229
+ """
230
+ B = brain_data.batch_size
231
+ device = brain_data.voxels.device
232
+ gen_len = max_len or self.max_len
233
+
234
+ _, brain_tokens = self.brain_encoder(brain_data.voxels)
235
+
236
+ # Start with BOS token
237
+ generated = torch.full((B, 1), self.bos_token, dtype=torch.long, device=device)
238
+
239
+ for _ in range(gen_len - 1):
240
+ logits = self.decoder(generated, brain_tokens)
241
+ next_logits = logits[:, -1, :] / temperature
242
+
243
+ # Top-k filtering
244
+ if top_k > 0:
245
+ topk_vals, _ = next_logits.topk(top_k, dim=-1)
246
+ threshold = topk_vals[:, -1].unsqueeze(-1)
247
+ next_logits = next_logits.masked_fill(next_logits < threshold, float("-inf"))
248
+
249
+ probs = F.softmax(next_logits, dim=-1)
250
+ next_token = torch.multinomial(probs, num_samples=1)
251
+ generated = torch.cat([generated, next_token], dim=1)
252
+
253
+ # Stop if all sequences produced a null byte
254
+ if (next_token == 0).all():
255
+ break
256
+
257
+ # Decode to text
258
+ texts = [self.tokens_to_text(generated[i]) for i in range(B)]
259
+
260
+ return ReconstructionResult(
261
+ modality=Modality.TEXT,
262
+ output=generated,
263
+ brain_condition=brain_tokens.mean(dim=1),
264
+ metadata={"texts": texts},
265
+ )
266
+
267
+
268
+ def build_brain2text(
269
+ n_voxels: int = 1024,
270
+ max_len: int = 128,
271
+ hidden_dim: int = 256,
272
+ depth: int = 6,
273
+ ) -> Brain2Text:
274
+ """Build a Brain2Text model with sensible defaults."""
275
+ return Brain2Text(
276
+ n_voxels=n_voxels, max_len=max_len,
277
+ hidden_dim=hidden_dim, depth=depth,
278
+ )