freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,4 @@
1
+ from flash.engine.worker import main
2
+
3
+ if __name__ == "__main__":
4
+ raise SystemExit(main())
@@ -0,0 +1,400 @@
1
+ """Pre-compile the worker's hot kernels on a real GPU and persist a portable mega-cache.
2
+
3
+ Run this ON A GPU BUILDER (an image-build runner that actually has the target arch's GPU) to kill
4
+ the ~10-15 min first-use JIT that #194 reintroduced on a cold worker. It warms the kernels the
5
+ trainer hits early — FlashAttention fwd/bwd, the Liger fused cross-entropy, flash-linear-attention's
6
+ Gated-DeltaNet (Qwen3.5/3.6 hybrid), and a representative ``torch.compile`` — then calls
7
+ ``torch.compiler.save_cache_artifacts()`` to write ONE portable mega-cache blob into the cache dir.
8
+ ``flash.engine.worker._load_kernel_cache_if_present`` loads it back at worker boot
9
+ (``torch.compiler.load_cache_artifacts``); the Dockerfile bakes the produced ``build/kernel_cache/``
10
+ into the image when built with ``--build-arg BUILD_KERNEL_CACHE=true``.
11
+
12
+ Measured: cold compile ~124s -> warm load ~0.2s (537x).
13
+
14
+ This module is import-safe WITHOUT torch installed (it must ``py_compile`` on the CPU-only CI image
15
+ that builds the worker): every heavy import lives INSIDE a function. Everything is best-effort —
16
+ each warm step is independently guarded so a missing/uncompilable kernel never aborts the bake; we
17
+ save whatever did compile. CLI: ``python -m flash.engine.worker.kernel_warmup --arch <sm> --out <dir>``.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import argparse
23
+ import json
24
+ import os
25
+ import time
26
+
27
+ # Default bake dir. Mirrors the Dockerfile's /opt/flash/kernelcache; the saved mega-cache file lands
28
+ # directly under it so _load_kernel_cache_if_present finds it. Keep this name in lockstep with
29
+ # engine.worker._KERNEL_CACHE_DIR / _KERNEL_CACHE_FILE.
30
+ DEFAULT_CACHE_DIR = "/opt/flash/kernelcache"
31
+ MEGA_CACHE_FILENAME = "mega_cache.bin"
32
+ MEGA_CACHE_META_FILENAME = "mega_cache.json"
33
+
34
+
35
+ def _log(msg: str) -> None:
36
+ """Single progress channel so the GPU builder's logs show each warm step."""
37
+ print(f"[kernel-warmup] {msg}", flush=True)
38
+
39
+
40
+ def _point_backends_at(cache_dir: str) -> None:
41
+ """Point Triton + TorchInductor at ``cache_dir`` so anything compiled below is content-addressed
42
+ under the same tree the worker reads (matches the Dockerfile ENV)."""
43
+ os.makedirs(os.path.join(cache_dir, "triton"), exist_ok=True)
44
+ os.makedirs(os.path.join(cache_dir, "inductor"), exist_ok=True)
45
+ os.environ["TRITON_CACHE_DIR"] = os.path.join(cache_dir, "triton")
46
+ os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(cache_dir, "inductor")
47
+
48
+
49
+ def _torch_sm(torch) -> str:
50
+ cap = torch.cuda.get_device_capability(0)
51
+ return f"sm{cap[0]}{cap[1]}"
52
+
53
+
54
+ def _require_gpu():
55
+ """Return the torch module if a CUDA GPU is live, else None (with a clear log).
56
+
57
+ The warm steps are only meaningful on a real GPU of the target arch — kernels are
58
+ content-addressed by arch + toolchain, so a CPU run would bake nothing usable.
59
+ """
60
+ try:
61
+ import torch
62
+
63
+ if not torch.cuda.is_available():
64
+ _log("no CUDA device visible — kernel warmup must run on a GPU builder; nothing baked")
65
+ return None
66
+ _log(f"GPU: {torch.cuda.get_device_name(0)} ({_torch_sm(torch)}), torch {torch.__version__}")
67
+ return torch
68
+ except Exception as e:
69
+ _log(f"torch unavailable ({e}); cannot warm kernels")
70
+ return None
71
+
72
+
73
+ def warm_flash_attn(torch) -> bool:
74
+ """Compile FlashAttention fwd + bwd (FA2 everywhere, + FA3 on Hopper) with one tiny attention."""
75
+ warmed = False
76
+ try:
77
+ from flash_attn import flash_attn_func
78
+
79
+ q, k, v = (
80
+ torch.randn(1, 64, 4, 64, device="cuda", dtype=torch.bfloat16, requires_grad=True)
81
+ for _ in range(3)
82
+ )
83
+ out = flash_attn_func(q, k, v, causal=True)
84
+ out.sum().backward() # exercise the bwd kernel too
85
+ torch.cuda.synchronize()
86
+ _log("flash-attn (FA2) fwd/bwd compiled")
87
+ warmed = True
88
+ except Exception as e:
89
+ _log(f"flash-attn (FA2) warm skipped: {e}")
90
+ try:
91
+ # FA3: on Hopper (sm90) the worker selects attn_implementation="flash_attention_3" (the local
92
+ # flash_attn_interface build) for full-attention layers, so a baked H100 cache must cover it
93
+ # too. a no-op/skip off-Hopper (the kernel is Hopper-only; the wheel just rides along).
94
+ import flash_attn_interface
95
+
96
+ q, k, v = (
97
+ torch.randn(1, 64, 4, 64, device="cuda", dtype=torch.bfloat16, requires_grad=True)
98
+ for _ in range(3)
99
+ )
100
+ out = flash_attn_interface.flash_attn_func(q, k, v, causal=True)
101
+ if isinstance(out, tuple):
102
+ out = out[0]
103
+ out.sum().backward()
104
+ torch.cuda.synchronize()
105
+ _log("flash-attn-3 (Hopper) fwd/bwd compiled")
106
+ warmed = True
107
+ except Exception as e:
108
+ _log(f"flash-attn-3 warm skipped (expected off-Hopper): {e}")
109
+ return warmed
110
+
111
+
112
+ def warm_liger_ce(torch) -> bool:
113
+ """Compile Liger cross-entropy kernels."""
114
+ warmed = False
115
+ try:
116
+ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
117
+
118
+ loss_fn = LigerCrossEntropyLoss()
119
+ logits = torch.randn(64, 4096, device="cuda", dtype=torch.bfloat16, requires_grad=True)
120
+ labels = torch.randint(0, 4096, (64,), device="cuda")
121
+ loss_fn(logits, labels).backward()
122
+ torch.cuda.synchronize()
123
+ _log("liger fused cross-entropy compiled")
124
+ warmed = True
125
+ except Exception as e:
126
+ _log(f"liger CE warm skipped: {e}")
127
+ try:
128
+ candidates = (
129
+ ("liger_kernel.ops.fused_linear_cross_entropy", "LigerFusedLinearCrossEntropyLoss"),
130
+ (
131
+ "liger_kernel.transformers.fused_linear_cross_entropy",
132
+ "LigerFusedLinearCrossEntropyLoss",
133
+ ),
134
+ )
135
+ loss_cls = None
136
+ for module_name, attr in candidates:
137
+ try:
138
+ mod = __import__(module_name, fromlist=[attr])
139
+ loss_cls = getattr(mod, attr)
140
+ break
141
+ except Exception:
142
+ continue
143
+ if loss_cls is None:
144
+ raise ImportError("no fused-linear Liger loss class found")
145
+ # representative catalog vocab width (qwen3.5/3.6 lm_head ~248k); triton/liger specialize the
146
+ # fused-ce chunking to the vocab shape, so warm the production width, not a toy 4096.
147
+ vocab = 248_320
148
+ hidden = torch.randn(64, 256, device="cuda", dtype=torch.bfloat16, requires_grad=True)
149
+ weight = torch.randn(vocab, 256, device="cuda", dtype=torch.bfloat16, requires_grad=True)
150
+ labels = torch.randint(0, vocab, (64,), device="cuda")
151
+ loss_fn = loss_cls()
152
+ # upstream signature is forward(self, lin_weight, _input, target): weight first, then hidden.
153
+ # call the known-good form first so we never launch a mismatched shape (which would trigger a
154
+ # cuda illegal access and poison the context before a later attempt can run).
155
+ attempts = (
156
+ lambda: loss_fn(weight, hidden, labels),
157
+ lambda: loss_fn(weight, hidden, target=labels),
158
+ lambda: loss_fn(hidden, weight, labels),
159
+ )
160
+ for call in attempts:
161
+ try:
162
+ out = call()
163
+ if isinstance(out, tuple):
164
+ out = out[0]
165
+ out.backward()
166
+ torch.cuda.synchronize()
167
+ _log("liger fused-linear loss compiled")
168
+ warmed = True
169
+ break # fall through to the model-layer warm below; don't exit the function
170
+ except Exception:
171
+ continue
172
+ else:
173
+ raise RuntimeError("fused-linear Liger calls were not accepted")
174
+ except Exception as e:
175
+ _log(f"liger fused-linear warm skipped: {e}")
176
+ try:
177
+ # model-layer liger kernels (rmsnorm + rope) that use_liger_kernel patches in besides the
178
+ # loss; these still jit on the first real forward/backward if only the ce loss was warmed.
179
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm
180
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb
181
+
182
+ rms = LigerRMSNorm(hidden_size=256).to(device="cuda", dtype=torch.bfloat16)
183
+ x = torch.randn(64, 256, device="cuda", dtype=torch.bfloat16, requires_grad=True)
184
+ rms(x).sum().backward()
185
+ b, h, t, d = 1, 4, 64, 64
186
+ q = torch.randn(b, h, t, d, device="cuda", dtype=torch.bfloat16)
187
+ k = torch.randn(b, h, t, d, device="cuda", dtype=torch.bfloat16)
188
+ cos = torch.randn(b, t, d, device="cuda", dtype=torch.bfloat16)
189
+ sin = torch.randn(b, t, d, device="cuda", dtype=torch.bfloat16)
190
+ liger_rotary_pos_emb(q, k, cos, sin)
191
+ torch.cuda.synchronize()
192
+ _log("liger model-layer kernels (rmsnorm/rope) compiled")
193
+ warmed = True
194
+ except Exception as e:
195
+ _log(f"liger model-layer warm skipped: {e}")
196
+ return warmed
197
+
198
+
199
+ def warm_fla_gdn(torch) -> bool:
200
+ """Compile flash-linear-attention's Gated-DeltaNet chunk kernels (Qwen3.5/3.6 hybrid path)."""
201
+ try:
202
+ # mirror the worker boot: on Hopper (sm90) the production path runs this first so fla's GDN
203
+ # chunk_bwd uses the CORRECT tilelang backend (fla #640: the Triton path miscomputes/raises).
204
+ # off-Hopper this is a no-op. bake the same backend the runtime will actually select.
205
+ try:
206
+ from flash.engine.worker.perf import _ensure_fla_fastpath_on_hopper
207
+
208
+ _ensure_fla_fastpath_on_hopper()
209
+ except Exception as e:
210
+ _log(f"hopper fla fast-path setup skipped: {e}")
211
+ from fla.ops.gated_delta_rule import chunk_gated_delta_rule
212
+
213
+ b, h, t, d = 1, 4, 256, 64
214
+ q = torch.randn(b, t, h, d, device="cuda", dtype=torch.bfloat16, requires_grad=True)
215
+ k = torch.randn(b, t, h, d, device="cuda", dtype=torch.bfloat16, requires_grad=True)
216
+ v = torch.randn(b, t, h, d, device="cuda", dtype=torch.bfloat16, requires_grad=True)
217
+ g = torch.randn(b, t, h, device="cuda", dtype=torch.float32, requires_grad=True)
218
+ beta = torch.rand(b, t, h, device="cuda", dtype=torch.bfloat16, requires_grad=True)
219
+ out = chunk_gated_delta_rule(q, k, v, g, beta, use_qk_l2norm_in_kernel=True)
220
+ if isinstance(out, tuple):
221
+ out = out[0]
222
+ out.sum().backward()
223
+ torch.cuda.synchronize()
224
+ _log("flash-linear-attention GDN fwd/bwd compiled")
225
+ # also warm the varlen path SFT token-packing (#218) uses: BlockDiagonalCollator(emit_varlen=True)
226
+ # feeds cu_seq_lens into the fla DeltaNet so the recurrence resets per packed example, which
227
+ # compiles different chunk kernels than the equal-length call above. shape it like one packed
228
+ # block (batch flattened to 1) with two unequal segments.
229
+ try:
230
+ tv = 128 + 96 # two example lengths packed into one sequence
231
+ qv = torch.randn(1, tv, h, d, device="cuda", dtype=torch.bfloat16, requires_grad=True)
232
+ kv = torch.randn(1, tv, h, d, device="cuda", dtype=torch.bfloat16, requires_grad=True)
233
+ vv = torch.randn(1, tv, h, d, device="cuda", dtype=torch.bfloat16, requires_grad=True)
234
+ gv = torch.randn(1, tv, h, device="cuda", dtype=torch.float32, requires_grad=True)
235
+ betav = torch.rand(1, tv, h, device="cuda", dtype=torch.bfloat16, requires_grad=True)
236
+ cu = torch.tensor([0, 128, tv], device="cuda", dtype=torch.int32)
237
+ outv = chunk_gated_delta_rule(
238
+ qv, kv, vv, gv, betav, use_qk_l2norm_in_kernel=True, cu_seqlens=cu
239
+ )
240
+ if isinstance(outv, tuple):
241
+ outv = outv[0]
242
+ outv.sum().backward()
243
+ torch.cuda.synchronize()
244
+ _log("flash-linear-attention GDN varlen (cu_seqlens) fwd/bwd compiled")
245
+ except Exception as e:
246
+ _log(f"fla GDN varlen warm skipped: {e}")
247
+ return True
248
+ except Exception as e:
249
+ _log(f"fla GDN warm skipped: {e}")
250
+ return False
251
+
252
+
253
+ def warm_chalk_kernels() -> bool:
254
+ """Compile default chalk self-test kernels when freesolo-chalk is installed."""
255
+ warmed = False
256
+ try:
257
+ from chalk.transformers import install_fused_lora_delta, install_qwen35_rope
258
+
259
+ warmed = bool(install_qwen35_rope()) or warmed
260
+ warmed = bool(install_fused_lora_delta()) or warmed
261
+ # embedding gather is chalk's third default gap-filler (chalk_kernels._KERNELS); import it
262
+ # separately so a name/version skew can't also drop the rope + lora-delta warms above.
263
+ try:
264
+ from chalk.transformers import install_fused_embedding
265
+
266
+ warmed = bool(install_fused_embedding()) or warmed
267
+ except Exception as e:
268
+ _log(f"chalk fused-embedding warm skipped: {e}")
269
+ _log(f"chalk default kernel installers ran (warmed={warmed})")
270
+ except Exception as e:
271
+ _log(f"chalk warm skipped: {e}")
272
+ return warmed
273
+
274
+
275
+ def warm_torch_compile(torch) -> bool:
276
+ """Trigger a representative ``torch.compile`` so TorchInductor populates its cache."""
277
+ try:
278
+
279
+ @torch.compile
280
+ def _fused(a, b):
281
+ return torch.nn.functional.gelu(a @ b)
282
+
283
+ a = torch.randn(256, 256, device="cuda", dtype=torch.bfloat16)
284
+ b = torch.randn(256, 256, device="cuda", dtype=torch.bfloat16)
285
+ _fused(a, b)
286
+ torch.cuda.synchronize()
287
+ _log("torch.compile (TorchInductor) warmed")
288
+ return True
289
+ except Exception as e:
290
+ _log(f"torch.compile warm skipped: {e}")
291
+ return False
292
+
293
+
294
+ def save_mega_cache(torch, out_dir: str) -> bool:
295
+ """Persist everything compiled this session into one portable blob via
296
+ ``torch.compiler.save_cache_artifacts()`` so the worker can ``load_cache_artifacts`` it at boot.
297
+ """
298
+ try:
299
+ artifacts = torch.compiler.save_cache_artifacts()
300
+ if not artifacts:
301
+ _log("save_cache_artifacts returned nothing — no compiled kernels to persist")
302
+ return False
303
+ # save_cache_artifacts returns (bytes, meta); persist the bytes payload.
304
+ blob = artifacts[0] if isinstance(artifacts, tuple) else artifacts
305
+ os.makedirs(out_dir, exist_ok=True)
306
+ path = os.path.join(out_dir, MEGA_CACHE_FILENAME)
307
+ with open(path, "wb") as f:
308
+ f.write(blob)
309
+ _log(f"mega-cache saved: {path} ({len(blob)} bytes)")
310
+ return True
311
+ except Exception as e:
312
+ _log(f"save_cache_artifacts failed: {e}")
313
+ return False
314
+
315
+
316
+ def save_cache_metadata(torch, out_dir: str, *, requested_arch: str | None, warmed: int) -> bool:
317
+ try:
318
+ meta = {
319
+ "sm": _torch_sm(torch),
320
+ "requested_arch": requested_arch,
321
+ "torch": getattr(torch, "__version__", "unknown"),
322
+ "cuda": getattr(getattr(torch, "version", None), "cuda", None),
323
+ "device": torch.cuda.get_device_name(0),
324
+ "warmed_groups": int(warmed),
325
+ "created_at": int(time.time()),
326
+ }
327
+ os.makedirs(out_dir, exist_ok=True)
328
+ path = os.path.join(out_dir, MEGA_CACHE_META_FILENAME)
329
+ with open(path, "w") as f:
330
+ json.dump(meta, f, sort_keys=True)
331
+ _log(f"cache metadata saved: {path} ({meta['sm']})")
332
+ return True
333
+ except Exception as e:
334
+ _log(f"cache metadata save failed: {e}")
335
+ return False
336
+
337
+
338
+ def warmup(out_dir: str = DEFAULT_CACHE_DIR, arch: str | None = None) -> int:
339
+ """Run every warm step then persist the mega-cache. Returns a process exit code.
340
+
341
+ Best-effort end to end: individual kernel failures are tolerated (we bake what compiled); only a
342
+ total absence of GPU/torch or a failed save is a non-zero exit so the builder surfaces it.
343
+ """
344
+ t0 = time.time()
345
+ _point_backends_at(out_dir)
346
+ if arch:
347
+ # let the caller pin the compile target for source builds that read it (e.g. flash-attn)
348
+ os.environ.setdefault("TORCH_CUDA_ARCH_LIST", arch)
349
+ _log(f"target arch pinned: TORCH_CUDA_ARCH_LIST={arch}")
350
+ torch = _require_gpu()
351
+ if torch is None:
352
+ return 1
353
+ if arch:
354
+ # --arch pins the compile target but the JIT/source builds key off the LIVE GPU, and the
355
+ # saved metadata records the physical sm. a mismatch (e.g. the sm90 publish step mis-scheduled
356
+ # onto an sm89 runner) would bake a cu128-sm90 image whose metadata says sm89 -> every H100
357
+ # worker rejects it and cold-JITs. FAIL the bake rather than publish a mislabeled artifact.
358
+ want_sm = "sm" + arch.replace(".", "")
359
+ live_sm = _torch_sm(torch)
360
+ if want_sm != live_sm:
361
+ _log(
362
+ f"ERROR: --arch {arch} ({want_sm}) does not match live GPU {live_sm}; refusing to "
363
+ f"bake a mislabeled cache (a {want_sm} image would carry {live_sm} metadata). "
364
+ "re-run on a matching GPU."
365
+ )
366
+ return 1
367
+ warmed = sum(
368
+ [
369
+ warm_flash_attn(torch),
370
+ warm_liger_ce(torch),
371
+ warm_fla_gdn(torch),
372
+ warm_chalk_kernels(),
373
+ warm_torch_compile(torch),
374
+ ]
375
+ )
376
+ _log(f"{warmed}/5 kernel groups compiled in {time.time() - t0:.1f}s; saving mega-cache")
377
+ saved = save_mega_cache(torch, out_dir)
378
+ meta_saved = save_cache_metadata(torch, out_dir, requested_arch=arch, warmed=warmed)
379
+ _log(f"done in {time.time() - t0:.1f}s (saved={saved})")
380
+ return 0 if saved and meta_saved else 1
381
+
382
+
383
+ def main() -> int:
384
+ ap = argparse.ArgumentParser(description="pre-compile hot worker kernels and bake a mega-cache")
385
+ ap.add_argument(
386
+ "--arch",
387
+ default=None,
388
+ help="target TORCH_CUDA_ARCH_LIST (e.g. '9.0' for Hopper); default: probe the live GPU",
389
+ )
390
+ ap.add_argument(
391
+ "--out",
392
+ default=DEFAULT_CACHE_DIR,
393
+ help=f"cache output dir (default: {DEFAULT_CACHE_DIR}); the bake produces <out>/{MEGA_CACHE_FILENAME}",
394
+ )
395
+ args = ap.parse_args()
396
+ return warmup(out_dir=args.out, arch=args.arch)
397
+
398
+
399
+ if __name__ == "__main__":
400
+ raise SystemExit(main())