cortexflowx 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortexflow/__init__.py +78 -0
- cortexflow/_types.py +115 -0
- cortexflow/brain2audio.py +298 -0
- cortexflow/brain2img.py +234 -0
- cortexflow/brain2text.py +278 -0
- cortexflow/brain_encoder.py +228 -0
- cortexflow/dit.py +397 -0
- cortexflow/flow_matching.py +236 -0
- cortexflow/training.py +283 -0
- cortexflow/vae.py +232 -0
- cortexflowx-0.1.0.dist-info/METADATA +218 -0
- cortexflowx-0.1.0.dist-info/RECORD +14 -0
- cortexflowx-0.1.0.dist-info/WHEEL +4 -0
- cortexflowx-0.1.0.dist-info/licenses/LICENSE +190 -0
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
"""Rectified Flow Matching for training and sampling.
|
|
2
|
+
|
|
3
|
+
Implements the rectified flow framework from Lipman et al. (2022) and
|
|
4
|
+
the improvements from Esser et al. (2024, Stable Diffusion 3):
|
|
5
|
+
|
|
6
|
+
- **Linear interpolation** paths: x_t = (1 - t) · x_0 + t · x_1
|
|
7
|
+
- **Velocity prediction**: v = x_1 - x_0 (the model learns the vector
|
|
8
|
+
field that transports noise to data)
|
|
9
|
+
- **Logit-normal timestep sampling**: biases training toward perceptually
|
|
10
|
+
relevant noise levels (SD3 improvement over uniform sampling)
|
|
11
|
+
- **Euler / Midpoint ODE solvers** for inference
|
|
12
|
+
|
|
13
|
+
The training loss is simply MSE between predicted and target velocity:
|
|
14
|
+
L = E_{t, x_0, x_1} [ || v_θ(x_t, t, c) - (x_1 - x_0) ||² ]
|
|
15
|
+
|
|
16
|
+
Reference: "Flow Matching Guide and Code" (arXiv:2412.06264)
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
import math
|
|
22
|
+
|
|
23
|
+
import torch
|
|
24
|
+
import torch.nn as nn
|
|
25
|
+
import torch.nn.functional as F
|
|
26
|
+
|
|
27
|
+
from cortexflow._types import FlowConfig
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class RectifiedFlowMatcher:
|
|
31
|
+
"""Rectified flow matching training and sampling.
|
|
32
|
+
|
|
33
|
+
Usage::
|
|
34
|
+
|
|
35
|
+
fm = RectifiedFlowMatcher()
|
|
36
|
+
|
|
37
|
+
# Training step
|
|
38
|
+
loss = fm.compute_loss(model, x_clean, brain_global, brain_tokens)
|
|
39
|
+
loss.backward()
|
|
40
|
+
|
|
41
|
+
# Sampling
|
|
42
|
+
x_gen = fm.sample(model, shape=(B, C, H, W), brain_global=bg, brain_tokens=bt)
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self, config: FlowConfig | None = None) -> None:
|
|
46
|
+
self.config = config or FlowConfig()
|
|
47
|
+
|
|
48
|
+
# ── Timestep Sampling ────────────────────────────────────────────
|
|
49
|
+
|
|
50
|
+
def sample_timesteps(
|
|
51
|
+
self, batch_size: int, device: torch.device
|
|
52
|
+
) -> torch.Tensor:
|
|
53
|
+
"""Sample timesteps t ∈ (0, 1).
|
|
54
|
+
|
|
55
|
+
If ``logit_normal`` is enabled (SD3-style), samples from a
|
|
56
|
+
logit-normal distribution which biases toward intermediate noise
|
|
57
|
+
levels where the perceptual signal is strongest.
|
|
58
|
+
"""
|
|
59
|
+
cfg = self.config
|
|
60
|
+
|
|
61
|
+
if cfg.logit_normal:
|
|
62
|
+
# Logit-normal: sample u ~ N(mean, std²), then t = sigmoid(u)
|
|
63
|
+
u = torch.randn(batch_size, device=device)
|
|
64
|
+
u = u * cfg.logit_normal_std + cfg.logit_normal_mean
|
|
65
|
+
t = torch.sigmoid(u)
|
|
66
|
+
else:
|
|
67
|
+
t = torch.rand(batch_size, device=device)
|
|
68
|
+
|
|
69
|
+
# Clamp away from exact 0 and 1 for numerical stability
|
|
70
|
+
return t.clamp(cfg.sigma_min, 1.0 - cfg.sigma_min)
|
|
71
|
+
|
|
72
|
+
# ── Training ─────────────────────────────────────────────────────
|
|
73
|
+
|
|
74
|
+
def compute_loss(
|
|
75
|
+
self,
|
|
76
|
+
model: nn.Module,
|
|
77
|
+
x_1: torch.Tensor,
|
|
78
|
+
brain_global: torch.Tensor,
|
|
79
|
+
brain_tokens: torch.Tensor | None = None,
|
|
80
|
+
) -> torch.Tensor:
|
|
81
|
+
"""Compute rectified flow matching loss.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
model: DiT model that predicts velocity v(x_t, t, condition).
|
|
85
|
+
x_1: Clean data (target) ``(B, C, H, W)``.
|
|
86
|
+
brain_global: Global brain embedding ``(B, D)``.
|
|
87
|
+
brain_tokens: Brain token sequence ``(B, T, D)`` for cross-attention.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
Scalar MSE loss: E[ ||v_pred - v_target||² ]
|
|
91
|
+
"""
|
|
92
|
+
B = x_1.shape[0]
|
|
93
|
+
device = x_1.device
|
|
94
|
+
|
|
95
|
+
# Sample timesteps
|
|
96
|
+
t = self.sample_timesteps(B, device)
|
|
97
|
+
|
|
98
|
+
# Sample noise
|
|
99
|
+
x_0 = torch.randn_like(x_1)
|
|
100
|
+
|
|
101
|
+
# Linear interpolation: x_t = (1-t) * noise + t * data
|
|
102
|
+
t_expand = t.view(B, *([1] * (x_1.ndim - 1)))
|
|
103
|
+
x_t = (1.0 - t_expand) * x_0 + t_expand * x_1
|
|
104
|
+
|
|
105
|
+
# Target velocity: points from noise to data
|
|
106
|
+
v_target = x_1 - x_0
|
|
107
|
+
|
|
108
|
+
# Model prediction
|
|
109
|
+
v_pred = model(x_t, t, brain_global, brain_tokens)
|
|
110
|
+
|
|
111
|
+
return F.mse_loss(v_pred, v_target)
|
|
112
|
+
|
|
113
|
+
# ── Sampling ─────────────────────────────────────────────────────
|
|
114
|
+
|
|
115
|
+
@torch.no_grad()
|
|
116
|
+
def sample(
|
|
117
|
+
self,
|
|
118
|
+
model: nn.Module,
|
|
119
|
+
shape: tuple[int, ...],
|
|
120
|
+
brain_global: torch.Tensor,
|
|
121
|
+
brain_tokens: torch.Tensor | None = None,
|
|
122
|
+
num_steps: int | None = None,
|
|
123
|
+
cfg_scale: float | None = None,
|
|
124
|
+
brain_global_uncond: torch.Tensor | None = None,
|
|
125
|
+
brain_tokens_uncond: torch.Tensor | None = None,
|
|
126
|
+
) -> torch.Tensor:
|
|
127
|
+
"""Generate samples by solving the ODE from noise to data.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
model: Trained DiT model.
|
|
131
|
+
shape: Output shape ``(B, C, H, W)``.
|
|
132
|
+
brain_global: Global brain condition.
|
|
133
|
+
brain_tokens: Brain token sequence.
|
|
134
|
+
num_steps: Override number of ODE steps.
|
|
135
|
+
cfg_scale: Classifier-free guidance scale. 1.0 = no guidance.
|
|
136
|
+
brain_global_uncond: Unconditional brain embedding for CFG.
|
|
137
|
+
brain_tokens_uncond: Unconditional brain tokens for CFG.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Generated latent ``(B, C, H, W)``.
|
|
141
|
+
"""
|
|
142
|
+
cfg = self.config
|
|
143
|
+
steps = num_steps or cfg.num_steps
|
|
144
|
+
guidance = cfg_scale if cfg_scale is not None else cfg.cfg_scale
|
|
145
|
+
do_cfg = guidance > 1.0 and brain_global_uncond is not None
|
|
146
|
+
device = brain_global.device
|
|
147
|
+
|
|
148
|
+
# Start from pure noise
|
|
149
|
+
x = torch.randn(shape, device=device)
|
|
150
|
+
dt = 1.0 / steps
|
|
151
|
+
|
|
152
|
+
for i in range(steps):
|
|
153
|
+
t_val = i / steps
|
|
154
|
+
t = torch.full((shape[0],), t_val, device=device)
|
|
155
|
+
|
|
156
|
+
if cfg.solver == "midpoint":
|
|
157
|
+
# Midpoint method: evaluate at t + dt/2
|
|
158
|
+
v = self._get_velocity(
|
|
159
|
+
model, x, t, brain_global, brain_tokens,
|
|
160
|
+
guidance, do_cfg, brain_global_uncond, brain_tokens_uncond,
|
|
161
|
+
)
|
|
162
|
+
x_mid = x + v * (dt / 2)
|
|
163
|
+
t_mid = torch.full((shape[0],), t_val + dt / 2, device=device)
|
|
164
|
+
v_mid = self._get_velocity(
|
|
165
|
+
model, x_mid, t_mid, brain_global, brain_tokens,
|
|
166
|
+
guidance, do_cfg, brain_global_uncond, brain_tokens_uncond,
|
|
167
|
+
)
|
|
168
|
+
x = x + v_mid * dt
|
|
169
|
+
else:
|
|
170
|
+
# Euler method
|
|
171
|
+
v = self._get_velocity(
|
|
172
|
+
model, x, t, brain_global, brain_tokens,
|
|
173
|
+
guidance, do_cfg, brain_global_uncond, brain_tokens_uncond,
|
|
174
|
+
)
|
|
175
|
+
x = x + v * dt
|
|
176
|
+
|
|
177
|
+
return x
|
|
178
|
+
|
|
179
|
+
def _get_velocity(
|
|
180
|
+
self,
|
|
181
|
+
model: nn.Module,
|
|
182
|
+
x: torch.Tensor,
|
|
183
|
+
t: torch.Tensor,
|
|
184
|
+
brain_global: torch.Tensor,
|
|
185
|
+
brain_tokens: torch.Tensor | None,
|
|
186
|
+
guidance: float,
|
|
187
|
+
do_cfg: bool,
|
|
188
|
+
brain_global_uncond: torch.Tensor | None,
|
|
189
|
+
brain_tokens_uncond: torch.Tensor | None,
|
|
190
|
+
) -> torch.Tensor:
|
|
191
|
+
"""Get velocity with optional classifier-free guidance."""
|
|
192
|
+
if do_cfg and brain_global_uncond is not None:
|
|
193
|
+
# Conditional prediction
|
|
194
|
+
v_cond = model(x, t, brain_global, brain_tokens)
|
|
195
|
+
# Unconditional prediction
|
|
196
|
+
v_uncond = model(x, t, brain_global_uncond, brain_tokens_uncond)
|
|
197
|
+
# CFG interpolation
|
|
198
|
+
return v_uncond + guidance * (v_cond - v_uncond)
|
|
199
|
+
else:
|
|
200
|
+
return model(x, t, brain_global, brain_tokens)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class EMAModel:
|
|
204
|
+
"""Exponential Moving Average of model parameters.
|
|
205
|
+
|
|
206
|
+
Maintains a shadow copy of model weights updated as:
|
|
207
|
+
θ_ema = decay · θ_ema + (1 - decay) · θ_model
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
def __init__(self, model: nn.Module, decay: float = 0.9999) -> None:
|
|
211
|
+
self.decay = decay
|
|
212
|
+
self.shadow: dict[str, torch.Tensor] = {}
|
|
213
|
+
for name, param in model.named_parameters():
|
|
214
|
+
if param.requires_grad:
|
|
215
|
+
self.shadow[name] = param.data.clone()
|
|
216
|
+
|
|
217
|
+
@torch.no_grad()
|
|
218
|
+
def update(self, model: nn.Module) -> None:
|
|
219
|
+
for name, param in model.named_parameters():
|
|
220
|
+
if name in self.shadow:
|
|
221
|
+
self.shadow[name].lerp_(param.data, 1.0 - self.decay)
|
|
222
|
+
|
|
223
|
+
def apply_to(self, model: nn.Module) -> dict[str, torch.Tensor]:
|
|
224
|
+
"""Replace model params with EMA params. Returns original params."""
|
|
225
|
+
originals: dict[str, torch.Tensor] = {}
|
|
226
|
+
for name, param in model.named_parameters():
|
|
227
|
+
if name in self.shadow:
|
|
228
|
+
originals[name] = param.data.clone()
|
|
229
|
+
param.data.copy_(self.shadow[name])
|
|
230
|
+
return originals
|
|
231
|
+
|
|
232
|
+
def restore(self, model: nn.Module, originals: dict[str, torch.Tensor]) -> None:
|
|
233
|
+
"""Restore original params after EMA evaluation."""
|
|
234
|
+
for name, param in model.named_parameters():
|
|
235
|
+
if name in originals:
|
|
236
|
+
param.data.copy_(originals[name])
|
cortexflow/training.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
"""Training utilities for cortexflow models.
|
|
2
|
+
|
|
3
|
+
Provides a generic training loop, learning rate schedulers, and data
|
|
4
|
+
utilities for training brain decoders on fMRI datasets.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import math
|
|
10
|
+
from typing import Any, Callable, Iterator
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
|
|
15
|
+
from cortexflow._types import TrainingConfig
|
|
16
|
+
from cortexflow.flow_matching import EMAModel
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class WarmupCosineScheduler:
|
|
20
|
+
"""Learning rate schedule: linear warmup → cosine decay.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
optimizer: PyTorch optimizer.
|
|
24
|
+
warmup_steps: Linear warmup duration.
|
|
25
|
+
total_steps: Total training steps.
|
|
26
|
+
min_lr_ratio: Minimum LR as a fraction of peak.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
optimizer: torch.optim.Optimizer,
|
|
32
|
+
warmup_steps: int,
|
|
33
|
+
total_steps: int,
|
|
34
|
+
min_lr_ratio: float = 0.01,
|
|
35
|
+
) -> None:
|
|
36
|
+
self.optimizer = optimizer
|
|
37
|
+
self.warmup_steps = warmup_steps
|
|
38
|
+
self.total_steps = total_steps
|
|
39
|
+
self.min_lr_ratio = min_lr_ratio
|
|
40
|
+
self.base_lrs = [pg["lr"] for pg in optimizer.param_groups]
|
|
41
|
+
self._step = 0
|
|
42
|
+
|
|
43
|
+
def step(self) -> None:
|
|
44
|
+
self._step += 1
|
|
45
|
+
lr_scale = self._get_scale(self._step)
|
|
46
|
+
for pg, base_lr in zip(self.optimizer.param_groups, self.base_lrs):
|
|
47
|
+
pg["lr"] = base_lr * lr_scale
|
|
48
|
+
|
|
49
|
+
def _get_scale(self, step: int) -> float:
|
|
50
|
+
if step <= self.warmup_steps:
|
|
51
|
+
return step / max(1, self.warmup_steps)
|
|
52
|
+
progress = (step - self.warmup_steps) / max(
|
|
53
|
+
1, self.total_steps - self.warmup_steps
|
|
54
|
+
)
|
|
55
|
+
progress = min(progress, 1.0)
|
|
56
|
+
cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
|
|
57
|
+
return self.min_lr_ratio + (1.0 - self.min_lr_ratio) * cosine
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def current_lr(self) -> float:
|
|
61
|
+
return self.optimizer.param_groups[0]["lr"]
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class SyntheticBrainDataset:
|
|
65
|
+
"""Synthetic dataset for testing and development.
|
|
66
|
+
|
|
67
|
+
Yields (stimulus, fmri) pairs with configurable modality.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
n_samples: int = 1000,
|
|
73
|
+
n_voxels: int = 1024,
|
|
74
|
+
modality: str = "image",
|
|
75
|
+
img_size: int = 64,
|
|
76
|
+
n_mels: int = 80,
|
|
77
|
+
audio_len: int = 128,
|
|
78
|
+
max_text_len: int = 64,
|
|
79
|
+
) -> None:
|
|
80
|
+
self.n_samples = n_samples
|
|
81
|
+
self.n_voxels = n_voxels
|
|
82
|
+
self.modality = modality
|
|
83
|
+
self.img_size = img_size
|
|
84
|
+
self.n_mels = n_mels
|
|
85
|
+
self.audio_len = audio_len
|
|
86
|
+
self.max_text_len = max_text_len
|
|
87
|
+
|
|
88
|
+
def __len__(self) -> int:
|
|
89
|
+
return self.n_samples
|
|
90
|
+
|
|
91
|
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
|
92
|
+
"""Generate a single synthetic sample."""
|
|
93
|
+
g = torch.Generator().manual_seed(idx)
|
|
94
|
+
fmri = torch.randn(self.n_voxels, generator=g)
|
|
95
|
+
|
|
96
|
+
if self.modality == "image":
|
|
97
|
+
stimulus = torch.rand(3, self.img_size, self.img_size, generator=g)
|
|
98
|
+
elif self.modality == "audio":
|
|
99
|
+
stimulus = torch.randn(self.n_mels, self.audio_len, generator=g).abs()
|
|
100
|
+
elif self.modality == "text":
|
|
101
|
+
# Random byte-level tokens (printable ASCII range)
|
|
102
|
+
tokens = torch.randint(32, 127, (self.max_text_len,), generator=g)
|
|
103
|
+
stimulus = tokens
|
|
104
|
+
else:
|
|
105
|
+
raise ValueError(f"Unknown modality: {self.modality}")
|
|
106
|
+
|
|
107
|
+
return {"fmri": fmri, "stimulus": stimulus}
|
|
108
|
+
|
|
109
|
+
def to_loader(
|
|
110
|
+
self, batch_size: int = 8, shuffle: bool = True
|
|
111
|
+
) -> Iterator[dict[str, torch.Tensor]]:
|
|
112
|
+
"""Simple batch iterator (no DataLoader dependency)."""
|
|
113
|
+
indices = list(range(self.n_samples))
|
|
114
|
+
if shuffle:
|
|
115
|
+
g = torch.Generator()
|
|
116
|
+
perm = torch.randperm(self.n_samples, generator=g).tolist()
|
|
117
|
+
indices = perm
|
|
118
|
+
|
|
119
|
+
for start in range(0, self.n_samples, batch_size):
|
|
120
|
+
batch_indices = indices[start : start + batch_size]
|
|
121
|
+
items = [self[i] for i in batch_indices]
|
|
122
|
+
batch: dict[str, torch.Tensor] = {}
|
|
123
|
+
for key in items[0]:
|
|
124
|
+
batch[key] = torch.stack([item[key] for item in items])
|
|
125
|
+
yield batch
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class Trainer:
|
|
129
|
+
"""Generic training loop for cortexflow models.
|
|
130
|
+
|
|
131
|
+
Handles optimizer setup, LR scheduling, EMA, gradient clipping,
|
|
132
|
+
and logging.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
model: A cortexflow model with a ``training_loss`` method.
|
|
136
|
+
config: Training configuration.
|
|
137
|
+
device: Device to train on.
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
def __init__(
|
|
141
|
+
self,
|
|
142
|
+
model: nn.Module,
|
|
143
|
+
config: TrainingConfig | None = None,
|
|
144
|
+
device: str | torch.device = "cpu",
|
|
145
|
+
) -> None:
|
|
146
|
+
self.model = model.to(device)
|
|
147
|
+
self.config = config or TrainingConfig()
|
|
148
|
+
self.device = torch.device(device)
|
|
149
|
+
|
|
150
|
+
cfg = self.config
|
|
151
|
+
self.optimizer = torch.optim.AdamW(
|
|
152
|
+
model.parameters(),
|
|
153
|
+
lr=cfg.learning_rate,
|
|
154
|
+
weight_decay=cfg.weight_decay,
|
|
155
|
+
betas=(0.9, 0.95),
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
self.ema: EMAModel | None = None
|
|
159
|
+
if cfg.ema_decay > 0:
|
|
160
|
+
self.ema = EMAModel(model, decay=cfg.ema_decay)
|
|
161
|
+
|
|
162
|
+
self.global_step = 0
|
|
163
|
+
self.scheduler: WarmupCosineScheduler | None = None
|
|
164
|
+
self._log_fn: Callable[[dict[str, Any]], None] | None = None
|
|
165
|
+
|
|
166
|
+
def set_logger(self, fn: Callable[[dict[str, Any]], None]) -> None:
|
|
167
|
+
"""Set a logging callback: fn({"step": ..., "loss": ..., ...})."""
|
|
168
|
+
self._log_fn = fn
|
|
169
|
+
|
|
170
|
+
def _log(self, metrics: dict[str, Any]) -> None:
|
|
171
|
+
if self._log_fn is not None:
|
|
172
|
+
self._log_fn(metrics)
|
|
173
|
+
|
|
174
|
+
def train_step(
|
|
175
|
+
self,
|
|
176
|
+
batch: dict[str, torch.Tensor],
|
|
177
|
+
loss_fn: Callable[[nn.Module, dict[str, torch.Tensor]], torch.Tensor],
|
|
178
|
+
) -> float:
|
|
179
|
+
"""Execute a single training step.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
batch: Dict of tensors (moved to device automatically).
|
|
183
|
+
loss_fn: Computes loss from (model, batch) → scalar.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
Loss value as float.
|
|
187
|
+
"""
|
|
188
|
+
cfg = self.config
|
|
189
|
+
self.model.train()
|
|
190
|
+
|
|
191
|
+
# Move data to device
|
|
192
|
+
batch_dev = {
|
|
193
|
+
k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
|
194
|
+
for k, v in batch.items()
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
loss = loss_fn(self.model, batch_dev)
|
|
198
|
+
loss.backward()
|
|
199
|
+
|
|
200
|
+
# Gradient clipping
|
|
201
|
+
if cfg.grad_clip > 0:
|
|
202
|
+
nn.utils.clip_grad_norm_(self.model.parameters(), cfg.grad_clip)
|
|
203
|
+
|
|
204
|
+
self.optimizer.step()
|
|
205
|
+
self.optimizer.zero_grad()
|
|
206
|
+
|
|
207
|
+
# EMA update
|
|
208
|
+
if self.ema is not None:
|
|
209
|
+
self.ema.update(self.model)
|
|
210
|
+
|
|
211
|
+
# LR schedule
|
|
212
|
+
if self.scheduler is not None:
|
|
213
|
+
self.scheduler.step()
|
|
214
|
+
|
|
215
|
+
self.global_step += 1
|
|
216
|
+
|
|
217
|
+
# Logging
|
|
218
|
+
if self.global_step % cfg.log_every == 0:
|
|
219
|
+
lr = (
|
|
220
|
+
self.scheduler.current_lr
|
|
221
|
+
if self.scheduler
|
|
222
|
+
else cfg.learning_rate
|
|
223
|
+
)
|
|
224
|
+
self._log({
|
|
225
|
+
"step": self.global_step,
|
|
226
|
+
"loss": loss.item(),
|
|
227
|
+
"lr": lr,
|
|
228
|
+
})
|
|
229
|
+
|
|
230
|
+
return loss.item()
|
|
231
|
+
|
|
232
|
+
def fit(
|
|
233
|
+
self,
|
|
234
|
+
dataset: SyntheticBrainDataset,
|
|
235
|
+
loss_fn: Callable[[nn.Module, dict[str, torch.Tensor]], torch.Tensor],
|
|
236
|
+
n_epochs: int | None = None,
|
|
237
|
+
) -> list[float]:
|
|
238
|
+
"""Full training loop over a dataset.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
dataset: Training data.
|
|
242
|
+
loss_fn: Loss function (model, batch) → scalar.
|
|
243
|
+
n_epochs: Override number of epochs.
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
List of per-step losses.
|
|
247
|
+
"""
|
|
248
|
+
cfg = self.config
|
|
249
|
+
epochs = n_epochs or cfg.n_epochs
|
|
250
|
+
steps_per_epoch = max(1, len(dataset) // cfg.batch_size)
|
|
251
|
+
total_steps = epochs * steps_per_epoch
|
|
252
|
+
|
|
253
|
+
self.scheduler = WarmupCosineScheduler(
|
|
254
|
+
self.optimizer, cfg.warmup_steps, total_steps
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
losses: list[float] = []
|
|
258
|
+
for epoch in range(epochs):
|
|
259
|
+
for batch in dataset.to_loader(batch_size=cfg.batch_size, shuffle=True):
|
|
260
|
+
loss = self.train_step(batch, loss_fn)
|
|
261
|
+
losses.append(loss)
|
|
262
|
+
|
|
263
|
+
return losses
|
|
264
|
+
|
|
265
|
+
def save_checkpoint(self, path: str) -> None:
|
|
266
|
+
"""Save model + optimizer + EMA state."""
|
|
267
|
+
state = {
|
|
268
|
+
"model": self.model.state_dict(),
|
|
269
|
+
"optimizer": self.optimizer.state_dict(),
|
|
270
|
+
"global_step": self.global_step,
|
|
271
|
+
}
|
|
272
|
+
if self.ema is not None:
|
|
273
|
+
state["ema"] = self.ema.shadow
|
|
274
|
+
torch.save(state, path)
|
|
275
|
+
|
|
276
|
+
def load_checkpoint(self, path: str) -> None:
|
|
277
|
+
"""Load model + optimizer + EMA state."""
|
|
278
|
+
state = torch.load(path, map_location=self.device, weights_only=False)
|
|
279
|
+
self.model.load_state_dict(state["model"])
|
|
280
|
+
self.optimizer.load_state_dict(state["optimizer"])
|
|
281
|
+
self.global_step = state.get("global_step", 0)
|
|
282
|
+
if self.ema is not None and "ema" in state:
|
|
283
|
+
self.ema.shadow = state["ema"]
|