nested-learning 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nested_learning/__init__.py +12 -0
- nested_learning/__main__.py +12 -0
- nested_learning/assoc_memory.py +23 -0
- nested_learning/backbones.py +147 -0
- nested_learning/capabilities.py +104 -0
- nested_learning/cli.py +253 -0
- nested_learning/cms.py +92 -0
- nested_learning/config_utils.py +50 -0
- nested_learning/configs/ablations/cms_sparse.yaml +46 -0
- nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
- nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
- nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
- nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
- nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
- nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
- nested_learning/configs/data/continual_segments_sample.yaml +9 -0
- nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
- nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
- nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
- nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
- nested_learning/configs/deepspeed/zero3.json +25 -0
- nested_learning/configs/hope/mid.yaml +118 -0
- nested_learning/configs/hope/mid_fsdp.yaml +47 -0
- nested_learning/configs/hope/pilot.yaml +2 -0
- nested_learning/configs/hope/pilot_attention.yaml +9 -0
- nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
- nested_learning/configs/hope/pilot_transformer.yaml +9 -0
- nested_learning/configs/hope/target.yaml +145 -0
- nested_learning/configs/hope/target_fsdp.yaml +47 -0
- nested_learning/configs/mid_smoke.yaml +99 -0
- nested_learning/configs/mid_stage2.yaml +110 -0
- nested_learning/configs/mid_stage2_smoke.yaml +102 -0
- nested_learning/configs/mid_titan_baseline.yaml +92 -0
- nested_learning/configs/pilot.yaml +127 -0
- nested_learning/configs/pilot_paper_faithful.yaml +42 -0
- nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
- nested_learning/configs/pilot_smoke.yaml +80 -0
- nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
- nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
- nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
- nested_learning/continual_classification.py +136 -0
- nested_learning/continual_streaming.py +283 -0
- nested_learning/data.py +153 -0
- nested_learning/device.py +21 -0
- nested_learning/eval_state.py +72 -0
- nested_learning/fast_state.py +108 -0
- nested_learning/functional.py +69 -0
- nested_learning/hope/__init__.py +0 -0
- nested_learning/hope/block.py +1973 -0
- nested_learning/hope/self_mod.py +40 -0
- nested_learning/instrumentation.py +38 -0
- nested_learning/levels.py +94 -0
- nested_learning/logging_utils.py +64 -0
- nested_learning/memorize.py +382 -0
- nested_learning/model.py +604 -0
- nested_learning/optim/__init__.py +0 -0
- nested_learning/optim/deep.py +102 -0
- nested_learning/optim/factory.py +13 -0
- nested_learning/optim/m3.py +121 -0
- nested_learning/optim/manager.py +151 -0
- nested_learning/titan/__init__.py +0 -0
- nested_learning/titan/memory.py +88 -0
- nested_learning/titan/model.py +412 -0
- nested_learning/titan/self_modifying.py +724 -0
- nested_learning/tokenizer.py +28 -0
- nested_learning/tokenizer_coverage.py +77 -0
- nested_learning/training.py +1600 -0
- nested_learning/transformer.py +104 -0
- nested_learning-0.2.0.dist-info/METADATA +390 -0
- nested_learning-0.2.0.dist-info/RECORD +76 -0
- nested_learning-0.2.0.dist-info/WHEEL +4 -0
- nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
- nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,1600 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import pickle
|
|
7
|
+
import random
|
|
8
|
+
from contextlib import nullcontext
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from hashlib import sha256
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Dict, Iterator, Protocol, Tuple, cast
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import torch
|
|
16
|
+
from omegaconf import DictConfig, OmegaConf
|
|
17
|
+
from torch.utils.data import DataLoader, DistributedSampler, IterableDataset
|
|
18
|
+
|
|
19
|
+
from .data import (
|
|
20
|
+
MixtureShardDataset,
|
|
21
|
+
ShardSourceConfig,
|
|
22
|
+
SyntheticTextConfig,
|
|
23
|
+
SyntheticTextDataset,
|
|
24
|
+
TokenShardDataset,
|
|
25
|
+
collate_batch,
|
|
26
|
+
)
|
|
27
|
+
from .levels import LevelSpec
|
|
28
|
+
from .logging_utils import BaseLogger, NullLogger, init_logger
|
|
29
|
+
from .model import HOPEModel, ModelConfig
|
|
30
|
+
from .optim.m3 import M3
|
|
31
|
+
from .titan.model import TitanOnlyModel, TitanOnlyModelConfig
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class DistributedContext:
|
|
36
|
+
rank: int
|
|
37
|
+
world_size: int
|
|
38
|
+
device: torch.device
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def unwrap_config(cfg: DictConfig) -> DictConfig:
|
|
42
|
+
"""Hydra can wrap grouped configs (e.g., hope/pilot) under the group name."""
|
|
43
|
+
if "model" in cfg:
|
|
44
|
+
return cfg
|
|
45
|
+
if "hope" in cfg:
|
|
46
|
+
return cast(DictConfig, cfg.hope)
|
|
47
|
+
if "ablations" in cfg:
|
|
48
|
+
return cast(DictConfig, cfg.ablations)
|
|
49
|
+
return cfg
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def build_model_from_cfg(model_cfg: DictConfig) -> torch.nn.Module:
|
|
53
|
+
model_type = model_cfg.get("type", "hope")
|
|
54
|
+
optimizer_cfg: Dict[str, dict] = {}
|
|
55
|
+
if "optimizers" in model_cfg:
|
|
56
|
+
optimizer_cfg = cast(
|
|
57
|
+
Dict[str, dict],
|
|
58
|
+
OmegaConf.to_container(model_cfg.optimizers, resolve=True),
|
|
59
|
+
)
|
|
60
|
+
teach_scale = model_cfg.get("teach_scale", 1.0)
|
|
61
|
+
teach_clip = model_cfg.get("teach_clip", 0.0)
|
|
62
|
+
teach_schedule: Dict[str, float] = {}
|
|
63
|
+
if "teach_schedule" in model_cfg:
|
|
64
|
+
teach_schedule = cast(
|
|
65
|
+
Dict[str, float],
|
|
66
|
+
OmegaConf.to_container(model_cfg.teach_schedule, resolve=True),
|
|
67
|
+
)
|
|
68
|
+
qk_l2_norm = bool(model_cfg.get("qk_l2_norm", False))
|
|
69
|
+
local_conv_window_raw = model_cfg.get("local_conv_window")
|
|
70
|
+
local_conv_window = None if local_conv_window_raw is None else int(local_conv_window_raw)
|
|
71
|
+
surprise_threshold_raw = model_cfg.get("surprise_threshold")
|
|
72
|
+
surprise_threshold = (
|
|
73
|
+
None if surprise_threshold_raw is None else float(surprise_threshold_raw)
|
|
74
|
+
)
|
|
75
|
+
surprise_metric = str(model_cfg.get("surprise_metric", "l2"))
|
|
76
|
+
cms_use_layernorm = bool(model_cfg.get("cms_use_layernorm", True))
|
|
77
|
+
if model_type == "titan":
|
|
78
|
+
titan_spec = LevelSpec(**model_cfg.titan_level)
|
|
79
|
+
titan_cfg = TitanOnlyModelConfig(
|
|
80
|
+
vocab_size=model_cfg.vocab_size,
|
|
81
|
+
dim=model_cfg.dim,
|
|
82
|
+
num_layers=model_cfg.num_layers,
|
|
83
|
+
heads=model_cfg.heads,
|
|
84
|
+
titan_level=titan_spec,
|
|
85
|
+
optimizers=optimizer_cfg,
|
|
86
|
+
teach_scale=teach_scale,
|
|
87
|
+
teach_clip=teach_clip,
|
|
88
|
+
teach_schedule=teach_schedule,
|
|
89
|
+
qk_l2_norm=qk_l2_norm,
|
|
90
|
+
local_conv_window=local_conv_window,
|
|
91
|
+
surprise_threshold=surprise_threshold,
|
|
92
|
+
surprise_metric=surprise_metric,
|
|
93
|
+
freeze_backbone=model_cfg.get("freeze_backbone", False),
|
|
94
|
+
self_mod_lr=float(model_cfg.get("self_mod_lr", 1e-3)),
|
|
95
|
+
self_mod_hidden=int(model_cfg.get("self_mod_hidden", 4)),
|
|
96
|
+
)
|
|
97
|
+
return TitanOnlyModel(titan_cfg)
|
|
98
|
+
titan_spec = LevelSpec(**model_cfg.titan_level)
|
|
99
|
+
cms_specs = [LevelSpec(**entry) for entry in model_cfg.cms_levels]
|
|
100
|
+
self_mod_chunk_size_memory_raw = model_cfg.get("self_mod_chunk_size_memory")
|
|
101
|
+
self_mod_chunk_size_memory = (
|
|
102
|
+
None if self_mod_chunk_size_memory_raw is None else int(self_mod_chunk_size_memory_raw)
|
|
103
|
+
)
|
|
104
|
+
self_mod_local_conv_window_raw = model_cfg.get("self_mod_local_conv_window", 4)
|
|
105
|
+
self_mod_local_conv_window = (
|
|
106
|
+
None if self_mod_local_conv_window_raw is None else int(self_mod_local_conv_window_raw)
|
|
107
|
+
)
|
|
108
|
+
hope_cfg = ModelConfig(
|
|
109
|
+
vocab_size=model_cfg.vocab_size,
|
|
110
|
+
dim=model_cfg.dim,
|
|
111
|
+
num_layers=model_cfg.num_layers,
|
|
112
|
+
heads=model_cfg.heads,
|
|
113
|
+
titan_level=titan_spec,
|
|
114
|
+
cms_levels=cms_specs,
|
|
115
|
+
cms_flush_partial_at_end=bool(model_cfg.get("cms_flush_partial_at_end", False)),
|
|
116
|
+
cms_use_layernorm=cms_use_layernorm,
|
|
117
|
+
optimizers=optimizer_cfg,
|
|
118
|
+
teach_scale=teach_scale,
|
|
119
|
+
teach_clip=teach_clip,
|
|
120
|
+
teach_schedule=teach_schedule,
|
|
121
|
+
gradient_checkpointing=model_cfg.get("gradient_checkpointing", False),
|
|
122
|
+
surprise_threshold=surprise_threshold,
|
|
123
|
+
surprise_metric=surprise_metric,
|
|
124
|
+
freeze_backbone=model_cfg.get("freeze_backbone", False),
|
|
125
|
+
qk_l2_norm=qk_l2_norm,
|
|
126
|
+
local_conv_window=local_conv_window,
|
|
127
|
+
self_mod_lr=float(model_cfg.get("self_mod_lr", 1e-3)),
|
|
128
|
+
self_mod_hidden=int(model_cfg.get("self_mod_hidden", 4)),
|
|
129
|
+
self_mod_chunk_size=int(model_cfg.get("self_mod_chunk_size", 1)),
|
|
130
|
+
self_mod_chunk_size_memory=self_mod_chunk_size_memory,
|
|
131
|
+
self_mod_objective=str(model_cfg.get("self_mod_objective", "l2")),
|
|
132
|
+
self_mod_stopgrad_vhat=bool(model_cfg.get("self_mod_stopgrad_vhat", True)),
|
|
133
|
+
self_mod_use_rank1_precond=bool(model_cfg.get("self_mod_use_rank1_precond", True)),
|
|
134
|
+
self_mod_use_alpha=bool(model_cfg.get("self_mod_use_alpha", True)),
|
|
135
|
+
self_mod_use_skip=bool(model_cfg.get("self_mod_use_skip", True)),
|
|
136
|
+
self_mod_momentum=float(model_cfg.get("self_mod_momentum", 0.0)),
|
|
137
|
+
self_mod_adaptive_q=bool(model_cfg.get("self_mod_adaptive_q", False)),
|
|
138
|
+
self_mod_local_conv_window=self_mod_local_conv_window,
|
|
139
|
+
transformer_mlp_hidden_multiplier=int(
|
|
140
|
+
model_cfg.get("transformer_mlp_hidden_multiplier", 4)
|
|
141
|
+
),
|
|
142
|
+
transformer_activation=str(model_cfg.get("transformer_activation", "gelu")),
|
|
143
|
+
block_variant=str(model_cfg.get("block_variant", "hope_hybrid")),
|
|
144
|
+
)
|
|
145
|
+
return HOPEModel(hope_cfg)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def build_dataloader(
|
|
149
|
+
data_cfg: DictConfig,
|
|
150
|
+
*,
|
|
151
|
+
distributed: bool,
|
|
152
|
+
dist_ctx: DistributedContext | None,
|
|
153
|
+
seed: int | None = None,
|
|
154
|
+
) -> Tuple[DataLoader, DistributedSampler | None]:
|
|
155
|
+
dataset = _build_dataset(data_cfg)
|
|
156
|
+
use_sampler = distributed and not isinstance(dataset, IterableDataset)
|
|
157
|
+
if use_sampler:
|
|
158
|
+
assert dist_ctx is not None
|
|
159
|
+
sampler: DistributedSampler | None = DistributedSampler(
|
|
160
|
+
dataset,
|
|
161
|
+
num_replicas=dist_ctx.world_size,
|
|
162
|
+
rank=dist_ctx.rank,
|
|
163
|
+
shuffle=True,
|
|
164
|
+
drop_last=False,
|
|
165
|
+
)
|
|
166
|
+
shuffle = False
|
|
167
|
+
else:
|
|
168
|
+
sampler = None
|
|
169
|
+
shuffle = True
|
|
170
|
+
if isinstance(dataset, IterableDataset):
|
|
171
|
+
shuffle = False
|
|
172
|
+
generator = None
|
|
173
|
+
worker_init_fn = None
|
|
174
|
+
if seed is not None:
|
|
175
|
+
generator = torch.Generator()
|
|
176
|
+
generator.manual_seed(seed)
|
|
177
|
+
worker_init_fn = _make_worker_init_fn(seed)
|
|
178
|
+
dataloader = DataLoader(
|
|
179
|
+
dataset,
|
|
180
|
+
batch_size=data_cfg.batch_size,
|
|
181
|
+
shuffle=shuffle,
|
|
182
|
+
sampler=sampler,
|
|
183
|
+
collate_fn=collate_batch,
|
|
184
|
+
num_workers=data_cfg.get("num_workers", 0),
|
|
185
|
+
pin_memory=True,
|
|
186
|
+
worker_init_fn=worker_init_fn,
|
|
187
|
+
generator=generator,
|
|
188
|
+
)
|
|
189
|
+
return dataloader, sampler
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _build_dataset(data_cfg: DictConfig):
|
|
193
|
+
source = data_cfg.source
|
|
194
|
+
if source == "synthetic":
|
|
195
|
+
synth_cfg = SyntheticTextConfig(
|
|
196
|
+
vocab_size=data_cfg.vocab_size,
|
|
197
|
+
seq_len=data_cfg.seq_len,
|
|
198
|
+
dataset_size=data_cfg.dataset_size,
|
|
199
|
+
)
|
|
200
|
+
return SyntheticTextDataset(synth_cfg)
|
|
201
|
+
if source == "shards":
|
|
202
|
+
shard_dir = data_cfg.shards_dir
|
|
203
|
+
return TokenShardDataset(shard_dir)
|
|
204
|
+
if source == "mixture":
|
|
205
|
+
mixture_cfg = data_cfg.mixture
|
|
206
|
+
sources = [
|
|
207
|
+
ShardSourceConfig(
|
|
208
|
+
name=entry.name,
|
|
209
|
+
shards_dir=entry.shards_dir,
|
|
210
|
+
weight=entry.weight,
|
|
211
|
+
)
|
|
212
|
+
for entry in mixture_cfg.sources
|
|
213
|
+
]
|
|
214
|
+
samples_per_epoch = mixture_cfg.samples_per_epoch
|
|
215
|
+
seed = mixture_cfg.get("seed", 0)
|
|
216
|
+
return MixtureShardDataset(
|
|
217
|
+
sources,
|
|
218
|
+
samples_per_epoch=samples_per_epoch,
|
|
219
|
+
seed=seed,
|
|
220
|
+
)
|
|
221
|
+
msg = f"Unsupported data source {source}"
|
|
222
|
+
raise ValueError(msg)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def compute_teach_signal(
|
|
226
|
+
model: "_HasLMHead",
|
|
227
|
+
logits: torch.Tensor,
|
|
228
|
+
tokens: torch.Tensor,
|
|
229
|
+
*,
|
|
230
|
+
next_tokens: torch.Tensor | None = None,
|
|
231
|
+
ignore_index: int | None = None,
|
|
232
|
+
) -> torch.Tensor:
|
|
233
|
+
"""
|
|
234
|
+
Approximate dL/dh where h is the hidden state before the LM head.
|
|
235
|
+
|
|
236
|
+
This matches the gradient of mean next-token CE.
|
|
237
|
+
|
|
238
|
+
By default this corresponds to CE(logits[:, :-1], tokens[:, 1:]).
|
|
239
|
+
If `next_tokens` is provided, the final logit position is also supervised
|
|
240
|
+
against that boundary target (used for chunked streaming boundaries).
|
|
241
|
+
|
|
242
|
+
If ignore_index is provided, targets equal to ignore_index are masked out and
|
|
243
|
+
the mean reduction denominator becomes the number of active targets (matching
|
|
244
|
+
PyTorch CE semantics).
|
|
245
|
+
"""
|
|
246
|
+
logits_detached = logits.detach()
|
|
247
|
+
probs = torch.softmax(logits_detached, dim=-1)
|
|
248
|
+
residual = probs.clone()
|
|
249
|
+
batch_size, seq_len, _ = residual.shape
|
|
250
|
+
|
|
251
|
+
targets = torch.zeros(
|
|
252
|
+
batch_size,
|
|
253
|
+
seq_len,
|
|
254
|
+
device=tokens.device,
|
|
255
|
+
dtype=tokens.dtype,
|
|
256
|
+
)
|
|
257
|
+
active = torch.zeros(
|
|
258
|
+
batch_size,
|
|
259
|
+
seq_len,
|
|
260
|
+
device=tokens.device,
|
|
261
|
+
dtype=torch.bool,
|
|
262
|
+
)
|
|
263
|
+
if seq_len > 1:
|
|
264
|
+
targets[:, :-1] = tokens[:, 1:]
|
|
265
|
+
active[:, :-1] = True
|
|
266
|
+
if next_tokens is not None:
|
|
267
|
+
if next_tokens.ndim == 2 and next_tokens.size(1) == 1:
|
|
268
|
+
next_targets = next_tokens[:, 0]
|
|
269
|
+
elif next_tokens.ndim == 1:
|
|
270
|
+
next_targets = next_tokens
|
|
271
|
+
else:
|
|
272
|
+
raise ValueError("next_tokens must have shape [B] or [B, 1]")
|
|
273
|
+
if next_targets.size(0) != batch_size:
|
|
274
|
+
raise ValueError("next_tokens batch dimension must match tokens batch dimension")
|
|
275
|
+
targets[:, -1] = next_targets.to(device=tokens.device, dtype=tokens.dtype)
|
|
276
|
+
active[:, -1] = True
|
|
277
|
+
if ignore_index is not None:
|
|
278
|
+
active = active & (targets != ignore_index)
|
|
279
|
+
|
|
280
|
+
active_f = active.to(dtype=residual.dtype)
|
|
281
|
+
residual.mul_(active_f.unsqueeze(-1))
|
|
282
|
+
safe_targets = torch.where(active, targets, torch.zeros_like(targets))
|
|
283
|
+
src = -active_f.unsqueeze(-1)
|
|
284
|
+
residual.scatter_add_(-1, safe_targets.unsqueeze(-1), src)
|
|
285
|
+
denom: torch.Tensor = active_f.sum().clamp(min=1.0)
|
|
286
|
+
residual = residual / denom
|
|
287
|
+
|
|
288
|
+
head_weight = model.lm_head.weight.detach()
|
|
289
|
+
if head_weight.dtype != residual.dtype:
|
|
290
|
+
head_weight = head_weight.to(dtype=residual.dtype)
|
|
291
|
+
grad = residual @ head_weight
|
|
292
|
+
return grad
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def _compute_layer_teach_signals(
|
|
296
|
+
loss: torch.Tensor,
|
|
297
|
+
block_outputs: list[torch.Tensor],
|
|
298
|
+
*,
|
|
299
|
+
detach: bool = True,
|
|
300
|
+
create_graph: bool = False,
|
|
301
|
+
) -> list[torch.Tensor]:
|
|
302
|
+
grads = torch.autograd.grad(
|
|
303
|
+
loss,
|
|
304
|
+
block_outputs,
|
|
305
|
+
retain_graph=True,
|
|
306
|
+
create_graph=create_graph,
|
|
307
|
+
allow_unused=False,
|
|
308
|
+
)
|
|
309
|
+
if detach:
|
|
310
|
+
return [g.detach() for g in grads]
|
|
311
|
+
return list(grads)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _compute_surprise_override(
|
|
315
|
+
metric: str,
|
|
316
|
+
*,
|
|
317
|
+
logits: torch.Tensor,
|
|
318
|
+
tokens: torch.Tensor,
|
|
319
|
+
loss: torch.Tensor,
|
|
320
|
+
next_tokens: torch.Tensor | None = None,
|
|
321
|
+
) -> float | None:
|
|
322
|
+
normalized = str(metric).strip().lower()
|
|
323
|
+
if normalized == "loss":
|
|
324
|
+
return float(loss.detach().item())
|
|
325
|
+
if normalized == "logit_entropy":
|
|
326
|
+
supervised_steps = int(tokens.size(1) - 1 + (0 if next_tokens is None else 1))
|
|
327
|
+
if supervised_steps <= 0:
|
|
328
|
+
return None
|
|
329
|
+
logits_detached = logits[:, :supervised_steps].detach().float()
|
|
330
|
+
probs = torch.softmax(logits_detached, dim=-1)
|
|
331
|
+
entropy = -(probs * torch.log(probs.clamp(min=1e-9))).sum(dim=-1).mean()
|
|
332
|
+
return float(entropy.item())
|
|
333
|
+
return None
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def _infer_online_chunk_size(model: HOPEModel) -> int | None:
|
|
337
|
+
min_period: int | None = None
|
|
338
|
+
blocks = getattr(model, "blocks", [])
|
|
339
|
+
for block in blocks:
|
|
340
|
+
cfg = getattr(block, "config", None)
|
|
341
|
+
levels = getattr(cfg, "cms_levels", None)
|
|
342
|
+
if not levels:
|
|
343
|
+
continue
|
|
344
|
+
for spec in levels:
|
|
345
|
+
period = int(spec.update_period)
|
|
346
|
+
if period <= 0:
|
|
347
|
+
continue
|
|
348
|
+
min_period = period if min_period is None else min(min_period, period)
|
|
349
|
+
return min_period
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def _iter_online_token_chunks(
|
|
353
|
+
tokens: torch.Tensor, *, chunk_size: int
|
|
354
|
+
) -> Iterator[tuple[torch.Tensor, bool]]:
|
|
355
|
+
if chunk_size < 1:
|
|
356
|
+
raise ValueError("chunk_size must be >= 1")
|
|
357
|
+
seq_len = tokens.size(1)
|
|
358
|
+
for core_start in range(0, seq_len, chunk_size):
|
|
359
|
+
core_end = min(core_start + chunk_size, seq_len)
|
|
360
|
+
if core_end <= core_start:
|
|
361
|
+
continue
|
|
362
|
+
# Carry one-token overlap so chunk boundaries still include next-token supervision.
|
|
363
|
+
chunk_start = core_start - 1 if core_start > 0 else core_start
|
|
364
|
+
chunk_tokens = tokens[:, chunk_start:core_end]
|
|
365
|
+
finalize_updates = core_end >= seq_len
|
|
366
|
+
yield chunk_tokens, finalize_updates
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def _iter_online_boundary_chunks(
|
|
370
|
+
tokens: torch.Tensor, *, chunk_size: int
|
|
371
|
+
) -> Iterator[tuple[torch.Tensor, torch.Tensor | None, bool]]:
|
|
372
|
+
"""
|
|
373
|
+
Yield non-overlapping chunks plus the boundary target token for chunk end.
|
|
374
|
+
|
|
375
|
+
This enables exact boundary supervision without one-token overlap.
|
|
376
|
+
"""
|
|
377
|
+
if chunk_size < 1:
|
|
378
|
+
raise ValueError("chunk_size must be >= 1")
|
|
379
|
+
seq_len = tokens.size(1)
|
|
380
|
+
for start in range(0, seq_len, chunk_size):
|
|
381
|
+
end = min(start + chunk_size, seq_len)
|
|
382
|
+
if end <= start:
|
|
383
|
+
continue
|
|
384
|
+
next_tokens = None
|
|
385
|
+
if end < seq_len:
|
|
386
|
+
next_tokens = tokens[:, end]
|
|
387
|
+
finalize_updates = end >= seq_len
|
|
388
|
+
yield tokens[:, start:end], next_tokens, finalize_updates
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
class _HasLMHead(Protocol):
|
|
392
|
+
lm_head: torch.nn.Linear
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def _checksum_path(path: str | None) -> str | None:
|
|
396
|
+
if not path:
|
|
397
|
+
return None
|
|
398
|
+
candidate = Path(path)
|
|
399
|
+
if not candidate.exists() or not candidate.is_file():
|
|
400
|
+
return None
|
|
401
|
+
digest = sha256()
|
|
402
|
+
with candidate.open("rb") as handle:
|
|
403
|
+
for chunk in iter(lambda: handle.read(1 << 20), b""):
|
|
404
|
+
digest.update(chunk)
|
|
405
|
+
return digest.hexdigest()
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def maybe_save_checkpoint(
|
|
409
|
+
cfg: DictConfig,
|
|
410
|
+
model: torch.nn.Module,
|
|
411
|
+
optimizer: torch.optim.Optimizer,
|
|
412
|
+
*,
|
|
413
|
+
step: int,
|
|
414
|
+
total_steps: int,
|
|
415
|
+
distributed: bool,
|
|
416
|
+
dist_ctx: DistributedContext | None,
|
|
417
|
+
step_offset: int = 0,
|
|
418
|
+
) -> None:
|
|
419
|
+
ckpt_cfg = cfg.train.get("checkpoint")
|
|
420
|
+
if not ckpt_cfg or not ckpt_cfg.get("enable", False):
|
|
421
|
+
return
|
|
422
|
+
if distributed and dist_ctx is not None and dist_ctx.rank != 0:
|
|
423
|
+
return
|
|
424
|
+
save_interval = ckpt_cfg.get("save_interval", total_steps)
|
|
425
|
+
save_last = ckpt_cfg.get("save_last", True)
|
|
426
|
+
is_last_step = (step + 1) >= total_steps
|
|
427
|
+
should_save = ((step + 1) % max(1, save_interval) == 0) or (save_last and is_last_step)
|
|
428
|
+
if not should_save:
|
|
429
|
+
return
|
|
430
|
+
ckpt_dir = Path(ckpt_cfg.get("dir", "checkpoints/default"))
|
|
431
|
+
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
|
432
|
+
global_step = step + 1 + int(step_offset)
|
|
433
|
+
ckpt_path = ckpt_dir / f"step_{global_step:06d}.pt"
|
|
434
|
+
tmp_path = ckpt_path.with_suffix(".tmp")
|
|
435
|
+
resolved_cfg = OmegaConf.to_container(cfg, resolve=True)
|
|
436
|
+
state = {
|
|
437
|
+
"model": model.state_dict(),
|
|
438
|
+
"optimizer": optimizer.state_dict(),
|
|
439
|
+
"step": step + 1,
|
|
440
|
+
"config": resolved_cfg,
|
|
441
|
+
}
|
|
442
|
+
torch.save(state, tmp_path)
|
|
443
|
+
os.replace(tmp_path, ckpt_path)
|
|
444
|
+
write_checkpoint_metadata(cfg, ckpt_path, global_step)
|
|
445
|
+
prefix = "[checkpoint]"
|
|
446
|
+
if distributed and dist_ctx is not None:
|
|
447
|
+
prefix = f"[checkpoint rank={dist_ctx.rank}]"
|
|
448
|
+
print(f"{prefix} saved {ckpt_path} (global_step={global_step})")
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def _validate_distributed_config(cfg: DictConfig, distributed: bool) -> None:
|
|
452
|
+
if not distributed:
|
|
453
|
+
return
|
|
454
|
+
strict = bool(cfg.train.get("strict_streaming_contract", False))
|
|
455
|
+
fail_if_faithful_disabled = bool(cfg.train.get("fail_if_paper_faithful_disabled", False))
|
|
456
|
+
fail_hard = strict or fail_if_faithful_disabled
|
|
457
|
+
if not fail_hard:
|
|
458
|
+
return
|
|
459
|
+
if bool(cfg.train.get("per_layer_teach_signal", False)):
|
|
460
|
+
raise RuntimeError(
|
|
461
|
+
"train.per_layer_teach_signal=true is not supported under DDP in this repo. "
|
|
462
|
+
"Set train.strict_streaming_contract=false and "
|
|
463
|
+
"train.fail_if_paper_faithful_disabled=false to allow the fallback, "
|
|
464
|
+
"or run single-process training."
|
|
465
|
+
)
|
|
466
|
+
if bool(cfg.train.get("online_updates", False)):
|
|
467
|
+
raise RuntimeError(
|
|
468
|
+
"train.online_updates=true is not supported under DDP in this repo. "
|
|
469
|
+
"Set train.strict_streaming_contract=false and "
|
|
470
|
+
"train.fail_if_paper_faithful_disabled=false to allow the fallback, "
|
|
471
|
+
"or run single-process training."
|
|
472
|
+
)
|
|
473
|
+
if bool(cfg.train.get("online_boundary_targets", False)):
|
|
474
|
+
raise RuntimeError(
|
|
475
|
+
"train.online_boundary_targets=true is not supported under DDP in this repo. "
|
|
476
|
+
"Set train.strict_streaming_contract=false and "
|
|
477
|
+
"train.fail_if_paper_faithful_disabled=false to allow the fallback, "
|
|
478
|
+
"or run single-process training."
|
|
479
|
+
)
|
|
480
|
+
if bool(cfg.train.get("online_carry_attention_cache", False)):
|
|
481
|
+
raise RuntimeError(
|
|
482
|
+
"train.online_carry_attention_cache=true is not supported under DDP in this repo. "
|
|
483
|
+
"Set train.strict_streaming_contract=false and "
|
|
484
|
+
"train.fail_if_paper_faithful_disabled=false to allow the fallback, "
|
|
485
|
+
"or run single-process training."
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
def _emit_streaming_warning(
|
|
490
|
+
*,
|
|
491
|
+
code: str,
|
|
492
|
+
message: str,
|
|
493
|
+
details: dict[str, object] | None = None,
|
|
494
|
+
) -> None:
|
|
495
|
+
payload: dict[str, object] = {"warning_code": code, "message": message}
|
|
496
|
+
if details:
|
|
497
|
+
payload["details"] = details
|
|
498
|
+
print(f"[train.warn] {json.dumps(payload, sort_keys=True)}")
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def _validate_paper_auditing_variant(cfg: DictConfig) -> None:
|
|
502
|
+
strict = bool(cfg.train.get("strict_streaming_contract", False))
|
|
503
|
+
block_variant = str(cfg.model.get("block_variant", "")).strip().lower()
|
|
504
|
+
if not block_variant:
|
|
505
|
+
return
|
|
506
|
+
allowed = {"hope_attention", "hope_selfmod"}
|
|
507
|
+
if block_variant in allowed:
|
|
508
|
+
return
|
|
509
|
+
msg = (
|
|
510
|
+
"strict streaming contract expects a paper-defined HOPE variant "
|
|
511
|
+
f"({sorted(allowed)}), got model.block_variant={block_variant!r}"
|
|
512
|
+
)
|
|
513
|
+
if strict:
|
|
514
|
+
raise RuntimeError(msg)
|
|
515
|
+
_emit_streaming_warning(
|
|
516
|
+
code="non_paper_variant",
|
|
517
|
+
message=msg,
|
|
518
|
+
details={"block_variant": block_variant},
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
def _validate_tied_lm_head_for_paper_auditing(
|
|
523
|
+
cfg: DictConfig,
|
|
524
|
+
model: torch.nn.Module,
|
|
525
|
+
) -> None:
|
|
526
|
+
strict = bool(cfg.train.get("strict_streaming_contract", False))
|
|
527
|
+
fail_if_faithful_disabled = bool(cfg.train.get("fail_if_paper_faithful_disabled", False))
|
|
528
|
+
if not (strict or fail_if_faithful_disabled):
|
|
529
|
+
return
|
|
530
|
+
lm_head = getattr(model, "lm_head", None)
|
|
531
|
+
embed = getattr(model, "embed", None)
|
|
532
|
+
if lm_head is None or embed is None:
|
|
533
|
+
return
|
|
534
|
+
lm_weight = getattr(lm_head, "weight", None)
|
|
535
|
+
emb_weight = getattr(embed, "weight", None)
|
|
536
|
+
if lm_weight is None or emb_weight is None:
|
|
537
|
+
return
|
|
538
|
+
if lm_weight.data_ptr() == emb_weight.data_ptr():
|
|
539
|
+
return
|
|
540
|
+
raise RuntimeError(
|
|
541
|
+
"paper-auditing mode requires tied LM head and embedding weights "
|
|
542
|
+
"(lm_head.weight must alias embed.weight)."
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
def _validate_fast_state_batch_semantics(cfg: DictConfig) -> None:
|
|
547
|
+
if not bool(cfg.train.get("use_fast_state", False)):
|
|
548
|
+
return
|
|
549
|
+
data_cfg = cfg.get("data")
|
|
550
|
+
if data_cfg is None:
|
|
551
|
+
return
|
|
552
|
+
batch_size_raw = data_cfg.get("batch_size", 1)
|
|
553
|
+
try:
|
|
554
|
+
batch_size = int(batch_size_raw)
|
|
555
|
+
except (TypeError, ValueError):
|
|
556
|
+
return
|
|
557
|
+
if batch_size <= 1:
|
|
558
|
+
return
|
|
559
|
+
msg = (
|
|
560
|
+
"train.use_fast_state=true currently shares CMS/TITAN fast state across the batch. "
|
|
561
|
+
"For strict per-context semantics, set data.batch_size=1."
|
|
562
|
+
)
|
|
563
|
+
strict = bool(cfg.train.get("strict_streaming_contract", False))
|
|
564
|
+
fail_if_faithful_disabled = bool(cfg.train.get("fail_if_paper_faithful_disabled", False))
|
|
565
|
+
if strict or fail_if_faithful_disabled:
|
|
566
|
+
raise RuntimeError(msg)
|
|
567
|
+
_emit_streaming_warning(
|
|
568
|
+
code="shared_fast_state_batch",
|
|
569
|
+
message=msg,
|
|
570
|
+
details={"batch_size": batch_size},
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
def _validate_online_update_fast_state_semantics(cfg: DictConfig) -> None:
|
|
575
|
+
train_cfg = cfg.get("train")
|
|
576
|
+
if train_cfg is None:
|
|
577
|
+
return
|
|
578
|
+
online_updates = bool(train_cfg.get("online_updates", False))
|
|
579
|
+
use_fast_state = bool(train_cfg.get("use_fast_state", False))
|
|
580
|
+
if not online_updates or use_fast_state:
|
|
581
|
+
return
|
|
582
|
+
msg = (
|
|
583
|
+
"train.online_updates=true with train.use_fast_state=false applies online writes "
|
|
584
|
+
"directly to base parameters within each step. This can make gradients across chunks "
|
|
585
|
+
"harder to interpret. Use train.use_fast_state=true for paper-faithful runs."
|
|
586
|
+
)
|
|
587
|
+
strict = bool(train_cfg.get("strict_streaming_contract", False))
|
|
588
|
+
fail_if_faithful_disabled = bool(train_cfg.get("fail_if_paper_faithful_disabled", False))
|
|
589
|
+
if strict or fail_if_faithful_disabled:
|
|
590
|
+
raise RuntimeError(msg)
|
|
591
|
+
_emit_streaming_warning(
|
|
592
|
+
code="online_updates_without_fast_state",
|
|
593
|
+
message=msg,
|
|
594
|
+
details={"online_updates": True, "use_fast_state": False},
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
def _resolve_algorithm_mode(cfg: DictConfig) -> str:
|
|
599
|
+
mode = str(cfg.train.get("algorithm_mode", "two_pass_stopgrad_updates")).strip()
|
|
600
|
+
allowed = {"two_pass_stopgrad_updates", "boundary_state_grad_through_write"}
|
|
601
|
+
if mode not in allowed:
|
|
602
|
+
raise RuntimeError(f"Unsupported train.algorithm_mode={mode!r}; allowed={sorted(allowed)}")
|
|
603
|
+
return mode
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
def _validate_algorithm_mode_constraints(
|
|
607
|
+
cfg: DictConfig,
|
|
608
|
+
*,
|
|
609
|
+
algorithm_mode: str,
|
|
610
|
+
distributed: bool,
|
|
611
|
+
) -> None:
|
|
612
|
+
if algorithm_mode != "boundary_state_grad_through_write":
|
|
613
|
+
return
|
|
614
|
+
if distributed:
|
|
615
|
+
raise RuntimeError(
|
|
616
|
+
"train.algorithm_mode='boundary_state_grad_through_write' is not supported in DDP."
|
|
617
|
+
)
|
|
618
|
+
if not bool(cfg.train.get("online_updates", False)):
|
|
619
|
+
raise RuntimeError(
|
|
620
|
+
"train.algorithm_mode='boundary_state_grad_through_write' requires "
|
|
621
|
+
"train.online_updates=true."
|
|
622
|
+
)
|
|
623
|
+
if not bool(cfg.train.get("per_layer_teach_signal", False)):
|
|
624
|
+
raise RuntimeError(
|
|
625
|
+
"train.algorithm_mode='boundary_state_grad_through_write' requires "
|
|
626
|
+
"train.per_layer_teach_signal=true."
|
|
627
|
+
)
|
|
628
|
+
if not bool(cfg.train.get("use_fast_state", False)):
|
|
629
|
+
raise RuntimeError(
|
|
630
|
+
"train.algorithm_mode='boundary_state_grad_through_write' requires "
|
|
631
|
+
"train.use_fast_state=true."
|
|
632
|
+
)
|
|
633
|
+
if bool(cfg.train.get("online_carry_attention_cache", False)) and not bool(
|
|
634
|
+
cfg.train.get("online_boundary_targets", False)
|
|
635
|
+
):
|
|
636
|
+
raise RuntimeError(
|
|
637
|
+
"online_carry_attention_cache=true requires train.online_boundary_targets=true "
|
|
638
|
+
"(non-overlap chunking)."
|
|
639
|
+
)
|
|
640
|
+
_emit_streaming_warning(
|
|
641
|
+
code="experimental_boundary_state_mode",
|
|
642
|
+
message=(
|
|
643
|
+
"train.algorithm_mode='boundary_state_grad_through_write' is an experimental "
|
|
644
|
+
"single-process path for mechanism probing and may use more memory."
|
|
645
|
+
),
|
|
646
|
+
details={"algorithm_mode": algorithm_mode},
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
|
|
650
|
+
def _validate_online_chunking_constraints(cfg: DictConfig) -> None:
|
|
651
|
+
online_updates = bool(cfg.train.get("online_updates", False))
|
|
652
|
+
online_boundary_targets = bool(cfg.train.get("online_boundary_targets", False))
|
|
653
|
+
online_carry_attention_cache = bool(cfg.train.get("online_carry_attention_cache", False))
|
|
654
|
+
if online_carry_attention_cache and not online_updates:
|
|
655
|
+
raise RuntimeError("online_carry_attention_cache=true requires train.online_updates=true")
|
|
656
|
+
if online_carry_attention_cache and not online_boundary_targets:
|
|
657
|
+
raise RuntimeError(
|
|
658
|
+
"online_carry_attention_cache=true requires train.online_boundary_targets=true "
|
|
659
|
+
"(non-overlap chunking)."
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
def _check_online_supervised_pairs(
|
|
664
|
+
*,
|
|
665
|
+
strict: bool,
|
|
666
|
+
observed_pairs: int,
|
|
667
|
+
seq_len: int,
|
|
668
|
+
) -> None:
|
|
669
|
+
expected_pairs = max(int(seq_len) - 1, 0)
|
|
670
|
+
if observed_pairs == expected_pairs:
|
|
671
|
+
return
|
|
672
|
+
msg = (
|
|
673
|
+
"online chunk supervision mismatch: observed pair coverage does not match sequence length "
|
|
674
|
+
f"(observed_pairs={observed_pairs}, expected_pairs={expected_pairs})"
|
|
675
|
+
)
|
|
676
|
+
if strict:
|
|
677
|
+
raise RuntimeError(msg)
|
|
678
|
+
_emit_streaming_warning(
|
|
679
|
+
code="online_supervision_mismatch",
|
|
680
|
+
message=msg,
|
|
681
|
+
details={"observed_pairs": observed_pairs, "expected_pairs": expected_pairs},
|
|
682
|
+
)
|
|
683
|
+
|
|
684
|
+
|
|
685
|
+
def run_training_loop(
|
|
686
|
+
cfg: DictConfig,
|
|
687
|
+
*,
|
|
688
|
+
device: torch.device,
|
|
689
|
+
distributed: bool = False,
|
|
690
|
+
dist_ctx: DistributedContext | None = None,
|
|
691
|
+
) -> Dict[str, float]:
|
|
692
|
+
algorithm_mode = _resolve_algorithm_mode(cfg)
|
|
693
|
+
_validate_algorithm_mode_constraints(
|
|
694
|
+
cfg,
|
|
695
|
+
algorithm_mode=algorithm_mode,
|
|
696
|
+
distributed=distributed,
|
|
697
|
+
)
|
|
698
|
+
_validate_online_chunking_constraints(cfg)
|
|
699
|
+
_validate_distributed_config(cfg, distributed)
|
|
700
|
+
_validate_paper_auditing_variant(cfg)
|
|
701
|
+
_validate_fast_state_batch_semantics(cfg)
|
|
702
|
+
_validate_online_update_fast_state_semantics(cfg)
|
|
703
|
+
model = build_model_from_cfg(cfg.model).to(device)
|
|
704
|
+
train_seed = cfg.train.get("seed")
|
|
705
|
+
deterministic = cfg.train.get("deterministic", False)
|
|
706
|
+
if train_seed is not None:
|
|
707
|
+
_seed_everything(int(train_seed), deterministic=bool(deterministic))
|
|
708
|
+
model = _maybe_compile_model(model, cfg.train.get("compile"))
|
|
709
|
+
if distributed:
|
|
710
|
+
assert dist_ctx is not None
|
|
711
|
+
if device.type == "cuda":
|
|
712
|
+
idx = device.index if device.index is not None else 0
|
|
713
|
+
model = torch.nn.parallel.DistributedDataParallel(
|
|
714
|
+
model,
|
|
715
|
+
device_ids=[idx],
|
|
716
|
+
output_device=idx,
|
|
717
|
+
find_unused_parameters=True,
|
|
718
|
+
)
|
|
719
|
+
else:
|
|
720
|
+
model = torch.nn.parallel.DistributedDataParallel(
|
|
721
|
+
model,
|
|
722
|
+
find_unused_parameters=True,
|
|
723
|
+
)
|
|
724
|
+
base_model = model.module
|
|
725
|
+
else:
|
|
726
|
+
base_model = model
|
|
727
|
+
|
|
728
|
+
_validate_tied_lm_head_for_paper_auditing(cfg, base_model)
|
|
729
|
+
|
|
730
|
+
seed_offset = 0
|
|
731
|
+
if train_seed is not None and dist_ctx is not None:
|
|
732
|
+
seed_offset = dist_ctx.rank
|
|
733
|
+
dataloader_seed = None if train_seed is None else int(train_seed) + seed_offset
|
|
734
|
+
dataloader, sampler = build_dataloader(
|
|
735
|
+
cfg.data,
|
|
736
|
+
distributed=distributed,
|
|
737
|
+
dist_ctx=dist_ctx,
|
|
738
|
+
seed=dataloader_seed,
|
|
739
|
+
)
|
|
740
|
+
optimizer = _build_optimizer(base_model, cfg, device=device)
|
|
741
|
+
autocast_factory = _make_autocast_factory(device, cfg.train.get("mixed_precision"))
|
|
742
|
+
logger = init_logger(getattr(cfg, "logging", None), cfg)
|
|
743
|
+
if distributed and dist_ctx is not None and dist_ctx.rank != 0:
|
|
744
|
+
logger = NullLogger()
|
|
745
|
+
_log_run_features(logger, base_model, cfg, optimizer, device)
|
|
746
|
+
steps = cfg.train.steps
|
|
747
|
+
log_interval = cfg.train.get("log_interval", 1)
|
|
748
|
+
per_layer_teach = bool(cfg.train.get("per_layer_teach_signal", False))
|
|
749
|
+
online_updates = bool(cfg.train.get("online_updates", False))
|
|
750
|
+
online_chunk_size = int(cfg.train.get("online_chunk_size", 0) or 0)
|
|
751
|
+
online_boundary_targets = bool(cfg.train.get("online_boundary_targets", False))
|
|
752
|
+
online_carry_attention_cache = bool(cfg.train.get("online_carry_attention_cache", False))
|
|
753
|
+
use_fast_state = bool(cfg.train.get("use_fast_state", False))
|
|
754
|
+
fail_if_faithful_disabled = bool(cfg.train.get("fail_if_paper_faithful_disabled", False))
|
|
755
|
+
strict_streaming = bool(cfg.train.get("strict_streaming_contract", False))
|
|
756
|
+
if distributed and per_layer_teach:
|
|
757
|
+
msg = "per_layer_teach_signal disabled under DDP (uses base model methods)"
|
|
758
|
+
if fail_if_faithful_disabled or strict_streaming:
|
|
759
|
+
raise RuntimeError(
|
|
760
|
+
f"{msg}. Set train.strict_streaming_contract=false and "
|
|
761
|
+
"train.fail_if_paper_faithful_disabled=false to allow the fallback, "
|
|
762
|
+
"or run single-process training."
|
|
763
|
+
)
|
|
764
|
+
_emit_streaming_warning(
|
|
765
|
+
code="ddp_disables_per_layer_teach",
|
|
766
|
+
message=msg,
|
|
767
|
+
details={"distributed": True},
|
|
768
|
+
)
|
|
769
|
+
per_layer_teach = False
|
|
770
|
+
if distributed and online_updates:
|
|
771
|
+
msg = "online_updates disabled under DDP (uses base model methods)"
|
|
772
|
+
if fail_if_faithful_disabled or strict_streaming:
|
|
773
|
+
raise RuntimeError(
|
|
774
|
+
f"{msg}. Set train.strict_streaming_contract=false and "
|
|
775
|
+
"train.fail_if_paper_faithful_disabled=false to allow the fallback, "
|
|
776
|
+
"or run single-process training."
|
|
777
|
+
)
|
|
778
|
+
_emit_streaming_warning(
|
|
779
|
+
code="ddp_disables_online_updates",
|
|
780
|
+
message=msg,
|
|
781
|
+
details={"distributed": True},
|
|
782
|
+
)
|
|
783
|
+
online_updates = False
|
|
784
|
+
if online_boundary_targets and not online_updates:
|
|
785
|
+
msg = "online_boundary_targets=true requires train.online_updates=true"
|
|
786
|
+
if fail_if_faithful_disabled or strict_streaming:
|
|
787
|
+
raise RuntimeError(msg)
|
|
788
|
+
_emit_streaming_warning(
|
|
789
|
+
code="boundary_targets_without_online_updates",
|
|
790
|
+
message=msg,
|
|
791
|
+
)
|
|
792
|
+
online_boundary_targets = False
|
|
793
|
+
if online_carry_attention_cache and not online_updates:
|
|
794
|
+
raise RuntimeError("online_carry_attention_cache=true requires train.online_updates=true")
|
|
795
|
+
if online_carry_attention_cache and not online_boundary_targets:
|
|
796
|
+
raise RuntimeError(
|
|
797
|
+
"online_carry_attention_cache=true requires train.online_boundary_targets=true "
|
|
798
|
+
"(non-overlap chunking)."
|
|
799
|
+
)
|
|
800
|
+
step_iter = iter(dataloader)
|
|
801
|
+
epoch = 0
|
|
802
|
+
metrics: Dict[str, float] = {}
|
|
803
|
+
surprise_metric_getter = getattr(base_model, "get_surprise_metric", None)
|
|
804
|
+
surprise_metric = (
|
|
805
|
+
str(surprise_metric_getter()).strip().lower()
|
|
806
|
+
if callable(surprise_metric_getter)
|
|
807
|
+
else str(cfg.model.get("surprise_metric", "l2")).strip().lower()
|
|
808
|
+
)
|
|
809
|
+
for step in range(steps):
|
|
810
|
+
if sampler is not None and step % len(dataloader) == 0:
|
|
811
|
+
sampler.set_epoch(epoch)
|
|
812
|
+
epoch += 1
|
|
813
|
+
try:
|
|
814
|
+
batch = next(step_iter)
|
|
815
|
+
except StopIteration:
|
|
816
|
+
step_iter = iter(dataloader)
|
|
817
|
+
batch = next(step_iter)
|
|
818
|
+
tokens = batch.to(device)
|
|
819
|
+
fast_state = None
|
|
820
|
+
if use_fast_state:
|
|
821
|
+
init_fn = getattr(base_model, "init_fast_state", None)
|
|
822
|
+
if not callable(init_fn):
|
|
823
|
+
raise ValueError("train.use_fast_state=true requires model.init_fast_state()")
|
|
824
|
+
fast_state = init_fn()
|
|
825
|
+
_apply_teach_schedule(base_model, cfg, step)
|
|
826
|
+
update_metrics: Dict[str, float] = {}
|
|
827
|
+
if online_updates and hasattr(base_model, "forward_with_block_outputs"):
|
|
828
|
+
total_loss = 0.0
|
|
829
|
+
total_tokens = 0
|
|
830
|
+
teach_signal_norm = 0.0
|
|
831
|
+
optimizer.zero_grad()
|
|
832
|
+
chunk_size = online_chunk_size
|
|
833
|
+
if chunk_size <= 0:
|
|
834
|
+
inferred = _infer_online_chunk_size(base_model)
|
|
835
|
+
chunk_size = inferred if inferred is not None else tokens.size(1)
|
|
836
|
+
if chunk_size < 1:
|
|
837
|
+
print(f"[train] online_chunk_size={chunk_size} is too small; clamping to 1")
|
|
838
|
+
chunk_size = 1
|
|
839
|
+
attention_cache = None
|
|
840
|
+
if online_carry_attention_cache:
|
|
841
|
+
init_attention_cache = getattr(base_model, "init_attention_cache", None)
|
|
842
|
+
if not callable(init_attention_cache):
|
|
843
|
+
raise RuntimeError(
|
|
844
|
+
"online_carry_attention_cache=true requires model.init_attention_cache()"
|
|
845
|
+
)
|
|
846
|
+
attention_cache = init_attention_cache()
|
|
847
|
+
|
|
848
|
+
chunk_iter: Iterator[tuple[torch.Tensor, torch.Tensor | None, bool]]
|
|
849
|
+
if online_boundary_targets:
|
|
850
|
+
chunk_iter = _iter_online_boundary_chunks(tokens, chunk_size=chunk_size)
|
|
851
|
+
else:
|
|
852
|
+
chunk_iter = (
|
|
853
|
+
(chunk, None, finalize_updates)
|
|
854
|
+
for chunk, finalize_updates in _iter_online_token_chunks(
|
|
855
|
+
tokens, chunk_size=chunk_size
|
|
856
|
+
)
|
|
857
|
+
)
|
|
858
|
+
for chunk_tokens, next_tokens, finalize_updates in chunk_iter:
|
|
859
|
+
target_count = chunk_tokens.size(1) - 1 + (0 if next_tokens is None else 1)
|
|
860
|
+
if target_count <= 0:
|
|
861
|
+
continue
|
|
862
|
+
chunk_attention_cache = attention_cache
|
|
863
|
+
with autocast_factory():
|
|
864
|
+
if attention_cache is not None:
|
|
865
|
+
logits, _pre, block_outputs, attention_cache = (
|
|
866
|
+
base_model.forward_with_block_outputs(
|
|
867
|
+
chunk_tokens,
|
|
868
|
+
fast_state=fast_state,
|
|
869
|
+
attention_cache=chunk_attention_cache,
|
|
870
|
+
return_attention_cache=True,
|
|
871
|
+
)
|
|
872
|
+
)
|
|
873
|
+
else:
|
|
874
|
+
logits, _pre, block_outputs = (
|
|
875
|
+
base_model.forward_with_block_outputs(
|
|
876
|
+
chunk_tokens,
|
|
877
|
+
fast_state=fast_state,
|
|
878
|
+
)
|
|
879
|
+
if fast_state is not None
|
|
880
|
+
else base_model.forward_with_block_outputs(chunk_tokens)
|
|
881
|
+
)
|
|
882
|
+
if next_tokens is None:
|
|
883
|
+
loss = torch.nn.functional.cross_entropy(
|
|
884
|
+
logits[:, :-1].reshape(-1, logits.size(-1)),
|
|
885
|
+
chunk_tokens[:, 1:].reshape(-1),
|
|
886
|
+
)
|
|
887
|
+
else:
|
|
888
|
+
boundary_targets = torch.cat(
|
|
889
|
+
[chunk_tokens[:, 1:], next_tokens.unsqueeze(1)],
|
|
890
|
+
dim=1,
|
|
891
|
+
)
|
|
892
|
+
loss = torch.nn.functional.cross_entropy(
|
|
893
|
+
logits[:, : boundary_targets.size(1), :].reshape(-1, logits.size(-1)),
|
|
894
|
+
boundary_targets.reshape(-1),
|
|
895
|
+
)
|
|
896
|
+
surprise_override = _compute_surprise_override(
|
|
897
|
+
surprise_metric,
|
|
898
|
+
logits=logits,
|
|
899
|
+
tokens=chunk_tokens,
|
|
900
|
+
loss=loss,
|
|
901
|
+
next_tokens=next_tokens,
|
|
902
|
+
)
|
|
903
|
+
if per_layer_teach:
|
|
904
|
+
differentiable_updates = algorithm_mode == "boundary_state_grad_through_write"
|
|
905
|
+
teach_signals = _compute_layer_teach_signals(
|
|
906
|
+
loss,
|
|
907
|
+
block_outputs,
|
|
908
|
+
detach=not differentiable_updates,
|
|
909
|
+
create_graph=differentiable_updates,
|
|
910
|
+
)
|
|
911
|
+
mean_teach_norm = torch.stack(
|
|
912
|
+
[sig.detach().norm(dim=-1).mean() for sig in teach_signals]
|
|
913
|
+
).mean()
|
|
914
|
+
teach_signal_norm += float(
|
|
915
|
+
mean_teach_norm
|
|
916
|
+
) * target_count
|
|
917
|
+
else:
|
|
918
|
+
teach_signal = compute_teach_signal(
|
|
919
|
+
base_model,
|
|
920
|
+
logits,
|
|
921
|
+
chunk_tokens,
|
|
922
|
+
next_tokens=next_tokens,
|
|
923
|
+
)
|
|
924
|
+
teach_signal_norm += teach_signal.norm(dim=-1).mean().item() * target_count
|
|
925
|
+
differentiable_updates = algorithm_mode == "boundary_state_grad_through_write"
|
|
926
|
+
# Boundary-state mode keeps a cross-chunk differentiable write path.
|
|
927
|
+
# Retain the graph so later chunks can backprop through earlier writes.
|
|
928
|
+
loss.backward(retain_graph=differentiable_updates)
|
|
929
|
+
if differentiable_updates:
|
|
930
|
+
if per_layer_teach:
|
|
931
|
+
base_model(
|
|
932
|
+
chunk_tokens,
|
|
933
|
+
teach_signals=teach_signals,
|
|
934
|
+
surprise_value=surprise_override,
|
|
935
|
+
fast_state=fast_state,
|
|
936
|
+
finalize_updates=finalize_updates,
|
|
937
|
+
attention_cache=chunk_attention_cache,
|
|
938
|
+
differentiable_updates=True,
|
|
939
|
+
)
|
|
940
|
+
else:
|
|
941
|
+
base_model(
|
|
942
|
+
chunk_tokens,
|
|
943
|
+
teach_signal=teach_signal,
|
|
944
|
+
surprise_value=surprise_override,
|
|
945
|
+
fast_state=fast_state,
|
|
946
|
+
finalize_updates=finalize_updates,
|
|
947
|
+
attention_cache=chunk_attention_cache,
|
|
948
|
+
differentiable_updates=True,
|
|
949
|
+
)
|
|
950
|
+
if hasattr(base_model, "pop_update_metrics"):
|
|
951
|
+
update_metrics = base_model.pop_update_metrics()
|
|
952
|
+
else:
|
|
953
|
+
with torch.no_grad():
|
|
954
|
+
if per_layer_teach:
|
|
955
|
+
base_model(
|
|
956
|
+
chunk_tokens,
|
|
957
|
+
teach_signals=teach_signals,
|
|
958
|
+
surprise_value=surprise_override,
|
|
959
|
+
fast_state=fast_state,
|
|
960
|
+
finalize_updates=finalize_updates,
|
|
961
|
+
attention_cache=chunk_attention_cache,
|
|
962
|
+
differentiable_updates=False,
|
|
963
|
+
)
|
|
964
|
+
else:
|
|
965
|
+
base_model(
|
|
966
|
+
chunk_tokens,
|
|
967
|
+
teach_signal=teach_signal,
|
|
968
|
+
surprise_value=surprise_override,
|
|
969
|
+
fast_state=fast_state,
|
|
970
|
+
finalize_updates=finalize_updates,
|
|
971
|
+
attention_cache=chunk_attention_cache,
|
|
972
|
+
differentiable_updates=False,
|
|
973
|
+
)
|
|
974
|
+
if hasattr(base_model, "pop_update_metrics"):
|
|
975
|
+
update_metrics = base_model.pop_update_metrics()
|
|
976
|
+
total_loss += loss.item() * target_count
|
|
977
|
+
total_tokens += target_count
|
|
978
|
+
_check_online_supervised_pairs(
|
|
979
|
+
strict=strict_streaming,
|
|
980
|
+
observed_pairs=total_tokens,
|
|
981
|
+
seq_len=int(tokens.size(1)),
|
|
982
|
+
)
|
|
983
|
+
torch.nn.utils.clip_grad_norm_(base_model.parameters(), max_norm=1.0)
|
|
984
|
+
optimizer.step()
|
|
985
|
+
loss = torch.tensor(total_loss / max(total_tokens, 1), device=device)
|
|
986
|
+
teach_signal_norm = teach_signal_norm / max(total_tokens, 1)
|
|
987
|
+
else:
|
|
988
|
+
with autocast_factory():
|
|
989
|
+
if per_layer_teach and hasattr(base_model, "forward_with_block_outputs"):
|
|
990
|
+
logits, _pre, block_outputs = (
|
|
991
|
+
base_model.forward_with_block_outputs(tokens, fast_state=fast_state)
|
|
992
|
+
if fast_state is not None
|
|
993
|
+
else base_model.forward_with_block_outputs(tokens)
|
|
994
|
+
)
|
|
995
|
+
loss = torch.nn.functional.cross_entropy(
|
|
996
|
+
logits[:, :-1].reshape(-1, logits.size(-1)),
|
|
997
|
+
tokens[:, 1:].reshape(-1),
|
|
998
|
+
)
|
|
999
|
+
else:
|
|
1000
|
+
if fast_state is not None:
|
|
1001
|
+
logits = model(tokens, fast_state=fast_state)
|
|
1002
|
+
else:
|
|
1003
|
+
logits = model(tokens)
|
|
1004
|
+
loss = torch.nn.functional.cross_entropy(
|
|
1005
|
+
logits[:, :-1].reshape(-1, logits.size(-1)),
|
|
1006
|
+
tokens[:, 1:].reshape(-1),
|
|
1007
|
+
)
|
|
1008
|
+
surprise_override = _compute_surprise_override(
|
|
1009
|
+
surprise_metric,
|
|
1010
|
+
logits=logits,
|
|
1011
|
+
tokens=tokens,
|
|
1012
|
+
loss=loss,
|
|
1013
|
+
next_tokens=None,
|
|
1014
|
+
)
|
|
1015
|
+
optimizer.zero_grad()
|
|
1016
|
+
if per_layer_teach and hasattr(base_model, "forward_with_block_outputs"):
|
|
1017
|
+
teach_signals = _compute_layer_teach_signals(loss, block_outputs)
|
|
1018
|
+
loss.backward()
|
|
1019
|
+
torch.nn.utils.clip_grad_norm_(base_model.parameters(), max_norm=1.0)
|
|
1020
|
+
optimizer.step()
|
|
1021
|
+
with torch.no_grad():
|
|
1022
|
+
if per_layer_teach and hasattr(base_model, "forward_with_block_outputs"):
|
|
1023
|
+
teach_signal_norm = float(
|
|
1024
|
+
torch.stack([sig.norm(dim=-1).mean() for sig in teach_signals]).mean()
|
|
1025
|
+
)
|
|
1026
|
+
base_model(
|
|
1027
|
+
tokens,
|
|
1028
|
+
teach_signals=teach_signals,
|
|
1029
|
+
surprise_value=surprise_override,
|
|
1030
|
+
fast_state=fast_state,
|
|
1031
|
+
)
|
|
1032
|
+
else:
|
|
1033
|
+
teach_signal = compute_teach_signal(base_model, logits, tokens)
|
|
1034
|
+
teach_signal_norm = teach_signal.norm(dim=-1).mean().item()
|
|
1035
|
+
base_model(
|
|
1036
|
+
tokens,
|
|
1037
|
+
teach_signal=teach_signal,
|
|
1038
|
+
surprise_value=surprise_override,
|
|
1039
|
+
fast_state=fast_state,
|
|
1040
|
+
)
|
|
1041
|
+
if hasattr(base_model, "pop_update_metrics"):
|
|
1042
|
+
update_metrics = base_model.pop_update_metrics()
|
|
1043
|
+
if step % log_interval == 0:
|
|
1044
|
+
ppl = torch.exp(loss.detach()).item()
|
|
1045
|
+
metrics_payload = {
|
|
1046
|
+
"loss": loss.item(),
|
|
1047
|
+
"ppl": ppl,
|
|
1048
|
+
"teach_signal_norm": teach_signal_norm,
|
|
1049
|
+
}
|
|
1050
|
+
metrics_payload.update(update_metrics)
|
|
1051
|
+
logger.log(metrics_payload, step=step)
|
|
1052
|
+
if (not distributed) or (dist_ctx and dist_ctx.rank == 0):
|
|
1053
|
+
print(
|
|
1054
|
+
f"[train] step={step} loss={loss.item():.4f} "
|
|
1055
|
+
f"ppl={ppl:.2f} teach_norm={teach_signal_norm:.4f}"
|
|
1056
|
+
)
|
|
1057
|
+
metrics = metrics_payload
|
|
1058
|
+
maybe_save_checkpoint(
|
|
1059
|
+
cfg,
|
|
1060
|
+
base_model,
|
|
1061
|
+
optimizer,
|
|
1062
|
+
step=step,
|
|
1063
|
+
total_steps=steps,
|
|
1064
|
+
distributed=distributed,
|
|
1065
|
+
dist_ctx=dist_ctx,
|
|
1066
|
+
step_offset=int(cfg.train.get("step_offset", 0) or 0),
|
|
1067
|
+
)
|
|
1068
|
+
logger.finish()
|
|
1069
|
+
return metrics
|
|
1070
|
+
|
|
1071
|
+
|
|
1072
|
+
def _apply_teach_schedule(model: HOPEModel, cfg: DictConfig, step: int) -> None:
|
|
1073
|
+
schedule = cfg.model.get("teach_schedule")
|
|
1074
|
+
base_scale = cfg.model.get("teach_scale", 1.0)
|
|
1075
|
+
scale = base_scale
|
|
1076
|
+
if schedule:
|
|
1077
|
+
warmup = schedule.get("warmup_steps", 0)
|
|
1078
|
+
if warmup and warmup > 0:
|
|
1079
|
+
scale *= min(1.0, (step + 1) / warmup)
|
|
1080
|
+
decay_start = schedule.get("decay_start")
|
|
1081
|
+
decay_duration = schedule.get("decay_duration")
|
|
1082
|
+
if (
|
|
1083
|
+
decay_start is not None
|
|
1084
|
+
and decay_duration
|
|
1085
|
+
and decay_duration > 0
|
|
1086
|
+
and (step + 1) > decay_start
|
|
1087
|
+
):
|
|
1088
|
+
progress = min(1.0, (step + 1 - decay_start) / decay_duration)
|
|
1089
|
+
scale *= max(0.0, 1.0 - progress)
|
|
1090
|
+
model.set_teach_runtime(scale=scale)
|
|
1091
|
+
|
|
1092
|
+
|
|
1093
|
+
def _maybe_compile_model(model: torch.nn.Module, compile_cfg: dict | None) -> torch.nn.Module:
|
|
1094
|
+
if not compile_cfg or not compile_cfg.get("enable", False):
|
|
1095
|
+
return model
|
|
1096
|
+
kwargs = {}
|
|
1097
|
+
if "mode" in compile_cfg:
|
|
1098
|
+
kwargs["mode"] = compile_cfg["mode"]
|
|
1099
|
+
if "backend" in compile_cfg:
|
|
1100
|
+
kwargs["backend"] = compile_cfg["backend"]
|
|
1101
|
+
try:
|
|
1102
|
+
return cast(torch.nn.Module, torch.compile(model, **kwargs)) # type: ignore[attr-defined]
|
|
1103
|
+
except Exception as err: # pragma: no cover - compile is optional
|
|
1104
|
+
if compile_cfg.get("strict", False):
|
|
1105
|
+
raise
|
|
1106
|
+
print(f"[compile] fallback to eager due to: {err}")
|
|
1107
|
+
return model
|
|
1108
|
+
|
|
1109
|
+
|
|
1110
|
+
def _make_autocast_factory(device: torch.device, mp_cfg: dict | None):
|
|
1111
|
+
if not mp_cfg or not mp_cfg.get("enabled", False):
|
|
1112
|
+
return lambda: nullcontext()
|
|
1113
|
+
dtype = _resolve_autocast_dtype(mp_cfg.get("dtype", "bf16"))
|
|
1114
|
+
device_type = device.type
|
|
1115
|
+
if device_type not in {"cuda", "cpu", "mps"}:
|
|
1116
|
+
device_type = "cpu"
|
|
1117
|
+
|
|
1118
|
+
def factory():
|
|
1119
|
+
try:
|
|
1120
|
+
return torch.autocast(device_type=device_type, dtype=dtype)
|
|
1121
|
+
except Exception as err: # pragma: no cover - device/dtype support varies by backend
|
|
1122
|
+
print(f"[autocast] disabled for device_type={device_type} dtype={dtype}: {err}")
|
|
1123
|
+
return nullcontext()
|
|
1124
|
+
|
|
1125
|
+
return factory
|
|
1126
|
+
|
|
1127
|
+
|
|
1128
|
+
def _resolve_autocast_dtype(name: str) -> torch.dtype:
|
|
1129
|
+
normalized = str(name).lower()
|
|
1130
|
+
if normalized in {"bf16", "bfloat16"}:
|
|
1131
|
+
return torch.bfloat16
|
|
1132
|
+
if normalized in {"fp16", "float16", "half"}:
|
|
1133
|
+
return torch.float16
|
|
1134
|
+
msg = f"Unsupported autocast dtype {name}"
|
|
1135
|
+
raise ValueError(msg)
|
|
1136
|
+
|
|
1137
|
+
|
|
1138
|
+
def _build_optimizer(
|
|
1139
|
+
model: torch.nn.Module, cfg: DictConfig, *, device: torch.device
|
|
1140
|
+
) -> torch.optim.Optimizer:
|
|
1141
|
+
optimizer_cfg_raw = cfg.get("optim")
|
|
1142
|
+
if isinstance(optimizer_cfg_raw, DictConfig):
|
|
1143
|
+
optimizer_cfg = optimizer_cfg_raw
|
|
1144
|
+
else:
|
|
1145
|
+
optimizer_cfg = cast(DictConfig, OmegaConf.create(optimizer_cfg_raw or {}))
|
|
1146
|
+
param_policy_raw = optimizer_cfg.get("param_policy")
|
|
1147
|
+
if param_policy_raw is None:
|
|
1148
|
+
outer_updates_memory_modules = optimizer_cfg.get("outer_updates_memory_modules")
|
|
1149
|
+
if outer_updates_memory_modules is None:
|
|
1150
|
+
param_policy = "all"
|
|
1151
|
+
else:
|
|
1152
|
+
param_policy = "all" if bool(outer_updates_memory_modules) else "exclude_memory"
|
|
1153
|
+
else:
|
|
1154
|
+
param_policy = str(param_policy_raw).strip().lower()
|
|
1155
|
+
named_params = _select_outer_named_parameters(model, param_policy)
|
|
1156
|
+
if not named_params:
|
|
1157
|
+
raise ValueError(
|
|
1158
|
+
f"No trainable parameters selected for optim.param_policy={param_policy!r}. "
|
|
1159
|
+
"Check freeze_backbone, requires_grad flags, or adjust the policy."
|
|
1160
|
+
)
|
|
1161
|
+
optim_type = str(optimizer_cfg.get("type", "adamw")).lower()
|
|
1162
|
+
if optim_type == "muon":
|
|
1163
|
+
return _build_muon_optimizer(
|
|
1164
|
+
model,
|
|
1165
|
+
optimizer_cfg,
|
|
1166
|
+
device=device,
|
|
1167
|
+
named_params=named_params,
|
|
1168
|
+
param_policy=param_policy,
|
|
1169
|
+
)
|
|
1170
|
+
if optim_type == "m3":
|
|
1171
|
+
return _build_m3_optimizer(
|
|
1172
|
+
model,
|
|
1173
|
+
optimizer_cfg,
|
|
1174
|
+
device=device,
|
|
1175
|
+
named_params=named_params,
|
|
1176
|
+
param_policy=param_policy,
|
|
1177
|
+
)
|
|
1178
|
+
lr = optimizer_cfg.get("lr", 1e-3)
|
|
1179
|
+
betas = optimizer_cfg.get("betas", (0.9, 0.999))
|
|
1180
|
+
weight_decay = optimizer_cfg.get("weight_decay", 0.0)
|
|
1181
|
+
fused_cfg = optimizer_cfg.get("fused", "auto")
|
|
1182
|
+
fused = False
|
|
1183
|
+
if fused_cfg == "auto":
|
|
1184
|
+
fused = device.type == "cuda" and torch.cuda.is_available()
|
|
1185
|
+
else:
|
|
1186
|
+
fused = bool(fused_cfg)
|
|
1187
|
+
kwargs = {"lr": lr, "betas": betas, "weight_decay": weight_decay}
|
|
1188
|
+
if fused:
|
|
1189
|
+
kwargs["fused"] = True
|
|
1190
|
+
params = [param for _, param in named_params]
|
|
1191
|
+
return torch.optim.AdamW(params, **kwargs)
|
|
1192
|
+
|
|
1193
|
+
|
|
1194
|
+
def _build_muon_optimizer(
|
|
1195
|
+
model: torch.nn.Module,
|
|
1196
|
+
optimizer_cfg: DictConfig,
|
|
1197
|
+
*,
|
|
1198
|
+
device: torch.device,
|
|
1199
|
+
named_params: list[tuple[str, torch.nn.Parameter]] | None = None,
|
|
1200
|
+
param_policy: str | None = None,
|
|
1201
|
+
):
|
|
1202
|
+
if not hasattr(torch.optim, "Muon"):
|
|
1203
|
+
raise RuntimeError("torch.optim.Muon is not available in this PyTorch build")
|
|
1204
|
+
lr = optimizer_cfg.get("lr", 1e-3)
|
|
1205
|
+
weight_decay = optimizer_cfg.get("weight_decay", 0.01)
|
|
1206
|
+
momentum = optimizer_cfg.get("momentum", 0.95)
|
|
1207
|
+
ns_coefficients = optimizer_cfg.get("ns_coefficients")
|
|
1208
|
+
ns_steps = optimizer_cfg.get("ns_steps")
|
|
1209
|
+
eps = optimizer_cfg.get("eps", 1e-7)
|
|
1210
|
+
fused_cfg = optimizer_cfg.get("fused", "auto")
|
|
1211
|
+
fused = False
|
|
1212
|
+
if fused_cfg == "auto":
|
|
1213
|
+
fused = device.type == "cuda" and torch.cuda.is_available()
|
|
1214
|
+
else:
|
|
1215
|
+
fused = bool(fused_cfg)
|
|
1216
|
+
muon_params: list[torch.nn.Parameter] = []
|
|
1217
|
+
adamw_params: list[torch.nn.Parameter] = []
|
|
1218
|
+
source = named_params if named_params is not None else model.named_parameters()
|
|
1219
|
+
for name, param in source:
|
|
1220
|
+
if not param.requires_grad:
|
|
1221
|
+
continue
|
|
1222
|
+
if _is_muon_candidate(name, param):
|
|
1223
|
+
muon_params.append(param)
|
|
1224
|
+
else:
|
|
1225
|
+
adamw_params.append(param)
|
|
1226
|
+
muon_kwargs = {
|
|
1227
|
+
"lr": lr,
|
|
1228
|
+
"weight_decay": weight_decay,
|
|
1229
|
+
"momentum": momentum,
|
|
1230
|
+
"eps": eps,
|
|
1231
|
+
}
|
|
1232
|
+
if ns_coefficients is not None:
|
|
1233
|
+
muon_kwargs["ns_coefficients"] = tuple(ns_coefficients)
|
|
1234
|
+
if ns_steps is not None:
|
|
1235
|
+
muon_kwargs["ns_steps"] = int(ns_steps)
|
|
1236
|
+
muon_opt = torch.optim.Muon(muon_params, **muon_kwargs) if muon_params else None # type: ignore[attr-defined]
|
|
1237
|
+
adamw_kwargs = {
|
|
1238
|
+
"lr": lr,
|
|
1239
|
+
"betas": optimizer_cfg.get("betas", (0.9, 0.999)),
|
|
1240
|
+
"weight_decay": weight_decay,
|
|
1241
|
+
}
|
|
1242
|
+
if fused:
|
|
1243
|
+
adamw_kwargs["fused"] = True
|
|
1244
|
+
adamw_opt = torch.optim.AdamW(adamw_params, **adamw_kwargs) if adamw_params else None
|
|
1245
|
+
muon_elems = int(sum(p.numel() for p in muon_params))
|
|
1246
|
+
adamw_elems = int(sum(p.numel() for p in adamw_params))
|
|
1247
|
+
return _HybridOptimizer(
|
|
1248
|
+
muon_opt,
|
|
1249
|
+
adamw_opt,
|
|
1250
|
+
muon_elems,
|
|
1251
|
+
adamw_elems,
|
|
1252
|
+
primary_name="muon",
|
|
1253
|
+
param_policy=param_policy,
|
|
1254
|
+
)
|
|
1255
|
+
|
|
1256
|
+
|
|
1257
|
+
def _build_m3_optimizer(
|
|
1258
|
+
model: torch.nn.Module,
|
|
1259
|
+
optimizer_cfg: DictConfig,
|
|
1260
|
+
*,
|
|
1261
|
+
device: torch.device,
|
|
1262
|
+
named_params: list[tuple[str, torch.nn.Parameter]] | None = None,
|
|
1263
|
+
param_policy: str | None = None,
|
|
1264
|
+
):
|
|
1265
|
+
lr = optimizer_cfg.get("lr", 1e-3)
|
|
1266
|
+
weight_decay = optimizer_cfg.get("weight_decay", 0.01)
|
|
1267
|
+
beta1 = optimizer_cfg.get("beta1", 0.9)
|
|
1268
|
+
beta2 = optimizer_cfg.get("beta2", 0.999)
|
|
1269
|
+
beta3 = optimizer_cfg.get("beta3", 0.9)
|
|
1270
|
+
alpha = optimizer_cfg.get("alpha", 1.0)
|
|
1271
|
+
ns_steps = int(optimizer_cfg.get("ns_steps", 3))
|
|
1272
|
+
slow_chunk = int(optimizer_cfg.get("slow_chunk", 100))
|
|
1273
|
+
eps = optimizer_cfg.get("eps", 1e-8)
|
|
1274
|
+
fused_cfg = optimizer_cfg.get("fused", "auto")
|
|
1275
|
+
fused = False
|
|
1276
|
+
if fused_cfg == "auto":
|
|
1277
|
+
fused = device.type == "cuda" and torch.cuda.is_available()
|
|
1278
|
+
else:
|
|
1279
|
+
fused = bool(fused_cfg)
|
|
1280
|
+
|
|
1281
|
+
m3_params: list[torch.nn.Parameter] = []
|
|
1282
|
+
adamw_params: list[torch.nn.Parameter] = []
|
|
1283
|
+
source = named_params if named_params is not None else model.named_parameters()
|
|
1284
|
+
for name, param in source:
|
|
1285
|
+
if not param.requires_grad:
|
|
1286
|
+
continue
|
|
1287
|
+
if _is_muon_candidate(name, param):
|
|
1288
|
+
m3_params.append(param)
|
|
1289
|
+
else:
|
|
1290
|
+
adamw_params.append(param)
|
|
1291
|
+
m3_opt = (
|
|
1292
|
+
M3(
|
|
1293
|
+
m3_params,
|
|
1294
|
+
lr=lr,
|
|
1295
|
+
beta1=beta1,
|
|
1296
|
+
beta2=beta2,
|
|
1297
|
+
beta3=beta3,
|
|
1298
|
+
alpha=alpha,
|
|
1299
|
+
eps=eps,
|
|
1300
|
+
ns_steps=ns_steps,
|
|
1301
|
+
slow_chunk=slow_chunk,
|
|
1302
|
+
weight_decay=weight_decay,
|
|
1303
|
+
)
|
|
1304
|
+
if m3_params
|
|
1305
|
+
else None
|
|
1306
|
+
)
|
|
1307
|
+
adamw_kwargs = {
|
|
1308
|
+
"lr": lr,
|
|
1309
|
+
"betas": optimizer_cfg.get("betas", (0.9, 0.999)),
|
|
1310
|
+
"weight_decay": weight_decay,
|
|
1311
|
+
}
|
|
1312
|
+
if fused:
|
|
1313
|
+
adamw_kwargs["fused"] = True
|
|
1314
|
+
adamw_opt = torch.optim.AdamW(adamw_params, **adamw_kwargs) if adamw_params else None
|
|
1315
|
+
m3_elems = int(sum(p.numel() for p in m3_params))
|
|
1316
|
+
adamw_elems = int(sum(p.numel() for p in adamw_params))
|
|
1317
|
+
return _HybridOptimizer(
|
|
1318
|
+
m3_opt,
|
|
1319
|
+
adamw_opt,
|
|
1320
|
+
m3_elems,
|
|
1321
|
+
adamw_elems,
|
|
1322
|
+
primary_name="m3",
|
|
1323
|
+
param_policy=param_policy,
|
|
1324
|
+
)
|
|
1325
|
+
|
|
1326
|
+
|
|
1327
|
+
def _select_outer_named_parameters(
|
|
1328
|
+
model: torch.nn.Module, param_policy: str
|
|
1329
|
+
) -> list[tuple[str, torch.nn.Parameter]]:
|
|
1330
|
+
policy = str(param_policy).strip().lower()
|
|
1331
|
+
trainable: list[tuple[str, torch.nn.Parameter]] = [
|
|
1332
|
+
(name, param) for name, param in model.named_parameters() if param.requires_grad
|
|
1333
|
+
]
|
|
1334
|
+
if policy in {"all", "full"}:
|
|
1335
|
+
return trainable
|
|
1336
|
+
if policy in {"exclude_memory", "no_memory"}:
|
|
1337
|
+
return [(name, param) for name, param in trainable if not _is_memory_param_name(name)]
|
|
1338
|
+
if policy in {"only_memory", "memory_only"}:
|
|
1339
|
+
return [(name, param) for name, param in trainable if _is_memory_param_name(name)]
|
|
1340
|
+
raise ValueError(
|
|
1341
|
+
f"Unsupported optim.param_policy={param_policy!r}. "
|
|
1342
|
+
"Expected one of ['all', 'exclude_memory', 'only_memory']."
|
|
1343
|
+
)
|
|
1344
|
+
|
|
1345
|
+
|
|
1346
|
+
def _is_memory_param_name(name: str) -> bool:
|
|
1347
|
+
lowered = name.lower()
|
|
1348
|
+
return any(token in lowered for token in (".cms.", ".titan_memory.", ".selfmod."))
|
|
1349
|
+
|
|
1350
|
+
|
|
1351
|
+
def _is_muon_candidate(name: str, param: torch.nn.Parameter) -> bool:
|
|
1352
|
+
if param.ndim < 2:
|
|
1353
|
+
return False
|
|
1354
|
+
lowered = name.lower()
|
|
1355
|
+
if "norm" in lowered or "embed" in lowered:
|
|
1356
|
+
return False
|
|
1357
|
+
return True
|
|
1358
|
+
|
|
1359
|
+
|
|
1360
|
+
class _HybridOptimizer:
|
|
1361
|
+
def __init__(
|
|
1362
|
+
self,
|
|
1363
|
+
primary_opt: torch.optim.Optimizer | None,
|
|
1364
|
+
secondary_opt: torch.optim.Optimizer | None,
|
|
1365
|
+
primary_param_elems: int,
|
|
1366
|
+
secondary_param_elems: int,
|
|
1367
|
+
*,
|
|
1368
|
+
primary_name: str = "muon",
|
|
1369
|
+
param_policy: str | None = None,
|
|
1370
|
+
):
|
|
1371
|
+
self.primary_opt = primary_opt
|
|
1372
|
+
self.secondary_opt = secondary_opt
|
|
1373
|
+
self.primary_param_elems = primary_param_elems
|
|
1374
|
+
self.secondary_param_elems = secondary_param_elems
|
|
1375
|
+
self.primary_name = primary_name
|
|
1376
|
+
self.param_policy = param_policy
|
|
1377
|
+
|
|
1378
|
+
def zero_grad(self) -> None:
|
|
1379
|
+
if self.primary_opt:
|
|
1380
|
+
self.primary_opt.zero_grad()
|
|
1381
|
+
if self.secondary_opt:
|
|
1382
|
+
self.secondary_opt.zero_grad()
|
|
1383
|
+
|
|
1384
|
+
def step(self) -> None:
|
|
1385
|
+
if self.primary_opt:
|
|
1386
|
+
self.primary_opt.step()
|
|
1387
|
+
if self.secondary_opt:
|
|
1388
|
+
self.secondary_opt.step()
|
|
1389
|
+
|
|
1390
|
+
def state_dict(self) -> dict:
|
|
1391
|
+
return {
|
|
1392
|
+
self.primary_name: self.primary_opt.state_dict() if self.primary_opt else None,
|
|
1393
|
+
"adamw": self.secondary_opt.state_dict() if self.secondary_opt else None,
|
|
1394
|
+
}
|
|
1395
|
+
|
|
1396
|
+
def load_state_dict(self, state: dict) -> None:
|
|
1397
|
+
if self.primary_opt and state.get(self.primary_name) is not None:
|
|
1398
|
+
self.primary_opt.load_state_dict(state[self.primary_name])
|
|
1399
|
+
if self.secondary_opt and state.get("adamw") is not None:
|
|
1400
|
+
self.secondary_opt.load_state_dict(state["adamw"])
|
|
1401
|
+
|
|
1402
|
+
@property
|
|
1403
|
+
def param_groups(self):
|
|
1404
|
+
groups = []
|
|
1405
|
+
if self.primary_opt:
|
|
1406
|
+
groups.extend(self.primary_opt.param_groups)
|
|
1407
|
+
if self.secondary_opt:
|
|
1408
|
+
groups.extend(self.secondary_opt.param_groups)
|
|
1409
|
+
return groups
|
|
1410
|
+
|
|
1411
|
+
def get_param_split(self) -> dict[str, int]:
|
|
1412
|
+
return {
|
|
1413
|
+
self.primary_name: self.primary_param_elems,
|
|
1414
|
+
"adamw": self.secondary_param_elems,
|
|
1415
|
+
}
|
|
1416
|
+
|
|
1417
|
+
|
|
1418
|
+
def _log_run_features(
|
|
1419
|
+
logger: BaseLogger,
|
|
1420
|
+
model: torch.nn.Module,
|
|
1421
|
+
cfg: DictConfig,
|
|
1422
|
+
optimizer: torch.optim.Optimizer,
|
|
1423
|
+
device: torch.device,
|
|
1424
|
+
) -> None:
|
|
1425
|
+
mp_cfg = cfg.train.get("mixed_precision", {})
|
|
1426
|
+
compile_cfg = cfg.train.get("compile", {})
|
|
1427
|
+
algorithm_mode = str(cfg.train.get("algorithm_mode", "two_pass_stopgrad_updates"))
|
|
1428
|
+
features: dict[str, object] = {
|
|
1429
|
+
"train.mixed_precision_enabled": bool(mp_cfg.get("enabled", False)),
|
|
1430
|
+
"train.mixed_precision_dtype": str(mp_cfg.get("dtype", "bf16")),
|
|
1431
|
+
"train.compile_enabled": bool(compile_cfg.get("enable", False)),
|
|
1432
|
+
"train.compile_mode": str(compile_cfg.get("mode", "default")) if compile_cfg else "default",
|
|
1433
|
+
"train.strict_streaming_contract": bool(cfg.train.get("strict_streaming_contract", False)),
|
|
1434
|
+
"train.online_updates": bool(cfg.train.get("online_updates", False)),
|
|
1435
|
+
"train.online_boundary_targets": bool(cfg.train.get("online_boundary_targets", False)),
|
|
1436
|
+
"train.online_carry_attention_cache": bool(
|
|
1437
|
+
cfg.train.get("online_carry_attention_cache", False)
|
|
1438
|
+
),
|
|
1439
|
+
"train.use_fast_state": bool(cfg.train.get("use_fast_state", False)),
|
|
1440
|
+
"train.algorithm_mode": algorithm_mode,
|
|
1441
|
+
"train.backprop_through_online_writes": algorithm_mode
|
|
1442
|
+
== "boundary_state_grad_through_write",
|
|
1443
|
+
"attention.flash_enabled": _detect_flash_attention(model),
|
|
1444
|
+
"device": device.type,
|
|
1445
|
+
}
|
|
1446
|
+
optimizer_cfg_raw = cfg.get("optim")
|
|
1447
|
+
if isinstance(optimizer_cfg_raw, DictConfig):
|
|
1448
|
+
optimizer_cfg = optimizer_cfg_raw
|
|
1449
|
+
else:
|
|
1450
|
+
optimizer_cfg = cast(DictConfig, OmegaConf.create(optimizer_cfg_raw or {}))
|
|
1451
|
+
param_policy_raw = optimizer_cfg.get("param_policy")
|
|
1452
|
+
if param_policy_raw is None:
|
|
1453
|
+
outer_updates_memory_modules = optimizer_cfg.get("outer_updates_memory_modules")
|
|
1454
|
+
if outer_updates_memory_modules is None:
|
|
1455
|
+
param_policy = "all"
|
|
1456
|
+
else:
|
|
1457
|
+
param_policy = "all" if bool(outer_updates_memory_modules) else "exclude_memory"
|
|
1458
|
+
else:
|
|
1459
|
+
param_policy = str(param_policy_raw).strip().lower()
|
|
1460
|
+
try:
|
|
1461
|
+
selected = _select_outer_named_parameters(model, param_policy)
|
|
1462
|
+
total_elems = int(sum(param.numel() for _, param in selected))
|
|
1463
|
+
memory_elems = int(
|
|
1464
|
+
sum(param.numel() for name, param in selected if _is_memory_param_name(name))
|
|
1465
|
+
)
|
|
1466
|
+
features["optim.param_policy"] = param_policy
|
|
1467
|
+
features["optim.param_policy_param_elems"] = total_elems
|
|
1468
|
+
features["optim.param_policy_memory_param_elems"] = memory_elems
|
|
1469
|
+
features["optim.param_policy_non_memory_param_elems"] = total_elems - memory_elems
|
|
1470
|
+
except Exception as err: # pragma: no cover - purely diagnostic
|
|
1471
|
+
features["optim.param_policy"] = param_policy
|
|
1472
|
+
features["optim.param_policy_error"] = str(err)
|
|
1473
|
+
split_fn = getattr(optimizer, "get_param_split", None)
|
|
1474
|
+
if callable(split_fn):
|
|
1475
|
+
split = split_fn()
|
|
1476
|
+
for key, value in split.items():
|
|
1477
|
+
features[f"optim.{key}_param_elems"] = int(value)
|
|
1478
|
+
logger.log(features, step=-1)
|
|
1479
|
+
print(f"[train] run_features {features}")
|
|
1480
|
+
|
|
1481
|
+
|
|
1482
|
+
def _detect_flash_attention(model: torch.nn.Module) -> bool:
|
|
1483
|
+
blocks = getattr(model, "blocks", [])
|
|
1484
|
+
for block in blocks:
|
|
1485
|
+
attn = getattr(block, "attn", None)
|
|
1486
|
+
config = getattr(attn, "config", None)
|
|
1487
|
+
if config is not None and hasattr(config, "use_flash"):
|
|
1488
|
+
return bool(config.use_flash)
|
|
1489
|
+
return False
|
|
1490
|
+
|
|
1491
|
+
|
|
1492
|
+
def write_checkpoint_metadata(cfg: DictConfig, ckpt_path: Path, step: int) -> None:
|
|
1493
|
+
config_yaml = OmegaConf.to_yaml(cfg)
|
|
1494
|
+
config_path = ckpt_path.with_suffix(".yaml")
|
|
1495
|
+
config_path.write_text(config_yaml)
|
|
1496
|
+
config_hash = sha256(config_yaml.encode("utf-8")).hexdigest()
|
|
1497
|
+
ckpt_hash = _checksum_path(str(ckpt_path))
|
|
1498
|
+
sha_path = ckpt_path.with_suffix(".sha256")
|
|
1499
|
+
if ckpt_hash:
|
|
1500
|
+
sha_path.write_text(f"{ckpt_hash} {ckpt_path.name}\n")
|
|
1501
|
+
tokenizer_path = cfg.data.get("tokenizer_path") if hasattr(cfg, "data") else None
|
|
1502
|
+
metadata = {
|
|
1503
|
+
"step": step,
|
|
1504
|
+
"checkpoint_sha256": ckpt_hash,
|
|
1505
|
+
"config_sha256": config_hash,
|
|
1506
|
+
"tokenizer_hash": _checksum_path(tokenizer_path) if tokenizer_path else None,
|
|
1507
|
+
"config_path": str(config_path),
|
|
1508
|
+
"algorithm_mode": str(cfg.train.get("algorithm_mode", "two_pass_stopgrad_updates")),
|
|
1509
|
+
"online_updates": bool(cfg.train.get("online_updates", False)),
|
|
1510
|
+
"online_boundary_targets": bool(cfg.train.get("online_boundary_targets", False)),
|
|
1511
|
+
"online_carry_attention_cache": bool(
|
|
1512
|
+
cfg.train.get("online_carry_attention_cache", False)
|
|
1513
|
+
),
|
|
1514
|
+
"use_fast_state": bool(cfg.train.get("use_fast_state", False)),
|
|
1515
|
+
"rng_states": _capture_rng_states(),
|
|
1516
|
+
}
|
|
1517
|
+
ckpt_path.with_suffix(".meta.json").write_text(json.dumps(metadata, indent=2))
|
|
1518
|
+
|
|
1519
|
+
|
|
1520
|
+
def verify_checkpoint_integrity(ckpt_path: Path) -> Dict[str, object]:
|
|
1521
|
+
if not ckpt_path.exists():
|
|
1522
|
+
raise FileNotFoundError(f"Checkpoint {ckpt_path} not found")
|
|
1523
|
+
meta_path = ckpt_path.with_suffix(".meta.json")
|
|
1524
|
+
if not meta_path.exists():
|
|
1525
|
+
raise FileNotFoundError(f"Metadata file {meta_path} missing")
|
|
1526
|
+
metadata = json.loads(meta_path.read_text())
|
|
1527
|
+
computed_sha = _checksum_path(str(ckpt_path))
|
|
1528
|
+
recorded_sha = metadata.get("checkpoint_sha256")
|
|
1529
|
+
if recorded_sha and computed_sha and recorded_sha != computed_sha:
|
|
1530
|
+
raise ValueError(
|
|
1531
|
+
f"Checkpoint SHA mismatch: recorded {recorded_sha} vs computed {computed_sha}"
|
|
1532
|
+
)
|
|
1533
|
+
sha_file = ckpt_path.with_suffix(".sha256")
|
|
1534
|
+
if sha_file.exists() and computed_sha:
|
|
1535
|
+
recorded_line = sha_file.read_text().strip().split()
|
|
1536
|
+
if recorded_line:
|
|
1537
|
+
recorded = recorded_line[0]
|
|
1538
|
+
if recorded != computed_sha:
|
|
1539
|
+
raise ValueError(f".sha256 mismatch: {recorded} vs {computed_sha}")
|
|
1540
|
+
config_path = ckpt_path.with_suffix(".yaml")
|
|
1541
|
+
if not config_path.exists():
|
|
1542
|
+
raise FileNotFoundError(f"Config file {config_path} missing")
|
|
1543
|
+
config_hash = sha256(config_path.read_text().encode("utf-8")).hexdigest()
|
|
1544
|
+
recorded_cfg_hash = metadata.get("config_sha256")
|
|
1545
|
+
if recorded_cfg_hash and recorded_cfg_hash != config_hash:
|
|
1546
|
+
raise ValueError(
|
|
1547
|
+
f"Config SHA mismatch: recorded {recorded_cfg_hash} vs computed {config_hash}"
|
|
1548
|
+
)
|
|
1549
|
+
if "rng_states" not in metadata:
|
|
1550
|
+
raise ValueError("Metadata missing rng_states")
|
|
1551
|
+
return metadata
|
|
1552
|
+
|
|
1553
|
+
|
|
1554
|
+
def _capture_rng_states() -> Dict[str, object]:
|
|
1555
|
+
payload: Dict[str, object] = {
|
|
1556
|
+
"python": _encode_pickle(random.getstate()),
|
|
1557
|
+
"numpy": _encode_pickle(np.random.get_state()),
|
|
1558
|
+
"torch": _tensor_state_to_hex(torch.random.get_rng_state()),
|
|
1559
|
+
}
|
|
1560
|
+
if torch.cuda.is_available():
|
|
1561
|
+
payload["torch_cuda"] = [
|
|
1562
|
+
_tensor_state_to_hex(state) for state in torch.cuda.get_rng_state_all()
|
|
1563
|
+
] # type: ignore[attr-defined]
|
|
1564
|
+
return payload
|
|
1565
|
+
|
|
1566
|
+
|
|
1567
|
+
def _encode_pickle(obj: object) -> str:
|
|
1568
|
+
return base64.b64encode(pickle.dumps(obj)).decode("ascii")
|
|
1569
|
+
|
|
1570
|
+
|
|
1571
|
+
def _tensor_state_to_hex(state: torch.Tensor) -> str:
|
|
1572
|
+
return state.cpu().numpy().tobytes().hex()
|
|
1573
|
+
|
|
1574
|
+
|
|
1575
|
+
def _seed_everything(seed: int, *, deterministic: bool = False) -> None:
|
|
1576
|
+
random.seed(seed)
|
|
1577
|
+
np.random.seed(seed)
|
|
1578
|
+
torch.manual_seed(seed)
|
|
1579
|
+
if torch.cuda.is_available():
|
|
1580
|
+
torch.cuda.manual_seed_all(seed)
|
|
1581
|
+
if deterministic:
|
|
1582
|
+
torch.use_deterministic_algorithms(True, warn_only=True)
|
|
1583
|
+
if hasattr(torch.backends, "cudnn"):
|
|
1584
|
+
torch.backends.cudnn.benchmark = False # type: ignore[attr-defined]
|
|
1585
|
+
torch.backends.cudnn.deterministic = True # type: ignore[attr-defined]
|
|
1586
|
+
else:
|
|
1587
|
+
torch.use_deterministic_algorithms(False)
|
|
1588
|
+
if hasattr(torch.backends, "cudnn"):
|
|
1589
|
+
torch.backends.cudnn.benchmark = True # type: ignore[attr-defined]
|
|
1590
|
+
torch.backends.cudnn.deterministic = False # type: ignore[attr-defined]
|
|
1591
|
+
|
|
1592
|
+
|
|
1593
|
+
def _make_worker_init_fn(base_seed: int):
|
|
1594
|
+
def _init_fn(worker_id: int) -> None:
|
|
1595
|
+
worker_seed = base_seed + worker_id
|
|
1596
|
+
np.random.seed(worker_seed)
|
|
1597
|
+
random.seed(worker_seed)
|
|
1598
|
+
torch.manual_seed(worker_seed)
|
|
1599
|
+
|
|
1600
|
+
return _init_fn
|