freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,1048 @@
1
+ """Pure GPU/perf/optimizer probes for the fine-tuning worker.
2
+
3
+ These helpers take the model id / max length / capability as ARGUMENTS and read NONE of
4
+ the worker's run-scoped module globals (HF_REPO/RUN_ID/SEED/RUN_MODE/PHASE/JOB_SPEC/
5
+ ACTIVE_ENV/THINKING or the _HB_* heartbeat family), so they live here as a leaf module.
6
+ ``flash.engine.worker`` re-exports them; this module must NOT import that package (no cycle).
7
+ Torch and other heavy deps are imported lazily inside the functions (CPU-importable).
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import contextlib
13
+ import csv
14
+ import os
15
+ import sys
16
+ import time
17
+
18
+ # Fused-CE (Liger) gate thresholds live in ONE place — flash.engine.vram — so the worker's run-time
19
+ # gate and the cost estimator's offline mirror (sft_logits_fused) can never drift. vram is a pure
20
+ # leaf (no worker import), so this is cycle-free.
21
+ from flash.engine.vram import _LIGER_LONG_CTX_TOKENS, _LIGER_MIN_PARAMS_B
22
+
23
+
24
+ def _attn_impl_for_capability(
25
+ major: int, minor: int = 0, *, fa3_available: bool = False, fa2_available: bool = False
26
+ ) -> str | None:
27
+ """Map a CUDA compute capability to the trainer ``attn_implementation`` — the best per-arch
28
+ FlashAttention kernel for the model's FULL-attention (softmax) layers, so SFT *and* GRPO use a
29
+ real flash kernel on every arch where one exists (not plain SDPA). The Qwen3.5/3.6
30
+ Gated-DeltaNet *linear*-attention layers always keep their own path (fla, or the native
31
+ pure-PyTorch delta rule once fla is dropped on Hopper) — FlashAttention does not apply to linear
32
+ attention.
33
+
34
+ Each arch maps to its ONE best flash kernel; the fallback is UNIFORM — plain SDPA on every arch
35
+ when that kernel's package is absent (no special FA3->FA2 chain on Hopper):
36
+ * Hopper (sm90, H100/H200): "flash_attention_3" — FA3's warp-specialized async kernels are the
37
+ fastest exact attention on Hopper; transformers routes it to the LOCAL ``flash_attn_interface``
38
+ (no HF Kernels-Hub, whose torch2.10 versions break ``import transformers``). FA3 is baked into
39
+ the worker image by default (Dockerfile FLASH_ATTN_3_SPEC), so ``fa3_available`` is normally
40
+ True; absent -> plain SDPA, same as every other arch.
41
+ * Ampere (sm80 A100 / sm86 3090·A6000) + Ada (sm89 4090·L40S): "flash_attention_2" when the
42
+ ``flash_attn`` wheel is importable (``fa2_available``) — FA3 does NOT support these archs.
43
+ * consumer Blackwell (sm120 5090 / RTX Pro): "sdpa" forced to the cuDNN backend. THE ONE arch
44
+ with no flash: FA3/FA4 need TMEM/tcgen05 that sm120 lacks, and the prebuilt FA2 CUDA wheel's
45
+ sm120 coverage is unverified, so cuDNN SDPA is the validated best here.
46
+ * anything else / flash unavailable -> None: transformers picks SDPA (already flash-backed on
47
+ Ampere/Ada/Hopper).
48
+ Pure function (no torch / no imports) so it's unit-testable on CPU; ``fa2_available`` /
49
+ ``fa3_available`` are the caller's probes (``optimal_attn_impl``). The big LoRA win is still the
50
+ Liger/chalk fused kernels; flash helps only the ~25% full-attention layers of the hybrid arch."""
51
+ if major == 9 and fa3_available: # Hopper: FA3 is the arch's best flash kernel
52
+ return "flash_attention_3"
53
+ if major == 8 and minor in (0, 6, 9) and fa2_available: # Ampere 8.0/8.6 + Ada 8.9 ONLY: FA2
54
+ # (gate the minor so an unsupported sm8x like sm87 Jetson Orin doesn't get FA2 forced on it)
55
+ return "flash_attention_2"
56
+ if (
57
+ major == 12
58
+ ): # consumer Blackwell: cuDNN SDPA (the one exception — FA3/FA4 need TMEM/tcgen05)
59
+ return "sdpa"
60
+ return None # the arch's flash kernel is absent -> plain SDPA (the SAME fallback on every arch)
61
+
62
+
63
+ def _flash_attn_3_available() -> bool:
64
+ """True when FlashAttention-3 is usable by transformers on this worker — i.e. the
65
+ ``flash_attn_interface`` module (the ``flash-attn-3`` Hopper build) is importable.
66
+
67
+ transformers' ``flash_attention_3`` path does ``from flash_attn_interface import
68
+ flash_attn_func, ...`` (modeling_flash_attention_utils), so a present module is exactly what
69
+ makes ``attn_implementation="flash_attention_3"`` resolve WITHOUT the HF Kernels-Hub. Prefer
70
+ transformers' own ``is_flash_attn_3_available`` probe (it verifies real importability). Only if
71
+ that probe is itself unavailable (transformers not importable here) fall back to a GUARDED import
72
+ of ``flash_attn_interface`` — NOT a bare ``find_spec``, so an on-disk-but-broken install (ABI
73
+ mismatch / missing .so) reads as unavailable instead of a false positive that would later crash
74
+ transformers at model load. FA3 is used whenever it's importable — fixed, no disable knob."""
75
+ try:
76
+ from transformers.utils import is_flash_attn_3_available
77
+
78
+ return bool(is_flash_attn_3_available())
79
+ except Exception:
80
+ try:
81
+ import flash_attn_interface # noqa: F401 (guarded: verifies real importability)
82
+
83
+ return True
84
+ except Exception:
85
+ return False
86
+
87
+
88
+ def _flash_attn_available() -> bool:
89
+ """True when the ``flash_attn`` (FA2) wheel is importable (baked into the worker image).
90
+
91
+ Drives the FA2 ``attn_implementation`` selection on Ampere/Ada (via ``_attn_impl_for_capability``)
92
+ AND the SFT packing default on every arch. ``_attn_impl_for_capability`` itself never picks FA2 on
93
+ Hopper (FA3, else uniform SDPA); FA2 re-enters there ONLY through the SFT packing path, which
94
+ forces FA2 varlen when ``optimal_attn_impl`` returned None (Hopper without FA3). On sm120 the
95
+ selector returns ``"sdpa"`` and run_sft DISABLES packing instead (consumer Blackwell stays plain
96
+ SDPA — no flash), so sm120 never forces FA2. Packing rationale: TRL's ``packing_strategy='bfd'``
97
+ produces flattened/padding-free
98
+ batches whose example boundaries are carried by ``position_ids`` and enforced ONLY by a
99
+ varlen-capable attention impl (FA2/FA3/flex). Under plain SDPA, packed examples attend ACROSS
100
+ boundaries (silent quality loss). find_spec only — no import side effects (no CUDA init). FA2 is
101
+ used whenever the wheel is importable — fixed, no disable knob."""
102
+ try:
103
+ import importlib.util
104
+
105
+ return importlib.util.find_spec("flash_attn") is not None
106
+ except Exception:
107
+ return False
108
+
109
+
110
+ def optimal_attn_impl() -> str | None:
111
+ """Best ``attn_implementation`` for the live GPU (None = leave transformers' default)."""
112
+ try:
113
+ import torch
114
+
115
+ if not torch.cuda.is_available():
116
+ return None
117
+ major, minor = torch.cuda.get_device_capability(0)
118
+ except Exception as e:
119
+ print("optimal_attn_impl probe failed:", e)
120
+ return None
121
+ fa2 = _flash_attn_available() # FA2 wheel importable (Ampere/Ada/Hopper)
122
+ # Probe FA3 only on Hopper (the only arch it selects it for) so a non-Hopper run never imports
123
+ # the transformers FA3 helpers needlessly.
124
+ fa3 = _flash_attn_3_available() if major == 9 else False
125
+ impl = _attn_impl_for_capability(major, minor, fa3_available=fa3, fa2_available=fa2)
126
+ if impl in ("flash_attention_2", "flash_attention_3"):
127
+ ver = "FlashAttention-3" if impl == "flash_attention_3" else "FlashAttention-2"
128
+ print(
129
+ f"[attn] sm{major}{minor} -> attn_implementation={impl} ({ver}, full-attention layers)"
130
+ )
131
+ elif major == 9 and not fa3:
132
+ # Hopper but FA3 not selected -> plain SDPA (uniform fallback). FA3 is baked into the worker
133
+ # image by default, so this means flash_attn_interface is absent/broken — check the install.
134
+ print(f"[attn] sm{major}{minor}: FA3 unavailable (flash_attn_interface absent) -> SDPA")
135
+ elif major == 12: # the only arch that returns impl=="sdpa" -> this branch covers all of it
136
+ print(
137
+ f"[attn] sm{major}{minor} (consumer Blackwell) -> SDPA/cuDNN (FA3/FA4 need TMEM; n/a on sm120)"
138
+ )
139
+ elif not fa2:
140
+ print(f"[attn] sm{major}{minor}: flash_attn wheel absent -> SDPA")
141
+ return impl
142
+
143
+
144
+ # Liger's fused linear cross-entropy is a MEMORY optimization (it never materializes the fp32
145
+ # [B,T,vocab] logits), not a fixed-batch speed win. PR #174 ledger: on a 1B model at fixed batch
146
+ # it is a measured NET LOSS on EVERY arch (min-of-3: A100 0.86x, H100 0.90x, RTX 3090 0.78x,
147
+ # RTX 4090 0.83x, RTX 5090 0.79x) — the per-step Triton overhead isn't repaid because the small
148
+ # model's logits don't dominate memory. Its value appears on LARGE models (lets a bigger batch
149
+ # fit / avoids OOM). So gate by estimated model size.
150
+ # ~3B in raw param count; the canonical threshold (in billions) lives in flash.engine.vram.
151
+ # 1B-class models measured net-negative -> Liger off below this.
152
+ _LIGER_MIN_PARAMS = _LIGER_MIN_PARAMS_B * 1e9
153
+
154
+
155
+ def _estimate_params(cfg) -> float:
156
+ """Rough param count from a HF config: embeddings (+untied lm_head) + transformer blocks.
157
+ For multimodal checkpoints (e.g. Qwen3.5-VL) the LM dims live under ``text_config`` — read it
158
+ when the top-level dims are absent, else the gate underestimates and wrongly disables the
159
+ memory path (GC/Liger) for the 4B/9B tiers."""
160
+ tc = getattr(cfg, "text_config", None)
161
+ src = cfg if getattr(cfg, "hidden_size", 0) else (tc or cfg)
162
+ h = getattr(src, "hidden_size", 0) or 0
163
+ v = getattr(src, "vocab_size", 0) or getattr(cfg, "vocab_size", 0) or 0
164
+ n = getattr(src, "num_hidden_layers", 0) or 0
165
+ tied = getattr(src, "tie_word_embeddings", getattr(cfg, "tie_word_embeddings", False))
166
+ emb = v * h * (1 if tied else 2)
167
+ blocks = n * 12 * h * h # ~12 h^2 per transformer block (attn + MLP)
168
+ return float(emb + blocks)
169
+
170
+
171
+ def _liger_default_for_model(model_id: str) -> bool:
172
+ """Default Liger ON only for models large enough that fused-CE's memory win pays off
173
+ (≥ _LIGER_MIN_PARAMS, ~3B). 1B-class models measured net-negative -> default OFF."""
174
+ try:
175
+ from transformers import AutoConfig
176
+
177
+ cfg = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
178
+ return _estimate_params(cfg) >= _LIGER_MIN_PARAMS
179
+ except Exception as e:
180
+ print("liger model-size probe failed (default off):", e)
181
+ return False
182
+
183
+
184
+ def liger_on(default_on: bool) -> bool:
185
+ """Whether to enable a Liger kernel path. ``default_on`` is the model-size decision (on only
186
+ for models large enough that fused-CE's memory win pays off; 1B-class is a measured net loss).
187
+ Even when on, require a CUDA GPU AND that ``liger_kernel`` is importable — the local
188
+ ``flash[gpu]`` extra doesn't ship it, so blindly setting use_liger_kernel would crash a
189
+ local GPU run. No GPU / absent -> off."""
190
+ if not default_on:
191
+ return False
192
+ try:
193
+ import importlib.util
194
+
195
+ import torch
196
+
197
+ return bool(
198
+ torch.cuda.is_available() and importlib.util.find_spec("liger_kernel") is not None
199
+ )
200
+ except Exception:
201
+ return False
202
+
203
+
204
+ def setup_perf_backends() -> None:
205
+ """Universal, arch-agnostic throughput knobs — safe on every CUDA arch, no JIT/compile cost.
206
+
207
+ - TF32 for fp32 matmuls/cuDNN (Ampere+): the residual fp32 ops in a bf16 LoRA run (some
208
+ norms, the optimizer's fp32 master step, any fp32 GEMM) run on the TF32 tensor cores at
209
+ ~no accuracy cost. No-op on pre-Ampere.
210
+ """
211
+ try:
212
+ import torch
213
+
214
+ if not torch.cuda.is_available():
215
+ return
216
+ torch.set_float32_matmul_precision("high") # TF32 for fp32 matmuls
217
+ torch.backends.cuda.matmul.allow_tf32 = True
218
+ torch.backends.cudnn.allow_tf32 = True
219
+ print("[perf] TF32 matmul/cuDNN enabled")
220
+ except Exception as e:
221
+ print("setup_perf_backends skipped:", e)
222
+
223
+
224
+ def _remove_fla_from_disk() -> tuple[list[str], bool]:
225
+ """Physically delete every importable ``fla`` package dir from the worker's REAL sys.path.
226
+
227
+ Loops until ``find_spec('fla')`` is clean (removing one copy can expose another further down
228
+ the path) and invalidates import caches so transformers' is_fla_available() probe sees it
229
+ gone. ``pip uninstall`` alone is unreliable here — it targets one site-packages but the base
230
+ image bakes ``fla`` into another dir on the path (and can report success while leaving the
231
+ package dir). Returns ``(removed_dirs, still_importable)``. Used by the Hopper auto-drop.
232
+ """
233
+ import importlib
234
+ import importlib.util
235
+ import shutil
236
+
237
+ removed: list[str] = []
238
+ for _ in range(6): # a few passes: removing one copy can reveal another earlier on the path
239
+ importlib.invalidate_caches()
240
+ spec = importlib.util.find_spec("fla")
241
+ if spec is None:
242
+ break
243
+ # Resolve the package directory (submodule_search_locations for a package, else the file dir).
244
+ locs = list(getattr(spec, "submodule_search_locations", None) or [])
245
+ if not locs and spec.origin:
246
+ locs = [os.path.dirname(spec.origin)]
247
+ progressed = False
248
+ for loc in locs:
249
+ if loc and os.path.isdir(loc) and os.path.basename(loc.rstrip("/")) == "fla":
250
+ try:
251
+ shutil.rmtree(loc)
252
+ removed.append(loc)
253
+ progressed = True
254
+ except Exception as e:
255
+ print(f"[fla] could not remove {loc}: {e}", flush=True)
256
+ if not progressed:
257
+ break
258
+ importlib.invalidate_caches()
259
+ return removed, importlib.util.find_spec("fla") is not None
260
+
261
+
262
+ def _find_real_libcudart() -> str | None:
263
+ """Path to a REAL ``libcudart.so.<major>`` that exports ``cudaDeviceReset`` (the symbol
264
+ tilelang's stub lacks), or None if none can be found. Prefers the nvidia-cuda-runtime wheel, then
265
+ the CUDA toolkit baked into the worker's -devel base image, then the system loader's own resolver
266
+ — and VERIFIES the symbol is actually present (a path is only returned if ``CDLL(path)`` exposes
267
+ ``cudaDeviceReset``). Never raises.
268
+
269
+ CUDA-major-agnostic: the worker image pins cu12 today, but the runtime wheel's import path and
270
+ soname differ across CUDA majors — the cu12 wheel ships ``nvidia/cuda_runtime/lib/libcudart.so.12``
271
+ while the cu13 wheel ships ``nvidia/cu13/lib/libcudart.so.13`` (and has NO ``nvidia.cuda_runtime``
272
+ module at all). A ``.so.12``-only probe silently returns None on cu13, leaving the stub shadow in
273
+ place. So we probe every ``nvidia/*/lib`` subdir and any ``libcudart.so.*`` major; the symlink
274
+ repoint that consumes this works for any real libcudart (``_verify`` still gates on the symbol)."""
275
+ import ctypes
276
+ import ctypes.util
277
+ import glob
278
+
279
+ def _verify(cand: str) -> str | None:
280
+ """Absolute path to ``cand`` if it loads AND exports cudaDeviceReset, else None. Handles both
281
+ absolute paths (glob results) and bare sonames like ``libcudart.so.12`` (find_library, which
282
+ the loader resolves but ``os.path.exists`` would reject)."""
283
+ try:
284
+ lib = ctypes.CDLL(cand) # an abs path opens directly; a bare soname is loader-resolved
285
+ except OSError:
286
+ return None
287
+ if not hasattr(lib, "cudaDeviceReset"):
288
+ return None
289
+ if os.path.isabs(cand) and os.path.exists(cand):
290
+ return os.path.realpath(cand)
291
+ # Bare soname: resolve to the file the loader actually opened, via /proc/self/maps.
292
+ base = os.path.basename(cand)
293
+ try:
294
+ with open("/proc/self/maps") as f:
295
+ for line in f:
296
+ if base in line and "/" in line:
297
+ p = line[line.index("/"):].rstrip()
298
+ if os.path.basename(p).startswith(base) and os.path.exists(p):
299
+ return os.path.realpath(p)
300
+ except OSError:
301
+ pass
302
+ return None
303
+
304
+ candidates: list[str] = []
305
+ # 1. nvidia cuda-runtime PyPI wheel (a torch/vLLM dep on many images), any CUDA major. Import the
306
+ # ``nvidia`` namespace package (it resolves even when a specific ``nvidia.cuda_runtime`` subpkg
307
+ # is absent — e.g. the cu13 wheel has none) and glob every ``nvidia/*/lib`` for any libcudart
308
+ # soname, so both the cu12 layout (nvidia/cuda_runtime/lib/libcudart.so.12) and the cu13 layout
309
+ # (nvidia/cu13/lib/libcudart.so.13) are found. ``sorted`` keeps candidate order deterministic.
310
+ try:
311
+ import nvidia # type: ignore # namespace package; subpkg import may fail, this won't
312
+
313
+ for base in sorted(map(str, getattr(nvidia, "__path__", []) or [])):
314
+ candidates += sorted(glob.glob(os.path.join(base, "*", "lib", "libcudart.so.*")))
315
+ except Exception:
316
+ pass
317
+ # 2. CUDA toolkit in a -devel base image (Dockerfile.worker today: cuda12.8-cudnn9-devel), any major.
318
+ for pat in (
319
+ "/usr/local/cuda*/lib64/libcudart.so.*",
320
+ "/usr/local/cuda*/targets/*/lib/libcudart.so.*",
321
+ "/usr/lib/x86_64-linux-gnu/libcudart.so.*",
322
+ ):
323
+ candidates += sorted(glob.glob(pat))
324
+ # 3. The loader's own resolver (LD_LIBRARY_PATH / ldconfig) — returns a bare soname, handled above.
325
+ found = ctypes.util.find_library("cudart")
326
+ if found:
327
+ candidates.append(found)
328
+ for cand in candidates:
329
+ real = _verify(cand)
330
+ if real is not None:
331
+ return real
332
+ return None
333
+
334
+
335
+ def _neutralize_tilelang_cudart_stub() -> None:
336
+ """Stop tilelang's bundled ``libcudart_stub.so`` from shadowing the real CUDA runtime in vLLM.
337
+
338
+ tilelang ships a minimal ``libcudart_stub.so`` (soname ``libcudart_stub.so``) that
339
+ ``libtilelang.so`` / ``libtvm.so`` link against; it exports only a SUBSET of the CUDA runtime —
340
+ notably it is MISSING ``cudaDeviceReset``. vLLM's ``vllm/device_allocator/cumem.py`` builds a
341
+ ``CudaRTLibrary`` at MODULE TOP LEVEL (``libcudart = CudaRTLibrary()``), and that module is
342
+ imported on EVERY vLLM init via ``gpu_worker.load_model`` ->
343
+ ``_maybe_get_memory_pool_context(tag="weights")`` — so the crash is NOT gated on sleep mode or
344
+ model size (a 0.8B GRPO run hit it too); any GRPO vLLM init is exposed. ``CudaRTLibrary`` finds
345
+ libcudart by a SUBSTRING scan of ``/proc/self/maps`` and returns the FIRST matching line
346
+ (address-ordered, so host-dependent ~coin-flip). Once tilelang is loaded — the Hopper fla fast
347
+ path, or fla's backend probe on any arch — the stub is mapped into the process and can win that
348
+ scan, so ``CudaRTLibrary()`` dlopens the stub and aborts the import with ``undefined symbol:
349
+ cudaDeviceReset`` before step 0. See flash #184.
350
+
351
+ Fix: BEFORE anything imports tilelang/fla/vllm, repoint the stub path at the REAL
352
+ ``libcudart.so.12`` via a symlink. Then whichever copy the loader (or vLLM's scan) resolves has
353
+ the full symbol set, and the real lib's soname (``libcudart.so.12``) dedupes against the copy
354
+ torch already loaded — so no second CUDA-runtime instance is created and the stub-named mapping
355
+ drops out of ``/proc/self/maps`` entirely. tilelang keeps working: the real runtime is a strict
356
+ superset of the stub it linked against. Applies on EVERY arch and model size (the crash spans
357
+ 0.8B/4B and A100/cheaper classes) and to every provisioning path (baked image or runtime pip),
358
+ since it runs in the worker before the first tilelang load. Must run AFTER
359
+ ``_ensure_fla_fastpath_on_hopper`` (a tilelang (re)install there would otherwise rewrite the
360
+ stub) and BEFORE the model/vLLM import.
361
+
362
+ Idempotent and best-effort: a missing tilelang, a missing stub, an already-healthy stub, or no
363
+ discoverable real runtime is a clean no-op; any error is swallowed (the worker must never crash
364
+ on this hygiene step). No GPU required.
365
+ """
366
+ import importlib.util
367
+
368
+ try:
369
+ spec = importlib.util.find_spec("tilelang")
370
+ except Exception:
371
+ spec = None
372
+ locs = list(getattr(spec, "submodule_search_locations", None) or []) if spec else []
373
+ if not locs:
374
+ return # tilelang not installed -> nothing can shadow libcudart
375
+ stub = os.path.join(locs[0], "lib", "libcudart_stub.so")
376
+ if not os.path.lexists(stub): # lexists: a dangling symlink still counts as present
377
+ return
378
+ # Idempotency WITHOUT loading the stub: we only ever turn the stub into a symlink, and a pristine
379
+ # tilelang always ships it as a regular file, so a RESOLVING symlink here means a prior pass
380
+ # already neutralized it. Crucially, do NOT probe the stub with ctypes.CDLL — that dlopens it (it
381
+ # loads fine under lazy binding despite the missing cudaDeviceReset) and maps it into THIS
382
+ # process's /proc/self/maps, which is exactly the libcudart line vLLM's CudaRTLibrary scan would
383
+ # then pick up -> the very crash we're preventing. The stub must never be loaded; only the file
384
+ # is touched. A DANGLING symlink (our target moved/was removed) is NOT done — it leaves tilelang
385
+ # with a broken libcudart_stub.so, so fall through and re-point it (os.path.exists follows links).
386
+ if os.path.islink(stub) and os.path.exists(stub):
387
+ return
388
+ real = _find_real_libcudart()
389
+ if real is None:
390
+ print(
391
+ "[worker] libcudart stub shadow: no real libcudart found; left as-is (flash #184)",
392
+ flush=True,
393
+ )
394
+ return
395
+ try:
396
+ # Preserve the original stub ONCE (reversible / debuggable), then point the stub path at the
397
+ # real runtime. os.replace is atomic; symlink keeps soname-dedup (no duplicate libcudart).
398
+ backup = stub + ".orig"
399
+ if not os.path.exists(backup):
400
+ os.replace(stub, backup)
401
+ else:
402
+ with contextlib.suppress(FileNotFoundError):
403
+ os.remove(stub)
404
+ os.symlink(real, stub)
405
+ print(f"[worker] redirected tilelang libcudart_stub.so -> {real} (flash #184)", flush=True)
406
+ except Exception as e:
407
+ print(f"[worker] libcudart stub neutralize failed: {e}", flush=True)
408
+
409
+
410
+ # Long-context runs are memory-bound (activations + vLLM KV cache scale with sequence length), so
411
+ # they need the memory features even on a SMALL model — PR #174 measured a 1B model OOM on GRPO at
412
+ # 4096 ctx in speed mode, but it fits in memory mode. So "memory mode" = large model OR long ctx.
413
+ _LONG_CONTEXT_TOKENS = _LIGER_LONG_CTX_TOKENS # canonical value in flash.engine.vram
414
+
415
+
416
+ def _memory_mode(model_id: str, max_length: int = 0) -> bool:
417
+ """Whether to default the memory-saving features (Liger, grad-checkpointing, vLLM sleep) ON:
418
+ a large model (fused-CE memory win) OR a long context (activations/KV dominate). Small model +
419
+ short context -> off (optimize for speed)."""
420
+ if max_length and max_length >= _LONG_CONTEXT_TOKENS:
421
+ return True
422
+ return _liger_default_for_model(model_id)
423
+
424
+
425
+ def grad_checkpointing_on(model_id: str, max_length: int = 0) -> bool:
426
+ """Gradient checkpointing recomputes the forward in backward (~25% slower) to save activation
427
+ memory — a MEMORY feature, not speed. ON for large models / long context that need the
428
+ headroom; OFF for small+short runs that fit without it (the speed win)."""
429
+ return _memory_mode(model_id, max_length)
430
+
431
+
432
+ def grpo_sleep_mode(
433
+ model_id: str,
434
+ *,
435
+ max_length: int = 0,
436
+ group_size: int = 8,
437
+ max_tokens: int | None = None,
438
+ lora_rank: int = 32,
439
+ thinking: bool = False,
440
+ card_vram_gb: float = 0.0,
441
+ ) -> bool:
442
+ """Whether colocated-vLLM GRPO should enable vLLM sleep mode (offload the rollout engine
443
+ between steps).
444
+
445
+ Sleep mode trades a large per-step cost for memory, and on the large-model GRPO path the
446
+ sleep/wake cycle STALLS the colocated rollout (the rollout produces unparseable completions and
447
+ then the worker hangs). So enable it ONLY when the run genuinely can't fit RESIDENT on the card:
448
+ when the policy + colocated rollout engine + training peak all fit on ``card_vram_gb`` (the
449
+ common case on an allocator-sized card), skip sleep mode entirely. Falls back to the
450
+ size/context gate (``_memory_mode``) when the card VRAM is unknown."""
451
+ from flash.engine.vram import grpo_fits_resident, grpo_rollout_seq_len
452
+
453
+ # Gate on the rollout length run_rl() ACTUALLY launches (max(1024, prompt+completion) when
454
+ # [train].max_length is unset -- 2368 default / 3584 thinking), NOT the raw max_length. With
455
+ # max_length unset (0) the size/context pre-filter would see a 0-length "short" run and early-
456
+ # exit for a sub-3B model, skipping the resident-fit check that a long max_tokens rollout needs.
457
+ seq_len = grpo_rollout_seq_len(max_length, max_tokens, thinking)
458
+ if not _memory_mode(model_id, seq_len):
459
+ return False # small model AND genuinely short rollout -> never needed
460
+ if card_vram_gb and card_vram_gb > 0:
461
+ try:
462
+ if grpo_fits_resident(
463
+ model_id,
464
+ seq_len=seq_len,
465
+ max_tokens=max_tokens,
466
+ lora_rank=lora_rank,
467
+ group_size=group_size,
468
+ thinking=thinking,
469
+ card_vram_gb=card_vram_gb,
470
+ ):
471
+ return False # fits resident -> skip the (buggy, slow) sleep/wake cycle
472
+ except Exception as e:
473
+ print("[rl] grpo sleep-mode resident check skipped:", e)
474
+ return True
475
+
476
+
477
+ def fused_optim_name() -> str:
478
+ """TRL/HF ``optim`` value: 8-bit paged AdamW (bitsandbytes int8 optimizer state paged to host
479
+ RAM). It fits a smaller/cheaper GPU and is the better default across the catalog."""
480
+ return "paged_adamw_8bit"
481
+
482
+
483
+ def _reset_peak_gpu() -> None:
484
+ """Reset the CUDA peak-memory counter so a subsequent ``_peak_gpu_gb`` measures only the work
485
+ that follows (e.g. the train loop, isolating the optimizer-state A/B from setup/model load)."""
486
+ try:
487
+ import torch
488
+
489
+ if torch.cuda.is_available():
490
+ torch.cuda.reset_peak_memory_stats()
491
+ except Exception:
492
+ pass
493
+
494
+
495
+ def _peak_gpu_gb() -> float:
496
+ """Peak torch-allocated CUDA memory (GB) since the last reset; 0.0 if no CUDA. Note: bnb paged
497
+ 8-bit optimizer state lives in unified/managed memory outside torch's caching allocator and is
498
+ NOT counted here — so this OVERSTATES the 8-bit saving. _GpuPeakSampler measures the true
499
+ device footprint (incl. bnb managed pages) for the honest A/B number."""
500
+ try:
501
+ import torch
502
+
503
+ if torch.cuda.is_available():
504
+ return round(torch.cuda.max_memory_allocated() / 1e9, 3)
505
+ except Exception:
506
+ pass
507
+ return 0.0
508
+
509
+
510
+ class _GpuPeakSampler:
511
+ """Background sampler of true device memory (GB) = (total - free) from cuda.mem_get_info, which
512
+ DOES include bitsandbytes managed/paged optimizer pages while they're GPU-resident (torch's
513
+ max_memory_allocated does not). This is the honest peak for the fp32-vs-8-bit optimizer A/B."""
514
+
515
+ def __init__(self, interval: float = 0.25):
516
+ self.interval = interval
517
+ self.peak_used = 0
518
+ self._stop = False
519
+ self._thread = None
520
+
521
+ def _run(self):
522
+ import torch
523
+
524
+ while not self._stop:
525
+ try:
526
+ free, total = torch.cuda.mem_get_info()
527
+ self.peak_used = max(self.peak_used, total - free)
528
+ except Exception:
529
+ pass
530
+ time.sleep(self.interval)
531
+
532
+ def start(self):
533
+ try:
534
+ import threading
535
+
536
+ import torch
537
+
538
+ if not torch.cuda.is_available():
539
+ return self
540
+ self._thread = threading.Thread(target=self._run, daemon=True)
541
+ self._thread.start()
542
+ except Exception:
543
+ pass
544
+ return self
545
+
546
+ def stop_gb(self) -> float:
547
+ self._stop = True
548
+ if self._thread is not None:
549
+ self._thread.join(timeout=2)
550
+ return round(self.peak_used / 1e9, 3)
551
+
552
+
553
+ def loraplus_optimizer_cls(optim_name: str):
554
+ """Optimizer class for the LoRA+ ``create_optimizer`` override (returns ``(cls, extra_kwargs)``).
555
+
556
+ The LoRA+ override has to *build* the optimizer itself (PEFT splits the LoRA A/B matrices into
557
+ separate param groups with different LRs), so it cannot inherit TRL's ``optim=`` string — it has
558
+ to choose a concrete class. Historically it always built a full-precision ``torch.optim.AdamW``,
559
+ which silently discarded the catalog's ``paged_adamw_8bit`` setting whenever LoRA+ was on.
560
+
561
+ PEFT's ``create_loraplus_optimizer`` accepts ANY ``optimizer_cls`` — including bitsandbytes 8-bit
562
+ optimizers (it registers embedding overrides with bnb's ``GlobalOptimManager`` to keep them
563
+ 32-bit) — so LoRA+ and the 8-bit paged optimizer state coexist. An ``8bit`` ``optim`` value
564
+ (the fleet default; ``fused_optim_name`` -> ``paged_adamw_8bit``) selects
565
+ ``bnb.optim.PagedAdamW8bit``; a non-8-bit ``optim`` keeps fp32 AdamW. This simply mirrors the
566
+ configured ``optim`` — there is no separate toggle: an on-GPU A/B (Qwen3.5-4B SFT, rank-128
567
+ LoRA, same seed/data/init) measured the 8-bit paged state at -75% optimizer memory
568
+ (1359 -> 346 MB) and -0.72 GB peak with NO convergence penalty (final loss 10.64 vs 11.16 from
569
+ an identical start), so it's unconditionally the default wherever ``optim`` is 8-bit. Falls
570
+ back to fp32 AdamW only if bitsandbytes is missing."""
571
+ import torch as _torch
572
+
573
+ # case-insensitive + str-safe: TRL normalizes optim to an OptimizerNames enum whose str() is
574
+ # "OptimizerNames.PAGED_ADAMW_8BIT" (uppercase), so a bare `"8bit" in optim_name` would miss it.
575
+ if "8bit" in str(optim_name or "").lower():
576
+ try:
577
+ import bitsandbytes as bnb
578
+
579
+ return bnb.optim.PagedAdamW8bit, {}
580
+ except Exception as e: # bnb missing / no CUDA build -> safe fp32 fallback
581
+ print(f"[lora+] bitsandbytes 8-bit optimizer unavailable ({e}); using fp32 AdamW")
582
+ return _torch.optim.AdamW, {}
583
+
584
+
585
+ def _sdpa_cudnn_ctx(attn_impl: str | None):
586
+ """Context forcing the cuDNN SDPA backend (real Blackwell-consumer kernels) when we fell
587
+ back to plain SDPA on sm120; a no-op context otherwise. Best-effort."""
588
+ if attn_impl != "sdpa":
589
+ return contextlib.nullcontext()
590
+ try:
591
+ from torch.nn.attention import SDPBackend, sdpa_kernel
592
+
593
+ # Priority-ordered: prefer the fast cuDNN/flash/efficient kernels, but ALWAYS include MATH
594
+ # as the final fallback. Restricting to only [CUDNN, EFFICIENT] makes sm120 GRPO crash with
595
+ # "RuntimeError: No available kernel" when neither has a kernel for the completion-batch
596
+ # attention shape (MEASURED: Qwen3.5 GRPO on RTX 5090). MATH is universal, so the candidate
597
+ # set is never empty; set_priority keeps cuDNN first whenever it CAN serve the shape (SFT
598
+ # fast path unchanged), only falling through for the shapes cuDNN/efficient reject.
599
+ return sdpa_kernel(
600
+ [
601
+ SDPBackend.CUDNN_ATTENTION,
602
+ SDPBackend.FLASH_ATTENTION,
603
+ SDPBackend.EFFICIENT_ATTENTION,
604
+ SDPBackend.MATH,
605
+ ],
606
+ set_priority=True,
607
+ )
608
+ except Exception as e:
609
+ print("[attn] cuDNN SDPA backend unavailable, using default SDPA:", e)
610
+ return contextlib.nullcontext()
611
+
612
+
613
+ def _float_or_none(value) -> float | None:
614
+ try:
615
+ text = str(value).strip()
616
+ if not text or text.upper() in {"N/A", "[N/A]", "NOT SUPPORTED", "[NOT SUPPORTED]"}:
617
+ return None
618
+ return float(text)
619
+ except (TypeError, ValueError):
620
+ return None
621
+
622
+
623
+ def _int_or_none(value) -> int | None:
624
+ num = _float_or_none(value)
625
+ return int(num) if num is not None else None
626
+
627
+
628
+ def _round_gb_from_mib(value) -> float | None:
629
+ num = _float_or_none(value)
630
+ if num is None:
631
+ return None
632
+ return round(num / 1024.0, 3)
633
+
634
+
635
+ def _clean_diag(diag: dict) -> dict:
636
+ return {k: v for k, v in diag.items() if v is not None and v != ""}
637
+
638
+
639
+ def _query_nvidia_gpu() -> dict:
640
+ import subprocess
641
+
642
+ fields = [
643
+ "index",
644
+ "uuid",
645
+ "driver_version",
646
+ "name",
647
+ "utilization.gpu",
648
+ "utilization.memory",
649
+ "memory.total",
650
+ "memory.used",
651
+ "memory.free",
652
+ "temperature.gpu",
653
+ "power.draw",
654
+ "power.limit",
655
+ "pstate",
656
+ "clocks.sm",
657
+ "clocks.mem",
658
+ "pcie.link.gen.current",
659
+ "pcie.link.width.current",
660
+ ]
661
+ out = subprocess.run(
662
+ ["nvidia-smi", f"--query-gpu={','.join(fields)}", "--format=csv,noheader,nounits"],
663
+ capture_output=True,
664
+ text=True,
665
+ timeout=8.0, # nvidia-smi diag timeout (fixed; flash is fully managed)
666
+ )
667
+ raw = (out.stdout or out.stderr).strip()
668
+ if out.returncode != 0:
669
+ return {"nvidia_smi_err": raw[:300]}
670
+ rows = list(csv.reader(raw.splitlines()))
671
+ if not rows:
672
+ return {}
673
+ first = [cell.strip() for cell in rows[0]]
674
+ row = dict(zip(fields, first, strict=False))
675
+ diag = {
676
+ "index": _int_or_none(row.get("index")),
677
+ "uuid": row.get("uuid"),
678
+ "driver_version": row.get("driver_version"),
679
+ "device_name": row.get("name"),
680
+ "gpu_util_pct": _int_or_none(row.get("utilization.gpu")),
681
+ "mem_util_pct": _int_or_none(row.get("utilization.memory")),
682
+ "memory_total_gb": _round_gb_from_mib(row.get("memory.total")),
683
+ "memory_used_gb": _round_gb_from_mib(row.get("memory.used")),
684
+ "memory_free_gb": _round_gb_from_mib(row.get("memory.free")),
685
+ "temperature_c": _int_or_none(row.get("temperature.gpu")),
686
+ "power_w": _float_or_none(row.get("power.draw")),
687
+ "power_limit_w": _float_or_none(row.get("power.limit")),
688
+ "pstate": row.get("pstate"),
689
+ "sm_clock_mhz": _int_or_none(row.get("clocks.sm")),
690
+ "mem_clock_mhz": _int_or_none(row.get("clocks.mem")),
691
+ "pcie_gen": _int_or_none(row.get("pcie.link.gen.current")),
692
+ "pcie_width": _int_or_none(row.get("pcie.link.width.current")),
693
+ }
694
+ clean = _clean_diag(diag)
695
+ clean["nvidia_smi"] = raw[:300]
696
+ return clean
697
+
698
+
699
+ def _query_nvidia_processes() -> list[dict]:
700
+ import subprocess
701
+
702
+ out = subprocess.run(
703
+ [
704
+ "nvidia-smi",
705
+ "--query-compute-apps=pid,process_name,used_memory",
706
+ "--format=csv,noheader,nounits",
707
+ ],
708
+ capture_output=True,
709
+ text=True,
710
+ timeout=8.0, # nvidia-smi diag timeout (fixed; flash is fully managed)
711
+ )
712
+ if out.returncode != 0 or not out.stdout.strip():
713
+ return []
714
+ rows = []
715
+ for row in csv.reader(out.stdout.splitlines()):
716
+ if len(row) < 3:
717
+ continue
718
+ rows.append(
719
+ _clean_diag(
720
+ {
721
+ "pid": _int_or_none(row[0]),
722
+ "process_name": row[1].strip(),
723
+ "used_memory_gb": _round_gb_from_mib(row[2]),
724
+ }
725
+ )
726
+ )
727
+ return sorted(rows, key=lambda r: float(r.get("used_memory_gb") or 0.0), reverse=True)[:8]
728
+
729
+
730
+ def gpu_diagnostics(include_torch: bool = True) -> dict:
731
+ """Collect live CUDA/GPU telemetry for run logs and status."""
732
+ diag = {}
733
+ if include_torch:
734
+ try:
735
+ import torch
736
+
737
+ diag["torch"] = torch.__version__
738
+ diag["torch_cuda"] = torch.version.cuda
739
+ diag["cuda_available"] = torch.cuda.is_available()
740
+ try:
741
+ diag["device_count"] = torch.cuda.device_count()
742
+ if torch.cuda.is_available():
743
+ diag["device_name"] = torch.cuda.get_device_name(0)
744
+ free, total = torch.cuda.mem_get_info()
745
+ diag["torch_memory_free_gb"] = round(free / (1024**3), 3)
746
+ diag["torch_memory_total_gb"] = round(total / (1024**3), 3)
747
+ diag["torch_memory_allocated_gb"] = round(
748
+ torch.cuda.memory_allocated() / (1024**3), 3
749
+ )
750
+ diag["torch_memory_reserved_gb"] = round(
751
+ torch.cuda.memory_reserved() / (1024**3), 3
752
+ )
753
+ except Exception as e:
754
+ diag["device_query_err"] = str(e)[:160]
755
+ except Exception as e:
756
+ diag["torch_import_err"] = str(e)[:160]
757
+ try:
758
+ diag.update(_query_nvidia_gpu())
759
+ processes = _query_nvidia_processes()
760
+ if processes:
761
+ diag["processes"] = processes
762
+ except Exception as e:
763
+ diag["nvidia_smi_err"] = str(e)[:160]
764
+ return _clean_diag(diag)
765
+
766
+
767
+ # Human-readable sentinel embedded in the error message (debug tag only — the runner classifies
768
+ # structurally off the worker's heartbeat ``retriable`` flag, not by matching this phrase).
769
+ RETRIABLE_INFRA_MARKER = "RETRIABLE_INFRA_GPU"
770
+
771
+
772
+ class RetriableInfraError(RuntimeError):
773
+ """An infrastructure failure the control plane should RETRY on a fresh worker.
774
+
775
+ Raised for a host the run can never train on — e.g. a GPU that never comes up
776
+ (``wait_for_gpu`` times out) or a required-upload failure. The worker's top-level handler
777
+ stamps ``retriable=True`` into heartbeat.json so the runner retries on a fresh worker.
778
+ """
779
+
780
+ def __init__(self, reason: str):
781
+ super().__init__(f"{RETRIABLE_INFRA_MARKER}: {reason}")
782
+
783
+
784
+ def detect_mig_slice() -> str | None:
785
+ """Return a reason string if this worker was handed a MIG slice (a partitioned GPU), else None.
786
+
787
+ A MIG slice NVML-asserts PyTorch's CUDA allocator — observed when a provider substitutes a
788
+ requested GPU type with a Blackwell MIG slice — which crashes the run with an opaque allocator
789
+ assert partway through setup. Detect it up front (via nvidia-smi, before the first real CUDA op)
790
+ so the worker can fail fast as RETRIABLE infra and the runner re-provisions a fresh FULL GPU,
791
+ rather than letting the run die mid-setup. Best-effort: never raises (a missing/odd nvidia-smi
792
+ just means "no MIG detected", which the subsequent CUDA readiness probe still guards)."""
793
+ import subprocess
794
+
795
+ try:
796
+ out = subprocess.run(
797
+ ["nvidia-smi", "-L"], capture_output=True, text=True, timeout=20
798
+ ).stdout
799
+ # A MIG slice appears as a nested device line, e.g.
800
+ # " MIG 1g.10gb Device 0: (UUID: MIG-xxxx)" (or any "UUID: MIG-..." entry).
801
+ for line in out.splitlines():
802
+ s = line.strip()
803
+ if s.startswith("MIG ") or "UUID: MIG-" in s:
804
+ return f"MIG slice detected (nvidia-smi -L: {s[:120]!r})"
805
+ except Exception:
806
+ pass
807
+ try:
808
+ q = subprocess.run(
809
+ ["nvidia-smi", "--query-gpu=mig.mode.current", "--format=csv,noheader"],
810
+ capture_output=True, text=True, timeout=20,
811
+ ).stdout.strip()
812
+ # "Enabled" => the assigned GPU is partitioned into MIG instances (no full-GPU access).
813
+ # "Disabled"/"N/A"/"[Not Supported]" (consumer + MIG-incapable cards) => fine.
814
+ if q and "enabled" in q.lower():
815
+ return f"MIG mode enabled on the assigned GPU (mig.mode.current={q!r})"
816
+ except Exception:
817
+ pass
818
+ return None
819
+
820
+
821
+ def wait_for_gpu():
822
+ """Rented nodes sometimes report 'CUDA device not ready' transiently at startup.
823
+ Poll a trivial CUDA op until it succeeds before doing real work; raise if never ready.
824
+
825
+ Also fails fast (retriable) if the assigned GPU is a MIG slice — a partitioned GPU crashes the
826
+ CUDA allocator, so we re-provision on a fresh FULL GPU instead of dying mid-setup."""
827
+ import time as _t
828
+
829
+ mig = detect_mig_slice()
830
+ if mig:
831
+ # Infra-shaped: a MIG slice can never train this workload -> retry on a fresh full GPU.
832
+ raise RetriableInfraError(f"{mig}; retrying on a fresh full (non-MIG) GPU")
833
+
834
+ last = None
835
+ for i in range(12):
836
+ try:
837
+ import torch
838
+
839
+ if torch.cuda.is_available():
840
+ # Force an actual kernel launch (alloc + add) to confirm the GPU is live.
841
+ _ = torch.zeros(8, device="cuda") + 1
842
+ torch.cuda.synchronize()
843
+ print(f"GPU ready after {i} retries: {torch.cuda.get_device_name(0)}")
844
+ return True
845
+ last = "cuda not available"
846
+ except Exception as e:
847
+ last = str(e)[:160]
848
+ print(f"GPU not ready (try {i + 1}/12): {last}; sleeping 10s")
849
+ _t.sleep(10)
850
+ # Infra-shaped: a host whose GPU never comes up is dead, not a code bug -> retry on a fresh one.
851
+ raise RetriableInfraError(f"GPU never became ready after 12 tries: {last}")
852
+
853
+
854
+ def free_gpu(trainer=None):
855
+ try:
856
+ import gc
857
+
858
+ import torch
859
+
860
+ try:
861
+ if trainer is not None and hasattr(trainer, "model"):
862
+ trainer.model = None
863
+ except Exception:
864
+ # Best-effort VRAM release before gc; any failure here is non-fatal cleanup.
865
+ pass
866
+ gc.collect()
867
+ if torch.cuda.is_available():
868
+ torch.cuda.empty_cache()
869
+ except Exception as e:
870
+ print("free_gpu warn:", e)
871
+
872
+
873
+ def _metric_curve(trainer, key: str) -> list:
874
+ """The logged values of `key` (e.g. 'loss' or 'reward') from the trainer's log history,
875
+ rounded + capped. Lets metrics.json carry the convergence/reward curve for an A/B without
876
+ relying on a checkpoint's trainer_state.json (only written on save_steps) or the console
877
+ (only uploaded on failure). Never raises."""
878
+ try:
879
+ vals = [round(float(h[key]), 4) for h in trainer.state.log_history if key in h]
880
+ return vals[:400]
881
+ except Exception:
882
+ return []
883
+
884
+
885
+ def _ensure_fla_fastpath_on_hopper() -> None:
886
+ """Make flash-linear-attention's GatedDeltaNet fast path CORRECT + fast on Hopper (sm90)
887
+ instead of dropping it.
888
+
889
+ fla's gated chunk_bwd Triton kernel is miscomputed on Hopper with Triton>=3.4 and HARD-RAISES
890
+ (fla #640). The worker historically DROPPED fla here and fell back to the pure-PyTorch delta
891
+ rule — correct but slow + memory-heavy. The real fix is fla's **tilelang** backend, which is
892
+ correct on Triton>=3.4. So on Hopper we ensure the working stack is present rather than
893
+ removing fla:
894
+ * the pinned ``tilelang==0.1.11`` (the correct GDN chunk_bwd backend) + the pinned
895
+ ``apache-tvm-ffi==0.1.11`` (0.1.12 double-registers the TVM-FFI runtime -> ``import
896
+ tilelang`` aborts; and tilelang's own ``apache-tvm-ffi~=0.1.0`` range would let 0.1.12
897
+ back in, so the pin is force-installed last and its resolved version is verified), and
898
+ * a COMPLETE ``fla`` (the PyPI ``flash-linear-attention`` wheel is a broken stub missing
899
+ ``fla.modules``; reinstall from git if the resident copy is incomplete).
900
+ Validated A/B (H100 SXM, Qwen3.5 hidden-2560 LoRA, controlled fla on/off): seq4096 435->105
901
+ ms/step & 9.9->6.1 GB (4.2x / 1.6x); seq8192 7.1x; seq16384 3106->247 ms & 32->17 GB (12.6x /
902
+ 1.9x). Forward loss matches the torch delta to 1.8e-4 (correct). Runs in the worker process,
903
+ after all installs and BEFORE any model import. Non-Hopper:
904
+ no-op (fla's Triton kernel is correct there). Best-effort + FAIL-CLOSED: a failed install
905
+ (pip rc!=0), a missing module, or the wrong resolved ``apache-tvm-ffi`` version all flip the
906
+ gate off and DISABLE fla, leaving the (correct) pure-PyTorch delta rule in place — a worker
907
+ never crashes on a dep hiccup, and it never silently runs fla's broken Hopper GDN kernel.
908
+ """
909
+ import importlib
910
+ import importlib.util
911
+ import subprocess
912
+
913
+ try:
914
+ import torch
915
+
916
+ if not (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 9):
917
+ return # not Hopper: fla's Triton kernel is correct here.
918
+ except Exception:
919
+ return
920
+
921
+ def _have(mod: str) -> bool:
922
+ try:
923
+ return importlib.util.find_spec(mod) is not None
924
+ except Exception:
925
+ return False
926
+
927
+ def _ver(dist: str) -> str | None:
928
+ """Installed version of a distribution (by metadata), or None if absent/unreadable.
929
+
930
+ Distinct from _have (a find_spec import probe): the install can silently leave the WRONG
931
+ version resolved (e.g. tilelang's ``apache-tvm-ffi~=0.1.0`` range happily keeps 0.1.12,
932
+ which still find_spec-imports but aborts ``import tilelang``), so the gate must check the
933
+ actual installed version, not just importability.
934
+ """
935
+ try:
936
+ import importlib.metadata as _md
937
+
938
+ return _md.version(dist)
939
+ except Exception:
940
+ return None
941
+
942
+ def _pip(*args: str) -> bool:
943
+ """Run pip install; return True only if pip exited 0. A failed install (network/build/
944
+ resolver) must NOT be silently treated as success — the caller gates ``ok`` on this."""
945
+ try:
946
+ rc = subprocess.run(
947
+ [sys.executable, "-m", "pip", "install", "-q", *args], check=False
948
+ ).returncode
949
+ except Exception:
950
+ return False
951
+ return rc == 0
952
+
953
+ # The exact apache-tvm-ffi pin the tilelang backend needs (0.1.12 double-registers the TVM-FFI
954
+ # runtime -> `import tilelang` aborts). Kept as a constant so the install spec and the post-
955
+ # install version gate below can't drift apart. Keep in lockstep with WORKER_DEPS / Dockerfile.
956
+ TVM_FFI_PIN = "0.1.11"
957
+ TILELANG_PIN = "0.1.11" # pin the GDN backend too (same rationale as the fla SHA pin)
958
+
959
+ try:
960
+ # 1. tilelang backend (correct GDN chunk_bwd on Triton>=3.4) + the pinned tvm-ffi.
961
+ # Track whether each install actually succeeded — a failed pip (rc!=0) must flip the
962
+ # gate to the pure-PyTorch fallback rather than be ignored. (_have-only would also pass
963
+ # on a stale/partial copy from a previous boot.) tilelang pulls apache-tvm-ffi via a
964
+ # range that allows the broken 0.1.12, so force-reinstall the exact pin AFTER tilelang
965
+ # and verify the resolved version below.
966
+ # Enforce the EXACT pin: (re)install when tilelang is absent OR a different version is
967
+ # resident (a job or the base image may carry another tilelang; _have-only would treat that
968
+ # as healthy and skip the install, leaving the wrong/uncertain GDN backend in place). Mirror
969
+ # the apache-tvm-ffi handling: check the installed version via _ver and reinstall on mismatch.
970
+ tilelang_ok = True
971
+ tilelang_reinstalled = False
972
+ if _ver("tilelang") != TILELANG_PIN:
973
+ tilelang_ok = _pip(f"tilelang=={TILELANG_PIN}")
974
+ tilelang_reinstalled = True
975
+ # Only force the tvm-ffi pin when it's actually wrong OR tilelang was just (re)installed
976
+ # (tilelang's apache-tvm-ffi~=0.1.0 range can have pulled the broken 0.1.12). Skipping the pip
977
+ # when the exact pin is already resident avoids avoidable cold-start latency and a spurious
978
+ # disable on a transient network/resolver failure — the ok gate still re-verifies the version.
979
+ # If this install runs and fails we DON'T trust the resident copy — tvm_ffi_ok gates `ok` below.
980
+ if _ver("apache-tvm-ffi") != TVM_FFI_PIN or tilelang_reinstalled:
981
+ tvm_ffi_ok = _pip(f"apache-tvm-ffi=={TVM_FFI_PIN}")
982
+ else:
983
+ tvm_ffi_ok = True
984
+ # 2. a COMPLETE fla — the PyPI wheel ships a stub without `fla.modules`. Reinstall from git
985
+ # when the resident copy is missing the real package (or absent entirely).
986
+ fla_ok = True
987
+ if not (_have("fla") and _have("fla.modules")):
988
+ _remove_fla_from_disk() # clear any broken stub before the git install
989
+ # Pinned to the same commit as WORKER_DEPS / Dockerfile.worker so a runtime reinstall is
990
+ # reproducible (the moving default branch could pull a broken/incompatible fla).
991
+ fla_ok = _pip(
992
+ "--no-deps",
993
+ "git+https://github.com/fla-org/flash-linear-attention.git"
994
+ "@f0e213dbd8b5fb90c3c7eca869ac1706d5377139",
995
+ )
996
+ importlib.invalidate_caches()
997
+ # Gate on BOTH (a) every install we ran exiting 0 — a failed pip (network/build/resolver)
998
+ # must NOT be treated as healthy just because a stale/partial copy still find_spec-imports —
999
+ # AND (b) the modules importing AND (c) the resolved apache-tvm-ffi being exactly the pin.
1000
+ # (c) matters because tilelang depends on `apache-tvm-ffi~=0.1.0`, so the resolver can keep
1001
+ # the broken 0.1.12 (which find_spec-imports fine but aborts `import tilelang`); checking the
1002
+ # version is the only reliable signal the pin actually landed.
1003
+ tvm_ffi_ver = _ver("apache-tvm-ffi")
1004
+ tilelang_ver = _ver("tilelang")
1005
+ installs_ok = tilelang_ok and tvm_ffi_ok and fla_ok
1006
+ ok = (
1007
+ installs_ok
1008
+ and _have("fla")
1009
+ and _have("fla.modules")
1010
+ and _have("tilelang")
1011
+ and tilelang_ver == TILELANG_PIN
1012
+ and tvm_ffi_ver == TVM_FFI_PIN
1013
+ )
1014
+ if not ok:
1015
+ # The healthy fla+tilelang stack could not be assembled, so fla's GDN chunk_bwd would
1016
+ # still hit the broken Triton>=3.4 path on Hopper (fla #640) and HARD-RAISE. A print
1017
+ # alone does NOT prevent that: transformers gates GDN on is_fla_available() (a
1018
+ # find_spec('fla') probe), so as long as fla stays importable it gets engaged. PHYSICALLY
1019
+ # remove fla so the probe sees it gone and transformers uses the correct pure-PyTorch
1020
+ # delta rule instead of crashing. _remove_fla_from_disk loops over the real sys.path +
1021
+ # invalidates caches, so find_spec('fla') is None afterwards (the gate flips off).
1022
+ _removed, _still = _remove_fla_from_disk()
1023
+ print(
1024
+ "[hopper] fla GDN fast path unavailable -> DISABLING fla "
1025
+ f"(installs_ok={installs_ok} [tilelang={tilelang_ok} tvm_ffi={tvm_ffi_ok} "
1026
+ f"fla={fla_ok}], tilelang_ver={tilelang_ver!r} (want {TILELANG_PIN}), "
1027
+ f"tvm_ffi_ver={tvm_ffi_ver!r} (want {TVM_FFI_PIN}); "
1028
+ f"removed {len(_removed)} copy(ies); still_importable={_still}); "
1029
+ "pure-PyTorch delta fallback",
1030
+ flush=True,
1031
+ )
1032
+ else:
1033
+ print(
1034
+ "[hopper] fla GDN fast path ENABLED (fla+tilelang "
1035
+ f"{tilelang_ver}/tvm-ffi {tvm_ffi_ver}, fla #640 fixed)",
1036
+ flush=True,
1037
+ )
1038
+ except Exception as e: # never let a dep hiccup crash the worker — torch delta still runs
1039
+ # Fail-closed: an unexpected error mid-setup must still leave Hopper on the correct
1040
+ # pure-PyTorch delta path, not a half-configured fla that transformers would engage and
1041
+ # crash on (#640). Best-effort disable fla; never re-raise.
1042
+ with contextlib.suppress(Exception):
1043
+ _remove_fla_from_disk()
1044
+ print(
1045
+ f"[hopper] fla fast-path setup errored ({type(e).__name__}: {e}); "
1046
+ "disabled fla -> pure-PyTorch delta",
1047
+ flush=True,
1048
+ )