freesolo-flash-dev 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flash/__init__.py +29 -0
- flash/_channel.py +23 -0
- flash/_fileio.py +35 -0
- flash/_logging.py +49 -0
- flash/_update_check.py +266 -0
- flash/catalog.py +253 -0
- flash/cli/__init__.py +1 -0
- flash/cli/main/__init__.py +227 -0
- flash/cli/main/__main__.py +6 -0
- flash/cli/main/commands.py +636 -0
- flash/cli/main/envpush.py +317 -0
- flash/cli/main/render.py +599 -0
- flash/cli/main/training_doc.py +455 -0
- flash/client/__init__.py +14 -0
- flash/client/config.py +70 -0
- flash/client/http.py +372 -0
- flash/client/runtime_secrets.py +69 -0
- flash/client/specs.py +20 -0
- flash/cost/__init__.py +16 -0
- flash/cost/analytical.py +175 -0
- flash/cost/facts.py +114 -0
- flash/cost/spec.py +113 -0
- flash/cost/types.py +158 -0
- flash/engine/__init__.py +6 -0
- flash/engine/accounting.py +36 -0
- flash/engine/chalk_kernels.py +116 -0
- flash/engine/multiturn_rollout.py +780 -0
- flash/engine/recipe.py +86 -0
- flash/engine/vram.py +603 -0
- flash/engine/worker/__init__.py +2916 -0
- flash/engine/worker/__main__.py +4 -0
- flash/engine/worker/kernel_warmup.py +400 -0
- flash/engine/worker/lora.py +796 -0
- flash/engine/worker/packing.py +366 -0
- flash/engine/worker/perf.py +1048 -0
- flash/envs/__init__.py +10 -0
- flash/envs/adapter/__init__.py +883 -0
- flash/envs/adapter/rubric.py +222 -0
- flash/envs/base.py +52 -0
- flash/envs/registry.py +62 -0
- flash/mcp/__init__.py +1 -0
- flash/mcp/server.py +85 -0
- flash/providers/__init__.py +59 -0
- flash/providers/_auth.py +24 -0
- flash/providers/_http.py +230 -0
- flash/providers/_instance.py +416 -0
- flash/providers/_instance_bootstrap.py +517 -0
- flash/providers/_poll.py +311 -0
- flash/providers/allocator.py +193 -0
- flash/providers/base.py +431 -0
- flash/providers/hyperstack/__init__.py +127 -0
- flash/providers/hyperstack/api.py +522 -0
- flash/providers/hyperstack/auth.py +17 -0
- flash/providers/hyperstack/gpus.py +29 -0
- flash/providers/hyperstack/jobs/__init__.py +632 -0
- flash/providers/hyperstack/jobs/builders.py +122 -0
- flash/providers/hyperstack/preflight.py +23 -0
- flash/providers/hyperstack/pricing.py +26 -0
- flash/providers/hyperstack/train.py +25 -0
- flash/providers/lambdalabs/__init__.py +139 -0
- flash/providers/lambdalabs/api.py +261 -0
- flash/providers/lambdalabs/auth.py +18 -0
- flash/providers/lambdalabs/gpus.py +29 -0
- flash/providers/lambdalabs/jobs/__init__.py +724 -0
- flash/providers/lambdalabs/jobs/builders.py +118 -0
- flash/providers/lambdalabs/preflight.py +27 -0
- flash/providers/lambdalabs/pricing.py +51 -0
- flash/providers/lambdalabs/train.py +27 -0
- flash/providers/preflight.py +55 -0
- flash/providers/realized.py +80 -0
- flash/providers/runpod/__init__.py +130 -0
- flash/providers/runpod/api.py +186 -0
- flash/providers/runpod/auth.py +37 -0
- flash/providers/runpod/cost.py +57 -0
- flash/providers/runpod/gpus.py +46 -0
- flash/providers/runpod/jobs.py +956 -0
- flash/providers/runpod/keys.py +139 -0
- flash/providers/runpod/preflight.py +30 -0
- flash/providers/runpod/preload.py +915 -0
- flash/providers/runpod/pricing.py +18 -0
- flash/providers/runpod/slots.py +79 -0
- flash/providers/runpod/train/__init__.py +150 -0
- flash/providers/runpod/train/deps.py +395 -0
- flash/providers/runpod/train/endpoints.py +820 -0
- flash/py.typed +0 -0
- flash/runner/__init__.py +686 -0
- flash/runner/checkpoints.py +82 -0
- flash/runner/deploy.py +422 -0
- flash/runner/lifecycle.py +672 -0
- flash/schema/__init__.py +375 -0
- flash/schema/fields.py +331 -0
- flash/serve/__init__.py +1 -0
- flash/serve/deploy.py +326 -0
- flash/serve/pricing.py +60 -0
- flash/server/__init__.py +1 -0
- flash/server/__main__.py +20 -0
- flash/server/app.py +961 -0
- flash/server/auth.py +263 -0
- flash/server/billing.py +124 -0
- flash/server/checkpoints.py +110 -0
- flash/server/db.py +160 -0
- flash/server/environment_registry.py +102 -0
- flash/server/envs.py +360 -0
- flash/server/reconcile.py +163 -0
- flash/server/run_registry.py +150 -0
- flash/spec.py +333 -0
- freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
- freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
- freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
- freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
- freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,400 @@
|
|
|
1
|
+
"""Pre-compile the worker's hot kernels on a real GPU and persist a portable mega-cache.
|
|
2
|
+
|
|
3
|
+
Run this ON A GPU BUILDER (an image-build runner that actually has the target arch's GPU) to kill
|
|
4
|
+
the ~10-15 min first-use JIT that #194 reintroduced on a cold worker. It warms the kernels the
|
|
5
|
+
trainer hits early — FlashAttention fwd/bwd, the Liger fused cross-entropy, flash-linear-attention's
|
|
6
|
+
Gated-DeltaNet (Qwen3.5/3.6 hybrid), and a representative ``torch.compile`` — then calls
|
|
7
|
+
``torch.compiler.save_cache_artifacts()`` to write ONE portable mega-cache blob into the cache dir.
|
|
8
|
+
``flash.engine.worker._load_kernel_cache_if_present`` loads it back at worker boot
|
|
9
|
+
(``torch.compiler.load_cache_artifacts``); the Dockerfile bakes the produced ``build/kernel_cache/``
|
|
10
|
+
into the image when built with ``--build-arg BUILD_KERNEL_CACHE=true``.
|
|
11
|
+
|
|
12
|
+
Measured: cold compile ~124s -> warm load ~0.2s (537x).
|
|
13
|
+
|
|
14
|
+
This module is import-safe WITHOUT torch installed (it must ``py_compile`` on the CPU-only CI image
|
|
15
|
+
that builds the worker): every heavy import lives INSIDE a function. Everything is best-effort —
|
|
16
|
+
each warm step is independently guarded so a missing/uncompilable kernel never aborts the bake; we
|
|
17
|
+
save whatever did compile. CLI: ``python -m flash.engine.worker.kernel_warmup --arch <sm> --out <dir>``.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import argparse
|
|
23
|
+
import json
|
|
24
|
+
import os
|
|
25
|
+
import time
|
|
26
|
+
|
|
27
|
+
# Default bake dir. Mirrors the Dockerfile's /opt/flash/kernelcache; the saved mega-cache file lands
|
|
28
|
+
# directly under it so _load_kernel_cache_if_present finds it. Keep this name in lockstep with
|
|
29
|
+
# engine.worker._KERNEL_CACHE_DIR / _KERNEL_CACHE_FILE.
|
|
30
|
+
DEFAULT_CACHE_DIR = "/opt/flash/kernelcache"
|
|
31
|
+
MEGA_CACHE_FILENAME = "mega_cache.bin"
|
|
32
|
+
MEGA_CACHE_META_FILENAME = "mega_cache.json"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _log(msg: str) -> None:
|
|
36
|
+
"""Single progress channel so the GPU builder's logs show each warm step."""
|
|
37
|
+
print(f"[kernel-warmup] {msg}", flush=True)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _point_backends_at(cache_dir: str) -> None:
|
|
41
|
+
"""Point Triton + TorchInductor at ``cache_dir`` so anything compiled below is content-addressed
|
|
42
|
+
under the same tree the worker reads (matches the Dockerfile ENV)."""
|
|
43
|
+
os.makedirs(os.path.join(cache_dir, "triton"), exist_ok=True)
|
|
44
|
+
os.makedirs(os.path.join(cache_dir, "inductor"), exist_ok=True)
|
|
45
|
+
os.environ["TRITON_CACHE_DIR"] = os.path.join(cache_dir, "triton")
|
|
46
|
+
os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(cache_dir, "inductor")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _torch_sm(torch) -> str:
|
|
50
|
+
cap = torch.cuda.get_device_capability(0)
|
|
51
|
+
return f"sm{cap[0]}{cap[1]}"
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _require_gpu():
|
|
55
|
+
"""Return the torch module if a CUDA GPU is live, else None (with a clear log).
|
|
56
|
+
|
|
57
|
+
The warm steps are only meaningful on a real GPU of the target arch — kernels are
|
|
58
|
+
content-addressed by arch + toolchain, so a CPU run would bake nothing usable.
|
|
59
|
+
"""
|
|
60
|
+
try:
|
|
61
|
+
import torch
|
|
62
|
+
|
|
63
|
+
if not torch.cuda.is_available():
|
|
64
|
+
_log("no CUDA device visible — kernel warmup must run on a GPU builder; nothing baked")
|
|
65
|
+
return None
|
|
66
|
+
_log(f"GPU: {torch.cuda.get_device_name(0)} ({_torch_sm(torch)}), torch {torch.__version__}")
|
|
67
|
+
return torch
|
|
68
|
+
except Exception as e:
|
|
69
|
+
_log(f"torch unavailable ({e}); cannot warm kernels")
|
|
70
|
+
return None
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def warm_flash_attn(torch) -> bool:
|
|
74
|
+
"""Compile FlashAttention fwd + bwd (FA2 everywhere, + FA3 on Hopper) with one tiny attention."""
|
|
75
|
+
warmed = False
|
|
76
|
+
try:
|
|
77
|
+
from flash_attn import flash_attn_func
|
|
78
|
+
|
|
79
|
+
q, k, v = (
|
|
80
|
+
torch.randn(1, 64, 4, 64, device="cuda", dtype=torch.bfloat16, requires_grad=True)
|
|
81
|
+
for _ in range(3)
|
|
82
|
+
)
|
|
83
|
+
out = flash_attn_func(q, k, v, causal=True)
|
|
84
|
+
out.sum().backward() # exercise the bwd kernel too
|
|
85
|
+
torch.cuda.synchronize()
|
|
86
|
+
_log("flash-attn (FA2) fwd/bwd compiled")
|
|
87
|
+
warmed = True
|
|
88
|
+
except Exception as e:
|
|
89
|
+
_log(f"flash-attn (FA2) warm skipped: {e}")
|
|
90
|
+
try:
|
|
91
|
+
# FA3: on Hopper (sm90) the worker selects attn_implementation="flash_attention_3" (the local
|
|
92
|
+
# flash_attn_interface build) for full-attention layers, so a baked H100 cache must cover it
|
|
93
|
+
# too. a no-op/skip off-Hopper (the kernel is Hopper-only; the wheel just rides along).
|
|
94
|
+
import flash_attn_interface
|
|
95
|
+
|
|
96
|
+
q, k, v = (
|
|
97
|
+
torch.randn(1, 64, 4, 64, device="cuda", dtype=torch.bfloat16, requires_grad=True)
|
|
98
|
+
for _ in range(3)
|
|
99
|
+
)
|
|
100
|
+
out = flash_attn_interface.flash_attn_func(q, k, v, causal=True)
|
|
101
|
+
if isinstance(out, tuple):
|
|
102
|
+
out = out[0]
|
|
103
|
+
out.sum().backward()
|
|
104
|
+
torch.cuda.synchronize()
|
|
105
|
+
_log("flash-attn-3 (Hopper) fwd/bwd compiled")
|
|
106
|
+
warmed = True
|
|
107
|
+
except Exception as e:
|
|
108
|
+
_log(f"flash-attn-3 warm skipped (expected off-Hopper): {e}")
|
|
109
|
+
return warmed
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def warm_liger_ce(torch) -> bool:
|
|
113
|
+
"""Compile Liger cross-entropy kernels."""
|
|
114
|
+
warmed = False
|
|
115
|
+
try:
|
|
116
|
+
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
|
117
|
+
|
|
118
|
+
loss_fn = LigerCrossEntropyLoss()
|
|
119
|
+
logits = torch.randn(64, 4096, device="cuda", dtype=torch.bfloat16, requires_grad=True)
|
|
120
|
+
labels = torch.randint(0, 4096, (64,), device="cuda")
|
|
121
|
+
loss_fn(logits, labels).backward()
|
|
122
|
+
torch.cuda.synchronize()
|
|
123
|
+
_log("liger fused cross-entropy compiled")
|
|
124
|
+
warmed = True
|
|
125
|
+
except Exception as e:
|
|
126
|
+
_log(f"liger CE warm skipped: {e}")
|
|
127
|
+
try:
|
|
128
|
+
candidates = (
|
|
129
|
+
("liger_kernel.ops.fused_linear_cross_entropy", "LigerFusedLinearCrossEntropyLoss"),
|
|
130
|
+
(
|
|
131
|
+
"liger_kernel.transformers.fused_linear_cross_entropy",
|
|
132
|
+
"LigerFusedLinearCrossEntropyLoss",
|
|
133
|
+
),
|
|
134
|
+
)
|
|
135
|
+
loss_cls = None
|
|
136
|
+
for module_name, attr in candidates:
|
|
137
|
+
try:
|
|
138
|
+
mod = __import__(module_name, fromlist=[attr])
|
|
139
|
+
loss_cls = getattr(mod, attr)
|
|
140
|
+
break
|
|
141
|
+
except Exception:
|
|
142
|
+
continue
|
|
143
|
+
if loss_cls is None:
|
|
144
|
+
raise ImportError("no fused-linear Liger loss class found")
|
|
145
|
+
# representative catalog vocab width (qwen3.5/3.6 lm_head ~248k); triton/liger specialize the
|
|
146
|
+
# fused-ce chunking to the vocab shape, so warm the production width, not a toy 4096.
|
|
147
|
+
vocab = 248_320
|
|
148
|
+
hidden = torch.randn(64, 256, device="cuda", dtype=torch.bfloat16, requires_grad=True)
|
|
149
|
+
weight = torch.randn(vocab, 256, device="cuda", dtype=torch.bfloat16, requires_grad=True)
|
|
150
|
+
labels = torch.randint(0, vocab, (64,), device="cuda")
|
|
151
|
+
loss_fn = loss_cls()
|
|
152
|
+
# upstream signature is forward(self, lin_weight, _input, target): weight first, then hidden.
|
|
153
|
+
# call the known-good form first so we never launch a mismatched shape (which would trigger a
|
|
154
|
+
# cuda illegal access and poison the context before a later attempt can run).
|
|
155
|
+
attempts = (
|
|
156
|
+
lambda: loss_fn(weight, hidden, labels),
|
|
157
|
+
lambda: loss_fn(weight, hidden, target=labels),
|
|
158
|
+
lambda: loss_fn(hidden, weight, labels),
|
|
159
|
+
)
|
|
160
|
+
for call in attempts:
|
|
161
|
+
try:
|
|
162
|
+
out = call()
|
|
163
|
+
if isinstance(out, tuple):
|
|
164
|
+
out = out[0]
|
|
165
|
+
out.backward()
|
|
166
|
+
torch.cuda.synchronize()
|
|
167
|
+
_log("liger fused-linear loss compiled")
|
|
168
|
+
warmed = True
|
|
169
|
+
break # fall through to the model-layer warm below; don't exit the function
|
|
170
|
+
except Exception:
|
|
171
|
+
continue
|
|
172
|
+
else:
|
|
173
|
+
raise RuntimeError("fused-linear Liger calls were not accepted")
|
|
174
|
+
except Exception as e:
|
|
175
|
+
_log(f"liger fused-linear warm skipped: {e}")
|
|
176
|
+
try:
|
|
177
|
+
# model-layer liger kernels (rmsnorm + rope) that use_liger_kernel patches in besides the
|
|
178
|
+
# loss; these still jit on the first real forward/backward if only the ce loss was warmed.
|
|
179
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
180
|
+
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
181
|
+
|
|
182
|
+
rms = LigerRMSNorm(hidden_size=256).to(device="cuda", dtype=torch.bfloat16)
|
|
183
|
+
x = torch.randn(64, 256, device="cuda", dtype=torch.bfloat16, requires_grad=True)
|
|
184
|
+
rms(x).sum().backward()
|
|
185
|
+
b, h, t, d = 1, 4, 64, 64
|
|
186
|
+
q = torch.randn(b, h, t, d, device="cuda", dtype=torch.bfloat16)
|
|
187
|
+
k = torch.randn(b, h, t, d, device="cuda", dtype=torch.bfloat16)
|
|
188
|
+
cos = torch.randn(b, t, d, device="cuda", dtype=torch.bfloat16)
|
|
189
|
+
sin = torch.randn(b, t, d, device="cuda", dtype=torch.bfloat16)
|
|
190
|
+
liger_rotary_pos_emb(q, k, cos, sin)
|
|
191
|
+
torch.cuda.synchronize()
|
|
192
|
+
_log("liger model-layer kernels (rmsnorm/rope) compiled")
|
|
193
|
+
warmed = True
|
|
194
|
+
except Exception as e:
|
|
195
|
+
_log(f"liger model-layer warm skipped: {e}")
|
|
196
|
+
return warmed
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def warm_fla_gdn(torch) -> bool:
|
|
200
|
+
"""Compile flash-linear-attention's Gated-DeltaNet chunk kernels (Qwen3.5/3.6 hybrid path)."""
|
|
201
|
+
try:
|
|
202
|
+
# mirror the worker boot: on Hopper (sm90) the production path runs this first so fla's GDN
|
|
203
|
+
# chunk_bwd uses the CORRECT tilelang backend (fla #640: the Triton path miscomputes/raises).
|
|
204
|
+
# off-Hopper this is a no-op. bake the same backend the runtime will actually select.
|
|
205
|
+
try:
|
|
206
|
+
from flash.engine.worker.perf import _ensure_fla_fastpath_on_hopper
|
|
207
|
+
|
|
208
|
+
_ensure_fla_fastpath_on_hopper()
|
|
209
|
+
except Exception as e:
|
|
210
|
+
_log(f"hopper fla fast-path setup skipped: {e}")
|
|
211
|
+
from fla.ops.gated_delta_rule import chunk_gated_delta_rule
|
|
212
|
+
|
|
213
|
+
b, h, t, d = 1, 4, 256, 64
|
|
214
|
+
q = torch.randn(b, t, h, d, device="cuda", dtype=torch.bfloat16, requires_grad=True)
|
|
215
|
+
k = torch.randn(b, t, h, d, device="cuda", dtype=torch.bfloat16, requires_grad=True)
|
|
216
|
+
v = torch.randn(b, t, h, d, device="cuda", dtype=torch.bfloat16, requires_grad=True)
|
|
217
|
+
g = torch.randn(b, t, h, device="cuda", dtype=torch.float32, requires_grad=True)
|
|
218
|
+
beta = torch.rand(b, t, h, device="cuda", dtype=torch.bfloat16, requires_grad=True)
|
|
219
|
+
out = chunk_gated_delta_rule(q, k, v, g, beta, use_qk_l2norm_in_kernel=True)
|
|
220
|
+
if isinstance(out, tuple):
|
|
221
|
+
out = out[0]
|
|
222
|
+
out.sum().backward()
|
|
223
|
+
torch.cuda.synchronize()
|
|
224
|
+
_log("flash-linear-attention GDN fwd/bwd compiled")
|
|
225
|
+
# also warm the varlen path SFT token-packing (#218) uses: BlockDiagonalCollator(emit_varlen=True)
|
|
226
|
+
# feeds cu_seq_lens into the fla DeltaNet so the recurrence resets per packed example, which
|
|
227
|
+
# compiles different chunk kernels than the equal-length call above. shape it like one packed
|
|
228
|
+
# block (batch flattened to 1) with two unequal segments.
|
|
229
|
+
try:
|
|
230
|
+
tv = 128 + 96 # two example lengths packed into one sequence
|
|
231
|
+
qv = torch.randn(1, tv, h, d, device="cuda", dtype=torch.bfloat16, requires_grad=True)
|
|
232
|
+
kv = torch.randn(1, tv, h, d, device="cuda", dtype=torch.bfloat16, requires_grad=True)
|
|
233
|
+
vv = torch.randn(1, tv, h, d, device="cuda", dtype=torch.bfloat16, requires_grad=True)
|
|
234
|
+
gv = torch.randn(1, tv, h, device="cuda", dtype=torch.float32, requires_grad=True)
|
|
235
|
+
betav = torch.rand(1, tv, h, device="cuda", dtype=torch.bfloat16, requires_grad=True)
|
|
236
|
+
cu = torch.tensor([0, 128, tv], device="cuda", dtype=torch.int32)
|
|
237
|
+
outv = chunk_gated_delta_rule(
|
|
238
|
+
qv, kv, vv, gv, betav, use_qk_l2norm_in_kernel=True, cu_seqlens=cu
|
|
239
|
+
)
|
|
240
|
+
if isinstance(outv, tuple):
|
|
241
|
+
outv = outv[0]
|
|
242
|
+
outv.sum().backward()
|
|
243
|
+
torch.cuda.synchronize()
|
|
244
|
+
_log("flash-linear-attention GDN varlen (cu_seqlens) fwd/bwd compiled")
|
|
245
|
+
except Exception as e:
|
|
246
|
+
_log(f"fla GDN varlen warm skipped: {e}")
|
|
247
|
+
return True
|
|
248
|
+
except Exception as e:
|
|
249
|
+
_log(f"fla GDN warm skipped: {e}")
|
|
250
|
+
return False
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def warm_chalk_kernels() -> bool:
|
|
254
|
+
"""Compile default chalk self-test kernels when freesolo-chalk is installed."""
|
|
255
|
+
warmed = False
|
|
256
|
+
try:
|
|
257
|
+
from chalk.transformers import install_fused_lora_delta, install_qwen35_rope
|
|
258
|
+
|
|
259
|
+
warmed = bool(install_qwen35_rope()) or warmed
|
|
260
|
+
warmed = bool(install_fused_lora_delta()) or warmed
|
|
261
|
+
# embedding gather is chalk's third default gap-filler (chalk_kernels._KERNELS); import it
|
|
262
|
+
# separately so a name/version skew can't also drop the rope + lora-delta warms above.
|
|
263
|
+
try:
|
|
264
|
+
from chalk.transformers import install_fused_embedding
|
|
265
|
+
|
|
266
|
+
warmed = bool(install_fused_embedding()) or warmed
|
|
267
|
+
except Exception as e:
|
|
268
|
+
_log(f"chalk fused-embedding warm skipped: {e}")
|
|
269
|
+
_log(f"chalk default kernel installers ran (warmed={warmed})")
|
|
270
|
+
except Exception as e:
|
|
271
|
+
_log(f"chalk warm skipped: {e}")
|
|
272
|
+
return warmed
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def warm_torch_compile(torch) -> bool:
|
|
276
|
+
"""Trigger a representative ``torch.compile`` so TorchInductor populates its cache."""
|
|
277
|
+
try:
|
|
278
|
+
|
|
279
|
+
@torch.compile
|
|
280
|
+
def _fused(a, b):
|
|
281
|
+
return torch.nn.functional.gelu(a @ b)
|
|
282
|
+
|
|
283
|
+
a = torch.randn(256, 256, device="cuda", dtype=torch.bfloat16)
|
|
284
|
+
b = torch.randn(256, 256, device="cuda", dtype=torch.bfloat16)
|
|
285
|
+
_fused(a, b)
|
|
286
|
+
torch.cuda.synchronize()
|
|
287
|
+
_log("torch.compile (TorchInductor) warmed")
|
|
288
|
+
return True
|
|
289
|
+
except Exception as e:
|
|
290
|
+
_log(f"torch.compile warm skipped: {e}")
|
|
291
|
+
return False
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def save_mega_cache(torch, out_dir: str) -> bool:
|
|
295
|
+
"""Persist everything compiled this session into one portable blob via
|
|
296
|
+
``torch.compiler.save_cache_artifacts()`` so the worker can ``load_cache_artifacts`` it at boot.
|
|
297
|
+
"""
|
|
298
|
+
try:
|
|
299
|
+
artifacts = torch.compiler.save_cache_artifacts()
|
|
300
|
+
if not artifacts:
|
|
301
|
+
_log("save_cache_artifacts returned nothing — no compiled kernels to persist")
|
|
302
|
+
return False
|
|
303
|
+
# save_cache_artifacts returns (bytes, meta); persist the bytes payload.
|
|
304
|
+
blob = artifacts[0] if isinstance(artifacts, tuple) else artifacts
|
|
305
|
+
os.makedirs(out_dir, exist_ok=True)
|
|
306
|
+
path = os.path.join(out_dir, MEGA_CACHE_FILENAME)
|
|
307
|
+
with open(path, "wb") as f:
|
|
308
|
+
f.write(blob)
|
|
309
|
+
_log(f"mega-cache saved: {path} ({len(blob)} bytes)")
|
|
310
|
+
return True
|
|
311
|
+
except Exception as e:
|
|
312
|
+
_log(f"save_cache_artifacts failed: {e}")
|
|
313
|
+
return False
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def save_cache_metadata(torch, out_dir: str, *, requested_arch: str | None, warmed: int) -> bool:
|
|
317
|
+
try:
|
|
318
|
+
meta = {
|
|
319
|
+
"sm": _torch_sm(torch),
|
|
320
|
+
"requested_arch": requested_arch,
|
|
321
|
+
"torch": getattr(torch, "__version__", "unknown"),
|
|
322
|
+
"cuda": getattr(getattr(torch, "version", None), "cuda", None),
|
|
323
|
+
"device": torch.cuda.get_device_name(0),
|
|
324
|
+
"warmed_groups": int(warmed),
|
|
325
|
+
"created_at": int(time.time()),
|
|
326
|
+
}
|
|
327
|
+
os.makedirs(out_dir, exist_ok=True)
|
|
328
|
+
path = os.path.join(out_dir, MEGA_CACHE_META_FILENAME)
|
|
329
|
+
with open(path, "w") as f:
|
|
330
|
+
json.dump(meta, f, sort_keys=True)
|
|
331
|
+
_log(f"cache metadata saved: {path} ({meta['sm']})")
|
|
332
|
+
return True
|
|
333
|
+
except Exception as e:
|
|
334
|
+
_log(f"cache metadata save failed: {e}")
|
|
335
|
+
return False
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def warmup(out_dir: str = DEFAULT_CACHE_DIR, arch: str | None = None) -> int:
|
|
339
|
+
"""Run every warm step then persist the mega-cache. Returns a process exit code.
|
|
340
|
+
|
|
341
|
+
Best-effort end to end: individual kernel failures are tolerated (we bake what compiled); only a
|
|
342
|
+
total absence of GPU/torch or a failed save is a non-zero exit so the builder surfaces it.
|
|
343
|
+
"""
|
|
344
|
+
t0 = time.time()
|
|
345
|
+
_point_backends_at(out_dir)
|
|
346
|
+
if arch:
|
|
347
|
+
# let the caller pin the compile target for source builds that read it (e.g. flash-attn)
|
|
348
|
+
os.environ.setdefault("TORCH_CUDA_ARCH_LIST", arch)
|
|
349
|
+
_log(f"target arch pinned: TORCH_CUDA_ARCH_LIST={arch}")
|
|
350
|
+
torch = _require_gpu()
|
|
351
|
+
if torch is None:
|
|
352
|
+
return 1
|
|
353
|
+
if arch:
|
|
354
|
+
# --arch pins the compile target but the JIT/source builds key off the LIVE GPU, and the
|
|
355
|
+
# saved metadata records the physical sm. a mismatch (e.g. the sm90 publish step mis-scheduled
|
|
356
|
+
# onto an sm89 runner) would bake a cu128-sm90 image whose metadata says sm89 -> every H100
|
|
357
|
+
# worker rejects it and cold-JITs. FAIL the bake rather than publish a mislabeled artifact.
|
|
358
|
+
want_sm = "sm" + arch.replace(".", "")
|
|
359
|
+
live_sm = _torch_sm(torch)
|
|
360
|
+
if want_sm != live_sm:
|
|
361
|
+
_log(
|
|
362
|
+
f"ERROR: --arch {arch} ({want_sm}) does not match live GPU {live_sm}; refusing to "
|
|
363
|
+
f"bake a mislabeled cache (a {want_sm} image would carry {live_sm} metadata). "
|
|
364
|
+
"re-run on a matching GPU."
|
|
365
|
+
)
|
|
366
|
+
return 1
|
|
367
|
+
warmed = sum(
|
|
368
|
+
[
|
|
369
|
+
warm_flash_attn(torch),
|
|
370
|
+
warm_liger_ce(torch),
|
|
371
|
+
warm_fla_gdn(torch),
|
|
372
|
+
warm_chalk_kernels(),
|
|
373
|
+
warm_torch_compile(torch),
|
|
374
|
+
]
|
|
375
|
+
)
|
|
376
|
+
_log(f"{warmed}/5 kernel groups compiled in {time.time() - t0:.1f}s; saving mega-cache")
|
|
377
|
+
saved = save_mega_cache(torch, out_dir)
|
|
378
|
+
meta_saved = save_cache_metadata(torch, out_dir, requested_arch=arch, warmed=warmed)
|
|
379
|
+
_log(f"done in {time.time() - t0:.1f}s (saved={saved})")
|
|
380
|
+
return 0 if saved and meta_saved else 1
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def main() -> int:
|
|
384
|
+
ap = argparse.ArgumentParser(description="pre-compile hot worker kernels and bake a mega-cache")
|
|
385
|
+
ap.add_argument(
|
|
386
|
+
"--arch",
|
|
387
|
+
default=None,
|
|
388
|
+
help="target TORCH_CUDA_ARCH_LIST (e.g. '9.0' for Hopper); default: probe the live GPU",
|
|
389
|
+
)
|
|
390
|
+
ap.add_argument(
|
|
391
|
+
"--out",
|
|
392
|
+
default=DEFAULT_CACHE_DIR,
|
|
393
|
+
help=f"cache output dir (default: {DEFAULT_CACHE_DIR}); the bake produces <out>/{MEGA_CACHE_FILENAME}",
|
|
394
|
+
)
|
|
395
|
+
args = ap.parse_args()
|
|
396
|
+
return warmup(out_dir=args.out, arch=args.arch)
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
if __name__ == "__main__":
|
|
400
|
+
raise SystemExit(main())
|