freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,2916 @@
1
+ """On-GPU fine-tuning worker (RunPod). Modes: sft | rl.
2
+
3
+ This module runs on the provisioned RunPod GPU. It uses the shared recipe
4
+ (``flash.engine.recipe``) so SFT targets and RL rewards are rendered and scored
5
+ consistently.
6
+
7
+ Artifacts (adapter, metrics.json, heartbeat.json, checkpoints) are streamed to a
8
+ Hugging Face dataset repo. HF checkpoints give preemption resilience: if a worker is
9
+ recycled mid-run we resume from the latest uploaded checkpoint. Metrics are also
10
+ returned directly to the caller by the launching provider.
11
+
12
+ Core environment variables (set by the launching provider / runner):
13
+ RUN_MODE sft|rl
14
+ SEED int
15
+ HF_REPO Hugging Face dataset repo for artifacts, populated per-run from the
16
+ JobSpec's [train] hf_repo by whichever provider launches the worker
17
+ HF_TOKEN
18
+ RUN_ID unique id for this run (namespacing in the repo)
19
+
20
+ The FLASH_*/RL_*/SFT_* env vars are A/B overrides documented at their use sites; the
21
+ JobSpec [train] table is the source of truth for per-run knobs.
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import contextlib
27
+ import faulthandler
28
+ import json
29
+ import math
30
+ import os
31
+ import random
32
+ import re
33
+ import sys
34
+ import tempfile
35
+ import threading
36
+ import time
37
+ import traceback
38
+
39
+ from flash.engine.accounting import RunMetrics
40
+
41
+ # Shared, substrate-neutral fine-tuning internals (live in this same package).
42
+ from flash.engine.chalk_kernels import active_kernels, install_chalk_kernels
43
+ from flash.engine.recipe import RECIPE
44
+
45
+ # Re-export the pure helpers split into the leaf submodules ``.perf`` and ``.lora``.
46
+ # CRITICAL: the readers below (run_sft / run_rl / make_lora / _init_adapter_model / ...) call
47
+ # these by their bare name, which resolves through THIS module's namespace — so a test's
48
+ # ``monkeypatch.setattr(worker, "<name>", ...)`` still reaches the readers. Names actually used
49
+ # by the retained readers are imported plainly; names re-exported only for API / test access
50
+ # (no retained reader uses them) are marked unused for the linter.
51
+ from flash.engine.worker.lora import (
52
+ _LM_SYNC_REMAP_ON,
53
+ _VL_EXCLUDE_SEGMENTS, # noqa: F401
54
+ _patch_peft_weight_converter_compat, # noqa: F401
55
+ _remap_vl_sync_weights, # noqa: F401
56
+ assert_adapter_delta_nonzero,
57
+ assert_adapter_load_clean,
58
+ assert_lora_applied,
59
+ disable_liger_grpo_torch_compile,
60
+ is_vl_checkpoint,
61
+ lora_exclude_modules,
62
+ model_quant, # noqa: F401
63
+ patch_grpo_mask_aware_lm_head,
64
+ patch_vllm_language_model_only,
65
+ patch_vllm_lm_weight_sync,
66
+ remap_adapter_keys, # noqa: F401
67
+ remap_vl_adapter_dir,
68
+ strip_language_model_infix, # noqa: F401
69
+ vllm_language_model_only_kwargs, # noqa: F401
70
+ )
71
+ from flash.engine.worker.packing import (
72
+ BlockDiagonalCollator,
73
+ gdn_packing_available,
74
+ model_is_gdn_hybrid,
75
+ model_is_pure_attention,
76
+ pack_token_ids,
77
+ packing_efficiency,
78
+ tokenize_for_packing,
79
+ )
80
+ from flash.engine.worker.perf import (
81
+ RetriableInfraError,
82
+ _attn_impl_for_capability, # noqa: F401
83
+ _ensure_fla_fastpath_on_hopper,
84
+ _estimate_params, # noqa: F401
85
+ _flash_attn_3_available, # noqa: F401
86
+ _flash_attn_available,
87
+ _GpuPeakSampler,
88
+ _liger_default_for_model, # noqa: F401
89
+ _memory_mode,
90
+ _metric_curve,
91
+ _neutralize_tilelang_cudart_stub,
92
+ _peak_gpu_gb,
93
+ _remove_fla_from_disk, # noqa: F401
94
+ _reset_peak_gpu,
95
+ _sdpa_cudnn_ctx,
96
+ free_gpu,
97
+ fused_optim_name,
98
+ gpu_diagnostics,
99
+ grad_checkpointing_on,
100
+ grpo_sleep_mode,
101
+ liger_on,
102
+ loraplus_optimizer_cls,
103
+ optimal_attn_impl,
104
+ setup_perf_backends,
105
+ wait_for_gpu,
106
+ )
107
+ from flash.envs.adapter import GitHubRateLimitError
108
+ from flash.envs.registry import load_environment
109
+ from flash.spec import load_job_spec_from_env
110
+
111
+ HF_REPO = os.environ.get("HF_REPO", "")
112
+ RUN_ID = os.environ.get("RUN_ID", "local")
113
+ SEED = int(os.environ.get("SEED", "0"))
114
+ RUN_MODE = os.environ.get("RUN_MODE", "sft")
115
+ ATTEMPT = os.environ.get("ATTEMPT", "")
116
+ JOB_SPEC = load_job_spec_from_env()
117
+ # PHASE is the stable artifact namespace (sft|rl) and matches RUN_MODE for a train run.
118
+ PHASE = os.environ.get(
119
+ "PHASE",
120
+ JOB_SPEC.phase if JOB_SPEC else (RUN_MODE if RUN_MODE in ("sft", "rl") else "sft"),
121
+ )
122
+
123
+
124
+ def _load_active_env():
125
+ """Load the run's Freesolo environment from the JobSpec; require an explicit env.
126
+
127
+ There is no default/builtin environment: a run MUST name a published Freesolo
128
+ environment id. Failing here prevents a paid worker from training/evaluating the
129
+ wrong task.
130
+ """
131
+ if JOB_SPEC is None:
132
+ # No JobSpec at all (e.g. the module imported for a non-run path / a unit test). There
133
+ # is nothing to select; defer the hard requirement to the JobSpec-present branch so the
134
+ # module stays importable. A real run always carries a JobSpec.
135
+ return None
136
+ env_id = JOB_SPEC.environment.id
137
+ if not env_id:
138
+ # Every supported algorithm (sft/grpo) trains/evaluates against a Freesolo env, so a
139
+ # missing env is always a misconfigured spec. Fail loudly rather than fall back to a
140
+ # default and burn a paid worker on the wrong task.
141
+ raise RuntimeError(
142
+ "JobSpec sets no environment: provide [environment] id "
143
+ "(a Freesolo environment id like 'your-name/your-env', returned by "
144
+ "`flash env push --name <name>`)."
145
+ )
146
+ # Pass the control-plane-pinned commit sha (resolve-once hook) when present so the adapter
147
+ # skips the GitHub ref->sha resolve; "" (the default) keeps the worker resolving it itself.
148
+ return load_environment(
149
+ env_id, JOB_SPEC.environment.params, resolved_sha=JOB_SPEC.environment.resolved_sha
150
+ )
151
+
152
+
153
+ ACTIVE_ENV = None
154
+
155
+
156
+ def require_active_env():
157
+ """Return the run's loaded environment, or raise a CLEAR error when there is none.
158
+
159
+ ``ACTIVE_ENV`` is None on the no-JobSpec path (the module is imported with no
160
+ FLASH_JOB_SPEC_JSON/PATH, e.g. a misconfigured worker launch). Every train/eval consumer
161
+ needs a real env; without this guard the first ``ACTIVE_ENV.<attr>`` access dies with an
162
+ opaque ``AttributeError: 'NoneType' object has no attribute ...``. Fail loudly with an
163
+ actionable message instead — mirrors the explicit RuntimeError raised when a JobSpec is
164
+ present but names no environment.
165
+ """
166
+ global ACTIVE_ENV
167
+ if ACTIVE_ENV is None:
168
+ ACTIVE_ENV = _load_active_env()
169
+ if ACTIVE_ENV is None:
170
+ raise RuntimeError(
171
+ "no environment is loaded: this worker was started without a JobSpec "
172
+ "(FLASH_JOB_SPEC_JSON / FLASH_JOB_SPEC_PATH is unset). A train/eval run must "
173
+ "carry a JobSpec naming [environment] id "
174
+ "(a Freesolo environment id like 'your-name/your-env', returned by "
175
+ "`flash env push --name <name>`)."
176
+ )
177
+ return ACTIVE_ENV
178
+
179
+
180
+ # Thinking/reasoning mode: one flag per run from the run config (TOML `thinking`), consumed
181
+ # identically by SFT rendering, RL rollouts, and serving. Defaults off without a JobSpec.
182
+ THINKING = JOB_SPEC.thinking if JOB_SPEC else False
183
+
184
+
185
+ # ---------------------------------------------------------------------------
186
+ # HF helpers (code-delivery + artifact channel; works without inbound network)
187
+ # ---------------------------------------------------------------------------
188
+ def error_artifact_name(mode: str) -> str:
189
+ """Per-mode error filename (e.g. error_sft.txt) so a run's traceback is uploaded
190
+ under a stable name even though heartbeat.json is single-file/overwritten."""
191
+ return f"error_{mode}.txt"
192
+
193
+
194
+ def hf_api():
195
+ from huggingface_hub import HfApi
196
+
197
+ return HfApi(token=os.environ.get("HF_TOKEN"))
198
+
199
+
200
+ def hf_prefix() -> str:
201
+ return f"{PHASE}/{RUN_ID}/seed{SEED}"
202
+
203
+
204
+ def _hf_upload(do_upload, repo_subpath: str, required: bool, label: str) -> None:
205
+ """Shared HF upload loop for files/folders: HF_REPO guard + retry/raise-or-warn.
206
+
207
+ ``required=True`` (completion artifacts DONE/metrics.json, the trained adapter) retries
208
+ and finally raises: a swallowed upload failure would make the control plane mark a
209
+ finished run failed/retried, or mark the run done while deployment can never download
210
+ the missing adapter. Optional artifacts (generations, logs) only warn.
211
+ """
212
+ if not HF_REPO:
213
+ return
214
+ attempts = 3 if required else 1
215
+ for attempt in range(attempts):
216
+ try:
217
+ do_upload()
218
+ return
219
+ except Exception as e:
220
+ if required and attempt + 1 < attempts:
221
+ print(f"{label} retry {attempt + 1}/{attempts}: {e}")
222
+ time.sleep(5 * (attempt + 1))
223
+ continue
224
+ if required:
225
+ # Already retried 3x -> the host/network is bad, not the run. Infra-shaped.
226
+ raise RetriableInfraError(f"required upload of {repo_subpath!r} failed: {e}") from e
227
+ print(f"{label} warn:", e)
228
+ return
229
+
230
+
231
+ def hf_upload_file(local_path: str, repo_subpath: str, required: bool = False):
232
+ """Upload one file to the run's HF prefix."""
233
+ _hf_upload(
234
+ lambda: hf_api().upload_file(
235
+ path_or_fileobj=local_path,
236
+ path_in_repo=f"{hf_prefix()}/{repo_subpath}",
237
+ repo_id=HF_REPO,
238
+ repo_type="dataset",
239
+ ),
240
+ repo_subpath,
241
+ required,
242
+ "hf_upload_file",
243
+ )
244
+
245
+
246
+ _DEBUG_UPLOAD_LOCK = threading.Lock()
247
+
248
+
249
+ def upload_debug_jsonl(name: str, rows: list[dict], *, keep_last: int = 200) -> None:
250
+ """Append bounded JSONL debug rows and upload them as an optional artifact.
251
+
252
+ This is intentionally best-effort: debug visibility must not fail a paid run.
253
+ """
254
+ if not rows or not HF_REPO:
255
+ return
256
+ repo_name = os.path.basename(name if name.endswith(".jsonl") else f"{name}.jsonl")
257
+ path = os.path.join("/tmp", repo_name)
258
+ try:
259
+ with _DEBUG_UPLOAD_LOCK:
260
+ existing: list[str] = []
261
+ with contextlib.suppress(FileNotFoundError), open(path) as f:
262
+ existing = f.readlines()[-keep_last:]
263
+ with open(path, "w") as f:
264
+ f.writelines(existing)
265
+ for row in rows:
266
+ f.write(json.dumps(row, default=str, ensure_ascii=True, sort_keys=True) + "\n")
267
+ hf_upload_file(path, repo_name)
268
+ except Exception as e:
269
+ print(f"debug upload warn ({repo_name}): {e}")
270
+
271
+
272
+ def hf_upload_folder(local_dir: str, repo_subpath: str, required: bool = False):
273
+ """Upload a folder to the run's HF prefix."""
274
+ _hf_upload(
275
+ lambda: hf_api().upload_folder(
276
+ folder_path=local_dir,
277
+ path_in_repo=f"{hf_prefix()}/{repo_subpath}",
278
+ repo_id=HF_REPO,
279
+ repo_type="dataset",
280
+ ),
281
+ repo_subpath,
282
+ required,
283
+ "hf_upload_folder",
284
+ )
285
+
286
+
287
+ def hf_resume_checkpoint() -> str | None:
288
+ """Latest streamed trainer checkpoint for this run (or None).
289
+
290
+ Checkpoints are uploaded DURING the run by ``make_checkpoint_upload_callback`` as
291
+ ``<prefix>/checkpoint/checkpoint-<step>/``; a replacement worker downloads the
292
+ newest one so a mid-run preemption costs at most one save interval.
293
+ """
294
+ if not HF_REPO:
295
+ return None
296
+ try:
297
+ from huggingface_hub import snapshot_download
298
+
299
+ snapshot_download(
300
+ repo_id=HF_REPO,
301
+ repo_type="dataset",
302
+ allow_patterns=[f"{hf_prefix()}/checkpoint/**"],
303
+ local_dir="/tmp/resume",
304
+ token=os.environ.get("HF_TOKEN"),
305
+ )
306
+ base = os.path.join("/tmp/resume", hf_prefix(), "checkpoint")
307
+ if not os.path.isdir(base):
308
+ return None
309
+ cands = [d for d in os.listdir(base) if d.startswith("checkpoint-")]
310
+ if not cands:
311
+ return None
312
+ latest = max(cands, key=lambda d: int(d.split("-")[-1]))
313
+ path = os.path.join(base, latest)
314
+ print(f"[resume] found streamed checkpoint: {path}")
315
+ return path
316
+ except Exception as e:
317
+ print("hf_resume_checkpoint warn:", e)
318
+ return None
319
+
320
+
321
+ def prefetch_model(model_id: str) -> float:
322
+ """Pull the model weights into the local HF cache up front; return seconds spent.
323
+
324
+ The trainer/vLLM would download lazily anyway — doing it explicitly (a) makes the
325
+ download a first-class, timed stage in the heartbeat stream (the cold-start metric
326
+ the speed work optimizes), and (b) fails fast with a clear disk/network error
327
+ instead of dying inside trainer construction. Idempotent: a warm cache costs ~0 s.
328
+ """
329
+ from huggingface_hub import snapshot_download
330
+
331
+ t0 = time.time()
332
+ try:
333
+ snapshot_download(
334
+ repo_id=model_id,
335
+ # weights + tokenizer/config only (same exclusions as the image bake)
336
+ ignore_patterns=["*.pth", "*.gguf", "original/*", "*.onnx", "*.msgpack", "*.h5"],
337
+ )
338
+ except Exception as e:
339
+ # Surface but don't fail here: gated/local-only models still load fine through
340
+ # the normal from_pretrained path the trainer uses next.
341
+ print("prefetch_model warn:", e)
342
+ secs = round(time.time() - t0, 1)
343
+ heartbeat(
344
+ "model_prefetched",
345
+ model=model_id,
346
+ download_seconds=secs,
347
+ hf_transfer=os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", ""),
348
+ gpu=gpu_diagnostics(),
349
+ )
350
+ return secs
351
+
352
+
353
+ # Trainer-state files a serving engine never needs: optimizer/scheduler/rng/loss-curve
354
+ # state. Excluded when publishing the deployable per-step adapter so each step's snapshot is
355
+ # just the LoRA weights + config (a few MB), small enough to KEEP every step (no pruning).
356
+ _CHECKPOINT_TRAINER_STATE = (
357
+ "optimizer.pt",
358
+ "optimizer.bin",
359
+ "scheduler.pt",
360
+ "scaler.pt",
361
+ "rng_state*.pth",
362
+ "trainer_state.json",
363
+ "training_args.bin",
364
+ "*.distcp",
365
+ "global_step*/**",
366
+ "latest",
367
+ "zero_to_fp32.py",
368
+ )
369
+
370
+ # The PEFT adapter weights file a checkpoint must carry to be loadable/servable (safetensors is
371
+ # the default; .bin is the legacy fallback). A step with adapter_config.json but no weights is
372
+ # NOT deployable, so it's never published/listed.
373
+ _ADAPTER_WEIGHT_FILES = ("adapter_model.safetensors", "adapter_model.bin")
374
+
375
+
376
+ def publish_deployable_checkpoint(ckpt_dir: str, step: int) -> str | None:
377
+ """Mirror a trainer checkpoint's LoRA adapter to a stable, NON-pruned per-step path so a
378
+ run cancelled mid-RL is still one-command-deployable from its last good step.
379
+
380
+ The trainer's checkpoint folder already contains the PEFT adapter (``adapter_config.json``
381
+ + ``adapter_model.safetensors``) that ``deploy_adapter`` serves; we re-upload just those
382
+ (dropping optimizer/scheduler/rng state) to ``<prefix>/checkpoints/step-<step>/adapter``.
383
+ Unlike the resume checkpoint (``checkpoint/**``, kept latest-only), these accumulate, so
384
+ EVERY step stays deployable. Returns the deployable adapter subfolder, or ``None`` when
385
+ there's no adapter to publish. Best-effort: a failure here never fails a paid run.
386
+ """
387
+ if not HF_REPO:
388
+ return None
389
+ # Only publish a checkpoint that actually carries a loadable adapter (config AND weights) —
390
+ # never advertise a non-deployable step.
391
+ has_config = os.path.isfile(os.path.join(ckpt_dir, "adapter_config.json"))
392
+ has_weights = any(os.path.isfile(os.path.join(ckpt_dir, w)) for w in _ADAPTER_WEIGHT_FILES)
393
+ if not (has_config and has_weights):
394
+ return None
395
+ subfolder = f"{hf_prefix()}/checkpoints/step-{step}/adapter"
396
+ try:
397
+ hf_api().upload_folder(
398
+ folder_path=ckpt_dir,
399
+ path_in_repo=subfolder,
400
+ repo_id=HF_REPO,
401
+ repo_type="dataset",
402
+ ignore_patterns=list(_CHECKPOINT_TRAINER_STATE),
403
+ )
404
+ heartbeat("checkpoint_deployable", step=step, subfolder=subfolder)
405
+ return subfolder
406
+ except Exception as e:
407
+ print(f"[ckpt] deployable publish warn (step {step}):", e)
408
+ return None
409
+
410
+
411
+ def make_checkpoint_upload_callback():
412
+ """Stream each trainer save to HF so preemption loses <= one save interval.
413
+
414
+ Uploads run in a background thread (the train loop never blocks on the network);
415
+ older checkpoints are deleted in the same commit. If an upload is still in flight
416
+ when the next save fires, the new save is skipped (the following one catches up).
417
+
418
+ Each save also publishes a deployable per-step adapter snapshot (``publish_deployable_
419
+ checkpoint``) so a run cancelled mid-RL can still be deployed from its latest step.
420
+ """
421
+ from transformers import TrainerCallback
422
+
423
+ lock = threading.Lock()
424
+
425
+ class _CheckpointUpload(TrainerCallback):
426
+ def on_save(self, args, state, control, **kwargs):
427
+ if not HF_REPO:
428
+ return
429
+ step = int(state.global_step)
430
+ ckpt_dir = os.path.join(args.output_dir, f"checkpoint-{step}")
431
+ if not os.path.isdir(ckpt_dir):
432
+ return
433
+ if not lock.acquire(blocking=False):
434
+ print(f"[ckpt] upload busy; skipping step {step}")
435
+ return
436
+
437
+ def _upload():
438
+ try:
439
+ hf_api().upload_folder(
440
+ folder_path=ckpt_dir,
441
+ path_in_repo=f"{hf_prefix()}/checkpoint/checkpoint-{step}",
442
+ repo_id=HF_REPO,
443
+ repo_type="dataset",
444
+ delete_patterns=[f"{hf_prefix()}/checkpoint/**"],
445
+ )
446
+ heartbeat("checkpoint_uploaded", step=step)
447
+ # Mirror this step's adapter to its own kept-forever path so the run
448
+ # stays deployable even if it never reaches "done".
449
+ publish_deployable_checkpoint(ckpt_dir, step)
450
+ except Exception as e:
451
+ print("ckpt upload warn:", e)
452
+ finally:
453
+ lock.release()
454
+
455
+ threading.Thread(target=_upload, daemon=True).start()
456
+
457
+ return _CheckpointUpload()
458
+
459
+
460
+ # Heartbeat HF-commit throttle. Each heartbeat() commits heartbeat.json to the HF artifact
461
+ # repo; committing every training step (the reward callback fires per step) blows HuggingFace's
462
+ # per-repo commit rate limit (128/hour), especially when several runs share one HF_REPO. Only
463
+ # the per-step "rl_step" stage is high-frequency, so throttle JUST that one to once per
464
+ # 60s; every other stage — including milestones and the terminal done/already_done — always
465
+ # commits so the control plane never misses a transition.
466
+ # The local file + stdout line are always written regardless.
467
+ _HB_LAST_UPLOAD = 0.0
468
+
469
+
470
+ # The rl_step heartbeat-upload throttle, in seconds (fixed 60s) — keeps GRPO under HF's
471
+ # 128 commits/hour-per-repo limit when concurrent runs share one HF_REPO.
472
+ _HB_MIN_INTERVAL_S = 60.0
473
+ _HB_THROTTLED_STAGES = frozenset({"rl_step"})
474
+ # Terminal transitions the control plane must never miss — always committed.
475
+ _HB_TERMINAL_STAGES = frozenset({"done", "already_done"})
476
+ _HB_TERMINAL_ONLY = False
477
+ # Even in terminal-only mode, emit a SLOW heartbeat at this cadence so the control plane's stall
478
+ # detector keeps seeing progress through a long
479
+ # training phase and doesn't false-stall the run. 600s -> ~6 commits/hr, far under the 128/hr cap.
480
+ _HB_TERMINAL_ONLY_INTERVAL_S = 600.0
481
+
482
+
483
+ # Serializes heartbeat.json writes and _HB_LAST_UPLOAD reads/updates. During GRPO,
484
+ # heartbeat() is called concurrently from the trainer thread (reward callback) and the
485
+ # checkpoint-upload daemon thread; without this lock two writers can interleave and
486
+ # truncate/garble heartbeat.json (and race _HB_LAST_UPLOAD).
487
+ _HB_LOCK = threading.Lock()
488
+ # Serializes the actual HF upload (a slow network commit) SEPARATELY from _HB_LOCK so the
489
+ # trainer's frequent local writes never block on the network. Without it, two heartbeat
490
+ # threads can upload heartbeat.json concurrently: a slower upload could land AFTER a newer
491
+ # one on HF (reorder), so this lock makes uploads strictly ordered.
492
+ _HB_UPLOAD_LOCK = threading.Lock()
493
+
494
+ # Stall diagnostics: when FLASH_STALL_FAULTHANDLER_S > 0, arm a faulthandler watchdog that dumps
495
+ # every thread's Python stack (then exits, so the run FAILS instead of hanging until the
496
+ # control-plane stall watchdog kills it ~25 min later, and the dump is uploaded with
497
+ # console_<phase>.txt). The timer is re-armed on every heartbeat, so it only fires when NO progress
498
+ # heartbeat lands for the whole window -- i.e. a real hang. OFF by default (0); opt-in per run via
499
+ # [worker_env]. Used to localize the GRPO sleep-mode rollout hang.
500
+ _STALL_FAULTHANDLER_S = 0
501
+ with contextlib.suppress(Exception):
502
+ _STALL_FAULTHANDLER_S = int(os.environ.get("FLASH_STALL_FAULTHANDLER_S", "0") or 0)
503
+
504
+
505
+ def _rearm_stall_faulthandler() -> None:
506
+ if _STALL_FAULTHANDLER_S <= 0:
507
+ return
508
+ with contextlib.suppress(Exception):
509
+ faulthandler.cancel_dump_traceback_later()
510
+ faulthandler.dump_traceback_later(_STALL_FAULTHANDLER_S, exit=True)
511
+
512
+
513
+ def heartbeat(stage: str, **kw):
514
+ global _HB_LAST_UPLOAD
515
+ payload = {
516
+ "stage": stage,
517
+ "ts": time.time(),
518
+ "run_id": RUN_ID,
519
+ "mode": RUN_MODE,
520
+ "seed": SEED,
521
+ "attempt": ATTEMPT,
522
+ **kw,
523
+ }
524
+ # The datacenter the worker actually landed in (RunPod serverless sets RUNPOD_DC_ID) — a
525
+ # diagnostic so the control plane / logs show which region a run hit (the eager weight-cache fleet
526
+ # already has a volume in every storage DC). Empty/absent on non-RunPod (instance) workers and
527
+ # harmless; only emitted when present.
528
+ _dc = os.environ.get("RUNPOD_DC_ID") or ""
529
+ if _dc:
530
+ payload.setdefault("dc", _dc)
531
+ os.makedirs("/tmp/hb", exist_ok=True)
532
+ p = "/tmp/hb/heartbeat.json"
533
+ # _HB_LOCK guards ONLY the fast local work (atomic write + _HB_LAST_UPLOAD + snapshot capture);
534
+ # the slow HF commit runs OUTSIDE it so the trainer's per-step reward callback never blocks on
535
+ # the network behind the checkpoint daemon's commit (a GRPO perf regression).
536
+ with _HB_LOCK:
537
+ # Atomic write: temp file + os.replace() so a concurrent reader never sees a partial file.
538
+ tmp = p + f".{os.getpid()}.{threading.get_ident()}.tmp"
539
+ snapshot = json.dumps(payload)
540
+ with open(tmp, "w") as f:
541
+ f.write(snapshot)
542
+ os.replace(tmp, p)
543
+ now = time.time()
544
+ if stage in _HB_TERMINAL_STAGES or stage.startswith("error_"):
545
+ upload_due = True # never miss a terminal transition
546
+ elif _HB_TERMINAL_ONLY:
547
+ # Benchmark fan-out: keep commits far under the 128/hour cap, but still emit a SLOW
548
+ # heartbeat (~every _HB_TERMINAL_ONLY_INTERVAL_S) so the control-plane stall detector
549
+ # sees progress during a long training phase and doesn't false-stall the run.
550
+ upload_due = (
551
+ _HB_LAST_UPLOAD == 0.0 or (now - _HB_LAST_UPLOAD) >= _HB_TERMINAL_ONLY_INTERVAL_S
552
+ )
553
+ else:
554
+ throttled = stage in _HB_THROTTLED_STAGES
555
+ upload_due = not throttled or (now - _HB_LAST_UPLOAD) >= _HB_MIN_INTERVAL_S
556
+ if upload_due:
557
+ _HB_LAST_UPLOAD = now # claim the slot under the lock (throttle stays atomic)
558
+ if upload_due:
559
+ # Serialize the network commit under a SEPARATE lock so uploads can't reorder, and
560
+ # upload the captured snapshot (via a private temp file, since hf_upload_file takes
561
+ # a path) rather than re-reading p — which a newer heartbeat may already have
562
+ # overwritten between our slot-claim and this upload.
563
+ with _HB_UPLOAD_LOCK:
564
+ up = p + f".{os.getpid()}.{threading.get_ident()}.upload.tmp"
565
+ with open(up, "w") as f:
566
+ f.write(snapshot)
567
+ try:
568
+ hf_upload_file(up, "heartbeat.json")
569
+ finally:
570
+ with contextlib.suppress(OSError):
571
+ os.remove(up)
572
+ # Re-arm the stall watchdog: progress landed, so reset the no-heartbeat timer.
573
+ _rearm_stall_faulthandler()
574
+ print("HEARTBEAT", json.dumps(payload))
575
+
576
+
577
+ # ---------------------------------------------------------------------------
578
+ # Decoding parity: render with the model's own chat template and one run-wide thinking
579
+ # flag (off by default), so SFT targets and RL rollouts use identical prompt
580
+ # formatting within a run.
581
+ # ---------------------------------------------------------------------------
582
+ def render_prompt(tokenizer, item) -> str:
583
+ item = item if isinstance(item, dict) else {"question": item}
584
+ msgs = require_active_env().prompt_messages(item)
585
+ return tokenizer.apply_chat_template(
586
+ msgs, tokenize=False, add_generation_prompt=True, enable_thinking=THINKING
587
+ )
588
+
589
+
590
+ def strip_think(completion: str | None) -> str | None:
591
+ """Drop <think>...</think> reasoning before the environment grades/rewards a
592
+ thinking-mode completion.
593
+
594
+ - closed block(s): keep only the text after the LAST </think>. This also covers
595
+ always-thinking templates that pre-open <think> inside the generation prompt,
596
+ whose completions contain </think> with no opening tag.
597
+ - unclosed <think> (completion budget exhausted): keep only the pre-think text
598
+ (usually empty), so answer extraction fails and the completion scores 0 —
599
+ deliberate reward pressure to close thinking within budget, and it keeps a
600
+ last-number fallback from matching numbers inside the reasoning.
601
+ - no tags: unchanged.
602
+ """
603
+ if completion is None:
604
+ return None
605
+ if "</think>" in completion:
606
+ return completion.rsplit("</think>", 1)[1]
607
+ if "<think>" in completion:
608
+ return completion.split("<think>", 1)[0]
609
+ return completion
610
+
611
+
612
+ def graded_text(completion: str | None) -> str | None:
613
+ """What the env grader/reward sees: thinking runs strip <think> blocks first (a
614
+ completion whose reasoning never closes grades 0 — see strip_think). Applied once
615
+ here, before ACTIVE_ENV.grade/reward, so it works for every environment."""
616
+ return strip_think(completion) if THINKING else completion
617
+
618
+
619
+ # ---------------------------------------------------------------------------
620
+ # SFT
621
+ # ---------------------------------------------------------------------------
622
+
623
+
624
+ def force_vllm_backend_for_sm120() -> str | None:
625
+ """On RTX 5090 / consumer Blackwell (sm120), force a PTX-independent vLLM attention backend.
626
+
627
+ vLLM's default rollout backend is flash-attn, whose PRE-BUILT PTX needs a newer driver JIT than
628
+ many 5090 RunPod hosts have — when the JIT fails the colocated rollout silently produces NO
629
+ completions (empty reward_history, ~1.4 s "done"; a whole 22-run sweep hit this on every 5090).
630
+ FLASHINFER is vLLM's Blackwell-native backend (no flash-attn PTX dependency) and trains on a 5090
631
+ (measured: FLASHINFER/TORCH_SDPA/TRITON_ATTN all train, ~116 s). This mirrors the trainer's
632
+ cuDNN-SDPA forcing on sm120 (``_attn_impl_for_capability``). The GRPO no-op guard remains the
633
+ backstop. Returns the backend set (None if not sm120). Fixed — no operator override."""
634
+ try:
635
+ import torch
636
+
637
+ if not torch.cuda.is_available() or torch.cuda.get_device_capability(0)[0] != 12:
638
+ return None
639
+ except Exception as e:
640
+ print("[rl] sm120 vLLM backend probe skipped:", e)
641
+ return None
642
+ os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER"
643
+ print(
644
+ "[rl] sm120 (RTX 5090): VLLM_ATTENTION_BACKEND=FLASHINFER (flash-attn PTX is unreliable "
645
+ "on consumer Blackwell hosts -> empty-rollout failures)"
646
+ )
647
+ return "FLASHINFER"
648
+
649
+
650
+ def finalize_alloc_conf_for_sleep() -> None:
651
+ """Sync the CUDA allocator conf with the worker's RESOLVED vLLM sleep default (RL runs only).
652
+
653
+ The launcher (providers/*/train.py build_worker_env) picks the sleep-SAFE non-expandable
654
+ PYTORCH_ALLOC_CONF for RL before this process starts, but it can't know the GRPO sleep decision:
655
+ for a small model the worker resolves sleep OFF (the speed default), so the non-expandable conf
656
+ is safe but fragments a long colocate run. Here (we have the model config + GPU) we resolve the
657
+ SAME deterministic sleep default (``_memory_mode``, exactly run_rl's gate) and, if sleep is OFF,
658
+ switch to expandable_segments — which only crashes WITH sleep on, a case we've just ruled out.
659
+ PYTORCH_ALLOC_CONF is read lazily at the first CUDA allocation, so this must run before any
660
+ allocation (it does — called at boot)."""
661
+ if PHASE != "rl":
662
+ return
663
+ try:
664
+ model_id = JOB_SPEC.model if JOB_SPEC else ""
665
+ # Resolve the sleep decision EXACTLY as run_rl does (grpo_sleep_mode: the size/context gate
666
+ # PLUS the resident-fit check against the live card), so the alloc conf matches the sleep
667
+ # mode the trainer will actually use.
668
+ _t = JOB_SPEC.train if JOB_SPEC else None
669
+ ctx = 0
670
+ try:
671
+ if _t and _t.max_length:
672
+ ctx = int(_t.max_length)
673
+ except Exception:
674
+ ctx = 0
675
+ card_gb = 0.0
676
+ try:
677
+ import torch as _torch_card
678
+
679
+ if _torch_card.cuda.is_available():
680
+ # Binary GiB to match grpo_fits_resident (see run_rl); /1e9 over-reports ~7%.
681
+ card_gb = _torch_card.cuda.get_device_properties(0).total_memory / (1024**3)
682
+ except Exception:
683
+ card_gb = 0.0
684
+ # Resolve group_size EXACTLY as run_rl does (gcfg override, else the recipe default), not a
685
+ # flat 8: if the recipe's rl.group_size differs from 8 the alloc-conf sleep decision here
686
+ # would diverge from the trainer's, picking the wrong expandable/non-expandable conf.
687
+ from flash.engine.recipe import RECIPE as _RECIPE
688
+
689
+ _gcfg = grpo_overrides()
690
+ _group_size = int(_gcfg.get("group_size") or _RECIPE.rl.group_size)
691
+ sleep_on = grpo_sleep_mode(
692
+ model_id,
693
+ max_length=ctx,
694
+ group_size=_group_size,
695
+ max_tokens=(_t.max_tokens if _t else None),
696
+ lora_rank=int(_t.lora_rank) if _t and _t.lora_rank else 32,
697
+ thinking=THINKING,
698
+ card_vram_gb=card_gb,
699
+ )
700
+ if not sleep_on: # sleep resolves OFF -> expandable is safe + better
701
+ conf = "expandable_segments:True"
702
+ os.environ["PYTORCH_ALLOC_CONF"] = conf
703
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = conf
704
+ print(f"[alloc] sleep resolves OFF -> {conf} (anti-fragmentation, matches worker gate)")
705
+ else:
706
+ print("[alloc] sleep resolves ON -> keeping launcher's non-expandable conf")
707
+ except Exception as e:
708
+ print("[alloc] auto-conf skipped:", e)
709
+
710
+
711
+ def wandb_report_to() -> list[str]:
712
+ """TRL/HF ``report_to`` targets. Restores the W&B logging the legacy freesolo training path had
713
+ but the flash migration dropped: report to W&B whenever WANDB_API_KEY is present. No key -> []
714
+ (silent, the metrics.json artifact is still the source of truth).
715
+
716
+ Project + run name come ONLY from the typed ``[wandb]`` config (``JOB_SPEC.wandb``) — there is
717
+ NO WANDB_PROJECT / WANDB_NAME environment variable. HF's WandbCallback has no project argument
718
+ and would read WANDB_PROJECT from the env, so we initialize the run directly via the wandb SDK
719
+ here (``wandb.init(project=..., name=...)``); the Trainer's callback then reuses that run. The
720
+ run's entity is the API key's default account/team (we don't pass ``entity=``), so the only
721
+ W&B env var is the WANDB_API_KEY credential."""
722
+ if not os.environ.get("WANDB_API_KEY"):
723
+ return []
724
+ import importlib.util
725
+
726
+ if importlib.util.find_spec("wandb") is None:
727
+ print("[wandb] WANDB_API_KEY set but the wandb package is missing; skipping W&B logging")
728
+ return []
729
+ # Best-effort, like the bitsandbytes import above: a partial/broken wandb install or an
730
+ # init failure (auth, network, runtime import error) must NOT abort training — W&B logging is
731
+ # optional and metrics.json is the source of truth. Any failure -> no W&B logging ([]).
732
+ try:
733
+ import wandb
734
+
735
+ if wandb.run is None: # init from the spec so the project needs no WANDB_PROJECT env
736
+ project = (JOB_SPEC.wandb.project if JOB_SPEC else None) or "flash"
737
+ wandb.init(project=project, name=wandb_run_name())
738
+ except Exception as e:
739
+ print(
740
+ f"[wandb] W&B init failed ({e}); skipping W&B logging (metrics.json is still written)"
741
+ )
742
+ return []
743
+ return ["wandb"]
744
+
745
+
746
+ def wandb_run_name() -> str:
747
+ """W&B run name, from the typed ``[wandb] run_name`` config (``JOB_SPEC.wandb.run_name``) only —
748
+ no WANDB_NAME environment variable. An explicit name is used verbatim (the user owns the
749
+ naming); otherwise a stable id tying the dashboard run to the Flash run
750
+ (``flash-<phase>-<run_id>-seed<N>``). Passed to the Trainer via ``TrainingArguments.run_name``
751
+ and to ``wandb.init`` above."""
752
+ configured = JOB_SPEC.wandb.run_name if JOB_SPEC else None
753
+ if configured and configured.strip():
754
+ return configured.strip()
755
+ return f"flash-{PHASE}-{RUN_ID}-seed{SEED}"
756
+
757
+
758
+ def wandb_run_info() -> dict:
759
+ """The live W&B run's {url, id, project} if W&B is active, else {}. Recorded in metrics.json so
760
+ the W&B run is verifiable + the freesolo agent's `wandb_runs` / the SDK's link_wandb can point at
761
+ the real dashboard URL — the link the flash migration otherwise dropped. Never raises."""
762
+ try:
763
+ import wandb
764
+
765
+ run = getattr(wandb, "run", None)
766
+ if run is None:
767
+ return {}
768
+ return {
769
+ "wandb_url": getattr(run, "url", None),
770
+ "wandb_id": getattr(run, "id", None),
771
+ "wandb_project": getattr(run, "project", None),
772
+ }
773
+ except Exception:
774
+ return {}
775
+
776
+
777
+ def make_lora(model_id: str | None = None):
778
+ """LoRA config. We target 'all-linear' (every nn.Linear) rather than a hardcoded
779
+ q/k/v/o list: it is architecture-agnostic, so the same recipe works for the dense
780
+ default (Qwen3-4B-Instruct-2507) and for newer models with extra projection
781
+ types (e.g. the Qwen3.5 hybrid Gated-DeltaNet) without missing any adapters.
782
+ For natively-multimodal checkpoints the vision tower is excluded (see
783
+ ``lora_exclude_modules``)."""
784
+ from peft import LoraConfig
785
+
786
+ # Adapt every linear projection. "all-linear" is a PEFT SPECIAL string (not a module name)
787
+ # that PEFT expands to all linear layers — the right managed default across the catalog.
788
+ targets = "all-linear"
789
+ rank = JOB_SPEC.train.lora_rank if JOB_SPEC else RECIPE.lora.rank
790
+ alpha = JOB_SPEC.train.lora_alpha if JOB_SPEC else RECIPE.lora.alpha
791
+ kwargs = {
792
+ "r": rank,
793
+ "lora_alpha": alpha,
794
+ "lora_dropout": RECIPE.lora.dropout,
795
+ "target_modules": targets,
796
+ "task_type": "CAUSAL_LM",
797
+ }
798
+ # Adapter initialization: standard zero-B init (the LoRA delta starts at zero, so the saved
799
+ # adapter is a plain residual that loads correctly onto the ORIGINAL base).
800
+ # PiSSA was removed: it mutates the effective base during training, so its saved adapter only
801
+ # reconstructs against the PiSSA-residual base. Loading that adapter onto the unmodified base
802
+ # at SERVING or GRPO WARM-START (which is exactly our flow) corrupts the model -> the served
803
+ # model emits only whitespace and warm-start GRPO hangs. peft can convert PiSSA->standard on
804
+ # save, but the simpler, robust choice is the default init (the convergence gain isn't worth
805
+ # silently breaking serve + warm-start).
806
+ kwargs["init_lora_weights"] = True
807
+ print(
808
+ "[lora] init_lora_weights=True (standard zero-B; PiSSA removed for serve/warm-start safety)"
809
+ )
810
+ # Standard LoRA scaling (alpha/r). rsLoRA was removed: it scales by alpha/sqrt(r) (~5.6x larger
811
+ # for r=32/alpha=64), so with the usual LoRA LR (e.g. 2e-4) the effective update is ~5.6x too
812
+ # large -> SFT diverges to a degenerate adapter (served model repeats a single token / emits
813
+ # whitespace) and the adapter is also fragile under vLLM's rsLoRA handling at serve time.
814
+ # Standard scaling keeps the catalog LRs sane and the saved adapter serve-safe.
815
+ kwargs["use_rslora"] = False
816
+ if model_id and targets == "all-linear":
817
+ exclude = lora_exclude_modules(model_id)
818
+ if exclude:
819
+ kwargs["exclude_modules"] = exclude
820
+ print(f"[lora] excluding modules for {model_id}: {exclude}")
821
+ return LoraConfig(**kwargs)
822
+
823
+
824
+ def require_vllm_for_rollout_func(use_rollout_func: bool, use_vllm: bool, model_id: str) -> None:
825
+ """Fail fast when a multi-turn GRPO run needs colocated vLLM but it's disabled.
826
+
827
+ The multi-turn rollout closure (``multiturn_rollout.build_rollout_func``) drives generation
828
+ through ``trainer.vllm_generation.llm``. TRL only creates that engine when ``use_vllm`` is
829
+ True, so with vLLM disabled the rollout would AttributeError at the first turn. GRPO now always
830
+ colocates vLLM (``use_vllm`` is unconditionally True), so this guard is defensive — keep it to
831
+ fail fast with an actionable message should a future tier disable the rollout engine.
832
+ """
833
+ if use_rollout_func and not use_vllm:
834
+ raise RuntimeError(
835
+ f"multi-turn GRPO needs colocated vLLM, which is disabled for {model_id}. "
836
+ "Use a single-turn environment for this model, or a model tier that keeps "
837
+ "vLLM enabled for rollouts."
838
+ )
839
+
840
+
841
+ def run_sft():
842
+ from datasets import Dataset
843
+ from transformers import AutoTokenizer
844
+ from trl import SFTConfig as TRLSFTConfig
845
+ from trl import SFTTrainer
846
+
847
+ env = require_active_env() # fail loudly (not AttributeError: NoneType) on the no-JobSpec path
848
+ t_start = time.time()
849
+ heartbeat("sft_start", gpu=gpu_diagnostics())
850
+ # SFT on a multi-turn env: rows whose target completion is a full trajectory train on the whole
851
+ # transcript (proper multi-turn SFT, handled below); rows with a single-turn target completion
852
+ # collapse to one assistant turn. Warn only for the collapsing case (computed during the
853
+ # dataset build below), not unconditionally.
854
+ wait_for_gpu()
855
+ setup_perf_backends()
856
+ model_id = JOB_SPEC.model if JOB_SPEC else RECIPE.hf_model_id
857
+ download_seconds = prefetch_model(model_id)
858
+ tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
859
+ if tok.pad_token is None:
860
+ tok.pad_token = tok.eos_token
861
+
862
+ # Build SFT text dataset (seeded shuffle for reproducibility)
863
+ train = env.dataset()
864
+ rng = random.Random(SEED)
865
+ rng.shuffle(train)
866
+ max_examples = int(
867
+ JOB_SPEC.train.max_examples or 0
868
+ if JOB_SPEC and JOB_SPEC.train and JOB_SPEC.train.max_examples is not None
869
+ else 0
870
+ )
871
+ if max_examples > 0:
872
+ train = train[:max_examples]
873
+ texts = []
874
+ multiturn_targets = 0
875
+ for ex in train:
876
+ # The env (via the freesolo-sdk Environment.sft_completion) owns the target completion: the
877
+ # full multi-turn target trajectory (assistant turns + tool calls + tool results + replies)
878
+ # when the row ships one, else a single target assistant turn. Training on the whole
879
+ # transcript is what makes SFT actually multi-turn (the tool-call protocol + replies) — the
880
+ # warm start the GRPO recipe expects. A >1-message completion is a multi-turn trajectory.
881
+ completion = env.sft_completion(ex)
882
+ if len(completion) > 1: # a multi-turn target trajectory (vs a single assistant turn)
883
+ multiturn_targets += 1
884
+ msgs = [*env.prompt_messages(ex), *completion]
885
+ texts.append(
886
+ {
887
+ "text": tok.apply_chat_template(
888
+ msgs, tokenize=False, add_generation_prompt=False, enable_thinking=THINKING
889
+ )
890
+ }
891
+ )
892
+ if multiturn_targets:
893
+ print(f"[sft] multi-turn SFT: {multiturn_targets}/{len(train)} rows train on a full target transcript")
894
+ elif getattr(env, "multi_turn", False):
895
+ print(
896
+ "[sft][warn] this is a multi-turn Freesolo environment but no row ships a multi-turn "
897
+ "target completion; SFT collapses to a single assistant turn per row (tool/env turns "
898
+ "ignored). Provide target transcripts (output={\"messages\": [...]}) for proper multi-turn SFT."
899
+ )
900
+ if THINKING and not any("<think>" in t["text"] for t in texts[:256]):
901
+ print(
902
+ "WARN: thinking mode is ON but no sampled SFT target contains a <think> "
903
+ "trace — training on non-reasoning targets teaches the model to SKIP "
904
+ "thinking. Use a dataset with reasoning traces, or set thinking = false."
905
+ )
906
+ ds = Dataset.from_list(texts)
907
+
908
+ setup_seconds = time.time() - t_start
909
+ heartbeat("sft_model_load", setup_seconds=setup_seconds, gpu=gpu_diagnostics())
910
+
911
+ # Epochs come from the run's [train] epochs (already in JOB_SPEC), else the recipe default.
912
+ epochs = int(
913
+ JOB_SPEC.train.epochs
914
+ if JOB_SPEC and JOB_SPEC.train.epochs is not None
915
+ else RECIPE.sft.num_epochs
916
+ )
917
+ # SDK [train] knobs override the recipe default.
918
+ from flash.catalog import vocab_size_for
919
+ from flash.engine.vram import resolve_params_b, sft_grad_accum, sft_logits_fused
920
+
921
+ _t = JOB_SPEC.train if JOB_SPEC else None
922
+ sft_lr = _t.learning_rate if _t and _t.learning_rate is not None else RECIPE.sft.learning_rate
923
+ sft_max_len = (
924
+ _t.max_length
925
+ if _t and _t.max_length is not None
926
+ else (RECIPE.sft.max_seq_len_thinking if THINKING else RECIPE.sft.max_seq_len)
927
+ )
928
+ # batch_size is the GLOBAL/effective batch; sft_grad_accum sizes the per-device micro-batch +
929
+ # grad-accum to realize it (shared with the cost estimator's step count, see engine.vram).
930
+ effective_batch = (
931
+ _t.batch_size if _t and _t.batch_size is not None else RECIPE.sft.effective_batch
932
+ )
933
+ # Large-vocab OOM guard: when the fused CE (Liger) is OFF, the SFTTrainer materializes the full
934
+ # [per_device, seq, vocab] fp32 logits + grad — at Qwen3.5's ~248k vocab a 0.8B SFT OOM'd a
935
+ # 24 GB card in backward. Cap the per-device micro-batch by the real model vocab + seq so those
936
+ # logits stay within the logits budget; grad-accum rises to keep the effective batch unchanged
937
+ # (the SFT mirror of rl_per_device_comps' GRPO cap). fused mirrors liger_on(_memory_mode(...))
938
+ # below, so the cap binds exactly when the worker won't fuse the CE.
939
+ _sft_params_b = resolve_params_b(model_id) # catalog stat else HF safetensors (open models)
940
+ _sft_vocab = vocab_size_for(model_id)
941
+ # Actual fused-CE decision == what `use_liger_kernel` is set from below (line ~879). sft_logits_fused
942
+ # is the offline size/ctx mirror; liger_on(...) adds the runtime CUDA + liger_kernel-importable
943
+ # check, so the cap binds exactly when the fused CE is NOT really taken.
944
+ _sft_fused = sft_logits_fused(_sft_params_b, sft_max_len) and liger_on(
945
+ _memory_mode(model_id, sft_max_len)
946
+ )
947
+ per_device_bs, grad_accum = sft_grad_accum(
948
+ effective_batch, seq_len=sft_max_len, vocab=_sft_vocab, fused=_sft_fused
949
+ )
950
+ if not _sft_fused and per_device_bs < min(effective_batch, 4):
951
+ print(
952
+ f"[sft] large-vocab logits cap: per_device={per_device_bs} grad_accum={grad_accum} "
953
+ f"(seq={sft_max_len}, vocab={_sft_vocab}; realized batch "
954
+ f"{per_device_bs * grad_accum} >= requested {effective_batch})"
955
+ )
956
+ sft_save_default = _t.save_every if _t and _t.save_every is not None else 50
957
+ out_dir = f"/tmp/sft_seed{SEED}"
958
+ resume_ckpt = hf_resume_checkpoint()
959
+
960
+ # [train].max_steps>0 caps optimizer steps (used by the cheap pre-flight smoke).
961
+ max_steps = int(_t.max_steps or 0 if _t and _t.max_steps is not None else 0)
962
+ cfg_kwargs = {
963
+ "output_dir": out_dir,
964
+ "num_train_epochs": epochs,
965
+ "per_device_train_batch_size": per_device_bs,
966
+ "gradient_accumulation_steps": grad_accum,
967
+ "learning_rate": sft_lr,
968
+ "warmup_ratio": RECIPE.sft.warmup_frac,
969
+ "logging_steps": 10,
970
+ "save_steps": sft_save_default,
971
+ "save_total_limit": 1,
972
+ # Resumable checkpoints: save the optimizer / scheduler / RNG state alongside the (small)
973
+ # LoRA adapter. We DO resume mid-run — make_checkpoint_upload_callback streams each save to
974
+ # HF and a replacement worker calls resume_from_checkpoint(hf_resume_checkpoint()) after a
975
+ # preemption — so without this the resumed run would re-initialize the optimizer (Adam
976
+ # moments) and LR schedule instead of truly continuing. For LoRA the optimizer state is tiny
977
+ # (it covers only the trainable adapter params), so the save spike is negligible. The
978
+ # deployable per-step snapshot (publish_deployable_checkpoint) strips this trainer state
979
+ # separately, so serving still gets adapter-only files.
980
+ "save_only_model": False,
981
+ "max_length": sft_max_len,
982
+ "bf16": True,
983
+ "report_to": wandb_report_to(), # W&B when WANDB_API_KEY present (restored post-flash-migration)
984
+ "run_name": wandb_run_name(),
985
+ # Dataloader parallelism: overlap host-side collation/tokenization with GPU compute so a
986
+ # real (large) training set isn't dataloader-bound. Pure throughput, zero quality change.
987
+ # Negligible on the tiny benchmark (pre-tokenized, in-memory); a real win at production
988
+ # dataset sizes.
989
+ "dataloader_num_workers": 4,
990
+ "dataloader_pin_memory": True,
991
+ "dataloader_persistent_workers": True,
992
+ "seed": SEED,
993
+ "gradient_checkpointing": grad_checkpointing_on(model_id, sft_max_len),
994
+ # Non-reentrant checkpointing: composes cleanly with autograd hooks (verl #3629) and is
995
+ # required by TRL for correct grad flow through the LoRA adapters.
996
+ "gradient_checkpointing_kwargs": {"use_reentrant": False},
997
+ "completion_only_loss": False,
998
+ # Optimizer: 8-bit paged AdamW (int8 state paged to host RAM -> fits a smaller GPU).
999
+ "optim": fused_optim_name(),
1000
+ }
1001
+ if max_steps > 0:
1002
+ cfg_kwargs["max_steps"] = max_steps
1003
+ # Example packing: concatenate short examples into full max_length sequences so a batch isn't
1004
+ # mostly pad tokens — PR #174 measured a 4.4-10.7x SFT speedup (h100 8.2x, 4090 10.7x) because
1005
+ # instruction targets are far shorter than max_seq_len; unpacked batches waste most of their
1006
+ # FLOPs on padding. TRL's 'bfd' strategy makes padding-free batches whose example boundaries are
1007
+ # honored ONLY by an attention impl that reads them — under plain SDPA packed examples
1008
+ # cross-contaminate (silent quality loss). The boundary-correct backend is FlashAttention-2
1009
+ # varlen (reads position_ids), which the worker image bakes in best-effort: Dockerfile.worker
1010
+ # installs FLASH_ATTN_SPEC (a community cu128/torch2.10/cp312 wheel preferred, source build as a
1011
+ # fallback) and tolerates a build failure -> SDPA. So _fa_ok is True whenever that install landed;
1012
+ # packing is ON then (varlen keeps 'bfd' example boundaries correct). If the best-effort install
1013
+ # failed, _fa_ok is False and we SKIP packing — without a boundary-correct attn backend examples
1014
+ # would cross-contaminate under SDPA.
1015
+ # Pure full-attention vs GatedDeltaNet hybrid (Qwen3.5/3.6) — probed ONCE here and reused across
1016
+ # the whole packing decision (each probe reads the cached HF config). TRL 'bfd' packing keeps
1017
+ # example boundaries via position_ids that a varlen attn honors, but it provides NO seq_idx, so it
1018
+ # can't reset a GDN hybrid's causal conv -> bfd-packing a GDN model silently cross-contaminates its
1019
+ # linear-attention layers. So bfd is enabled for PURE full-attention models only; GDN hybrids pack
1020
+ # via the cu_seqlens/seq_idx varlen collator branch below (when their kernels are present).
1021
+ _pure_attn = model_is_pure_attention(model_id)
1022
+ _gdn = model_is_gdn_hybrid(model_id)
1023
+ _fa_ok = _flash_attn_available()
1024
+ if _fa_ok and _pure_attn:
1025
+ cfg_kwargs["packing"] = True
1026
+ print("[sft] example packing enabled (FA2 varlen)")
1027
+ elif _fa_ok and _gdn:
1028
+ print(
1029
+ "[sft] TRL bfd packing NOT used for the GatedDeltaNet hybrid (bfd can't reset the conv); "
1030
+ "the cu_seqlens/seq_idx varlen collator handles its packing when both kernels are present."
1031
+ )
1032
+ else:
1033
+ # FA2 bfd packing not enabled here — either flash_attn isn't importable, or it is but the arch
1034
+ # isn't bfd-safe (e.g. sliding-window). This is NOT the final word: the SDPA block-diagonal /
1035
+ # GDN-varlen block below may still turn packing on for a pure-attention or GDN-hybrid model.
1036
+ _bfd_why = "flash_attn not importable" if not _fa_ok else "arch not bfd-safe under FA2 varlen"
1037
+ print(f"[sft] TRL bfd (FA2) packing not used ({_bfd_why}); the SDPA-mask path decides packing below.")
1038
+ # Liger fused CE/RMSNorm/RoPE kernels, gated by model size (_memory_mode). The fused linear
1039
+ # cross-entropy is the big large-vocab (Qwen3.5 ~248k) memory/throughput win.
1040
+ if liger_on(_memory_mode(model_id, sft_max_len)):
1041
+ cfg_kwargs["use_liger_kernel"] = True
1042
+ print("[sft] liger fused kernels enabled")
1043
+ _attn = optimal_attn_impl() # arch-best FlashAttention (FA3 Hopper / FA2 Ampere·Ada) or SDPA
1044
+ # Packing correctness: 'bfd' packed batches are boundary-correct ONLY under a varlen-capable attn
1045
+ # (FA2 and FA3 both expose flash_attn_varlen_func; plain SDPA cross-contaminates packed examples).
1046
+ # Use the ARCH-BEST flash impl optimal_attn_impl already picked (so Hopper packs under FA3, not
1047
+ # FA2). Cases when it did NOT pick a flash impl:
1048
+ # * _attn == "sdpa" (sm120, the deliberate no-flash exception): DISABLE packing — consumer
1049
+ # Blackwell stays plain SDPA; do NOT force FA2 (its sm120 kernel coverage is unverified).
1050
+ # * _attn is None (Hopper without FA3): force FA2 for boundary-correct varlen IF the wheel is
1051
+ # importable; else drop packing rather than silently cross-contaminate.
1052
+ if cfg_kwargs.get("packing"):
1053
+ if _attn in ("flash_attention_2", "flash_attention_3"):
1054
+ print(f"[sft] attn_implementation={_attn} (packing boundary-correct varlen)")
1055
+ elif _attn == "sdpa":
1056
+ cfg_kwargs["packing"] = False
1057
+ print("[sft] packing disabled: selected attn_implementation=sdpa (no varlen flash backend)")
1058
+ elif _fa_ok:
1059
+ _attn = "flash_attention_2"
1060
+ print("[sft] attn_implementation=flash_attention_2 (packing boundary-correct varlen)")
1061
+ else:
1062
+ cfg_kwargs["packing"] = False
1063
+ print("[sft] packing disabled: no varlen flash backend (FA2/FA3) available -> plain SDPA")
1064
+
1065
+ # --- True token packing via a 4D block-diagonal SDPA mask (no flash-attn / no flex) ---------
1066
+ # When the run lands on plain SDPA (no varlen flash backend) the block above left packing OFF —
1067
+ # notably on sm120 (RTX 5090, flash's DEFAULT GPU), and anywhere the best-effort flash-attn
1068
+ # build didn't land. For a PURE full-attention model we can still pack: concatenate examples
1069
+ # into max_length blocks and feed a 4D block-diagonal causal mask SDPA honors natively, so
1070
+ # packed examples never attend across boundaries (boundary-correct, numerically identical to
1071
+ # unpacked — verified on a tiny Qwen3/Llama: |packed-separate| logits ~1e-7). This reclaims the packing
1072
+ # throughput win on the default GPU with neither flash-attn nor flex_attention. GatedDeltaNet
1073
+ # hybrids (Qwen3.5/3.6) take the NEXT branch instead — a mask alone can't reset their linear-
1074
+ # attention state, so they also need the cu_seqlens/seq_idx varlen kwargs.
1075
+ _collator = None
1076
+ # The mask paths materialize a dense [B, 1, T, T] mask — O(T^2) memory. At very long context that
1077
+ # tax (hundreds of MB to >1 GB) can OOM a run that previously fit under memory-efficient SDPA, and
1078
+ # packing buys little there anyway (long rows already fill a block). Above this cap, leave packing
1079
+ # off (train unpacked, as today). 16384: the dense bf16/bool mask stays <=~256 MB at bsz=1.
1080
+ _PACK_MASK_MAX_LEN = 16384
1081
+ _mask_pack_ok = sft_max_len <= _PACK_MASK_MAX_LEN
1082
+ _sdpa_pack = bool(not cfg_kwargs.get("packing") and _pure_attn and _mask_pack_ok)
1083
+ if _sdpa_pack:
1084
+ # The 4D mask requires a MASK-READING attn (SDPA). DOWNGRADE any flash impl optimal_attn_impl
1085
+ # picked — e.g. FA3 on a Hopper worker whose FA2 wheel didn't build — to SDPA: a flash varlen
1086
+ # kernel SILENTLY IGNORES the 4D mask, so packed examples would attend across boundaries. (A
1087
+ # bare ``_attn or "sdpa"`` would leave the truthy flash string in place — the bug this avoids.)
1088
+ if _attn in ("flash_attention_2", "flash_attention_3"):
1089
+ print(f"[sft] packing under SDPA: downgrading {_attn} -> sdpa (a flash kernel ignores the 4D mask)")
1090
+ _attn = "sdpa"
1091
+ cfg_kwargs["packing"] = False # we own the packing; TRL must not also pack
1092
+ # Hand TRL pre-tokenized, pre-packed rows + our collator: skip its dataset prep and stop the
1093
+ # signature-based column pruning from dropping our seq_lengths column before collation.
1094
+ _dk = dict(cfg_kwargs.get("dataset_kwargs") or {})
1095
+ _dk["skip_prepare_dataset"] = True
1096
+ cfg_kwargs["dataset_kwargs"] = _dk
1097
+ cfg_kwargs["remove_unused_columns"] = False
1098
+ # Tokenize EXACTLY like TRL's non-packed prep (EOS-append parity so the model still learns to
1099
+ # stop; batched; truncate to max_length) then bin-pack into <= max_length blocks.
1100
+ _tokenized = tokenize_for_packing([t["text"] for t in texts], tok, sft_max_len)
1101
+ _packed_rows = pack_token_ids(_tokenized, sft_max_len)
1102
+ ds = Dataset.from_list(_packed_rows)
1103
+ _collator = BlockDiagonalCollator(pad_token_id=tok.pad_token_id)
1104
+ # Memory: re-size the per-device micro-batch (in BLOCKS) for the full-block [pd, max_length,
1105
+ # vocab] fp32 logits budget — a no-op under Liger's fused CE. Quality: each block holds
1106
+ # ~ex_per_block examples, so KEEP the effective batch in EXAMPLES at the configured value by
1107
+ # re-deriving grad_accum from the block count. Without this, packing balloons the effective
1108
+ # batch ~ex_per_block-fold (fewer, larger updates -> mild undertraining at the same epochs:
1109
+ # an A/B measured +5.2% held-out loss vs unpacked, closed to +0.1% once matched).
1110
+ _pd_pack, _ = sft_grad_accum(
1111
+ effective_batch, seq_len=sft_max_len, vocab=vocab_size_for(model_id),
1112
+ fused=bool(cfg_kwargs.get("use_liger_kernel")),
1113
+ )
1114
+ # The dense [pd, 1, T, T] bool mask is pd*T^2 bytes — under Liger the logits cap doesn't bind
1115
+ # so pd can be 4, and at long context that mask alone is GBs. Cap pd so the mask stays <=512MB
1116
+ # (a no-op at short ctx: at T=2048 it allows pd up to ~125; it only bites past ~12k tokens).
1117
+ _pd_pack = max(1, min(_pd_pack, (512 * 1024 * 1024) // (sft_max_len * sft_max_len)))
1118
+ _ex_per_block = len(_tokenized) / max(1, len(_packed_rows))
1119
+ cfg_kwargs["per_device_train_batch_size"] = _pd_pack
1120
+ cfg_kwargs["gradient_accumulation_steps"] = max(
1121
+ 1, math.ceil(effective_batch / max(1.0, _pd_pack * _ex_per_block))
1122
+ )
1123
+ print(
1124
+ "[sft] true token packing ENABLED (4D block-diagonal SDPA mask): "
1125
+ f"{len(_tokenized)} examples -> {len(_packed_rows)} blocks (~{_ex_per_block:.1f} ex/block, "
1126
+ f"{packing_efficiency(_packed_rows, sft_max_len):.0%} dense) of <= {sft_max_len} tok; "
1127
+ f"pd={_pd_pack} ga={cfg_kwargs['gradient_accumulation_steps']} (effective batch kept "
1128
+ f"~{effective_batch} ex); no flash-attn / no flex_attention"
1129
+ )
1130
+ elif not cfg_kwargs.get("packing") and _gdn and gdn_packing_available(model_id) and _mask_pack_ok:
1131
+ # GatedDeltaNet hybrid (Qwen3.5/3.6, flash's flagship tier): the 4D block-diagonal mask makes
1132
+ # the FULL-attention layers boundary-correct, and the linear-attention (DeltaNet) layers reset
1133
+ # their recurrence + causal conv at example boundaries via cu_seq_lens_q (fla kernel) + seq_idx
1134
+ # (causal_conv1d). GPU-validated on Qwen3.5-0.8B (RTX 5090): a packed example's output is
1135
+ # byte-identical regardless of its neighbors' content (ZERO cross-example leakage); the only
1136
+ # diff vs unpacked is benign bf16 GDN-kernel tiling numerics (~0.3 on logits). Gated on BOTH
1137
+ # kernels being importable (gdn_packing_available) so a worker without them stays unpacked.
1138
+ # Pin SDPA for the full-attn layers (downgrade any flash impl, e.g. FA3 on Hopper — it would
1139
+ # ignore the 4D mask); the DeltaNet layers are unaffected (they use cu_seqlens/seq_idx).
1140
+ if _attn in ("flash_attention_2", "flash_attention_3"):
1141
+ print(f"[sft] GDN packing under SDPA: downgrading {_attn} -> sdpa for the full-attn layers")
1142
+ _attn = "sdpa"
1143
+ cfg_kwargs["packing"] = False
1144
+ _dk = dict(cfg_kwargs.get("dataset_kwargs") or {})
1145
+ _dk["skip_prepare_dataset"] = True
1146
+ cfg_kwargs["dataset_kwargs"] = _dk
1147
+ cfg_kwargs["remove_unused_columns"] = False
1148
+ # EOS-append parity + batched + truncated tokenization (same as the unpacked path), then pack.
1149
+ _tokenized = tokenize_for_packing([t["text"] for t in texts], tok, sft_max_len)
1150
+ _packed_rows = pack_token_ids(_tokenized, sft_max_len)
1151
+ ds = Dataset.from_list(_packed_rows)
1152
+ _collator = BlockDiagonalCollator(pad_token_id=tok.pad_token_id, emit_varlen=True)
1153
+ # cu_seqlens spans ONE packed block, so per-device is a single block; keep the effective batch
1154
+ # in EXAMPLES at the configured value via grad-accum (each block holds ~ex_per_block examples —
1155
+ # without this the effective batch would balloon ~ex_per_block-fold -> undertraining).
1156
+ _ex_per_block = len(_tokenized) / max(1, len(_packed_rows))
1157
+ cfg_kwargs["per_device_train_batch_size"] = 1
1158
+ cfg_kwargs["gradient_accumulation_steps"] = max(1, math.ceil(effective_batch / max(1.0, _ex_per_block)))
1159
+ print(
1160
+ "[sft] true token packing ENABLED for GatedDeltaNet hybrid (4D mask + cu_seqlens/seq_idx "
1161
+ f"varlen): {len(_tokenized)} examples -> {len(_packed_rows)} blocks (~{_ex_per_block:.1f} "
1162
+ f"ex/block, {packing_efficiency(_packed_rows, sft_max_len):.0%} dense) of <= {sft_max_len} "
1163
+ f"tok; pd=1 ga={cfg_kwargs['gradient_accumulation_steps']} (effective batch kept ~{effective_batch} ex)"
1164
+ )
1165
+ elif not cfg_kwargs.get("packing") and (_pure_attn or _gdn) and not _mask_pack_ok:
1166
+ print(
1167
+ f"[sft] packing stays OFF: max_length {sft_max_len} > {_PACK_MASK_MAX_LEN} — the dense "
1168
+ "O(T^2) block-diagonal mask gets too large at long context (unpacked is more memory-"
1169
+ "efficient there, and long rows already fill a block)."
1170
+ )
1171
+ elif not cfg_kwargs.get("packing") and not _pure_attn:
1172
+ _why = (
1173
+ "hybrid GatedDeltaNet but the fla/causal_conv1d varlen kernels aren't both importable"
1174
+ if _gdn
1175
+ else "non-full-attention arch (e.g. sliding-window) a block-diagonal mask can't pack"
1176
+ )
1177
+ print(f"[sft] packing stays OFF: {_why}. (Pure full-attention models pack via the SDPA mask.)")
1178
+ # Explicit bf16 + no auto device-map: TRL/transformers-5 string loading can
1179
+ # otherwise fall back to fp32 (2x VRAM; observed 18.6 GB for a 4.66B model) or
1180
+ # accelerate-offload large models to meta ("expected device meta but got
1181
+ # cuda:0" in backward on the 9B).
1182
+ mik = {"dtype": "bfloat16", "device_map": None}
1183
+ if _attn:
1184
+ mik["attn_implementation"] = _attn
1185
+ cfg_kwargs["model_init_kwargs"] = mik
1186
+ cfg = TRLSFTConfig(**cfg_kwargs)
1187
+
1188
+ # LoRA+ (convergence lever, arXiv 2402.12354; always-on: measured -52% train loss in A/B
1189
+ # (gpu-bench)): give the LoRA B matrices a higher LR than A (ratio 16). Reported ~2x fewer steps
1190
+ # to target at identical per-step FLOPs. TRL builds the model from a string inside __init__, so
1191
+ # the optimizer (which needs the instantiated params) can't be pre-built — override
1192
+ # create_optimizer to construct it from self.model once it exists.
1193
+ _lp_ratio = 16
1194
+ _SFT = SFTTrainer
1195
+ if _lp_ratio > 1:
1196
+
1197
+ class _SFT(SFTTrainer): # local LoRA+ subclass
1198
+ _loraplus_applied = False # True only once the LoRA+ grouping actually installs
1199
+
1200
+ def create_optimizer(self):
1201
+ if self.optimizer is None:
1202
+ try:
1203
+ from peft.optimizers import create_loraplus_optimizer
1204
+
1205
+ # Mirror the configured `optim` so LoRA+ and the 8-bit paged optimizer state
1206
+ # coexist (instead of silently forcing fp32 AdamW); see loraplus_optimizer_cls.
1207
+ # .value (not str()): self.args.optim is a TRL OptimizerNames enum whose
1208
+ # str() is "OptimizerNames.PAGED_ADAMW_8BIT"; pass the raw value
1209
+ # ("paged_adamw_8bit") so the 8-bit match works.
1210
+ opt_cls, extra = loraplus_optimizer_cls(
1211
+ getattr(self.args.optim, "value", self.args.optim)
1212
+ )
1213
+ # Forward the TrainingArguments optimizer config that the default HF
1214
+ # create_optimizer path would have applied. Building the optimizer
1215
+ # ourselves means we must replicate it explicitly, or LoRA+ runs would
1216
+ # silently use the optimizer class's own defaults instead of the
1217
+ # configured betas/eps/weight_decay. betas/eps go straight to the optimizer
1218
+ # constructor (alongside any `extra` from loraplus_optimizer_cls);
1219
+ # weight_decay is handled separately below.
1220
+ fwd = dict(extra)
1221
+ _betas = (
1222
+ getattr(self.args, "adam_beta1", None),
1223
+ getattr(self.args, "adam_beta2", None),
1224
+ )
1225
+ if None not in _betas:
1226
+ fwd.setdefault("betas", _betas)
1227
+ _eps = getattr(self.args, "adam_epsilon", None)
1228
+ if _eps is not None:
1229
+ fwd.setdefault("eps", _eps)
1230
+ # PEFT does NOT read args.weight_decay; it applies decay via its own LoRA+
1231
+ # param groups, keyed off the loraplus_weight_decay kwarg (which it pops
1232
+ # before constructing the optimizer). Pass it as a top-level kwarg so it
1233
+ # isn't forwarded into the optimizer constructor.
1234
+ lp_extra: dict[str, object] = {}
1235
+ _wd = getattr(self.args, "weight_decay", None)
1236
+ if _wd is not None:
1237
+ lp_extra["loraplus_weight_decay"] = _wd
1238
+ # PEFT's create_loraplus_optimizer forwards extra kwargs to the optimizer;
1239
+ # the lr keyword name has shifted across PEFT versions, so pass it via
1240
+ # optimizer_kwargs (the stable form) and fall back to a top-level lr=.
1241
+ try:
1242
+ self.optimizer = create_loraplus_optimizer(
1243
+ model=self.model,
1244
+ optimizer_cls=opt_cls,
1245
+ optimizer_kwargs={"lr": self.args.learning_rate, **fwd},
1246
+ loraplus_lr_ratio=_lp_ratio,
1247
+ **lp_extra,
1248
+ )
1249
+ except TypeError:
1250
+ self.optimizer = create_loraplus_optimizer(
1251
+ model=self.model,
1252
+ optimizer_cls=opt_cls,
1253
+ lr=self.args.learning_rate,
1254
+ loraplus_lr_ratio=_lp_ratio,
1255
+ **fwd,
1256
+ **lp_extra,
1257
+ )
1258
+ self._loraplus_applied = True
1259
+ print(
1260
+ f"[lora+] optimizer enabled (B-matrix LR ratio={_lp_ratio}, "
1261
+ f"cls={opt_cls.__name__})"
1262
+ )
1263
+ return self.optimizer
1264
+ except Exception as e: # never block training on the LoRA+ wiring
1265
+ print("[lora+] setup failed, falling back to default optimizer:", e)
1266
+ return super().create_optimizer()
1267
+
1268
+ # Pass model as a string id + tokenizer as processing_class so TRL takes the
1269
+ # text/causal-LM path (not the VLM processor path) for this multimodal checkpoint.
1270
+ # SFTTrainer.__init__ blocks for 10-15 min on first use (FA2 CUDA kernel JIT compilation);
1271
+ # without a heartbeat the control plane can't distinguish this from a real hang and may
1272
+ # recycle the worker. A daemon thread pings every 30s so the stall detector stays quiet.
1273
+ _sft_init_done = threading.Event()
1274
+
1275
+ def _sft_init_heartbeat() -> None:
1276
+ while not _sft_init_done.wait(30.0):
1277
+ heartbeat("sft_initializing", gpu=gpu_diagnostics())
1278
+
1279
+ _sft_init_hb = threading.Thread(target=_sft_init_heartbeat, daemon=True)
1280
+ _sft_init_hb.start()
1281
+ try:
1282
+ trainer = _SFT(
1283
+ model=model_id,
1284
+ args=cfg,
1285
+ train_dataset=ds,
1286
+ peft_config=make_lora(model_id),
1287
+ processing_class=tok,
1288
+ # Our block-diagonal collator on the SDPA-packing path; None elsewhere == TRL default.
1289
+ data_collator=_collator,
1290
+ callbacks=[make_sft_heartbeat_callback(), make_checkpoint_upload_callback()],
1291
+ )
1292
+ finally:
1293
+ _sft_init_done.set()
1294
+ # Apply chalk's gap-filling kernels (RoPE/LoRA-delta/embedding, like Liger) on the materialized
1295
+ # SFT trainer.model — chalk's apply patches the LIVE module, so it must run AFTER TRL builds the
1296
+ # model (chalk composes on top of TRL's Liger). No-op unless a FLASH_* kernel flag selects it and
1297
+ # freesolo-chalk is installed.
1298
+ _chalk_report = install_chalk_kernels(getattr(trainer, "model", None))
1299
+
1300
+ _reset_peak_gpu() # so peak_gpu_gb reflects the train loop (optimizer-state A/B is measurable)
1301
+ _gpu_sampler = _GpuPeakSampler().start() # true device peak incl. bnb managed optimizer pages
1302
+ t_train = time.time()
1303
+ with _sdpa_cudnn_ctx(_attn): # force cuDNN SDPA on sm120 (no-op otherwise)
1304
+ trainer.train(resume_from_checkpoint=resume_ckpt)
1305
+ train_wall = time.time() - t_train
1306
+ sft_peak_gpu_gb = _peak_gpu_gb()
1307
+ sft_device_peak_gpu_gb = _gpu_sampler.stop_gb()
1308
+
1309
+ adapter_dir = f"{out_dir}/adapter"
1310
+ trainer.model.save_pretrained(adapter_dir)
1311
+ tok.save_pretrained(adapter_dir)
1312
+ hf_upload_folder(adapter_dir, "adapter", required=True)
1313
+ heartbeat("sft_trained", train_wall=train_wall, gpu=gpu_diagnostics())
1314
+
1315
+ # count train tokens
1316
+ train_tokens = int(sum(len(tok(t["text"])["input_ids"]) for t in texts) * epochs)
1317
+
1318
+ # Write train metadata + the completion sentinel (metrics.json/DONE) for this phase.
1319
+ write_train_meta(
1320
+ phase="sft",
1321
+ adapter_dir=adapter_dir,
1322
+ model_id=model_id,
1323
+ train_wall=train_wall,
1324
+ setup_seconds=setup_seconds,
1325
+ train_tokens=train_tokens,
1326
+ generated_tokens=0,
1327
+ notes={
1328
+ "epochs": epochs,
1329
+ "resumed": bool(resume_ckpt),
1330
+ "download_seconds": download_seconds,
1331
+ "hf_transfer": os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", ""),
1332
+ "thinking": THINKING,
1333
+ # Persist the loss curve so a CONVERGENCE A/B (PiSSA / LoRA+ init, etc.) is measurable
1334
+ # without a checkpoint: trainer_state.json is only written on a save_step, and the
1335
+ # console is only uploaded on failure, so a short successful run otherwise drops its
1336
+ # loss history entirely.
1337
+ "loss_curve": _metric_curve(trainer, "loss"),
1338
+ # Peak torch-allocated GPU memory during the train loop (excludes bnb managed pages, so
1339
+ # it overstates the 8-bit saving — use device_peak_gpu_gb for the true footprint).
1340
+ "peak_gpu_gb": sft_peak_gpu_gb,
1341
+ # True peak device memory (total-free, incl. bnb managed optimizer pages): the honest
1342
+ # headline for the fp32-vs-8-bit LoRA+ optimizer A/B.
1343
+ "device_peak_gpu_gb": sft_device_peak_gpu_gb,
1344
+ # Report the optimizer ACTUALLY built on the trainer, not the planned class: if the
1345
+ # LoRA+ create_optimizer override failed, training falls back to TRL's configured
1346
+ # optimizer without LoRA+ grouping. loraplus_applied records which path actually ran.
1347
+ # Accelerate wraps the optimizer (AcceleratedOptimizer) under transformers 5.x, so unwrap
1348
+ # via `.optimizer` to record the underlying PagedAdamW8bit/AdamW the A/B cares about, not
1349
+ # the wrapper name.
1350
+ "loraplus_optim": (
1351
+ type(getattr(trainer.optimizer, "optimizer", trainer.optimizer)).__name__
1352
+ if getattr(trainer, "optimizer", None) is not None
1353
+ else loraplus_optimizer_cls(fused_optim_name())[0].__name__
1354
+ ),
1355
+ "loraplus_applied": getattr(trainer, "_loraplus_applied", False),
1356
+ # Which chalk gap-filling kernels actually ENGAGED (empty/None = chalk not installed or
1357
+ # every kernel fell back) — verifies the chalk stack without the console.
1358
+ "chalk_kernels": active_kernels(_chalk_report) or None,
1359
+ **wandb_run_info(),
1360
+ },
1361
+ )
1362
+ free_gpu(trainer)
1363
+
1364
+
1365
+ # ---------------------------------------------------------------------------
1366
+ # RL (GRPO) with TRL + colocated vLLM
1367
+ # ---------------------------------------------------------------------------
1368
+ def compute_grpo_batching(prompts_per_step: int, group_size: int, per_device_comps: int) -> dict:
1369
+ """Translate an intended ``prompts_per_step`` into a TRL GRPO batch configuration.
1370
+
1371
+ TRL's GRPO batch sizing is denominated in **completions (prompt-completion pairs), not
1372
+ prompts**. The number of *unique prompts* optimized per step is
1373
+
1374
+ (per_device_train_batch_size * gradient_accumulation_steps * num_processes)
1375
+ / num_generations
1376
+
1377
+ So to actually optimize ``prompts_per_step`` prompts per step, the global *completion*
1378
+ batch must equal ``prompts_per_step * group_size``. We keep ``per_device`` small (it,
1379
+ not grad-accum, sets peak VRAM) and put the rest in gradient accumulation.
1380
+
1381
+ The bug this fixes: ``grad_accum = prompts_per_step // per_device`` treated
1382
+ ``per_device_train_batch_size`` as a *prompt* count, omitting the ``* group_size``
1383
+ factor, so a run intended as 64 prompts/step actually optimized only
1384
+ ``64 / group_size = 8`` prompts/step (an 8x smaller effective batch).
1385
+ """
1386
+ group_size = max(1, int(group_size))
1387
+ prompts_per_step = max(1, int(prompts_per_step))
1388
+ per_device = max(1, int(per_device_comps))
1389
+ target_comps = prompts_per_step * group_size # total completions / optimizer step
1390
+ # Never let the per-device completion micro-batch exceed the target completion batch:
1391
+ # a small prompts_per_step would otherwise overshoot it (mirrors run_sft's
1392
+ # `min(per_device_bs, effective_batch)`). No-op at the default (prompts_per_step=64).
1393
+ per_device = max(1, min(per_device, target_comps))
1394
+ # per_device is the fixed VRAM knob, but when it does NOT divide target_comps neither floor
1395
+ # nor ceil of grad_accum is right: floor (the old bug) silently optimizes FEWER prompts than
1396
+ # requested, while ceil over-shoots and asks TRL for MORE unique prompts than the (already
1397
+ # dataset-capped) prompts_per_step -- which, on a small retained dataset, yields no batches
1398
+ # after the paid worker is provisioned. Instead shrink per_device to the largest divisor of
1399
+ # target_comps that is <= the requested per_device: that lowers (never raises) peak VRAM and
1400
+ # makes per_device * grad_accum == target_comps EXACTLY, so unique prompts == prompts_per_step
1401
+ # with no over/under-shoot. (per_device=16, target_comps=40 -> 10 -> grad_accum=4 -> 40 comps
1402
+ # = exactly 5 prompts. A divisor always exists since 1 divides everything.)
1403
+ while target_comps % per_device != 0:
1404
+ per_device -= 1
1405
+ grad_accum = max(1, target_comps // per_device)
1406
+ # The global completion batch (per_device * grad_accum == target_comps) is divisible by
1407
+ # num_generations (= group_size) by construction, since target_comps = prompts_per_step *
1408
+ # group_size; TRL's divisibility requirement is satisfied with no further rounding.
1409
+ generations_per_step = per_device * grad_accum
1410
+ unique_prompts_per_step = generations_per_step // group_size
1411
+ return {
1412
+ "per_device_train_batch_size": per_device,
1413
+ "gradient_accumulation_steps": grad_accum,
1414
+ "generations_per_step": generations_per_step,
1415
+ "unique_prompts_per_step": unique_prompts_per_step,
1416
+ # TRL requires the global completion batch be divisible by num_generations.
1417
+ "divisible_by_group": (generations_per_step % group_size == 0),
1418
+ }
1419
+
1420
+
1421
+ def resolve_grpo_prompts_per_step(requested: int, available_prompts: int) -> int:
1422
+ """Cap GRPO's prompt batch to the retained dataset size.
1423
+
1424
+ TRL's GRPO dataloader can yield zero batches when the configured prompt batch is larger
1425
+ than the dataset that remains after prompt-budget filtering. That surfaces late as
1426
+ "There seems not to be a single sample in your epoch_iterator" and then our no-reward guard
1427
+ reports the wrong cause. Small smoke envs should still train; use every retained prompt per
1428
+ step instead of asking TRL for an impossible larger batch.
1429
+ """
1430
+ requested = max(1, int(requested))
1431
+ available_prompts = int(available_prompts)
1432
+ if available_prompts <= 0:
1433
+ raise ValueError("GRPO needs at least one retained training prompt")
1434
+ return min(requested, available_prompts)
1435
+
1436
+
1437
+ def build_grpo_prompt_dataset(prompts: list[dict]) -> tuple[list[dict], list]:
1438
+ """Arrow-safe GRPO rollout rows + the parallel example lookup ``reward_fn`` maps back through.
1439
+
1440
+ ``Dataset.from_list`` lets PyArrow infer ONE column type per (nested) field across ALL rows, so
1441
+ embedding the rich per-example record makes a *valid* env whose per-row ``info``/``metadata``
1442
+ legitimately mixes types crash dataset construction with ``ArrowInvalid`` — and the whole RL
1443
+ phase dies at startup, AFTER the paid GPU is provisioned, on input that passed offline
1444
+ single-example validation. (Observed with ifeval-lite: ``metadata.param`` is an int target word
1445
+ count for some rows and a required-word string ``'gentle'`` for others; Arrow infers ``int64``
1446
+ from the leading rows then fails on the first string.)
1447
+
1448
+ Fix: keep the dataset columns trivially typed — the TRL-required ``prompt`` plus a stable integer
1449
+ ``example_idx`` — and return the original example objects in a parallel list. ``reward_fn`` maps
1450
+ the index back, so the env still sees its EXACT record (no JSON/Arrow round-trip, no type
1451
+ coercion). ``rows[i]["example_idx"] == i`` and ``examples[i]`` is that row's record.
1452
+ """
1453
+ examples = [p["example"] for p in prompts]
1454
+ rows = [{"prompt": p["prompt"], "example_idx": i} for i, p in enumerate(prompts)]
1455
+ return rows, examples
1456
+
1457
+
1458
+ # Hard ceiling on the per-device completion micro-batch when growing on a SHORT-seq run. MEASURED
1459
+ # (RunPod, Qwen3.5-0.8B GRPO, group8, gsm8k, seq1024, 6 steps): trainer throughput rises from
1460
+ # per_device 4 -> 8 (~+12%) and plateaus 8..16 (A100 80GB: 375/407/411 tok/s at pd 4/8/16), then
1461
+ # REGRESSES at pd 32 (326 tok/s, -20%) as the larger forward stops buying MFU. So we never grow
1462
+ # past the top of that plateau, even on a card with VRAM to spare. (Reward histories at pd 4 and
1463
+ # 16 were identical -> per_device is a pure speed/VRAM knob, not an optimization change.)
1464
+ _RL_PER_DEVICE_MAX = 16
1465
+ # Reference sequence length the activation/VRAM divisor is calibrated at. The colocate activation
1466
+ # peak grows with the training sequence length; the cap is scaled by seq_len/_RL_ACT_SEQ_REF so a
1467
+ # short-seq run (the underfed regime) is allowed a proportionally bigger micro-batch.
1468
+ _RL_ACT_SEQ_REF = 2048.0
1469
+ # VRAM-per-(micro-batch element) divisor at the reference seq, normalized to ~2B width (1.41).
1470
+ # MEASURED: Qwen3.5-2B group8 seq2048 OOMs a 32 GB card at per_device=8 but trains at 4 ->
1471
+ # 32 / (7.5 * 1.0 * 1.0) = 4. (Unchanged from the historical colocate cap, so at/above the
1472
+ # reference seq the value is byte-for-byte the old one — no regression.)
1473
+ _RL_ACT_DIVISOR = 7.5
1474
+ # Floor on the seq scale: caps how far a short sequence may grow the micro-batch. Set so the
1475
+ # underfed case that motivated this — Qwen3.5-0.8B GRPO on a 24 GB card at seq<=1024 — lands on
1476
+ # the MEASURED-SAFE per_device 8 (RunPod RTX 4090 24 GB: pd8 fits at 19.0 GB and is +12.6% over
1477
+ # pd4, while the old seq-independent cap under-fed it at ~5; pd16 there would need ~27 GB -> OOM).
1478
+ # 24 / (7.5 * (0.894/1.41) * 0.63) = 8.0. Bounds short-seq growth to ~1.6x the reference cap.
1479
+ _RL_ACT_SEQ_SCALE_FLOOR = 0.63
1480
+ # Clamp the seq scale at 1.0 (never ABOVE the reference). Combined with the short_seq growth gate,
1481
+ # this makes a seq>=reference run byte-for-byte the old value: seq_scale==1.0 -> vram_cap == the
1482
+ # old colocate cap, and the ceiling falls back to the historical default, so min(default, ...) is
1483
+ # exactly what the old code returned. We deliberately do NOT tighten long-seq below the historical
1484
+ # value (grad checkpointing makes activations sub-linear in seq there, so the linear model would
1485
+ # over-cap), nor grow above it (unvalidated — the regression is in tokens-in-flight = pd x seq).
1486
+ _RL_ACT_SEQ_SCALE_CEIL = 1.0
1487
+
1488
+
1489
+ def rl_per_device_comps(
1490
+ completion_len: int = 0,
1491
+ vocab: int = 248_320,
1492
+ *,
1493
+ use_vllm: bool = True,
1494
+ params_b: float | None = None,
1495
+ seq_len: int = 0,
1496
+ ) -> int:
1497
+ """Per-device *completion* micro-batch for GRPO (TRL counts completions, not prompts).
1498
+
1499
+ This, not grad-accum, sets peak trainer VRAM AND the trainer step's MFU: a bigger
1500
+ micro-batch means bigger, fewer GEMMs (less launch overhead, fuller tensor cores) at the
1501
+ same effective batch (compute_grpo_batching pushes the remainder into grad-accum, so the
1502
+ optimization is identical — only speed/VRAM change). MEASURED on RunPod (Qwen3.5-0.8B GRPO,
1503
+ group8, seq1024): the old seq-independent colocate cap under-fed a 24 GB card at per_device ~5,
1504
+ while per_device 8 fits (19.0 GB) and is +12.6% throughput; on an 80 GB card throughput
1505
+ plateaus at per_device 8..16 and regresses by per_device 32. So on a SHORT-seq run we grow the
1506
+ micro-batch into the card's measured VRAM headroom up to the plateau ceiling.
1507
+
1508
+ Growth is GATED to short sequences (seq < the reference). At/above the reference seq the value
1509
+ is byte-for-byte the historical one — bigger per_device at long context is unvalidated and the
1510
+ regression is driven by tokens-in-flight (per_device x seq), which a fixed-per_device ceiling
1511
+ would not catch.
1512
+
1513
+ Two upper bounds cap the growth:
1514
+
1515
+ * **logits budget (6 GB)** — a HARD correctness cap. The logprob pass can materialize fp32
1516
+ logits of shape [per_device, completion_len, vocab]; at Qwen3.5's ~248k vocab a long
1517
+ completion is enormous (per_device 8 x 4096 tok x 248k x 4 B = ~30 GiB -> OOMs a small
1518
+ card). Liger normally fuses these away, but this stays a safety net for the fallback path.
1519
+
1520
+ * **activation/VRAM cap** — the per-device forward holds the model's attention/activation
1521
+ memory (the Qwen3.5 GDN/FLA kernels peak per micro-batch even with grad checkpointing),
1522
+ which the logits term can't see and which Liger does NOT touch. Calibrated against the live
1523
+ card's VRAM, model width (~sqrt(params)), and — unlike the old seq-independent cap — the
1524
+ training sequence length: activations scale ~linearly with seq, so a SHORT-seq run gets a
1525
+ proportionally bigger cap. MEASURED at seq_ref=2048: Qwen3.5-2B (width ~1.41) group8 OOMs a
1526
+ 32 GB card at per_device=8 but trains at 4 -> 32 / 7.5 = 4.
1527
+
1528
+ Off a live card (allocator / unit tests) there is no VRAM signal, so we fall back to the
1529
+ conservative historical default (8, or 2 with thinking) bounded by the logits budget — the
1530
+ allocator already provisions for that floor, and the worker only ever grows INTO the spare
1531
+ VRAM the chosen card actually reports, so it cannot over-fill the card it was routed to.
1532
+ """
1533
+ default = 2 if THINKING else 8
1534
+
1535
+ # Logits budget: hard upper bound on the fp32 [per_device, completion, vocab] logprob tensor.
1536
+ logits_cap = _RL_PER_DEVICE_MAX
1537
+ if completion_len > 0:
1538
+ logits_cap = max(1, int(6.0e9 / (max(1, completion_len) * vocab * 4)))
1539
+
1540
+ # Growth is gated to SHORT sequences (seq < the reference). At/above the reference seq the
1541
+ # micro-batch is left exactly as the historical code computed it: bigger per_device at long
1542
+ # context is unvalidated and risky — the measured throughput regression is driven by
1543
+ # tokens-in-flight (per_device x seq), so per_device 16 at seq 2048 (~the regression-zone
1544
+ # per_device 32 at seq 1024) could regress, and a fixed-per_device ceiling would not catch it.
1545
+ short_seq = (seq_len or _RL_ACT_SEQ_REF) < _RL_ACT_SEQ_REF
1546
+
1547
+ # Activation/VRAM cap — only computable on a live card. It both caps DOWN (big model / small
1548
+ # card / long seq) and, on a SHORT-seq run, lets the micro-batch GROW into spare VRAM.
1549
+ vram_cap = None
1550
+ if use_vllm:
1551
+ try:
1552
+ import torch
1553
+
1554
+ if torch.cuda.is_available():
1555
+ vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
1556
+ width = (max(float(params_b), 0.1) ** 0.5) if params_b else 1.41
1557
+ seq_scale = min(
1558
+ _RL_ACT_SEQ_SCALE_CEIL,
1559
+ max(_RL_ACT_SEQ_SCALE_FLOOR, (seq_len or _RL_ACT_SEQ_REF) / _RL_ACT_SEQ_REF),
1560
+ )
1561
+ vram_cap = max(
1562
+ 1, int(vram_gb / (_RL_ACT_DIVISOR * (width / 1.41) * seq_scale))
1563
+ )
1564
+ except Exception as e:
1565
+ print("rl_per_device_comps colocate cap probe failed (keeping logits cap):", e)
1566
+
1567
+ if vram_cap is None:
1568
+ # No live card (allocator / offline / unit tests): conservative default, logits-bounded.
1569
+ return max(1, min(default, logits_cap))
1570
+ # Short seq -> grow into measured VRAM headroom up to the plateau ceiling. At/above the
1571
+ # reference seq the ceiling is the historical default, and seq_scale is clamped to 1.0 so
1572
+ # vram_cap == the old colocate cap -> the result is byte-for-byte the old value (no regression,
1573
+ # no unvalidated long-seq growth).
1574
+ #
1575
+ # THINKING runs are EXCLUDED from the growth path: they emit long completions whose
1576
+ # activation/logprob cost the prompt-only `seq_len` gate cannot see, so letting short-seq
1577
+ # growth raise the ceiling to _RL_PER_DEVICE_MAX would silently override the conservative
1578
+ # thinking default (2) and risk OOM / unstable training. They keep `default` as the ceiling,
1579
+ # i.e. byte-for-byte the historical value.
1580
+ ceiling = _RL_PER_DEVICE_MAX if (short_seq and not THINKING) else default
1581
+ return max(1, min(ceiling, logits_cap, vram_cap))
1582
+
1583
+
1584
+ _STEP_GPU_DIAG_INTERVAL_S = 300.0
1585
+ _SFT_HEARTBEAT_INTERVAL_S = 60.0
1586
+
1587
+
1588
+ def make_reward_heartbeat_callback():
1589
+ """A TRL/transformers callback that streams the per-step mean reward to the HF heartbeat
1590
+ channel, giving the worker a live RL signal (no pod log API) and recording a
1591
+ ``reward_history``. Built lazily so the module imports without transformers installed."""
1592
+ from transformers import TrainerCallback
1593
+
1594
+ class _RewardHeartbeat(TrainerCallback):
1595
+ def __init__(self):
1596
+ self.reward_history = []
1597
+ self.last_gpu_diag_at = 0.0
1598
+
1599
+ def on_log(self, args, state, control, logs=None, **kwargs):
1600
+ if not logs:
1601
+ return
1602
+ r = logs.get("reward")
1603
+ if r is None:
1604
+ return
1605
+ try:
1606
+ r = float(r)
1607
+ except (TypeError, ValueError):
1608
+ return
1609
+ self.reward_history.append(r)
1610
+ step = int(getattr(state, "global_step", len(self.reward_history)))
1611
+ payload = {
1612
+ "step": step,
1613
+ "reward": r,
1614
+ "reward_last": self.reward_history[-8:],
1615
+ }
1616
+ now = time.monotonic()
1617
+ if (
1618
+ self.last_gpu_diag_at == 0.0
1619
+ or now - self.last_gpu_diag_at >= _STEP_GPU_DIAG_INTERVAL_S
1620
+ ):
1621
+ payload["gpu"] = gpu_diagnostics()
1622
+ self.last_gpu_diag_at = now
1623
+ heartbeat("rl_step", **payload)
1624
+
1625
+ return _RewardHeartbeat()
1626
+
1627
+
1628
+ def make_sft_heartbeat_callback():
1629
+ """Stream SFT trainer logs so a run is not silent between model load and completion."""
1630
+ from transformers import TrainerCallback
1631
+
1632
+ class _SFTHeartbeat(TrainerCallback):
1633
+ def __init__(self):
1634
+ self.last_heartbeat_at = 0.0
1635
+ self.last_gpu_diag_at = 0.0
1636
+
1637
+ def on_log(self, args, state, control, logs=None, **kwargs):
1638
+ if not logs:
1639
+ return
1640
+ now = time.monotonic()
1641
+ if self.last_heartbeat_at and now - self.last_heartbeat_at < _SFT_HEARTBEAT_INTERVAL_S:
1642
+ return
1643
+ self.last_heartbeat_at = now
1644
+ payload = {
1645
+ "step": int(getattr(state, "global_step", 0) or 0),
1646
+ "epoch": logs.get("epoch"),
1647
+ "loss": logs.get("loss"),
1648
+ "grad_norm": logs.get("grad_norm"),
1649
+ "learning_rate": logs.get("learning_rate"),
1650
+ }
1651
+ if (
1652
+ self.last_gpu_diag_at == 0.0
1653
+ or now - self.last_gpu_diag_at >= _STEP_GPU_DIAG_INTERVAL_S
1654
+ ):
1655
+ payload["gpu"] = gpu_diagnostics()
1656
+ self.last_gpu_diag_at = now
1657
+ heartbeat("sft_step", **{k: v for k, v in payload.items() if v is not None})
1658
+
1659
+ return _SFTHeartbeat()
1660
+
1661
+
1662
+ def grpo_overrides() -> dict:
1663
+ """The GRPO recipe knobs, read off the job spec's ``[train]`` table (``TrainSpec``).
1664
+ A field left unset (None) is omitted here so the recipe default applies downstream.
1665
+
1666
+ Knobs: group_size, temperature, max_tokens (completion budget), kl_penalty_coef (the KL
1667
+ beta), advantage_clip (centered-advantage clip), and thinking_length_penalty_coef
1668
+ (a per-<think>-token reward deduction). These live in ``[train]`` — NOT in
1669
+ ``[environment.params]``, which is forwarded verbatim to the Freesolo env loader."""
1670
+ if not JOB_SPEC:
1671
+ return {}
1672
+ train = JOB_SPEC.train
1673
+ cfg = {
1674
+ "group_size": train.group_size,
1675
+ "temperature": train.temperature,
1676
+ "max_tokens": train.max_tokens,
1677
+ "kl_penalty_coef": train.kl_penalty_coef,
1678
+ "advantage_clip": train.advantage_clip,
1679
+ "thinking_length_penalty_coef": train.thinking_length_penalty_coef,
1680
+ }
1681
+ return {k: v for k, v in cfg.items() if v is not None}
1682
+
1683
+
1684
+ def think_token_count(completion: str | None, tokenizer) -> int:
1685
+ """Number of tokens inside the completion's <think>...</think> span (0 if none).
1686
+
1687
+ Used for the thinking-length reward deduction: long reasoning is penalized in
1688
+ proportion to the tokens it spent, mirroring the SDK's thinking_length_penalty_coef.
1689
+ """
1690
+ if not completion or "<think>" not in completion:
1691
+ return 0
1692
+ after = completion.split("<think>", 1)[1]
1693
+ think_text = after.split("</think>", 1)[0] if "</think>" in after else after
1694
+ if not think_text:
1695
+ return 0
1696
+ return len(tokenizer(think_text, add_special_tokens=False)["input_ids"])
1697
+
1698
+
1699
+ def _init_adapter_model(model_id: str):
1700
+ """Base model + the ``train.init_from_adapter`` adapter loaded as a trainable
1701
+ PeftModel, or the plain ``model_id`` string + a fresh LoRA when it is unset.
1702
+
1703
+ GRPO continuing an SFT adapter: TRL trains the LOADED adapter (peft_config=None)
1704
+ instead of attaching a fresh one."""
1705
+ prefix = JOB_SPEC.train.init_from_adapter if JOB_SPEC else ""
1706
+ if not prefix:
1707
+ return model_id, make_lora(model_id)
1708
+ adir = _download_adapter(prefix)
1709
+ if not adir:
1710
+ # The user explicitly asked GRPO to continue from this adapter; silently
1711
+ # falling back to a fresh base-model LoRA would spend a full paid run
1712
+ # optimizing the wrong starting point. Fail hard instead.
1713
+ raise RuntimeError(
1714
+ f"train.init_from_adapter={prefix!r} could not be downloaded from the artifact "
1715
+ "store (wrong/missing prefix or no access); refusing to silently start GRPO from "
1716
+ "the base model. Fix the adapter prefix / HF credentials, or omit "
1717
+ "init_from_adapter to train a fresh LoRA."
1718
+ )
1719
+ from peft import PeftModel
1720
+ from transformers import AutoModelForCausalLM
1721
+
1722
+ print(f"[init-adapter] initializing LoRA from {prefix}")
1723
+ # VL checkpoints (Qwen3.5/3.6): the SFT step saved the adapter against the FULL multimodal model
1724
+ # (keys under ``base_model.model.model.language_model.layers.*``), but we load the base here via
1725
+ # AutoModelForCausalLM (text-only tree, ``base_model.model.model.layers.*``). Strip the
1726
+ # ``.language_model.`` infix on disk so PeftModel.from_pretrained matches the SFT keys —
1727
+ # otherwise peft only WARNS about missing keys and silently trains a fresh LoRA, discarding the
1728
+ # SFT. No-op for non-VL checkpoints. See flash/engine/worker/lora.py.
1729
+ remap_vl_adapter_dir(adir, model_id)
1730
+ _attn = optimal_attn_impl()
1731
+ base = AutoModelForCausalLM.from_pretrained(
1732
+ model_id,
1733
+ dtype="bfloat16",
1734
+ trust_remote_code=True,
1735
+ **({"attn_implementation": _attn} if _attn else {}),
1736
+ )
1737
+ model = PeftModel.from_pretrained(base, adir, is_trainable=True)
1738
+ # Fail loudly if the adapter didn't actually apply (a key mismatch would otherwise silently start
1739
+ # GRPO from the base model again). from_pretrained loads with load_state_dict(strict=False) and
1740
+ # only WARNS on a mismatch, discarding the load result — so re-run load_adapter to CAPTURE which
1741
+ # keys matched and assert matched==saved (peft injects the LoRA modules from target_modules BEFORE
1742
+ # loading weights, so the module-count check alone can't see a silent weight discard). The reload
1743
+ # is idempotent: same weights into the same "default" adapter. See flash/engine/worker/lora.py.
1744
+ # Mirror from_pretrained's key_mapping: for transformers models that define a
1745
+ # ``_checkpoint_conversion_mapping`` (renamed-arch checkpoints), from_pretrained remaps the adapter
1746
+ # keys before loading; the reload must apply the SAME mapping or it would reinterpret valid keys as
1747
+ # mismatched and falsely abort. peft reads it off the base model (peft_model.py from_pretrained).
1748
+ key_mapping = getattr(base, "_checkpoint_conversion_mapping", None)
1749
+ load_result = model.load_adapter(
1750
+ adir, adapter_name="default", is_trainable=True, key_mapping=key_mapping
1751
+ )
1752
+ assert_adapter_load_clean(load_result, model_id)
1753
+ assert_lora_applied(model, model_id)
1754
+ assert_adapter_delta_nonzero(model, model_id)
1755
+ return model, None
1756
+
1757
+
1758
+ def _grpo_resume_already_complete(resume_ckpt, target_steps: int, steps_run: int) -> bool:
1759
+ """True when this worker resumed a checkpoint that already reached the target step count.
1760
+
1761
+ Such a resume legitimately performs ZERO new optimizer steps (so the fresh hb_cb has an empty
1762
+ reward_history) yet the policy IS fully trained — it must NOT be flagged as a no-op failure.
1763
+ """
1764
+ return bool(resume_ckpt) and target_steps > 0 and steps_run >= target_steps
1765
+
1766
+
1767
+ def _grpo_is_no_op_failure(reward_history, resume_ckpt, target_steps: int, steps_run: int) -> bool:
1768
+ """True when a GRPO run trained NOTHING and must fail loudly instead of reporting as done.
1769
+
1770
+ An empty ``reward_history`` means the reward callback never fired — the rollout scored nothing
1771
+ (e.g. vLLM silently returning no completions), so no real training happened. The sole exception
1772
+ is a resume that already reached the target steps (see ``_grpo_resume_already_complete``): that
1773
+ has an empty fresh history but a fully-trained policy, so it is NOT a failure.
1774
+ """
1775
+ if reward_history:
1776
+ return False
1777
+ return not _grpo_resume_already_complete(resume_ckpt, target_steps, steps_run)
1778
+
1779
+
1780
+ def run_rl():
1781
+ from datasets import Dataset
1782
+ from transformers import AutoTokenizer
1783
+ from trl import GRPOConfig, GRPOTrainer
1784
+
1785
+ env = require_active_env() # fail loudly (not AttributeError: NoneType) on the no-JobSpec path
1786
+ t_start = time.time()
1787
+ heartbeat("rl_start", gpu=gpu_diagnostics())
1788
+ # GRPO rollout strategy by env shape (trl 1.6 adds the hooks these need):
1789
+ # * single-turn -> TRL single-shot generation + per-completion reward (below);
1790
+ # * tool (ToolEnv & subs:
1791
+ # Stateful/Sandbox/Python) -> TRL drives the tool-call loop natively via
1792
+ # GRPOTrainer(tools=...) (it parses tool calls, executes the tools, and masks the
1793
+ # tool-result tokens itself); the reward scores the full transcript;
1794
+ # * pure multi-turn -> a custom rollout_func (flash.engine.multiturn_rollout)
1795
+ # drives THIS env's turn loop on the colocate engine and returns the interleaved
1796
+ # token sequence with an env_mask so only the model's tokens are trained.
1797
+ is_tool_env = getattr(env, "is_tool_env", False)
1798
+ is_multi_turn = getattr(env, "multi_turn", False)
1799
+ conversational = is_multi_turn # message-list prompts (tool + pure multi-turn) vs strings
1800
+ if is_multi_turn:
1801
+ # The Liger fused GRPO loss (use_liger_kernel, kept ON to avoid the 248k-vocab fp32-logits
1802
+ # OOM) torch.compiles, and on the VARIABLE-length multi-turn completions its dynamo guard
1803
+ # build trips a torch 2.10 bug (symbol_to_source IndexError) that crashes the first
1804
+ # training step. Let dynamo FALL BACK TO EAGER for the offending function instead of
1805
+ # raising. This is NOT `TORCHDYNAMO_DISABLE` (which would also break the colocate vLLM
1806
+ # engine's required compilation) — dynamo stays enabled; only erroring graphs run eager.
1807
+ try:
1808
+ import torch._dynamo
1809
+
1810
+ torch._dynamo.config.suppress_errors = True
1811
+ print("[rl] multi-turn: torch._dynamo suppress_errors=True (Liger loss falls back to eager on dynamic shapes)")
1812
+ except Exception as exc: # never let a torch internals change block the run
1813
+ print(f"[rl] could not set torch._dynamo.suppress_errors: {exc!r}")
1814
+ wait_for_gpu()
1815
+ setup_perf_backends()
1816
+ model_id = JOB_SPEC.model if JOB_SPEC else RECIPE.hf_model_id
1817
+ download_seconds = prefetch_model(model_id)
1818
+ rl = RECIPE.rl
1819
+ # Steps come from the run's [train] steps (already in JOB_SPEC), else the recipe default.
1820
+ steps = int(
1821
+ JOB_SPEC.train.steps if JOB_SPEC and JOB_SPEC.train.steps is not None else rl.num_steps
1822
+ )
1823
+ # Throughput/quality knobs: the number of prompts optimized per step, completions per
1824
+ # prompt, and whether vLLM offloads weights between steps. Sleep mode frees memory for the
1825
+ # optimizer but reloads ~weights each step (a large per-step cost); it's gated OFF by model
1826
+ # size when both the policy and rollout engine fit resident.
1827
+ gcfg = grpo_overrides()
1828
+ _t = JOB_SPEC.train if JOB_SPEC else None
1829
+ # batch_size = prompts per optimizer step for GRPO.
1830
+ # prompts per optimizer step = the run config's [train].batch_size (recipe default otherwise).
1831
+ prompts_per_step = int(
1832
+ _t.batch_size if _t and _t.batch_size is not None else rl.prompts_per_step
1833
+ )
1834
+ group_size = int(gcfg.get("group_size") or rl.group_size)
1835
+ # temperature: explicit None check, NOT `or` — a configured 0.0 (greedy/deterministic
1836
+ # rollouts) must be honored, not fall back to the recipe sampling temperature.
1837
+ _gcfg_temp = gcfg.get("temperature")
1838
+ _temperature = float(_gcfg_temp if _gcfg_temp is not None else rl.sampling_temperature)
1839
+ _kl_beta = float(gcfg.get("kl_penalty_coef") or 0.0)
1840
+ _adv_clip = float(gcfg.get("advantage_clip") or 0.0)
1841
+ _think_penalty = float(gcfg.get("thinking_length_penalty_coef") or 0.0)
1842
+ # vLLM sleep mode offloads the rollout engine's weights between steps to free memory for the
1843
+ # optimizer, but reloading each step is a large per-step cost (PR #174 measured ~2-2.6x faster
1844
+ # GRPO with it OFF on models that fit) AND on the large-model GRPO path the sleep/wake cycle
1845
+ # STALLS the colocated rollout (the rollout emits unparseable completions, then the worker
1846
+ # hangs mid-training). So enable sleep only when the run genuinely can't fit RESIDENT on THIS
1847
+ # card: large/long-context AND the policy + colocated rollout engine + training peak don't fit
1848
+ # on the live GPU. When they fit (the common allocator-sized case), skip sleep entirely.
1849
+ _grpo_ctx = int(_t.max_length if _t and _t.max_length else 0)
1850
+ _card_vram_gb = 0.0
1851
+ try:
1852
+ import torch as _torch_card
1853
+
1854
+ if _torch_card.cuda.is_available():
1855
+ # Binary GiB (/(1024**3)), NOT decimal GB (/1e9 over-reports ~7%): grpo_fits_resident's
1856
+ # VRAM estimate is in GiB, so a decimal card size would make a marginal card look big
1857
+ # enough to fit resident and wrongly disable sleep, risking OOM.
1858
+ _card_vram_gb = _torch_card.cuda.get_device_properties(0).total_memory / (1024**3)
1859
+ except Exception as _e:
1860
+ print("[rl] card VRAM probe failed (sleep-mode gate falls back to size/context):", _e)
1861
+ _lora_rank = int(_t.lora_rank) if _t and _t.lora_rank else 32
1862
+ sleep_mode = grpo_sleep_mode(
1863
+ model_id,
1864
+ max_length=_grpo_ctx,
1865
+ group_size=group_size,
1866
+ max_tokens=gcfg.get("max_tokens"),
1867
+ lora_rank=_lora_rank,
1868
+ thinking=THINKING,
1869
+ card_vram_gb=_card_vram_gb,
1870
+ )
1871
+ print(
1872
+ f"[rl] vLLM sleep mode = {sleep_mode} "
1873
+ f"(model={model_id}, ctx={_grpo_ctx}, card={_card_vram_gb:.0f}GB)"
1874
+ )
1875
+ # Rollout backend: always colocated vLLM (fast). The whole supported catalog runs GRPO with
1876
+ # colocated vLLM; there is no transformers-generation fallback.
1877
+ use_vllm = True
1878
+ print("[rl] rollout backend: colocated vLLM")
1879
+ from flash.catalog import MODELS as _CATALOG
1880
+
1881
+ _info = _CATALOG.get(model_id)
1882
+ tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
1883
+ if tok.pad_token is None:
1884
+ tok.pad_token = tok.eos_token
1885
+
1886
+ train = env.dataset()
1887
+ rng = random.Random(SEED)
1888
+ rng.shuffle(train)
1889
+ if conversational:
1890
+ # Message-list prompts so the chat template applies roles + (for tool envs) the tool
1891
+ # schemas; per-turn length is managed by the tool loop / rollout_func, not a flat budget.
1892
+ prompts = [{"prompt": env.prompt_messages(ex), "example": ex} for ex in train]
1893
+ else:
1894
+ prompts = [{"prompt": render_prompt(tok, ex), "example": ex} for ex in train]
1895
+ # The colocated vLLM engine's model length is the hard cap on prompt+completion at
1896
+ # rollout. Size it from [train].max_length and derive the prompt budget from it so a
1897
+ # bigger engine or a smaller completion automatically admits longer prompts (rather than
1898
+ # a fixed rl.max_prompt_len that no env override could lift).
1899
+ _max_completion = int(
1900
+ gcfg.get("max_tokens")
1901
+ or (rl.max_completion_len_thinking if THINKING else rl.max_completion_len)
1902
+ )
1903
+ # Engine context = the run's [train].max_length (so a long-context GRPO config sized/paid for
1904
+ # by the allocator actually RUNS at that length), else the recipe default. Without the
1905
+ # train.max_length fallback the allocator provisions a big GPU for the long context but the
1906
+ # engine runs short — paying for headroom we never use.
1907
+ _train_ctx = _t.max_length if (_t and _t.max_length) else 0
1908
+ vllm_max_len = int(_train_ctx or max(1024, rl.max_prompt_len + _max_completion))
1909
+ # The engine must fit completion + at least some prompt. If [train].max_length is below the
1910
+ # completion budget, no prompt can ever fit — fail fast here rather than passing a 1-token
1911
+ # budget that lets prompts through and then OOMs/overflows mid-rollout.
1912
+ if vllm_max_len <= _max_completion:
1913
+ raise ValueError(
1914
+ f"engine length {vllm_max_len} leaves no room for the {_max_completion}-token "
1915
+ "completion; raise [train].max_length or lower [train].max_tokens"
1916
+ )
1917
+ prompt_budget = vllm_max_len - _max_completion
1918
+
1919
+ # TRL 1.5's GRPOConfig has no max_prompt_length and does NOT truncate prompts, so a prompt
1920
+ # that leaves no room for the completion within the engine length would fail mid-rollout
1921
+ # AFTER the paid worker is provisioned. Drop prompts that don't fit the budget up front.
1922
+ # render_prompt returns an apply_chat_template(tokenize=False) string that already carries
1923
+ # the special tokens, so tokenize with add_special_tokens=False (the default re-adds
1924
+ # BOS/EOS and over-counts).
1925
+ # Drop prompts that leave no room for the completion within the engine length — applies to
1926
+ # BOTH single-turn (string prompts) and conversational (message-list) prompts, so a tool /
1927
+ # multi-turn rollout can't overflow the colocate engine mid-generation. Conversational
1928
+ # prompts are length-checked via the chat template (with the generation prompt).
1929
+ # Tool schemas TRL injects into the prompt for native tools= GRPO — include them in the
1930
+ # budget for a tool env so a prompt isn't undercounted at filter time vs. rollout time.
1931
+ _oai_tools = (
1932
+ getattr(getattr(env, "_env", None), "oai_tools", None) if is_tool_env else None
1933
+ )
1934
+
1935
+ def _prompt_tokens(p) -> int:
1936
+ if conversational:
1937
+ # Render to text then tokenize — the SAME path the rollout uses — so the filter
1938
+ # count matches the rollout's count (avoids a tokenize=True vs text mismatch).
1939
+ kw = {"tools": _oai_tools} if _oai_tools else {}
1940
+ try:
1941
+ text = tok.apply_chat_template(
1942
+ p["prompt"],
1943
+ add_generation_prompt=True,
1944
+ tokenize=False,
1945
+ enable_thinking=THINKING,
1946
+ **kw,
1947
+ )
1948
+ except Exception as exc:
1949
+ # Fail fast WITH context: a tokenizer/template incompatibility would render every
1950
+ # prompt uncountable and otherwise surface as a misleading "all prompts exceed
1951
+ # budget" — raise so the model/template can be fixed before a paid run trains on
1952
+ # a degenerate dataset.
1953
+ raise RuntimeError(
1954
+ "failed to render a conversational prompt with this model's chat template "
1955
+ f"(fix the model/template or the env's prompts): {exc}"
1956
+ ) from exc
1957
+ return len(tok(text, add_special_tokens=False).input_ids)
1958
+ return len(tok(p["prompt"], add_special_tokens=False).input_ids)
1959
+
1960
+ kept = [p for p in prompts if 0 < _prompt_tokens(p) <= prompt_budget]
1961
+ if len(kept) < len(prompts):
1962
+ print(
1963
+ f"[rl] dropped {len(prompts) - len(kept)} prompts over the {prompt_budget}-token "
1964
+ f"prompt budget (engine {vllm_max_len} - completion {_max_completion})"
1965
+ )
1966
+ if not kept:
1967
+ raise ValueError(
1968
+ f"every training prompt exceeds the {prompt_budget}-token prompt budget (engine "
1969
+ f"{vllm_max_len} - completion {_max_completion}); raise [train].max_length, lower "
1970
+ "[train].max_tokens, or shorten the environment's prompts"
1971
+ )
1972
+ prompts = kept
1973
+ resolved_prompts_per_step = resolve_grpo_prompts_per_step(prompts_per_step, len(prompts))
1974
+ if resolved_prompts_per_step != prompts_per_step:
1975
+ print(
1976
+ f"[rl] lowering prompts_per_step from {prompts_per_step} to "
1977
+ f"{resolved_prompts_per_step}: only {len(prompts)} prompt(s) fit after filtering"
1978
+ )
1979
+ prompts_per_step = resolved_prompts_per_step
1980
+ # Carry a stable integer index instead of the rich record so PyArrow can't crash on an env whose
1981
+ # per-row info/metadata legitimately mixes types (see build_grpo_prompt_dataset). reward_fn maps
1982
+ # the index back to the original example object below.
1983
+ ds_rows, rollout_examples = build_grpo_prompt_dataset(prompts)
1984
+ ds = Dataset.from_list(ds_rows)
1985
+
1986
+ def reward_fn(completions, **kwargs):
1987
+ # rollout_func (pure multi-turn) path: the per-rollout reward is computed by the env
1988
+ # during the rollout and forwarded as the "reward" extra field — pass it through.
1989
+ if kwargs.get("reward") is not None:
1990
+ return [float(r) for r in kwargs["reward"]]
1991
+ # Score the <think>-stripped text (graded_text), then — datums parity — deduct
1992
+ # the thinking-length penalty computed from the RAW completion's <think> span.
1993
+ # The dataset carries example_idx (not the record); map each back to its original object.
1994
+ # Fail LOUD if TRL stops forwarding example_idx (column pruning / a TRL change): defaulting to
1995
+ # [] would zip to ZERO examples -> empty rewards -> silent no-op / broken training (issues
1996
+ # #206 / #210). A reward over the wrong/empty examples is far worse than crashing the run.
1997
+ example_idx = kwargs.get("example_idx")
1998
+ if example_idx is None:
1999
+ raise RuntimeError(
2000
+ "GRPO reward_fn received no 'example_idx' column from TRL — the reward cannot be "
2001
+ "mapped back to its training example, so every reward would be empty/misaligned "
2002
+ f"(got kwargs keys {sorted(kwargs)}). This usually means TRL dropped the dataset "
2003
+ "column (remove_unused_columns / a TRL version change); the run is aborted rather "
2004
+ "than silently training on no signal."
2005
+ )
2006
+ if len(example_idx) != len(completions):
2007
+ raise RuntimeError(
2008
+ f"GRPO reward_fn example_idx/completions length mismatch "
2009
+ f"({len(example_idx)} vs {len(completions)}) — rewards would be misaligned with "
2010
+ "the sampled completions; aborting rather than training on a shifted reward signal."
2011
+ )
2012
+ examples = [rollout_examples[int(i)] for i in example_idx]
2013
+ rewards = []
2014
+ debug_rows = []
2015
+ for idx, (comp, ex) in enumerate(zip(completions, examples, strict=False)):
2016
+ if isinstance(comp, list):
2017
+ # Tool / conversational transcript (TRL passes a list of messages): score the
2018
+ # whole transcript via the environment reward (no <think> stripping —
2019
+ # multi-turn content).
2020
+ r = env.reward_from_messages(comp, ex)
2021
+ rewards.append(r)
2022
+ continue
2023
+ graded = graded_text(comp)
2024
+ breakdown = None
2025
+ if hasattr(env, "scores_breakdown"):
2026
+ breakdown = env.scores_breakdown(graded, ex)
2027
+ r = float(breakdown.get("total", 0.0))
2028
+ else:
2029
+ r = env.reward(graded, ex)
2030
+ if _think_penalty > 0 and THINKING:
2031
+ r -= _think_penalty * think_token_count(comp, tok)
2032
+ rewards.append(r)
2033
+ if idx < 8:
2034
+ debug_rows.append(
2035
+ {
2036
+ "ts": time.time(),
2037
+ "attempt": ATTEMPT,
2038
+ "run_id": RUN_ID,
2039
+ "mode": RUN_MODE,
2040
+ "seed": SEED,
2041
+ "reward": r,
2042
+ "breakdown": breakdown,
2043
+ "completion_prefix": str(comp or "")[:1000],
2044
+ "graded_prefix": str(graded or "")[:1000],
2045
+ "example_id": (ex or {}).get("id") if isinstance(ex, dict) else None,
2046
+ "example_input": (ex or {}).get("input") if isinstance(ex, dict) else None,
2047
+ }
2048
+ )
2049
+ upload_debug_jsonl("reward_debug.jsonl", debug_rows)
2050
+ return rewards
2051
+
2052
+ # TRL's per_device_train_batch_size counts COMPLETIONS, not prompts. Size grad-accum so
2053
+ # the global completion batch = prompts_per_step * group_size, i.e. each optimizer step
2054
+ # actually optimizes `prompts_per_step` prompts. The per-device *completion* micro-batch
2055
+ # is the VRAM knob (thinking-aware; see rl_per_device_comps).
2056
+ from flash.engine.vram import resolve_params_b
2057
+
2058
+ # Open-model (uncataloged) GRPO: size the colocate activation cap from the catalog stat, else
2059
+ # the HF safetensors metadata (no download). Without a real count a large open model falls back
2060
+ # to the ~2B-width default in rl_per_device_comps and gets too LOOSE a per-device cap ->
2061
+ # colocate OOM. Best-effort: stays None offline, keeping prior behavior.
2062
+ _params_b = resolve_params_b(model_id)
2063
+ from flash.catalog import vocab_size_for
2064
+
2065
+ # Per-device completion-logits cap: a multi-turn rollout accumulates a FULL transcript (model
2066
+ # turns + masked env tokens) up to the engine context — far longer than the single-turn per-turn
2067
+ # budget `_max_completion` — and the trainer's logprob forward processes that whole completion.
2068
+ # So size the fp32 [per_device, completion, vocab] cap against the WORST-CASE multi-turn
2069
+ # completion length (the engine context) instead of `_max_completion`, or a long multi-turn run
2070
+ # OOMs the trainer forward. Single-turn keeps `_max_completion` (its true completion length).
2071
+ _cap_completion_len = vllm_max_len if is_multi_turn else _max_completion
2072
+ per_device_comps = rl_per_device_comps(
2073
+ _cap_completion_len,
2074
+ vocab=vocab_size_for(model_id),
2075
+ use_vllm=use_vllm,
2076
+ params_b=_params_b,
2077
+ # The trainer forward processes prompt+completion up to the engine context, so the
2078
+ # activation/VRAM cap is sized against the worst-case training sequence length.
2079
+ seq_len=vllm_max_len,
2080
+ )
2081
+ if is_multi_turn and _cap_completion_len != _max_completion:
2082
+ print(
2083
+ f"[rl] multi-turn: sizing the per-device logits cap against the full transcript length "
2084
+ f"{_cap_completion_len} (engine context), not the per-turn budget {_max_completion}"
2085
+ )
2086
+ batching = compute_grpo_batching(prompts_per_step, group_size, per_device_comps)
2087
+ if not batching["divisible_by_group"]:
2088
+ print(
2089
+ "WARN: generation batch not divisible by group size; check prompts_per_step/group_size"
2090
+ )
2091
+ print(
2092
+ f"[rl] GRPO batching: per_device={batching['per_device_train_batch_size']} "
2093
+ f"grad_accum={batching['gradient_accumulation_steps']} "
2094
+ f"generations/step={batching['generations_per_step']} "
2095
+ f"unique_prompts/step={batching['unique_prompts_per_step']} "
2096
+ f"(target prompts/step={prompts_per_step}, group={group_size}, sleep={sleep_mode})"
2097
+ )
2098
+ out_dir = f"/tmp/rl_seed{SEED}"
2099
+ resume_ckpt = hf_resume_checkpoint()
2100
+
2101
+ grpo_kwargs = {
2102
+ "output_dir": out_dir,
2103
+ "learning_rate": (
2104
+ _t.learning_rate if _t and _t.learning_rate is not None else rl.learning_rate
2105
+ ),
2106
+ "per_device_train_batch_size": batching["per_device_train_batch_size"],
2107
+ "gradient_accumulation_steps": batching["gradient_accumulation_steps"],
2108
+ "num_generations": group_size,
2109
+ # NB: GRPOConfig has no max_prompt_length field (TRL 1.5) and does not truncate
2110
+ # prompts; the dataset is pre-filtered above to prompts that fit prompt_budget
2111
+ # (vllm_max_len - completion), so every prompt fits the engine sized here.
2112
+ "max_completion_length": _max_completion,
2113
+ "max_steps": steps,
2114
+ "temperature": _temperature,
2115
+ "top_p": rl.sampling_top_p,
2116
+ "use_vllm": use_vllm,
2117
+ "logging_steps": 1,
2118
+ "save_steps": _t.save_every if _t and _t.save_every is not None else 20,
2119
+ "save_total_limit": 1,
2120
+ # Resumable checkpoints: keep the optimizer/scheduler/RNG state with the LoRA adapter so a
2121
+ # preempted GRPO run resumed via resume_from_checkpoint(hf_resume_checkpoint()) continues
2122
+ # with intact optimizer state + step instead of a fresh optimizer. For LoRA this state is
2123
+ # small (trainable adapter params only). The deployable per-step snapshot strips it
2124
+ # separately, so serving still gets adapter-only files.
2125
+ "save_only_model": False,
2126
+ "bf16": True,
2127
+ "report_to": wandb_report_to(), # W&B when WANDB_API_KEY present (restored post-flash-migration)
2128
+ "run_name": wandb_run_name(),
2129
+ "seed": SEED,
2130
+ "gradient_checkpointing": grad_checkpointing_on(model_id, vllm_max_len),
2131
+ # Non-reentrant checkpointing: the modern path that composes correctly with autograd
2132
+ # saved-tensor hooks and avoids the reentrant path's extra graph retention. (verl #3629.)
2133
+ "gradient_checkpointing_kwargs": {"use_reentrant": False},
2134
+ # Pin a stable, well-conditioned GRPO recipe instead of inheriting TRL's defaults
2135
+ # (which on a short run suppress the lift): constant LR (TRL default 'linear' decays
2136
+ # to 0 over the run), advantages centered by group mean only (no std scaling, which
2137
+ # biases by difficulty/length — matches datums.centered_advantages), and no
2138
+ # length-normalized loss. beta is the KL-to-reference coef (datums kl_masks ->
2139
+ # kl_penalty_coef).
2140
+ "lr_scheduler_type": "constant",
2141
+ "warmup_ratio": 0.0,
2142
+ "beta": _kl_beta,
2143
+ "scale_rewards": "none",
2144
+ "loss_type": "dr_grpo",
2145
+ # Optimizer: 8-bit paged AdamW (int8 state paged to host RAM -> fits a smaller GPU);
2146
+ # colocated GRPO (trainer + vLLM on one GPU) is memory-tight, so this is the right default.
2147
+ "optim": fused_optim_name(),
2148
+ }
2149
+ # Liger fused GRPO loss: fuses the lm_head + per-token logprob so the fp32
2150
+ # [batch, seq, ~248k vocab] logits never materialize — the documented GRPO OOM driver.
2151
+ # TRL 1.6's GRPOConfig flag is `use_liger_kernel` (NOT `use_liger_loss`, which doesn't
2152
+ # exist in 1.6). DEFAULT ON for the GRPO path regardless of model size: MEASURED that
2153
+ # WITHOUT it even Qwen3.5-0.8B GRPO OOMs a 24 GB (and 32 GB) card because the per-completion
2154
+ # logits over the 248k vocab dominate — the small-scale JIT cost is far cheaper than the OOM.
2155
+ # (This differs from SFT, where Liger is gated by size since 1B-class SFT can be net-negative.)
2156
+ if liger_on(True):
2157
+ grpo_kwargs["use_liger_kernel"] = True
2158
+ print("[rl] liger fused GRPO loss enabled")
2159
+ if use_vllm:
2160
+ # RTX 5090 / sm120: pin a PTX-independent vLLM attention backend (FLASHINFER) BEFORE TRL
2161
+ # builds the colocated engine — else the rollout can silently produce no completions on
2162
+ # old-driver Blackwell hosts (flash-attn PTX JIT failure). No-op off sm120 / if pinned.
2163
+ force_vllm_backend_for_sm120()
2164
+ # Colocate shares one GPU between the policy model and the vLLM rollout engine.
2165
+ # vllm_max_model_length bounds the KV cache to what GRPO needs (else vLLM sizes for
2166
+ # the model's FULL context and won't start on a consumer GPU).
2167
+ # vllm_gpu_memory_utilization sizes vLLM's KV pool. The blanket sleep-path 0.45 was a
2168
+ # misjudgement: on an 80 GB A100 it reserves 0.45 x 80 = 36 GB of KV, but a GRPO rollout only
2169
+ # holds ~num_generations x context tokens. MEASURED (Qwen3.5-4B colocate): that 36 GB
2170
+ # reservation is the dominant resident allocation and sets the step peak (~46 GB) — exactly why
2171
+ # trainer-side optimisations (mask-aware lm_head, fused layers) moved nothing. colocate_kv_util
2172
+ # sizes both paths from flash's per-model KV estimate instead (vram.py); MEASURED 4B/80 GB peak
2173
+ # 46 -> 26 GB, reward byte-identical, train_wall neutral.
2174
+ try:
2175
+ import torch as _torch_vram
2176
+
2177
+ from flash.engine.vram import colocate_kv_util
2178
+
2179
+ _total_vram_gb = _torch_vram.cuda.get_device_properties(0).total_memory / 1e9
2180
+ _vllm_gpu_mem_util = colocate_kv_util(
2181
+ _params_b, vllm_max_len, _total_vram_gb, sleep_mode, num_generations=group_size
2182
+ )
2183
+ except Exception:
2184
+ _vllm_gpu_mem_util = 0.45 if sleep_mode else 0.10 # safe fallback to the old constants
2185
+ grpo_kwargs.update(
2186
+ vllm_mode="colocate",
2187
+ vllm_max_model_length=vllm_max_len,
2188
+ vllm_gpu_memory_utilization=_vllm_gpu_mem_util,
2189
+ vllm_enable_sleep_mode=sleep_mode,
2190
+ )
2191
+ # Rollout-memory + throughput knobs, applied ONLY if this TRL exposes the field (so an
2192
+ # older TRL never crashes on an unknown kwarg). All verl-validated for GRPO colocate (#174).
2193
+ _grpo_fields = set(getattr(GRPOConfig, "__dataclass_fields__", {}))
2194
+
2195
+ def _set_vllm_field(names, value, label):
2196
+ for _f in names:
2197
+ if _f in _grpo_fields:
2198
+ grpo_kwargs[_f] = value
2199
+ print(f"[rl] {label} ({_f}={value})")
2200
+ return True
2201
+ return False
2202
+
2203
+ # fp8 KV cache only where the silicon has native fp8 (compute capability >= 8.9: Ada /
2204
+ # Hopper / Blackwell) — ~halves the rollout KV pool. Ampere (A100/A6000/3090) lacks
2205
+ # fp8, so it stays fp16 there (forcing it on would error / silently emulate).
2206
+ try:
2207
+ import torch as _torch
2208
+
2209
+ _want_fp8 = _torch.cuda.get_device_capability() >= (8, 9)
2210
+ except Exception:
2211
+ _want_fp8 = False
2212
+ if _want_fp8:
2213
+ _set_vllm_field(("vllm_kv_cache_dtype", "kv_cache_dtype"), "fp8", "fp8 KV cache")
2214
+ # PREFIX CACHING: every GRPO group of `num_generations` rollouts shares the SAME prompt
2215
+ # prefix, so caching the prompt KV computes it once and reuses it — the dominant rollout win
2216
+ # on one GPU. CHUNKED PREFILL interleaves prefill with decode so a long prompt doesn't stall
2217
+ # the batch. CUDAGRAPH MODE sets verl's full-graph-decode + piecewise-fallback rollout mode.
2218
+ _set_vllm_field(
2219
+ ("vllm_enable_prefix_caching", "enable_prefix_caching"),
2220
+ True,
2221
+ "vLLM prefix caching (shared GRPO prompt KV reuse)",
2222
+ )
2223
+ _set_vllm_field(
2224
+ ("vllm_enable_chunked_prefill", "enable_chunked_prefill"),
2225
+ True,
2226
+ "vLLM chunked prefill",
2227
+ )
2228
+ # vLLM 0.19.1 regressed the Triton _compute_slot_mapping_kernel: it launches
2229
+ # (num_reqs + 1) thread blocks but the block table only has num_reqs rows, so the
2230
+ # extra block causes an illegal memory access (cudaErrorIllegalAddress) on the first
2231
+ # generation step. CUDA graph compilation triggers this path. Skip FULL_AND_PIECEWISE
2232
+ # for vLLM versions outside TRL's supported range (0.12.0-0.19.0) until a fix lands.
2233
+ _cudagraph_safe = True
2234
+ try:
2235
+ import vllm as _vllm_mod
2236
+
2237
+ _ver_base = _vllm_mod.__version__.split("+")[0] # strip PEP440 local (e.g. +cu121)
2238
+ _vllm_ver = tuple(int(x) for x in _ver_base.split(".")[:3])
2239
+ if _vllm_ver > (0, 19, 0):
2240
+ _cudagraph_safe = False
2241
+ print(
2242
+ f"[rl][warn] vLLM {_vllm_mod.__version__} > 0.19.0: skipping "
2243
+ "FULL_AND_PIECEWISE CUDA graph compilation (Triton slot-mapping "
2244
+ "crash workaround; update vLLM to a TRL-supported version to re-enable)"
2245
+ )
2246
+ # vLLM 0.19.1 ALSO hits `RuntimeError: aot_compile is not supported by the
2247
+ # current configuration` through its DEFAULT torch.compile path on some GPU
2248
+ # arches (Ampere sm_86: A6000, A100) — it fires from vllm/compilation/wrapper.py
2249
+ # when torch._dynamo.is_compiling() is False inside the CUDA-graph capture path.
2250
+ # Skipping FULL_AND_PIECEWISE above is not enough (the vllm_compilation_config
2251
+ # GRPOConfig field doesn't exist in this TRL, so that _set_vllm_field is a no-op).
2252
+ # VLLM_TORCH_COMPILE_LEVEL=0 (NO_COMPILATION) forces vLLM to execute the model
2253
+ # eagerly, preventing the AOT path entirely. Official vLLM env var (vllm/envs.py);
2254
+ # a no-op on a vLLM that doesn't define it. Don't override an operator-set value.
2255
+ if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ:
2256
+ os.environ["VLLM_TORCH_COMPILE_LEVEL"] = "0"
2257
+ print("[rl][warn] VLLM_TORCH_COMPILE_LEVEL=0 (prevent aot_compile on vLLM 0.19.1)")
2258
+ except Exception:
2259
+ pass
2260
+ if _cudagraph_safe:
2261
+ _set_vllm_field(
2262
+ ("vllm_compilation_config", "compilation_config"),
2263
+ {"cudagraph_mode": "FULL_AND_PIECEWISE"},
2264
+ "vLLM cudagraph_mode (verl rollout default)",
2265
+ )
2266
+ # Adapter init: continue training the SFT adapter (peft_config=None, model is the
2267
+ # loaded PeftModel) when train.init_from_adapter is set, else a fresh LoRA on the
2268
+ # string model id (model_init_kwargs forces bf16 — TRL string-loading can fall back
2269
+ # to fp32 and double VRAM).
2270
+ init_model, init_peft = _init_adapter_model(model_id)
2271
+ # chalk's kernels are applied AFTER construction (below) against trainer.model: chalk's apply
2272
+ # patches the LIVE nn.Module, so there is nothing to install pre-build. On the fresh-LoRA path
2273
+ # init_model is just the model-id string (TRL builds the module), and even on the
2274
+ # continue-adapter path TRL may rebuild/wrap the PeftModel, so trainer.model is the
2275
+ # authoritative target.
2276
+ if init_peft is not None:
2277
+ # Fresh LoRA: TRL loads the string model id with these kwargs, then attaches the
2278
+ # adapter. Force bf16 (TRL string-loading can fall back to fp32 and double VRAM).
2279
+ _attn = optimal_attn_impl() # arch-aware FlashAttention (Kernels Hub) / SDPA
2280
+ grpo_kwargs["model_init_kwargs"] = {"dtype": "bfloat16"}
2281
+ if _attn:
2282
+ grpo_kwargs["model_init_kwargs"]["attn_implementation"] = _attn
2283
+ else:
2284
+ _attn = optimal_attn_impl()
2285
+ # stop_sequences: TRL forwards generation_kwargs to the (vLLM) sampler, whose
2286
+ # SamplingParams.stop truncates each rollout at the requested delimiter — so the reward
2287
+ # sees the same completion the config intends, instead of generating to max_completion.
2288
+ if _t and _t.stop_sequences:
2289
+ grpo_kwargs["generation_kwargs"] = {"stop": list(_t.stop_sequences)}
2290
+ # advantage_clip>0 is the datums centered-advantage clamp; TRL has no advantage-value
2291
+ # clip knob (it clips the importance ratio), so honor the default (clip off ==
2292
+ # centered) and surface a note when a config asks for an explicit clamp.
2293
+ if _adv_clip > 0:
2294
+ print(f"[rl] advantage_clip={_adv_clip} recorded; TRL centers advantages (no value clip)")
2295
+ # num_iterations (the one promoted GRPO speed lever, measured 1.38x faster) is feature-detected
2296
+ # so an older TRL that lacks the field is simply skipped (GRPOConfig rejects unknown kwargs).
2297
+ # Generation dominates GRPO wall-clock, so reusing each rollout batch for 2 optimizer steps is
2298
+ # the cheapest large speedup; mu=2 is the standard GRPO config and TRL's importance-sampling
2299
+ # correction (on by default) keeps the step stable. (The GSPO/DAPO A/B levers were dropped: the
2300
+ # framework-scan in gpu-bench/RESEARCH_FINDINGS.md measured no robust win over baseline.)
2301
+ import dataclasses as _dc
2302
+
2303
+ try:
2304
+ _grpo_fields = {f.name for f in _dc.fields(GRPOConfig)}
2305
+ except TypeError:
2306
+ _grpo_fields = set() # not a dataclass on this TRL -> skip the feature-detected knob
2307
+ if "num_iterations" in _grpo_fields:
2308
+ grpo_kwargs["num_iterations"] = 2
2309
+ print("[rl] rollout amortization: num_iterations=2 (reuse each generation batch)")
2310
+ # truncated importance sampling (tis): trl's grpo applies an importance-sampling correction by
2311
+ # default, but with mode="sequence_mask" and clip_max=3.0. the verl/openrlhf recipe for the
2312
+ # rollout(vllm)-vs-training token-distribution mismatch is TOKEN-LEVEL truncated is with the
2313
+ # per-token ratio clipped at c=2 (verl rollout_is_threshold=2.0). adopt that recipe here:
2314
+ # token_truncate + c_max=2.0. feature-detected against this trl's GRPOConfig fields (canonical
2315
+ # clip field first, then the pre-2.0 deprecated alias), so a trl that lacks a field is skipped.
2316
+ # note: this deliberately changes trl's defaults (sequence_mask / 3.0) to the recipe values.
2317
+ if "vllm_importance_sampling_mode" in _grpo_fields:
2318
+ grpo_kwargs["vllm_importance_sampling_mode"] = "token_truncate"
2319
+ print("[rl] tis mode=token_truncate (token-level truncated importance sampling)")
2320
+ _tis_c = 2.0
2321
+ _tis_clip_field = next(
2322
+ (
2323
+ f
2324
+ for f in ("vllm_importance_sampling_clip_max", "vllm_importance_sampling_cap")
2325
+ if f in _grpo_fields
2326
+ ),
2327
+ None,
2328
+ )
2329
+ if _tis_clip_field:
2330
+ grpo_kwargs[_tis_clip_field] = _tis_c
2331
+ print(f"[rl] tis clip c_max={_tis_c} ({_tis_clip_field})")
2332
+ else:
2333
+ print("[rl] tis: trl default importance-sampling correction in effect; no clip field on this trl")
2334
+ cfg = GRPOConfig(**grpo_kwargs)
2335
+ setup_seconds = time.time() - t_start
2336
+ heartbeat("rl_train_start", setup_seconds=setup_seconds, gpu=gpu_diagnostics())
2337
+
2338
+ # VL checkpoints (Qwen3.5/3.6) train text-only: make TRL's colocated rollout
2339
+ # engine skip the vision tower (VRAM + 5090 PTX-compat; see the patch docstring).
2340
+ # Only relevant when vLLM drives rollouts; transformers generation uses the trainer
2341
+ # model (already text-only via the LoRA target/exclude config).
2342
+ if use_vllm:
2343
+ patch_vllm_language_model_only(model_id)
2344
+ # Install (but do NOT yet activate) the TRL->vLLM weight-sync name remap for Qwen3.5/3.6:
2345
+ # the trainer pushes ``model.*`` names but the VL engine's LM params live under
2346
+ # ``language_model.*``, so the first sync_weights() would raise without this. Activated
2347
+ # below, after the trainer + its initial checkpoint load are built.
2348
+ patch_vllm_lm_weight_sync(model_id)
2349
+ hb_cb = make_reward_heartbeat_callback()
2350
+ # Multi-turn / tool wiring (trl 1.6): tool envs hand TRL the tool callables so it runs the
2351
+ # tool-call loop natively; pure multi-turn envs hand TRL a rollout_func that drives the
2352
+ # env's own turn loop on the colocate engine (env_mask masks the non-model tokens).
2353
+ extra_trainer_kwargs: dict = {}
2354
+ tools = env.tools() if is_tool_env else []
2355
+ # A tool env exposing NO tools would silently degrade to single-shot under tools=[]; drive
2356
+ # it through the rollout_func turn loop instead so it isn't mis-trained as single-turn.
2357
+ if is_tool_env and not tools:
2358
+ print("[rl][warn] tool env exposes no tools — using the multi-turn rollout_func path")
2359
+ use_rollout_func = is_multi_turn and not (is_tool_env and tools)
2360
+ require_vllm_for_rollout_func(use_rollout_func, use_vllm, model_id)
2361
+ if is_tool_env and tools:
2362
+ extra_trainer_kwargs["tools"] = tools
2363
+ print(f"[rl] tool env: handing {len(tools)} tool(s) to TRL's native tool loop")
2364
+ if use_rollout_func:
2365
+ from flash.engine.multiturn_rollout import (
2366
+ build_examples_index,
2367
+ build_rollout_func,
2368
+ index_collisions,
2369
+ )
2370
+
2371
+ examples_by_key = build_examples_index(train, env.prompt_messages)
2372
+ ncol = index_collisions(train, env.prompt_messages)
2373
+ if ncol:
2374
+ print(
2375
+ f"[rl][warn] {ncol} duplicate prompt(s) collide in the reward index; the shared "
2376
+ "prompt scores against the last example's answer/info"
2377
+ )
2378
+ extra_trainer_kwargs["rollout_func"] = build_rollout_func(
2379
+ active_env=env,
2380
+ tok=tok,
2381
+ examples_by_key=examples_by_key,
2382
+ max_completion=_max_completion,
2383
+ max_turns=getattr(env, "max_turns", 10),
2384
+ temperature=_temperature,
2385
+ top_p=rl.sampling_top_p,
2386
+ stop=(list(_t.stop_sequences) if _t and _t.stop_sequences else None),
2387
+ thinking=THINKING,
2388
+ engine_max_len=vllm_max_len,
2389
+ )
2390
+ print("[rl] multi-turn env: driving the turn loop via rollout_func")
2391
+ # GRPOTrainer.__init__ blocks during model/vLLM init + FA2 kernel compilation (can be
2392
+ # 10-20 min on first use). Background heartbeats keep the stall detector quiet.
2393
+ _rl_init_done = threading.Event()
2394
+
2395
+ def _rl_init_heartbeat() -> None:
2396
+ while not _rl_init_done.wait(30.0):
2397
+ heartbeat("rl_initializing", gpu=gpu_diagnostics())
2398
+
2399
+ _rl_init_hb = threading.Thread(target=_rl_init_heartbeat, daemon=True)
2400
+ _rl_init_hb.start()
2401
+ try:
2402
+ trainer = GRPOTrainer(
2403
+ model=init_model,
2404
+ args=cfg,
2405
+ train_dataset=ds,
2406
+ reward_funcs=reward_fn,
2407
+ peft_config=init_peft,
2408
+ processing_class=tok,
2409
+ callbacks=[hb_cb, make_checkpoint_upload_callback()],
2410
+ **extra_trainer_kwargs,
2411
+ )
2412
+ finally:
2413
+ _rl_init_done.set()
2414
+ # Apply chalk's gap-filling kernels (RoPE/LoRA-delta/embedding, like Liger) on the module
2415
+ # GRPOTrainer actually optimizes (trainer.model) — the fresh-LoRA path only passes the model-id
2416
+ # string to TRL, so trainer.model is the authoritative target. chalk composes on top of Liger.
2417
+ # Capture the install report so the engaged kernels land in metrics (active_kernels below).
2418
+ _chalk_report = install_chalk_kernels(getattr(trainer, "model", None))
2419
+ # Liger fused-loss chunk_size: TRL leaves it at the default 1, so the fused GRPO loss runs its
2420
+ # whole detach -> chunk_forward -> compiled-loss -> autograd.grad cycle ONCE PER SEQUENCE
2421
+ # (per_device_train_batch_size times) — Python/kernel-launch/compile-guard overhead that
2422
+ # dominates at small-model scale where the GEMMs are tiny. Collapse it to ONE invocation over the
2423
+ # whole per-device micro-batch. Numerically identical (every loss_type normalizes by the GLOBAL
2424
+ # token count, not the chunk-local size, and chunk losses are summed). Must run BEFORE the
2425
+ # mask-aware wrap below, which replaces trainer.liger_grpo_loss with a closure that has no
2426
+ # chunk_size attribute.
2427
+ _liger_loss = getattr(trainer, "liger_grpo_loss", None)
2428
+ if _liger_loss is not None and hasattr(_liger_loss, "chunk_size"):
2429
+ _cs = max(1, int(getattr(trainer.args, "per_device_train_batch_size", 1)))
2430
+ if _cs > int(getattr(_liger_loss, "chunk_size", 1)):
2431
+ _liger_loss.chunk_size = _cs
2432
+ print(f"[rl] liger fused-loss chunk_size -> {_cs} (one invocation, not one per sequence)")
2433
+ # Run liger's fused GRPO loss EAGER: drop ONLY its torch.compile (BROKEN on torch 2.10 — its
2434
+ # dynamo guard-gen trips a symbol_to_source IndexError that crashes the first GRPO step on every
2435
+ # path), keep the chunked memory path that prevents the 248k-vocab fp32-logit OOM. Must run BEFORE
2436
+ # the mask-aware wrap below, which replaces trainer.liger_grpo_loss with a closure. See the helper.
2437
+ if disable_liger_grpo_torch_compile(trainer):
2438
+ print(
2439
+ "[rl] liger GRPO loss: torch.compile DISABLED (eager loss math; chunked memory path "
2440
+ "retained) — dodges the torch 2.10 dynamo guard-gen crash (symbol_to_source IndexError)"
2441
+ )
2442
+ # Mask-aware lm_head: skip the 248k-vocab projection at MASKED completion positions in the GRPO
2443
+ # loss — its most expensive op, and the trainer step dominates train_wall. For MULTI-TURN that
2444
+ # masked set is the ~half-to-most of the transcript that is env/tool text; for SINGLE-TURN it is
2445
+ # the right-PADDING (GRPO samples variable-length completions, padded to the batch max). Either
2446
+ # way those positions add zero loss/gradient but pay full FLOPs. Loss-preserving; applies to ALL
2447
+ # GRPO with the Liger fused loss; no-op when nothing is masked (uniform-length single-turn).
2448
+ if grpo_kwargs.get("use_liger_kernel") and patch_grpo_mask_aware_lm_head(trainer):
2449
+ _masked_kind = "env + padding" if use_rollout_func else "padding"
2450
+ print(f"[rl] mask-aware lm_head: skipping masked ({_masked_kind}) positions in the GRPO loss")
2451
+ # The trainer (and its colocated vLLM engine + initial checkpoint load) is now built. Activate
2452
+ # the TRL->vLLM weight-sync name remap ONLY now (see patch_vllm_lm_weight_sync) so the initial
2453
+ # checkpoint load stayed untouched while the train-time syncs get remapped. No-op unless the VL
2454
+ # patch above was installed.
2455
+ if use_vllm:
2456
+ _LM_SYNC_REMAP_ON["on"] = True
2457
+ if is_vl_checkpoint(model_id):
2458
+ print("[vllm] LM weight-sync remap activated for training syncs")
2459
+ # Mid-run eval is intentionally NOT run during training: held-out evaluation happens on the
2460
+ # deploy/serving side (against the trained adapter), keeping training pure (no eval-phase cost
2461
+ # or eval-boundary stalls). Training streams only the per-step reward heartbeat.
2462
+ _reset_peak_gpu() # peak_gpu_gb reflects the train loop (verifies the micro-batch headroom)
2463
+ _gpu_sampler = _GpuPeakSampler().start() # true device peak incl. vLLM colocate + bnb pages
2464
+ t_train = time.time()
2465
+ with _sdpa_cudnn_ctx(_attn): # force cuDNN SDPA on sm120 (no-op otherwise)
2466
+ trainer.train(resume_from_checkpoint=resume_ckpt)
2467
+ train_wall = time.time() - t_train
2468
+ rl_peak_gpu_gb = _peak_gpu_gb()
2469
+ rl_device_peak_gpu_gb = _gpu_sampler.stop_gb()
2470
+ reward_history = list(getattr(hb_cb, "reward_history", []))
2471
+ # A GRPO run that finishes WITHOUT the reward callback ever firing (empty reward_history)
2472
+ # produced NO real training — the rollout scored nothing (e.g. vLLM generation silently
2473
+ # returning no completions, observed on RTX 5090 / sm120: ~1.4 s wall, empty reward + loss
2474
+ # curves, but the run otherwise "succeeds"). That is a FAILURE, not a success: a no-op run with
2475
+ # an unchanged adapter must not be reported as done — fail loudly so the operator/agent doesn't
2476
+ # trust it. (An env returning all-zero rewards still appends 0.0s, so an EMPTY history uniquely
2477
+ # means the reward path never ran.)
2478
+ _steps_run = int(getattr(trainer.state, "global_step", 0) or 0)
2479
+ # A resume that already reached the target steps legitimately performs ZERO new optimizer
2480
+ # steps: the previous worker uploaded the final checkpoint (and scored its rewards) but died
2481
+ # before writing metrics/DONE, so this worker's fresh hb_cb has an empty reward_history even
2482
+ # though the policy IS fully trained. Don't fail those — finalize from the resumed state. The
2483
+ # no-op guard below is only for a run that genuinely trained nothing (no resume, or the resume
2484
+ # didn't reach the target steps).
2485
+ _resumed_complete = _grpo_resume_already_complete(resume_ckpt, steps, _steps_run)
2486
+ if _grpo_is_no_op_failure(reward_history, resume_ckpt, steps, _steps_run):
2487
+ if _steps_run == 0:
2488
+ raise RuntimeError(
2489
+ "GRPO trainer completed zero optimizer steps before any reward was scored. "
2490
+ f"retained_prompts={len(prompts)}, prompts_per_step={prompts_per_step}, "
2491
+ f"generations_per_step={batching['generations_per_step']}. This usually means "
2492
+ "TRL built an empty dataloader; add training examples, lower [train].batch_size, "
2493
+ "or reduce prompt length/max_tokens so more examples fit."
2494
+ )
2495
+ raise RuntimeError(
2496
+ f"GRPO scored no reward in {train_wall:.1f}s over {_steps_run} step(s) — the rollout "
2497
+ "produced no completions, so the policy was never actually trained. Failing loudly "
2498
+ "instead of reporting a no-op run as done (seen on RTX 5090/sm120 vLLM rollout)."
2499
+ )
2500
+ if not reward_history and _resumed_complete:
2501
+ print(
2502
+ f"[resume] no new reward in this worker but resumed checkpoint already reached "
2503
+ f"{_steps_run}/{steps} step(s) — finalizing the completed policy instead of failing."
2504
+ )
2505
+ adapter_dir = f"{out_dir}/adapter"
2506
+ trainer.model.save_pretrained(adapter_dir)
2507
+ tok.save_pretrained(adapter_dir)
2508
+ hf_upload_folder(adapter_dir, "adapter", required=True)
2509
+ heartbeat("rl_trained", train_wall=train_wall, gpu=gpu_diagnostics())
2510
+
2511
+ # Upper bound on generated tokens: completions actually optimized (the intended
2512
+ # prompts_per_step after the batch fix) x the max completion length. Over-counts (most
2513
+ # completions are shorter); reported as an upper bound, used only for a rough throughput.
2514
+ gen_tokens = steps * batching["unique_prompts_per_step"] * group_size * _max_completion
2515
+ write_train_meta(
2516
+ phase="rl",
2517
+ adapter_dir=adapter_dir,
2518
+ model_id=model_id,
2519
+ train_wall=train_wall,
2520
+ setup_seconds=setup_seconds,
2521
+ train_tokens=0,
2522
+ generated_tokens=gen_tokens,
2523
+ notes={
2524
+ "steps": steps,
2525
+ "resumed": bool(resume_ckpt),
2526
+ "download_seconds": download_seconds,
2527
+ "hf_transfer": os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", ""),
2528
+ "reward_history": reward_history,
2529
+ "loss_curve": _metric_curve(trainer, "loss"),
2530
+ # Peak torch-allocated GPU memory during the GRPO train loop (excludes bnb managed
2531
+ # pages). device_peak_gpu_gb is the TRUE device footprint (total-free, incl. the vLLM
2532
+ # colocate engine + bnb pages): the headline for verifying the per-device micro-batch
2533
+ # left the card with headroom (no OOM) at the sized batch.
2534
+ "peak_gpu_gb": rl_peak_gpu_gb,
2535
+ "device_peak_gpu_gb": rl_device_peak_gpu_gb,
2536
+ # Which chalk gap-filling kernels actually ENGAGED (None = chalk not installed or every
2537
+ # kernel fell back) — verifies the chalk stack on a GRPO run without the console.
2538
+ "chalk_kernels": active_kernels(_chalk_report) or None,
2539
+ **wandb_run_info(),
2540
+ "gen_tokens_is_upper_bound": True,
2541
+ "thinking": THINKING,
2542
+ "max_completion_len": _max_completion,
2543
+ "prompts_per_step": batching["unique_prompts_per_step"],
2544
+ "generations_per_step": batching["generations_per_step"],
2545
+ "group_size": group_size,
2546
+ "per_device_train_batch_size": batching["per_device_train_batch_size"],
2547
+ "gradient_accumulation_steps": batching["gradient_accumulation_steps"],
2548
+ "grpo_recipe": {
2549
+ "lr_scheduler": "constant",
2550
+ "beta": _kl_beta,
2551
+ "scale_rewards": "none",
2552
+ "loss_type": "dr_grpo",
2553
+ "temperature": _temperature,
2554
+ "advantage_clip": _adv_clip,
2555
+ "thinking_length_penalty_coef": _think_penalty,
2556
+ "init_from_adapter": JOB_SPEC.train.init_from_adapter if JOB_SPEC else "",
2557
+ },
2558
+ },
2559
+ )
2560
+ free_gpu(trainer)
2561
+
2562
+
2563
+ # ---------------------------------------------------------------------------
2564
+ # Completion: train phase writes metrics.json + the DONE sentinel (see _finalize).
2565
+ # ---------------------------------------------------------------------------
2566
+
2567
+
2568
+ def write_train_meta(
2569
+ phase, adapter_dir, model_id, train_wall, setup_seconds, train_tokens, generated_tokens, notes
2570
+ ):
2571
+ env = require_active_env()
2572
+ meta = {
2573
+ "phase": phase,
2574
+ "adapter_dir": adapter_dir,
2575
+ "model_id": model_id,
2576
+ "train_wall": train_wall,
2577
+ "setup_seconds": setup_seconds,
2578
+ "train_tokens": train_tokens,
2579
+ "generated_tokens": generated_tokens,
2580
+ "notes": notes or {},
2581
+ }
2582
+ with open("/tmp/train_meta.json", "w") as f:
2583
+ json.dump(meta, f)
2584
+ hf_upload_file("/tmp/train_meta.json", "train_meta.json")
2585
+ heartbeat(
2586
+ f"{phase}_train_done",
2587
+ **{k: meta[k] for k in ("train_wall", "train_tokens", "generated_tokens")},
2588
+ gpu=gpu_diagnostics(),
2589
+ )
2590
+ # Finalize directly from the training phase: build the run-metrics record (training
2591
+ # metrics only — loss/reward are streamed by the trainer; reward_history is in notes)
2592
+ # and write the completion sentinel. There is no separate eval phase.
2593
+ m = RunMetrics(
2594
+ # Substrate the worker actually ran on. The RunPod launcher sets FLASH_ARM; default to
2595
+ # "runpod" when unset so persisted metrics correctly attribute the compute backend.
2596
+ arm=os.environ.get("FLASH_ARM", "runpod"),
2597
+ phase=phase,
2598
+ seed=SEED,
2599
+ model_id=model_id,
2600
+ wall_seconds=train_wall,
2601
+ setup_seconds=setup_seconds,
2602
+ train_throughput_toks_per_s=(
2603
+ (generated_tokens or train_tokens) / train_wall if train_wall else 0.0
2604
+ ),
2605
+ train_tokens=train_tokens,
2606
+ generated_tokens=generated_tokens,
2607
+ notes={
2608
+ **(notes or {}),
2609
+ "renderer": "flash_env",
2610
+ "thinking": THINKING,
2611
+ "train_wall": train_wall,
2612
+ "model_id": model_id,
2613
+ "environment": env.id,
2614
+ "job_spec": JOB_SPEC.to_dict() if JOB_SPEC else None,
2615
+ },
2616
+ )
2617
+ _finalize(m)
2618
+
2619
+
2620
+ def _resolve_adapter_ref(adapter_ref: str) -> tuple[str, str] | None:
2621
+ """Resolve init_from_adapter into (repo, prefix).
2622
+
2623
+ The only public form is the exact adapter_ref emitted by ``flash status``:
2624
+ ``<owner>/<repo>:<phase>/<run_id>/seed<N>``.
2625
+ """
2626
+ adapter_ref = adapter_ref.strip()
2627
+ match = re.fullmatch(
2628
+ r"(?P<repo>[A-Za-z0-9][A-Za-z0-9._-]*/[A-Za-z0-9][A-Za-z0-9._-]*):"
2629
+ r"(?P<phase>sft|rl)/(?P<run_id>[A-Za-z0-9][A-Za-z0-9._-]{0,127})/seed(?P<seed>\d+)",
2630
+ adapter_ref,
2631
+ )
2632
+ if not match:
2633
+ return None
2634
+ repo, phase, run_id, seed = match.groups()
2635
+ return repo, f"{phase}/{run_id}/seed{seed}"
2636
+
2637
+
2638
+ def _download_adapter(adapter_prefix: str | None) -> str | None:
2639
+ """Download an init_from_adapter LoRA to /tmp/evdl/<prefix>/adapter and return its dir.
2640
+
2641
+ ``adapter_prefix`` must be the full ``adapter_ref`` string emitted by ``flash status``:
2642
+ ``<owner>/<repo>:<phase>/<run_id>/seed<N>``.
2643
+ """
2644
+ if not adapter_prefix:
2645
+ return None
2646
+ resolved = _resolve_adapter_ref(adapter_prefix)
2647
+ if not resolved:
2648
+ return None
2649
+ repo, prefix = resolved
2650
+ from huggingface_hub import snapshot_download
2651
+
2652
+ snapshot_download(
2653
+ repo_id=repo,
2654
+ repo_type="dataset",
2655
+ allow_patterns=[f"{prefix}/adapter/*"],
2656
+ local_dir="/tmp/evdl",
2657
+ token=os.environ.get("HF_TOKEN"),
2658
+ )
2659
+ adir = os.path.join("/tmp/evdl", prefix, "adapter")
2660
+ return adir if os.path.isdir(adir) else None
2661
+
2662
+
2663
+ def _finalize(metrics: RunMetrics):
2664
+ metrics.save("/tmp/metrics.json")
2665
+ # Required: a swallowed upload would make the control plane fail/retry a finished run.
2666
+ hf_upload_file("/tmp/metrics.json", "metrics.json", required=True)
2667
+ # DONE sentinel so the controller knows it's safe to tear down
2668
+ with open("/tmp/DONE", "w") as f:
2669
+ f.write(str(time.time()))
2670
+ hf_upload_file("/tmp/DONE", "DONE", required=True)
2671
+ heartbeat("done", gpu=gpu_diagnostics())
2672
+ print("NODE DONE:", metrics.to_json())
2673
+
2674
+
2675
+ # How long to wait for wandb.finish() to flush. On SUCCESS the full run must sync (a slow network /
2676
+ # large run can exceed the old 5s and leave the run "crashed"), so give it a generous-but-bounded
2677
+ # window; on FAILURE abort fast (the run is failing regardless and the worker is hard-exiting).
2678
+ _WANDB_FINISH_WAIT_S = 120.0
2679
+ _WANDB_FINISH_FAIL_WAIT_S = 5.0
2680
+
2681
+
2682
+ # Baked compiled-kernel cache (opt-in; see Dockerfile.worker + flash/engine/worker/kernel_warmup.py).
2683
+ # The Dockerfile points TRITON_CACHE_DIR/TORCHINDUCTOR_CACHE_DIR here and, when built with
2684
+ # --build-arg BUILD_KERNEL_CACHE=true, bakes a portable mega-cache produced on a real GPU. These
2685
+ # names are kept in lockstep with kernel_warmup.DEFAULT_CACHE_DIR / MEGA_CACHE_FILENAME.
2686
+ _KERNEL_CACHE_DIR = "/opt/flash/kernelcache"
2687
+ _KERNEL_CACHE_FILE = os.path.join(_KERNEL_CACHE_DIR, "mega_cache.bin")
2688
+ _KERNEL_CACHE_META_FILE = os.path.join(_KERNEL_CACHE_DIR, "mega_cache.json")
2689
+
2690
+
2691
+ def _current_cuda_sm(torch) -> str | None:
2692
+ try:
2693
+ if not torch.cuda.is_available():
2694
+ return None
2695
+ cap = torch.cuda.get_device_capability(0)
2696
+ return f"sm{cap[0]}{cap[1]}"
2697
+ except Exception:
2698
+ return None
2699
+
2700
+
2701
+ def _load_kernel_cache_if_present() -> bool:
2702
+ """Best-effort: if a baked mega-cache blob exists, load it so the worker skips first-run JIT.
2703
+
2704
+ Loads the portable cache that kernel_warmup.py wrote on a GPU builder via
2705
+ ``torch.compiler.load_cache_artifacts()`` — measured cold compile ~124s -> warm load ~0.2s.
2706
+ OPT-IN: when no baked cache is present (the default image build), this is a no-op and the worker
2707
+ JITs on first use exactly as before (#163's init heartbeat covers that stall). Never raises:
2708
+ a missing torch / missing file / unusable blob just logs and leaves the JIT path intact.
2709
+ """
2710
+ def _reject(reason: str) -> bool:
2711
+ # a baked cache is present but unusable (no/garbled metadata or wrong arch): repoint
2712
+ # triton/inductor OFF the baked trees (Dockerfile points them at /opt/flash/kernelcache)
2713
+ # so the JIT fallback compiles fresh into scratch instead of reusing wrong-arch baked
2714
+ # entries that would collide with this worker's arch.
2715
+ print(f"[kernel-cache] {reason} -> first-run JIT fallback")
2716
+ scratch = os.path.join(tempfile.gettempdir(), "flash-kernelcache-jit")
2717
+ for sub, var in (("triton", "TRITON_CACHE_DIR"), ("inductor", "TORCHINDUCTOR_CACHE_DIR")):
2718
+ d = os.path.join(scratch, sub)
2719
+ os.makedirs(d, exist_ok=True)
2720
+ os.environ[var] = d
2721
+ return False
2722
+
2723
+ if not os.path.isfile(_KERNEL_CACHE_FILE):
2724
+ print(f"[kernel-cache] no baked cache at {_KERNEL_CACHE_FILE} -> first-run JIT (expected default)")
2725
+ return False
2726
+ try:
2727
+ import torch
2728
+
2729
+ current_sm = _current_cuda_sm(torch)
2730
+ try:
2731
+ with open(_KERNEL_CACHE_META_FILE) as f:
2732
+ meta = json.load(f)
2733
+ except FileNotFoundError:
2734
+ return _reject("baked cache has no metadata")
2735
+ except Exception as e:
2736
+ return _reject(f"metadata unreadable ({e})")
2737
+ cached_sm = str(meta.get("sm") or "")
2738
+ if not current_sm:
2739
+ # can't verify the worker's GPU arch -> don't risk loading a wrong-arch blob; JIT instead.
2740
+ return _reject("worker GPU arch undetermined")
2741
+ if cached_sm != current_sm:
2742
+ return _reject(
2743
+ f"baked cache arch {cached_sm or 'unknown'} does not match worker arch {current_sm}"
2744
+ )
2745
+ with open(_KERNEL_CACHE_FILE, "rb") as f:
2746
+ blob = f.read()
2747
+ torch.compiler.load_cache_artifacts(blob)
2748
+ print(
2749
+ f"[kernel-cache] loaded baked mega-cache for {cached_sm or 'unknown'} "
2750
+ f"({len(blob)} bytes) -> skipping first-run JIT"
2751
+ )
2752
+ return True
2753
+ except Exception as e:
2754
+ # never block boot on a bad/absent cache: fall back to the normal JIT path. repoint off the
2755
+ # baked trees too — if the mega blob was present + arch-matched but load raised, the on-disk
2756
+ # triton/inductor entries may be partial/corrupt, so JIT fresh into scratch.
2757
+ return _reject(f"load skipped ({e})")
2758
+
2759
+
2760
+ def wandb_finish(exit_code: int = 0) -> None:
2761
+ """Finalize the W&B run before the worker's hard ``os._exit()``.
2762
+
2763
+ The worker hard-exits to dodge the colocated-vLLM teardown deadlock (see main),
2764
+ which skips wandb's atexit sync — so a *successfully completed* run was left
2765
+ dangling and W&B eventually marked it ``crashed`` even though all metrics were
2766
+ logged. Explicitly finish the run (we own it: we called ``wandb.init`` in
2767
+ ``wandb_report_to``) so it shows ``finished``. Best-effort; never raises (W&B is
2768
+ optional, metrics.json is the source of truth)."""
2769
+ if not os.environ.get("WANDB_API_KEY"):
2770
+ return
2771
+ import importlib.util
2772
+
2773
+ # find_spec can RAISE (not just return None) when wandb is already in sys.modules with an
2774
+ # absent/partial __spec__ (e.g. a namespace-package or a partially-initialized import) — that
2775
+ # would propagate out of the shutdown path and skip the hard exit. Keep it best-effort: treat any
2776
+ # probe failure as "wandb present enough to try", and let the import + finish below (already
2777
+ # wrapped) decide. Only a definitive None (probe succeeded, module truly absent) returns early.
2778
+ try:
2779
+ if importlib.util.find_spec("wandb") is None:
2780
+ return
2781
+ except Exception:
2782
+ pass # ambiguous probe -> fall through and try to finish (still fully guarded below)
2783
+ try:
2784
+ import wandb
2785
+
2786
+ if getattr(wandb, "run", None) is None:
2787
+ return
2788
+
2789
+ errs: list[Exception] = []
2790
+
2791
+ def _finish() -> None:
2792
+ try:
2793
+ wandb.finish(exit_code=exit_code)
2794
+ except Exception as e:
2795
+ errs.append(e)
2796
+
2797
+ t = threading.Thread(target=_finish, daemon=True)
2798
+ t.start()
2799
+ # On SUCCESS (exit_code == 0) wandb.finish() must flush the full run; a slow network / large
2800
+ # run can take well over 5s, and cutting it off there is what leaves the run dangling ->
2801
+ # "crashed". Allow a longer, still-bounded wait on success; keep the short cut-off on the
2802
+ # FAILURE path (exit_code != 0) where we want to abort fast and the run is failing anyway.
2803
+ wait_s = _WANDB_FINISH_WAIT_S if exit_code == 0 else _WANDB_FINISH_FAIL_WAIT_S
2804
+ t.join(timeout=wait_s)
2805
+ if t.is_alive():
2806
+ print(f"[wandb] finish() did not complete within {wait_s}s; continuing with hard exit")
2807
+ elif errs:
2808
+ print(f"[wandb] finish() warning: {errs[0]}")
2809
+ except Exception as e: # pragma: no cover - logging-only path
2810
+ print(f"[wandb] finish() warning: {e}")
2811
+
2812
+
2813
+ def main():
2814
+ # Idempotency: if DONE was already uploaded, a re-delivered job re-fetches the final
2815
+ # metrics from HF and returns them immediately. (The previous behavior — sleeping in
2816
+ # an infinite loop — kept a billable GPU worker alive until the execution timeout.)
2817
+ try:
2818
+ # Idempotency FIRST — before any env-mutating pip install / package removal: a re-delivered
2819
+ # job whose DONE already exists must return the persisted metrics and exit WITHOUT running
2820
+ # _ensure_fla_fastpath_on_hopper() (mutates the env: pip-installs tilelang/fla) — that wasted
2821
+ # a worker mutating its env on an already-complete run. It runs after the DONE check below.
2822
+ if HF_REPO:
2823
+ from huggingface_hub import hf_hub_download
2824
+
2825
+ try:
2826
+ hf_hub_download(
2827
+ repo_id=HF_REPO,
2828
+ repo_type="dataset",
2829
+ filename=f"{hf_prefix()}/DONE",
2830
+ token=os.environ.get("HF_TOKEN"),
2831
+ )
2832
+ done = True
2833
+ except Exception:
2834
+ done = False
2835
+ if done:
2836
+ print("Run already complete (DONE present); returning persisted metrics.")
2837
+ heartbeat("already_done", gpu=gpu_diagnostics(include_torch=False))
2838
+ try:
2839
+ got = hf_hub_download(
2840
+ repo_id=HF_REPO,
2841
+ repo_type="dataset",
2842
+ filename=f"{hf_prefix()}/metrics.json",
2843
+ token=os.environ.get("HF_TOKEN"),
2844
+ )
2845
+ import shutil
2846
+
2847
+ shutil.copy(got, "/tmp/metrics.json")
2848
+ sys.stdout.flush()
2849
+ os._exit(0)
2850
+ except Exception as e:
2851
+ raise SystemExit(f"DONE present but metrics.json unavailable: {e}") from e
2852
+ # Not a DONE re-delivery -> this worker will train. These must run before any model import:
2853
+ _ensure_fla_fastpath_on_hopper() # Hopper: enable fla+tilelang GDN fast path (see perf.py)
2854
+ # Repoint tilelang's libcudart_stub.so at the real CUDA runtime so it can't shadow libcudart
2855
+ # in vLLM's CudaRTLibrary (intermittent `undefined symbol: cudaDeviceReset` on GRPO vLLM
2856
+ # init, any model size/arch). AFTER the fla fast path (a tilelang reinstall there rewrites
2857
+ # the stub) and BEFORE the model/vLLM import. See perf.py / flash #184.
2858
+ _neutralize_tilelang_cudart_stub()
2859
+ heartbeat("boot", gpu=gpu_diagnostics(include_torch=False))
2860
+ finalize_alloc_conf_for_sleep() # sync CUDA alloc conf to resolved sleep (before first CUDA alloc)
2861
+ # Opt-in: load a baked compiled-kernel mega-cache (if the image shipped one) so the worker
2862
+ # skips the ~10-15 min first-run JIT. Best-effort + no-op when absent (the default), so the
2863
+ # normal JIT path is untouched. Runs AFTER finalize_alloc_conf_for_sleep: _load probes CUDA
2864
+ # (_current_cuda_sm -> get_device_capability triggers CUDA init), so the allocator conf must be
2865
+ # resolved first; still before any model/kernel import that would otherwise trigger compilation.
2866
+ _load_kernel_cache_if_present()
2867
+ # Dispatch table — register new algorithms (e.g. ppo) here as they land.
2868
+ modes = {
2869
+ "sft": run_sft, # SFT (TRL SFTTrainer)
2870
+ "rl": run_rl, # GRPO (TRL GRPOTrainer + colocated vLLM)
2871
+ }
2872
+ handler = modes.get(RUN_MODE)
2873
+ if handler is None:
2874
+ raise SystemExit(f"unknown RUN_MODE {RUN_MODE}; known: {sorted(modes)}")
2875
+ handler()
2876
+ # All artifacts (adapter, train_meta, metrics, DONE) are uploaded to HF *inside* the
2877
+ # handler. The RL trainer's colocated vLLM can DEADLOCK at interpreter shutdown
2878
+ # during NCCL/IPC/CUDA teardown — not segfault-and-exit (which `check=False` on the
2879
+ # train subprocess already tolerates), but hang forever. That would block the Flash
2880
+ # handler's *blocking* `subprocess.run` (heartbeat frozen at "rl_train_done") and the
2881
+ # whole run stalls until the wall-clock cap. Hard-exit to bypass the hanging teardown now that
2882
+ # every output is safely persisted.
2883
+ wandb_finish(exit_code=0) # mark the W&B run finished BEFORE os._exit (which skips wandb's atexit sync)
2884
+ sys.stdout.flush()
2885
+ sys.stderr.flush()
2886
+ os._exit(0)
2887
+ except Exception as e:
2888
+ # Structured retry signal both pollers read: an infra failure -> retry on a fresh worker.
2889
+ # GitHubRateLimitError (env ref resolution hit a persistent GitHub rate limit) is retriable:
2890
+ # reschedule on a fresh worker once the limit window resets rather than hard-failing. Env
2891
+ # resolution runs lazily inside this try (require_active_env, called by the handlers above),
2892
+ # never at import, so a rate-limit raise reaches here and is classified correctly.
2893
+ retriable = isinstance(e, (RetriableInfraError, GitHubRateLimitError))
2894
+ tb = traceback.format_exc()
2895
+ traceback.print_exc()
2896
+ try:
2897
+ err_name = error_artifact_name(RUN_MODE)
2898
+ err_path = f"/tmp/{err_name}"
2899
+ with open(err_path, "w") as f:
2900
+ f.write(tb)
2901
+ hf_upload_file(err_path, err_name)
2902
+ except Exception as up_err:
2903
+ print("error-upload warn:", up_err)
2904
+ hb_flags = {"retriable": retriable}
2905
+ try:
2906
+ heartbeat(f"error_{RUN_MODE}", error=str(e)[:500], **hb_flags, diag=gpu_diagnostics())
2907
+ except Exception:
2908
+ heartbeat(f"error_{RUN_MODE}", error=str(e)[:500], **hb_flags)
2909
+ # keep container alive briefly so logs flush, then exit non-zero -> restart
2910
+ wandb_finish(exit_code=1) # finalize the W&B run as failed (don't leave it dangling -> "crashed")
2911
+ time.sleep(10)
2912
+ raise
2913
+
2914
+
2915
+ if __name__ == "__main__":
2916
+ main()