cortexflowx 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cortexflow/__init__.py ADDED
@@ -0,0 +1,78 @@
1
+ """CortexFlow — Brain decoding with Diffusion Transformers & Flow Matching.
2
+
3
+ Reconstruct what someone saw, heard, or thought from fMRI brain activity.
4
+
5
+ Architecture: fMRI → BrainEncoder → DiT (flow matching) → stimulus
6
+
7
+ Modules:
8
+ dit Diffusion Transformer backbone (AdaLN-Zero, QK-Norm, SwiGLU)
9
+ flow_matching Rectified flow training & ODE sampling
10
+ vae Latent space compression (2D images, 1D audio)
11
+ brain_encoder fMRI → conditioning embeddings (global + tokens)
12
+ brain2img Full brain → image pipeline
13
+ brain2audio Full brain → audio pipeline
14
+ brain2text Full brain → text pipeline (autoregressive)
15
+ training Training loops, schedulers, synthetic data
16
+ """
17
+
18
+ from cortexflow._types import (
19
+ BrainData,
20
+ CortexFlowError,
21
+ DiTConfig,
22
+ FlowConfig,
23
+ Modality,
24
+ ReconstructionResult,
25
+ TrainingConfig,
26
+ VAEConfig,
27
+ )
28
+ from cortexflow.brain2audio import AudioDiT, Brain2Audio, build_brain2audio
29
+ from cortexflow.brain2img import Brain2Image, build_brain2img
30
+ from cortexflow.brain2text import Brain2Text, BrainTextDecoder, build_brain2text
31
+ from cortexflow.brain_encoder import (
32
+ BrainEncoder,
33
+ ROIBrainEncoder,
34
+ SubjectAdapter,
35
+ make_synthetic_brain_data,
36
+ )
37
+ from cortexflow.dit import DiffusionTransformer
38
+ from cortexflow.flow_matching import EMAModel, RectifiedFlowMatcher
39
+ from cortexflow.training import SyntheticBrainDataset, Trainer, WarmupCosineScheduler
40
+ from cortexflow.vae import AudioVAE, LatentVAE
41
+
42
+ __version__ = "0.1.0"
43
+
44
+ __all__ = [
45
+ # Types
46
+ "BrainData",
47
+ "ReconstructionResult",
48
+ "DiTConfig",
49
+ "FlowConfig",
50
+ "VAEConfig",
51
+ "TrainingConfig",
52
+ "Modality",
53
+ "CortexFlowError",
54
+ # Core
55
+ "DiffusionTransformer",
56
+ "RectifiedFlowMatcher",
57
+ "EMAModel",
58
+ "LatentVAE",
59
+ "AudioVAE",
60
+ "BrainEncoder",
61
+ "ROIBrainEncoder",
62
+ "SubjectAdapter",
63
+ # Pipelines
64
+ "Brain2Image",
65
+ "Brain2Audio",
66
+ "Brain2Text",
67
+ "AudioDiT",
68
+ "BrainTextDecoder",
69
+ # Factories
70
+ "build_brain2img",
71
+ "build_brain2audio",
72
+ "build_brain2text",
73
+ "make_synthetic_brain_data",
74
+ # Training
75
+ "Trainer",
76
+ "WarmupCosineScheduler",
77
+ "SyntheticBrainDataset",
78
+ ]
cortexflow/_types.py ADDED
@@ -0,0 +1,115 @@
1
+ """Core data types for cortexflow."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from enum import Enum
7
+ from typing import Any
8
+
9
+ import torch
10
+
11
+
12
+ class Modality(str, Enum):
13
+ """Output modality for brain decoding."""
14
+
15
+ IMAGE = "image"
16
+ AUDIO = "audio"
17
+ TEXT = "text"
18
+
19
+
20
+ @dataclass
21
+ class BrainData:
22
+ """Container for fMRI brain activity data.
23
+
24
+ Attributes:
25
+ voxels: Tensor of shape ``(batch, n_voxels)`` or ``(n_voxels,)``.
26
+ subject_id: Optional subject identifier for subject-specific adapters.
27
+ roi_mask: Optional boolean mask indicating which voxels belong to ROIs.
28
+ tr: Repetition time in seconds (fMRI temporal resolution).
29
+ """
30
+
31
+ voxels: torch.Tensor
32
+ subject_id: str | None = None
33
+ roi_mask: torch.Tensor | None = None
34
+ tr: float = 2.0
35
+
36
+ @property
37
+ def n_voxels(self) -> int:
38
+ return self.voxels.shape[-1]
39
+
40
+ @property
41
+ def batch_size(self) -> int:
42
+ if self.voxels.ndim == 1:
43
+ return 1
44
+ return self.voxels.shape[0]
45
+
46
+
47
+ @dataclass
48
+ class ReconstructionResult:
49
+ """Result of a brain decoding reconstruction."""
50
+
51
+ modality: Modality
52
+ output: torch.Tensor # decoded output (image / waveform / token ids)
53
+ brain_condition: torch.Tensor # the conditioning embeddings used
54
+ n_steps: int = 50 # diffusion steps used
55
+ cfg_scale: float = 1.0 # classifier-free guidance scale
56
+ metadata: dict[str, Any] = field(default_factory=dict)
57
+
58
+
59
+ @dataclass
60
+ class TrainingConfig:
61
+ """Configuration for training a brain decoder."""
62
+
63
+ learning_rate: float = 1e-4
64
+ batch_size: int = 16
65
+ n_epochs: int = 100
66
+ warmup_steps: int = 500
67
+ weight_decay: float = 0.01
68
+ ema_decay: float = 0.9999
69
+ grad_clip: float = 1.0
70
+ mixed_precision: bool = False
71
+ log_every: int = 100
72
+ save_every: int = 1000
73
+ eval_every: int = 500
74
+
75
+
76
+ @dataclass
77
+ class DiTConfig:
78
+ """Configuration for the Diffusion Transformer."""
79
+
80
+ in_channels: int = 4 # latent channels
81
+ hidden_dim: int = 768
82
+ depth: int = 12
83
+ num_heads: int = 12
84
+ patch_size: int = 2
85
+ cond_dim: int = 768 # brain conditioning dimension
86
+ mlp_ratio: float = 4.0
87
+ qk_norm: bool = True
88
+ use_cross_attn: bool = True
89
+
90
+
91
+ @dataclass
92
+ class VAEConfig:
93
+ """Configuration for the latent VAE."""
94
+
95
+ in_channels: int = 3 # RGB
96
+ latent_channels: int = 4
97
+ hidden_dims: list[int] = field(default_factory=lambda: [64, 128, 256, 512])
98
+ kl_weight: float = 1e-6
99
+
100
+
101
+ @dataclass
102
+ class FlowConfig:
103
+ """Configuration for flow matching."""
104
+
105
+ num_steps: int = 50
106
+ sigma_min: float = 1e-5
107
+ logit_normal: bool = True # SD3-style timestep sampling
108
+ logit_normal_mean: float = 0.0
109
+ logit_normal_std: float = 1.0
110
+ solver: str = "euler" # "euler" or "midpoint"
111
+ cfg_scale: float = 4.0
112
+
113
+
114
+ class CortexFlowError(Exception):
115
+ """Base exception for cortexflow."""
@@ -0,0 +1,298 @@
1
+ """Brain → Audio reconstruction pipeline.
2
+
3
+ Reconstruct what someone heard from their fMRI activity.
4
+
5
+ Architecture: fMRI → BrainEncoder → DiT (flow matching on mel spectrograms)
6
+ → Griffin-Lim or learned vocoder → waveform.
7
+
8
+ The DiT operates on 1D latent sequences (compressed mel spectrograms)
9
+ rather than 2D spatial maps.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import math
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+
20
+ from cortexflow._types import (
21
+ BrainData,
22
+ DiTConfig,
23
+ FlowConfig,
24
+ Modality,
25
+ ReconstructionResult,
26
+ )
27
+ from cortexflow.brain_encoder import BrainEncoder
28
+ from cortexflow.flow_matching import RectifiedFlowMatcher
29
+
30
+
31
+ # ── 1D DiT for audio ────────────────────────────────────────────────────
32
+
33
+
34
+ class DiTBlock1D(nn.Module):
35
+ """DiT block adapted for 1D sequences (audio spectrograms)."""
36
+
37
+ def __init__(self, hidden_dim: int, num_heads: int, cond_dim: int, mlp_ratio: float = 4.0):
38
+ super().__init__()
39
+ self.adaLN = nn.Sequential(nn.SiLU(), nn.Linear(hidden_dim, 6 * hidden_dim))
40
+ self.norm1 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
41
+ self.attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
42
+ self.norm_cross = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
43
+ self.cross_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
44
+ self.cond_proj = nn.Linear(cond_dim, hidden_dim) if cond_dim != hidden_dim else nn.Identity()
45
+ self.norm2 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
46
+ mlp_h = int(hidden_dim * mlp_ratio)
47
+ self.mlp = nn.Sequential(
48
+ nn.Linear(hidden_dim, mlp_h), nn.SiLU(),
49
+ nn.Linear(mlp_h, hidden_dim),
50
+ )
51
+ nn.init.zeros_(self.adaLN[-1].weight)
52
+ nn.init.zeros_(self.adaLN[-1].bias)
53
+
54
+ def forward(self, x: torch.Tensor, c: torch.Tensor, brain_tokens: torch.Tensor | None = None):
55
+ s1, sc1, g1, s2, sc2, g2 = self.adaLN(c).chunk(6, dim=-1)
56
+
57
+ h = self.norm1(x) * (1 + sc1.unsqueeze(1)) + s1.unsqueeze(1)
58
+ h, _ = self.attn(h, h, h, need_weights=False)
59
+ x = x + g1.unsqueeze(1) * h
60
+
61
+ if brain_tokens is not None:
62
+ h = self.norm_cross(x)
63
+ kv = self.cond_proj(brain_tokens)
64
+ h, _ = self.cross_attn(h, kv, kv, need_weights=False)
65
+ x = x + h
66
+
67
+ h = self.norm2(x) * (1 + sc2.unsqueeze(1)) + s2.unsqueeze(1)
68
+ h = self.mlp(h)
69
+ x = x + g2.unsqueeze(1) * h
70
+ return x
71
+
72
+
73
+ class AudioDiT(nn.Module):
74
+ """1D Diffusion Transformer for mel spectrogram generation.
75
+
76
+ Operates on a sequence of mel frames (or compressed latents).
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ n_mels: int = 80,
82
+ seq_len: int = 128,
83
+ hidden_dim: int = 256,
84
+ depth: int = 6,
85
+ num_heads: int = 8,
86
+ cond_dim: int = 256,
87
+ ):
88
+ super().__init__()
89
+ self.n_mels = n_mels
90
+ self.seq_len = seq_len
91
+ self.input_proj = nn.Linear(n_mels, hidden_dim)
92
+ self.pos_embed = nn.Parameter(torch.zeros(1, seq_len, hidden_dim))
93
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
94
+
95
+ # Timestep embedding
96
+ self.time_embed = nn.Sequential(
97
+ nn.Linear(hidden_dim, hidden_dim * 4), nn.SiLU(),
98
+ nn.Linear(hidden_dim * 4, hidden_dim),
99
+ )
100
+ self.time_dim = hidden_dim
101
+ self.cond_proj = nn.Sequential(nn.Linear(cond_dim, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, hidden_dim))
102
+
103
+ self.blocks = nn.ModuleList([
104
+ DiTBlock1D(hidden_dim, num_heads, cond_dim) for _ in range(depth)
105
+ ])
106
+ self.final_norm = nn.LayerNorm(hidden_dim, eps=1e-6)
107
+ self.output_proj = nn.Linear(hidden_dim, n_mels)
108
+ nn.init.zeros_(self.output_proj.weight)
109
+ nn.init.zeros_(self.output_proj.bias)
110
+
111
+ def _sinusoidal_embed(self, t: torch.Tensor) -> torch.Tensor:
112
+ half = self.time_dim // 2
113
+ freqs = torch.exp(-math.log(10000) * torch.arange(half, device=t.device).float() / half)
114
+ args = t.view(-1, 1) * freqs.unsqueeze(0)
115
+ return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
116
+
117
+ def forward(
118
+ self,
119
+ x: torch.Tensor,
120
+ t: torch.Tensor,
121
+ brain_global: torch.Tensor,
122
+ brain_tokens: torch.Tensor | None = None,
123
+ ) -> torch.Tensor:
124
+ """Predict velocity for mel spectrogram flow matching.
125
+
126
+ Args:
127
+ x: ``(B, n_mels, T)`` noisy mel spectrogram.
128
+ t: ``(B,)`` timestep.
129
+ brain_global: ``(B, cond_dim)``
130
+ brain_tokens: ``(B, K, cond_dim)``
131
+
132
+ Returns:
133
+ ``(B, n_mels, T)`` predicted velocity.
134
+ """
135
+ B = x.shape[0]
136
+ # (B, n_mels, T) → (B, T, n_mels) → (B, T, hidden)
137
+ x = x.transpose(1, 2)
138
+ x = self.input_proj(x) + self.pos_embed[:, :x.shape[1]]
139
+
140
+ t_emb = self.time_embed(self._sinusoidal_embed(t))
141
+ c = t_emb + self.cond_proj(brain_global)
142
+
143
+ for block in self.blocks:
144
+ x = block(x, c, brain_tokens)
145
+
146
+ x = self.final_norm(x)
147
+ x = self.output_proj(x) # (B, T, n_mels)
148
+ return x.transpose(1, 2) # (B, n_mels, T)
149
+
150
+
151
+ # ── Brain2Audio Pipeline ────────────────────────────────────────────────
152
+
153
+
154
+ class Brain2Audio(nn.Module):
155
+ """Reconstruct audio from brain activity using 1D DiT + flow matching.
156
+
157
+ Pipeline::
158
+
159
+ fMRI → BrainEncoder → AudioDiT (flow matching) → mel spectrogram
160
+ → Griffin-Lim → waveform
161
+
162
+ Args:
163
+ n_voxels: Number of fMRI voxels.
164
+ n_mels: Number of mel frequency bins.
165
+ audio_len: Number of mel spectrogram frames.
166
+ sample_rate: Audio sample rate (for mel spectrogram computation).
167
+ """
168
+
169
+ def __init__(
170
+ self,
171
+ n_voxels: int = 1024,
172
+ n_mels: int = 80,
173
+ audio_len: int = 128,
174
+ hidden_dim: int = 256,
175
+ depth: int = 6,
176
+ num_heads: int = 8,
177
+ sample_rate: int = 16000,
178
+ ):
179
+ super().__init__()
180
+ self.n_mels = n_mels
181
+ self.audio_len = audio_len
182
+ self.sample_rate = sample_rate
183
+
184
+ cond_dim = hidden_dim
185
+ self.brain_encoder = BrainEncoder(
186
+ n_voxels=n_voxels, cond_dim=cond_dim, n_tokens=16,
187
+ )
188
+ self.dit = AudioDiT(
189
+ n_mels=n_mels, seq_len=audio_len, hidden_dim=hidden_dim,
190
+ depth=depth, num_heads=num_heads, cond_dim=cond_dim,
191
+ )
192
+ self.flow_matcher = RectifiedFlowMatcher(FlowConfig())
193
+
194
+ # Unconditional embeddings for CFG
195
+ self.uncond_global = nn.Parameter(torch.zeros(1, cond_dim))
196
+ self.uncond_tokens = nn.Parameter(torch.zeros(1, 16, cond_dim))
197
+
198
+ def training_loss(self, mel: torch.Tensor, brain_data: BrainData) -> torch.Tensor:
199
+ """Flow matching loss on mel spectrograms.
200
+
201
+ Args:
202
+ mel: Target mel spectrogram ``(B, n_mels, T)``.
203
+ brain_data: Corresponding fMRI.
204
+ """
205
+ brain_global, brain_tokens = self.brain_encoder(brain_data.voxels)
206
+ return self.flow_matcher.compute_loss(self.dit, mel, brain_global, brain_tokens)
207
+
208
+ @torch.no_grad()
209
+ def reconstruct(
210
+ self,
211
+ brain_data: BrainData,
212
+ num_steps: int = 50,
213
+ cfg_scale: float = 3.0,
214
+ ) -> ReconstructionResult:
215
+ """Reconstruct audio mel spectrogram from brain activity."""
216
+ B = brain_data.batch_size
217
+ device = brain_data.voxels.device
218
+ brain_global, brain_tokens = self.brain_encoder(brain_data.voxels)
219
+
220
+ mel_shape = (B, self.n_mels, self.audio_len)
221
+ mel = self.flow_matcher.sample(
222
+ self.dit, mel_shape, brain_global, brain_tokens,
223
+ num_steps=num_steps, cfg_scale=cfg_scale,
224
+ brain_global_uncond=self.uncond_global.expand(B, -1),
225
+ brain_tokens_uncond=self.uncond_tokens.expand(B, -1, -1),
226
+ )
227
+ return ReconstructionResult(
228
+ modality=Modality.AUDIO,
229
+ output=mel,
230
+ brain_condition=brain_global,
231
+ n_steps=num_steps,
232
+ cfg_scale=cfg_scale,
233
+ )
234
+
235
+ @staticmethod
236
+ def mel_to_waveform(mel: torch.Tensor, n_fft: int = 1024, hop_length: int = 256) -> torch.Tensor:
237
+ """Convert mel spectrogram to waveform via Griffin-Lim.
238
+
239
+ Args:
240
+ mel: ``(B, n_mels, T)`` mel spectrogram (linear scale).
241
+ n_fft: FFT size.
242
+ hop_length: Hop length.
243
+
244
+ Returns:
245
+ ``(B, samples)`` waveform.
246
+ """
247
+ B, n_mels, T = mel.shape
248
+ # Create simple mel filterbank inverse (pseudo-inverse approach)
249
+ n_freqs = n_fft // 2 + 1
250
+ # Approximate: project mel back to linear spectrogram
251
+ mel_basis = torch.linspace(0, 1, n_mels, device=mel.device).unsqueeze(1)
252
+ freq_basis = torch.linspace(0, 1, n_freqs, device=mel.device).unsqueeze(0)
253
+ filterbank = torch.exp(-0.5 * ((mel_basis - freq_basis) / 0.05) ** 2)
254
+ filterbank = filterbank / (filterbank.sum(dim=0, keepdim=True) + 1e-8)
255
+
256
+ # Mel → linear spectrogram (pseudo-inverse)
257
+ fb_pinv = filterbank.T # (n_freqs, n_mels)
258
+ magnitude = torch.matmul(fb_pinv.unsqueeze(0), mel.clamp(min=0)) # (B, n_freqs, T)
259
+
260
+ # Griffin-Lim: iterative phase estimation
261
+ window = torch.hann_window(n_fft, device=mel.device)
262
+ out_length = T * hop_length
263
+ phase = torch.randn(B, n_freqs, T, device=mel.device) * 2 * math.pi
264
+ for _ in range(32):
265
+ stft = magnitude * torch.exp(1j * phase)
266
+ waveform = torch.istft(
267
+ stft, n_fft=n_fft, hop_length=hop_length,
268
+ window=window, length=out_length,
269
+ )
270
+ new_stft = torch.stft(
271
+ waveform, n_fft=n_fft, hop_length=hop_length,
272
+ window=window, return_complex=True,
273
+ )
274
+ # Trim or pad to match magnitude shape
275
+ nt = min(new_stft.shape[-1], T)
276
+ phase = torch.zeros_like(magnitude)
277
+ phase[:, :, :nt] = new_stft[:, :, :nt].angle()
278
+
279
+ stft = magnitude * torch.exp(1j * phase)
280
+ waveform = torch.istft(
281
+ stft, n_fft=n_fft, hop_length=hop_length,
282
+ window=window, length=out_length,
283
+ )
284
+ return waveform
285
+
286
+
287
+ def build_brain2audio(
288
+ n_voxels: int = 1024,
289
+ n_mels: int = 80,
290
+ audio_len: int = 128,
291
+ hidden_dim: int = 256,
292
+ depth: int = 6,
293
+ ) -> Brain2Audio:
294
+ """Build a Brain2Audio model with sensible defaults."""
295
+ return Brain2Audio(
296
+ n_voxels=n_voxels, n_mels=n_mels, audio_len=audio_len,
297
+ hidden_dim=hidden_dim, depth=depth,
298
+ )