cortexflowx 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,236 @@
1
+ """Rectified Flow Matching for training and sampling.
2
+
3
+ Implements the rectified flow framework from Lipman et al. (2022) and
4
+ the improvements from Esser et al. (2024, Stable Diffusion 3):
5
+
6
+ - **Linear interpolation** paths: x_t = (1 - t) · x_0 + t · x_1
7
+ - **Velocity prediction**: v = x_1 - x_0 (the model learns the vector
8
+ field that transports noise to data)
9
+ - **Logit-normal timestep sampling**: biases training toward perceptually
10
+ relevant noise levels (SD3 improvement over uniform sampling)
11
+ - **Euler / Midpoint ODE solvers** for inference
12
+
13
+ The training loss is simply MSE between predicted and target velocity:
14
+ L = E_{t, x_0, x_1} [ || v_θ(x_t, t, c) - (x_1 - x_0) ||² ]
15
+
16
+ Reference: "Flow Matching Guide and Code" (arXiv:2412.06264)
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import math
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+
27
+ from cortexflow._types import FlowConfig
28
+
29
+
30
+ class RectifiedFlowMatcher:
31
+ """Rectified flow matching training and sampling.
32
+
33
+ Usage::
34
+
35
+ fm = RectifiedFlowMatcher()
36
+
37
+ # Training step
38
+ loss = fm.compute_loss(model, x_clean, brain_global, brain_tokens)
39
+ loss.backward()
40
+
41
+ # Sampling
42
+ x_gen = fm.sample(model, shape=(B, C, H, W), brain_global=bg, brain_tokens=bt)
43
+ """
44
+
45
+ def __init__(self, config: FlowConfig | None = None) -> None:
46
+ self.config = config or FlowConfig()
47
+
48
+ # ── Timestep Sampling ────────────────────────────────────────────
49
+
50
+ def sample_timesteps(
51
+ self, batch_size: int, device: torch.device
52
+ ) -> torch.Tensor:
53
+ """Sample timesteps t ∈ (0, 1).
54
+
55
+ If ``logit_normal`` is enabled (SD3-style), samples from a
56
+ logit-normal distribution which biases toward intermediate noise
57
+ levels where the perceptual signal is strongest.
58
+ """
59
+ cfg = self.config
60
+
61
+ if cfg.logit_normal:
62
+ # Logit-normal: sample u ~ N(mean, std²), then t = sigmoid(u)
63
+ u = torch.randn(batch_size, device=device)
64
+ u = u * cfg.logit_normal_std + cfg.logit_normal_mean
65
+ t = torch.sigmoid(u)
66
+ else:
67
+ t = torch.rand(batch_size, device=device)
68
+
69
+ # Clamp away from exact 0 and 1 for numerical stability
70
+ return t.clamp(cfg.sigma_min, 1.0 - cfg.sigma_min)
71
+
72
+ # ── Training ─────────────────────────────────────────────────────
73
+
74
+ def compute_loss(
75
+ self,
76
+ model: nn.Module,
77
+ x_1: torch.Tensor,
78
+ brain_global: torch.Tensor,
79
+ brain_tokens: torch.Tensor | None = None,
80
+ ) -> torch.Tensor:
81
+ """Compute rectified flow matching loss.
82
+
83
+ Args:
84
+ model: DiT model that predicts velocity v(x_t, t, condition).
85
+ x_1: Clean data (target) ``(B, C, H, W)``.
86
+ brain_global: Global brain embedding ``(B, D)``.
87
+ brain_tokens: Brain token sequence ``(B, T, D)`` for cross-attention.
88
+
89
+ Returns:
90
+ Scalar MSE loss: E[ ||v_pred - v_target||² ]
91
+ """
92
+ B = x_1.shape[0]
93
+ device = x_1.device
94
+
95
+ # Sample timesteps
96
+ t = self.sample_timesteps(B, device)
97
+
98
+ # Sample noise
99
+ x_0 = torch.randn_like(x_1)
100
+
101
+ # Linear interpolation: x_t = (1-t) * noise + t * data
102
+ t_expand = t.view(B, *([1] * (x_1.ndim - 1)))
103
+ x_t = (1.0 - t_expand) * x_0 + t_expand * x_1
104
+
105
+ # Target velocity: points from noise to data
106
+ v_target = x_1 - x_0
107
+
108
+ # Model prediction
109
+ v_pred = model(x_t, t, brain_global, brain_tokens)
110
+
111
+ return F.mse_loss(v_pred, v_target)
112
+
113
+ # ── Sampling ─────────────────────────────────────────────────────
114
+
115
+ @torch.no_grad()
116
+ def sample(
117
+ self,
118
+ model: nn.Module,
119
+ shape: tuple[int, ...],
120
+ brain_global: torch.Tensor,
121
+ brain_tokens: torch.Tensor | None = None,
122
+ num_steps: int | None = None,
123
+ cfg_scale: float | None = None,
124
+ brain_global_uncond: torch.Tensor | None = None,
125
+ brain_tokens_uncond: torch.Tensor | None = None,
126
+ ) -> torch.Tensor:
127
+ """Generate samples by solving the ODE from noise to data.
128
+
129
+ Args:
130
+ model: Trained DiT model.
131
+ shape: Output shape ``(B, C, H, W)``.
132
+ brain_global: Global brain condition.
133
+ brain_tokens: Brain token sequence.
134
+ num_steps: Override number of ODE steps.
135
+ cfg_scale: Classifier-free guidance scale. 1.0 = no guidance.
136
+ brain_global_uncond: Unconditional brain embedding for CFG.
137
+ brain_tokens_uncond: Unconditional brain tokens for CFG.
138
+
139
+ Returns:
140
+ Generated latent ``(B, C, H, W)``.
141
+ """
142
+ cfg = self.config
143
+ steps = num_steps or cfg.num_steps
144
+ guidance = cfg_scale if cfg_scale is not None else cfg.cfg_scale
145
+ do_cfg = guidance > 1.0 and brain_global_uncond is not None
146
+ device = brain_global.device
147
+
148
+ # Start from pure noise
149
+ x = torch.randn(shape, device=device)
150
+ dt = 1.0 / steps
151
+
152
+ for i in range(steps):
153
+ t_val = i / steps
154
+ t = torch.full((shape[0],), t_val, device=device)
155
+
156
+ if cfg.solver == "midpoint":
157
+ # Midpoint method: evaluate at t + dt/2
158
+ v = self._get_velocity(
159
+ model, x, t, brain_global, brain_tokens,
160
+ guidance, do_cfg, brain_global_uncond, brain_tokens_uncond,
161
+ )
162
+ x_mid = x + v * (dt / 2)
163
+ t_mid = torch.full((shape[0],), t_val + dt / 2, device=device)
164
+ v_mid = self._get_velocity(
165
+ model, x_mid, t_mid, brain_global, brain_tokens,
166
+ guidance, do_cfg, brain_global_uncond, brain_tokens_uncond,
167
+ )
168
+ x = x + v_mid * dt
169
+ else:
170
+ # Euler method
171
+ v = self._get_velocity(
172
+ model, x, t, brain_global, brain_tokens,
173
+ guidance, do_cfg, brain_global_uncond, brain_tokens_uncond,
174
+ )
175
+ x = x + v * dt
176
+
177
+ return x
178
+
179
+ def _get_velocity(
180
+ self,
181
+ model: nn.Module,
182
+ x: torch.Tensor,
183
+ t: torch.Tensor,
184
+ brain_global: torch.Tensor,
185
+ brain_tokens: torch.Tensor | None,
186
+ guidance: float,
187
+ do_cfg: bool,
188
+ brain_global_uncond: torch.Tensor | None,
189
+ brain_tokens_uncond: torch.Tensor | None,
190
+ ) -> torch.Tensor:
191
+ """Get velocity with optional classifier-free guidance."""
192
+ if do_cfg and brain_global_uncond is not None:
193
+ # Conditional prediction
194
+ v_cond = model(x, t, brain_global, brain_tokens)
195
+ # Unconditional prediction
196
+ v_uncond = model(x, t, brain_global_uncond, brain_tokens_uncond)
197
+ # CFG interpolation
198
+ return v_uncond + guidance * (v_cond - v_uncond)
199
+ else:
200
+ return model(x, t, brain_global, brain_tokens)
201
+
202
+
203
+ class EMAModel:
204
+ """Exponential Moving Average of model parameters.
205
+
206
+ Maintains a shadow copy of model weights updated as:
207
+ θ_ema = decay · θ_ema + (1 - decay) · θ_model
208
+ """
209
+
210
+ def __init__(self, model: nn.Module, decay: float = 0.9999) -> None:
211
+ self.decay = decay
212
+ self.shadow: dict[str, torch.Tensor] = {}
213
+ for name, param in model.named_parameters():
214
+ if param.requires_grad:
215
+ self.shadow[name] = param.data.clone()
216
+
217
+ @torch.no_grad()
218
+ def update(self, model: nn.Module) -> None:
219
+ for name, param in model.named_parameters():
220
+ if name in self.shadow:
221
+ self.shadow[name].lerp_(param.data, 1.0 - self.decay)
222
+
223
+ def apply_to(self, model: nn.Module) -> dict[str, torch.Tensor]:
224
+ """Replace model params with EMA params. Returns original params."""
225
+ originals: dict[str, torch.Tensor] = {}
226
+ for name, param in model.named_parameters():
227
+ if name in self.shadow:
228
+ originals[name] = param.data.clone()
229
+ param.data.copy_(self.shadow[name])
230
+ return originals
231
+
232
+ def restore(self, model: nn.Module, originals: dict[str, torch.Tensor]) -> None:
233
+ """Restore original params after EMA evaluation."""
234
+ for name, param in model.named_parameters():
235
+ if name in originals:
236
+ param.data.copy_(originals[name])
cortexflow/training.py ADDED
@@ -0,0 +1,283 @@
1
+ """Training utilities for cortexflow models.
2
+
3
+ Provides a generic training loop, learning rate schedulers, and data
4
+ utilities for training brain decoders on fMRI datasets.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import math
10
+ from typing import Any, Callable, Iterator
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+ from cortexflow._types import TrainingConfig
16
+ from cortexflow.flow_matching import EMAModel
17
+
18
+
19
+ class WarmupCosineScheduler:
20
+ """Learning rate schedule: linear warmup → cosine decay.
21
+
22
+ Args:
23
+ optimizer: PyTorch optimizer.
24
+ warmup_steps: Linear warmup duration.
25
+ total_steps: Total training steps.
26
+ min_lr_ratio: Minimum LR as a fraction of peak.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ optimizer: torch.optim.Optimizer,
32
+ warmup_steps: int,
33
+ total_steps: int,
34
+ min_lr_ratio: float = 0.01,
35
+ ) -> None:
36
+ self.optimizer = optimizer
37
+ self.warmup_steps = warmup_steps
38
+ self.total_steps = total_steps
39
+ self.min_lr_ratio = min_lr_ratio
40
+ self.base_lrs = [pg["lr"] for pg in optimizer.param_groups]
41
+ self._step = 0
42
+
43
+ def step(self) -> None:
44
+ self._step += 1
45
+ lr_scale = self._get_scale(self._step)
46
+ for pg, base_lr in zip(self.optimizer.param_groups, self.base_lrs):
47
+ pg["lr"] = base_lr * lr_scale
48
+
49
+ def _get_scale(self, step: int) -> float:
50
+ if step <= self.warmup_steps:
51
+ return step / max(1, self.warmup_steps)
52
+ progress = (step - self.warmup_steps) / max(
53
+ 1, self.total_steps - self.warmup_steps
54
+ )
55
+ progress = min(progress, 1.0)
56
+ cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
57
+ return self.min_lr_ratio + (1.0 - self.min_lr_ratio) * cosine
58
+
59
+ @property
60
+ def current_lr(self) -> float:
61
+ return self.optimizer.param_groups[0]["lr"]
62
+
63
+
64
+ class SyntheticBrainDataset:
65
+ """Synthetic dataset for testing and development.
66
+
67
+ Yields (stimulus, fmri) pairs with configurable modality.
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ n_samples: int = 1000,
73
+ n_voxels: int = 1024,
74
+ modality: str = "image",
75
+ img_size: int = 64,
76
+ n_mels: int = 80,
77
+ audio_len: int = 128,
78
+ max_text_len: int = 64,
79
+ ) -> None:
80
+ self.n_samples = n_samples
81
+ self.n_voxels = n_voxels
82
+ self.modality = modality
83
+ self.img_size = img_size
84
+ self.n_mels = n_mels
85
+ self.audio_len = audio_len
86
+ self.max_text_len = max_text_len
87
+
88
+ def __len__(self) -> int:
89
+ return self.n_samples
90
+
91
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
92
+ """Generate a single synthetic sample."""
93
+ g = torch.Generator().manual_seed(idx)
94
+ fmri = torch.randn(self.n_voxels, generator=g)
95
+
96
+ if self.modality == "image":
97
+ stimulus = torch.rand(3, self.img_size, self.img_size, generator=g)
98
+ elif self.modality == "audio":
99
+ stimulus = torch.randn(self.n_mels, self.audio_len, generator=g).abs()
100
+ elif self.modality == "text":
101
+ # Random byte-level tokens (printable ASCII range)
102
+ tokens = torch.randint(32, 127, (self.max_text_len,), generator=g)
103
+ stimulus = tokens
104
+ else:
105
+ raise ValueError(f"Unknown modality: {self.modality}")
106
+
107
+ return {"fmri": fmri, "stimulus": stimulus}
108
+
109
+ def to_loader(
110
+ self, batch_size: int = 8, shuffle: bool = True
111
+ ) -> Iterator[dict[str, torch.Tensor]]:
112
+ """Simple batch iterator (no DataLoader dependency)."""
113
+ indices = list(range(self.n_samples))
114
+ if shuffle:
115
+ g = torch.Generator()
116
+ perm = torch.randperm(self.n_samples, generator=g).tolist()
117
+ indices = perm
118
+
119
+ for start in range(0, self.n_samples, batch_size):
120
+ batch_indices = indices[start : start + batch_size]
121
+ items = [self[i] for i in batch_indices]
122
+ batch: dict[str, torch.Tensor] = {}
123
+ for key in items[0]:
124
+ batch[key] = torch.stack([item[key] for item in items])
125
+ yield batch
126
+
127
+
128
+ class Trainer:
129
+ """Generic training loop for cortexflow models.
130
+
131
+ Handles optimizer setup, LR scheduling, EMA, gradient clipping,
132
+ and logging.
133
+
134
+ Args:
135
+ model: A cortexflow model with a ``training_loss`` method.
136
+ config: Training configuration.
137
+ device: Device to train on.
138
+ """
139
+
140
+ def __init__(
141
+ self,
142
+ model: nn.Module,
143
+ config: TrainingConfig | None = None,
144
+ device: str | torch.device = "cpu",
145
+ ) -> None:
146
+ self.model = model.to(device)
147
+ self.config = config or TrainingConfig()
148
+ self.device = torch.device(device)
149
+
150
+ cfg = self.config
151
+ self.optimizer = torch.optim.AdamW(
152
+ model.parameters(),
153
+ lr=cfg.learning_rate,
154
+ weight_decay=cfg.weight_decay,
155
+ betas=(0.9, 0.95),
156
+ )
157
+
158
+ self.ema: EMAModel | None = None
159
+ if cfg.ema_decay > 0:
160
+ self.ema = EMAModel(model, decay=cfg.ema_decay)
161
+
162
+ self.global_step = 0
163
+ self.scheduler: WarmupCosineScheduler | None = None
164
+ self._log_fn: Callable[[dict[str, Any]], None] | None = None
165
+
166
+ def set_logger(self, fn: Callable[[dict[str, Any]], None]) -> None:
167
+ """Set a logging callback: fn({"step": ..., "loss": ..., ...})."""
168
+ self._log_fn = fn
169
+
170
+ def _log(self, metrics: dict[str, Any]) -> None:
171
+ if self._log_fn is not None:
172
+ self._log_fn(metrics)
173
+
174
+ def train_step(
175
+ self,
176
+ batch: dict[str, torch.Tensor],
177
+ loss_fn: Callable[[nn.Module, dict[str, torch.Tensor]], torch.Tensor],
178
+ ) -> float:
179
+ """Execute a single training step.
180
+
181
+ Args:
182
+ batch: Dict of tensors (moved to device automatically).
183
+ loss_fn: Computes loss from (model, batch) → scalar.
184
+
185
+ Returns:
186
+ Loss value as float.
187
+ """
188
+ cfg = self.config
189
+ self.model.train()
190
+
191
+ # Move data to device
192
+ batch_dev = {
193
+ k: v.to(self.device) if isinstance(v, torch.Tensor) else v
194
+ for k, v in batch.items()
195
+ }
196
+
197
+ loss = loss_fn(self.model, batch_dev)
198
+ loss.backward()
199
+
200
+ # Gradient clipping
201
+ if cfg.grad_clip > 0:
202
+ nn.utils.clip_grad_norm_(self.model.parameters(), cfg.grad_clip)
203
+
204
+ self.optimizer.step()
205
+ self.optimizer.zero_grad()
206
+
207
+ # EMA update
208
+ if self.ema is not None:
209
+ self.ema.update(self.model)
210
+
211
+ # LR schedule
212
+ if self.scheduler is not None:
213
+ self.scheduler.step()
214
+
215
+ self.global_step += 1
216
+
217
+ # Logging
218
+ if self.global_step % cfg.log_every == 0:
219
+ lr = (
220
+ self.scheduler.current_lr
221
+ if self.scheduler
222
+ else cfg.learning_rate
223
+ )
224
+ self._log({
225
+ "step": self.global_step,
226
+ "loss": loss.item(),
227
+ "lr": lr,
228
+ })
229
+
230
+ return loss.item()
231
+
232
+ def fit(
233
+ self,
234
+ dataset: SyntheticBrainDataset,
235
+ loss_fn: Callable[[nn.Module, dict[str, torch.Tensor]], torch.Tensor],
236
+ n_epochs: int | None = None,
237
+ ) -> list[float]:
238
+ """Full training loop over a dataset.
239
+
240
+ Args:
241
+ dataset: Training data.
242
+ loss_fn: Loss function (model, batch) → scalar.
243
+ n_epochs: Override number of epochs.
244
+
245
+ Returns:
246
+ List of per-step losses.
247
+ """
248
+ cfg = self.config
249
+ epochs = n_epochs or cfg.n_epochs
250
+ steps_per_epoch = max(1, len(dataset) // cfg.batch_size)
251
+ total_steps = epochs * steps_per_epoch
252
+
253
+ self.scheduler = WarmupCosineScheduler(
254
+ self.optimizer, cfg.warmup_steps, total_steps
255
+ )
256
+
257
+ losses: list[float] = []
258
+ for epoch in range(epochs):
259
+ for batch in dataset.to_loader(batch_size=cfg.batch_size, shuffle=True):
260
+ loss = self.train_step(batch, loss_fn)
261
+ losses.append(loss)
262
+
263
+ return losses
264
+
265
+ def save_checkpoint(self, path: str) -> None:
266
+ """Save model + optimizer + EMA state."""
267
+ state = {
268
+ "model": self.model.state_dict(),
269
+ "optimizer": self.optimizer.state_dict(),
270
+ "global_step": self.global_step,
271
+ }
272
+ if self.ema is not None:
273
+ state["ema"] = self.ema.shadow
274
+ torch.save(state, path)
275
+
276
+ def load_checkpoint(self, path: str) -> None:
277
+ """Load model + optimizer + EMA state."""
278
+ state = torch.load(path, map_location=self.device, weights_only=False)
279
+ self.model.load_state_dict(state["model"])
280
+ self.optimizer.load_state_dict(state["optimizer"])
281
+ self.global_step = state.get("global_step", 0)
282
+ if self.ema is not None and "ema" in state:
283
+ self.ema.shadow = state["ema"]