omnius 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +4959 -0
- package/dist/index.d.ts +6 -0
- package/dist/index.js +630665 -0
- package/dist/launcher.cjs +78 -0
- package/dist/postinstall-daemon.cjs +776 -0
- package/dist/preinstall.cjs +92 -0
- package/dist/scripts/autoresearch-prepare.py +459 -0
- package/dist/scripts/autoresearch-train.py +661 -0
- package/dist/scripts/crawlee-scraper.py +358 -0
- package/dist/scripts/live-nemotron.py +478 -0
- package/dist/scripts/live-whisper.py +242 -0
- package/dist/scripts/ocr-advanced.py +571 -0
- package/dist/scripts/start-moondream.py +112 -0
- package/dist/scripts/tor/UPSTREAM-README.md +148 -0
- package/dist/scripts/tor/destroy_tor.sh +29 -0
- package/dist/scripts/tor/tor_setup.sh +163 -0
- package/dist/scripts/transcribe-file.py +63 -0
- package/dist/scripts/web_scrape.py +1295 -0
- package/npm-shrinkwrap.json +7412 -0
- package/package.json +142 -0
- package/prompts/agentic/system-large.md +569 -0
- package/prompts/agentic/system-medium.md +211 -0
- package/prompts/agentic/system-small.md +114 -0
- package/prompts/compaction/context-compaction.md +44 -0
- package/prompts/personality/level-1-minimal.md +3 -0
- package/prompts/personality/level-2-concise.md +3 -0
- package/prompts/personality/level-4-explanatory.md +3 -0
- package/prompts/personality/level-5-thorough.md +3 -0
- package/prompts/personality/level-autist.md +3 -0
- package/prompts/personality/level-stark.md +3 -0
- package/prompts/runners/dispatcher.md +24 -0
- package/prompts/runners/editor.md +44 -0
- package/prompts/runners/evaluator.md +30 -0
- package/prompts/runners/merge-summary.md +9 -0
- package/prompts/runners/normalizer.md +23 -0
- package/prompts/runners/planner.md +33 -0
- package/prompts/runners/scout.md +39 -0
- package/prompts/runners/verifier.md +36 -0
- package/prompts/skill-builder/seed-analysis.md +30 -0
- package/prompts/skill-builder/skill-expansion.md +76 -0
- package/prompts/skill-builder/skill-validation.md +31 -0
- package/prompts/templates/analysis.md +14 -0
- package/prompts/templates/code-review.md +16 -0
- package/prompts/templates/code.md +13 -0
- package/prompts/templates/document.md +13 -0
- package/prompts/templates/error-diagnosis.md +14 -0
- package/prompts/templates/general.md +9 -0
- package/prompts/templates/plan.md +15 -0
- package/prompts/templates/system.md +16 -0
- package/prompts/tui/dmn-gather.md +128 -0
- package/prompts/tui/dream-consolidate.md +48 -0
- package/prompts/tui/dream-lucid-eval.md +17 -0
- package/prompts/tui/dream-lucid-implement.md +14 -0
- package/prompts/tui/dream-stages.md +19 -0
- package/prompts/tui/emotion-behavioral.md +2 -0
- package/prompts/tui/emotion-center.md +12 -0
- package/voices/personaplex/OverBarn.pt +0 -0
- package/voices/personaplex/clone-voice.py +384 -0
- package/voices/personaplex/dequant-loader.py +174 -0
- package/voices/personaplex/quantize-weights.py +167 -0
|
@@ -0,0 +1,661 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Autoresearch pretraining script. Single-GPU, single-file.
|
|
3
|
+
Cherry-picked and simplified from nanochat.
|
|
4
|
+
Usage: uv run train.py
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import sys
|
|
9
|
+
import subprocess
|
|
10
|
+
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
|
11
|
+
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
|
12
|
+
|
|
13
|
+
# ---------------------------------------------------------------------------
|
|
14
|
+
# Auto-bootstrap: ensure we're running in a venv with deps installed
|
|
15
|
+
# ---------------------------------------------------------------------------
|
|
16
|
+
|
|
17
|
+
def _ensure_venv():
|
|
18
|
+
"""If not in a venv, check for .venv in script dir and re-exec."""
|
|
19
|
+
if hasattr(sys, "real_prefix") or (hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix):
|
|
20
|
+
return # Already in a venv
|
|
21
|
+
|
|
22
|
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
23
|
+
venv_python = os.path.join(script_dir, ".venv", "bin", "python")
|
|
24
|
+
|
|
25
|
+
if os.path.exists(venv_python):
|
|
26
|
+
print(f"Re-launching with venv Python: {venv_python}")
|
|
27
|
+
os.execv(venv_python, [venv_python] + sys.argv)
|
|
28
|
+
else:
|
|
29
|
+
# Try to run prepare.py bootstrap first (it creates the venv)
|
|
30
|
+
prepare_py = os.path.join(script_dir, "prepare.py")
|
|
31
|
+
if os.path.exists(prepare_py):
|
|
32
|
+
print("No venv found. Running prepare.py to bootstrap dependencies...")
|
|
33
|
+
subprocess.check_call([sys.executable, prepare_py, "--num-shards", "0"])
|
|
34
|
+
# After prepare.py creates the venv, re-exec
|
|
35
|
+
if os.path.exists(venv_python):
|
|
36
|
+
os.execv(venv_python, [venv_python] + sys.argv)
|
|
37
|
+
print("ERROR: No venv found. Run 'uv sync' or 'python prepare.py' first.", file=sys.stderr)
|
|
38
|
+
sys.exit(1)
|
|
39
|
+
|
|
40
|
+
_ensure_venv()
|
|
41
|
+
|
|
42
|
+
import gc
|
|
43
|
+
import math
|
|
44
|
+
import time
|
|
45
|
+
from dataclasses import dataclass, asdict
|
|
46
|
+
|
|
47
|
+
import torch
|
|
48
|
+
import torch.nn as nn
|
|
49
|
+
import torch.nn.functional as F
|
|
50
|
+
|
|
51
|
+
from kernels import get_kernel
|
|
52
|
+
cap = torch.cuda.get_device_capability()
|
|
53
|
+
# varunneal's FA3 is Hopper only, use kernels-community on non-Hopper GPUs
|
|
54
|
+
repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3"
|
|
55
|
+
fa3 = get_kernel(repo).flash_attn_interface
|
|
56
|
+
|
|
57
|
+
from prepare import MAX_SEQ_LEN, TIME_BUDGET, Tokenizer, make_dataloader, evaluate_bpb
|
|
58
|
+
|
|
59
|
+
# ---------------------------------------------------------------------------
|
|
60
|
+
# GPT Model
|
|
61
|
+
# ---------------------------------------------------------------------------
|
|
62
|
+
|
|
63
|
+
@dataclass
|
|
64
|
+
class GPTConfig:
|
|
65
|
+
sequence_len: int = 2048
|
|
66
|
+
vocab_size: int = 32768
|
|
67
|
+
n_layer: int = 12
|
|
68
|
+
n_head: int = 6
|
|
69
|
+
n_kv_head: int = 6
|
|
70
|
+
n_embd: int = 768
|
|
71
|
+
window_pattern: str = "SSSL"
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def norm(x):
|
|
75
|
+
return F.rms_norm(x, (x.size(-1),))
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def has_ve(layer_idx, n_layer):
|
|
79
|
+
"""Returns True if layer should have Value Embedding (alternating, last always included)."""
|
|
80
|
+
return layer_idx % 2 == (n_layer - 1) % 2
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def apply_rotary_emb(x, cos, sin):
|
|
84
|
+
assert x.ndim == 4
|
|
85
|
+
d = x.shape[3] // 2
|
|
86
|
+
x1, x2 = x[..., :d], x[..., d:]
|
|
87
|
+
y1 = x1 * cos + x2 * sin
|
|
88
|
+
y2 = x1 * (-sin) + x2 * cos
|
|
89
|
+
return torch.cat([y1, y2], 3)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class CausalSelfAttention(nn.Module):
|
|
93
|
+
def __init__(self, config, layer_idx):
|
|
94
|
+
super().__init__()
|
|
95
|
+
self.n_head = config.n_head
|
|
96
|
+
self.n_kv_head = config.n_kv_head
|
|
97
|
+
self.n_embd = config.n_embd
|
|
98
|
+
self.head_dim = self.n_embd // self.n_head
|
|
99
|
+
assert self.n_embd % self.n_head == 0
|
|
100
|
+
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
|
101
|
+
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
|
102
|
+
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
|
103
|
+
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
|
104
|
+
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
|
105
|
+
self.ve_gate_channels = 32
|
|
106
|
+
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
|
|
107
|
+
|
|
108
|
+
def forward(self, x, ve, cos_sin, window_size):
|
|
109
|
+
B, T, C = x.size()
|
|
110
|
+
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
|
111
|
+
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
|
112
|
+
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
|
113
|
+
|
|
114
|
+
# Value residual (ResFormer): mix in value embedding with input-dependent gate per head
|
|
115
|
+
if ve is not None:
|
|
116
|
+
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
|
|
117
|
+
gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels]))
|
|
118
|
+
v = v + gate.unsqueeze(-1) * ve
|
|
119
|
+
|
|
120
|
+
cos, sin = cos_sin
|
|
121
|
+
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
|
|
122
|
+
q, k = norm(q), norm(k)
|
|
123
|
+
|
|
124
|
+
y = fa3.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
|
125
|
+
y = y.contiguous().view(B, T, -1)
|
|
126
|
+
y = self.c_proj(y)
|
|
127
|
+
return y
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class MLP(nn.Module):
|
|
131
|
+
def __init__(self, config):
|
|
132
|
+
super().__init__()
|
|
133
|
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
|
134
|
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
|
135
|
+
|
|
136
|
+
def forward(self, x):
|
|
137
|
+
x = self.c_fc(x)
|
|
138
|
+
x = F.relu(x).square()
|
|
139
|
+
x = self.c_proj(x)
|
|
140
|
+
return x
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class Block(nn.Module):
|
|
144
|
+
def __init__(self, config, layer_idx):
|
|
145
|
+
super().__init__()
|
|
146
|
+
self.attn = CausalSelfAttention(config, layer_idx)
|
|
147
|
+
self.mlp = MLP(config)
|
|
148
|
+
|
|
149
|
+
def forward(self, x, ve, cos_sin, window_size):
|
|
150
|
+
x = x + self.attn(norm(x), ve, cos_sin, window_size)
|
|
151
|
+
x = x + self.mlp(norm(x))
|
|
152
|
+
return x
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class GPT(nn.Module):
|
|
156
|
+
def __init__(self, config):
|
|
157
|
+
super().__init__()
|
|
158
|
+
self.config = config
|
|
159
|
+
self.window_sizes = self._compute_window_sizes(config)
|
|
160
|
+
self.transformer = nn.ModuleDict({
|
|
161
|
+
"wte": nn.Embedding(config.vocab_size, config.n_embd),
|
|
162
|
+
"h": nn.ModuleList([Block(config, i) for i in range(config.n_layer)]),
|
|
163
|
+
})
|
|
164
|
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
165
|
+
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer))
|
|
166
|
+
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer))
|
|
167
|
+
# Value embeddings
|
|
168
|
+
head_dim = config.n_embd // config.n_head
|
|
169
|
+
kv_dim = config.n_kv_head * head_dim
|
|
170
|
+
self.value_embeds = nn.ModuleDict({
|
|
171
|
+
str(i): nn.Embedding(config.vocab_size, kv_dim)
|
|
172
|
+
for i in range(config.n_layer) if has_ve(i, config.n_layer)
|
|
173
|
+
})
|
|
174
|
+
# Rotary embeddings
|
|
175
|
+
self.rotary_seq_len = config.sequence_len * 10
|
|
176
|
+
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
|
177
|
+
self.register_buffer("cos", cos, persistent=False)
|
|
178
|
+
self.register_buffer("sin", sin, persistent=False)
|
|
179
|
+
|
|
180
|
+
@torch.no_grad()
|
|
181
|
+
def init_weights(self):
|
|
182
|
+
# Embedding and unembedding
|
|
183
|
+
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0)
|
|
184
|
+
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
|
|
185
|
+
# Transformer blocks
|
|
186
|
+
n_embd = self.config.n_embd
|
|
187
|
+
s = 3**0.5 * n_embd**-0.5
|
|
188
|
+
for block in self.transformer.h:
|
|
189
|
+
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s)
|
|
190
|
+
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
|
|
191
|
+
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
|
|
192
|
+
torch.nn.init.zeros_(block.attn.c_proj.weight)
|
|
193
|
+
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
|
|
194
|
+
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
|
195
|
+
# Per-layer scalars
|
|
196
|
+
self.resid_lambdas.fill_(1.0)
|
|
197
|
+
self.x0_lambdas.fill_(0.1)
|
|
198
|
+
# Value embeddings
|
|
199
|
+
for ve in self.value_embeds.values():
|
|
200
|
+
torch.nn.init.uniform_(ve.weight, -s, s)
|
|
201
|
+
# Gate weights init to zero (sigmoid(0)=0.5, scaled by 2 -> 1.0 = neutral)
|
|
202
|
+
for block in self.transformer.h:
|
|
203
|
+
if block.attn.ve_gate is not None:
|
|
204
|
+
torch.nn.init.zeros_(block.attn.ve_gate.weight)
|
|
205
|
+
# Rotary embeddings
|
|
206
|
+
head_dim = self.config.n_embd // self.config.n_head
|
|
207
|
+
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
|
208
|
+
self.cos, self.sin = cos, sin
|
|
209
|
+
# Cast embeddings to bf16
|
|
210
|
+
self.transformer.wte.to(dtype=torch.bfloat16)
|
|
211
|
+
for ve in self.value_embeds.values():
|
|
212
|
+
ve.to(dtype=torch.bfloat16)
|
|
213
|
+
|
|
214
|
+
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
|
215
|
+
if device is None:
|
|
216
|
+
device = self.transformer.wte.weight.device
|
|
217
|
+
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
|
|
218
|
+
inv_freq = 1.0 / (base ** (channel_range / head_dim))
|
|
219
|
+
t = torch.arange(seq_len, dtype=torch.float32, device=device)
|
|
220
|
+
freqs = torch.outer(t, inv_freq)
|
|
221
|
+
cos, sin = freqs.cos(), freqs.sin()
|
|
222
|
+
cos, sin = cos.bfloat16(), sin.bfloat16()
|
|
223
|
+
cos, sin = cos[None, :, None, :], sin[None, :, None, :]
|
|
224
|
+
return cos, sin
|
|
225
|
+
|
|
226
|
+
def _compute_window_sizes(self, config):
|
|
227
|
+
pattern = config.window_pattern.upper()
|
|
228
|
+
assert all(c in "SL" for c in pattern)
|
|
229
|
+
long_window = config.sequence_len
|
|
230
|
+
short_window = long_window // 2
|
|
231
|
+
char_to_window = {"L": (long_window, 0), "S": (short_window, 0)}
|
|
232
|
+
window_sizes = []
|
|
233
|
+
for layer_idx in range(config.n_layer):
|
|
234
|
+
char = pattern[layer_idx % len(pattern)]
|
|
235
|
+
window_sizes.append(char_to_window[char])
|
|
236
|
+
window_sizes[-1] = (long_window, 0)
|
|
237
|
+
return window_sizes
|
|
238
|
+
|
|
239
|
+
def estimate_flops(self):
|
|
240
|
+
"""Estimated FLOPs per token (forward + backward)."""
|
|
241
|
+
nparams = sum(p.numel() for p in self.parameters())
|
|
242
|
+
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
|
243
|
+
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
|
244
|
+
self.resid_lambdas.numel() + self.x0_lambdas.numel())
|
|
245
|
+
h = self.config.n_head
|
|
246
|
+
q = self.config.n_embd // self.config.n_head
|
|
247
|
+
t = self.config.sequence_len
|
|
248
|
+
attn_flops = 0
|
|
249
|
+
for window_size in self.window_sizes:
|
|
250
|
+
window = window_size[0]
|
|
251
|
+
effective_seq = t if window < 0 else min(window, t)
|
|
252
|
+
attn_flops += 12 * h * q * effective_seq
|
|
253
|
+
return 6 * (nparams - nparams_exclude) + attn_flops
|
|
254
|
+
|
|
255
|
+
def num_scaling_params(self):
|
|
256
|
+
wte = sum(p.numel() for p in self.transformer.wte.parameters())
|
|
257
|
+
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
|
|
258
|
+
lm_head = sum(p.numel() for p in self.lm_head.parameters())
|
|
259
|
+
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
|
|
260
|
+
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
|
|
261
|
+
total = wte + value_embeds + lm_head + transformer_matrices + scalars
|
|
262
|
+
return {
|
|
263
|
+
'wte': wte, 'value_embeds': value_embeds, 'lm_head': lm_head,
|
|
264
|
+
'transformer_matrices': transformer_matrices, 'scalars': scalars, 'total': total,
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02,
|
|
268
|
+
weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
|
269
|
+
model_dim = self.config.n_embd
|
|
270
|
+
matrix_params = list(self.transformer.h.parameters())
|
|
271
|
+
value_embeds_params = list(self.value_embeds.parameters())
|
|
272
|
+
embedding_params = list(self.transformer.wte.parameters())
|
|
273
|
+
lm_head_params = list(self.lm_head.parameters())
|
|
274
|
+
resid_params = [self.resid_lambdas]
|
|
275
|
+
x0_params = [self.x0_lambdas]
|
|
276
|
+
assert len(list(self.parameters())) == (len(matrix_params) + len(embedding_params) +
|
|
277
|
+
len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params))
|
|
278
|
+
# Scale LR ∝ 1/√dmodel (tuned at 768 dim)
|
|
279
|
+
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
|
280
|
+
print(f"Scaling AdamW LRs by 1/sqrt({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
|
281
|
+
param_groups = [
|
|
282
|
+
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
|
283
|
+
dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
|
284
|
+
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
|
285
|
+
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
|
286
|
+
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0),
|
|
287
|
+
]
|
|
288
|
+
for shape in sorted({p.shape for p in matrix_params}):
|
|
289
|
+
group_params = [p for p in matrix_params if p.shape == shape]
|
|
290
|
+
param_groups.append(dict(
|
|
291
|
+
kind='muon', params=group_params, lr=matrix_lr,
|
|
292
|
+
momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
|
|
293
|
+
))
|
|
294
|
+
optimizer = MuonAdamW(param_groups)
|
|
295
|
+
for group in optimizer.param_groups:
|
|
296
|
+
group["initial_lr"] = group["lr"]
|
|
297
|
+
return optimizer
|
|
298
|
+
|
|
299
|
+
def forward(self, idx, targets=None, reduction='mean'):
|
|
300
|
+
B, T = idx.size()
|
|
301
|
+
assert T <= self.cos.size(1)
|
|
302
|
+
cos_sin = self.cos[:, :T], self.sin[:, :T]
|
|
303
|
+
|
|
304
|
+
x = self.transformer.wte(idx)
|
|
305
|
+
x = norm(x)
|
|
306
|
+
x0 = x
|
|
307
|
+
for i, block in enumerate(self.transformer.h):
|
|
308
|
+
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
|
309
|
+
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
|
|
310
|
+
x = block(x, ve, cos_sin, self.window_sizes[i])
|
|
311
|
+
x = norm(x)
|
|
312
|
+
|
|
313
|
+
softcap = 15
|
|
314
|
+
logits = self.lm_head(x)
|
|
315
|
+
logits = logits.float()
|
|
316
|
+
logits = softcap * torch.tanh(logits / softcap)
|
|
317
|
+
|
|
318
|
+
if targets is not None:
|
|
319
|
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1),
|
|
320
|
+
ignore_index=-1, reduction=reduction)
|
|
321
|
+
return loss
|
|
322
|
+
return logits
|
|
323
|
+
|
|
324
|
+
# ---------------------------------------------------------------------------
|
|
325
|
+
# Optimizer (MuonAdamW, single GPU only)
|
|
326
|
+
# ---------------------------------------------------------------------------
|
|
327
|
+
|
|
328
|
+
polar_express_coeffs = [
|
|
329
|
+
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
|
330
|
+
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
|
|
331
|
+
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
|
|
332
|
+
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
|
|
333
|
+
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
|
334
|
+
]
|
|
335
|
+
|
|
336
|
+
@torch.compile(dynamic=False, fullgraph=True)
|
|
337
|
+
def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t):
|
|
338
|
+
p.mul_(1 - lr_t * wd_t)
|
|
339
|
+
exp_avg.lerp_(grad, 1 - beta1_t)
|
|
340
|
+
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
|
341
|
+
bias1 = 1 - beta1_t ** step_t
|
|
342
|
+
bias2 = 1 - beta2_t ** step_t
|
|
343
|
+
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
|
344
|
+
step_size = lr_t / bias1
|
|
345
|
+
p.add_(exp_avg / denom, alpha=-step_size)
|
|
346
|
+
|
|
347
|
+
@torch.compile(dynamic=False, fullgraph=True)
|
|
348
|
+
def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer,
|
|
349
|
+
momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim):
|
|
350
|
+
# Nesterov momentum
|
|
351
|
+
momentum = momentum_t.to(stacked_grads.dtype)
|
|
352
|
+
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
|
353
|
+
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
|
354
|
+
# Polar express orthogonalization
|
|
355
|
+
X = g.bfloat16()
|
|
356
|
+
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
|
357
|
+
if g.size(-2) > g.size(-1):
|
|
358
|
+
for a, b, c in polar_express_coeffs[:ns_steps]:
|
|
359
|
+
A = X.mT @ X
|
|
360
|
+
B = b * A + c * (A @ A)
|
|
361
|
+
X = a * X + X @ B
|
|
362
|
+
else:
|
|
363
|
+
for a, b, c in polar_express_coeffs[:ns_steps]:
|
|
364
|
+
A = X @ X.mT
|
|
365
|
+
B = b * A + c * (A @ A)
|
|
366
|
+
X = a * X + B @ X
|
|
367
|
+
g = X
|
|
368
|
+
# NorMuon variance reduction
|
|
369
|
+
beta2 = beta2_t.to(g.dtype)
|
|
370
|
+
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
|
371
|
+
red_dim_size = g.size(red_dim)
|
|
372
|
+
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
|
373
|
+
v_norm = v_norm_sq.sqrt()
|
|
374
|
+
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
|
375
|
+
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
|
376
|
+
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
|
377
|
+
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
|
378
|
+
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
|
379
|
+
g = g * final_scale.to(g.dtype)
|
|
380
|
+
# Cautious weight decay + parameter update
|
|
381
|
+
lr = lr_t.to(g.dtype)
|
|
382
|
+
wd = wd_t.to(g.dtype)
|
|
383
|
+
mask = (g * stacked_params) >= 0
|
|
384
|
+
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
class MuonAdamW(torch.optim.Optimizer):
|
|
388
|
+
"""Combined optimizer: Muon for 2D matrix params, AdamW for others."""
|
|
389
|
+
|
|
390
|
+
def __init__(self, param_groups):
|
|
391
|
+
super().__init__(param_groups, defaults={})
|
|
392
|
+
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
|
393
|
+
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
394
|
+
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
395
|
+
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
396
|
+
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
397
|
+
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
398
|
+
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
399
|
+
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
400
|
+
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
401
|
+
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
402
|
+
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
403
|
+
|
|
404
|
+
def _step_adamw(self, group):
|
|
405
|
+
for p in group['params']:
|
|
406
|
+
if p.grad is None:
|
|
407
|
+
continue
|
|
408
|
+
grad = p.grad
|
|
409
|
+
state = self.state[p]
|
|
410
|
+
if not state:
|
|
411
|
+
state['step'] = 0
|
|
412
|
+
state['exp_avg'] = torch.zeros_like(p)
|
|
413
|
+
state['exp_avg_sq'] = torch.zeros_like(p)
|
|
414
|
+
state['step'] += 1
|
|
415
|
+
self._adamw_step_t.fill_(state['step'])
|
|
416
|
+
self._adamw_lr_t.fill_(group['lr'])
|
|
417
|
+
self._adamw_beta1_t.fill_(group['betas'][0])
|
|
418
|
+
self._adamw_beta2_t.fill_(group['betas'][1])
|
|
419
|
+
self._adamw_eps_t.fill_(group['eps'])
|
|
420
|
+
self._adamw_wd_t.fill_(group['weight_decay'])
|
|
421
|
+
adamw_step_fused(p, grad, state['exp_avg'], state['exp_avg_sq'],
|
|
422
|
+
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
|
|
423
|
+
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t)
|
|
424
|
+
|
|
425
|
+
def _step_muon(self, group):
|
|
426
|
+
params = group['params']
|
|
427
|
+
if not params:
|
|
428
|
+
return
|
|
429
|
+
p = params[0]
|
|
430
|
+
state = self.state[p]
|
|
431
|
+
num_params = len(params)
|
|
432
|
+
shape, device, dtype = p.shape, p.device, p.dtype
|
|
433
|
+
if "momentum_buffer" not in state:
|
|
434
|
+
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
|
435
|
+
if "second_momentum_buffer" not in state:
|
|
436
|
+
state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
|
|
437
|
+
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
|
438
|
+
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
|
439
|
+
stacked_grads = torch.stack([p.grad for p in params])
|
|
440
|
+
stacked_params = torch.stack(params)
|
|
441
|
+
self._muon_momentum_t.fill_(group["momentum"])
|
|
442
|
+
self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
|
443
|
+
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
|
444
|
+
self._muon_wd_t.fill_(group["weight_decay"])
|
|
445
|
+
muon_step_fused(stacked_grads, stacked_params,
|
|
446
|
+
state["momentum_buffer"], state["second_momentum_buffer"],
|
|
447
|
+
self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t,
|
|
448
|
+
self._muon_beta2_t, group["ns_steps"], red_dim)
|
|
449
|
+
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
|
450
|
+
|
|
451
|
+
@torch.no_grad()
|
|
452
|
+
def step(self):
|
|
453
|
+
for group in self.param_groups:
|
|
454
|
+
if group['kind'] == 'adamw':
|
|
455
|
+
self._step_adamw(group)
|
|
456
|
+
elif group['kind'] == 'muon':
|
|
457
|
+
self._step_muon(group)
|
|
458
|
+
|
|
459
|
+
# ---------------------------------------------------------------------------
|
|
460
|
+
# Hyperparameters (edit these directly, no CLI flags needed)
|
|
461
|
+
# ---------------------------------------------------------------------------
|
|
462
|
+
|
|
463
|
+
# Model architecture
|
|
464
|
+
ASPECT_RATIO = 64 # model_dim = depth * ASPECT_RATIO
|
|
465
|
+
HEAD_DIM = 128 # target head dimension for attention
|
|
466
|
+
WINDOW_PATTERN = "SSSL" # sliding window pattern: L=full, S=half context
|
|
467
|
+
|
|
468
|
+
# Optimization
|
|
469
|
+
TOTAL_BATCH_SIZE = 2**19 # ~524K tokens per optimizer step
|
|
470
|
+
EMBEDDING_LR = 0.6 # learning rate for token embeddings (Adam)
|
|
471
|
+
UNEMBEDDING_LR = 0.004 # learning rate for lm_head (Adam)
|
|
472
|
+
MATRIX_LR = 0.04 # learning rate for matrix parameters (Muon)
|
|
473
|
+
SCALAR_LR = 0.5 # learning rate for per-layer scalars (Adam)
|
|
474
|
+
WEIGHT_DECAY = 0.2 # cautious weight decay for Muon
|
|
475
|
+
ADAM_BETAS = (0.8, 0.95) # Adam beta1, beta2
|
|
476
|
+
WARMUP_RATIO = 0.0 # fraction of time budget for LR warmup
|
|
477
|
+
WARMDOWN_RATIO = 0.5 # fraction of time budget for LR warmdown
|
|
478
|
+
FINAL_LR_FRAC = 0.0 # final LR as fraction of initial
|
|
479
|
+
|
|
480
|
+
# Model size
|
|
481
|
+
DEPTH = 8 # number of transformer layers
|
|
482
|
+
DEVICE_BATCH_SIZE = 128 # per-device batch size (reduce if OOM)
|
|
483
|
+
|
|
484
|
+
# ---------------------------------------------------------------------------
|
|
485
|
+
# Setup: tokenizer, model, optimizer, dataloader
|
|
486
|
+
# ---------------------------------------------------------------------------
|
|
487
|
+
|
|
488
|
+
t_start = time.time()
|
|
489
|
+
torch.manual_seed(42)
|
|
490
|
+
torch.cuda.manual_seed(42)
|
|
491
|
+
torch.set_float32_matmul_precision("high")
|
|
492
|
+
device = torch.device("cuda")
|
|
493
|
+
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
|
494
|
+
H100_BF16_PEAK_FLOPS = 989.5e12
|
|
495
|
+
|
|
496
|
+
tokenizer = Tokenizer.from_directory()
|
|
497
|
+
vocab_size = tokenizer.get_vocab_size()
|
|
498
|
+
print(f"Vocab size: {vocab_size:,}")
|
|
499
|
+
|
|
500
|
+
def build_model_config(depth):
|
|
501
|
+
base_dim = depth * ASPECT_RATIO
|
|
502
|
+
model_dim = ((base_dim + HEAD_DIM - 1) // HEAD_DIM) * HEAD_DIM
|
|
503
|
+
num_heads = model_dim // HEAD_DIM
|
|
504
|
+
return GPTConfig(
|
|
505
|
+
sequence_len=MAX_SEQ_LEN, vocab_size=vocab_size,
|
|
506
|
+
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
|
|
507
|
+
window_pattern=WINDOW_PATTERN,
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
config = build_model_config(DEPTH)
|
|
511
|
+
print(f"Model config: {asdict(config)}")
|
|
512
|
+
|
|
513
|
+
with torch.device("meta"):
|
|
514
|
+
model = GPT(config)
|
|
515
|
+
model.to_empty(device=device)
|
|
516
|
+
model.init_weights()
|
|
517
|
+
|
|
518
|
+
param_counts = model.num_scaling_params()
|
|
519
|
+
print("Parameter counts:")
|
|
520
|
+
for key, value in param_counts.items():
|
|
521
|
+
print(f" {key:24s}: {value:,}")
|
|
522
|
+
num_params = param_counts['total']
|
|
523
|
+
num_flops_per_token = model.estimate_flops()
|
|
524
|
+
print(f"Estimated FLOPs per token: {num_flops_per_token:e}")
|
|
525
|
+
|
|
526
|
+
tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN
|
|
527
|
+
assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0
|
|
528
|
+
grad_accum_steps = TOTAL_BATCH_SIZE // tokens_per_fwdbwd
|
|
529
|
+
|
|
530
|
+
optimizer = model.setup_optimizer(
|
|
531
|
+
unembedding_lr=UNEMBEDDING_LR,
|
|
532
|
+
embedding_lr=EMBEDDING_LR,
|
|
533
|
+
scalar_lr=SCALAR_LR,
|
|
534
|
+
adam_betas=ADAM_BETAS,
|
|
535
|
+
matrix_lr=MATRIX_LR,
|
|
536
|
+
weight_decay=WEIGHT_DECAY,
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
model = torch.compile(model, dynamic=False)
|
|
540
|
+
|
|
541
|
+
train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, MAX_SEQ_LEN, "train")
|
|
542
|
+
x, y, epoch = next(train_loader) # prefetch first batch
|
|
543
|
+
|
|
544
|
+
print(f"Time budget: {TIME_BUDGET}s")
|
|
545
|
+
print(f"Gradient accumulation steps: {grad_accum_steps}")
|
|
546
|
+
|
|
547
|
+
# Schedules (all based on progress = training_time / TIME_BUDGET)
|
|
548
|
+
|
|
549
|
+
def get_lr_multiplier(progress):
|
|
550
|
+
if progress < WARMUP_RATIO:
|
|
551
|
+
return progress / WARMUP_RATIO if WARMUP_RATIO > 0 else 1.0
|
|
552
|
+
elif progress < 1.0 - WARMDOWN_RATIO:
|
|
553
|
+
return 1.0
|
|
554
|
+
else:
|
|
555
|
+
cooldown = (1.0 - progress) / WARMDOWN_RATIO
|
|
556
|
+
return cooldown * 1.0 + (1 - cooldown) * FINAL_LR_FRAC
|
|
557
|
+
|
|
558
|
+
def get_muon_momentum(step):
|
|
559
|
+
frac = min(step / 300, 1)
|
|
560
|
+
return (1 - frac) * 0.85 + frac * 0.95
|
|
561
|
+
|
|
562
|
+
def get_weight_decay(progress):
|
|
563
|
+
return WEIGHT_DECAY * (1 - progress)
|
|
564
|
+
|
|
565
|
+
# ---------------------------------------------------------------------------
|
|
566
|
+
# Training loop
|
|
567
|
+
# ---------------------------------------------------------------------------
|
|
568
|
+
|
|
569
|
+
t_start_training = time.time()
|
|
570
|
+
smooth_train_loss = 0
|
|
571
|
+
total_training_time = 0
|
|
572
|
+
step = 0
|
|
573
|
+
|
|
574
|
+
while True:
|
|
575
|
+
torch.cuda.synchronize()
|
|
576
|
+
t0 = time.time()
|
|
577
|
+
for micro_step in range(grad_accum_steps):
|
|
578
|
+
with autocast_ctx:
|
|
579
|
+
loss = model(x, y)
|
|
580
|
+
train_loss = loss.detach()
|
|
581
|
+
loss = loss / grad_accum_steps
|
|
582
|
+
loss.backward()
|
|
583
|
+
x, y, epoch = next(train_loader)
|
|
584
|
+
|
|
585
|
+
# Progress and schedules
|
|
586
|
+
progress = min(total_training_time / TIME_BUDGET, 1.0)
|
|
587
|
+
lrm = get_lr_multiplier(progress)
|
|
588
|
+
muon_momentum = get_muon_momentum(step)
|
|
589
|
+
muon_weight_decay = get_weight_decay(progress)
|
|
590
|
+
for group in optimizer.param_groups:
|
|
591
|
+
group["lr"] = group["initial_lr"] * lrm
|
|
592
|
+
if group['kind'] == 'muon':
|
|
593
|
+
group["momentum"] = muon_momentum
|
|
594
|
+
group["weight_decay"] = muon_weight_decay
|
|
595
|
+
optimizer.step()
|
|
596
|
+
model.zero_grad(set_to_none=True)
|
|
597
|
+
|
|
598
|
+
train_loss_f = train_loss.item()
|
|
599
|
+
|
|
600
|
+
# Fast fail: abort if loss is exploding or NaN
|
|
601
|
+
if math.isnan(train_loss_f) or train_loss_f > 100:
|
|
602
|
+
print("FAIL")
|
|
603
|
+
exit(1)
|
|
604
|
+
|
|
605
|
+
torch.cuda.synchronize()
|
|
606
|
+
t1 = time.time()
|
|
607
|
+
dt = t1 - t0
|
|
608
|
+
|
|
609
|
+
if step > 10:
|
|
610
|
+
total_training_time += dt
|
|
611
|
+
|
|
612
|
+
# Logging
|
|
613
|
+
ema_beta = 0.9
|
|
614
|
+
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f
|
|
615
|
+
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1))
|
|
616
|
+
pct_done = 100 * progress
|
|
617
|
+
tok_per_sec = int(TOTAL_BATCH_SIZE / dt)
|
|
618
|
+
mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE / dt / H100_BF16_PEAK_FLOPS
|
|
619
|
+
remaining = max(0, TIME_BUDGET - total_training_time)
|
|
620
|
+
|
|
621
|
+
print(f"\rstep {step:05d} ({pct_done:.1f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt*1000:.0f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.1f}% | epoch: {epoch} | remaining: {remaining:.0f}s ", end="", flush=True)
|
|
622
|
+
|
|
623
|
+
# GC management (Python's GC causes ~500ms stalls)
|
|
624
|
+
if step == 0:
|
|
625
|
+
gc.collect()
|
|
626
|
+
gc.freeze()
|
|
627
|
+
gc.disable()
|
|
628
|
+
elif (step + 1) % 5000 == 0:
|
|
629
|
+
gc.collect()
|
|
630
|
+
|
|
631
|
+
step += 1
|
|
632
|
+
|
|
633
|
+
# Time's up — but only stop after warmup steps so we don't count compilation
|
|
634
|
+
if step > 10 and total_training_time >= TIME_BUDGET:
|
|
635
|
+
break
|
|
636
|
+
|
|
637
|
+
print() # newline after \r training log
|
|
638
|
+
|
|
639
|
+
total_tokens = step * TOTAL_BATCH_SIZE
|
|
640
|
+
|
|
641
|
+
# Final eval
|
|
642
|
+
model.eval()
|
|
643
|
+
with autocast_ctx:
|
|
644
|
+
val_bpb = evaluate_bpb(model, tokenizer, DEVICE_BATCH_SIZE)
|
|
645
|
+
|
|
646
|
+
# Final summary
|
|
647
|
+
t_end = time.time()
|
|
648
|
+
startup_time = t_start_training - t_start
|
|
649
|
+
steady_state_mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE * (step - 10) / total_training_time / H100_BF16_PEAK_FLOPS if total_training_time > 0 else 0
|
|
650
|
+
peak_vram_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
|
|
651
|
+
|
|
652
|
+
print("---")
|
|
653
|
+
print(f"val_bpb: {val_bpb:.6f}")
|
|
654
|
+
print(f"training_seconds: {total_training_time:.1f}")
|
|
655
|
+
print(f"total_seconds: {t_end - t_start:.1f}")
|
|
656
|
+
print(f"peak_vram_mb: {peak_vram_mb:.1f}")
|
|
657
|
+
print(f"mfu_percent: {steady_state_mfu:.2f}")
|
|
658
|
+
print(f"total_tokens_M: {total_tokens / 1e6:.1f}")
|
|
659
|
+
print(f"num_steps: {step}")
|
|
660
|
+
print(f"num_params_M: {num_params / 1e6:.1f}")
|
|
661
|
+
print(f"depth: {DEPTH}")
|