cortexflowx 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortexflow/__init__.py +78 -0
- cortexflow/_types.py +115 -0
- cortexflow/brain2audio.py +298 -0
- cortexflow/brain2img.py +234 -0
- cortexflow/brain2text.py +278 -0
- cortexflow/brain_encoder.py +228 -0
- cortexflow/dit.py +397 -0
- cortexflow/flow_matching.py +236 -0
- cortexflow/training.py +283 -0
- cortexflow/vae.py +232 -0
- cortexflowx-0.1.0.dist-info/METADATA +218 -0
- cortexflowx-0.1.0.dist-info/RECORD +14 -0
- cortexflowx-0.1.0.dist-info/WHEEL +4 -0
- cortexflowx-0.1.0.dist-info/licenses/LICENSE +190 -0
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
"""Brain encoder: fMRI voxels → conditioning embeddings for DiT.
|
|
2
|
+
|
|
3
|
+
Supports two encoding modes:
|
|
4
|
+
|
|
5
|
+
1. **Global**: MLP projects full voxel vector to a single global embedding.
|
|
6
|
+
Used for AdaLN conditioning in DiT blocks.
|
|
7
|
+
2. **Tokenized**: MLP projects voxels to a sequence of ``n_tokens`` vectors.
|
|
8
|
+
Used for cross-attention in DiT blocks, giving the model rich spatial
|
|
9
|
+
information about brain activity patterns.
|
|
10
|
+
|
|
11
|
+
Optionally supports **ROI-aware encoding**: separate sub-encoders for
|
|
12
|
+
different cortical regions (V1, FFA, A1, etc.) whose outputs are
|
|
13
|
+
concatenated and projected.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn as nn
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BrainEncoder(nn.Module):
|
|
23
|
+
"""Project fMRI voxels to DiT conditioning embeddings.
|
|
24
|
+
|
|
25
|
+
Produces both a global embedding (for AdaLN) and a token sequence
|
|
26
|
+
(for cross-attention).
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
n_voxels: Number of input fMRI voxels.
|
|
30
|
+
cond_dim: Dimension of each conditioning vector.
|
|
31
|
+
n_tokens: Number of brain tokens for cross-attention.
|
|
32
|
+
hidden_dim: Intermediate MLP dimension.
|
|
33
|
+
dropout: Dropout rate for regularization.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
n_voxels: int,
|
|
39
|
+
cond_dim: int = 768,
|
|
40
|
+
n_tokens: int = 16,
|
|
41
|
+
hidden_dim: int | None = None,
|
|
42
|
+
dropout: float = 0.1,
|
|
43
|
+
) -> None:
|
|
44
|
+
super().__init__()
|
|
45
|
+
self.n_voxels = n_voxels
|
|
46
|
+
self.cond_dim = cond_dim
|
|
47
|
+
self.n_tokens = n_tokens
|
|
48
|
+
h = hidden_dim or cond_dim * 2
|
|
49
|
+
|
|
50
|
+
# Shared backbone
|
|
51
|
+
self.backbone = nn.Sequential(
|
|
52
|
+
nn.Linear(n_voxels, h),
|
|
53
|
+
nn.LayerNorm(h),
|
|
54
|
+
nn.GELU(),
|
|
55
|
+
nn.Dropout(dropout),
|
|
56
|
+
nn.Linear(h, h),
|
|
57
|
+
nn.LayerNorm(h),
|
|
58
|
+
nn.GELU(),
|
|
59
|
+
nn.Dropout(dropout),
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Global projection (for AdaLN conditioning)
|
|
63
|
+
self.global_proj = nn.Sequential(
|
|
64
|
+
nn.Linear(h, cond_dim),
|
|
65
|
+
nn.LayerNorm(cond_dim),
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Token projection (for cross-attention)
|
|
69
|
+
self.token_proj = nn.Sequential(
|
|
70
|
+
nn.Linear(h, n_tokens * cond_dim),
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def forward(self, voxels: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
74
|
+
"""Encode fMRI voxels to conditioning signals.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
voxels: ``(B, n_voxels)`` BOLD activity.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
(brain_global, brain_tokens):
|
|
81
|
+
- brain_global: ``(B, cond_dim)`` for AdaLN.
|
|
82
|
+
- brain_tokens: ``(B, n_tokens, cond_dim)`` for cross-attention.
|
|
83
|
+
"""
|
|
84
|
+
if voxels.ndim == 1:
|
|
85
|
+
voxels = voxels.unsqueeze(0)
|
|
86
|
+
|
|
87
|
+
h = self.backbone(voxels)
|
|
88
|
+
brain_global = self.global_proj(h)
|
|
89
|
+
brain_tokens = self.token_proj(h).view(-1, self.n_tokens, self.cond_dim)
|
|
90
|
+
return brain_global, brain_tokens
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class ROIBrainEncoder(nn.Module):
|
|
94
|
+
"""ROI-aware brain encoder with per-region sub-encoders.
|
|
95
|
+
|
|
96
|
+
Different brain regions encode different information (V1 → low-level
|
|
97
|
+
visual, FFA → faces, A1 → audio, etc.). This encoder processes each
|
|
98
|
+
ROI independently and then fuses them.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
roi_sizes: Dict mapping ROI names to voxel counts.
|
|
102
|
+
cond_dim: Output conditioning dimension.
|
|
103
|
+
n_tokens: Number of brain tokens for cross-attention.
|
|
104
|
+
per_roi_dim: Hidden dim for each ROI sub-encoder.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
def __init__(
|
|
108
|
+
self,
|
|
109
|
+
roi_sizes: dict[str, int],
|
|
110
|
+
cond_dim: int = 768,
|
|
111
|
+
n_tokens: int = 16,
|
|
112
|
+
per_roi_dim: int = 128,
|
|
113
|
+
) -> None:
|
|
114
|
+
super().__init__()
|
|
115
|
+
self.roi_names = sorted(roi_sizes.keys())
|
|
116
|
+
self.roi_sizes = roi_sizes
|
|
117
|
+
self.cond_dim = cond_dim
|
|
118
|
+
self.n_tokens = n_tokens
|
|
119
|
+
|
|
120
|
+
# Per-ROI sub-encoders
|
|
121
|
+
self.roi_encoders = nn.ModuleDict()
|
|
122
|
+
for name in self.roi_names:
|
|
123
|
+
n = roi_sizes[name]
|
|
124
|
+
self.roi_encoders[name] = nn.Sequential(
|
|
125
|
+
nn.Linear(n, per_roi_dim),
|
|
126
|
+
nn.LayerNorm(per_roi_dim),
|
|
127
|
+
nn.GELU(),
|
|
128
|
+
nn.Linear(per_roi_dim, per_roi_dim),
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
total_roi_dim = per_roi_dim * len(self.roi_names)
|
|
132
|
+
|
|
133
|
+
# Fusion
|
|
134
|
+
self.global_proj = nn.Sequential(
|
|
135
|
+
nn.Linear(total_roi_dim, cond_dim),
|
|
136
|
+
nn.LayerNorm(cond_dim),
|
|
137
|
+
)
|
|
138
|
+
self.token_proj = nn.Linear(total_roi_dim, n_tokens * cond_dim)
|
|
139
|
+
|
|
140
|
+
def forward(
|
|
141
|
+
self, roi_voxels: dict[str, torch.Tensor]
|
|
142
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
143
|
+
"""Encode per-ROI voxels.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
roi_voxels: Dict mapping ROI names to tensors ``(B, n_voxels_roi)``.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
(brain_global, brain_tokens)
|
|
150
|
+
"""
|
|
151
|
+
encoded = []
|
|
152
|
+
for name in self.roi_names:
|
|
153
|
+
x = roi_voxels[name]
|
|
154
|
+
if x.ndim == 1:
|
|
155
|
+
x = x.unsqueeze(0)
|
|
156
|
+
encoded.append(self.roi_encoders[name](x))
|
|
157
|
+
|
|
158
|
+
h = torch.cat(encoded, dim=-1)
|
|
159
|
+
brain_global = self.global_proj(h)
|
|
160
|
+
brain_tokens = self.token_proj(h).view(-1, self.n_tokens, self.cond_dim)
|
|
161
|
+
return brain_global, brain_tokens
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class SubjectAdapter(nn.Module):
|
|
165
|
+
"""Lightweight per-subject adapter (LoRA-style).
|
|
166
|
+
|
|
167
|
+
Different subjects have different brain anatomy and functional
|
|
168
|
+
organization. This adapter learns a low-rank residual per subject
|
|
169
|
+
that adjusts the shared brain encoder output.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
cond_dim: Conditioning dimension to adapt.
|
|
173
|
+
rank: Low-rank dimension.
|
|
174
|
+
n_subjects: Number of subjects.
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
def __init__(self, cond_dim: int = 768, rank: int = 16, n_subjects: int = 10) -> None:
|
|
178
|
+
super().__init__()
|
|
179
|
+
self.subject_embed = nn.Embedding(n_subjects, rank)
|
|
180
|
+
self.down = nn.Linear(cond_dim, rank, bias=False)
|
|
181
|
+
self.up = nn.Linear(rank, cond_dim, bias=False)
|
|
182
|
+
|
|
183
|
+
nn.init.zeros_(self.up.weight)
|
|
184
|
+
|
|
185
|
+
def forward(
|
|
186
|
+
self, brain_global: torch.Tensor, subject_idx: torch.Tensor
|
|
187
|
+
) -> torch.Tensor:
|
|
188
|
+
"""Apply subject-specific adaptation.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
brain_global: ``(B, cond_dim)``
|
|
192
|
+
subject_idx: ``(B,)`` integer subject indices.
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
Adapted ``(B, cond_dim)``
|
|
196
|
+
"""
|
|
197
|
+
s = self.subject_embed(subject_idx) # (B, rank)
|
|
198
|
+
residual = self.up(self.down(brain_global) * s)
|
|
199
|
+
return brain_global + residual
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def make_synthetic_brain_data(
|
|
203
|
+
batch_size: int = 4,
|
|
204
|
+
n_voxels: int = 1024,
|
|
205
|
+
device: torch.device | str = "cpu",
|
|
206
|
+
) -> torch.Tensor:
|
|
207
|
+
"""Generate synthetic fMRI-like data for testing.
|
|
208
|
+
|
|
209
|
+
Produces smooth, spatially-correlated activations that roughly
|
|
210
|
+
mimic the statistics of real BOLD signal.
|
|
211
|
+
"""
|
|
212
|
+
# Base signal with spatial correlation
|
|
213
|
+
raw = torch.randn(batch_size, n_voxels, device=device)
|
|
214
|
+
# Smooth with a 1D Gaussian kernel
|
|
215
|
+
kernel_size = min(31, n_voxels // 2 * 2 + 1)
|
|
216
|
+
if kernel_size >= 3:
|
|
217
|
+
sigma = kernel_size / 6
|
|
218
|
+
x = torch.arange(kernel_size, dtype=torch.float32, device=device) - kernel_size // 2
|
|
219
|
+
kernel = torch.exp(-0.5 * (x / sigma) ** 2)
|
|
220
|
+
kernel = kernel / kernel.sum()
|
|
221
|
+
kernel = kernel.view(1, 1, -1)
|
|
222
|
+
raw_3d = raw.unsqueeze(1) # (B, 1, V)
|
|
223
|
+
padding = kernel_size // 2
|
|
224
|
+
smoothed = torch.nn.functional.conv1d(raw_3d, kernel, padding=padding)
|
|
225
|
+
raw = smoothed.squeeze(1)
|
|
226
|
+
# Normalize to zero mean, unit variance
|
|
227
|
+
raw = (raw - raw.mean(dim=-1, keepdim=True)) / (raw.std(dim=-1, keepdim=True) + 1e-8)
|
|
228
|
+
return raw
|
cortexflow/dit.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
1
|
+
"""Diffusion Transformer (DiT) backbone.
|
|
2
|
+
|
|
3
|
+
Implements the DiT architecture from Peebles & Xie (2022) with modern
|
|
4
|
+
improvements from SD3/FLUX:
|
|
5
|
+
|
|
6
|
+
- **AdaLN-Zero** conditioning: timestep + brain embeddings modulate
|
|
7
|
+
LayerNorm parameters with a learned zero-initialized gate.
|
|
8
|
+
- **QK-Norm**: normalizes query and key projections for training
|
|
9
|
+
stability at scale (per Dehghani et al. 2023).
|
|
10
|
+
- **SwiGLU** activation in the feedforward network (Shazeer 2020).
|
|
11
|
+
- **Cross-attention** to brain conditioning tokens for rich
|
|
12
|
+
fMRI→latent interaction.
|
|
13
|
+
|
|
14
|
+
Reference: "Scalable Diffusion Models with Transformers" (arXiv:2212.09748)
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import math
|
|
20
|
+
|
|
21
|
+
import torch
|
|
22
|
+
import torch.nn as nn
|
|
23
|
+
import torch.nn.functional as F
|
|
24
|
+
|
|
25
|
+
from cortexflow._types import DiTConfig
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# ── Helpers ──────────────────────────────────────────────────────────────
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
|
32
|
+
"""Apply adaptive LayerNorm modulation: x * (1 + scale) + shift."""
|
|
33
|
+
return x * (1.0 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class SwiGLU(nn.Module):
|
|
37
|
+
"""SwiGLU feedforward network (Shazeer 2020).
|
|
38
|
+
|
|
39
|
+
Projects to 2× intermediate dim, applies SiLU gate, projects back.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, dim: int, hidden_dim: int) -> None:
|
|
43
|
+
super().__init__()
|
|
44
|
+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
|
45
|
+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
|
46
|
+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
|
47
|
+
|
|
48
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
49
|
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# ── Timestep Embedding ──────────────────────────────────────────────────
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class TimestepEmbedding(nn.Module):
|
|
56
|
+
"""Sinusoidal timestep embedding → MLP projection.
|
|
57
|
+
|
|
58
|
+
Maps scalar timestep ``t ∈ [0, 1]`` to a ``dim``-dimensional vector.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(self, dim: int, max_period: int = 10000) -> None:
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.dim = dim
|
|
64
|
+
self.max_period = max_period
|
|
65
|
+
self.mlp = nn.Sequential(
|
|
66
|
+
nn.Linear(dim, dim * 4),
|
|
67
|
+
nn.SiLU(),
|
|
68
|
+
nn.Linear(dim * 4, dim),
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
|
72
|
+
half = self.dim // 2
|
|
73
|
+
freqs = torch.exp(
|
|
74
|
+
-math.log(self.max_period)
|
|
75
|
+
* torch.arange(half, device=t.device, dtype=torch.float32)
|
|
76
|
+
/ half
|
|
77
|
+
)
|
|
78
|
+
# t: (batch,) or (batch, 1)
|
|
79
|
+
if t.ndim == 0:
|
|
80
|
+
t = t.unsqueeze(0)
|
|
81
|
+
t_flat = t.view(-1, 1).float()
|
|
82
|
+
args = t_flat * freqs.unsqueeze(0)
|
|
83
|
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
84
|
+
if self.dim % 2:
|
|
85
|
+
embedding = F.pad(embedding, (0, 1))
|
|
86
|
+
return self.mlp(embedding)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# ── Patch Embedding ─────────────────────────────────────────────────────
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class PatchEmbed(nn.Module):
|
|
93
|
+
"""Convert spatial latent maps into a sequence of patch tokens."""
|
|
94
|
+
|
|
95
|
+
def __init__(self, in_channels: int, hidden_dim: int, patch_size: int = 2) -> None:
|
|
96
|
+
super().__init__()
|
|
97
|
+
self.patch_size = patch_size
|
|
98
|
+
self.proj = nn.Conv2d(
|
|
99
|
+
in_channels, hidden_dim, kernel_size=patch_size, stride=patch_size
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
103
|
+
# x: (B, C, H, W) → (B, N, D)
|
|
104
|
+
x = self.proj(x) # (B, D, H', W')
|
|
105
|
+
return x.flatten(2).transpose(1, 2)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
# ── DiT Block ───────────────────────────────────────────────────────────
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class DiTBlock(nn.Module):
|
|
112
|
+
"""Diffusion Transformer block with AdaLN-Zero conditioning.
|
|
113
|
+
|
|
114
|
+
Each block contains:
|
|
115
|
+
1. AdaLN-modulated self-attention (with optional QK-Norm)
|
|
116
|
+
2. Optional cross-attention to brain conditioning tokens
|
|
117
|
+
3. AdaLN-modulated SwiGLU feedforward
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
def __init__(
|
|
121
|
+
self,
|
|
122
|
+
hidden_dim: int,
|
|
123
|
+
num_heads: int,
|
|
124
|
+
mlp_ratio: float = 4.0,
|
|
125
|
+
qk_norm: bool = True,
|
|
126
|
+
use_cross_attn: bool = True,
|
|
127
|
+
cond_dim: int = 768,
|
|
128
|
+
) -> None:
|
|
129
|
+
super().__init__()
|
|
130
|
+
self.hidden_dim = hidden_dim
|
|
131
|
+
self.use_cross_attn = use_cross_attn
|
|
132
|
+
|
|
133
|
+
# AdaLN modulation: produces 6 vectors (shift/scale/gate × 2)
|
|
134
|
+
n_adaln = 9 if use_cross_attn else 6
|
|
135
|
+
self.adaLN_modulation = nn.Sequential(
|
|
136
|
+
nn.SiLU(),
|
|
137
|
+
nn.Linear(hidden_dim, n_adaln * hidden_dim),
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Self-attention
|
|
141
|
+
self.norm1 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
|
|
142
|
+
self.attn = nn.MultiheadAttention(
|
|
143
|
+
hidden_dim, num_heads, batch_first=True, dropout=0.0
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# QK-Norm: normalize Q and K for training stability
|
|
147
|
+
self.qk_norm = qk_norm
|
|
148
|
+
if qk_norm:
|
|
149
|
+
self.q_norm = nn.LayerNorm(hidden_dim // num_heads, eps=1e-6)
|
|
150
|
+
self.k_norm = nn.LayerNorm(hidden_dim // num_heads, eps=1e-6)
|
|
151
|
+
self.num_heads = num_heads
|
|
152
|
+
|
|
153
|
+
# Cross-attention to brain conditioning
|
|
154
|
+
if use_cross_attn:
|
|
155
|
+
self.norm_cross = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
|
|
156
|
+
self.cross_attn = nn.MultiheadAttention(
|
|
157
|
+
hidden_dim, num_heads, batch_first=True, dropout=0.0
|
|
158
|
+
)
|
|
159
|
+
self.cond_proj = (
|
|
160
|
+
nn.Linear(cond_dim, hidden_dim)
|
|
161
|
+
if cond_dim != hidden_dim
|
|
162
|
+
else nn.Identity()
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# Feedforward
|
|
166
|
+
self.norm2 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
|
|
167
|
+
mlp_hidden = int(hidden_dim * mlp_ratio)
|
|
168
|
+
self.mlp = SwiGLU(hidden_dim, mlp_hidden)
|
|
169
|
+
|
|
170
|
+
# Zero-initialize the gating parameters
|
|
171
|
+
nn.init.zeros_(self.adaLN_modulation[-1].weight)
|
|
172
|
+
nn.init.zeros_(self.adaLN_modulation[-1].bias)
|
|
173
|
+
|
|
174
|
+
def _qk_norm_attn(
|
|
175
|
+
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
|
|
176
|
+
) -> torch.Tensor:
|
|
177
|
+
"""Self-attention with per-head QK normalization."""
|
|
178
|
+
B, N, D = q.shape
|
|
179
|
+
head_dim = D // self.num_heads
|
|
180
|
+
|
|
181
|
+
# Reshape to (B, heads, N, head_dim)
|
|
182
|
+
q = q.view(B, N, self.num_heads, head_dim).transpose(1, 2)
|
|
183
|
+
k = k.view(B, N, self.num_heads, head_dim).transpose(1, 2)
|
|
184
|
+
v = v.view(B, N, self.num_heads, head_dim).transpose(1, 2)
|
|
185
|
+
|
|
186
|
+
# Normalize Q and K
|
|
187
|
+
q = self.q_norm(q)
|
|
188
|
+
k = self.k_norm(k)
|
|
189
|
+
|
|
190
|
+
# Scaled dot-product attention
|
|
191
|
+
scale = head_dim ** -0.5
|
|
192
|
+
attn = (q @ k.transpose(-2, -1)) * scale
|
|
193
|
+
attn = attn.softmax(dim=-1)
|
|
194
|
+
out = attn @ v
|
|
195
|
+
|
|
196
|
+
return out.transpose(1, 2).reshape(B, N, D)
|
|
197
|
+
|
|
198
|
+
def forward(
|
|
199
|
+
self,
|
|
200
|
+
x: torch.Tensor,
|
|
201
|
+
c: torch.Tensor,
|
|
202
|
+
brain_tokens: torch.Tensor | None = None,
|
|
203
|
+
) -> torch.Tensor:
|
|
204
|
+
"""Forward pass.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
x: Patch tokens ``(B, N, D)``.
|
|
208
|
+
c: Conditioning vector ``(B, D)`` — fused timestep + brain global.
|
|
209
|
+
brain_tokens: Brain conditioning tokens ``(B, T, cond_dim)`` for
|
|
210
|
+
cross-attention.
|
|
211
|
+
"""
|
|
212
|
+
# Compute all modulation parameters at once
|
|
213
|
+
mod = self.adaLN_modulation(c)
|
|
214
|
+
if self.use_cross_attn:
|
|
215
|
+
(
|
|
216
|
+
shift_sa, scale_sa, gate_sa,
|
|
217
|
+
shift_ca, scale_ca, gate_ca,
|
|
218
|
+
shift_ff, scale_ff, gate_ff,
|
|
219
|
+
) = mod.chunk(9, dim=-1)
|
|
220
|
+
else:
|
|
221
|
+
shift_sa, scale_sa, gate_sa, shift_ff, scale_ff, gate_ff = mod.chunk(
|
|
222
|
+
6, dim=-1
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# 1. Self-attention with AdaLN
|
|
226
|
+
h = modulate(self.norm1(x), shift_sa, scale_sa)
|
|
227
|
+
if self.qk_norm:
|
|
228
|
+
h = self._qk_norm_attn(h, h, h)
|
|
229
|
+
else:
|
|
230
|
+
h, _ = self.attn(h, h, h, need_weights=False)
|
|
231
|
+
x = x + gate_sa.unsqueeze(1) * h
|
|
232
|
+
|
|
233
|
+
# 2. Cross-attention to brain tokens
|
|
234
|
+
if self.use_cross_attn and brain_tokens is not None:
|
|
235
|
+
h = modulate(self.norm_cross(x), shift_ca, scale_ca)
|
|
236
|
+
kv = self.cond_proj(brain_tokens)
|
|
237
|
+
h, _ = self.cross_attn(h, kv, kv, need_weights=False)
|
|
238
|
+
x = x + gate_ca.unsqueeze(1) * h
|
|
239
|
+
|
|
240
|
+
# 3. Feedforward with AdaLN
|
|
241
|
+
h = modulate(self.norm2(x), shift_ff, scale_ff)
|
|
242
|
+
h = self.mlp(h)
|
|
243
|
+
x = x + gate_ff.unsqueeze(1) * h
|
|
244
|
+
|
|
245
|
+
return x
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
# ── Final Layer ─────────────────────────────────────────────────────────
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class FinalLayer(nn.Module):
|
|
252
|
+
"""DiT final layer: AdaLN → linear projection back to patch space."""
|
|
253
|
+
|
|
254
|
+
def __init__(self, hidden_dim: int, patch_size: int, out_channels: int) -> None:
|
|
255
|
+
super().__init__()
|
|
256
|
+
self.norm = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
|
|
257
|
+
self.adaLN = nn.Sequential(
|
|
258
|
+
nn.SiLU(),
|
|
259
|
+
nn.Linear(hidden_dim, 2 * hidden_dim),
|
|
260
|
+
)
|
|
261
|
+
self.proj = nn.Linear(hidden_dim, patch_size * patch_size * out_channels)
|
|
262
|
+
|
|
263
|
+
nn.init.zeros_(self.adaLN[-1].weight)
|
|
264
|
+
nn.init.zeros_(self.adaLN[-1].bias)
|
|
265
|
+
nn.init.zeros_(self.proj.weight)
|
|
266
|
+
nn.init.zeros_(self.proj.bias)
|
|
267
|
+
|
|
268
|
+
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
|
269
|
+
shift, scale = self.adaLN(c).chunk(2, dim=-1)
|
|
270
|
+
x = modulate(self.norm(x), shift, scale)
|
|
271
|
+
return self.proj(x)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
# ── Diffusion Transformer ──────────────────────────────────────────────
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class DiffusionTransformer(nn.Module):
|
|
278
|
+
"""Diffusion Transformer (DiT) for brain-conditioned generation.
|
|
279
|
+
|
|
280
|
+
Operates on latent patches and predicts the velocity field for
|
|
281
|
+
rectified flow matching.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
config: DiT configuration.
|
|
285
|
+
img_size: Spatial size of the latent input (H = W).
|
|
286
|
+
"""
|
|
287
|
+
|
|
288
|
+
def __init__(self, config: DiTConfig | None = None, img_size: int = 32) -> None:
|
|
289
|
+
super().__init__()
|
|
290
|
+
cfg = config or DiTConfig()
|
|
291
|
+
self.config = cfg
|
|
292
|
+
self.img_size = img_size
|
|
293
|
+
|
|
294
|
+
# Patch embedding
|
|
295
|
+
self.patch_embed = PatchEmbed(cfg.in_channels, cfg.hidden_dim, cfg.patch_size)
|
|
296
|
+
n_patches = (img_size // cfg.patch_size) ** 2
|
|
297
|
+
|
|
298
|
+
# Learned positional embedding
|
|
299
|
+
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches, cfg.hidden_dim))
|
|
300
|
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
|
301
|
+
|
|
302
|
+
# Timestep embedding
|
|
303
|
+
self.time_embed = TimestepEmbedding(cfg.hidden_dim)
|
|
304
|
+
|
|
305
|
+
# Brain conditioning global projection
|
|
306
|
+
self.cond_global_proj = nn.Sequential(
|
|
307
|
+
nn.Linear(cfg.cond_dim, cfg.hidden_dim),
|
|
308
|
+
nn.SiLU(),
|
|
309
|
+
nn.Linear(cfg.hidden_dim, cfg.hidden_dim),
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
# Transformer blocks
|
|
313
|
+
self.blocks = nn.ModuleList(
|
|
314
|
+
[
|
|
315
|
+
DiTBlock(
|
|
316
|
+
hidden_dim=cfg.hidden_dim,
|
|
317
|
+
num_heads=cfg.num_heads,
|
|
318
|
+
mlp_ratio=cfg.mlp_ratio,
|
|
319
|
+
qk_norm=cfg.qk_norm,
|
|
320
|
+
use_cross_attn=cfg.use_cross_attn,
|
|
321
|
+
cond_dim=cfg.cond_dim,
|
|
322
|
+
)
|
|
323
|
+
for _ in range(cfg.depth)
|
|
324
|
+
]
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
# Output projection
|
|
328
|
+
self.final_layer = FinalLayer(cfg.hidden_dim, cfg.patch_size, cfg.in_channels)
|
|
329
|
+
|
|
330
|
+
self._initialize_weights()
|
|
331
|
+
|
|
332
|
+
def _initialize_weights(self) -> None:
|
|
333
|
+
"""Initialize weights following DiT conventions."""
|
|
334
|
+
|
|
335
|
+
def _init(m: nn.Module) -> None:
|
|
336
|
+
if isinstance(m, nn.Linear):
|
|
337
|
+
nn.init.xavier_uniform_(m.weight)
|
|
338
|
+
if m.bias is not None:
|
|
339
|
+
nn.init.zeros_(m.bias)
|
|
340
|
+
elif isinstance(m, nn.Conv2d):
|
|
341
|
+
nn.init.xavier_uniform_(m.weight)
|
|
342
|
+
if m.bias is not None:
|
|
343
|
+
nn.init.zeros_(m.bias)
|
|
344
|
+
|
|
345
|
+
self.apply(_init)
|
|
346
|
+
|
|
347
|
+
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
|
|
348
|
+
"""Reshape patch tokens back to spatial latent maps.
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
x: ``(B, N, patch_size² × C)``
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
``(B, C, H, W)``
|
|
355
|
+
"""
|
|
356
|
+
p = self.config.patch_size
|
|
357
|
+
c = self.config.in_channels
|
|
358
|
+
h = w = self.img_size // p
|
|
359
|
+
x = x.view(-1, h, w, p, p, c)
|
|
360
|
+
x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
|
|
361
|
+
return x.view(-1, c, h * p, w * p)
|
|
362
|
+
|
|
363
|
+
def forward(
|
|
364
|
+
self,
|
|
365
|
+
x: torch.Tensor,
|
|
366
|
+
t: torch.Tensor,
|
|
367
|
+
brain_global: torch.Tensor,
|
|
368
|
+
brain_tokens: torch.Tensor | None = None,
|
|
369
|
+
) -> torch.Tensor:
|
|
370
|
+
"""Predict velocity field v(x_t, t | brain).
|
|
371
|
+
|
|
372
|
+
Args:
|
|
373
|
+
x: Noisy latent ``(B, C, H, W)``.
|
|
374
|
+
t: Timestep ``(B,)`` in ``[0, 1]``.
|
|
375
|
+
brain_global: Global brain embedding ``(B, cond_dim)`` — pooled
|
|
376
|
+
fMRI representation used for AdaLN conditioning.
|
|
377
|
+
brain_tokens: Sequence of brain tokens ``(B, T, cond_dim)`` for
|
|
378
|
+
cross-attention. If None, cross-attention is skipped.
|
|
379
|
+
|
|
380
|
+
Returns:
|
|
381
|
+
Predicted velocity ``(B, C, H, W)``.
|
|
382
|
+
"""
|
|
383
|
+
# Embed patches + add positional encoding
|
|
384
|
+
x = self.patch_embed(x) + self.pos_embed
|
|
385
|
+
|
|
386
|
+
# Condition = timestep + brain global embedding
|
|
387
|
+
t_emb = self.time_embed(t)
|
|
388
|
+
c_global = self.cond_global_proj(brain_global)
|
|
389
|
+
c = t_emb + c_global
|
|
390
|
+
|
|
391
|
+
# Transformer blocks
|
|
392
|
+
for block in self.blocks:
|
|
393
|
+
x = block(x, c, brain_tokens)
|
|
394
|
+
|
|
395
|
+
# Final output
|
|
396
|
+
x = self.final_layer(x, c)
|
|
397
|
+
return self.unpatchify(x)
|