nested-learning 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nested_learning/__init__.py +12 -0
  2. nested_learning/__main__.py +12 -0
  3. nested_learning/assoc_memory.py +23 -0
  4. nested_learning/backbones.py +147 -0
  5. nested_learning/capabilities.py +104 -0
  6. nested_learning/cli.py +253 -0
  7. nested_learning/cms.py +92 -0
  8. nested_learning/config_utils.py +50 -0
  9. nested_learning/configs/ablations/cms_sparse.yaml +46 -0
  10. nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
  11. nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
  12. nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
  13. nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
  14. nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
  15. nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
  16. nested_learning/configs/data/continual_segments_sample.yaml +9 -0
  17. nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
  18. nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
  19. nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
  20. nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
  21. nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
  22. nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
  23. nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
  24. nested_learning/configs/deepspeed/zero3.json +25 -0
  25. nested_learning/configs/hope/mid.yaml +118 -0
  26. nested_learning/configs/hope/mid_fsdp.yaml +47 -0
  27. nested_learning/configs/hope/pilot.yaml +2 -0
  28. nested_learning/configs/hope/pilot_attention.yaml +9 -0
  29. nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
  30. nested_learning/configs/hope/pilot_transformer.yaml +9 -0
  31. nested_learning/configs/hope/target.yaml +145 -0
  32. nested_learning/configs/hope/target_fsdp.yaml +47 -0
  33. nested_learning/configs/mid_smoke.yaml +99 -0
  34. nested_learning/configs/mid_stage2.yaml +110 -0
  35. nested_learning/configs/mid_stage2_smoke.yaml +102 -0
  36. nested_learning/configs/mid_titan_baseline.yaml +92 -0
  37. nested_learning/configs/pilot.yaml +127 -0
  38. nested_learning/configs/pilot_paper_faithful.yaml +42 -0
  39. nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
  40. nested_learning/configs/pilot_smoke.yaml +80 -0
  41. nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
  42. nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
  43. nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
  44. nested_learning/continual_classification.py +136 -0
  45. nested_learning/continual_streaming.py +283 -0
  46. nested_learning/data.py +153 -0
  47. nested_learning/device.py +21 -0
  48. nested_learning/eval_state.py +72 -0
  49. nested_learning/fast_state.py +108 -0
  50. nested_learning/functional.py +69 -0
  51. nested_learning/hope/__init__.py +0 -0
  52. nested_learning/hope/block.py +1973 -0
  53. nested_learning/hope/self_mod.py +40 -0
  54. nested_learning/instrumentation.py +38 -0
  55. nested_learning/levels.py +94 -0
  56. nested_learning/logging_utils.py +64 -0
  57. nested_learning/memorize.py +382 -0
  58. nested_learning/model.py +604 -0
  59. nested_learning/optim/__init__.py +0 -0
  60. nested_learning/optim/deep.py +102 -0
  61. nested_learning/optim/factory.py +13 -0
  62. nested_learning/optim/m3.py +121 -0
  63. nested_learning/optim/manager.py +151 -0
  64. nested_learning/titan/__init__.py +0 -0
  65. nested_learning/titan/memory.py +88 -0
  66. nested_learning/titan/model.py +412 -0
  67. nested_learning/titan/self_modifying.py +724 -0
  68. nested_learning/tokenizer.py +28 -0
  69. nested_learning/tokenizer_coverage.py +77 -0
  70. nested_learning/training.py +1600 -0
  71. nested_learning/transformer.py +104 -0
  72. nested_learning-0.2.0.dist-info/METADATA +390 -0
  73. nested_learning-0.2.0.dist-info/RECORD +76 -0
  74. nested_learning-0.2.0.dist-info/WHEEL +4 -0
  75. nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
  76. nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,1600 @@
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import json
5
+ import os
6
+ import pickle
7
+ import random
8
+ from contextlib import nullcontext
9
+ from dataclasses import dataclass
10
+ from hashlib import sha256
11
+ from pathlib import Path
12
+ from typing import Dict, Iterator, Protocol, Tuple, cast
13
+
14
+ import numpy as np
15
+ import torch
16
+ from omegaconf import DictConfig, OmegaConf
17
+ from torch.utils.data import DataLoader, DistributedSampler, IterableDataset
18
+
19
+ from .data import (
20
+ MixtureShardDataset,
21
+ ShardSourceConfig,
22
+ SyntheticTextConfig,
23
+ SyntheticTextDataset,
24
+ TokenShardDataset,
25
+ collate_batch,
26
+ )
27
+ from .levels import LevelSpec
28
+ from .logging_utils import BaseLogger, NullLogger, init_logger
29
+ from .model import HOPEModel, ModelConfig
30
+ from .optim.m3 import M3
31
+ from .titan.model import TitanOnlyModel, TitanOnlyModelConfig
32
+
33
+
34
+ @dataclass
35
+ class DistributedContext:
36
+ rank: int
37
+ world_size: int
38
+ device: torch.device
39
+
40
+
41
+ def unwrap_config(cfg: DictConfig) -> DictConfig:
42
+ """Hydra can wrap grouped configs (e.g., hope/pilot) under the group name."""
43
+ if "model" in cfg:
44
+ return cfg
45
+ if "hope" in cfg:
46
+ return cast(DictConfig, cfg.hope)
47
+ if "ablations" in cfg:
48
+ return cast(DictConfig, cfg.ablations)
49
+ return cfg
50
+
51
+
52
+ def build_model_from_cfg(model_cfg: DictConfig) -> torch.nn.Module:
53
+ model_type = model_cfg.get("type", "hope")
54
+ optimizer_cfg: Dict[str, dict] = {}
55
+ if "optimizers" in model_cfg:
56
+ optimizer_cfg = cast(
57
+ Dict[str, dict],
58
+ OmegaConf.to_container(model_cfg.optimizers, resolve=True),
59
+ )
60
+ teach_scale = model_cfg.get("teach_scale", 1.0)
61
+ teach_clip = model_cfg.get("teach_clip", 0.0)
62
+ teach_schedule: Dict[str, float] = {}
63
+ if "teach_schedule" in model_cfg:
64
+ teach_schedule = cast(
65
+ Dict[str, float],
66
+ OmegaConf.to_container(model_cfg.teach_schedule, resolve=True),
67
+ )
68
+ qk_l2_norm = bool(model_cfg.get("qk_l2_norm", False))
69
+ local_conv_window_raw = model_cfg.get("local_conv_window")
70
+ local_conv_window = None if local_conv_window_raw is None else int(local_conv_window_raw)
71
+ surprise_threshold_raw = model_cfg.get("surprise_threshold")
72
+ surprise_threshold = (
73
+ None if surprise_threshold_raw is None else float(surprise_threshold_raw)
74
+ )
75
+ surprise_metric = str(model_cfg.get("surprise_metric", "l2"))
76
+ cms_use_layernorm = bool(model_cfg.get("cms_use_layernorm", True))
77
+ if model_type == "titan":
78
+ titan_spec = LevelSpec(**model_cfg.titan_level)
79
+ titan_cfg = TitanOnlyModelConfig(
80
+ vocab_size=model_cfg.vocab_size,
81
+ dim=model_cfg.dim,
82
+ num_layers=model_cfg.num_layers,
83
+ heads=model_cfg.heads,
84
+ titan_level=titan_spec,
85
+ optimizers=optimizer_cfg,
86
+ teach_scale=teach_scale,
87
+ teach_clip=teach_clip,
88
+ teach_schedule=teach_schedule,
89
+ qk_l2_norm=qk_l2_norm,
90
+ local_conv_window=local_conv_window,
91
+ surprise_threshold=surprise_threshold,
92
+ surprise_metric=surprise_metric,
93
+ freeze_backbone=model_cfg.get("freeze_backbone", False),
94
+ self_mod_lr=float(model_cfg.get("self_mod_lr", 1e-3)),
95
+ self_mod_hidden=int(model_cfg.get("self_mod_hidden", 4)),
96
+ )
97
+ return TitanOnlyModel(titan_cfg)
98
+ titan_spec = LevelSpec(**model_cfg.titan_level)
99
+ cms_specs = [LevelSpec(**entry) for entry in model_cfg.cms_levels]
100
+ self_mod_chunk_size_memory_raw = model_cfg.get("self_mod_chunk_size_memory")
101
+ self_mod_chunk_size_memory = (
102
+ None if self_mod_chunk_size_memory_raw is None else int(self_mod_chunk_size_memory_raw)
103
+ )
104
+ self_mod_local_conv_window_raw = model_cfg.get("self_mod_local_conv_window", 4)
105
+ self_mod_local_conv_window = (
106
+ None if self_mod_local_conv_window_raw is None else int(self_mod_local_conv_window_raw)
107
+ )
108
+ hope_cfg = ModelConfig(
109
+ vocab_size=model_cfg.vocab_size,
110
+ dim=model_cfg.dim,
111
+ num_layers=model_cfg.num_layers,
112
+ heads=model_cfg.heads,
113
+ titan_level=titan_spec,
114
+ cms_levels=cms_specs,
115
+ cms_flush_partial_at_end=bool(model_cfg.get("cms_flush_partial_at_end", False)),
116
+ cms_use_layernorm=cms_use_layernorm,
117
+ optimizers=optimizer_cfg,
118
+ teach_scale=teach_scale,
119
+ teach_clip=teach_clip,
120
+ teach_schedule=teach_schedule,
121
+ gradient_checkpointing=model_cfg.get("gradient_checkpointing", False),
122
+ surprise_threshold=surprise_threshold,
123
+ surprise_metric=surprise_metric,
124
+ freeze_backbone=model_cfg.get("freeze_backbone", False),
125
+ qk_l2_norm=qk_l2_norm,
126
+ local_conv_window=local_conv_window,
127
+ self_mod_lr=float(model_cfg.get("self_mod_lr", 1e-3)),
128
+ self_mod_hidden=int(model_cfg.get("self_mod_hidden", 4)),
129
+ self_mod_chunk_size=int(model_cfg.get("self_mod_chunk_size", 1)),
130
+ self_mod_chunk_size_memory=self_mod_chunk_size_memory,
131
+ self_mod_objective=str(model_cfg.get("self_mod_objective", "l2")),
132
+ self_mod_stopgrad_vhat=bool(model_cfg.get("self_mod_stopgrad_vhat", True)),
133
+ self_mod_use_rank1_precond=bool(model_cfg.get("self_mod_use_rank1_precond", True)),
134
+ self_mod_use_alpha=bool(model_cfg.get("self_mod_use_alpha", True)),
135
+ self_mod_use_skip=bool(model_cfg.get("self_mod_use_skip", True)),
136
+ self_mod_momentum=float(model_cfg.get("self_mod_momentum", 0.0)),
137
+ self_mod_adaptive_q=bool(model_cfg.get("self_mod_adaptive_q", False)),
138
+ self_mod_local_conv_window=self_mod_local_conv_window,
139
+ transformer_mlp_hidden_multiplier=int(
140
+ model_cfg.get("transformer_mlp_hidden_multiplier", 4)
141
+ ),
142
+ transformer_activation=str(model_cfg.get("transformer_activation", "gelu")),
143
+ block_variant=str(model_cfg.get("block_variant", "hope_hybrid")),
144
+ )
145
+ return HOPEModel(hope_cfg)
146
+
147
+
148
+ def build_dataloader(
149
+ data_cfg: DictConfig,
150
+ *,
151
+ distributed: bool,
152
+ dist_ctx: DistributedContext | None,
153
+ seed: int | None = None,
154
+ ) -> Tuple[DataLoader, DistributedSampler | None]:
155
+ dataset = _build_dataset(data_cfg)
156
+ use_sampler = distributed and not isinstance(dataset, IterableDataset)
157
+ if use_sampler:
158
+ assert dist_ctx is not None
159
+ sampler: DistributedSampler | None = DistributedSampler(
160
+ dataset,
161
+ num_replicas=dist_ctx.world_size,
162
+ rank=dist_ctx.rank,
163
+ shuffle=True,
164
+ drop_last=False,
165
+ )
166
+ shuffle = False
167
+ else:
168
+ sampler = None
169
+ shuffle = True
170
+ if isinstance(dataset, IterableDataset):
171
+ shuffle = False
172
+ generator = None
173
+ worker_init_fn = None
174
+ if seed is not None:
175
+ generator = torch.Generator()
176
+ generator.manual_seed(seed)
177
+ worker_init_fn = _make_worker_init_fn(seed)
178
+ dataloader = DataLoader(
179
+ dataset,
180
+ batch_size=data_cfg.batch_size,
181
+ shuffle=shuffle,
182
+ sampler=sampler,
183
+ collate_fn=collate_batch,
184
+ num_workers=data_cfg.get("num_workers", 0),
185
+ pin_memory=True,
186
+ worker_init_fn=worker_init_fn,
187
+ generator=generator,
188
+ )
189
+ return dataloader, sampler
190
+
191
+
192
+ def _build_dataset(data_cfg: DictConfig):
193
+ source = data_cfg.source
194
+ if source == "synthetic":
195
+ synth_cfg = SyntheticTextConfig(
196
+ vocab_size=data_cfg.vocab_size,
197
+ seq_len=data_cfg.seq_len,
198
+ dataset_size=data_cfg.dataset_size,
199
+ )
200
+ return SyntheticTextDataset(synth_cfg)
201
+ if source == "shards":
202
+ shard_dir = data_cfg.shards_dir
203
+ return TokenShardDataset(shard_dir)
204
+ if source == "mixture":
205
+ mixture_cfg = data_cfg.mixture
206
+ sources = [
207
+ ShardSourceConfig(
208
+ name=entry.name,
209
+ shards_dir=entry.shards_dir,
210
+ weight=entry.weight,
211
+ )
212
+ for entry in mixture_cfg.sources
213
+ ]
214
+ samples_per_epoch = mixture_cfg.samples_per_epoch
215
+ seed = mixture_cfg.get("seed", 0)
216
+ return MixtureShardDataset(
217
+ sources,
218
+ samples_per_epoch=samples_per_epoch,
219
+ seed=seed,
220
+ )
221
+ msg = f"Unsupported data source {source}"
222
+ raise ValueError(msg)
223
+
224
+
225
+ def compute_teach_signal(
226
+ model: "_HasLMHead",
227
+ logits: torch.Tensor,
228
+ tokens: torch.Tensor,
229
+ *,
230
+ next_tokens: torch.Tensor | None = None,
231
+ ignore_index: int | None = None,
232
+ ) -> torch.Tensor:
233
+ """
234
+ Approximate dL/dh where h is the hidden state before the LM head.
235
+
236
+ This matches the gradient of mean next-token CE.
237
+
238
+ By default this corresponds to CE(logits[:, :-1], tokens[:, 1:]).
239
+ If `next_tokens` is provided, the final logit position is also supervised
240
+ against that boundary target (used for chunked streaming boundaries).
241
+
242
+ If ignore_index is provided, targets equal to ignore_index are masked out and
243
+ the mean reduction denominator becomes the number of active targets (matching
244
+ PyTorch CE semantics).
245
+ """
246
+ logits_detached = logits.detach()
247
+ probs = torch.softmax(logits_detached, dim=-1)
248
+ residual = probs.clone()
249
+ batch_size, seq_len, _ = residual.shape
250
+
251
+ targets = torch.zeros(
252
+ batch_size,
253
+ seq_len,
254
+ device=tokens.device,
255
+ dtype=tokens.dtype,
256
+ )
257
+ active = torch.zeros(
258
+ batch_size,
259
+ seq_len,
260
+ device=tokens.device,
261
+ dtype=torch.bool,
262
+ )
263
+ if seq_len > 1:
264
+ targets[:, :-1] = tokens[:, 1:]
265
+ active[:, :-1] = True
266
+ if next_tokens is not None:
267
+ if next_tokens.ndim == 2 and next_tokens.size(1) == 1:
268
+ next_targets = next_tokens[:, 0]
269
+ elif next_tokens.ndim == 1:
270
+ next_targets = next_tokens
271
+ else:
272
+ raise ValueError("next_tokens must have shape [B] or [B, 1]")
273
+ if next_targets.size(0) != batch_size:
274
+ raise ValueError("next_tokens batch dimension must match tokens batch dimension")
275
+ targets[:, -1] = next_targets.to(device=tokens.device, dtype=tokens.dtype)
276
+ active[:, -1] = True
277
+ if ignore_index is not None:
278
+ active = active & (targets != ignore_index)
279
+
280
+ active_f = active.to(dtype=residual.dtype)
281
+ residual.mul_(active_f.unsqueeze(-1))
282
+ safe_targets = torch.where(active, targets, torch.zeros_like(targets))
283
+ src = -active_f.unsqueeze(-1)
284
+ residual.scatter_add_(-1, safe_targets.unsqueeze(-1), src)
285
+ denom: torch.Tensor = active_f.sum().clamp(min=1.0)
286
+ residual = residual / denom
287
+
288
+ head_weight = model.lm_head.weight.detach()
289
+ if head_weight.dtype != residual.dtype:
290
+ head_weight = head_weight.to(dtype=residual.dtype)
291
+ grad = residual @ head_weight
292
+ return grad
293
+
294
+
295
+ def _compute_layer_teach_signals(
296
+ loss: torch.Tensor,
297
+ block_outputs: list[torch.Tensor],
298
+ *,
299
+ detach: bool = True,
300
+ create_graph: bool = False,
301
+ ) -> list[torch.Tensor]:
302
+ grads = torch.autograd.grad(
303
+ loss,
304
+ block_outputs,
305
+ retain_graph=True,
306
+ create_graph=create_graph,
307
+ allow_unused=False,
308
+ )
309
+ if detach:
310
+ return [g.detach() for g in grads]
311
+ return list(grads)
312
+
313
+
314
+ def _compute_surprise_override(
315
+ metric: str,
316
+ *,
317
+ logits: torch.Tensor,
318
+ tokens: torch.Tensor,
319
+ loss: torch.Tensor,
320
+ next_tokens: torch.Tensor | None = None,
321
+ ) -> float | None:
322
+ normalized = str(metric).strip().lower()
323
+ if normalized == "loss":
324
+ return float(loss.detach().item())
325
+ if normalized == "logit_entropy":
326
+ supervised_steps = int(tokens.size(1) - 1 + (0 if next_tokens is None else 1))
327
+ if supervised_steps <= 0:
328
+ return None
329
+ logits_detached = logits[:, :supervised_steps].detach().float()
330
+ probs = torch.softmax(logits_detached, dim=-1)
331
+ entropy = -(probs * torch.log(probs.clamp(min=1e-9))).sum(dim=-1).mean()
332
+ return float(entropy.item())
333
+ return None
334
+
335
+
336
+ def _infer_online_chunk_size(model: HOPEModel) -> int | None:
337
+ min_period: int | None = None
338
+ blocks = getattr(model, "blocks", [])
339
+ for block in blocks:
340
+ cfg = getattr(block, "config", None)
341
+ levels = getattr(cfg, "cms_levels", None)
342
+ if not levels:
343
+ continue
344
+ for spec in levels:
345
+ period = int(spec.update_period)
346
+ if period <= 0:
347
+ continue
348
+ min_period = period if min_period is None else min(min_period, period)
349
+ return min_period
350
+
351
+
352
+ def _iter_online_token_chunks(
353
+ tokens: torch.Tensor, *, chunk_size: int
354
+ ) -> Iterator[tuple[torch.Tensor, bool]]:
355
+ if chunk_size < 1:
356
+ raise ValueError("chunk_size must be >= 1")
357
+ seq_len = tokens.size(1)
358
+ for core_start in range(0, seq_len, chunk_size):
359
+ core_end = min(core_start + chunk_size, seq_len)
360
+ if core_end <= core_start:
361
+ continue
362
+ # Carry one-token overlap so chunk boundaries still include next-token supervision.
363
+ chunk_start = core_start - 1 if core_start > 0 else core_start
364
+ chunk_tokens = tokens[:, chunk_start:core_end]
365
+ finalize_updates = core_end >= seq_len
366
+ yield chunk_tokens, finalize_updates
367
+
368
+
369
+ def _iter_online_boundary_chunks(
370
+ tokens: torch.Tensor, *, chunk_size: int
371
+ ) -> Iterator[tuple[torch.Tensor, torch.Tensor | None, bool]]:
372
+ """
373
+ Yield non-overlapping chunks plus the boundary target token for chunk end.
374
+
375
+ This enables exact boundary supervision without one-token overlap.
376
+ """
377
+ if chunk_size < 1:
378
+ raise ValueError("chunk_size must be >= 1")
379
+ seq_len = tokens.size(1)
380
+ for start in range(0, seq_len, chunk_size):
381
+ end = min(start + chunk_size, seq_len)
382
+ if end <= start:
383
+ continue
384
+ next_tokens = None
385
+ if end < seq_len:
386
+ next_tokens = tokens[:, end]
387
+ finalize_updates = end >= seq_len
388
+ yield tokens[:, start:end], next_tokens, finalize_updates
389
+
390
+
391
+ class _HasLMHead(Protocol):
392
+ lm_head: torch.nn.Linear
393
+
394
+
395
+ def _checksum_path(path: str | None) -> str | None:
396
+ if not path:
397
+ return None
398
+ candidate = Path(path)
399
+ if not candidate.exists() or not candidate.is_file():
400
+ return None
401
+ digest = sha256()
402
+ with candidate.open("rb") as handle:
403
+ for chunk in iter(lambda: handle.read(1 << 20), b""):
404
+ digest.update(chunk)
405
+ return digest.hexdigest()
406
+
407
+
408
+ def maybe_save_checkpoint(
409
+ cfg: DictConfig,
410
+ model: torch.nn.Module,
411
+ optimizer: torch.optim.Optimizer,
412
+ *,
413
+ step: int,
414
+ total_steps: int,
415
+ distributed: bool,
416
+ dist_ctx: DistributedContext | None,
417
+ step_offset: int = 0,
418
+ ) -> None:
419
+ ckpt_cfg = cfg.train.get("checkpoint")
420
+ if not ckpt_cfg or not ckpt_cfg.get("enable", False):
421
+ return
422
+ if distributed and dist_ctx is not None and dist_ctx.rank != 0:
423
+ return
424
+ save_interval = ckpt_cfg.get("save_interval", total_steps)
425
+ save_last = ckpt_cfg.get("save_last", True)
426
+ is_last_step = (step + 1) >= total_steps
427
+ should_save = ((step + 1) % max(1, save_interval) == 0) or (save_last and is_last_step)
428
+ if not should_save:
429
+ return
430
+ ckpt_dir = Path(ckpt_cfg.get("dir", "checkpoints/default"))
431
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
432
+ global_step = step + 1 + int(step_offset)
433
+ ckpt_path = ckpt_dir / f"step_{global_step:06d}.pt"
434
+ tmp_path = ckpt_path.with_suffix(".tmp")
435
+ resolved_cfg = OmegaConf.to_container(cfg, resolve=True)
436
+ state = {
437
+ "model": model.state_dict(),
438
+ "optimizer": optimizer.state_dict(),
439
+ "step": step + 1,
440
+ "config": resolved_cfg,
441
+ }
442
+ torch.save(state, tmp_path)
443
+ os.replace(tmp_path, ckpt_path)
444
+ write_checkpoint_metadata(cfg, ckpt_path, global_step)
445
+ prefix = "[checkpoint]"
446
+ if distributed and dist_ctx is not None:
447
+ prefix = f"[checkpoint rank={dist_ctx.rank}]"
448
+ print(f"{prefix} saved {ckpt_path} (global_step={global_step})")
449
+
450
+
451
+ def _validate_distributed_config(cfg: DictConfig, distributed: bool) -> None:
452
+ if not distributed:
453
+ return
454
+ strict = bool(cfg.train.get("strict_streaming_contract", False))
455
+ fail_if_faithful_disabled = bool(cfg.train.get("fail_if_paper_faithful_disabled", False))
456
+ fail_hard = strict or fail_if_faithful_disabled
457
+ if not fail_hard:
458
+ return
459
+ if bool(cfg.train.get("per_layer_teach_signal", False)):
460
+ raise RuntimeError(
461
+ "train.per_layer_teach_signal=true is not supported under DDP in this repo. "
462
+ "Set train.strict_streaming_contract=false and "
463
+ "train.fail_if_paper_faithful_disabled=false to allow the fallback, "
464
+ "or run single-process training."
465
+ )
466
+ if bool(cfg.train.get("online_updates", False)):
467
+ raise RuntimeError(
468
+ "train.online_updates=true is not supported under DDP in this repo. "
469
+ "Set train.strict_streaming_contract=false and "
470
+ "train.fail_if_paper_faithful_disabled=false to allow the fallback, "
471
+ "or run single-process training."
472
+ )
473
+ if bool(cfg.train.get("online_boundary_targets", False)):
474
+ raise RuntimeError(
475
+ "train.online_boundary_targets=true is not supported under DDP in this repo. "
476
+ "Set train.strict_streaming_contract=false and "
477
+ "train.fail_if_paper_faithful_disabled=false to allow the fallback, "
478
+ "or run single-process training."
479
+ )
480
+ if bool(cfg.train.get("online_carry_attention_cache", False)):
481
+ raise RuntimeError(
482
+ "train.online_carry_attention_cache=true is not supported under DDP in this repo. "
483
+ "Set train.strict_streaming_contract=false and "
484
+ "train.fail_if_paper_faithful_disabled=false to allow the fallback, "
485
+ "or run single-process training."
486
+ )
487
+
488
+
489
+ def _emit_streaming_warning(
490
+ *,
491
+ code: str,
492
+ message: str,
493
+ details: dict[str, object] | None = None,
494
+ ) -> None:
495
+ payload: dict[str, object] = {"warning_code": code, "message": message}
496
+ if details:
497
+ payload["details"] = details
498
+ print(f"[train.warn] {json.dumps(payload, sort_keys=True)}")
499
+
500
+
501
+ def _validate_paper_auditing_variant(cfg: DictConfig) -> None:
502
+ strict = bool(cfg.train.get("strict_streaming_contract", False))
503
+ block_variant = str(cfg.model.get("block_variant", "")).strip().lower()
504
+ if not block_variant:
505
+ return
506
+ allowed = {"hope_attention", "hope_selfmod"}
507
+ if block_variant in allowed:
508
+ return
509
+ msg = (
510
+ "strict streaming contract expects a paper-defined HOPE variant "
511
+ f"({sorted(allowed)}), got model.block_variant={block_variant!r}"
512
+ )
513
+ if strict:
514
+ raise RuntimeError(msg)
515
+ _emit_streaming_warning(
516
+ code="non_paper_variant",
517
+ message=msg,
518
+ details={"block_variant": block_variant},
519
+ )
520
+
521
+
522
+ def _validate_tied_lm_head_for_paper_auditing(
523
+ cfg: DictConfig,
524
+ model: torch.nn.Module,
525
+ ) -> None:
526
+ strict = bool(cfg.train.get("strict_streaming_contract", False))
527
+ fail_if_faithful_disabled = bool(cfg.train.get("fail_if_paper_faithful_disabled", False))
528
+ if not (strict or fail_if_faithful_disabled):
529
+ return
530
+ lm_head = getattr(model, "lm_head", None)
531
+ embed = getattr(model, "embed", None)
532
+ if lm_head is None or embed is None:
533
+ return
534
+ lm_weight = getattr(lm_head, "weight", None)
535
+ emb_weight = getattr(embed, "weight", None)
536
+ if lm_weight is None or emb_weight is None:
537
+ return
538
+ if lm_weight.data_ptr() == emb_weight.data_ptr():
539
+ return
540
+ raise RuntimeError(
541
+ "paper-auditing mode requires tied LM head and embedding weights "
542
+ "(lm_head.weight must alias embed.weight)."
543
+ )
544
+
545
+
546
+ def _validate_fast_state_batch_semantics(cfg: DictConfig) -> None:
547
+ if not bool(cfg.train.get("use_fast_state", False)):
548
+ return
549
+ data_cfg = cfg.get("data")
550
+ if data_cfg is None:
551
+ return
552
+ batch_size_raw = data_cfg.get("batch_size", 1)
553
+ try:
554
+ batch_size = int(batch_size_raw)
555
+ except (TypeError, ValueError):
556
+ return
557
+ if batch_size <= 1:
558
+ return
559
+ msg = (
560
+ "train.use_fast_state=true currently shares CMS/TITAN fast state across the batch. "
561
+ "For strict per-context semantics, set data.batch_size=1."
562
+ )
563
+ strict = bool(cfg.train.get("strict_streaming_contract", False))
564
+ fail_if_faithful_disabled = bool(cfg.train.get("fail_if_paper_faithful_disabled", False))
565
+ if strict or fail_if_faithful_disabled:
566
+ raise RuntimeError(msg)
567
+ _emit_streaming_warning(
568
+ code="shared_fast_state_batch",
569
+ message=msg,
570
+ details={"batch_size": batch_size},
571
+ )
572
+
573
+
574
+ def _validate_online_update_fast_state_semantics(cfg: DictConfig) -> None:
575
+ train_cfg = cfg.get("train")
576
+ if train_cfg is None:
577
+ return
578
+ online_updates = bool(train_cfg.get("online_updates", False))
579
+ use_fast_state = bool(train_cfg.get("use_fast_state", False))
580
+ if not online_updates or use_fast_state:
581
+ return
582
+ msg = (
583
+ "train.online_updates=true with train.use_fast_state=false applies online writes "
584
+ "directly to base parameters within each step. This can make gradients across chunks "
585
+ "harder to interpret. Use train.use_fast_state=true for paper-faithful runs."
586
+ )
587
+ strict = bool(train_cfg.get("strict_streaming_contract", False))
588
+ fail_if_faithful_disabled = bool(train_cfg.get("fail_if_paper_faithful_disabled", False))
589
+ if strict or fail_if_faithful_disabled:
590
+ raise RuntimeError(msg)
591
+ _emit_streaming_warning(
592
+ code="online_updates_without_fast_state",
593
+ message=msg,
594
+ details={"online_updates": True, "use_fast_state": False},
595
+ )
596
+
597
+
598
+ def _resolve_algorithm_mode(cfg: DictConfig) -> str:
599
+ mode = str(cfg.train.get("algorithm_mode", "two_pass_stopgrad_updates")).strip()
600
+ allowed = {"two_pass_stopgrad_updates", "boundary_state_grad_through_write"}
601
+ if mode not in allowed:
602
+ raise RuntimeError(f"Unsupported train.algorithm_mode={mode!r}; allowed={sorted(allowed)}")
603
+ return mode
604
+
605
+
606
+ def _validate_algorithm_mode_constraints(
607
+ cfg: DictConfig,
608
+ *,
609
+ algorithm_mode: str,
610
+ distributed: bool,
611
+ ) -> None:
612
+ if algorithm_mode != "boundary_state_grad_through_write":
613
+ return
614
+ if distributed:
615
+ raise RuntimeError(
616
+ "train.algorithm_mode='boundary_state_grad_through_write' is not supported in DDP."
617
+ )
618
+ if not bool(cfg.train.get("online_updates", False)):
619
+ raise RuntimeError(
620
+ "train.algorithm_mode='boundary_state_grad_through_write' requires "
621
+ "train.online_updates=true."
622
+ )
623
+ if not bool(cfg.train.get("per_layer_teach_signal", False)):
624
+ raise RuntimeError(
625
+ "train.algorithm_mode='boundary_state_grad_through_write' requires "
626
+ "train.per_layer_teach_signal=true."
627
+ )
628
+ if not bool(cfg.train.get("use_fast_state", False)):
629
+ raise RuntimeError(
630
+ "train.algorithm_mode='boundary_state_grad_through_write' requires "
631
+ "train.use_fast_state=true."
632
+ )
633
+ if bool(cfg.train.get("online_carry_attention_cache", False)) and not bool(
634
+ cfg.train.get("online_boundary_targets", False)
635
+ ):
636
+ raise RuntimeError(
637
+ "online_carry_attention_cache=true requires train.online_boundary_targets=true "
638
+ "(non-overlap chunking)."
639
+ )
640
+ _emit_streaming_warning(
641
+ code="experimental_boundary_state_mode",
642
+ message=(
643
+ "train.algorithm_mode='boundary_state_grad_through_write' is an experimental "
644
+ "single-process path for mechanism probing and may use more memory."
645
+ ),
646
+ details={"algorithm_mode": algorithm_mode},
647
+ )
648
+
649
+
650
+ def _validate_online_chunking_constraints(cfg: DictConfig) -> None:
651
+ online_updates = bool(cfg.train.get("online_updates", False))
652
+ online_boundary_targets = bool(cfg.train.get("online_boundary_targets", False))
653
+ online_carry_attention_cache = bool(cfg.train.get("online_carry_attention_cache", False))
654
+ if online_carry_attention_cache and not online_updates:
655
+ raise RuntimeError("online_carry_attention_cache=true requires train.online_updates=true")
656
+ if online_carry_attention_cache and not online_boundary_targets:
657
+ raise RuntimeError(
658
+ "online_carry_attention_cache=true requires train.online_boundary_targets=true "
659
+ "(non-overlap chunking)."
660
+ )
661
+
662
+
663
+ def _check_online_supervised_pairs(
664
+ *,
665
+ strict: bool,
666
+ observed_pairs: int,
667
+ seq_len: int,
668
+ ) -> None:
669
+ expected_pairs = max(int(seq_len) - 1, 0)
670
+ if observed_pairs == expected_pairs:
671
+ return
672
+ msg = (
673
+ "online chunk supervision mismatch: observed pair coverage does not match sequence length "
674
+ f"(observed_pairs={observed_pairs}, expected_pairs={expected_pairs})"
675
+ )
676
+ if strict:
677
+ raise RuntimeError(msg)
678
+ _emit_streaming_warning(
679
+ code="online_supervision_mismatch",
680
+ message=msg,
681
+ details={"observed_pairs": observed_pairs, "expected_pairs": expected_pairs},
682
+ )
683
+
684
+
685
+ def run_training_loop(
686
+ cfg: DictConfig,
687
+ *,
688
+ device: torch.device,
689
+ distributed: bool = False,
690
+ dist_ctx: DistributedContext | None = None,
691
+ ) -> Dict[str, float]:
692
+ algorithm_mode = _resolve_algorithm_mode(cfg)
693
+ _validate_algorithm_mode_constraints(
694
+ cfg,
695
+ algorithm_mode=algorithm_mode,
696
+ distributed=distributed,
697
+ )
698
+ _validate_online_chunking_constraints(cfg)
699
+ _validate_distributed_config(cfg, distributed)
700
+ _validate_paper_auditing_variant(cfg)
701
+ _validate_fast_state_batch_semantics(cfg)
702
+ _validate_online_update_fast_state_semantics(cfg)
703
+ model = build_model_from_cfg(cfg.model).to(device)
704
+ train_seed = cfg.train.get("seed")
705
+ deterministic = cfg.train.get("deterministic", False)
706
+ if train_seed is not None:
707
+ _seed_everything(int(train_seed), deterministic=bool(deterministic))
708
+ model = _maybe_compile_model(model, cfg.train.get("compile"))
709
+ if distributed:
710
+ assert dist_ctx is not None
711
+ if device.type == "cuda":
712
+ idx = device.index if device.index is not None else 0
713
+ model = torch.nn.parallel.DistributedDataParallel(
714
+ model,
715
+ device_ids=[idx],
716
+ output_device=idx,
717
+ find_unused_parameters=True,
718
+ )
719
+ else:
720
+ model = torch.nn.parallel.DistributedDataParallel(
721
+ model,
722
+ find_unused_parameters=True,
723
+ )
724
+ base_model = model.module
725
+ else:
726
+ base_model = model
727
+
728
+ _validate_tied_lm_head_for_paper_auditing(cfg, base_model)
729
+
730
+ seed_offset = 0
731
+ if train_seed is not None and dist_ctx is not None:
732
+ seed_offset = dist_ctx.rank
733
+ dataloader_seed = None if train_seed is None else int(train_seed) + seed_offset
734
+ dataloader, sampler = build_dataloader(
735
+ cfg.data,
736
+ distributed=distributed,
737
+ dist_ctx=dist_ctx,
738
+ seed=dataloader_seed,
739
+ )
740
+ optimizer = _build_optimizer(base_model, cfg, device=device)
741
+ autocast_factory = _make_autocast_factory(device, cfg.train.get("mixed_precision"))
742
+ logger = init_logger(getattr(cfg, "logging", None), cfg)
743
+ if distributed and dist_ctx is not None and dist_ctx.rank != 0:
744
+ logger = NullLogger()
745
+ _log_run_features(logger, base_model, cfg, optimizer, device)
746
+ steps = cfg.train.steps
747
+ log_interval = cfg.train.get("log_interval", 1)
748
+ per_layer_teach = bool(cfg.train.get("per_layer_teach_signal", False))
749
+ online_updates = bool(cfg.train.get("online_updates", False))
750
+ online_chunk_size = int(cfg.train.get("online_chunk_size", 0) or 0)
751
+ online_boundary_targets = bool(cfg.train.get("online_boundary_targets", False))
752
+ online_carry_attention_cache = bool(cfg.train.get("online_carry_attention_cache", False))
753
+ use_fast_state = bool(cfg.train.get("use_fast_state", False))
754
+ fail_if_faithful_disabled = bool(cfg.train.get("fail_if_paper_faithful_disabled", False))
755
+ strict_streaming = bool(cfg.train.get("strict_streaming_contract", False))
756
+ if distributed and per_layer_teach:
757
+ msg = "per_layer_teach_signal disabled under DDP (uses base model methods)"
758
+ if fail_if_faithful_disabled or strict_streaming:
759
+ raise RuntimeError(
760
+ f"{msg}. Set train.strict_streaming_contract=false and "
761
+ "train.fail_if_paper_faithful_disabled=false to allow the fallback, "
762
+ "or run single-process training."
763
+ )
764
+ _emit_streaming_warning(
765
+ code="ddp_disables_per_layer_teach",
766
+ message=msg,
767
+ details={"distributed": True},
768
+ )
769
+ per_layer_teach = False
770
+ if distributed and online_updates:
771
+ msg = "online_updates disabled under DDP (uses base model methods)"
772
+ if fail_if_faithful_disabled or strict_streaming:
773
+ raise RuntimeError(
774
+ f"{msg}. Set train.strict_streaming_contract=false and "
775
+ "train.fail_if_paper_faithful_disabled=false to allow the fallback, "
776
+ "or run single-process training."
777
+ )
778
+ _emit_streaming_warning(
779
+ code="ddp_disables_online_updates",
780
+ message=msg,
781
+ details={"distributed": True},
782
+ )
783
+ online_updates = False
784
+ if online_boundary_targets and not online_updates:
785
+ msg = "online_boundary_targets=true requires train.online_updates=true"
786
+ if fail_if_faithful_disabled or strict_streaming:
787
+ raise RuntimeError(msg)
788
+ _emit_streaming_warning(
789
+ code="boundary_targets_without_online_updates",
790
+ message=msg,
791
+ )
792
+ online_boundary_targets = False
793
+ if online_carry_attention_cache and not online_updates:
794
+ raise RuntimeError("online_carry_attention_cache=true requires train.online_updates=true")
795
+ if online_carry_attention_cache and not online_boundary_targets:
796
+ raise RuntimeError(
797
+ "online_carry_attention_cache=true requires train.online_boundary_targets=true "
798
+ "(non-overlap chunking)."
799
+ )
800
+ step_iter = iter(dataloader)
801
+ epoch = 0
802
+ metrics: Dict[str, float] = {}
803
+ surprise_metric_getter = getattr(base_model, "get_surprise_metric", None)
804
+ surprise_metric = (
805
+ str(surprise_metric_getter()).strip().lower()
806
+ if callable(surprise_metric_getter)
807
+ else str(cfg.model.get("surprise_metric", "l2")).strip().lower()
808
+ )
809
+ for step in range(steps):
810
+ if sampler is not None and step % len(dataloader) == 0:
811
+ sampler.set_epoch(epoch)
812
+ epoch += 1
813
+ try:
814
+ batch = next(step_iter)
815
+ except StopIteration:
816
+ step_iter = iter(dataloader)
817
+ batch = next(step_iter)
818
+ tokens = batch.to(device)
819
+ fast_state = None
820
+ if use_fast_state:
821
+ init_fn = getattr(base_model, "init_fast_state", None)
822
+ if not callable(init_fn):
823
+ raise ValueError("train.use_fast_state=true requires model.init_fast_state()")
824
+ fast_state = init_fn()
825
+ _apply_teach_schedule(base_model, cfg, step)
826
+ update_metrics: Dict[str, float] = {}
827
+ if online_updates and hasattr(base_model, "forward_with_block_outputs"):
828
+ total_loss = 0.0
829
+ total_tokens = 0
830
+ teach_signal_norm = 0.0
831
+ optimizer.zero_grad()
832
+ chunk_size = online_chunk_size
833
+ if chunk_size <= 0:
834
+ inferred = _infer_online_chunk_size(base_model)
835
+ chunk_size = inferred if inferred is not None else tokens.size(1)
836
+ if chunk_size < 1:
837
+ print(f"[train] online_chunk_size={chunk_size} is too small; clamping to 1")
838
+ chunk_size = 1
839
+ attention_cache = None
840
+ if online_carry_attention_cache:
841
+ init_attention_cache = getattr(base_model, "init_attention_cache", None)
842
+ if not callable(init_attention_cache):
843
+ raise RuntimeError(
844
+ "online_carry_attention_cache=true requires model.init_attention_cache()"
845
+ )
846
+ attention_cache = init_attention_cache()
847
+
848
+ chunk_iter: Iterator[tuple[torch.Tensor, torch.Tensor | None, bool]]
849
+ if online_boundary_targets:
850
+ chunk_iter = _iter_online_boundary_chunks(tokens, chunk_size=chunk_size)
851
+ else:
852
+ chunk_iter = (
853
+ (chunk, None, finalize_updates)
854
+ for chunk, finalize_updates in _iter_online_token_chunks(
855
+ tokens, chunk_size=chunk_size
856
+ )
857
+ )
858
+ for chunk_tokens, next_tokens, finalize_updates in chunk_iter:
859
+ target_count = chunk_tokens.size(1) - 1 + (0 if next_tokens is None else 1)
860
+ if target_count <= 0:
861
+ continue
862
+ chunk_attention_cache = attention_cache
863
+ with autocast_factory():
864
+ if attention_cache is not None:
865
+ logits, _pre, block_outputs, attention_cache = (
866
+ base_model.forward_with_block_outputs(
867
+ chunk_tokens,
868
+ fast_state=fast_state,
869
+ attention_cache=chunk_attention_cache,
870
+ return_attention_cache=True,
871
+ )
872
+ )
873
+ else:
874
+ logits, _pre, block_outputs = (
875
+ base_model.forward_with_block_outputs(
876
+ chunk_tokens,
877
+ fast_state=fast_state,
878
+ )
879
+ if fast_state is not None
880
+ else base_model.forward_with_block_outputs(chunk_tokens)
881
+ )
882
+ if next_tokens is None:
883
+ loss = torch.nn.functional.cross_entropy(
884
+ logits[:, :-1].reshape(-1, logits.size(-1)),
885
+ chunk_tokens[:, 1:].reshape(-1),
886
+ )
887
+ else:
888
+ boundary_targets = torch.cat(
889
+ [chunk_tokens[:, 1:], next_tokens.unsqueeze(1)],
890
+ dim=1,
891
+ )
892
+ loss = torch.nn.functional.cross_entropy(
893
+ logits[:, : boundary_targets.size(1), :].reshape(-1, logits.size(-1)),
894
+ boundary_targets.reshape(-1),
895
+ )
896
+ surprise_override = _compute_surprise_override(
897
+ surprise_metric,
898
+ logits=logits,
899
+ tokens=chunk_tokens,
900
+ loss=loss,
901
+ next_tokens=next_tokens,
902
+ )
903
+ if per_layer_teach:
904
+ differentiable_updates = algorithm_mode == "boundary_state_grad_through_write"
905
+ teach_signals = _compute_layer_teach_signals(
906
+ loss,
907
+ block_outputs,
908
+ detach=not differentiable_updates,
909
+ create_graph=differentiable_updates,
910
+ )
911
+ mean_teach_norm = torch.stack(
912
+ [sig.detach().norm(dim=-1).mean() for sig in teach_signals]
913
+ ).mean()
914
+ teach_signal_norm += float(
915
+ mean_teach_norm
916
+ ) * target_count
917
+ else:
918
+ teach_signal = compute_teach_signal(
919
+ base_model,
920
+ logits,
921
+ chunk_tokens,
922
+ next_tokens=next_tokens,
923
+ )
924
+ teach_signal_norm += teach_signal.norm(dim=-1).mean().item() * target_count
925
+ differentiable_updates = algorithm_mode == "boundary_state_grad_through_write"
926
+ # Boundary-state mode keeps a cross-chunk differentiable write path.
927
+ # Retain the graph so later chunks can backprop through earlier writes.
928
+ loss.backward(retain_graph=differentiable_updates)
929
+ if differentiable_updates:
930
+ if per_layer_teach:
931
+ base_model(
932
+ chunk_tokens,
933
+ teach_signals=teach_signals,
934
+ surprise_value=surprise_override,
935
+ fast_state=fast_state,
936
+ finalize_updates=finalize_updates,
937
+ attention_cache=chunk_attention_cache,
938
+ differentiable_updates=True,
939
+ )
940
+ else:
941
+ base_model(
942
+ chunk_tokens,
943
+ teach_signal=teach_signal,
944
+ surprise_value=surprise_override,
945
+ fast_state=fast_state,
946
+ finalize_updates=finalize_updates,
947
+ attention_cache=chunk_attention_cache,
948
+ differentiable_updates=True,
949
+ )
950
+ if hasattr(base_model, "pop_update_metrics"):
951
+ update_metrics = base_model.pop_update_metrics()
952
+ else:
953
+ with torch.no_grad():
954
+ if per_layer_teach:
955
+ base_model(
956
+ chunk_tokens,
957
+ teach_signals=teach_signals,
958
+ surprise_value=surprise_override,
959
+ fast_state=fast_state,
960
+ finalize_updates=finalize_updates,
961
+ attention_cache=chunk_attention_cache,
962
+ differentiable_updates=False,
963
+ )
964
+ else:
965
+ base_model(
966
+ chunk_tokens,
967
+ teach_signal=teach_signal,
968
+ surprise_value=surprise_override,
969
+ fast_state=fast_state,
970
+ finalize_updates=finalize_updates,
971
+ attention_cache=chunk_attention_cache,
972
+ differentiable_updates=False,
973
+ )
974
+ if hasattr(base_model, "pop_update_metrics"):
975
+ update_metrics = base_model.pop_update_metrics()
976
+ total_loss += loss.item() * target_count
977
+ total_tokens += target_count
978
+ _check_online_supervised_pairs(
979
+ strict=strict_streaming,
980
+ observed_pairs=total_tokens,
981
+ seq_len=int(tokens.size(1)),
982
+ )
983
+ torch.nn.utils.clip_grad_norm_(base_model.parameters(), max_norm=1.0)
984
+ optimizer.step()
985
+ loss = torch.tensor(total_loss / max(total_tokens, 1), device=device)
986
+ teach_signal_norm = teach_signal_norm / max(total_tokens, 1)
987
+ else:
988
+ with autocast_factory():
989
+ if per_layer_teach and hasattr(base_model, "forward_with_block_outputs"):
990
+ logits, _pre, block_outputs = (
991
+ base_model.forward_with_block_outputs(tokens, fast_state=fast_state)
992
+ if fast_state is not None
993
+ else base_model.forward_with_block_outputs(tokens)
994
+ )
995
+ loss = torch.nn.functional.cross_entropy(
996
+ logits[:, :-1].reshape(-1, logits.size(-1)),
997
+ tokens[:, 1:].reshape(-1),
998
+ )
999
+ else:
1000
+ if fast_state is not None:
1001
+ logits = model(tokens, fast_state=fast_state)
1002
+ else:
1003
+ logits = model(tokens)
1004
+ loss = torch.nn.functional.cross_entropy(
1005
+ logits[:, :-1].reshape(-1, logits.size(-1)),
1006
+ tokens[:, 1:].reshape(-1),
1007
+ )
1008
+ surprise_override = _compute_surprise_override(
1009
+ surprise_metric,
1010
+ logits=logits,
1011
+ tokens=tokens,
1012
+ loss=loss,
1013
+ next_tokens=None,
1014
+ )
1015
+ optimizer.zero_grad()
1016
+ if per_layer_teach and hasattr(base_model, "forward_with_block_outputs"):
1017
+ teach_signals = _compute_layer_teach_signals(loss, block_outputs)
1018
+ loss.backward()
1019
+ torch.nn.utils.clip_grad_norm_(base_model.parameters(), max_norm=1.0)
1020
+ optimizer.step()
1021
+ with torch.no_grad():
1022
+ if per_layer_teach and hasattr(base_model, "forward_with_block_outputs"):
1023
+ teach_signal_norm = float(
1024
+ torch.stack([sig.norm(dim=-1).mean() for sig in teach_signals]).mean()
1025
+ )
1026
+ base_model(
1027
+ tokens,
1028
+ teach_signals=teach_signals,
1029
+ surprise_value=surprise_override,
1030
+ fast_state=fast_state,
1031
+ )
1032
+ else:
1033
+ teach_signal = compute_teach_signal(base_model, logits, tokens)
1034
+ teach_signal_norm = teach_signal.norm(dim=-1).mean().item()
1035
+ base_model(
1036
+ tokens,
1037
+ teach_signal=teach_signal,
1038
+ surprise_value=surprise_override,
1039
+ fast_state=fast_state,
1040
+ )
1041
+ if hasattr(base_model, "pop_update_metrics"):
1042
+ update_metrics = base_model.pop_update_metrics()
1043
+ if step % log_interval == 0:
1044
+ ppl = torch.exp(loss.detach()).item()
1045
+ metrics_payload = {
1046
+ "loss": loss.item(),
1047
+ "ppl": ppl,
1048
+ "teach_signal_norm": teach_signal_norm,
1049
+ }
1050
+ metrics_payload.update(update_metrics)
1051
+ logger.log(metrics_payload, step=step)
1052
+ if (not distributed) or (dist_ctx and dist_ctx.rank == 0):
1053
+ print(
1054
+ f"[train] step={step} loss={loss.item():.4f} "
1055
+ f"ppl={ppl:.2f} teach_norm={teach_signal_norm:.4f}"
1056
+ )
1057
+ metrics = metrics_payload
1058
+ maybe_save_checkpoint(
1059
+ cfg,
1060
+ base_model,
1061
+ optimizer,
1062
+ step=step,
1063
+ total_steps=steps,
1064
+ distributed=distributed,
1065
+ dist_ctx=dist_ctx,
1066
+ step_offset=int(cfg.train.get("step_offset", 0) or 0),
1067
+ )
1068
+ logger.finish()
1069
+ return metrics
1070
+
1071
+
1072
+ def _apply_teach_schedule(model: HOPEModel, cfg: DictConfig, step: int) -> None:
1073
+ schedule = cfg.model.get("teach_schedule")
1074
+ base_scale = cfg.model.get("teach_scale", 1.0)
1075
+ scale = base_scale
1076
+ if schedule:
1077
+ warmup = schedule.get("warmup_steps", 0)
1078
+ if warmup and warmup > 0:
1079
+ scale *= min(1.0, (step + 1) / warmup)
1080
+ decay_start = schedule.get("decay_start")
1081
+ decay_duration = schedule.get("decay_duration")
1082
+ if (
1083
+ decay_start is not None
1084
+ and decay_duration
1085
+ and decay_duration > 0
1086
+ and (step + 1) > decay_start
1087
+ ):
1088
+ progress = min(1.0, (step + 1 - decay_start) / decay_duration)
1089
+ scale *= max(0.0, 1.0 - progress)
1090
+ model.set_teach_runtime(scale=scale)
1091
+
1092
+
1093
+ def _maybe_compile_model(model: torch.nn.Module, compile_cfg: dict | None) -> torch.nn.Module:
1094
+ if not compile_cfg or not compile_cfg.get("enable", False):
1095
+ return model
1096
+ kwargs = {}
1097
+ if "mode" in compile_cfg:
1098
+ kwargs["mode"] = compile_cfg["mode"]
1099
+ if "backend" in compile_cfg:
1100
+ kwargs["backend"] = compile_cfg["backend"]
1101
+ try:
1102
+ return cast(torch.nn.Module, torch.compile(model, **kwargs)) # type: ignore[attr-defined]
1103
+ except Exception as err: # pragma: no cover - compile is optional
1104
+ if compile_cfg.get("strict", False):
1105
+ raise
1106
+ print(f"[compile] fallback to eager due to: {err}")
1107
+ return model
1108
+
1109
+
1110
+ def _make_autocast_factory(device: torch.device, mp_cfg: dict | None):
1111
+ if not mp_cfg or not mp_cfg.get("enabled", False):
1112
+ return lambda: nullcontext()
1113
+ dtype = _resolve_autocast_dtype(mp_cfg.get("dtype", "bf16"))
1114
+ device_type = device.type
1115
+ if device_type not in {"cuda", "cpu", "mps"}:
1116
+ device_type = "cpu"
1117
+
1118
+ def factory():
1119
+ try:
1120
+ return torch.autocast(device_type=device_type, dtype=dtype)
1121
+ except Exception as err: # pragma: no cover - device/dtype support varies by backend
1122
+ print(f"[autocast] disabled for device_type={device_type} dtype={dtype}: {err}")
1123
+ return nullcontext()
1124
+
1125
+ return factory
1126
+
1127
+
1128
+ def _resolve_autocast_dtype(name: str) -> torch.dtype:
1129
+ normalized = str(name).lower()
1130
+ if normalized in {"bf16", "bfloat16"}:
1131
+ return torch.bfloat16
1132
+ if normalized in {"fp16", "float16", "half"}:
1133
+ return torch.float16
1134
+ msg = f"Unsupported autocast dtype {name}"
1135
+ raise ValueError(msg)
1136
+
1137
+
1138
+ def _build_optimizer(
1139
+ model: torch.nn.Module, cfg: DictConfig, *, device: torch.device
1140
+ ) -> torch.optim.Optimizer:
1141
+ optimizer_cfg_raw = cfg.get("optim")
1142
+ if isinstance(optimizer_cfg_raw, DictConfig):
1143
+ optimizer_cfg = optimizer_cfg_raw
1144
+ else:
1145
+ optimizer_cfg = cast(DictConfig, OmegaConf.create(optimizer_cfg_raw or {}))
1146
+ param_policy_raw = optimizer_cfg.get("param_policy")
1147
+ if param_policy_raw is None:
1148
+ outer_updates_memory_modules = optimizer_cfg.get("outer_updates_memory_modules")
1149
+ if outer_updates_memory_modules is None:
1150
+ param_policy = "all"
1151
+ else:
1152
+ param_policy = "all" if bool(outer_updates_memory_modules) else "exclude_memory"
1153
+ else:
1154
+ param_policy = str(param_policy_raw).strip().lower()
1155
+ named_params = _select_outer_named_parameters(model, param_policy)
1156
+ if not named_params:
1157
+ raise ValueError(
1158
+ f"No trainable parameters selected for optim.param_policy={param_policy!r}. "
1159
+ "Check freeze_backbone, requires_grad flags, or adjust the policy."
1160
+ )
1161
+ optim_type = str(optimizer_cfg.get("type", "adamw")).lower()
1162
+ if optim_type == "muon":
1163
+ return _build_muon_optimizer(
1164
+ model,
1165
+ optimizer_cfg,
1166
+ device=device,
1167
+ named_params=named_params,
1168
+ param_policy=param_policy,
1169
+ )
1170
+ if optim_type == "m3":
1171
+ return _build_m3_optimizer(
1172
+ model,
1173
+ optimizer_cfg,
1174
+ device=device,
1175
+ named_params=named_params,
1176
+ param_policy=param_policy,
1177
+ )
1178
+ lr = optimizer_cfg.get("lr", 1e-3)
1179
+ betas = optimizer_cfg.get("betas", (0.9, 0.999))
1180
+ weight_decay = optimizer_cfg.get("weight_decay", 0.0)
1181
+ fused_cfg = optimizer_cfg.get("fused", "auto")
1182
+ fused = False
1183
+ if fused_cfg == "auto":
1184
+ fused = device.type == "cuda" and torch.cuda.is_available()
1185
+ else:
1186
+ fused = bool(fused_cfg)
1187
+ kwargs = {"lr": lr, "betas": betas, "weight_decay": weight_decay}
1188
+ if fused:
1189
+ kwargs["fused"] = True
1190
+ params = [param for _, param in named_params]
1191
+ return torch.optim.AdamW(params, **kwargs)
1192
+
1193
+
1194
+ def _build_muon_optimizer(
1195
+ model: torch.nn.Module,
1196
+ optimizer_cfg: DictConfig,
1197
+ *,
1198
+ device: torch.device,
1199
+ named_params: list[tuple[str, torch.nn.Parameter]] | None = None,
1200
+ param_policy: str | None = None,
1201
+ ):
1202
+ if not hasattr(torch.optim, "Muon"):
1203
+ raise RuntimeError("torch.optim.Muon is not available in this PyTorch build")
1204
+ lr = optimizer_cfg.get("lr", 1e-3)
1205
+ weight_decay = optimizer_cfg.get("weight_decay", 0.01)
1206
+ momentum = optimizer_cfg.get("momentum", 0.95)
1207
+ ns_coefficients = optimizer_cfg.get("ns_coefficients")
1208
+ ns_steps = optimizer_cfg.get("ns_steps")
1209
+ eps = optimizer_cfg.get("eps", 1e-7)
1210
+ fused_cfg = optimizer_cfg.get("fused", "auto")
1211
+ fused = False
1212
+ if fused_cfg == "auto":
1213
+ fused = device.type == "cuda" and torch.cuda.is_available()
1214
+ else:
1215
+ fused = bool(fused_cfg)
1216
+ muon_params: list[torch.nn.Parameter] = []
1217
+ adamw_params: list[torch.nn.Parameter] = []
1218
+ source = named_params if named_params is not None else model.named_parameters()
1219
+ for name, param in source:
1220
+ if not param.requires_grad:
1221
+ continue
1222
+ if _is_muon_candidate(name, param):
1223
+ muon_params.append(param)
1224
+ else:
1225
+ adamw_params.append(param)
1226
+ muon_kwargs = {
1227
+ "lr": lr,
1228
+ "weight_decay": weight_decay,
1229
+ "momentum": momentum,
1230
+ "eps": eps,
1231
+ }
1232
+ if ns_coefficients is not None:
1233
+ muon_kwargs["ns_coefficients"] = tuple(ns_coefficients)
1234
+ if ns_steps is not None:
1235
+ muon_kwargs["ns_steps"] = int(ns_steps)
1236
+ muon_opt = torch.optim.Muon(muon_params, **muon_kwargs) if muon_params else None # type: ignore[attr-defined]
1237
+ adamw_kwargs = {
1238
+ "lr": lr,
1239
+ "betas": optimizer_cfg.get("betas", (0.9, 0.999)),
1240
+ "weight_decay": weight_decay,
1241
+ }
1242
+ if fused:
1243
+ adamw_kwargs["fused"] = True
1244
+ adamw_opt = torch.optim.AdamW(adamw_params, **adamw_kwargs) if adamw_params else None
1245
+ muon_elems = int(sum(p.numel() for p in muon_params))
1246
+ adamw_elems = int(sum(p.numel() for p in adamw_params))
1247
+ return _HybridOptimizer(
1248
+ muon_opt,
1249
+ adamw_opt,
1250
+ muon_elems,
1251
+ adamw_elems,
1252
+ primary_name="muon",
1253
+ param_policy=param_policy,
1254
+ )
1255
+
1256
+
1257
+ def _build_m3_optimizer(
1258
+ model: torch.nn.Module,
1259
+ optimizer_cfg: DictConfig,
1260
+ *,
1261
+ device: torch.device,
1262
+ named_params: list[tuple[str, torch.nn.Parameter]] | None = None,
1263
+ param_policy: str | None = None,
1264
+ ):
1265
+ lr = optimizer_cfg.get("lr", 1e-3)
1266
+ weight_decay = optimizer_cfg.get("weight_decay", 0.01)
1267
+ beta1 = optimizer_cfg.get("beta1", 0.9)
1268
+ beta2 = optimizer_cfg.get("beta2", 0.999)
1269
+ beta3 = optimizer_cfg.get("beta3", 0.9)
1270
+ alpha = optimizer_cfg.get("alpha", 1.0)
1271
+ ns_steps = int(optimizer_cfg.get("ns_steps", 3))
1272
+ slow_chunk = int(optimizer_cfg.get("slow_chunk", 100))
1273
+ eps = optimizer_cfg.get("eps", 1e-8)
1274
+ fused_cfg = optimizer_cfg.get("fused", "auto")
1275
+ fused = False
1276
+ if fused_cfg == "auto":
1277
+ fused = device.type == "cuda" and torch.cuda.is_available()
1278
+ else:
1279
+ fused = bool(fused_cfg)
1280
+
1281
+ m3_params: list[torch.nn.Parameter] = []
1282
+ adamw_params: list[torch.nn.Parameter] = []
1283
+ source = named_params if named_params is not None else model.named_parameters()
1284
+ for name, param in source:
1285
+ if not param.requires_grad:
1286
+ continue
1287
+ if _is_muon_candidate(name, param):
1288
+ m3_params.append(param)
1289
+ else:
1290
+ adamw_params.append(param)
1291
+ m3_opt = (
1292
+ M3(
1293
+ m3_params,
1294
+ lr=lr,
1295
+ beta1=beta1,
1296
+ beta2=beta2,
1297
+ beta3=beta3,
1298
+ alpha=alpha,
1299
+ eps=eps,
1300
+ ns_steps=ns_steps,
1301
+ slow_chunk=slow_chunk,
1302
+ weight_decay=weight_decay,
1303
+ )
1304
+ if m3_params
1305
+ else None
1306
+ )
1307
+ adamw_kwargs = {
1308
+ "lr": lr,
1309
+ "betas": optimizer_cfg.get("betas", (0.9, 0.999)),
1310
+ "weight_decay": weight_decay,
1311
+ }
1312
+ if fused:
1313
+ adamw_kwargs["fused"] = True
1314
+ adamw_opt = torch.optim.AdamW(adamw_params, **adamw_kwargs) if adamw_params else None
1315
+ m3_elems = int(sum(p.numel() for p in m3_params))
1316
+ adamw_elems = int(sum(p.numel() for p in adamw_params))
1317
+ return _HybridOptimizer(
1318
+ m3_opt,
1319
+ adamw_opt,
1320
+ m3_elems,
1321
+ adamw_elems,
1322
+ primary_name="m3",
1323
+ param_policy=param_policy,
1324
+ )
1325
+
1326
+
1327
+ def _select_outer_named_parameters(
1328
+ model: torch.nn.Module, param_policy: str
1329
+ ) -> list[tuple[str, torch.nn.Parameter]]:
1330
+ policy = str(param_policy).strip().lower()
1331
+ trainable: list[tuple[str, torch.nn.Parameter]] = [
1332
+ (name, param) for name, param in model.named_parameters() if param.requires_grad
1333
+ ]
1334
+ if policy in {"all", "full"}:
1335
+ return trainable
1336
+ if policy in {"exclude_memory", "no_memory"}:
1337
+ return [(name, param) for name, param in trainable if not _is_memory_param_name(name)]
1338
+ if policy in {"only_memory", "memory_only"}:
1339
+ return [(name, param) for name, param in trainable if _is_memory_param_name(name)]
1340
+ raise ValueError(
1341
+ f"Unsupported optim.param_policy={param_policy!r}. "
1342
+ "Expected one of ['all', 'exclude_memory', 'only_memory']."
1343
+ )
1344
+
1345
+
1346
+ def _is_memory_param_name(name: str) -> bool:
1347
+ lowered = name.lower()
1348
+ return any(token in lowered for token in (".cms.", ".titan_memory.", ".selfmod."))
1349
+
1350
+
1351
+ def _is_muon_candidate(name: str, param: torch.nn.Parameter) -> bool:
1352
+ if param.ndim < 2:
1353
+ return False
1354
+ lowered = name.lower()
1355
+ if "norm" in lowered or "embed" in lowered:
1356
+ return False
1357
+ return True
1358
+
1359
+
1360
+ class _HybridOptimizer:
1361
+ def __init__(
1362
+ self,
1363
+ primary_opt: torch.optim.Optimizer | None,
1364
+ secondary_opt: torch.optim.Optimizer | None,
1365
+ primary_param_elems: int,
1366
+ secondary_param_elems: int,
1367
+ *,
1368
+ primary_name: str = "muon",
1369
+ param_policy: str | None = None,
1370
+ ):
1371
+ self.primary_opt = primary_opt
1372
+ self.secondary_opt = secondary_opt
1373
+ self.primary_param_elems = primary_param_elems
1374
+ self.secondary_param_elems = secondary_param_elems
1375
+ self.primary_name = primary_name
1376
+ self.param_policy = param_policy
1377
+
1378
+ def zero_grad(self) -> None:
1379
+ if self.primary_opt:
1380
+ self.primary_opt.zero_grad()
1381
+ if self.secondary_opt:
1382
+ self.secondary_opt.zero_grad()
1383
+
1384
+ def step(self) -> None:
1385
+ if self.primary_opt:
1386
+ self.primary_opt.step()
1387
+ if self.secondary_opt:
1388
+ self.secondary_opt.step()
1389
+
1390
+ def state_dict(self) -> dict:
1391
+ return {
1392
+ self.primary_name: self.primary_opt.state_dict() if self.primary_opt else None,
1393
+ "adamw": self.secondary_opt.state_dict() if self.secondary_opt else None,
1394
+ }
1395
+
1396
+ def load_state_dict(self, state: dict) -> None:
1397
+ if self.primary_opt and state.get(self.primary_name) is not None:
1398
+ self.primary_opt.load_state_dict(state[self.primary_name])
1399
+ if self.secondary_opt and state.get("adamw") is not None:
1400
+ self.secondary_opt.load_state_dict(state["adamw"])
1401
+
1402
+ @property
1403
+ def param_groups(self):
1404
+ groups = []
1405
+ if self.primary_opt:
1406
+ groups.extend(self.primary_opt.param_groups)
1407
+ if self.secondary_opt:
1408
+ groups.extend(self.secondary_opt.param_groups)
1409
+ return groups
1410
+
1411
+ def get_param_split(self) -> dict[str, int]:
1412
+ return {
1413
+ self.primary_name: self.primary_param_elems,
1414
+ "adamw": self.secondary_param_elems,
1415
+ }
1416
+
1417
+
1418
+ def _log_run_features(
1419
+ logger: BaseLogger,
1420
+ model: torch.nn.Module,
1421
+ cfg: DictConfig,
1422
+ optimizer: torch.optim.Optimizer,
1423
+ device: torch.device,
1424
+ ) -> None:
1425
+ mp_cfg = cfg.train.get("mixed_precision", {})
1426
+ compile_cfg = cfg.train.get("compile", {})
1427
+ algorithm_mode = str(cfg.train.get("algorithm_mode", "two_pass_stopgrad_updates"))
1428
+ features: dict[str, object] = {
1429
+ "train.mixed_precision_enabled": bool(mp_cfg.get("enabled", False)),
1430
+ "train.mixed_precision_dtype": str(mp_cfg.get("dtype", "bf16")),
1431
+ "train.compile_enabled": bool(compile_cfg.get("enable", False)),
1432
+ "train.compile_mode": str(compile_cfg.get("mode", "default")) if compile_cfg else "default",
1433
+ "train.strict_streaming_contract": bool(cfg.train.get("strict_streaming_contract", False)),
1434
+ "train.online_updates": bool(cfg.train.get("online_updates", False)),
1435
+ "train.online_boundary_targets": bool(cfg.train.get("online_boundary_targets", False)),
1436
+ "train.online_carry_attention_cache": bool(
1437
+ cfg.train.get("online_carry_attention_cache", False)
1438
+ ),
1439
+ "train.use_fast_state": bool(cfg.train.get("use_fast_state", False)),
1440
+ "train.algorithm_mode": algorithm_mode,
1441
+ "train.backprop_through_online_writes": algorithm_mode
1442
+ == "boundary_state_grad_through_write",
1443
+ "attention.flash_enabled": _detect_flash_attention(model),
1444
+ "device": device.type,
1445
+ }
1446
+ optimizer_cfg_raw = cfg.get("optim")
1447
+ if isinstance(optimizer_cfg_raw, DictConfig):
1448
+ optimizer_cfg = optimizer_cfg_raw
1449
+ else:
1450
+ optimizer_cfg = cast(DictConfig, OmegaConf.create(optimizer_cfg_raw or {}))
1451
+ param_policy_raw = optimizer_cfg.get("param_policy")
1452
+ if param_policy_raw is None:
1453
+ outer_updates_memory_modules = optimizer_cfg.get("outer_updates_memory_modules")
1454
+ if outer_updates_memory_modules is None:
1455
+ param_policy = "all"
1456
+ else:
1457
+ param_policy = "all" if bool(outer_updates_memory_modules) else "exclude_memory"
1458
+ else:
1459
+ param_policy = str(param_policy_raw).strip().lower()
1460
+ try:
1461
+ selected = _select_outer_named_parameters(model, param_policy)
1462
+ total_elems = int(sum(param.numel() for _, param in selected))
1463
+ memory_elems = int(
1464
+ sum(param.numel() for name, param in selected if _is_memory_param_name(name))
1465
+ )
1466
+ features["optim.param_policy"] = param_policy
1467
+ features["optim.param_policy_param_elems"] = total_elems
1468
+ features["optim.param_policy_memory_param_elems"] = memory_elems
1469
+ features["optim.param_policy_non_memory_param_elems"] = total_elems - memory_elems
1470
+ except Exception as err: # pragma: no cover - purely diagnostic
1471
+ features["optim.param_policy"] = param_policy
1472
+ features["optim.param_policy_error"] = str(err)
1473
+ split_fn = getattr(optimizer, "get_param_split", None)
1474
+ if callable(split_fn):
1475
+ split = split_fn()
1476
+ for key, value in split.items():
1477
+ features[f"optim.{key}_param_elems"] = int(value)
1478
+ logger.log(features, step=-1)
1479
+ print(f"[train] run_features {features}")
1480
+
1481
+
1482
+ def _detect_flash_attention(model: torch.nn.Module) -> bool:
1483
+ blocks = getattr(model, "blocks", [])
1484
+ for block in blocks:
1485
+ attn = getattr(block, "attn", None)
1486
+ config = getattr(attn, "config", None)
1487
+ if config is not None and hasattr(config, "use_flash"):
1488
+ return bool(config.use_flash)
1489
+ return False
1490
+
1491
+
1492
+ def write_checkpoint_metadata(cfg: DictConfig, ckpt_path: Path, step: int) -> None:
1493
+ config_yaml = OmegaConf.to_yaml(cfg)
1494
+ config_path = ckpt_path.with_suffix(".yaml")
1495
+ config_path.write_text(config_yaml)
1496
+ config_hash = sha256(config_yaml.encode("utf-8")).hexdigest()
1497
+ ckpt_hash = _checksum_path(str(ckpt_path))
1498
+ sha_path = ckpt_path.with_suffix(".sha256")
1499
+ if ckpt_hash:
1500
+ sha_path.write_text(f"{ckpt_hash} {ckpt_path.name}\n")
1501
+ tokenizer_path = cfg.data.get("tokenizer_path") if hasattr(cfg, "data") else None
1502
+ metadata = {
1503
+ "step": step,
1504
+ "checkpoint_sha256": ckpt_hash,
1505
+ "config_sha256": config_hash,
1506
+ "tokenizer_hash": _checksum_path(tokenizer_path) if tokenizer_path else None,
1507
+ "config_path": str(config_path),
1508
+ "algorithm_mode": str(cfg.train.get("algorithm_mode", "two_pass_stopgrad_updates")),
1509
+ "online_updates": bool(cfg.train.get("online_updates", False)),
1510
+ "online_boundary_targets": bool(cfg.train.get("online_boundary_targets", False)),
1511
+ "online_carry_attention_cache": bool(
1512
+ cfg.train.get("online_carry_attention_cache", False)
1513
+ ),
1514
+ "use_fast_state": bool(cfg.train.get("use_fast_state", False)),
1515
+ "rng_states": _capture_rng_states(),
1516
+ }
1517
+ ckpt_path.with_suffix(".meta.json").write_text(json.dumps(metadata, indent=2))
1518
+
1519
+
1520
+ def verify_checkpoint_integrity(ckpt_path: Path) -> Dict[str, object]:
1521
+ if not ckpt_path.exists():
1522
+ raise FileNotFoundError(f"Checkpoint {ckpt_path} not found")
1523
+ meta_path = ckpt_path.with_suffix(".meta.json")
1524
+ if not meta_path.exists():
1525
+ raise FileNotFoundError(f"Metadata file {meta_path} missing")
1526
+ metadata = json.loads(meta_path.read_text())
1527
+ computed_sha = _checksum_path(str(ckpt_path))
1528
+ recorded_sha = metadata.get("checkpoint_sha256")
1529
+ if recorded_sha and computed_sha and recorded_sha != computed_sha:
1530
+ raise ValueError(
1531
+ f"Checkpoint SHA mismatch: recorded {recorded_sha} vs computed {computed_sha}"
1532
+ )
1533
+ sha_file = ckpt_path.with_suffix(".sha256")
1534
+ if sha_file.exists() and computed_sha:
1535
+ recorded_line = sha_file.read_text().strip().split()
1536
+ if recorded_line:
1537
+ recorded = recorded_line[0]
1538
+ if recorded != computed_sha:
1539
+ raise ValueError(f".sha256 mismatch: {recorded} vs {computed_sha}")
1540
+ config_path = ckpt_path.with_suffix(".yaml")
1541
+ if not config_path.exists():
1542
+ raise FileNotFoundError(f"Config file {config_path} missing")
1543
+ config_hash = sha256(config_path.read_text().encode("utf-8")).hexdigest()
1544
+ recorded_cfg_hash = metadata.get("config_sha256")
1545
+ if recorded_cfg_hash and recorded_cfg_hash != config_hash:
1546
+ raise ValueError(
1547
+ f"Config SHA mismatch: recorded {recorded_cfg_hash} vs computed {config_hash}"
1548
+ )
1549
+ if "rng_states" not in metadata:
1550
+ raise ValueError("Metadata missing rng_states")
1551
+ return metadata
1552
+
1553
+
1554
+ def _capture_rng_states() -> Dict[str, object]:
1555
+ payload: Dict[str, object] = {
1556
+ "python": _encode_pickle(random.getstate()),
1557
+ "numpy": _encode_pickle(np.random.get_state()),
1558
+ "torch": _tensor_state_to_hex(torch.random.get_rng_state()),
1559
+ }
1560
+ if torch.cuda.is_available():
1561
+ payload["torch_cuda"] = [
1562
+ _tensor_state_to_hex(state) for state in torch.cuda.get_rng_state_all()
1563
+ ] # type: ignore[attr-defined]
1564
+ return payload
1565
+
1566
+
1567
+ def _encode_pickle(obj: object) -> str:
1568
+ return base64.b64encode(pickle.dumps(obj)).decode("ascii")
1569
+
1570
+
1571
+ def _tensor_state_to_hex(state: torch.Tensor) -> str:
1572
+ return state.cpu().numpy().tobytes().hex()
1573
+
1574
+
1575
+ def _seed_everything(seed: int, *, deterministic: bool = False) -> None:
1576
+ random.seed(seed)
1577
+ np.random.seed(seed)
1578
+ torch.manual_seed(seed)
1579
+ if torch.cuda.is_available():
1580
+ torch.cuda.manual_seed_all(seed)
1581
+ if deterministic:
1582
+ torch.use_deterministic_algorithms(True, warn_only=True)
1583
+ if hasattr(torch.backends, "cudnn"):
1584
+ torch.backends.cudnn.benchmark = False # type: ignore[attr-defined]
1585
+ torch.backends.cudnn.deterministic = True # type: ignore[attr-defined]
1586
+ else:
1587
+ torch.use_deterministic_algorithms(False)
1588
+ if hasattr(torch.backends, "cudnn"):
1589
+ torch.backends.cudnn.benchmark = True # type: ignore[attr-defined]
1590
+ torch.backends.cudnn.deterministic = False # type: ignore[attr-defined]
1591
+
1592
+
1593
+ def _make_worker_init_fn(base_seed: int):
1594
+ def _init_fn(worker_id: int) -> None:
1595
+ worker_seed = base_seed + worker_id
1596
+ np.random.seed(worker_seed)
1597
+ random.seed(worker_seed)
1598
+ torch.manual_seed(worker_seed)
1599
+
1600
+ return _init_fn