cortexflowx 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,228 @@
1
+ """Brain encoder: fMRI voxels → conditioning embeddings for DiT.
2
+
3
+ Supports two encoding modes:
4
+
5
+ 1. **Global**: MLP projects full voxel vector to a single global embedding.
6
+ Used for AdaLN conditioning in DiT blocks.
7
+ 2. **Tokenized**: MLP projects voxels to a sequence of ``n_tokens`` vectors.
8
+ Used for cross-attention in DiT blocks, giving the model rich spatial
9
+ information about brain activity patterns.
10
+
11
+ Optionally supports **ROI-aware encoding**: separate sub-encoders for
12
+ different cortical regions (V1, FFA, A1, etc.) whose outputs are
13
+ concatenated and projected.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+
22
+ class BrainEncoder(nn.Module):
23
+ """Project fMRI voxels to DiT conditioning embeddings.
24
+
25
+ Produces both a global embedding (for AdaLN) and a token sequence
26
+ (for cross-attention).
27
+
28
+ Args:
29
+ n_voxels: Number of input fMRI voxels.
30
+ cond_dim: Dimension of each conditioning vector.
31
+ n_tokens: Number of brain tokens for cross-attention.
32
+ hidden_dim: Intermediate MLP dimension.
33
+ dropout: Dropout rate for regularization.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ n_voxels: int,
39
+ cond_dim: int = 768,
40
+ n_tokens: int = 16,
41
+ hidden_dim: int | None = None,
42
+ dropout: float = 0.1,
43
+ ) -> None:
44
+ super().__init__()
45
+ self.n_voxels = n_voxels
46
+ self.cond_dim = cond_dim
47
+ self.n_tokens = n_tokens
48
+ h = hidden_dim or cond_dim * 2
49
+
50
+ # Shared backbone
51
+ self.backbone = nn.Sequential(
52
+ nn.Linear(n_voxels, h),
53
+ nn.LayerNorm(h),
54
+ nn.GELU(),
55
+ nn.Dropout(dropout),
56
+ nn.Linear(h, h),
57
+ nn.LayerNorm(h),
58
+ nn.GELU(),
59
+ nn.Dropout(dropout),
60
+ )
61
+
62
+ # Global projection (for AdaLN conditioning)
63
+ self.global_proj = nn.Sequential(
64
+ nn.Linear(h, cond_dim),
65
+ nn.LayerNorm(cond_dim),
66
+ )
67
+
68
+ # Token projection (for cross-attention)
69
+ self.token_proj = nn.Sequential(
70
+ nn.Linear(h, n_tokens * cond_dim),
71
+ )
72
+
73
+ def forward(self, voxels: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
74
+ """Encode fMRI voxels to conditioning signals.
75
+
76
+ Args:
77
+ voxels: ``(B, n_voxels)`` BOLD activity.
78
+
79
+ Returns:
80
+ (brain_global, brain_tokens):
81
+ - brain_global: ``(B, cond_dim)`` for AdaLN.
82
+ - brain_tokens: ``(B, n_tokens, cond_dim)`` for cross-attention.
83
+ """
84
+ if voxels.ndim == 1:
85
+ voxels = voxels.unsqueeze(0)
86
+
87
+ h = self.backbone(voxels)
88
+ brain_global = self.global_proj(h)
89
+ brain_tokens = self.token_proj(h).view(-1, self.n_tokens, self.cond_dim)
90
+ return brain_global, brain_tokens
91
+
92
+
93
+ class ROIBrainEncoder(nn.Module):
94
+ """ROI-aware brain encoder with per-region sub-encoders.
95
+
96
+ Different brain regions encode different information (V1 → low-level
97
+ visual, FFA → faces, A1 → audio, etc.). This encoder processes each
98
+ ROI independently and then fuses them.
99
+
100
+ Args:
101
+ roi_sizes: Dict mapping ROI names to voxel counts.
102
+ cond_dim: Output conditioning dimension.
103
+ n_tokens: Number of brain tokens for cross-attention.
104
+ per_roi_dim: Hidden dim for each ROI sub-encoder.
105
+ """
106
+
107
+ def __init__(
108
+ self,
109
+ roi_sizes: dict[str, int],
110
+ cond_dim: int = 768,
111
+ n_tokens: int = 16,
112
+ per_roi_dim: int = 128,
113
+ ) -> None:
114
+ super().__init__()
115
+ self.roi_names = sorted(roi_sizes.keys())
116
+ self.roi_sizes = roi_sizes
117
+ self.cond_dim = cond_dim
118
+ self.n_tokens = n_tokens
119
+
120
+ # Per-ROI sub-encoders
121
+ self.roi_encoders = nn.ModuleDict()
122
+ for name in self.roi_names:
123
+ n = roi_sizes[name]
124
+ self.roi_encoders[name] = nn.Sequential(
125
+ nn.Linear(n, per_roi_dim),
126
+ nn.LayerNorm(per_roi_dim),
127
+ nn.GELU(),
128
+ nn.Linear(per_roi_dim, per_roi_dim),
129
+ )
130
+
131
+ total_roi_dim = per_roi_dim * len(self.roi_names)
132
+
133
+ # Fusion
134
+ self.global_proj = nn.Sequential(
135
+ nn.Linear(total_roi_dim, cond_dim),
136
+ nn.LayerNorm(cond_dim),
137
+ )
138
+ self.token_proj = nn.Linear(total_roi_dim, n_tokens * cond_dim)
139
+
140
+ def forward(
141
+ self, roi_voxels: dict[str, torch.Tensor]
142
+ ) -> tuple[torch.Tensor, torch.Tensor]:
143
+ """Encode per-ROI voxels.
144
+
145
+ Args:
146
+ roi_voxels: Dict mapping ROI names to tensors ``(B, n_voxels_roi)``.
147
+
148
+ Returns:
149
+ (brain_global, brain_tokens)
150
+ """
151
+ encoded = []
152
+ for name in self.roi_names:
153
+ x = roi_voxels[name]
154
+ if x.ndim == 1:
155
+ x = x.unsqueeze(0)
156
+ encoded.append(self.roi_encoders[name](x))
157
+
158
+ h = torch.cat(encoded, dim=-1)
159
+ brain_global = self.global_proj(h)
160
+ brain_tokens = self.token_proj(h).view(-1, self.n_tokens, self.cond_dim)
161
+ return brain_global, brain_tokens
162
+
163
+
164
+ class SubjectAdapter(nn.Module):
165
+ """Lightweight per-subject adapter (LoRA-style).
166
+
167
+ Different subjects have different brain anatomy and functional
168
+ organization. This adapter learns a low-rank residual per subject
169
+ that adjusts the shared brain encoder output.
170
+
171
+ Args:
172
+ cond_dim: Conditioning dimension to adapt.
173
+ rank: Low-rank dimension.
174
+ n_subjects: Number of subjects.
175
+ """
176
+
177
+ def __init__(self, cond_dim: int = 768, rank: int = 16, n_subjects: int = 10) -> None:
178
+ super().__init__()
179
+ self.subject_embed = nn.Embedding(n_subjects, rank)
180
+ self.down = nn.Linear(cond_dim, rank, bias=False)
181
+ self.up = nn.Linear(rank, cond_dim, bias=False)
182
+
183
+ nn.init.zeros_(self.up.weight)
184
+
185
+ def forward(
186
+ self, brain_global: torch.Tensor, subject_idx: torch.Tensor
187
+ ) -> torch.Tensor:
188
+ """Apply subject-specific adaptation.
189
+
190
+ Args:
191
+ brain_global: ``(B, cond_dim)``
192
+ subject_idx: ``(B,)`` integer subject indices.
193
+
194
+ Returns:
195
+ Adapted ``(B, cond_dim)``
196
+ """
197
+ s = self.subject_embed(subject_idx) # (B, rank)
198
+ residual = self.up(self.down(brain_global) * s)
199
+ return brain_global + residual
200
+
201
+
202
+ def make_synthetic_brain_data(
203
+ batch_size: int = 4,
204
+ n_voxels: int = 1024,
205
+ device: torch.device | str = "cpu",
206
+ ) -> torch.Tensor:
207
+ """Generate synthetic fMRI-like data for testing.
208
+
209
+ Produces smooth, spatially-correlated activations that roughly
210
+ mimic the statistics of real BOLD signal.
211
+ """
212
+ # Base signal with spatial correlation
213
+ raw = torch.randn(batch_size, n_voxels, device=device)
214
+ # Smooth with a 1D Gaussian kernel
215
+ kernel_size = min(31, n_voxels // 2 * 2 + 1)
216
+ if kernel_size >= 3:
217
+ sigma = kernel_size / 6
218
+ x = torch.arange(kernel_size, dtype=torch.float32, device=device) - kernel_size // 2
219
+ kernel = torch.exp(-0.5 * (x / sigma) ** 2)
220
+ kernel = kernel / kernel.sum()
221
+ kernel = kernel.view(1, 1, -1)
222
+ raw_3d = raw.unsqueeze(1) # (B, 1, V)
223
+ padding = kernel_size // 2
224
+ smoothed = torch.nn.functional.conv1d(raw_3d, kernel, padding=padding)
225
+ raw = smoothed.squeeze(1)
226
+ # Normalize to zero mean, unit variance
227
+ raw = (raw - raw.mean(dim=-1, keepdim=True)) / (raw.std(dim=-1, keepdim=True) + 1e-8)
228
+ return raw
cortexflow/dit.py ADDED
@@ -0,0 +1,397 @@
1
+ """Diffusion Transformer (DiT) backbone.
2
+
3
+ Implements the DiT architecture from Peebles & Xie (2022) with modern
4
+ improvements from SD3/FLUX:
5
+
6
+ - **AdaLN-Zero** conditioning: timestep + brain embeddings modulate
7
+ LayerNorm parameters with a learned zero-initialized gate.
8
+ - **QK-Norm**: normalizes query and key projections for training
9
+ stability at scale (per Dehghani et al. 2023).
10
+ - **SwiGLU** activation in the feedforward network (Shazeer 2020).
11
+ - **Cross-attention** to brain conditioning tokens for rich
12
+ fMRI→latent interaction.
13
+
14
+ Reference: "Scalable Diffusion Models with Transformers" (arXiv:2212.09748)
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import math
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+
25
+ from cortexflow._types import DiTConfig
26
+
27
+
28
+ # ── Helpers ──────────────────────────────────────────────────────────────
29
+
30
+
31
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
32
+ """Apply adaptive LayerNorm modulation: x * (1 + scale) + shift."""
33
+ return x * (1.0 + scale.unsqueeze(1)) + shift.unsqueeze(1)
34
+
35
+
36
+ class SwiGLU(nn.Module):
37
+ """SwiGLU feedforward network (Shazeer 2020).
38
+
39
+ Projects to 2× intermediate dim, applies SiLU gate, projects back.
40
+ """
41
+
42
+ def __init__(self, dim: int, hidden_dim: int) -> None:
43
+ super().__init__()
44
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
45
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
46
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
47
+
48
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
49
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
50
+
51
+
52
+ # ── Timestep Embedding ──────────────────────────────────────────────────
53
+
54
+
55
+ class TimestepEmbedding(nn.Module):
56
+ """Sinusoidal timestep embedding → MLP projection.
57
+
58
+ Maps scalar timestep ``t ∈ [0, 1]`` to a ``dim``-dimensional vector.
59
+ """
60
+
61
+ def __init__(self, dim: int, max_period: int = 10000) -> None:
62
+ super().__init__()
63
+ self.dim = dim
64
+ self.max_period = max_period
65
+ self.mlp = nn.Sequential(
66
+ nn.Linear(dim, dim * 4),
67
+ nn.SiLU(),
68
+ nn.Linear(dim * 4, dim),
69
+ )
70
+
71
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
72
+ half = self.dim // 2
73
+ freqs = torch.exp(
74
+ -math.log(self.max_period)
75
+ * torch.arange(half, device=t.device, dtype=torch.float32)
76
+ / half
77
+ )
78
+ # t: (batch,) or (batch, 1)
79
+ if t.ndim == 0:
80
+ t = t.unsqueeze(0)
81
+ t_flat = t.view(-1, 1).float()
82
+ args = t_flat * freqs.unsqueeze(0)
83
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
84
+ if self.dim % 2:
85
+ embedding = F.pad(embedding, (0, 1))
86
+ return self.mlp(embedding)
87
+
88
+
89
+ # ── Patch Embedding ─────────────────────────────────────────────────────
90
+
91
+
92
+ class PatchEmbed(nn.Module):
93
+ """Convert spatial latent maps into a sequence of patch tokens."""
94
+
95
+ def __init__(self, in_channels: int, hidden_dim: int, patch_size: int = 2) -> None:
96
+ super().__init__()
97
+ self.patch_size = patch_size
98
+ self.proj = nn.Conv2d(
99
+ in_channels, hidden_dim, kernel_size=patch_size, stride=patch_size
100
+ )
101
+
102
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
103
+ # x: (B, C, H, W) → (B, N, D)
104
+ x = self.proj(x) # (B, D, H', W')
105
+ return x.flatten(2).transpose(1, 2)
106
+
107
+
108
+ # ── DiT Block ───────────────────────────────────────────────────────────
109
+
110
+
111
+ class DiTBlock(nn.Module):
112
+ """Diffusion Transformer block with AdaLN-Zero conditioning.
113
+
114
+ Each block contains:
115
+ 1. AdaLN-modulated self-attention (with optional QK-Norm)
116
+ 2. Optional cross-attention to brain conditioning tokens
117
+ 3. AdaLN-modulated SwiGLU feedforward
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ hidden_dim: int,
123
+ num_heads: int,
124
+ mlp_ratio: float = 4.0,
125
+ qk_norm: bool = True,
126
+ use_cross_attn: bool = True,
127
+ cond_dim: int = 768,
128
+ ) -> None:
129
+ super().__init__()
130
+ self.hidden_dim = hidden_dim
131
+ self.use_cross_attn = use_cross_attn
132
+
133
+ # AdaLN modulation: produces 6 vectors (shift/scale/gate × 2)
134
+ n_adaln = 9 if use_cross_attn else 6
135
+ self.adaLN_modulation = nn.Sequential(
136
+ nn.SiLU(),
137
+ nn.Linear(hidden_dim, n_adaln * hidden_dim),
138
+ )
139
+
140
+ # Self-attention
141
+ self.norm1 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
142
+ self.attn = nn.MultiheadAttention(
143
+ hidden_dim, num_heads, batch_first=True, dropout=0.0
144
+ )
145
+
146
+ # QK-Norm: normalize Q and K for training stability
147
+ self.qk_norm = qk_norm
148
+ if qk_norm:
149
+ self.q_norm = nn.LayerNorm(hidden_dim // num_heads, eps=1e-6)
150
+ self.k_norm = nn.LayerNorm(hidden_dim // num_heads, eps=1e-6)
151
+ self.num_heads = num_heads
152
+
153
+ # Cross-attention to brain conditioning
154
+ if use_cross_attn:
155
+ self.norm_cross = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
156
+ self.cross_attn = nn.MultiheadAttention(
157
+ hidden_dim, num_heads, batch_first=True, dropout=0.0
158
+ )
159
+ self.cond_proj = (
160
+ nn.Linear(cond_dim, hidden_dim)
161
+ if cond_dim != hidden_dim
162
+ else nn.Identity()
163
+ )
164
+
165
+ # Feedforward
166
+ self.norm2 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
167
+ mlp_hidden = int(hidden_dim * mlp_ratio)
168
+ self.mlp = SwiGLU(hidden_dim, mlp_hidden)
169
+
170
+ # Zero-initialize the gating parameters
171
+ nn.init.zeros_(self.adaLN_modulation[-1].weight)
172
+ nn.init.zeros_(self.adaLN_modulation[-1].bias)
173
+
174
+ def _qk_norm_attn(
175
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
176
+ ) -> torch.Tensor:
177
+ """Self-attention with per-head QK normalization."""
178
+ B, N, D = q.shape
179
+ head_dim = D // self.num_heads
180
+
181
+ # Reshape to (B, heads, N, head_dim)
182
+ q = q.view(B, N, self.num_heads, head_dim).transpose(1, 2)
183
+ k = k.view(B, N, self.num_heads, head_dim).transpose(1, 2)
184
+ v = v.view(B, N, self.num_heads, head_dim).transpose(1, 2)
185
+
186
+ # Normalize Q and K
187
+ q = self.q_norm(q)
188
+ k = self.k_norm(k)
189
+
190
+ # Scaled dot-product attention
191
+ scale = head_dim ** -0.5
192
+ attn = (q @ k.transpose(-2, -1)) * scale
193
+ attn = attn.softmax(dim=-1)
194
+ out = attn @ v
195
+
196
+ return out.transpose(1, 2).reshape(B, N, D)
197
+
198
+ def forward(
199
+ self,
200
+ x: torch.Tensor,
201
+ c: torch.Tensor,
202
+ brain_tokens: torch.Tensor | None = None,
203
+ ) -> torch.Tensor:
204
+ """Forward pass.
205
+
206
+ Args:
207
+ x: Patch tokens ``(B, N, D)``.
208
+ c: Conditioning vector ``(B, D)`` — fused timestep + brain global.
209
+ brain_tokens: Brain conditioning tokens ``(B, T, cond_dim)`` for
210
+ cross-attention.
211
+ """
212
+ # Compute all modulation parameters at once
213
+ mod = self.adaLN_modulation(c)
214
+ if self.use_cross_attn:
215
+ (
216
+ shift_sa, scale_sa, gate_sa,
217
+ shift_ca, scale_ca, gate_ca,
218
+ shift_ff, scale_ff, gate_ff,
219
+ ) = mod.chunk(9, dim=-1)
220
+ else:
221
+ shift_sa, scale_sa, gate_sa, shift_ff, scale_ff, gate_ff = mod.chunk(
222
+ 6, dim=-1
223
+ )
224
+
225
+ # 1. Self-attention with AdaLN
226
+ h = modulate(self.norm1(x), shift_sa, scale_sa)
227
+ if self.qk_norm:
228
+ h = self._qk_norm_attn(h, h, h)
229
+ else:
230
+ h, _ = self.attn(h, h, h, need_weights=False)
231
+ x = x + gate_sa.unsqueeze(1) * h
232
+
233
+ # 2. Cross-attention to brain tokens
234
+ if self.use_cross_attn and brain_tokens is not None:
235
+ h = modulate(self.norm_cross(x), shift_ca, scale_ca)
236
+ kv = self.cond_proj(brain_tokens)
237
+ h, _ = self.cross_attn(h, kv, kv, need_weights=False)
238
+ x = x + gate_ca.unsqueeze(1) * h
239
+
240
+ # 3. Feedforward with AdaLN
241
+ h = modulate(self.norm2(x), shift_ff, scale_ff)
242
+ h = self.mlp(h)
243
+ x = x + gate_ff.unsqueeze(1) * h
244
+
245
+ return x
246
+
247
+
248
+ # ── Final Layer ─────────────────────────────────────────────────────────
249
+
250
+
251
+ class FinalLayer(nn.Module):
252
+ """DiT final layer: AdaLN → linear projection back to patch space."""
253
+
254
+ def __init__(self, hidden_dim: int, patch_size: int, out_channels: int) -> None:
255
+ super().__init__()
256
+ self.norm = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
257
+ self.adaLN = nn.Sequential(
258
+ nn.SiLU(),
259
+ nn.Linear(hidden_dim, 2 * hidden_dim),
260
+ )
261
+ self.proj = nn.Linear(hidden_dim, patch_size * patch_size * out_channels)
262
+
263
+ nn.init.zeros_(self.adaLN[-1].weight)
264
+ nn.init.zeros_(self.adaLN[-1].bias)
265
+ nn.init.zeros_(self.proj.weight)
266
+ nn.init.zeros_(self.proj.bias)
267
+
268
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
269
+ shift, scale = self.adaLN(c).chunk(2, dim=-1)
270
+ x = modulate(self.norm(x), shift, scale)
271
+ return self.proj(x)
272
+
273
+
274
+ # ── Diffusion Transformer ──────────────────────────────────────────────
275
+
276
+
277
+ class DiffusionTransformer(nn.Module):
278
+ """Diffusion Transformer (DiT) for brain-conditioned generation.
279
+
280
+ Operates on latent patches and predicts the velocity field for
281
+ rectified flow matching.
282
+
283
+ Args:
284
+ config: DiT configuration.
285
+ img_size: Spatial size of the latent input (H = W).
286
+ """
287
+
288
+ def __init__(self, config: DiTConfig | None = None, img_size: int = 32) -> None:
289
+ super().__init__()
290
+ cfg = config or DiTConfig()
291
+ self.config = cfg
292
+ self.img_size = img_size
293
+
294
+ # Patch embedding
295
+ self.patch_embed = PatchEmbed(cfg.in_channels, cfg.hidden_dim, cfg.patch_size)
296
+ n_patches = (img_size // cfg.patch_size) ** 2
297
+
298
+ # Learned positional embedding
299
+ self.pos_embed = nn.Parameter(torch.zeros(1, n_patches, cfg.hidden_dim))
300
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
301
+
302
+ # Timestep embedding
303
+ self.time_embed = TimestepEmbedding(cfg.hidden_dim)
304
+
305
+ # Brain conditioning global projection
306
+ self.cond_global_proj = nn.Sequential(
307
+ nn.Linear(cfg.cond_dim, cfg.hidden_dim),
308
+ nn.SiLU(),
309
+ nn.Linear(cfg.hidden_dim, cfg.hidden_dim),
310
+ )
311
+
312
+ # Transformer blocks
313
+ self.blocks = nn.ModuleList(
314
+ [
315
+ DiTBlock(
316
+ hidden_dim=cfg.hidden_dim,
317
+ num_heads=cfg.num_heads,
318
+ mlp_ratio=cfg.mlp_ratio,
319
+ qk_norm=cfg.qk_norm,
320
+ use_cross_attn=cfg.use_cross_attn,
321
+ cond_dim=cfg.cond_dim,
322
+ )
323
+ for _ in range(cfg.depth)
324
+ ]
325
+ )
326
+
327
+ # Output projection
328
+ self.final_layer = FinalLayer(cfg.hidden_dim, cfg.patch_size, cfg.in_channels)
329
+
330
+ self._initialize_weights()
331
+
332
+ def _initialize_weights(self) -> None:
333
+ """Initialize weights following DiT conventions."""
334
+
335
+ def _init(m: nn.Module) -> None:
336
+ if isinstance(m, nn.Linear):
337
+ nn.init.xavier_uniform_(m.weight)
338
+ if m.bias is not None:
339
+ nn.init.zeros_(m.bias)
340
+ elif isinstance(m, nn.Conv2d):
341
+ nn.init.xavier_uniform_(m.weight)
342
+ if m.bias is not None:
343
+ nn.init.zeros_(m.bias)
344
+
345
+ self.apply(_init)
346
+
347
+ def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
348
+ """Reshape patch tokens back to spatial latent maps.
349
+
350
+ Args:
351
+ x: ``(B, N, patch_size² × C)``
352
+
353
+ Returns:
354
+ ``(B, C, H, W)``
355
+ """
356
+ p = self.config.patch_size
357
+ c = self.config.in_channels
358
+ h = w = self.img_size // p
359
+ x = x.view(-1, h, w, p, p, c)
360
+ x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
361
+ return x.view(-1, c, h * p, w * p)
362
+
363
+ def forward(
364
+ self,
365
+ x: torch.Tensor,
366
+ t: torch.Tensor,
367
+ brain_global: torch.Tensor,
368
+ brain_tokens: torch.Tensor | None = None,
369
+ ) -> torch.Tensor:
370
+ """Predict velocity field v(x_t, t | brain).
371
+
372
+ Args:
373
+ x: Noisy latent ``(B, C, H, W)``.
374
+ t: Timestep ``(B,)`` in ``[0, 1]``.
375
+ brain_global: Global brain embedding ``(B, cond_dim)`` — pooled
376
+ fMRI representation used for AdaLN conditioning.
377
+ brain_tokens: Sequence of brain tokens ``(B, T, cond_dim)`` for
378
+ cross-attention. If None, cross-attention is skipped.
379
+
380
+ Returns:
381
+ Predicted velocity ``(B, C, H, W)``.
382
+ """
383
+ # Embed patches + add positional encoding
384
+ x = self.patch_embed(x) + self.pos_embed
385
+
386
+ # Condition = timestep + brain global embedding
387
+ t_emb = self.time_embed(t)
388
+ c_global = self.cond_global_proj(brain_global)
389
+ c = t_emb + c_global
390
+
391
+ # Transformer blocks
392
+ for block in self.blocks:
393
+ x = block(x, c, brain_tokens)
394
+
395
+ # Final output
396
+ x = self.final_layer(x, c)
397
+ return self.unpatchify(x)