omnius 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. package/README.md +4959 -0
  2. package/dist/index.d.ts +6 -0
  3. package/dist/index.js +630665 -0
  4. package/dist/launcher.cjs +78 -0
  5. package/dist/postinstall-daemon.cjs +776 -0
  6. package/dist/preinstall.cjs +92 -0
  7. package/dist/scripts/autoresearch-prepare.py +459 -0
  8. package/dist/scripts/autoresearch-train.py +661 -0
  9. package/dist/scripts/crawlee-scraper.py +358 -0
  10. package/dist/scripts/live-nemotron.py +478 -0
  11. package/dist/scripts/live-whisper.py +242 -0
  12. package/dist/scripts/ocr-advanced.py +571 -0
  13. package/dist/scripts/start-moondream.py +112 -0
  14. package/dist/scripts/tor/UPSTREAM-README.md +148 -0
  15. package/dist/scripts/tor/destroy_tor.sh +29 -0
  16. package/dist/scripts/tor/tor_setup.sh +163 -0
  17. package/dist/scripts/transcribe-file.py +63 -0
  18. package/dist/scripts/web_scrape.py +1295 -0
  19. package/npm-shrinkwrap.json +7412 -0
  20. package/package.json +142 -0
  21. package/prompts/agentic/system-large.md +569 -0
  22. package/prompts/agentic/system-medium.md +211 -0
  23. package/prompts/agentic/system-small.md +114 -0
  24. package/prompts/compaction/context-compaction.md +44 -0
  25. package/prompts/personality/level-1-minimal.md +3 -0
  26. package/prompts/personality/level-2-concise.md +3 -0
  27. package/prompts/personality/level-4-explanatory.md +3 -0
  28. package/prompts/personality/level-5-thorough.md +3 -0
  29. package/prompts/personality/level-autist.md +3 -0
  30. package/prompts/personality/level-stark.md +3 -0
  31. package/prompts/runners/dispatcher.md +24 -0
  32. package/prompts/runners/editor.md +44 -0
  33. package/prompts/runners/evaluator.md +30 -0
  34. package/prompts/runners/merge-summary.md +9 -0
  35. package/prompts/runners/normalizer.md +23 -0
  36. package/prompts/runners/planner.md +33 -0
  37. package/prompts/runners/scout.md +39 -0
  38. package/prompts/runners/verifier.md +36 -0
  39. package/prompts/skill-builder/seed-analysis.md +30 -0
  40. package/prompts/skill-builder/skill-expansion.md +76 -0
  41. package/prompts/skill-builder/skill-validation.md +31 -0
  42. package/prompts/templates/analysis.md +14 -0
  43. package/prompts/templates/code-review.md +16 -0
  44. package/prompts/templates/code.md +13 -0
  45. package/prompts/templates/document.md +13 -0
  46. package/prompts/templates/error-diagnosis.md +14 -0
  47. package/prompts/templates/general.md +9 -0
  48. package/prompts/templates/plan.md +15 -0
  49. package/prompts/templates/system.md +16 -0
  50. package/prompts/tui/dmn-gather.md +128 -0
  51. package/prompts/tui/dream-consolidate.md +48 -0
  52. package/prompts/tui/dream-lucid-eval.md +17 -0
  53. package/prompts/tui/dream-lucid-implement.md +14 -0
  54. package/prompts/tui/dream-stages.md +19 -0
  55. package/prompts/tui/emotion-behavioral.md +2 -0
  56. package/prompts/tui/emotion-center.md +12 -0
  57. package/voices/personaplex/OverBarn.pt +0 -0
  58. package/voices/personaplex/clone-voice.py +384 -0
  59. package/voices/personaplex/dequant-loader.py +174 -0
  60. package/voices/personaplex/quantize-weights.py +167 -0
@@ -0,0 +1,661 @@
1
+ """
2
+ Autoresearch pretraining script. Single-GPU, single-file.
3
+ Cherry-picked and simplified from nanochat.
4
+ Usage: uv run train.py
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import subprocess
10
+ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
11
+ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
12
+
13
+ # ---------------------------------------------------------------------------
14
+ # Auto-bootstrap: ensure we're running in a venv with deps installed
15
+ # ---------------------------------------------------------------------------
16
+
17
+ def _ensure_venv():
18
+ """If not in a venv, check for .venv in script dir and re-exec."""
19
+ if hasattr(sys, "real_prefix") or (hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix):
20
+ return # Already in a venv
21
+
22
+ script_dir = os.path.dirname(os.path.abspath(__file__))
23
+ venv_python = os.path.join(script_dir, ".venv", "bin", "python")
24
+
25
+ if os.path.exists(venv_python):
26
+ print(f"Re-launching with venv Python: {venv_python}")
27
+ os.execv(venv_python, [venv_python] + sys.argv)
28
+ else:
29
+ # Try to run prepare.py bootstrap first (it creates the venv)
30
+ prepare_py = os.path.join(script_dir, "prepare.py")
31
+ if os.path.exists(prepare_py):
32
+ print("No venv found. Running prepare.py to bootstrap dependencies...")
33
+ subprocess.check_call([sys.executable, prepare_py, "--num-shards", "0"])
34
+ # After prepare.py creates the venv, re-exec
35
+ if os.path.exists(venv_python):
36
+ os.execv(venv_python, [venv_python] + sys.argv)
37
+ print("ERROR: No venv found. Run 'uv sync' or 'python prepare.py' first.", file=sys.stderr)
38
+ sys.exit(1)
39
+
40
+ _ensure_venv()
41
+
42
+ import gc
43
+ import math
44
+ import time
45
+ from dataclasses import dataclass, asdict
46
+
47
+ import torch
48
+ import torch.nn as nn
49
+ import torch.nn.functional as F
50
+
51
+ from kernels import get_kernel
52
+ cap = torch.cuda.get_device_capability()
53
+ # varunneal's FA3 is Hopper only, use kernels-community on non-Hopper GPUs
54
+ repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3"
55
+ fa3 = get_kernel(repo).flash_attn_interface
56
+
57
+ from prepare import MAX_SEQ_LEN, TIME_BUDGET, Tokenizer, make_dataloader, evaluate_bpb
58
+
59
+ # ---------------------------------------------------------------------------
60
+ # GPT Model
61
+ # ---------------------------------------------------------------------------
62
+
63
+ @dataclass
64
+ class GPTConfig:
65
+ sequence_len: int = 2048
66
+ vocab_size: int = 32768
67
+ n_layer: int = 12
68
+ n_head: int = 6
69
+ n_kv_head: int = 6
70
+ n_embd: int = 768
71
+ window_pattern: str = "SSSL"
72
+
73
+
74
+ def norm(x):
75
+ return F.rms_norm(x, (x.size(-1),))
76
+
77
+
78
+ def has_ve(layer_idx, n_layer):
79
+ """Returns True if layer should have Value Embedding (alternating, last always included)."""
80
+ return layer_idx % 2 == (n_layer - 1) % 2
81
+
82
+
83
+ def apply_rotary_emb(x, cos, sin):
84
+ assert x.ndim == 4
85
+ d = x.shape[3] // 2
86
+ x1, x2 = x[..., :d], x[..., d:]
87
+ y1 = x1 * cos + x2 * sin
88
+ y2 = x1 * (-sin) + x2 * cos
89
+ return torch.cat([y1, y2], 3)
90
+
91
+
92
+ class CausalSelfAttention(nn.Module):
93
+ def __init__(self, config, layer_idx):
94
+ super().__init__()
95
+ self.n_head = config.n_head
96
+ self.n_kv_head = config.n_kv_head
97
+ self.n_embd = config.n_embd
98
+ self.head_dim = self.n_embd // self.n_head
99
+ assert self.n_embd % self.n_head == 0
100
+ assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
101
+ self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
102
+ self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
103
+ self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
104
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
105
+ self.ve_gate_channels = 32
106
+ self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
107
+
108
+ def forward(self, x, ve, cos_sin, window_size):
109
+ B, T, C = x.size()
110
+ q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
111
+ k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
112
+ v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
113
+
114
+ # Value residual (ResFormer): mix in value embedding with input-dependent gate per head
115
+ if ve is not None:
116
+ ve = ve.view(B, T, self.n_kv_head, self.head_dim)
117
+ gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels]))
118
+ v = v + gate.unsqueeze(-1) * ve
119
+
120
+ cos, sin = cos_sin
121
+ q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
122
+ q, k = norm(q), norm(k)
123
+
124
+ y = fa3.flash_attn_func(q, k, v, causal=True, window_size=window_size)
125
+ y = y.contiguous().view(B, T, -1)
126
+ y = self.c_proj(y)
127
+ return y
128
+
129
+
130
+ class MLP(nn.Module):
131
+ def __init__(self, config):
132
+ super().__init__()
133
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
134
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
135
+
136
+ def forward(self, x):
137
+ x = self.c_fc(x)
138
+ x = F.relu(x).square()
139
+ x = self.c_proj(x)
140
+ return x
141
+
142
+
143
+ class Block(nn.Module):
144
+ def __init__(self, config, layer_idx):
145
+ super().__init__()
146
+ self.attn = CausalSelfAttention(config, layer_idx)
147
+ self.mlp = MLP(config)
148
+
149
+ def forward(self, x, ve, cos_sin, window_size):
150
+ x = x + self.attn(norm(x), ve, cos_sin, window_size)
151
+ x = x + self.mlp(norm(x))
152
+ return x
153
+
154
+
155
+ class GPT(nn.Module):
156
+ def __init__(self, config):
157
+ super().__init__()
158
+ self.config = config
159
+ self.window_sizes = self._compute_window_sizes(config)
160
+ self.transformer = nn.ModuleDict({
161
+ "wte": nn.Embedding(config.vocab_size, config.n_embd),
162
+ "h": nn.ModuleList([Block(config, i) for i in range(config.n_layer)]),
163
+ })
164
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
165
+ self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer))
166
+ self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer))
167
+ # Value embeddings
168
+ head_dim = config.n_embd // config.n_head
169
+ kv_dim = config.n_kv_head * head_dim
170
+ self.value_embeds = nn.ModuleDict({
171
+ str(i): nn.Embedding(config.vocab_size, kv_dim)
172
+ for i in range(config.n_layer) if has_ve(i, config.n_layer)
173
+ })
174
+ # Rotary embeddings
175
+ self.rotary_seq_len = config.sequence_len * 10
176
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
177
+ self.register_buffer("cos", cos, persistent=False)
178
+ self.register_buffer("sin", sin, persistent=False)
179
+
180
+ @torch.no_grad()
181
+ def init_weights(self):
182
+ # Embedding and unembedding
183
+ torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0)
184
+ torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
185
+ # Transformer blocks
186
+ n_embd = self.config.n_embd
187
+ s = 3**0.5 * n_embd**-0.5
188
+ for block in self.transformer.h:
189
+ torch.nn.init.uniform_(block.attn.c_q.weight, -s, s)
190
+ torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
191
+ torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
192
+ torch.nn.init.zeros_(block.attn.c_proj.weight)
193
+ torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
194
+ torch.nn.init.zeros_(block.mlp.c_proj.weight)
195
+ # Per-layer scalars
196
+ self.resid_lambdas.fill_(1.0)
197
+ self.x0_lambdas.fill_(0.1)
198
+ # Value embeddings
199
+ for ve in self.value_embeds.values():
200
+ torch.nn.init.uniform_(ve.weight, -s, s)
201
+ # Gate weights init to zero (sigmoid(0)=0.5, scaled by 2 -> 1.0 = neutral)
202
+ for block in self.transformer.h:
203
+ if block.attn.ve_gate is not None:
204
+ torch.nn.init.zeros_(block.attn.ve_gate.weight)
205
+ # Rotary embeddings
206
+ head_dim = self.config.n_embd // self.config.n_head
207
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
208
+ self.cos, self.sin = cos, sin
209
+ # Cast embeddings to bf16
210
+ self.transformer.wte.to(dtype=torch.bfloat16)
211
+ for ve in self.value_embeds.values():
212
+ ve.to(dtype=torch.bfloat16)
213
+
214
+ def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
215
+ if device is None:
216
+ device = self.transformer.wte.weight.device
217
+ channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
218
+ inv_freq = 1.0 / (base ** (channel_range / head_dim))
219
+ t = torch.arange(seq_len, dtype=torch.float32, device=device)
220
+ freqs = torch.outer(t, inv_freq)
221
+ cos, sin = freqs.cos(), freqs.sin()
222
+ cos, sin = cos.bfloat16(), sin.bfloat16()
223
+ cos, sin = cos[None, :, None, :], sin[None, :, None, :]
224
+ return cos, sin
225
+
226
+ def _compute_window_sizes(self, config):
227
+ pattern = config.window_pattern.upper()
228
+ assert all(c in "SL" for c in pattern)
229
+ long_window = config.sequence_len
230
+ short_window = long_window // 2
231
+ char_to_window = {"L": (long_window, 0), "S": (short_window, 0)}
232
+ window_sizes = []
233
+ for layer_idx in range(config.n_layer):
234
+ char = pattern[layer_idx % len(pattern)]
235
+ window_sizes.append(char_to_window[char])
236
+ window_sizes[-1] = (long_window, 0)
237
+ return window_sizes
238
+
239
+ def estimate_flops(self):
240
+ """Estimated FLOPs per token (forward + backward)."""
241
+ nparams = sum(p.numel() for p in self.parameters())
242
+ value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
243
+ nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
244
+ self.resid_lambdas.numel() + self.x0_lambdas.numel())
245
+ h = self.config.n_head
246
+ q = self.config.n_embd // self.config.n_head
247
+ t = self.config.sequence_len
248
+ attn_flops = 0
249
+ for window_size in self.window_sizes:
250
+ window = window_size[0]
251
+ effective_seq = t if window < 0 else min(window, t)
252
+ attn_flops += 12 * h * q * effective_seq
253
+ return 6 * (nparams - nparams_exclude) + attn_flops
254
+
255
+ def num_scaling_params(self):
256
+ wte = sum(p.numel() for p in self.transformer.wte.parameters())
257
+ value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
258
+ lm_head = sum(p.numel() for p in self.lm_head.parameters())
259
+ transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
260
+ scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
261
+ total = wte + value_embeds + lm_head + transformer_matrices + scalars
262
+ return {
263
+ 'wte': wte, 'value_embeds': value_embeds, 'lm_head': lm_head,
264
+ 'transformer_matrices': transformer_matrices, 'scalars': scalars, 'total': total,
265
+ }
266
+
267
+ def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02,
268
+ weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
269
+ model_dim = self.config.n_embd
270
+ matrix_params = list(self.transformer.h.parameters())
271
+ value_embeds_params = list(self.value_embeds.parameters())
272
+ embedding_params = list(self.transformer.wte.parameters())
273
+ lm_head_params = list(self.lm_head.parameters())
274
+ resid_params = [self.resid_lambdas]
275
+ x0_params = [self.x0_lambdas]
276
+ assert len(list(self.parameters())) == (len(matrix_params) + len(embedding_params) +
277
+ len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params))
278
+ # Scale LR ∝ 1/√dmodel (tuned at 768 dim)
279
+ dmodel_lr_scale = (model_dim / 768) ** -0.5
280
+ print(f"Scaling AdamW LRs by 1/sqrt({model_dim}/768) = {dmodel_lr_scale:.6f}")
281
+ param_groups = [
282
+ dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
283
+ dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
284
+ dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
285
+ dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0),
286
+ dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0),
287
+ ]
288
+ for shape in sorted({p.shape for p in matrix_params}):
289
+ group_params = [p for p in matrix_params if p.shape == shape]
290
+ param_groups.append(dict(
291
+ kind='muon', params=group_params, lr=matrix_lr,
292
+ momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
293
+ ))
294
+ optimizer = MuonAdamW(param_groups)
295
+ for group in optimizer.param_groups:
296
+ group["initial_lr"] = group["lr"]
297
+ return optimizer
298
+
299
+ def forward(self, idx, targets=None, reduction='mean'):
300
+ B, T = idx.size()
301
+ assert T <= self.cos.size(1)
302
+ cos_sin = self.cos[:, :T], self.sin[:, :T]
303
+
304
+ x = self.transformer.wte(idx)
305
+ x = norm(x)
306
+ x0 = x
307
+ for i, block in enumerate(self.transformer.h):
308
+ x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
309
+ ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
310
+ x = block(x, ve, cos_sin, self.window_sizes[i])
311
+ x = norm(x)
312
+
313
+ softcap = 15
314
+ logits = self.lm_head(x)
315
+ logits = logits.float()
316
+ logits = softcap * torch.tanh(logits / softcap)
317
+
318
+ if targets is not None:
319
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1),
320
+ ignore_index=-1, reduction=reduction)
321
+ return loss
322
+ return logits
323
+
324
+ # ---------------------------------------------------------------------------
325
+ # Optimizer (MuonAdamW, single GPU only)
326
+ # ---------------------------------------------------------------------------
327
+
328
+ polar_express_coeffs = [
329
+ (8.156554524902461, -22.48329292557795, 15.878769915207462),
330
+ (4.042929935166739, -2.808917465908714, 0.5000178451051316),
331
+ (3.8916678022926607, -2.772484153217685, 0.5060648178503393),
332
+ (3.285753657755655, -2.3681294933425376, 0.46449024233003106),
333
+ (2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
334
+ ]
335
+
336
+ @torch.compile(dynamic=False, fullgraph=True)
337
+ def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t):
338
+ p.mul_(1 - lr_t * wd_t)
339
+ exp_avg.lerp_(grad, 1 - beta1_t)
340
+ exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
341
+ bias1 = 1 - beta1_t ** step_t
342
+ bias2 = 1 - beta2_t ** step_t
343
+ denom = (exp_avg_sq / bias2).sqrt() + eps_t
344
+ step_size = lr_t / bias1
345
+ p.add_(exp_avg / denom, alpha=-step_size)
346
+
347
+ @torch.compile(dynamic=False, fullgraph=True)
348
+ def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer,
349
+ momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim):
350
+ # Nesterov momentum
351
+ momentum = momentum_t.to(stacked_grads.dtype)
352
+ momentum_buffer.lerp_(stacked_grads, 1 - momentum)
353
+ g = stacked_grads.lerp_(momentum_buffer, momentum)
354
+ # Polar express orthogonalization
355
+ X = g.bfloat16()
356
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
357
+ if g.size(-2) > g.size(-1):
358
+ for a, b, c in polar_express_coeffs[:ns_steps]:
359
+ A = X.mT @ X
360
+ B = b * A + c * (A @ A)
361
+ X = a * X + X @ B
362
+ else:
363
+ for a, b, c in polar_express_coeffs[:ns_steps]:
364
+ A = X @ X.mT
365
+ B = b * A + c * (A @ A)
366
+ X = a * X + B @ X
367
+ g = X
368
+ # NorMuon variance reduction
369
+ beta2 = beta2_t.to(g.dtype)
370
+ v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
371
+ red_dim_size = g.size(red_dim)
372
+ v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
373
+ v_norm = v_norm_sq.sqrt()
374
+ second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
375
+ step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
376
+ scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
377
+ v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
378
+ final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
379
+ g = g * final_scale.to(g.dtype)
380
+ # Cautious weight decay + parameter update
381
+ lr = lr_t.to(g.dtype)
382
+ wd = wd_t.to(g.dtype)
383
+ mask = (g * stacked_params) >= 0
384
+ stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
385
+
386
+
387
+ class MuonAdamW(torch.optim.Optimizer):
388
+ """Combined optimizer: Muon for 2D matrix params, AdamW for others."""
389
+
390
+ def __init__(self, param_groups):
391
+ super().__init__(param_groups, defaults={})
392
+ # 0-D CPU tensors to avoid torch.compile recompilation when values change
393
+ self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
394
+ self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
395
+ self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
396
+ self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
397
+ self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
398
+ self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
399
+ self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
400
+ self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
401
+ self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
402
+ self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
403
+
404
+ def _step_adamw(self, group):
405
+ for p in group['params']:
406
+ if p.grad is None:
407
+ continue
408
+ grad = p.grad
409
+ state = self.state[p]
410
+ if not state:
411
+ state['step'] = 0
412
+ state['exp_avg'] = torch.zeros_like(p)
413
+ state['exp_avg_sq'] = torch.zeros_like(p)
414
+ state['step'] += 1
415
+ self._adamw_step_t.fill_(state['step'])
416
+ self._adamw_lr_t.fill_(group['lr'])
417
+ self._adamw_beta1_t.fill_(group['betas'][0])
418
+ self._adamw_beta2_t.fill_(group['betas'][1])
419
+ self._adamw_eps_t.fill_(group['eps'])
420
+ self._adamw_wd_t.fill_(group['weight_decay'])
421
+ adamw_step_fused(p, grad, state['exp_avg'], state['exp_avg_sq'],
422
+ self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
423
+ self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t)
424
+
425
+ def _step_muon(self, group):
426
+ params = group['params']
427
+ if not params:
428
+ return
429
+ p = params[0]
430
+ state = self.state[p]
431
+ num_params = len(params)
432
+ shape, device, dtype = p.shape, p.device, p.dtype
433
+ if "momentum_buffer" not in state:
434
+ state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
435
+ if "second_momentum_buffer" not in state:
436
+ state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
437
+ state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
438
+ red_dim = -1 if shape[-2] >= shape[-1] else -2
439
+ stacked_grads = torch.stack([p.grad for p in params])
440
+ stacked_params = torch.stack(params)
441
+ self._muon_momentum_t.fill_(group["momentum"])
442
+ self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
443
+ self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
444
+ self._muon_wd_t.fill_(group["weight_decay"])
445
+ muon_step_fused(stacked_grads, stacked_params,
446
+ state["momentum_buffer"], state["second_momentum_buffer"],
447
+ self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t,
448
+ self._muon_beta2_t, group["ns_steps"], red_dim)
449
+ torch._foreach_copy_(params, list(stacked_params.unbind(0)))
450
+
451
+ @torch.no_grad()
452
+ def step(self):
453
+ for group in self.param_groups:
454
+ if group['kind'] == 'adamw':
455
+ self._step_adamw(group)
456
+ elif group['kind'] == 'muon':
457
+ self._step_muon(group)
458
+
459
+ # ---------------------------------------------------------------------------
460
+ # Hyperparameters (edit these directly, no CLI flags needed)
461
+ # ---------------------------------------------------------------------------
462
+
463
+ # Model architecture
464
+ ASPECT_RATIO = 64 # model_dim = depth * ASPECT_RATIO
465
+ HEAD_DIM = 128 # target head dimension for attention
466
+ WINDOW_PATTERN = "SSSL" # sliding window pattern: L=full, S=half context
467
+
468
+ # Optimization
469
+ TOTAL_BATCH_SIZE = 2**19 # ~524K tokens per optimizer step
470
+ EMBEDDING_LR = 0.6 # learning rate for token embeddings (Adam)
471
+ UNEMBEDDING_LR = 0.004 # learning rate for lm_head (Adam)
472
+ MATRIX_LR = 0.04 # learning rate for matrix parameters (Muon)
473
+ SCALAR_LR = 0.5 # learning rate for per-layer scalars (Adam)
474
+ WEIGHT_DECAY = 0.2 # cautious weight decay for Muon
475
+ ADAM_BETAS = (0.8, 0.95) # Adam beta1, beta2
476
+ WARMUP_RATIO = 0.0 # fraction of time budget for LR warmup
477
+ WARMDOWN_RATIO = 0.5 # fraction of time budget for LR warmdown
478
+ FINAL_LR_FRAC = 0.0 # final LR as fraction of initial
479
+
480
+ # Model size
481
+ DEPTH = 8 # number of transformer layers
482
+ DEVICE_BATCH_SIZE = 128 # per-device batch size (reduce if OOM)
483
+
484
+ # ---------------------------------------------------------------------------
485
+ # Setup: tokenizer, model, optimizer, dataloader
486
+ # ---------------------------------------------------------------------------
487
+
488
+ t_start = time.time()
489
+ torch.manual_seed(42)
490
+ torch.cuda.manual_seed(42)
491
+ torch.set_float32_matmul_precision("high")
492
+ device = torch.device("cuda")
493
+ autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
494
+ H100_BF16_PEAK_FLOPS = 989.5e12
495
+
496
+ tokenizer = Tokenizer.from_directory()
497
+ vocab_size = tokenizer.get_vocab_size()
498
+ print(f"Vocab size: {vocab_size:,}")
499
+
500
+ def build_model_config(depth):
501
+ base_dim = depth * ASPECT_RATIO
502
+ model_dim = ((base_dim + HEAD_DIM - 1) // HEAD_DIM) * HEAD_DIM
503
+ num_heads = model_dim // HEAD_DIM
504
+ return GPTConfig(
505
+ sequence_len=MAX_SEQ_LEN, vocab_size=vocab_size,
506
+ n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
507
+ window_pattern=WINDOW_PATTERN,
508
+ )
509
+
510
+ config = build_model_config(DEPTH)
511
+ print(f"Model config: {asdict(config)}")
512
+
513
+ with torch.device("meta"):
514
+ model = GPT(config)
515
+ model.to_empty(device=device)
516
+ model.init_weights()
517
+
518
+ param_counts = model.num_scaling_params()
519
+ print("Parameter counts:")
520
+ for key, value in param_counts.items():
521
+ print(f" {key:24s}: {value:,}")
522
+ num_params = param_counts['total']
523
+ num_flops_per_token = model.estimate_flops()
524
+ print(f"Estimated FLOPs per token: {num_flops_per_token:e}")
525
+
526
+ tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN
527
+ assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0
528
+ grad_accum_steps = TOTAL_BATCH_SIZE // tokens_per_fwdbwd
529
+
530
+ optimizer = model.setup_optimizer(
531
+ unembedding_lr=UNEMBEDDING_LR,
532
+ embedding_lr=EMBEDDING_LR,
533
+ scalar_lr=SCALAR_LR,
534
+ adam_betas=ADAM_BETAS,
535
+ matrix_lr=MATRIX_LR,
536
+ weight_decay=WEIGHT_DECAY,
537
+ )
538
+
539
+ model = torch.compile(model, dynamic=False)
540
+
541
+ train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, MAX_SEQ_LEN, "train")
542
+ x, y, epoch = next(train_loader) # prefetch first batch
543
+
544
+ print(f"Time budget: {TIME_BUDGET}s")
545
+ print(f"Gradient accumulation steps: {grad_accum_steps}")
546
+
547
+ # Schedules (all based on progress = training_time / TIME_BUDGET)
548
+
549
+ def get_lr_multiplier(progress):
550
+ if progress < WARMUP_RATIO:
551
+ return progress / WARMUP_RATIO if WARMUP_RATIO > 0 else 1.0
552
+ elif progress < 1.0 - WARMDOWN_RATIO:
553
+ return 1.0
554
+ else:
555
+ cooldown = (1.0 - progress) / WARMDOWN_RATIO
556
+ return cooldown * 1.0 + (1 - cooldown) * FINAL_LR_FRAC
557
+
558
+ def get_muon_momentum(step):
559
+ frac = min(step / 300, 1)
560
+ return (1 - frac) * 0.85 + frac * 0.95
561
+
562
+ def get_weight_decay(progress):
563
+ return WEIGHT_DECAY * (1 - progress)
564
+
565
+ # ---------------------------------------------------------------------------
566
+ # Training loop
567
+ # ---------------------------------------------------------------------------
568
+
569
+ t_start_training = time.time()
570
+ smooth_train_loss = 0
571
+ total_training_time = 0
572
+ step = 0
573
+
574
+ while True:
575
+ torch.cuda.synchronize()
576
+ t0 = time.time()
577
+ for micro_step in range(grad_accum_steps):
578
+ with autocast_ctx:
579
+ loss = model(x, y)
580
+ train_loss = loss.detach()
581
+ loss = loss / grad_accum_steps
582
+ loss.backward()
583
+ x, y, epoch = next(train_loader)
584
+
585
+ # Progress and schedules
586
+ progress = min(total_training_time / TIME_BUDGET, 1.0)
587
+ lrm = get_lr_multiplier(progress)
588
+ muon_momentum = get_muon_momentum(step)
589
+ muon_weight_decay = get_weight_decay(progress)
590
+ for group in optimizer.param_groups:
591
+ group["lr"] = group["initial_lr"] * lrm
592
+ if group['kind'] == 'muon':
593
+ group["momentum"] = muon_momentum
594
+ group["weight_decay"] = muon_weight_decay
595
+ optimizer.step()
596
+ model.zero_grad(set_to_none=True)
597
+
598
+ train_loss_f = train_loss.item()
599
+
600
+ # Fast fail: abort if loss is exploding or NaN
601
+ if math.isnan(train_loss_f) or train_loss_f > 100:
602
+ print("FAIL")
603
+ exit(1)
604
+
605
+ torch.cuda.synchronize()
606
+ t1 = time.time()
607
+ dt = t1 - t0
608
+
609
+ if step > 10:
610
+ total_training_time += dt
611
+
612
+ # Logging
613
+ ema_beta = 0.9
614
+ smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f
615
+ debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1))
616
+ pct_done = 100 * progress
617
+ tok_per_sec = int(TOTAL_BATCH_SIZE / dt)
618
+ mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE / dt / H100_BF16_PEAK_FLOPS
619
+ remaining = max(0, TIME_BUDGET - total_training_time)
620
+
621
+ print(f"\rstep {step:05d} ({pct_done:.1f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt*1000:.0f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.1f}% | epoch: {epoch} | remaining: {remaining:.0f}s ", end="", flush=True)
622
+
623
+ # GC management (Python's GC causes ~500ms stalls)
624
+ if step == 0:
625
+ gc.collect()
626
+ gc.freeze()
627
+ gc.disable()
628
+ elif (step + 1) % 5000 == 0:
629
+ gc.collect()
630
+
631
+ step += 1
632
+
633
+ # Time's up — but only stop after warmup steps so we don't count compilation
634
+ if step > 10 and total_training_time >= TIME_BUDGET:
635
+ break
636
+
637
+ print() # newline after \r training log
638
+
639
+ total_tokens = step * TOTAL_BATCH_SIZE
640
+
641
+ # Final eval
642
+ model.eval()
643
+ with autocast_ctx:
644
+ val_bpb = evaluate_bpb(model, tokenizer, DEVICE_BATCH_SIZE)
645
+
646
+ # Final summary
647
+ t_end = time.time()
648
+ startup_time = t_start_training - t_start
649
+ steady_state_mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE * (step - 10) / total_training_time / H100_BF16_PEAK_FLOPS if total_training_time > 0 else 0
650
+ peak_vram_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
651
+
652
+ print("---")
653
+ print(f"val_bpb: {val_bpb:.6f}")
654
+ print(f"training_seconds: {total_training_time:.1f}")
655
+ print(f"total_seconds: {t_end - t_start:.1f}")
656
+ print(f"peak_vram_mb: {peak_vram_mb:.1f}")
657
+ print(f"mfu_percent: {steady_state_mfu:.2f}")
658
+ print(f"total_tokens_M: {total_tokens / 1e6:.1f}")
659
+ print(f"num_steps: {step}")
660
+ print(f"num_params_M: {num_params / 1e6:.1f}")
661
+ print(f"depth: {DEPTH}")