freesolo-flash-dev 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flash/__init__.py +29 -0
- flash/_channel.py +23 -0
- flash/_fileio.py +35 -0
- flash/_logging.py +49 -0
- flash/_update_check.py +266 -0
- flash/catalog.py +253 -0
- flash/cli/__init__.py +1 -0
- flash/cli/main/__init__.py +227 -0
- flash/cli/main/__main__.py +6 -0
- flash/cli/main/commands.py +636 -0
- flash/cli/main/envpush.py +317 -0
- flash/cli/main/render.py +599 -0
- flash/cli/main/training_doc.py +455 -0
- flash/client/__init__.py +14 -0
- flash/client/config.py +70 -0
- flash/client/http.py +372 -0
- flash/client/runtime_secrets.py +69 -0
- flash/client/specs.py +20 -0
- flash/cost/__init__.py +16 -0
- flash/cost/analytical.py +175 -0
- flash/cost/facts.py +114 -0
- flash/cost/spec.py +113 -0
- flash/cost/types.py +158 -0
- flash/engine/__init__.py +6 -0
- flash/engine/accounting.py +36 -0
- flash/engine/chalk_kernels.py +116 -0
- flash/engine/multiturn_rollout.py +780 -0
- flash/engine/recipe.py +86 -0
- flash/engine/vram.py +603 -0
- flash/engine/worker/__init__.py +2916 -0
- flash/engine/worker/__main__.py +4 -0
- flash/engine/worker/kernel_warmup.py +400 -0
- flash/engine/worker/lora.py +796 -0
- flash/engine/worker/packing.py +366 -0
- flash/engine/worker/perf.py +1048 -0
- flash/envs/__init__.py +10 -0
- flash/envs/adapter/__init__.py +883 -0
- flash/envs/adapter/rubric.py +222 -0
- flash/envs/base.py +52 -0
- flash/envs/registry.py +62 -0
- flash/mcp/__init__.py +1 -0
- flash/mcp/server.py +85 -0
- flash/providers/__init__.py +59 -0
- flash/providers/_auth.py +24 -0
- flash/providers/_http.py +230 -0
- flash/providers/_instance.py +416 -0
- flash/providers/_instance_bootstrap.py +517 -0
- flash/providers/_poll.py +311 -0
- flash/providers/allocator.py +193 -0
- flash/providers/base.py +431 -0
- flash/providers/hyperstack/__init__.py +127 -0
- flash/providers/hyperstack/api.py +522 -0
- flash/providers/hyperstack/auth.py +17 -0
- flash/providers/hyperstack/gpus.py +29 -0
- flash/providers/hyperstack/jobs/__init__.py +632 -0
- flash/providers/hyperstack/jobs/builders.py +122 -0
- flash/providers/hyperstack/preflight.py +23 -0
- flash/providers/hyperstack/pricing.py +26 -0
- flash/providers/hyperstack/train.py +25 -0
- flash/providers/lambdalabs/__init__.py +139 -0
- flash/providers/lambdalabs/api.py +261 -0
- flash/providers/lambdalabs/auth.py +18 -0
- flash/providers/lambdalabs/gpus.py +29 -0
- flash/providers/lambdalabs/jobs/__init__.py +724 -0
- flash/providers/lambdalabs/jobs/builders.py +118 -0
- flash/providers/lambdalabs/preflight.py +27 -0
- flash/providers/lambdalabs/pricing.py +51 -0
- flash/providers/lambdalabs/train.py +27 -0
- flash/providers/preflight.py +55 -0
- flash/providers/realized.py +80 -0
- flash/providers/runpod/__init__.py +130 -0
- flash/providers/runpod/api.py +186 -0
- flash/providers/runpod/auth.py +37 -0
- flash/providers/runpod/cost.py +57 -0
- flash/providers/runpod/gpus.py +46 -0
- flash/providers/runpod/jobs.py +956 -0
- flash/providers/runpod/keys.py +139 -0
- flash/providers/runpod/preflight.py +30 -0
- flash/providers/runpod/preload.py +915 -0
- flash/providers/runpod/pricing.py +18 -0
- flash/providers/runpod/slots.py +79 -0
- flash/providers/runpod/train/__init__.py +150 -0
- flash/providers/runpod/train/deps.py +395 -0
- flash/providers/runpod/train/endpoints.py +820 -0
- flash/py.typed +0 -0
- flash/runner/__init__.py +686 -0
- flash/runner/checkpoints.py +82 -0
- flash/runner/deploy.py +422 -0
- flash/runner/lifecycle.py +672 -0
- flash/schema/__init__.py +375 -0
- flash/schema/fields.py +331 -0
- flash/serve/__init__.py +1 -0
- flash/serve/deploy.py +326 -0
- flash/serve/pricing.py +60 -0
- flash/server/__init__.py +1 -0
- flash/server/__main__.py +20 -0
- flash/server/app.py +961 -0
- flash/server/auth.py +263 -0
- flash/server/billing.py +124 -0
- flash/server/checkpoints.py +110 -0
- flash/server/db.py +160 -0
- flash/server/environment_registry.py +102 -0
- flash/server/envs.py +360 -0
- flash/server/reconcile.py +163 -0
- flash/server/run_registry.py +150 -0
- flash/spec.py +333 -0
- freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
- freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
- freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
- freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
- freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,1048 @@
|
|
|
1
|
+
"""Pure GPU/perf/optimizer probes for the fine-tuning worker.
|
|
2
|
+
|
|
3
|
+
These helpers take the model id / max length / capability as ARGUMENTS and read NONE of
|
|
4
|
+
the worker's run-scoped module globals (HF_REPO/RUN_ID/SEED/RUN_MODE/PHASE/JOB_SPEC/
|
|
5
|
+
ACTIVE_ENV/THINKING or the _HB_* heartbeat family), so they live here as a leaf module.
|
|
6
|
+
``flash.engine.worker`` re-exports them; this module must NOT import that package (no cycle).
|
|
7
|
+
Torch and other heavy deps are imported lazily inside the functions (CPU-importable).
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import contextlib
|
|
13
|
+
import csv
|
|
14
|
+
import os
|
|
15
|
+
import sys
|
|
16
|
+
import time
|
|
17
|
+
|
|
18
|
+
# Fused-CE (Liger) gate thresholds live in ONE place — flash.engine.vram — so the worker's run-time
|
|
19
|
+
# gate and the cost estimator's offline mirror (sft_logits_fused) can never drift. vram is a pure
|
|
20
|
+
# leaf (no worker import), so this is cycle-free.
|
|
21
|
+
from flash.engine.vram import _LIGER_LONG_CTX_TOKENS, _LIGER_MIN_PARAMS_B
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _attn_impl_for_capability(
|
|
25
|
+
major: int, minor: int = 0, *, fa3_available: bool = False, fa2_available: bool = False
|
|
26
|
+
) -> str | None:
|
|
27
|
+
"""Map a CUDA compute capability to the trainer ``attn_implementation`` — the best per-arch
|
|
28
|
+
FlashAttention kernel for the model's FULL-attention (softmax) layers, so SFT *and* GRPO use a
|
|
29
|
+
real flash kernel on every arch where one exists (not plain SDPA). The Qwen3.5/3.6
|
|
30
|
+
Gated-DeltaNet *linear*-attention layers always keep their own path (fla, or the native
|
|
31
|
+
pure-PyTorch delta rule once fla is dropped on Hopper) — FlashAttention does not apply to linear
|
|
32
|
+
attention.
|
|
33
|
+
|
|
34
|
+
Each arch maps to its ONE best flash kernel; the fallback is UNIFORM — plain SDPA on every arch
|
|
35
|
+
when that kernel's package is absent (no special FA3->FA2 chain on Hopper):
|
|
36
|
+
* Hopper (sm90, H100/H200): "flash_attention_3" — FA3's warp-specialized async kernels are the
|
|
37
|
+
fastest exact attention on Hopper; transformers routes it to the LOCAL ``flash_attn_interface``
|
|
38
|
+
(no HF Kernels-Hub, whose torch2.10 versions break ``import transformers``). FA3 is baked into
|
|
39
|
+
the worker image by default (Dockerfile FLASH_ATTN_3_SPEC), so ``fa3_available`` is normally
|
|
40
|
+
True; absent -> plain SDPA, same as every other arch.
|
|
41
|
+
* Ampere (sm80 A100 / sm86 3090·A6000) + Ada (sm89 4090·L40S): "flash_attention_2" when the
|
|
42
|
+
``flash_attn`` wheel is importable (``fa2_available``) — FA3 does NOT support these archs.
|
|
43
|
+
* consumer Blackwell (sm120 5090 / RTX Pro): "sdpa" forced to the cuDNN backend. THE ONE arch
|
|
44
|
+
with no flash: FA3/FA4 need TMEM/tcgen05 that sm120 lacks, and the prebuilt FA2 CUDA wheel's
|
|
45
|
+
sm120 coverage is unverified, so cuDNN SDPA is the validated best here.
|
|
46
|
+
* anything else / flash unavailable -> None: transformers picks SDPA (already flash-backed on
|
|
47
|
+
Ampere/Ada/Hopper).
|
|
48
|
+
Pure function (no torch / no imports) so it's unit-testable on CPU; ``fa2_available`` /
|
|
49
|
+
``fa3_available`` are the caller's probes (``optimal_attn_impl``). The big LoRA win is still the
|
|
50
|
+
Liger/chalk fused kernels; flash helps only the ~25% full-attention layers of the hybrid arch."""
|
|
51
|
+
if major == 9 and fa3_available: # Hopper: FA3 is the arch's best flash kernel
|
|
52
|
+
return "flash_attention_3"
|
|
53
|
+
if major == 8 and minor in (0, 6, 9) and fa2_available: # Ampere 8.0/8.6 + Ada 8.9 ONLY: FA2
|
|
54
|
+
# (gate the minor so an unsupported sm8x like sm87 Jetson Orin doesn't get FA2 forced on it)
|
|
55
|
+
return "flash_attention_2"
|
|
56
|
+
if (
|
|
57
|
+
major == 12
|
|
58
|
+
): # consumer Blackwell: cuDNN SDPA (the one exception — FA3/FA4 need TMEM/tcgen05)
|
|
59
|
+
return "sdpa"
|
|
60
|
+
return None # the arch's flash kernel is absent -> plain SDPA (the SAME fallback on every arch)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _flash_attn_3_available() -> bool:
|
|
64
|
+
"""True when FlashAttention-3 is usable by transformers on this worker — i.e. the
|
|
65
|
+
``flash_attn_interface`` module (the ``flash-attn-3`` Hopper build) is importable.
|
|
66
|
+
|
|
67
|
+
transformers' ``flash_attention_3`` path does ``from flash_attn_interface import
|
|
68
|
+
flash_attn_func, ...`` (modeling_flash_attention_utils), so a present module is exactly what
|
|
69
|
+
makes ``attn_implementation="flash_attention_3"`` resolve WITHOUT the HF Kernels-Hub. Prefer
|
|
70
|
+
transformers' own ``is_flash_attn_3_available`` probe (it verifies real importability). Only if
|
|
71
|
+
that probe is itself unavailable (transformers not importable here) fall back to a GUARDED import
|
|
72
|
+
of ``flash_attn_interface`` — NOT a bare ``find_spec``, so an on-disk-but-broken install (ABI
|
|
73
|
+
mismatch / missing .so) reads as unavailable instead of a false positive that would later crash
|
|
74
|
+
transformers at model load. FA3 is used whenever it's importable — fixed, no disable knob."""
|
|
75
|
+
try:
|
|
76
|
+
from transformers.utils import is_flash_attn_3_available
|
|
77
|
+
|
|
78
|
+
return bool(is_flash_attn_3_available())
|
|
79
|
+
except Exception:
|
|
80
|
+
try:
|
|
81
|
+
import flash_attn_interface # noqa: F401 (guarded: verifies real importability)
|
|
82
|
+
|
|
83
|
+
return True
|
|
84
|
+
except Exception:
|
|
85
|
+
return False
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _flash_attn_available() -> bool:
|
|
89
|
+
"""True when the ``flash_attn`` (FA2) wheel is importable (baked into the worker image).
|
|
90
|
+
|
|
91
|
+
Drives the FA2 ``attn_implementation`` selection on Ampere/Ada (via ``_attn_impl_for_capability``)
|
|
92
|
+
AND the SFT packing default on every arch. ``_attn_impl_for_capability`` itself never picks FA2 on
|
|
93
|
+
Hopper (FA3, else uniform SDPA); FA2 re-enters there ONLY through the SFT packing path, which
|
|
94
|
+
forces FA2 varlen when ``optimal_attn_impl`` returned None (Hopper without FA3). On sm120 the
|
|
95
|
+
selector returns ``"sdpa"`` and run_sft DISABLES packing instead (consumer Blackwell stays plain
|
|
96
|
+
SDPA — no flash), so sm120 never forces FA2. Packing rationale: TRL's ``packing_strategy='bfd'``
|
|
97
|
+
produces flattened/padding-free
|
|
98
|
+
batches whose example boundaries are carried by ``position_ids`` and enforced ONLY by a
|
|
99
|
+
varlen-capable attention impl (FA2/FA3/flex). Under plain SDPA, packed examples attend ACROSS
|
|
100
|
+
boundaries (silent quality loss). find_spec only — no import side effects (no CUDA init). FA2 is
|
|
101
|
+
used whenever the wheel is importable — fixed, no disable knob."""
|
|
102
|
+
try:
|
|
103
|
+
import importlib.util
|
|
104
|
+
|
|
105
|
+
return importlib.util.find_spec("flash_attn") is not None
|
|
106
|
+
except Exception:
|
|
107
|
+
return False
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def optimal_attn_impl() -> str | None:
|
|
111
|
+
"""Best ``attn_implementation`` for the live GPU (None = leave transformers' default)."""
|
|
112
|
+
try:
|
|
113
|
+
import torch
|
|
114
|
+
|
|
115
|
+
if not torch.cuda.is_available():
|
|
116
|
+
return None
|
|
117
|
+
major, minor = torch.cuda.get_device_capability(0)
|
|
118
|
+
except Exception as e:
|
|
119
|
+
print("optimal_attn_impl probe failed:", e)
|
|
120
|
+
return None
|
|
121
|
+
fa2 = _flash_attn_available() # FA2 wheel importable (Ampere/Ada/Hopper)
|
|
122
|
+
# Probe FA3 only on Hopper (the only arch it selects it for) so a non-Hopper run never imports
|
|
123
|
+
# the transformers FA3 helpers needlessly.
|
|
124
|
+
fa3 = _flash_attn_3_available() if major == 9 else False
|
|
125
|
+
impl = _attn_impl_for_capability(major, minor, fa3_available=fa3, fa2_available=fa2)
|
|
126
|
+
if impl in ("flash_attention_2", "flash_attention_3"):
|
|
127
|
+
ver = "FlashAttention-3" if impl == "flash_attention_3" else "FlashAttention-2"
|
|
128
|
+
print(
|
|
129
|
+
f"[attn] sm{major}{minor} -> attn_implementation={impl} ({ver}, full-attention layers)"
|
|
130
|
+
)
|
|
131
|
+
elif major == 9 and not fa3:
|
|
132
|
+
# Hopper but FA3 not selected -> plain SDPA (uniform fallback). FA3 is baked into the worker
|
|
133
|
+
# image by default, so this means flash_attn_interface is absent/broken — check the install.
|
|
134
|
+
print(f"[attn] sm{major}{minor}: FA3 unavailable (flash_attn_interface absent) -> SDPA")
|
|
135
|
+
elif major == 12: # the only arch that returns impl=="sdpa" -> this branch covers all of it
|
|
136
|
+
print(
|
|
137
|
+
f"[attn] sm{major}{minor} (consumer Blackwell) -> SDPA/cuDNN (FA3/FA4 need TMEM; n/a on sm120)"
|
|
138
|
+
)
|
|
139
|
+
elif not fa2:
|
|
140
|
+
print(f"[attn] sm{major}{minor}: flash_attn wheel absent -> SDPA")
|
|
141
|
+
return impl
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
# Liger's fused linear cross-entropy is a MEMORY optimization (it never materializes the fp32
|
|
145
|
+
# [B,T,vocab] logits), not a fixed-batch speed win. PR #174 ledger: on a 1B model at fixed batch
|
|
146
|
+
# it is a measured NET LOSS on EVERY arch (min-of-3: A100 0.86x, H100 0.90x, RTX 3090 0.78x,
|
|
147
|
+
# RTX 4090 0.83x, RTX 5090 0.79x) — the per-step Triton overhead isn't repaid because the small
|
|
148
|
+
# model's logits don't dominate memory. Its value appears on LARGE models (lets a bigger batch
|
|
149
|
+
# fit / avoids OOM). So gate by estimated model size.
|
|
150
|
+
# ~3B in raw param count; the canonical threshold (in billions) lives in flash.engine.vram.
|
|
151
|
+
# 1B-class models measured net-negative -> Liger off below this.
|
|
152
|
+
_LIGER_MIN_PARAMS = _LIGER_MIN_PARAMS_B * 1e9
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _estimate_params(cfg) -> float:
|
|
156
|
+
"""Rough param count from a HF config: embeddings (+untied lm_head) + transformer blocks.
|
|
157
|
+
For multimodal checkpoints (e.g. Qwen3.5-VL) the LM dims live under ``text_config`` — read it
|
|
158
|
+
when the top-level dims are absent, else the gate underestimates and wrongly disables the
|
|
159
|
+
memory path (GC/Liger) for the 4B/9B tiers."""
|
|
160
|
+
tc = getattr(cfg, "text_config", None)
|
|
161
|
+
src = cfg if getattr(cfg, "hidden_size", 0) else (tc or cfg)
|
|
162
|
+
h = getattr(src, "hidden_size", 0) or 0
|
|
163
|
+
v = getattr(src, "vocab_size", 0) or getattr(cfg, "vocab_size", 0) or 0
|
|
164
|
+
n = getattr(src, "num_hidden_layers", 0) or 0
|
|
165
|
+
tied = getattr(src, "tie_word_embeddings", getattr(cfg, "tie_word_embeddings", False))
|
|
166
|
+
emb = v * h * (1 if tied else 2)
|
|
167
|
+
blocks = n * 12 * h * h # ~12 h^2 per transformer block (attn + MLP)
|
|
168
|
+
return float(emb + blocks)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _liger_default_for_model(model_id: str) -> bool:
|
|
172
|
+
"""Default Liger ON only for models large enough that fused-CE's memory win pays off
|
|
173
|
+
(≥ _LIGER_MIN_PARAMS, ~3B). 1B-class models measured net-negative -> default OFF."""
|
|
174
|
+
try:
|
|
175
|
+
from transformers import AutoConfig
|
|
176
|
+
|
|
177
|
+
cfg = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
|
|
178
|
+
return _estimate_params(cfg) >= _LIGER_MIN_PARAMS
|
|
179
|
+
except Exception as e:
|
|
180
|
+
print("liger model-size probe failed (default off):", e)
|
|
181
|
+
return False
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def liger_on(default_on: bool) -> bool:
|
|
185
|
+
"""Whether to enable a Liger kernel path. ``default_on`` is the model-size decision (on only
|
|
186
|
+
for models large enough that fused-CE's memory win pays off; 1B-class is a measured net loss).
|
|
187
|
+
Even when on, require a CUDA GPU AND that ``liger_kernel`` is importable — the local
|
|
188
|
+
``flash[gpu]`` extra doesn't ship it, so blindly setting use_liger_kernel would crash a
|
|
189
|
+
local GPU run. No GPU / absent -> off."""
|
|
190
|
+
if not default_on:
|
|
191
|
+
return False
|
|
192
|
+
try:
|
|
193
|
+
import importlib.util
|
|
194
|
+
|
|
195
|
+
import torch
|
|
196
|
+
|
|
197
|
+
return bool(
|
|
198
|
+
torch.cuda.is_available() and importlib.util.find_spec("liger_kernel") is not None
|
|
199
|
+
)
|
|
200
|
+
except Exception:
|
|
201
|
+
return False
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def setup_perf_backends() -> None:
|
|
205
|
+
"""Universal, arch-agnostic throughput knobs — safe on every CUDA arch, no JIT/compile cost.
|
|
206
|
+
|
|
207
|
+
- TF32 for fp32 matmuls/cuDNN (Ampere+): the residual fp32 ops in a bf16 LoRA run (some
|
|
208
|
+
norms, the optimizer's fp32 master step, any fp32 GEMM) run on the TF32 tensor cores at
|
|
209
|
+
~no accuracy cost. No-op on pre-Ampere.
|
|
210
|
+
"""
|
|
211
|
+
try:
|
|
212
|
+
import torch
|
|
213
|
+
|
|
214
|
+
if not torch.cuda.is_available():
|
|
215
|
+
return
|
|
216
|
+
torch.set_float32_matmul_precision("high") # TF32 for fp32 matmuls
|
|
217
|
+
torch.backends.cuda.matmul.allow_tf32 = True
|
|
218
|
+
torch.backends.cudnn.allow_tf32 = True
|
|
219
|
+
print("[perf] TF32 matmul/cuDNN enabled")
|
|
220
|
+
except Exception as e:
|
|
221
|
+
print("setup_perf_backends skipped:", e)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def _remove_fla_from_disk() -> tuple[list[str], bool]:
|
|
225
|
+
"""Physically delete every importable ``fla`` package dir from the worker's REAL sys.path.
|
|
226
|
+
|
|
227
|
+
Loops until ``find_spec('fla')`` is clean (removing one copy can expose another further down
|
|
228
|
+
the path) and invalidates import caches so transformers' is_fla_available() probe sees it
|
|
229
|
+
gone. ``pip uninstall`` alone is unreliable here — it targets one site-packages but the base
|
|
230
|
+
image bakes ``fla`` into another dir on the path (and can report success while leaving the
|
|
231
|
+
package dir). Returns ``(removed_dirs, still_importable)``. Used by the Hopper auto-drop.
|
|
232
|
+
"""
|
|
233
|
+
import importlib
|
|
234
|
+
import importlib.util
|
|
235
|
+
import shutil
|
|
236
|
+
|
|
237
|
+
removed: list[str] = []
|
|
238
|
+
for _ in range(6): # a few passes: removing one copy can reveal another earlier on the path
|
|
239
|
+
importlib.invalidate_caches()
|
|
240
|
+
spec = importlib.util.find_spec("fla")
|
|
241
|
+
if spec is None:
|
|
242
|
+
break
|
|
243
|
+
# Resolve the package directory (submodule_search_locations for a package, else the file dir).
|
|
244
|
+
locs = list(getattr(spec, "submodule_search_locations", None) or [])
|
|
245
|
+
if not locs and spec.origin:
|
|
246
|
+
locs = [os.path.dirname(spec.origin)]
|
|
247
|
+
progressed = False
|
|
248
|
+
for loc in locs:
|
|
249
|
+
if loc and os.path.isdir(loc) and os.path.basename(loc.rstrip("/")) == "fla":
|
|
250
|
+
try:
|
|
251
|
+
shutil.rmtree(loc)
|
|
252
|
+
removed.append(loc)
|
|
253
|
+
progressed = True
|
|
254
|
+
except Exception as e:
|
|
255
|
+
print(f"[fla] could not remove {loc}: {e}", flush=True)
|
|
256
|
+
if not progressed:
|
|
257
|
+
break
|
|
258
|
+
importlib.invalidate_caches()
|
|
259
|
+
return removed, importlib.util.find_spec("fla") is not None
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def _find_real_libcudart() -> str | None:
|
|
263
|
+
"""Path to a REAL ``libcudart.so.<major>`` that exports ``cudaDeviceReset`` (the symbol
|
|
264
|
+
tilelang's stub lacks), or None if none can be found. Prefers the nvidia-cuda-runtime wheel, then
|
|
265
|
+
the CUDA toolkit baked into the worker's -devel base image, then the system loader's own resolver
|
|
266
|
+
— and VERIFIES the symbol is actually present (a path is only returned if ``CDLL(path)`` exposes
|
|
267
|
+
``cudaDeviceReset``). Never raises.
|
|
268
|
+
|
|
269
|
+
CUDA-major-agnostic: the worker image pins cu12 today, but the runtime wheel's import path and
|
|
270
|
+
soname differ across CUDA majors — the cu12 wheel ships ``nvidia/cuda_runtime/lib/libcudart.so.12``
|
|
271
|
+
while the cu13 wheel ships ``nvidia/cu13/lib/libcudart.so.13`` (and has NO ``nvidia.cuda_runtime``
|
|
272
|
+
module at all). A ``.so.12``-only probe silently returns None on cu13, leaving the stub shadow in
|
|
273
|
+
place. So we probe every ``nvidia/*/lib`` subdir and any ``libcudart.so.*`` major; the symlink
|
|
274
|
+
repoint that consumes this works for any real libcudart (``_verify`` still gates on the symbol)."""
|
|
275
|
+
import ctypes
|
|
276
|
+
import ctypes.util
|
|
277
|
+
import glob
|
|
278
|
+
|
|
279
|
+
def _verify(cand: str) -> str | None:
|
|
280
|
+
"""Absolute path to ``cand`` if it loads AND exports cudaDeviceReset, else None. Handles both
|
|
281
|
+
absolute paths (glob results) and bare sonames like ``libcudart.so.12`` (find_library, which
|
|
282
|
+
the loader resolves but ``os.path.exists`` would reject)."""
|
|
283
|
+
try:
|
|
284
|
+
lib = ctypes.CDLL(cand) # an abs path opens directly; a bare soname is loader-resolved
|
|
285
|
+
except OSError:
|
|
286
|
+
return None
|
|
287
|
+
if not hasattr(lib, "cudaDeviceReset"):
|
|
288
|
+
return None
|
|
289
|
+
if os.path.isabs(cand) and os.path.exists(cand):
|
|
290
|
+
return os.path.realpath(cand)
|
|
291
|
+
# Bare soname: resolve to the file the loader actually opened, via /proc/self/maps.
|
|
292
|
+
base = os.path.basename(cand)
|
|
293
|
+
try:
|
|
294
|
+
with open("/proc/self/maps") as f:
|
|
295
|
+
for line in f:
|
|
296
|
+
if base in line and "/" in line:
|
|
297
|
+
p = line[line.index("/"):].rstrip()
|
|
298
|
+
if os.path.basename(p).startswith(base) and os.path.exists(p):
|
|
299
|
+
return os.path.realpath(p)
|
|
300
|
+
except OSError:
|
|
301
|
+
pass
|
|
302
|
+
return None
|
|
303
|
+
|
|
304
|
+
candidates: list[str] = []
|
|
305
|
+
# 1. nvidia cuda-runtime PyPI wheel (a torch/vLLM dep on many images), any CUDA major. Import the
|
|
306
|
+
# ``nvidia`` namespace package (it resolves even when a specific ``nvidia.cuda_runtime`` subpkg
|
|
307
|
+
# is absent — e.g. the cu13 wheel has none) and glob every ``nvidia/*/lib`` for any libcudart
|
|
308
|
+
# soname, so both the cu12 layout (nvidia/cuda_runtime/lib/libcudart.so.12) and the cu13 layout
|
|
309
|
+
# (nvidia/cu13/lib/libcudart.so.13) are found. ``sorted`` keeps candidate order deterministic.
|
|
310
|
+
try:
|
|
311
|
+
import nvidia # type: ignore # namespace package; subpkg import may fail, this won't
|
|
312
|
+
|
|
313
|
+
for base in sorted(map(str, getattr(nvidia, "__path__", []) or [])):
|
|
314
|
+
candidates += sorted(glob.glob(os.path.join(base, "*", "lib", "libcudart.so.*")))
|
|
315
|
+
except Exception:
|
|
316
|
+
pass
|
|
317
|
+
# 2. CUDA toolkit in a -devel base image (Dockerfile.worker today: cuda12.8-cudnn9-devel), any major.
|
|
318
|
+
for pat in (
|
|
319
|
+
"/usr/local/cuda*/lib64/libcudart.so.*",
|
|
320
|
+
"/usr/local/cuda*/targets/*/lib/libcudart.so.*",
|
|
321
|
+
"/usr/lib/x86_64-linux-gnu/libcudart.so.*",
|
|
322
|
+
):
|
|
323
|
+
candidates += sorted(glob.glob(pat))
|
|
324
|
+
# 3. The loader's own resolver (LD_LIBRARY_PATH / ldconfig) — returns a bare soname, handled above.
|
|
325
|
+
found = ctypes.util.find_library("cudart")
|
|
326
|
+
if found:
|
|
327
|
+
candidates.append(found)
|
|
328
|
+
for cand in candidates:
|
|
329
|
+
real = _verify(cand)
|
|
330
|
+
if real is not None:
|
|
331
|
+
return real
|
|
332
|
+
return None
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def _neutralize_tilelang_cudart_stub() -> None:
|
|
336
|
+
"""Stop tilelang's bundled ``libcudart_stub.so`` from shadowing the real CUDA runtime in vLLM.
|
|
337
|
+
|
|
338
|
+
tilelang ships a minimal ``libcudart_stub.so`` (soname ``libcudart_stub.so``) that
|
|
339
|
+
``libtilelang.so`` / ``libtvm.so`` link against; it exports only a SUBSET of the CUDA runtime —
|
|
340
|
+
notably it is MISSING ``cudaDeviceReset``. vLLM's ``vllm/device_allocator/cumem.py`` builds a
|
|
341
|
+
``CudaRTLibrary`` at MODULE TOP LEVEL (``libcudart = CudaRTLibrary()``), and that module is
|
|
342
|
+
imported on EVERY vLLM init via ``gpu_worker.load_model`` ->
|
|
343
|
+
``_maybe_get_memory_pool_context(tag="weights")`` — so the crash is NOT gated on sleep mode or
|
|
344
|
+
model size (a 0.8B GRPO run hit it too); any GRPO vLLM init is exposed. ``CudaRTLibrary`` finds
|
|
345
|
+
libcudart by a SUBSTRING scan of ``/proc/self/maps`` and returns the FIRST matching line
|
|
346
|
+
(address-ordered, so host-dependent ~coin-flip). Once tilelang is loaded — the Hopper fla fast
|
|
347
|
+
path, or fla's backend probe on any arch — the stub is mapped into the process and can win that
|
|
348
|
+
scan, so ``CudaRTLibrary()`` dlopens the stub and aborts the import with ``undefined symbol:
|
|
349
|
+
cudaDeviceReset`` before step 0. See flash #184.
|
|
350
|
+
|
|
351
|
+
Fix: BEFORE anything imports tilelang/fla/vllm, repoint the stub path at the REAL
|
|
352
|
+
``libcudart.so.12`` via a symlink. Then whichever copy the loader (or vLLM's scan) resolves has
|
|
353
|
+
the full symbol set, and the real lib's soname (``libcudart.so.12``) dedupes against the copy
|
|
354
|
+
torch already loaded — so no second CUDA-runtime instance is created and the stub-named mapping
|
|
355
|
+
drops out of ``/proc/self/maps`` entirely. tilelang keeps working: the real runtime is a strict
|
|
356
|
+
superset of the stub it linked against. Applies on EVERY arch and model size (the crash spans
|
|
357
|
+
0.8B/4B and A100/cheaper classes) and to every provisioning path (baked image or runtime pip),
|
|
358
|
+
since it runs in the worker before the first tilelang load. Must run AFTER
|
|
359
|
+
``_ensure_fla_fastpath_on_hopper`` (a tilelang (re)install there would otherwise rewrite the
|
|
360
|
+
stub) and BEFORE the model/vLLM import.
|
|
361
|
+
|
|
362
|
+
Idempotent and best-effort: a missing tilelang, a missing stub, an already-healthy stub, or no
|
|
363
|
+
discoverable real runtime is a clean no-op; any error is swallowed (the worker must never crash
|
|
364
|
+
on this hygiene step). No GPU required.
|
|
365
|
+
"""
|
|
366
|
+
import importlib.util
|
|
367
|
+
|
|
368
|
+
try:
|
|
369
|
+
spec = importlib.util.find_spec("tilelang")
|
|
370
|
+
except Exception:
|
|
371
|
+
spec = None
|
|
372
|
+
locs = list(getattr(spec, "submodule_search_locations", None) or []) if spec else []
|
|
373
|
+
if not locs:
|
|
374
|
+
return # tilelang not installed -> nothing can shadow libcudart
|
|
375
|
+
stub = os.path.join(locs[0], "lib", "libcudart_stub.so")
|
|
376
|
+
if not os.path.lexists(stub): # lexists: a dangling symlink still counts as present
|
|
377
|
+
return
|
|
378
|
+
# Idempotency WITHOUT loading the stub: we only ever turn the stub into a symlink, and a pristine
|
|
379
|
+
# tilelang always ships it as a regular file, so a RESOLVING symlink here means a prior pass
|
|
380
|
+
# already neutralized it. Crucially, do NOT probe the stub with ctypes.CDLL — that dlopens it (it
|
|
381
|
+
# loads fine under lazy binding despite the missing cudaDeviceReset) and maps it into THIS
|
|
382
|
+
# process's /proc/self/maps, which is exactly the libcudart line vLLM's CudaRTLibrary scan would
|
|
383
|
+
# then pick up -> the very crash we're preventing. The stub must never be loaded; only the file
|
|
384
|
+
# is touched. A DANGLING symlink (our target moved/was removed) is NOT done — it leaves tilelang
|
|
385
|
+
# with a broken libcudart_stub.so, so fall through and re-point it (os.path.exists follows links).
|
|
386
|
+
if os.path.islink(stub) and os.path.exists(stub):
|
|
387
|
+
return
|
|
388
|
+
real = _find_real_libcudart()
|
|
389
|
+
if real is None:
|
|
390
|
+
print(
|
|
391
|
+
"[worker] libcudart stub shadow: no real libcudart found; left as-is (flash #184)",
|
|
392
|
+
flush=True,
|
|
393
|
+
)
|
|
394
|
+
return
|
|
395
|
+
try:
|
|
396
|
+
# Preserve the original stub ONCE (reversible / debuggable), then point the stub path at the
|
|
397
|
+
# real runtime. os.replace is atomic; symlink keeps soname-dedup (no duplicate libcudart).
|
|
398
|
+
backup = stub + ".orig"
|
|
399
|
+
if not os.path.exists(backup):
|
|
400
|
+
os.replace(stub, backup)
|
|
401
|
+
else:
|
|
402
|
+
with contextlib.suppress(FileNotFoundError):
|
|
403
|
+
os.remove(stub)
|
|
404
|
+
os.symlink(real, stub)
|
|
405
|
+
print(f"[worker] redirected tilelang libcudart_stub.so -> {real} (flash #184)", flush=True)
|
|
406
|
+
except Exception as e:
|
|
407
|
+
print(f"[worker] libcudart stub neutralize failed: {e}", flush=True)
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
# Long-context runs are memory-bound (activations + vLLM KV cache scale with sequence length), so
|
|
411
|
+
# they need the memory features even on a SMALL model — PR #174 measured a 1B model OOM on GRPO at
|
|
412
|
+
# 4096 ctx in speed mode, but it fits in memory mode. So "memory mode" = large model OR long ctx.
|
|
413
|
+
_LONG_CONTEXT_TOKENS = _LIGER_LONG_CTX_TOKENS # canonical value in flash.engine.vram
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
def _memory_mode(model_id: str, max_length: int = 0) -> bool:
|
|
417
|
+
"""Whether to default the memory-saving features (Liger, grad-checkpointing, vLLM sleep) ON:
|
|
418
|
+
a large model (fused-CE memory win) OR a long context (activations/KV dominate). Small model +
|
|
419
|
+
short context -> off (optimize for speed)."""
|
|
420
|
+
if max_length and max_length >= _LONG_CONTEXT_TOKENS:
|
|
421
|
+
return True
|
|
422
|
+
return _liger_default_for_model(model_id)
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def grad_checkpointing_on(model_id: str, max_length: int = 0) -> bool:
|
|
426
|
+
"""Gradient checkpointing recomputes the forward in backward (~25% slower) to save activation
|
|
427
|
+
memory — a MEMORY feature, not speed. ON for large models / long context that need the
|
|
428
|
+
headroom; OFF for small+short runs that fit without it (the speed win)."""
|
|
429
|
+
return _memory_mode(model_id, max_length)
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
def grpo_sleep_mode(
|
|
433
|
+
model_id: str,
|
|
434
|
+
*,
|
|
435
|
+
max_length: int = 0,
|
|
436
|
+
group_size: int = 8,
|
|
437
|
+
max_tokens: int | None = None,
|
|
438
|
+
lora_rank: int = 32,
|
|
439
|
+
thinking: bool = False,
|
|
440
|
+
card_vram_gb: float = 0.0,
|
|
441
|
+
) -> bool:
|
|
442
|
+
"""Whether colocated-vLLM GRPO should enable vLLM sleep mode (offload the rollout engine
|
|
443
|
+
between steps).
|
|
444
|
+
|
|
445
|
+
Sleep mode trades a large per-step cost for memory, and on the large-model GRPO path the
|
|
446
|
+
sleep/wake cycle STALLS the colocated rollout (the rollout produces unparseable completions and
|
|
447
|
+
then the worker hangs). So enable it ONLY when the run genuinely can't fit RESIDENT on the card:
|
|
448
|
+
when the policy + colocated rollout engine + training peak all fit on ``card_vram_gb`` (the
|
|
449
|
+
common case on an allocator-sized card), skip sleep mode entirely. Falls back to the
|
|
450
|
+
size/context gate (``_memory_mode``) when the card VRAM is unknown."""
|
|
451
|
+
from flash.engine.vram import grpo_fits_resident, grpo_rollout_seq_len
|
|
452
|
+
|
|
453
|
+
# Gate on the rollout length run_rl() ACTUALLY launches (max(1024, prompt+completion) when
|
|
454
|
+
# [train].max_length is unset -- 2368 default / 3584 thinking), NOT the raw max_length. With
|
|
455
|
+
# max_length unset (0) the size/context pre-filter would see a 0-length "short" run and early-
|
|
456
|
+
# exit for a sub-3B model, skipping the resident-fit check that a long max_tokens rollout needs.
|
|
457
|
+
seq_len = grpo_rollout_seq_len(max_length, max_tokens, thinking)
|
|
458
|
+
if not _memory_mode(model_id, seq_len):
|
|
459
|
+
return False # small model AND genuinely short rollout -> never needed
|
|
460
|
+
if card_vram_gb and card_vram_gb > 0:
|
|
461
|
+
try:
|
|
462
|
+
if grpo_fits_resident(
|
|
463
|
+
model_id,
|
|
464
|
+
seq_len=seq_len,
|
|
465
|
+
max_tokens=max_tokens,
|
|
466
|
+
lora_rank=lora_rank,
|
|
467
|
+
group_size=group_size,
|
|
468
|
+
thinking=thinking,
|
|
469
|
+
card_vram_gb=card_vram_gb,
|
|
470
|
+
):
|
|
471
|
+
return False # fits resident -> skip the (buggy, slow) sleep/wake cycle
|
|
472
|
+
except Exception as e:
|
|
473
|
+
print("[rl] grpo sleep-mode resident check skipped:", e)
|
|
474
|
+
return True
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def fused_optim_name() -> str:
|
|
478
|
+
"""TRL/HF ``optim`` value: 8-bit paged AdamW (bitsandbytes int8 optimizer state paged to host
|
|
479
|
+
RAM). It fits a smaller/cheaper GPU and is the better default across the catalog."""
|
|
480
|
+
return "paged_adamw_8bit"
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def _reset_peak_gpu() -> None:
|
|
484
|
+
"""Reset the CUDA peak-memory counter so a subsequent ``_peak_gpu_gb`` measures only the work
|
|
485
|
+
that follows (e.g. the train loop, isolating the optimizer-state A/B from setup/model load)."""
|
|
486
|
+
try:
|
|
487
|
+
import torch
|
|
488
|
+
|
|
489
|
+
if torch.cuda.is_available():
|
|
490
|
+
torch.cuda.reset_peak_memory_stats()
|
|
491
|
+
except Exception:
|
|
492
|
+
pass
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def _peak_gpu_gb() -> float:
|
|
496
|
+
"""Peak torch-allocated CUDA memory (GB) since the last reset; 0.0 if no CUDA. Note: bnb paged
|
|
497
|
+
8-bit optimizer state lives in unified/managed memory outside torch's caching allocator and is
|
|
498
|
+
NOT counted here — so this OVERSTATES the 8-bit saving. _GpuPeakSampler measures the true
|
|
499
|
+
device footprint (incl. bnb managed pages) for the honest A/B number."""
|
|
500
|
+
try:
|
|
501
|
+
import torch
|
|
502
|
+
|
|
503
|
+
if torch.cuda.is_available():
|
|
504
|
+
return round(torch.cuda.max_memory_allocated() / 1e9, 3)
|
|
505
|
+
except Exception:
|
|
506
|
+
pass
|
|
507
|
+
return 0.0
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
class _GpuPeakSampler:
|
|
511
|
+
"""Background sampler of true device memory (GB) = (total - free) from cuda.mem_get_info, which
|
|
512
|
+
DOES include bitsandbytes managed/paged optimizer pages while they're GPU-resident (torch's
|
|
513
|
+
max_memory_allocated does not). This is the honest peak for the fp32-vs-8-bit optimizer A/B."""
|
|
514
|
+
|
|
515
|
+
def __init__(self, interval: float = 0.25):
|
|
516
|
+
self.interval = interval
|
|
517
|
+
self.peak_used = 0
|
|
518
|
+
self._stop = False
|
|
519
|
+
self._thread = None
|
|
520
|
+
|
|
521
|
+
def _run(self):
|
|
522
|
+
import torch
|
|
523
|
+
|
|
524
|
+
while not self._stop:
|
|
525
|
+
try:
|
|
526
|
+
free, total = torch.cuda.mem_get_info()
|
|
527
|
+
self.peak_used = max(self.peak_used, total - free)
|
|
528
|
+
except Exception:
|
|
529
|
+
pass
|
|
530
|
+
time.sleep(self.interval)
|
|
531
|
+
|
|
532
|
+
def start(self):
|
|
533
|
+
try:
|
|
534
|
+
import threading
|
|
535
|
+
|
|
536
|
+
import torch
|
|
537
|
+
|
|
538
|
+
if not torch.cuda.is_available():
|
|
539
|
+
return self
|
|
540
|
+
self._thread = threading.Thread(target=self._run, daemon=True)
|
|
541
|
+
self._thread.start()
|
|
542
|
+
except Exception:
|
|
543
|
+
pass
|
|
544
|
+
return self
|
|
545
|
+
|
|
546
|
+
def stop_gb(self) -> float:
|
|
547
|
+
self._stop = True
|
|
548
|
+
if self._thread is not None:
|
|
549
|
+
self._thread.join(timeout=2)
|
|
550
|
+
return round(self.peak_used / 1e9, 3)
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
def loraplus_optimizer_cls(optim_name: str):
|
|
554
|
+
"""Optimizer class for the LoRA+ ``create_optimizer`` override (returns ``(cls, extra_kwargs)``).
|
|
555
|
+
|
|
556
|
+
The LoRA+ override has to *build* the optimizer itself (PEFT splits the LoRA A/B matrices into
|
|
557
|
+
separate param groups with different LRs), so it cannot inherit TRL's ``optim=`` string — it has
|
|
558
|
+
to choose a concrete class. Historically it always built a full-precision ``torch.optim.AdamW``,
|
|
559
|
+
which silently discarded the catalog's ``paged_adamw_8bit`` setting whenever LoRA+ was on.
|
|
560
|
+
|
|
561
|
+
PEFT's ``create_loraplus_optimizer`` accepts ANY ``optimizer_cls`` — including bitsandbytes 8-bit
|
|
562
|
+
optimizers (it registers embedding overrides with bnb's ``GlobalOptimManager`` to keep them
|
|
563
|
+
32-bit) — so LoRA+ and the 8-bit paged optimizer state coexist. An ``8bit`` ``optim`` value
|
|
564
|
+
(the fleet default; ``fused_optim_name`` -> ``paged_adamw_8bit``) selects
|
|
565
|
+
``bnb.optim.PagedAdamW8bit``; a non-8-bit ``optim`` keeps fp32 AdamW. This simply mirrors the
|
|
566
|
+
configured ``optim`` — there is no separate toggle: an on-GPU A/B (Qwen3.5-4B SFT, rank-128
|
|
567
|
+
LoRA, same seed/data/init) measured the 8-bit paged state at -75% optimizer memory
|
|
568
|
+
(1359 -> 346 MB) and -0.72 GB peak with NO convergence penalty (final loss 10.64 vs 11.16 from
|
|
569
|
+
an identical start), so it's unconditionally the default wherever ``optim`` is 8-bit. Falls
|
|
570
|
+
back to fp32 AdamW only if bitsandbytes is missing."""
|
|
571
|
+
import torch as _torch
|
|
572
|
+
|
|
573
|
+
# case-insensitive + str-safe: TRL normalizes optim to an OptimizerNames enum whose str() is
|
|
574
|
+
# "OptimizerNames.PAGED_ADAMW_8BIT" (uppercase), so a bare `"8bit" in optim_name` would miss it.
|
|
575
|
+
if "8bit" in str(optim_name or "").lower():
|
|
576
|
+
try:
|
|
577
|
+
import bitsandbytes as bnb
|
|
578
|
+
|
|
579
|
+
return bnb.optim.PagedAdamW8bit, {}
|
|
580
|
+
except Exception as e: # bnb missing / no CUDA build -> safe fp32 fallback
|
|
581
|
+
print(f"[lora+] bitsandbytes 8-bit optimizer unavailable ({e}); using fp32 AdamW")
|
|
582
|
+
return _torch.optim.AdamW, {}
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
def _sdpa_cudnn_ctx(attn_impl: str | None):
|
|
586
|
+
"""Context forcing the cuDNN SDPA backend (real Blackwell-consumer kernels) when we fell
|
|
587
|
+
back to plain SDPA on sm120; a no-op context otherwise. Best-effort."""
|
|
588
|
+
if attn_impl != "sdpa":
|
|
589
|
+
return contextlib.nullcontext()
|
|
590
|
+
try:
|
|
591
|
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
592
|
+
|
|
593
|
+
# Priority-ordered: prefer the fast cuDNN/flash/efficient kernels, but ALWAYS include MATH
|
|
594
|
+
# as the final fallback. Restricting to only [CUDNN, EFFICIENT] makes sm120 GRPO crash with
|
|
595
|
+
# "RuntimeError: No available kernel" when neither has a kernel for the completion-batch
|
|
596
|
+
# attention shape (MEASURED: Qwen3.5 GRPO on RTX 5090). MATH is universal, so the candidate
|
|
597
|
+
# set is never empty; set_priority keeps cuDNN first whenever it CAN serve the shape (SFT
|
|
598
|
+
# fast path unchanged), only falling through for the shapes cuDNN/efficient reject.
|
|
599
|
+
return sdpa_kernel(
|
|
600
|
+
[
|
|
601
|
+
SDPBackend.CUDNN_ATTENTION,
|
|
602
|
+
SDPBackend.FLASH_ATTENTION,
|
|
603
|
+
SDPBackend.EFFICIENT_ATTENTION,
|
|
604
|
+
SDPBackend.MATH,
|
|
605
|
+
],
|
|
606
|
+
set_priority=True,
|
|
607
|
+
)
|
|
608
|
+
except Exception as e:
|
|
609
|
+
print("[attn] cuDNN SDPA backend unavailable, using default SDPA:", e)
|
|
610
|
+
return contextlib.nullcontext()
|
|
611
|
+
|
|
612
|
+
|
|
613
|
+
def _float_or_none(value) -> float | None:
|
|
614
|
+
try:
|
|
615
|
+
text = str(value).strip()
|
|
616
|
+
if not text or text.upper() in {"N/A", "[N/A]", "NOT SUPPORTED", "[NOT SUPPORTED]"}:
|
|
617
|
+
return None
|
|
618
|
+
return float(text)
|
|
619
|
+
except (TypeError, ValueError):
|
|
620
|
+
return None
|
|
621
|
+
|
|
622
|
+
|
|
623
|
+
def _int_or_none(value) -> int | None:
|
|
624
|
+
num = _float_or_none(value)
|
|
625
|
+
return int(num) if num is not None else None
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
def _round_gb_from_mib(value) -> float | None:
|
|
629
|
+
num = _float_or_none(value)
|
|
630
|
+
if num is None:
|
|
631
|
+
return None
|
|
632
|
+
return round(num / 1024.0, 3)
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
def _clean_diag(diag: dict) -> dict:
|
|
636
|
+
return {k: v for k, v in diag.items() if v is not None and v != ""}
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
def _query_nvidia_gpu() -> dict:
|
|
640
|
+
import subprocess
|
|
641
|
+
|
|
642
|
+
fields = [
|
|
643
|
+
"index",
|
|
644
|
+
"uuid",
|
|
645
|
+
"driver_version",
|
|
646
|
+
"name",
|
|
647
|
+
"utilization.gpu",
|
|
648
|
+
"utilization.memory",
|
|
649
|
+
"memory.total",
|
|
650
|
+
"memory.used",
|
|
651
|
+
"memory.free",
|
|
652
|
+
"temperature.gpu",
|
|
653
|
+
"power.draw",
|
|
654
|
+
"power.limit",
|
|
655
|
+
"pstate",
|
|
656
|
+
"clocks.sm",
|
|
657
|
+
"clocks.mem",
|
|
658
|
+
"pcie.link.gen.current",
|
|
659
|
+
"pcie.link.width.current",
|
|
660
|
+
]
|
|
661
|
+
out = subprocess.run(
|
|
662
|
+
["nvidia-smi", f"--query-gpu={','.join(fields)}", "--format=csv,noheader,nounits"],
|
|
663
|
+
capture_output=True,
|
|
664
|
+
text=True,
|
|
665
|
+
timeout=8.0, # nvidia-smi diag timeout (fixed; flash is fully managed)
|
|
666
|
+
)
|
|
667
|
+
raw = (out.stdout or out.stderr).strip()
|
|
668
|
+
if out.returncode != 0:
|
|
669
|
+
return {"nvidia_smi_err": raw[:300]}
|
|
670
|
+
rows = list(csv.reader(raw.splitlines()))
|
|
671
|
+
if not rows:
|
|
672
|
+
return {}
|
|
673
|
+
first = [cell.strip() for cell in rows[0]]
|
|
674
|
+
row = dict(zip(fields, first, strict=False))
|
|
675
|
+
diag = {
|
|
676
|
+
"index": _int_or_none(row.get("index")),
|
|
677
|
+
"uuid": row.get("uuid"),
|
|
678
|
+
"driver_version": row.get("driver_version"),
|
|
679
|
+
"device_name": row.get("name"),
|
|
680
|
+
"gpu_util_pct": _int_or_none(row.get("utilization.gpu")),
|
|
681
|
+
"mem_util_pct": _int_or_none(row.get("utilization.memory")),
|
|
682
|
+
"memory_total_gb": _round_gb_from_mib(row.get("memory.total")),
|
|
683
|
+
"memory_used_gb": _round_gb_from_mib(row.get("memory.used")),
|
|
684
|
+
"memory_free_gb": _round_gb_from_mib(row.get("memory.free")),
|
|
685
|
+
"temperature_c": _int_or_none(row.get("temperature.gpu")),
|
|
686
|
+
"power_w": _float_or_none(row.get("power.draw")),
|
|
687
|
+
"power_limit_w": _float_or_none(row.get("power.limit")),
|
|
688
|
+
"pstate": row.get("pstate"),
|
|
689
|
+
"sm_clock_mhz": _int_or_none(row.get("clocks.sm")),
|
|
690
|
+
"mem_clock_mhz": _int_or_none(row.get("clocks.mem")),
|
|
691
|
+
"pcie_gen": _int_or_none(row.get("pcie.link.gen.current")),
|
|
692
|
+
"pcie_width": _int_or_none(row.get("pcie.link.width.current")),
|
|
693
|
+
}
|
|
694
|
+
clean = _clean_diag(diag)
|
|
695
|
+
clean["nvidia_smi"] = raw[:300]
|
|
696
|
+
return clean
|
|
697
|
+
|
|
698
|
+
|
|
699
|
+
def _query_nvidia_processes() -> list[dict]:
|
|
700
|
+
import subprocess
|
|
701
|
+
|
|
702
|
+
out = subprocess.run(
|
|
703
|
+
[
|
|
704
|
+
"nvidia-smi",
|
|
705
|
+
"--query-compute-apps=pid,process_name,used_memory",
|
|
706
|
+
"--format=csv,noheader,nounits",
|
|
707
|
+
],
|
|
708
|
+
capture_output=True,
|
|
709
|
+
text=True,
|
|
710
|
+
timeout=8.0, # nvidia-smi diag timeout (fixed; flash is fully managed)
|
|
711
|
+
)
|
|
712
|
+
if out.returncode != 0 or not out.stdout.strip():
|
|
713
|
+
return []
|
|
714
|
+
rows = []
|
|
715
|
+
for row in csv.reader(out.stdout.splitlines()):
|
|
716
|
+
if len(row) < 3:
|
|
717
|
+
continue
|
|
718
|
+
rows.append(
|
|
719
|
+
_clean_diag(
|
|
720
|
+
{
|
|
721
|
+
"pid": _int_or_none(row[0]),
|
|
722
|
+
"process_name": row[1].strip(),
|
|
723
|
+
"used_memory_gb": _round_gb_from_mib(row[2]),
|
|
724
|
+
}
|
|
725
|
+
)
|
|
726
|
+
)
|
|
727
|
+
return sorted(rows, key=lambda r: float(r.get("used_memory_gb") or 0.0), reverse=True)[:8]
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
def gpu_diagnostics(include_torch: bool = True) -> dict:
|
|
731
|
+
"""Collect live CUDA/GPU telemetry for run logs and status."""
|
|
732
|
+
diag = {}
|
|
733
|
+
if include_torch:
|
|
734
|
+
try:
|
|
735
|
+
import torch
|
|
736
|
+
|
|
737
|
+
diag["torch"] = torch.__version__
|
|
738
|
+
diag["torch_cuda"] = torch.version.cuda
|
|
739
|
+
diag["cuda_available"] = torch.cuda.is_available()
|
|
740
|
+
try:
|
|
741
|
+
diag["device_count"] = torch.cuda.device_count()
|
|
742
|
+
if torch.cuda.is_available():
|
|
743
|
+
diag["device_name"] = torch.cuda.get_device_name(0)
|
|
744
|
+
free, total = torch.cuda.mem_get_info()
|
|
745
|
+
diag["torch_memory_free_gb"] = round(free / (1024**3), 3)
|
|
746
|
+
diag["torch_memory_total_gb"] = round(total / (1024**3), 3)
|
|
747
|
+
diag["torch_memory_allocated_gb"] = round(
|
|
748
|
+
torch.cuda.memory_allocated() / (1024**3), 3
|
|
749
|
+
)
|
|
750
|
+
diag["torch_memory_reserved_gb"] = round(
|
|
751
|
+
torch.cuda.memory_reserved() / (1024**3), 3
|
|
752
|
+
)
|
|
753
|
+
except Exception as e:
|
|
754
|
+
diag["device_query_err"] = str(e)[:160]
|
|
755
|
+
except Exception as e:
|
|
756
|
+
diag["torch_import_err"] = str(e)[:160]
|
|
757
|
+
try:
|
|
758
|
+
diag.update(_query_nvidia_gpu())
|
|
759
|
+
processes = _query_nvidia_processes()
|
|
760
|
+
if processes:
|
|
761
|
+
diag["processes"] = processes
|
|
762
|
+
except Exception as e:
|
|
763
|
+
diag["nvidia_smi_err"] = str(e)[:160]
|
|
764
|
+
return _clean_diag(diag)
|
|
765
|
+
|
|
766
|
+
|
|
767
|
+
# Human-readable sentinel embedded in the error message (debug tag only — the runner classifies
|
|
768
|
+
# structurally off the worker's heartbeat ``retriable`` flag, not by matching this phrase).
|
|
769
|
+
RETRIABLE_INFRA_MARKER = "RETRIABLE_INFRA_GPU"
|
|
770
|
+
|
|
771
|
+
|
|
772
|
+
class RetriableInfraError(RuntimeError):
|
|
773
|
+
"""An infrastructure failure the control plane should RETRY on a fresh worker.
|
|
774
|
+
|
|
775
|
+
Raised for a host the run can never train on — e.g. a GPU that never comes up
|
|
776
|
+
(``wait_for_gpu`` times out) or a required-upload failure. The worker's top-level handler
|
|
777
|
+
stamps ``retriable=True`` into heartbeat.json so the runner retries on a fresh worker.
|
|
778
|
+
"""
|
|
779
|
+
|
|
780
|
+
def __init__(self, reason: str):
|
|
781
|
+
super().__init__(f"{RETRIABLE_INFRA_MARKER}: {reason}")
|
|
782
|
+
|
|
783
|
+
|
|
784
|
+
def detect_mig_slice() -> str | None:
|
|
785
|
+
"""Return a reason string if this worker was handed a MIG slice (a partitioned GPU), else None.
|
|
786
|
+
|
|
787
|
+
A MIG slice NVML-asserts PyTorch's CUDA allocator — observed when a provider substitutes a
|
|
788
|
+
requested GPU type with a Blackwell MIG slice — which crashes the run with an opaque allocator
|
|
789
|
+
assert partway through setup. Detect it up front (via nvidia-smi, before the first real CUDA op)
|
|
790
|
+
so the worker can fail fast as RETRIABLE infra and the runner re-provisions a fresh FULL GPU,
|
|
791
|
+
rather than letting the run die mid-setup. Best-effort: never raises (a missing/odd nvidia-smi
|
|
792
|
+
just means "no MIG detected", which the subsequent CUDA readiness probe still guards)."""
|
|
793
|
+
import subprocess
|
|
794
|
+
|
|
795
|
+
try:
|
|
796
|
+
out = subprocess.run(
|
|
797
|
+
["nvidia-smi", "-L"], capture_output=True, text=True, timeout=20
|
|
798
|
+
).stdout
|
|
799
|
+
# A MIG slice appears as a nested device line, e.g.
|
|
800
|
+
# " MIG 1g.10gb Device 0: (UUID: MIG-xxxx)" (or any "UUID: MIG-..." entry).
|
|
801
|
+
for line in out.splitlines():
|
|
802
|
+
s = line.strip()
|
|
803
|
+
if s.startswith("MIG ") or "UUID: MIG-" in s:
|
|
804
|
+
return f"MIG slice detected (nvidia-smi -L: {s[:120]!r})"
|
|
805
|
+
except Exception:
|
|
806
|
+
pass
|
|
807
|
+
try:
|
|
808
|
+
q = subprocess.run(
|
|
809
|
+
["nvidia-smi", "--query-gpu=mig.mode.current", "--format=csv,noheader"],
|
|
810
|
+
capture_output=True, text=True, timeout=20,
|
|
811
|
+
).stdout.strip()
|
|
812
|
+
# "Enabled" => the assigned GPU is partitioned into MIG instances (no full-GPU access).
|
|
813
|
+
# "Disabled"/"N/A"/"[Not Supported]" (consumer + MIG-incapable cards) => fine.
|
|
814
|
+
if q and "enabled" in q.lower():
|
|
815
|
+
return f"MIG mode enabled on the assigned GPU (mig.mode.current={q!r})"
|
|
816
|
+
except Exception:
|
|
817
|
+
pass
|
|
818
|
+
return None
|
|
819
|
+
|
|
820
|
+
|
|
821
|
+
def wait_for_gpu():
|
|
822
|
+
"""Rented nodes sometimes report 'CUDA device not ready' transiently at startup.
|
|
823
|
+
Poll a trivial CUDA op until it succeeds before doing real work; raise if never ready.
|
|
824
|
+
|
|
825
|
+
Also fails fast (retriable) if the assigned GPU is a MIG slice — a partitioned GPU crashes the
|
|
826
|
+
CUDA allocator, so we re-provision on a fresh FULL GPU instead of dying mid-setup."""
|
|
827
|
+
import time as _t
|
|
828
|
+
|
|
829
|
+
mig = detect_mig_slice()
|
|
830
|
+
if mig:
|
|
831
|
+
# Infra-shaped: a MIG slice can never train this workload -> retry on a fresh full GPU.
|
|
832
|
+
raise RetriableInfraError(f"{mig}; retrying on a fresh full (non-MIG) GPU")
|
|
833
|
+
|
|
834
|
+
last = None
|
|
835
|
+
for i in range(12):
|
|
836
|
+
try:
|
|
837
|
+
import torch
|
|
838
|
+
|
|
839
|
+
if torch.cuda.is_available():
|
|
840
|
+
# Force an actual kernel launch (alloc + add) to confirm the GPU is live.
|
|
841
|
+
_ = torch.zeros(8, device="cuda") + 1
|
|
842
|
+
torch.cuda.synchronize()
|
|
843
|
+
print(f"GPU ready after {i} retries: {torch.cuda.get_device_name(0)}")
|
|
844
|
+
return True
|
|
845
|
+
last = "cuda not available"
|
|
846
|
+
except Exception as e:
|
|
847
|
+
last = str(e)[:160]
|
|
848
|
+
print(f"GPU not ready (try {i + 1}/12): {last}; sleeping 10s")
|
|
849
|
+
_t.sleep(10)
|
|
850
|
+
# Infra-shaped: a host whose GPU never comes up is dead, not a code bug -> retry on a fresh one.
|
|
851
|
+
raise RetriableInfraError(f"GPU never became ready after 12 tries: {last}")
|
|
852
|
+
|
|
853
|
+
|
|
854
|
+
def free_gpu(trainer=None):
|
|
855
|
+
try:
|
|
856
|
+
import gc
|
|
857
|
+
|
|
858
|
+
import torch
|
|
859
|
+
|
|
860
|
+
try:
|
|
861
|
+
if trainer is not None and hasattr(trainer, "model"):
|
|
862
|
+
trainer.model = None
|
|
863
|
+
except Exception:
|
|
864
|
+
# Best-effort VRAM release before gc; any failure here is non-fatal cleanup.
|
|
865
|
+
pass
|
|
866
|
+
gc.collect()
|
|
867
|
+
if torch.cuda.is_available():
|
|
868
|
+
torch.cuda.empty_cache()
|
|
869
|
+
except Exception as e:
|
|
870
|
+
print("free_gpu warn:", e)
|
|
871
|
+
|
|
872
|
+
|
|
873
|
+
def _metric_curve(trainer, key: str) -> list:
|
|
874
|
+
"""The logged values of `key` (e.g. 'loss' or 'reward') from the trainer's log history,
|
|
875
|
+
rounded + capped. Lets metrics.json carry the convergence/reward curve for an A/B without
|
|
876
|
+
relying on a checkpoint's trainer_state.json (only written on save_steps) or the console
|
|
877
|
+
(only uploaded on failure). Never raises."""
|
|
878
|
+
try:
|
|
879
|
+
vals = [round(float(h[key]), 4) for h in trainer.state.log_history if key in h]
|
|
880
|
+
return vals[:400]
|
|
881
|
+
except Exception:
|
|
882
|
+
return []
|
|
883
|
+
|
|
884
|
+
|
|
885
|
+
def _ensure_fla_fastpath_on_hopper() -> None:
|
|
886
|
+
"""Make flash-linear-attention's GatedDeltaNet fast path CORRECT + fast on Hopper (sm90)
|
|
887
|
+
instead of dropping it.
|
|
888
|
+
|
|
889
|
+
fla's gated chunk_bwd Triton kernel is miscomputed on Hopper with Triton>=3.4 and HARD-RAISES
|
|
890
|
+
(fla #640). The worker historically DROPPED fla here and fell back to the pure-PyTorch delta
|
|
891
|
+
rule — correct but slow + memory-heavy. The real fix is fla's **tilelang** backend, which is
|
|
892
|
+
correct on Triton>=3.4. So on Hopper we ensure the working stack is present rather than
|
|
893
|
+
removing fla:
|
|
894
|
+
* the pinned ``tilelang==0.1.11`` (the correct GDN chunk_bwd backend) + the pinned
|
|
895
|
+
``apache-tvm-ffi==0.1.11`` (0.1.12 double-registers the TVM-FFI runtime -> ``import
|
|
896
|
+
tilelang`` aborts; and tilelang's own ``apache-tvm-ffi~=0.1.0`` range would let 0.1.12
|
|
897
|
+
back in, so the pin is force-installed last and its resolved version is verified), and
|
|
898
|
+
* a COMPLETE ``fla`` (the PyPI ``flash-linear-attention`` wheel is a broken stub missing
|
|
899
|
+
``fla.modules``; reinstall from git if the resident copy is incomplete).
|
|
900
|
+
Validated A/B (H100 SXM, Qwen3.5 hidden-2560 LoRA, controlled fla on/off): seq4096 435->105
|
|
901
|
+
ms/step & 9.9->6.1 GB (4.2x / 1.6x); seq8192 7.1x; seq16384 3106->247 ms & 32->17 GB (12.6x /
|
|
902
|
+
1.9x). Forward loss matches the torch delta to 1.8e-4 (correct). Runs in the worker process,
|
|
903
|
+
after all installs and BEFORE any model import. Non-Hopper:
|
|
904
|
+
no-op (fla's Triton kernel is correct there). Best-effort + FAIL-CLOSED: a failed install
|
|
905
|
+
(pip rc!=0), a missing module, or the wrong resolved ``apache-tvm-ffi`` version all flip the
|
|
906
|
+
gate off and DISABLE fla, leaving the (correct) pure-PyTorch delta rule in place — a worker
|
|
907
|
+
never crashes on a dep hiccup, and it never silently runs fla's broken Hopper GDN kernel.
|
|
908
|
+
"""
|
|
909
|
+
import importlib
|
|
910
|
+
import importlib.util
|
|
911
|
+
import subprocess
|
|
912
|
+
|
|
913
|
+
try:
|
|
914
|
+
import torch
|
|
915
|
+
|
|
916
|
+
if not (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 9):
|
|
917
|
+
return # not Hopper: fla's Triton kernel is correct here.
|
|
918
|
+
except Exception:
|
|
919
|
+
return
|
|
920
|
+
|
|
921
|
+
def _have(mod: str) -> bool:
|
|
922
|
+
try:
|
|
923
|
+
return importlib.util.find_spec(mod) is not None
|
|
924
|
+
except Exception:
|
|
925
|
+
return False
|
|
926
|
+
|
|
927
|
+
def _ver(dist: str) -> str | None:
|
|
928
|
+
"""Installed version of a distribution (by metadata), or None if absent/unreadable.
|
|
929
|
+
|
|
930
|
+
Distinct from _have (a find_spec import probe): the install can silently leave the WRONG
|
|
931
|
+
version resolved (e.g. tilelang's ``apache-tvm-ffi~=0.1.0`` range happily keeps 0.1.12,
|
|
932
|
+
which still find_spec-imports but aborts ``import tilelang``), so the gate must check the
|
|
933
|
+
actual installed version, not just importability.
|
|
934
|
+
"""
|
|
935
|
+
try:
|
|
936
|
+
import importlib.metadata as _md
|
|
937
|
+
|
|
938
|
+
return _md.version(dist)
|
|
939
|
+
except Exception:
|
|
940
|
+
return None
|
|
941
|
+
|
|
942
|
+
def _pip(*args: str) -> bool:
|
|
943
|
+
"""Run pip install; return True only if pip exited 0. A failed install (network/build/
|
|
944
|
+
resolver) must NOT be silently treated as success — the caller gates ``ok`` on this."""
|
|
945
|
+
try:
|
|
946
|
+
rc = subprocess.run(
|
|
947
|
+
[sys.executable, "-m", "pip", "install", "-q", *args], check=False
|
|
948
|
+
).returncode
|
|
949
|
+
except Exception:
|
|
950
|
+
return False
|
|
951
|
+
return rc == 0
|
|
952
|
+
|
|
953
|
+
# The exact apache-tvm-ffi pin the tilelang backend needs (0.1.12 double-registers the TVM-FFI
|
|
954
|
+
# runtime -> `import tilelang` aborts). Kept as a constant so the install spec and the post-
|
|
955
|
+
# install version gate below can't drift apart. Keep in lockstep with WORKER_DEPS / Dockerfile.
|
|
956
|
+
TVM_FFI_PIN = "0.1.11"
|
|
957
|
+
TILELANG_PIN = "0.1.11" # pin the GDN backend too (same rationale as the fla SHA pin)
|
|
958
|
+
|
|
959
|
+
try:
|
|
960
|
+
# 1. tilelang backend (correct GDN chunk_bwd on Triton>=3.4) + the pinned tvm-ffi.
|
|
961
|
+
# Track whether each install actually succeeded — a failed pip (rc!=0) must flip the
|
|
962
|
+
# gate to the pure-PyTorch fallback rather than be ignored. (_have-only would also pass
|
|
963
|
+
# on a stale/partial copy from a previous boot.) tilelang pulls apache-tvm-ffi via a
|
|
964
|
+
# range that allows the broken 0.1.12, so force-reinstall the exact pin AFTER tilelang
|
|
965
|
+
# and verify the resolved version below.
|
|
966
|
+
# Enforce the EXACT pin: (re)install when tilelang is absent OR a different version is
|
|
967
|
+
# resident (a job or the base image may carry another tilelang; _have-only would treat that
|
|
968
|
+
# as healthy and skip the install, leaving the wrong/uncertain GDN backend in place). Mirror
|
|
969
|
+
# the apache-tvm-ffi handling: check the installed version via _ver and reinstall on mismatch.
|
|
970
|
+
tilelang_ok = True
|
|
971
|
+
tilelang_reinstalled = False
|
|
972
|
+
if _ver("tilelang") != TILELANG_PIN:
|
|
973
|
+
tilelang_ok = _pip(f"tilelang=={TILELANG_PIN}")
|
|
974
|
+
tilelang_reinstalled = True
|
|
975
|
+
# Only force the tvm-ffi pin when it's actually wrong OR tilelang was just (re)installed
|
|
976
|
+
# (tilelang's apache-tvm-ffi~=0.1.0 range can have pulled the broken 0.1.12). Skipping the pip
|
|
977
|
+
# when the exact pin is already resident avoids avoidable cold-start latency and a spurious
|
|
978
|
+
# disable on a transient network/resolver failure — the ok gate still re-verifies the version.
|
|
979
|
+
# If this install runs and fails we DON'T trust the resident copy — tvm_ffi_ok gates `ok` below.
|
|
980
|
+
if _ver("apache-tvm-ffi") != TVM_FFI_PIN or tilelang_reinstalled:
|
|
981
|
+
tvm_ffi_ok = _pip(f"apache-tvm-ffi=={TVM_FFI_PIN}")
|
|
982
|
+
else:
|
|
983
|
+
tvm_ffi_ok = True
|
|
984
|
+
# 2. a COMPLETE fla — the PyPI wheel ships a stub without `fla.modules`. Reinstall from git
|
|
985
|
+
# when the resident copy is missing the real package (or absent entirely).
|
|
986
|
+
fla_ok = True
|
|
987
|
+
if not (_have("fla") and _have("fla.modules")):
|
|
988
|
+
_remove_fla_from_disk() # clear any broken stub before the git install
|
|
989
|
+
# Pinned to the same commit as WORKER_DEPS / Dockerfile.worker so a runtime reinstall is
|
|
990
|
+
# reproducible (the moving default branch could pull a broken/incompatible fla).
|
|
991
|
+
fla_ok = _pip(
|
|
992
|
+
"--no-deps",
|
|
993
|
+
"git+https://github.com/fla-org/flash-linear-attention.git"
|
|
994
|
+
"@f0e213dbd8b5fb90c3c7eca869ac1706d5377139",
|
|
995
|
+
)
|
|
996
|
+
importlib.invalidate_caches()
|
|
997
|
+
# Gate on BOTH (a) every install we ran exiting 0 — a failed pip (network/build/resolver)
|
|
998
|
+
# must NOT be treated as healthy just because a stale/partial copy still find_spec-imports —
|
|
999
|
+
# AND (b) the modules importing AND (c) the resolved apache-tvm-ffi being exactly the pin.
|
|
1000
|
+
# (c) matters because tilelang depends on `apache-tvm-ffi~=0.1.0`, so the resolver can keep
|
|
1001
|
+
# the broken 0.1.12 (which find_spec-imports fine but aborts `import tilelang`); checking the
|
|
1002
|
+
# version is the only reliable signal the pin actually landed.
|
|
1003
|
+
tvm_ffi_ver = _ver("apache-tvm-ffi")
|
|
1004
|
+
tilelang_ver = _ver("tilelang")
|
|
1005
|
+
installs_ok = tilelang_ok and tvm_ffi_ok and fla_ok
|
|
1006
|
+
ok = (
|
|
1007
|
+
installs_ok
|
|
1008
|
+
and _have("fla")
|
|
1009
|
+
and _have("fla.modules")
|
|
1010
|
+
and _have("tilelang")
|
|
1011
|
+
and tilelang_ver == TILELANG_PIN
|
|
1012
|
+
and tvm_ffi_ver == TVM_FFI_PIN
|
|
1013
|
+
)
|
|
1014
|
+
if not ok:
|
|
1015
|
+
# The healthy fla+tilelang stack could not be assembled, so fla's GDN chunk_bwd would
|
|
1016
|
+
# still hit the broken Triton>=3.4 path on Hopper (fla #640) and HARD-RAISE. A print
|
|
1017
|
+
# alone does NOT prevent that: transformers gates GDN on is_fla_available() (a
|
|
1018
|
+
# find_spec('fla') probe), so as long as fla stays importable it gets engaged. PHYSICALLY
|
|
1019
|
+
# remove fla so the probe sees it gone and transformers uses the correct pure-PyTorch
|
|
1020
|
+
# delta rule instead of crashing. _remove_fla_from_disk loops over the real sys.path +
|
|
1021
|
+
# invalidates caches, so find_spec('fla') is None afterwards (the gate flips off).
|
|
1022
|
+
_removed, _still = _remove_fla_from_disk()
|
|
1023
|
+
print(
|
|
1024
|
+
"[hopper] fla GDN fast path unavailable -> DISABLING fla "
|
|
1025
|
+
f"(installs_ok={installs_ok} [tilelang={tilelang_ok} tvm_ffi={tvm_ffi_ok} "
|
|
1026
|
+
f"fla={fla_ok}], tilelang_ver={tilelang_ver!r} (want {TILELANG_PIN}), "
|
|
1027
|
+
f"tvm_ffi_ver={tvm_ffi_ver!r} (want {TVM_FFI_PIN}); "
|
|
1028
|
+
f"removed {len(_removed)} copy(ies); still_importable={_still}); "
|
|
1029
|
+
"pure-PyTorch delta fallback",
|
|
1030
|
+
flush=True,
|
|
1031
|
+
)
|
|
1032
|
+
else:
|
|
1033
|
+
print(
|
|
1034
|
+
"[hopper] fla GDN fast path ENABLED (fla+tilelang "
|
|
1035
|
+
f"{tilelang_ver}/tvm-ffi {tvm_ffi_ver}, fla #640 fixed)",
|
|
1036
|
+
flush=True,
|
|
1037
|
+
)
|
|
1038
|
+
except Exception as e: # never let a dep hiccup crash the worker — torch delta still runs
|
|
1039
|
+
# Fail-closed: an unexpected error mid-setup must still leave Hopper on the correct
|
|
1040
|
+
# pure-PyTorch delta path, not a half-configured fla that transformers would engage and
|
|
1041
|
+
# crash on (#640). Best-effort disable fla; never re-raise.
|
|
1042
|
+
with contextlib.suppress(Exception):
|
|
1043
|
+
_remove_fla_from_disk()
|
|
1044
|
+
print(
|
|
1045
|
+
f"[hopper] fla fast-path setup errored ({type(e).__name__}: {e}); "
|
|
1046
|
+
"disabled fla -> pure-PyTorch delta",
|
|
1047
|
+
flush=True,
|
|
1048
|
+
)
|