freesolo-flash-dev 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flash/__init__.py +29 -0
- flash/_channel.py +23 -0
- flash/_fileio.py +35 -0
- flash/_logging.py +49 -0
- flash/_update_check.py +266 -0
- flash/catalog.py +253 -0
- flash/cli/__init__.py +1 -0
- flash/cli/main/__init__.py +227 -0
- flash/cli/main/__main__.py +6 -0
- flash/cli/main/commands.py +636 -0
- flash/cli/main/envpush.py +317 -0
- flash/cli/main/render.py +599 -0
- flash/cli/main/training_doc.py +455 -0
- flash/client/__init__.py +14 -0
- flash/client/config.py +70 -0
- flash/client/http.py +372 -0
- flash/client/runtime_secrets.py +69 -0
- flash/client/specs.py +20 -0
- flash/cost/__init__.py +16 -0
- flash/cost/analytical.py +175 -0
- flash/cost/facts.py +114 -0
- flash/cost/spec.py +113 -0
- flash/cost/types.py +158 -0
- flash/engine/__init__.py +6 -0
- flash/engine/accounting.py +36 -0
- flash/engine/chalk_kernels.py +116 -0
- flash/engine/multiturn_rollout.py +780 -0
- flash/engine/recipe.py +86 -0
- flash/engine/vram.py +603 -0
- flash/engine/worker/__init__.py +2916 -0
- flash/engine/worker/__main__.py +4 -0
- flash/engine/worker/kernel_warmup.py +400 -0
- flash/engine/worker/lora.py +796 -0
- flash/engine/worker/packing.py +366 -0
- flash/engine/worker/perf.py +1048 -0
- flash/envs/__init__.py +10 -0
- flash/envs/adapter/__init__.py +883 -0
- flash/envs/adapter/rubric.py +222 -0
- flash/envs/base.py +52 -0
- flash/envs/registry.py +62 -0
- flash/mcp/__init__.py +1 -0
- flash/mcp/server.py +85 -0
- flash/providers/__init__.py +59 -0
- flash/providers/_auth.py +24 -0
- flash/providers/_http.py +230 -0
- flash/providers/_instance.py +416 -0
- flash/providers/_instance_bootstrap.py +517 -0
- flash/providers/_poll.py +311 -0
- flash/providers/allocator.py +193 -0
- flash/providers/base.py +431 -0
- flash/providers/hyperstack/__init__.py +127 -0
- flash/providers/hyperstack/api.py +522 -0
- flash/providers/hyperstack/auth.py +17 -0
- flash/providers/hyperstack/gpus.py +29 -0
- flash/providers/hyperstack/jobs/__init__.py +632 -0
- flash/providers/hyperstack/jobs/builders.py +122 -0
- flash/providers/hyperstack/preflight.py +23 -0
- flash/providers/hyperstack/pricing.py +26 -0
- flash/providers/hyperstack/train.py +25 -0
- flash/providers/lambdalabs/__init__.py +139 -0
- flash/providers/lambdalabs/api.py +261 -0
- flash/providers/lambdalabs/auth.py +18 -0
- flash/providers/lambdalabs/gpus.py +29 -0
- flash/providers/lambdalabs/jobs/__init__.py +724 -0
- flash/providers/lambdalabs/jobs/builders.py +118 -0
- flash/providers/lambdalabs/preflight.py +27 -0
- flash/providers/lambdalabs/pricing.py +51 -0
- flash/providers/lambdalabs/train.py +27 -0
- flash/providers/preflight.py +55 -0
- flash/providers/realized.py +80 -0
- flash/providers/runpod/__init__.py +130 -0
- flash/providers/runpod/api.py +186 -0
- flash/providers/runpod/auth.py +37 -0
- flash/providers/runpod/cost.py +57 -0
- flash/providers/runpod/gpus.py +46 -0
- flash/providers/runpod/jobs.py +956 -0
- flash/providers/runpod/keys.py +139 -0
- flash/providers/runpod/preflight.py +30 -0
- flash/providers/runpod/preload.py +915 -0
- flash/providers/runpod/pricing.py +18 -0
- flash/providers/runpod/slots.py +79 -0
- flash/providers/runpod/train/__init__.py +150 -0
- flash/providers/runpod/train/deps.py +395 -0
- flash/providers/runpod/train/endpoints.py +820 -0
- flash/py.typed +0 -0
- flash/runner/__init__.py +686 -0
- flash/runner/checkpoints.py +82 -0
- flash/runner/deploy.py +422 -0
- flash/runner/lifecycle.py +672 -0
- flash/schema/__init__.py +375 -0
- flash/schema/fields.py +331 -0
- flash/serve/__init__.py +1 -0
- flash/serve/deploy.py +326 -0
- flash/serve/pricing.py +60 -0
- flash/server/__init__.py +1 -0
- flash/server/__main__.py +20 -0
- flash/server/app.py +961 -0
- flash/server/auth.py +263 -0
- flash/server/billing.py +124 -0
- flash/server/checkpoints.py +110 -0
- flash/server/db.py +160 -0
- flash/server/environment_registry.py +102 -0
- flash/server/envs.py +360 -0
- flash/server/reconcile.py +163 -0
- flash/server/run_registry.py +150 -0
- flash/spec.py +333 -0
- freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
- freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
- freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
- freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
- freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,2916 @@
|
|
|
1
|
+
"""On-GPU fine-tuning worker (RunPod). Modes: sft | rl.
|
|
2
|
+
|
|
3
|
+
This module runs on the provisioned RunPod GPU. It uses the shared recipe
|
|
4
|
+
(``flash.engine.recipe``) so SFT targets and RL rewards are rendered and scored
|
|
5
|
+
consistently.
|
|
6
|
+
|
|
7
|
+
Artifacts (adapter, metrics.json, heartbeat.json, checkpoints) are streamed to a
|
|
8
|
+
Hugging Face dataset repo. HF checkpoints give preemption resilience: if a worker is
|
|
9
|
+
recycled mid-run we resume from the latest uploaded checkpoint. Metrics are also
|
|
10
|
+
returned directly to the caller by the launching provider.
|
|
11
|
+
|
|
12
|
+
Core environment variables (set by the launching provider / runner):
|
|
13
|
+
RUN_MODE sft|rl
|
|
14
|
+
SEED int
|
|
15
|
+
HF_REPO Hugging Face dataset repo for artifacts, populated per-run from the
|
|
16
|
+
JobSpec's [train] hf_repo by whichever provider launches the worker
|
|
17
|
+
HF_TOKEN
|
|
18
|
+
RUN_ID unique id for this run (namespacing in the repo)
|
|
19
|
+
|
|
20
|
+
The FLASH_*/RL_*/SFT_* env vars are A/B overrides documented at their use sites; the
|
|
21
|
+
JobSpec [train] table is the source of truth for per-run knobs.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
import contextlib
|
|
27
|
+
import faulthandler
|
|
28
|
+
import json
|
|
29
|
+
import math
|
|
30
|
+
import os
|
|
31
|
+
import random
|
|
32
|
+
import re
|
|
33
|
+
import sys
|
|
34
|
+
import tempfile
|
|
35
|
+
import threading
|
|
36
|
+
import time
|
|
37
|
+
import traceback
|
|
38
|
+
|
|
39
|
+
from flash.engine.accounting import RunMetrics
|
|
40
|
+
|
|
41
|
+
# Shared, substrate-neutral fine-tuning internals (live in this same package).
|
|
42
|
+
from flash.engine.chalk_kernels import active_kernels, install_chalk_kernels
|
|
43
|
+
from flash.engine.recipe import RECIPE
|
|
44
|
+
|
|
45
|
+
# Re-export the pure helpers split into the leaf submodules ``.perf`` and ``.lora``.
|
|
46
|
+
# CRITICAL: the readers below (run_sft / run_rl / make_lora / _init_adapter_model / ...) call
|
|
47
|
+
# these by their bare name, which resolves through THIS module's namespace — so a test's
|
|
48
|
+
# ``monkeypatch.setattr(worker, "<name>", ...)`` still reaches the readers. Names actually used
|
|
49
|
+
# by the retained readers are imported plainly; names re-exported only for API / test access
|
|
50
|
+
# (no retained reader uses them) are marked unused for the linter.
|
|
51
|
+
from flash.engine.worker.lora import (
|
|
52
|
+
_LM_SYNC_REMAP_ON,
|
|
53
|
+
_VL_EXCLUDE_SEGMENTS, # noqa: F401
|
|
54
|
+
_patch_peft_weight_converter_compat, # noqa: F401
|
|
55
|
+
_remap_vl_sync_weights, # noqa: F401
|
|
56
|
+
assert_adapter_delta_nonzero,
|
|
57
|
+
assert_adapter_load_clean,
|
|
58
|
+
assert_lora_applied,
|
|
59
|
+
disable_liger_grpo_torch_compile,
|
|
60
|
+
is_vl_checkpoint,
|
|
61
|
+
lora_exclude_modules,
|
|
62
|
+
model_quant, # noqa: F401
|
|
63
|
+
patch_grpo_mask_aware_lm_head,
|
|
64
|
+
patch_vllm_language_model_only,
|
|
65
|
+
patch_vllm_lm_weight_sync,
|
|
66
|
+
remap_adapter_keys, # noqa: F401
|
|
67
|
+
remap_vl_adapter_dir,
|
|
68
|
+
strip_language_model_infix, # noqa: F401
|
|
69
|
+
vllm_language_model_only_kwargs, # noqa: F401
|
|
70
|
+
)
|
|
71
|
+
from flash.engine.worker.packing import (
|
|
72
|
+
BlockDiagonalCollator,
|
|
73
|
+
gdn_packing_available,
|
|
74
|
+
model_is_gdn_hybrid,
|
|
75
|
+
model_is_pure_attention,
|
|
76
|
+
pack_token_ids,
|
|
77
|
+
packing_efficiency,
|
|
78
|
+
tokenize_for_packing,
|
|
79
|
+
)
|
|
80
|
+
from flash.engine.worker.perf import (
|
|
81
|
+
RetriableInfraError,
|
|
82
|
+
_attn_impl_for_capability, # noqa: F401
|
|
83
|
+
_ensure_fla_fastpath_on_hopper,
|
|
84
|
+
_estimate_params, # noqa: F401
|
|
85
|
+
_flash_attn_3_available, # noqa: F401
|
|
86
|
+
_flash_attn_available,
|
|
87
|
+
_GpuPeakSampler,
|
|
88
|
+
_liger_default_for_model, # noqa: F401
|
|
89
|
+
_memory_mode,
|
|
90
|
+
_metric_curve,
|
|
91
|
+
_neutralize_tilelang_cudart_stub,
|
|
92
|
+
_peak_gpu_gb,
|
|
93
|
+
_remove_fla_from_disk, # noqa: F401
|
|
94
|
+
_reset_peak_gpu,
|
|
95
|
+
_sdpa_cudnn_ctx,
|
|
96
|
+
free_gpu,
|
|
97
|
+
fused_optim_name,
|
|
98
|
+
gpu_diagnostics,
|
|
99
|
+
grad_checkpointing_on,
|
|
100
|
+
grpo_sleep_mode,
|
|
101
|
+
liger_on,
|
|
102
|
+
loraplus_optimizer_cls,
|
|
103
|
+
optimal_attn_impl,
|
|
104
|
+
setup_perf_backends,
|
|
105
|
+
wait_for_gpu,
|
|
106
|
+
)
|
|
107
|
+
from flash.envs.adapter import GitHubRateLimitError
|
|
108
|
+
from flash.envs.registry import load_environment
|
|
109
|
+
from flash.spec import load_job_spec_from_env
|
|
110
|
+
|
|
111
|
+
HF_REPO = os.environ.get("HF_REPO", "")
|
|
112
|
+
RUN_ID = os.environ.get("RUN_ID", "local")
|
|
113
|
+
SEED = int(os.environ.get("SEED", "0"))
|
|
114
|
+
RUN_MODE = os.environ.get("RUN_MODE", "sft")
|
|
115
|
+
ATTEMPT = os.environ.get("ATTEMPT", "")
|
|
116
|
+
JOB_SPEC = load_job_spec_from_env()
|
|
117
|
+
# PHASE is the stable artifact namespace (sft|rl) and matches RUN_MODE for a train run.
|
|
118
|
+
PHASE = os.environ.get(
|
|
119
|
+
"PHASE",
|
|
120
|
+
JOB_SPEC.phase if JOB_SPEC else (RUN_MODE if RUN_MODE in ("sft", "rl") else "sft"),
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _load_active_env():
|
|
125
|
+
"""Load the run's Freesolo environment from the JobSpec; require an explicit env.
|
|
126
|
+
|
|
127
|
+
There is no default/builtin environment: a run MUST name a published Freesolo
|
|
128
|
+
environment id. Failing here prevents a paid worker from training/evaluating the
|
|
129
|
+
wrong task.
|
|
130
|
+
"""
|
|
131
|
+
if JOB_SPEC is None:
|
|
132
|
+
# No JobSpec at all (e.g. the module imported for a non-run path / a unit test). There
|
|
133
|
+
# is nothing to select; defer the hard requirement to the JobSpec-present branch so the
|
|
134
|
+
# module stays importable. A real run always carries a JobSpec.
|
|
135
|
+
return None
|
|
136
|
+
env_id = JOB_SPEC.environment.id
|
|
137
|
+
if not env_id:
|
|
138
|
+
# Every supported algorithm (sft/grpo) trains/evaluates against a Freesolo env, so a
|
|
139
|
+
# missing env is always a misconfigured spec. Fail loudly rather than fall back to a
|
|
140
|
+
# default and burn a paid worker on the wrong task.
|
|
141
|
+
raise RuntimeError(
|
|
142
|
+
"JobSpec sets no environment: provide [environment] id "
|
|
143
|
+
"(a Freesolo environment id like 'your-name/your-env', returned by "
|
|
144
|
+
"`flash env push --name <name>`)."
|
|
145
|
+
)
|
|
146
|
+
# Pass the control-plane-pinned commit sha (resolve-once hook) when present so the adapter
|
|
147
|
+
# skips the GitHub ref->sha resolve; "" (the default) keeps the worker resolving it itself.
|
|
148
|
+
return load_environment(
|
|
149
|
+
env_id, JOB_SPEC.environment.params, resolved_sha=JOB_SPEC.environment.resolved_sha
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
ACTIVE_ENV = None
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def require_active_env():
|
|
157
|
+
"""Return the run's loaded environment, or raise a CLEAR error when there is none.
|
|
158
|
+
|
|
159
|
+
``ACTIVE_ENV`` is None on the no-JobSpec path (the module is imported with no
|
|
160
|
+
FLASH_JOB_SPEC_JSON/PATH, e.g. a misconfigured worker launch). Every train/eval consumer
|
|
161
|
+
needs a real env; without this guard the first ``ACTIVE_ENV.<attr>`` access dies with an
|
|
162
|
+
opaque ``AttributeError: 'NoneType' object has no attribute ...``. Fail loudly with an
|
|
163
|
+
actionable message instead — mirrors the explicit RuntimeError raised when a JobSpec is
|
|
164
|
+
present but names no environment.
|
|
165
|
+
"""
|
|
166
|
+
global ACTIVE_ENV
|
|
167
|
+
if ACTIVE_ENV is None:
|
|
168
|
+
ACTIVE_ENV = _load_active_env()
|
|
169
|
+
if ACTIVE_ENV is None:
|
|
170
|
+
raise RuntimeError(
|
|
171
|
+
"no environment is loaded: this worker was started without a JobSpec "
|
|
172
|
+
"(FLASH_JOB_SPEC_JSON / FLASH_JOB_SPEC_PATH is unset). A train/eval run must "
|
|
173
|
+
"carry a JobSpec naming [environment] id "
|
|
174
|
+
"(a Freesolo environment id like 'your-name/your-env', returned by "
|
|
175
|
+
"`flash env push --name <name>`)."
|
|
176
|
+
)
|
|
177
|
+
return ACTIVE_ENV
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
# Thinking/reasoning mode: one flag per run from the run config (TOML `thinking`), consumed
|
|
181
|
+
# identically by SFT rendering, RL rollouts, and serving. Defaults off without a JobSpec.
|
|
182
|
+
THINKING = JOB_SPEC.thinking if JOB_SPEC else False
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
# ---------------------------------------------------------------------------
|
|
186
|
+
# HF helpers (code-delivery + artifact channel; works without inbound network)
|
|
187
|
+
# ---------------------------------------------------------------------------
|
|
188
|
+
def error_artifact_name(mode: str) -> str:
|
|
189
|
+
"""Per-mode error filename (e.g. error_sft.txt) so a run's traceback is uploaded
|
|
190
|
+
under a stable name even though heartbeat.json is single-file/overwritten."""
|
|
191
|
+
return f"error_{mode}.txt"
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def hf_api():
|
|
195
|
+
from huggingface_hub import HfApi
|
|
196
|
+
|
|
197
|
+
return HfApi(token=os.environ.get("HF_TOKEN"))
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def hf_prefix() -> str:
|
|
201
|
+
return f"{PHASE}/{RUN_ID}/seed{SEED}"
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _hf_upload(do_upload, repo_subpath: str, required: bool, label: str) -> None:
|
|
205
|
+
"""Shared HF upload loop for files/folders: HF_REPO guard + retry/raise-or-warn.
|
|
206
|
+
|
|
207
|
+
``required=True`` (completion artifacts DONE/metrics.json, the trained adapter) retries
|
|
208
|
+
and finally raises: a swallowed upload failure would make the control plane mark a
|
|
209
|
+
finished run failed/retried, or mark the run done while deployment can never download
|
|
210
|
+
the missing adapter. Optional artifacts (generations, logs) only warn.
|
|
211
|
+
"""
|
|
212
|
+
if not HF_REPO:
|
|
213
|
+
return
|
|
214
|
+
attempts = 3 if required else 1
|
|
215
|
+
for attempt in range(attempts):
|
|
216
|
+
try:
|
|
217
|
+
do_upload()
|
|
218
|
+
return
|
|
219
|
+
except Exception as e:
|
|
220
|
+
if required and attempt + 1 < attempts:
|
|
221
|
+
print(f"{label} retry {attempt + 1}/{attempts}: {e}")
|
|
222
|
+
time.sleep(5 * (attempt + 1))
|
|
223
|
+
continue
|
|
224
|
+
if required:
|
|
225
|
+
# Already retried 3x -> the host/network is bad, not the run. Infra-shaped.
|
|
226
|
+
raise RetriableInfraError(f"required upload of {repo_subpath!r} failed: {e}") from e
|
|
227
|
+
print(f"{label} warn:", e)
|
|
228
|
+
return
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def hf_upload_file(local_path: str, repo_subpath: str, required: bool = False):
|
|
232
|
+
"""Upload one file to the run's HF prefix."""
|
|
233
|
+
_hf_upload(
|
|
234
|
+
lambda: hf_api().upload_file(
|
|
235
|
+
path_or_fileobj=local_path,
|
|
236
|
+
path_in_repo=f"{hf_prefix()}/{repo_subpath}",
|
|
237
|
+
repo_id=HF_REPO,
|
|
238
|
+
repo_type="dataset",
|
|
239
|
+
),
|
|
240
|
+
repo_subpath,
|
|
241
|
+
required,
|
|
242
|
+
"hf_upload_file",
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
_DEBUG_UPLOAD_LOCK = threading.Lock()
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def upload_debug_jsonl(name: str, rows: list[dict], *, keep_last: int = 200) -> None:
|
|
250
|
+
"""Append bounded JSONL debug rows and upload them as an optional artifact.
|
|
251
|
+
|
|
252
|
+
This is intentionally best-effort: debug visibility must not fail a paid run.
|
|
253
|
+
"""
|
|
254
|
+
if not rows or not HF_REPO:
|
|
255
|
+
return
|
|
256
|
+
repo_name = os.path.basename(name if name.endswith(".jsonl") else f"{name}.jsonl")
|
|
257
|
+
path = os.path.join("/tmp", repo_name)
|
|
258
|
+
try:
|
|
259
|
+
with _DEBUG_UPLOAD_LOCK:
|
|
260
|
+
existing: list[str] = []
|
|
261
|
+
with contextlib.suppress(FileNotFoundError), open(path) as f:
|
|
262
|
+
existing = f.readlines()[-keep_last:]
|
|
263
|
+
with open(path, "w") as f:
|
|
264
|
+
f.writelines(existing)
|
|
265
|
+
for row in rows:
|
|
266
|
+
f.write(json.dumps(row, default=str, ensure_ascii=True, sort_keys=True) + "\n")
|
|
267
|
+
hf_upload_file(path, repo_name)
|
|
268
|
+
except Exception as e:
|
|
269
|
+
print(f"debug upload warn ({repo_name}): {e}")
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def hf_upload_folder(local_dir: str, repo_subpath: str, required: bool = False):
|
|
273
|
+
"""Upload a folder to the run's HF prefix."""
|
|
274
|
+
_hf_upload(
|
|
275
|
+
lambda: hf_api().upload_folder(
|
|
276
|
+
folder_path=local_dir,
|
|
277
|
+
path_in_repo=f"{hf_prefix()}/{repo_subpath}",
|
|
278
|
+
repo_id=HF_REPO,
|
|
279
|
+
repo_type="dataset",
|
|
280
|
+
),
|
|
281
|
+
repo_subpath,
|
|
282
|
+
required,
|
|
283
|
+
"hf_upload_folder",
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def hf_resume_checkpoint() -> str | None:
|
|
288
|
+
"""Latest streamed trainer checkpoint for this run (or None).
|
|
289
|
+
|
|
290
|
+
Checkpoints are uploaded DURING the run by ``make_checkpoint_upload_callback`` as
|
|
291
|
+
``<prefix>/checkpoint/checkpoint-<step>/``; a replacement worker downloads the
|
|
292
|
+
newest one so a mid-run preemption costs at most one save interval.
|
|
293
|
+
"""
|
|
294
|
+
if not HF_REPO:
|
|
295
|
+
return None
|
|
296
|
+
try:
|
|
297
|
+
from huggingface_hub import snapshot_download
|
|
298
|
+
|
|
299
|
+
snapshot_download(
|
|
300
|
+
repo_id=HF_REPO,
|
|
301
|
+
repo_type="dataset",
|
|
302
|
+
allow_patterns=[f"{hf_prefix()}/checkpoint/**"],
|
|
303
|
+
local_dir="/tmp/resume",
|
|
304
|
+
token=os.environ.get("HF_TOKEN"),
|
|
305
|
+
)
|
|
306
|
+
base = os.path.join("/tmp/resume", hf_prefix(), "checkpoint")
|
|
307
|
+
if not os.path.isdir(base):
|
|
308
|
+
return None
|
|
309
|
+
cands = [d for d in os.listdir(base) if d.startswith("checkpoint-")]
|
|
310
|
+
if not cands:
|
|
311
|
+
return None
|
|
312
|
+
latest = max(cands, key=lambda d: int(d.split("-")[-1]))
|
|
313
|
+
path = os.path.join(base, latest)
|
|
314
|
+
print(f"[resume] found streamed checkpoint: {path}")
|
|
315
|
+
return path
|
|
316
|
+
except Exception as e:
|
|
317
|
+
print("hf_resume_checkpoint warn:", e)
|
|
318
|
+
return None
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def prefetch_model(model_id: str) -> float:
|
|
322
|
+
"""Pull the model weights into the local HF cache up front; return seconds spent.
|
|
323
|
+
|
|
324
|
+
The trainer/vLLM would download lazily anyway — doing it explicitly (a) makes the
|
|
325
|
+
download a first-class, timed stage in the heartbeat stream (the cold-start metric
|
|
326
|
+
the speed work optimizes), and (b) fails fast with a clear disk/network error
|
|
327
|
+
instead of dying inside trainer construction. Idempotent: a warm cache costs ~0 s.
|
|
328
|
+
"""
|
|
329
|
+
from huggingface_hub import snapshot_download
|
|
330
|
+
|
|
331
|
+
t0 = time.time()
|
|
332
|
+
try:
|
|
333
|
+
snapshot_download(
|
|
334
|
+
repo_id=model_id,
|
|
335
|
+
# weights + tokenizer/config only (same exclusions as the image bake)
|
|
336
|
+
ignore_patterns=["*.pth", "*.gguf", "original/*", "*.onnx", "*.msgpack", "*.h5"],
|
|
337
|
+
)
|
|
338
|
+
except Exception as e:
|
|
339
|
+
# Surface but don't fail here: gated/local-only models still load fine through
|
|
340
|
+
# the normal from_pretrained path the trainer uses next.
|
|
341
|
+
print("prefetch_model warn:", e)
|
|
342
|
+
secs = round(time.time() - t0, 1)
|
|
343
|
+
heartbeat(
|
|
344
|
+
"model_prefetched",
|
|
345
|
+
model=model_id,
|
|
346
|
+
download_seconds=secs,
|
|
347
|
+
hf_transfer=os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", ""),
|
|
348
|
+
gpu=gpu_diagnostics(),
|
|
349
|
+
)
|
|
350
|
+
return secs
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
# Trainer-state files a serving engine never needs: optimizer/scheduler/rng/loss-curve
|
|
354
|
+
# state. Excluded when publishing the deployable per-step adapter so each step's snapshot is
|
|
355
|
+
# just the LoRA weights + config (a few MB), small enough to KEEP every step (no pruning).
|
|
356
|
+
_CHECKPOINT_TRAINER_STATE = (
|
|
357
|
+
"optimizer.pt",
|
|
358
|
+
"optimizer.bin",
|
|
359
|
+
"scheduler.pt",
|
|
360
|
+
"scaler.pt",
|
|
361
|
+
"rng_state*.pth",
|
|
362
|
+
"trainer_state.json",
|
|
363
|
+
"training_args.bin",
|
|
364
|
+
"*.distcp",
|
|
365
|
+
"global_step*/**",
|
|
366
|
+
"latest",
|
|
367
|
+
"zero_to_fp32.py",
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
# The PEFT adapter weights file a checkpoint must carry to be loadable/servable (safetensors is
|
|
371
|
+
# the default; .bin is the legacy fallback). A step with adapter_config.json but no weights is
|
|
372
|
+
# NOT deployable, so it's never published/listed.
|
|
373
|
+
_ADAPTER_WEIGHT_FILES = ("adapter_model.safetensors", "adapter_model.bin")
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def publish_deployable_checkpoint(ckpt_dir: str, step: int) -> str | None:
|
|
377
|
+
"""Mirror a trainer checkpoint's LoRA adapter to a stable, NON-pruned per-step path so a
|
|
378
|
+
run cancelled mid-RL is still one-command-deployable from its last good step.
|
|
379
|
+
|
|
380
|
+
The trainer's checkpoint folder already contains the PEFT adapter (``adapter_config.json``
|
|
381
|
+
+ ``adapter_model.safetensors``) that ``deploy_adapter`` serves; we re-upload just those
|
|
382
|
+
(dropping optimizer/scheduler/rng state) to ``<prefix>/checkpoints/step-<step>/adapter``.
|
|
383
|
+
Unlike the resume checkpoint (``checkpoint/**``, kept latest-only), these accumulate, so
|
|
384
|
+
EVERY step stays deployable. Returns the deployable adapter subfolder, or ``None`` when
|
|
385
|
+
there's no adapter to publish. Best-effort: a failure here never fails a paid run.
|
|
386
|
+
"""
|
|
387
|
+
if not HF_REPO:
|
|
388
|
+
return None
|
|
389
|
+
# Only publish a checkpoint that actually carries a loadable adapter (config AND weights) —
|
|
390
|
+
# never advertise a non-deployable step.
|
|
391
|
+
has_config = os.path.isfile(os.path.join(ckpt_dir, "adapter_config.json"))
|
|
392
|
+
has_weights = any(os.path.isfile(os.path.join(ckpt_dir, w)) for w in _ADAPTER_WEIGHT_FILES)
|
|
393
|
+
if not (has_config and has_weights):
|
|
394
|
+
return None
|
|
395
|
+
subfolder = f"{hf_prefix()}/checkpoints/step-{step}/adapter"
|
|
396
|
+
try:
|
|
397
|
+
hf_api().upload_folder(
|
|
398
|
+
folder_path=ckpt_dir,
|
|
399
|
+
path_in_repo=subfolder,
|
|
400
|
+
repo_id=HF_REPO,
|
|
401
|
+
repo_type="dataset",
|
|
402
|
+
ignore_patterns=list(_CHECKPOINT_TRAINER_STATE),
|
|
403
|
+
)
|
|
404
|
+
heartbeat("checkpoint_deployable", step=step, subfolder=subfolder)
|
|
405
|
+
return subfolder
|
|
406
|
+
except Exception as e:
|
|
407
|
+
print(f"[ckpt] deployable publish warn (step {step}):", e)
|
|
408
|
+
return None
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def make_checkpoint_upload_callback():
|
|
412
|
+
"""Stream each trainer save to HF so preemption loses <= one save interval.
|
|
413
|
+
|
|
414
|
+
Uploads run in a background thread (the train loop never blocks on the network);
|
|
415
|
+
older checkpoints are deleted in the same commit. If an upload is still in flight
|
|
416
|
+
when the next save fires, the new save is skipped (the following one catches up).
|
|
417
|
+
|
|
418
|
+
Each save also publishes a deployable per-step adapter snapshot (``publish_deployable_
|
|
419
|
+
checkpoint``) so a run cancelled mid-RL can still be deployed from its latest step.
|
|
420
|
+
"""
|
|
421
|
+
from transformers import TrainerCallback
|
|
422
|
+
|
|
423
|
+
lock = threading.Lock()
|
|
424
|
+
|
|
425
|
+
class _CheckpointUpload(TrainerCallback):
|
|
426
|
+
def on_save(self, args, state, control, **kwargs):
|
|
427
|
+
if not HF_REPO:
|
|
428
|
+
return
|
|
429
|
+
step = int(state.global_step)
|
|
430
|
+
ckpt_dir = os.path.join(args.output_dir, f"checkpoint-{step}")
|
|
431
|
+
if not os.path.isdir(ckpt_dir):
|
|
432
|
+
return
|
|
433
|
+
if not lock.acquire(blocking=False):
|
|
434
|
+
print(f"[ckpt] upload busy; skipping step {step}")
|
|
435
|
+
return
|
|
436
|
+
|
|
437
|
+
def _upload():
|
|
438
|
+
try:
|
|
439
|
+
hf_api().upload_folder(
|
|
440
|
+
folder_path=ckpt_dir,
|
|
441
|
+
path_in_repo=f"{hf_prefix()}/checkpoint/checkpoint-{step}",
|
|
442
|
+
repo_id=HF_REPO,
|
|
443
|
+
repo_type="dataset",
|
|
444
|
+
delete_patterns=[f"{hf_prefix()}/checkpoint/**"],
|
|
445
|
+
)
|
|
446
|
+
heartbeat("checkpoint_uploaded", step=step)
|
|
447
|
+
# Mirror this step's adapter to its own kept-forever path so the run
|
|
448
|
+
# stays deployable even if it never reaches "done".
|
|
449
|
+
publish_deployable_checkpoint(ckpt_dir, step)
|
|
450
|
+
except Exception as e:
|
|
451
|
+
print("ckpt upload warn:", e)
|
|
452
|
+
finally:
|
|
453
|
+
lock.release()
|
|
454
|
+
|
|
455
|
+
threading.Thread(target=_upload, daemon=True).start()
|
|
456
|
+
|
|
457
|
+
return _CheckpointUpload()
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
# Heartbeat HF-commit throttle. Each heartbeat() commits heartbeat.json to the HF artifact
|
|
461
|
+
# repo; committing every training step (the reward callback fires per step) blows HuggingFace's
|
|
462
|
+
# per-repo commit rate limit (128/hour), especially when several runs share one HF_REPO. Only
|
|
463
|
+
# the per-step "rl_step" stage is high-frequency, so throttle JUST that one to once per
|
|
464
|
+
# 60s; every other stage — including milestones and the terminal done/already_done — always
|
|
465
|
+
# commits so the control plane never misses a transition.
|
|
466
|
+
# The local file + stdout line are always written regardless.
|
|
467
|
+
_HB_LAST_UPLOAD = 0.0
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
# The rl_step heartbeat-upload throttle, in seconds (fixed 60s) — keeps GRPO under HF's
|
|
471
|
+
# 128 commits/hour-per-repo limit when concurrent runs share one HF_REPO.
|
|
472
|
+
_HB_MIN_INTERVAL_S = 60.0
|
|
473
|
+
_HB_THROTTLED_STAGES = frozenset({"rl_step"})
|
|
474
|
+
# Terminal transitions the control plane must never miss — always committed.
|
|
475
|
+
_HB_TERMINAL_STAGES = frozenset({"done", "already_done"})
|
|
476
|
+
_HB_TERMINAL_ONLY = False
|
|
477
|
+
# Even in terminal-only mode, emit a SLOW heartbeat at this cadence so the control plane's stall
|
|
478
|
+
# detector keeps seeing progress through a long
|
|
479
|
+
# training phase and doesn't false-stall the run. 600s -> ~6 commits/hr, far under the 128/hr cap.
|
|
480
|
+
_HB_TERMINAL_ONLY_INTERVAL_S = 600.0
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
# Serializes heartbeat.json writes and _HB_LAST_UPLOAD reads/updates. During GRPO,
|
|
484
|
+
# heartbeat() is called concurrently from the trainer thread (reward callback) and the
|
|
485
|
+
# checkpoint-upload daemon thread; without this lock two writers can interleave and
|
|
486
|
+
# truncate/garble heartbeat.json (and race _HB_LAST_UPLOAD).
|
|
487
|
+
_HB_LOCK = threading.Lock()
|
|
488
|
+
# Serializes the actual HF upload (a slow network commit) SEPARATELY from _HB_LOCK so the
|
|
489
|
+
# trainer's frequent local writes never block on the network. Without it, two heartbeat
|
|
490
|
+
# threads can upload heartbeat.json concurrently: a slower upload could land AFTER a newer
|
|
491
|
+
# one on HF (reorder), so this lock makes uploads strictly ordered.
|
|
492
|
+
_HB_UPLOAD_LOCK = threading.Lock()
|
|
493
|
+
|
|
494
|
+
# Stall diagnostics: when FLASH_STALL_FAULTHANDLER_S > 0, arm a faulthandler watchdog that dumps
|
|
495
|
+
# every thread's Python stack (then exits, so the run FAILS instead of hanging until the
|
|
496
|
+
# control-plane stall watchdog kills it ~25 min later, and the dump is uploaded with
|
|
497
|
+
# console_<phase>.txt). The timer is re-armed on every heartbeat, so it only fires when NO progress
|
|
498
|
+
# heartbeat lands for the whole window -- i.e. a real hang. OFF by default (0); opt-in per run via
|
|
499
|
+
# [worker_env]. Used to localize the GRPO sleep-mode rollout hang.
|
|
500
|
+
_STALL_FAULTHANDLER_S = 0
|
|
501
|
+
with contextlib.suppress(Exception):
|
|
502
|
+
_STALL_FAULTHANDLER_S = int(os.environ.get("FLASH_STALL_FAULTHANDLER_S", "0") or 0)
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
def _rearm_stall_faulthandler() -> None:
|
|
506
|
+
if _STALL_FAULTHANDLER_S <= 0:
|
|
507
|
+
return
|
|
508
|
+
with contextlib.suppress(Exception):
|
|
509
|
+
faulthandler.cancel_dump_traceback_later()
|
|
510
|
+
faulthandler.dump_traceback_later(_STALL_FAULTHANDLER_S, exit=True)
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
def heartbeat(stage: str, **kw):
|
|
514
|
+
global _HB_LAST_UPLOAD
|
|
515
|
+
payload = {
|
|
516
|
+
"stage": stage,
|
|
517
|
+
"ts": time.time(),
|
|
518
|
+
"run_id": RUN_ID,
|
|
519
|
+
"mode": RUN_MODE,
|
|
520
|
+
"seed": SEED,
|
|
521
|
+
"attempt": ATTEMPT,
|
|
522
|
+
**kw,
|
|
523
|
+
}
|
|
524
|
+
# The datacenter the worker actually landed in (RunPod serverless sets RUNPOD_DC_ID) — a
|
|
525
|
+
# diagnostic so the control plane / logs show which region a run hit (the eager weight-cache fleet
|
|
526
|
+
# already has a volume in every storage DC). Empty/absent on non-RunPod (instance) workers and
|
|
527
|
+
# harmless; only emitted when present.
|
|
528
|
+
_dc = os.environ.get("RUNPOD_DC_ID") or ""
|
|
529
|
+
if _dc:
|
|
530
|
+
payload.setdefault("dc", _dc)
|
|
531
|
+
os.makedirs("/tmp/hb", exist_ok=True)
|
|
532
|
+
p = "/tmp/hb/heartbeat.json"
|
|
533
|
+
# _HB_LOCK guards ONLY the fast local work (atomic write + _HB_LAST_UPLOAD + snapshot capture);
|
|
534
|
+
# the slow HF commit runs OUTSIDE it so the trainer's per-step reward callback never blocks on
|
|
535
|
+
# the network behind the checkpoint daemon's commit (a GRPO perf regression).
|
|
536
|
+
with _HB_LOCK:
|
|
537
|
+
# Atomic write: temp file + os.replace() so a concurrent reader never sees a partial file.
|
|
538
|
+
tmp = p + f".{os.getpid()}.{threading.get_ident()}.tmp"
|
|
539
|
+
snapshot = json.dumps(payload)
|
|
540
|
+
with open(tmp, "w") as f:
|
|
541
|
+
f.write(snapshot)
|
|
542
|
+
os.replace(tmp, p)
|
|
543
|
+
now = time.time()
|
|
544
|
+
if stage in _HB_TERMINAL_STAGES or stage.startswith("error_"):
|
|
545
|
+
upload_due = True # never miss a terminal transition
|
|
546
|
+
elif _HB_TERMINAL_ONLY:
|
|
547
|
+
# Benchmark fan-out: keep commits far under the 128/hour cap, but still emit a SLOW
|
|
548
|
+
# heartbeat (~every _HB_TERMINAL_ONLY_INTERVAL_S) so the control-plane stall detector
|
|
549
|
+
# sees progress during a long training phase and doesn't false-stall the run.
|
|
550
|
+
upload_due = (
|
|
551
|
+
_HB_LAST_UPLOAD == 0.0 or (now - _HB_LAST_UPLOAD) >= _HB_TERMINAL_ONLY_INTERVAL_S
|
|
552
|
+
)
|
|
553
|
+
else:
|
|
554
|
+
throttled = stage in _HB_THROTTLED_STAGES
|
|
555
|
+
upload_due = not throttled or (now - _HB_LAST_UPLOAD) >= _HB_MIN_INTERVAL_S
|
|
556
|
+
if upload_due:
|
|
557
|
+
_HB_LAST_UPLOAD = now # claim the slot under the lock (throttle stays atomic)
|
|
558
|
+
if upload_due:
|
|
559
|
+
# Serialize the network commit under a SEPARATE lock so uploads can't reorder, and
|
|
560
|
+
# upload the captured snapshot (via a private temp file, since hf_upload_file takes
|
|
561
|
+
# a path) rather than re-reading p — which a newer heartbeat may already have
|
|
562
|
+
# overwritten between our slot-claim and this upload.
|
|
563
|
+
with _HB_UPLOAD_LOCK:
|
|
564
|
+
up = p + f".{os.getpid()}.{threading.get_ident()}.upload.tmp"
|
|
565
|
+
with open(up, "w") as f:
|
|
566
|
+
f.write(snapshot)
|
|
567
|
+
try:
|
|
568
|
+
hf_upload_file(up, "heartbeat.json")
|
|
569
|
+
finally:
|
|
570
|
+
with contextlib.suppress(OSError):
|
|
571
|
+
os.remove(up)
|
|
572
|
+
# Re-arm the stall watchdog: progress landed, so reset the no-heartbeat timer.
|
|
573
|
+
_rearm_stall_faulthandler()
|
|
574
|
+
print("HEARTBEAT", json.dumps(payload))
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
# ---------------------------------------------------------------------------
|
|
578
|
+
# Decoding parity: render with the model's own chat template and one run-wide thinking
|
|
579
|
+
# flag (off by default), so SFT targets and RL rollouts use identical prompt
|
|
580
|
+
# formatting within a run.
|
|
581
|
+
# ---------------------------------------------------------------------------
|
|
582
|
+
def render_prompt(tokenizer, item) -> str:
|
|
583
|
+
item = item if isinstance(item, dict) else {"question": item}
|
|
584
|
+
msgs = require_active_env().prompt_messages(item)
|
|
585
|
+
return tokenizer.apply_chat_template(
|
|
586
|
+
msgs, tokenize=False, add_generation_prompt=True, enable_thinking=THINKING
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
def strip_think(completion: str | None) -> str | None:
|
|
591
|
+
"""Drop <think>...</think> reasoning before the environment grades/rewards a
|
|
592
|
+
thinking-mode completion.
|
|
593
|
+
|
|
594
|
+
- closed block(s): keep only the text after the LAST </think>. This also covers
|
|
595
|
+
always-thinking templates that pre-open <think> inside the generation prompt,
|
|
596
|
+
whose completions contain </think> with no opening tag.
|
|
597
|
+
- unclosed <think> (completion budget exhausted): keep only the pre-think text
|
|
598
|
+
(usually empty), so answer extraction fails and the completion scores 0 —
|
|
599
|
+
deliberate reward pressure to close thinking within budget, and it keeps a
|
|
600
|
+
last-number fallback from matching numbers inside the reasoning.
|
|
601
|
+
- no tags: unchanged.
|
|
602
|
+
"""
|
|
603
|
+
if completion is None:
|
|
604
|
+
return None
|
|
605
|
+
if "</think>" in completion:
|
|
606
|
+
return completion.rsplit("</think>", 1)[1]
|
|
607
|
+
if "<think>" in completion:
|
|
608
|
+
return completion.split("<think>", 1)[0]
|
|
609
|
+
return completion
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
def graded_text(completion: str | None) -> str | None:
|
|
613
|
+
"""What the env grader/reward sees: thinking runs strip <think> blocks first (a
|
|
614
|
+
completion whose reasoning never closes grades 0 — see strip_think). Applied once
|
|
615
|
+
here, before ACTIVE_ENV.grade/reward, so it works for every environment."""
|
|
616
|
+
return strip_think(completion) if THINKING else completion
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
# ---------------------------------------------------------------------------
|
|
620
|
+
# SFT
|
|
621
|
+
# ---------------------------------------------------------------------------
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
def force_vllm_backend_for_sm120() -> str | None:
|
|
625
|
+
"""On RTX 5090 / consumer Blackwell (sm120), force a PTX-independent vLLM attention backend.
|
|
626
|
+
|
|
627
|
+
vLLM's default rollout backend is flash-attn, whose PRE-BUILT PTX needs a newer driver JIT than
|
|
628
|
+
many 5090 RunPod hosts have — when the JIT fails the colocated rollout silently produces NO
|
|
629
|
+
completions (empty reward_history, ~1.4 s "done"; a whole 22-run sweep hit this on every 5090).
|
|
630
|
+
FLASHINFER is vLLM's Blackwell-native backend (no flash-attn PTX dependency) and trains on a 5090
|
|
631
|
+
(measured: FLASHINFER/TORCH_SDPA/TRITON_ATTN all train, ~116 s). This mirrors the trainer's
|
|
632
|
+
cuDNN-SDPA forcing on sm120 (``_attn_impl_for_capability``). The GRPO no-op guard remains the
|
|
633
|
+
backstop. Returns the backend set (None if not sm120). Fixed — no operator override."""
|
|
634
|
+
try:
|
|
635
|
+
import torch
|
|
636
|
+
|
|
637
|
+
if not torch.cuda.is_available() or torch.cuda.get_device_capability(0)[0] != 12:
|
|
638
|
+
return None
|
|
639
|
+
except Exception as e:
|
|
640
|
+
print("[rl] sm120 vLLM backend probe skipped:", e)
|
|
641
|
+
return None
|
|
642
|
+
os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER"
|
|
643
|
+
print(
|
|
644
|
+
"[rl] sm120 (RTX 5090): VLLM_ATTENTION_BACKEND=FLASHINFER (flash-attn PTX is unreliable "
|
|
645
|
+
"on consumer Blackwell hosts -> empty-rollout failures)"
|
|
646
|
+
)
|
|
647
|
+
return "FLASHINFER"
|
|
648
|
+
|
|
649
|
+
|
|
650
|
+
def finalize_alloc_conf_for_sleep() -> None:
|
|
651
|
+
"""Sync the CUDA allocator conf with the worker's RESOLVED vLLM sleep default (RL runs only).
|
|
652
|
+
|
|
653
|
+
The launcher (providers/*/train.py build_worker_env) picks the sleep-SAFE non-expandable
|
|
654
|
+
PYTORCH_ALLOC_CONF for RL before this process starts, but it can't know the GRPO sleep decision:
|
|
655
|
+
for a small model the worker resolves sleep OFF (the speed default), so the non-expandable conf
|
|
656
|
+
is safe but fragments a long colocate run. Here (we have the model config + GPU) we resolve the
|
|
657
|
+
SAME deterministic sleep default (``_memory_mode``, exactly run_rl's gate) and, if sleep is OFF,
|
|
658
|
+
switch to expandable_segments — which only crashes WITH sleep on, a case we've just ruled out.
|
|
659
|
+
PYTORCH_ALLOC_CONF is read lazily at the first CUDA allocation, so this must run before any
|
|
660
|
+
allocation (it does — called at boot)."""
|
|
661
|
+
if PHASE != "rl":
|
|
662
|
+
return
|
|
663
|
+
try:
|
|
664
|
+
model_id = JOB_SPEC.model if JOB_SPEC else ""
|
|
665
|
+
# Resolve the sleep decision EXACTLY as run_rl does (grpo_sleep_mode: the size/context gate
|
|
666
|
+
# PLUS the resident-fit check against the live card), so the alloc conf matches the sleep
|
|
667
|
+
# mode the trainer will actually use.
|
|
668
|
+
_t = JOB_SPEC.train if JOB_SPEC else None
|
|
669
|
+
ctx = 0
|
|
670
|
+
try:
|
|
671
|
+
if _t and _t.max_length:
|
|
672
|
+
ctx = int(_t.max_length)
|
|
673
|
+
except Exception:
|
|
674
|
+
ctx = 0
|
|
675
|
+
card_gb = 0.0
|
|
676
|
+
try:
|
|
677
|
+
import torch as _torch_card
|
|
678
|
+
|
|
679
|
+
if _torch_card.cuda.is_available():
|
|
680
|
+
# Binary GiB to match grpo_fits_resident (see run_rl); /1e9 over-reports ~7%.
|
|
681
|
+
card_gb = _torch_card.cuda.get_device_properties(0).total_memory / (1024**3)
|
|
682
|
+
except Exception:
|
|
683
|
+
card_gb = 0.0
|
|
684
|
+
# Resolve group_size EXACTLY as run_rl does (gcfg override, else the recipe default), not a
|
|
685
|
+
# flat 8: if the recipe's rl.group_size differs from 8 the alloc-conf sleep decision here
|
|
686
|
+
# would diverge from the trainer's, picking the wrong expandable/non-expandable conf.
|
|
687
|
+
from flash.engine.recipe import RECIPE as _RECIPE
|
|
688
|
+
|
|
689
|
+
_gcfg = grpo_overrides()
|
|
690
|
+
_group_size = int(_gcfg.get("group_size") or _RECIPE.rl.group_size)
|
|
691
|
+
sleep_on = grpo_sleep_mode(
|
|
692
|
+
model_id,
|
|
693
|
+
max_length=ctx,
|
|
694
|
+
group_size=_group_size,
|
|
695
|
+
max_tokens=(_t.max_tokens if _t else None),
|
|
696
|
+
lora_rank=int(_t.lora_rank) if _t and _t.lora_rank else 32,
|
|
697
|
+
thinking=THINKING,
|
|
698
|
+
card_vram_gb=card_gb,
|
|
699
|
+
)
|
|
700
|
+
if not sleep_on: # sleep resolves OFF -> expandable is safe + better
|
|
701
|
+
conf = "expandable_segments:True"
|
|
702
|
+
os.environ["PYTORCH_ALLOC_CONF"] = conf
|
|
703
|
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = conf
|
|
704
|
+
print(f"[alloc] sleep resolves OFF -> {conf} (anti-fragmentation, matches worker gate)")
|
|
705
|
+
else:
|
|
706
|
+
print("[alloc] sleep resolves ON -> keeping launcher's non-expandable conf")
|
|
707
|
+
except Exception as e:
|
|
708
|
+
print("[alloc] auto-conf skipped:", e)
|
|
709
|
+
|
|
710
|
+
|
|
711
|
+
def wandb_report_to() -> list[str]:
|
|
712
|
+
"""TRL/HF ``report_to`` targets. Restores the W&B logging the legacy freesolo training path had
|
|
713
|
+
but the flash migration dropped: report to W&B whenever WANDB_API_KEY is present. No key -> []
|
|
714
|
+
(silent, the metrics.json artifact is still the source of truth).
|
|
715
|
+
|
|
716
|
+
Project + run name come ONLY from the typed ``[wandb]`` config (``JOB_SPEC.wandb``) — there is
|
|
717
|
+
NO WANDB_PROJECT / WANDB_NAME environment variable. HF's WandbCallback has no project argument
|
|
718
|
+
and would read WANDB_PROJECT from the env, so we initialize the run directly via the wandb SDK
|
|
719
|
+
here (``wandb.init(project=..., name=...)``); the Trainer's callback then reuses that run. The
|
|
720
|
+
run's entity is the API key's default account/team (we don't pass ``entity=``), so the only
|
|
721
|
+
W&B env var is the WANDB_API_KEY credential."""
|
|
722
|
+
if not os.environ.get("WANDB_API_KEY"):
|
|
723
|
+
return []
|
|
724
|
+
import importlib.util
|
|
725
|
+
|
|
726
|
+
if importlib.util.find_spec("wandb") is None:
|
|
727
|
+
print("[wandb] WANDB_API_KEY set but the wandb package is missing; skipping W&B logging")
|
|
728
|
+
return []
|
|
729
|
+
# Best-effort, like the bitsandbytes import above: a partial/broken wandb install or an
|
|
730
|
+
# init failure (auth, network, runtime import error) must NOT abort training — W&B logging is
|
|
731
|
+
# optional and metrics.json is the source of truth. Any failure -> no W&B logging ([]).
|
|
732
|
+
try:
|
|
733
|
+
import wandb
|
|
734
|
+
|
|
735
|
+
if wandb.run is None: # init from the spec so the project needs no WANDB_PROJECT env
|
|
736
|
+
project = (JOB_SPEC.wandb.project if JOB_SPEC else None) or "flash"
|
|
737
|
+
wandb.init(project=project, name=wandb_run_name())
|
|
738
|
+
except Exception as e:
|
|
739
|
+
print(
|
|
740
|
+
f"[wandb] W&B init failed ({e}); skipping W&B logging (metrics.json is still written)"
|
|
741
|
+
)
|
|
742
|
+
return []
|
|
743
|
+
return ["wandb"]
|
|
744
|
+
|
|
745
|
+
|
|
746
|
+
def wandb_run_name() -> str:
|
|
747
|
+
"""W&B run name, from the typed ``[wandb] run_name`` config (``JOB_SPEC.wandb.run_name``) only —
|
|
748
|
+
no WANDB_NAME environment variable. An explicit name is used verbatim (the user owns the
|
|
749
|
+
naming); otherwise a stable id tying the dashboard run to the Flash run
|
|
750
|
+
(``flash-<phase>-<run_id>-seed<N>``). Passed to the Trainer via ``TrainingArguments.run_name``
|
|
751
|
+
and to ``wandb.init`` above."""
|
|
752
|
+
configured = JOB_SPEC.wandb.run_name if JOB_SPEC else None
|
|
753
|
+
if configured and configured.strip():
|
|
754
|
+
return configured.strip()
|
|
755
|
+
return f"flash-{PHASE}-{RUN_ID}-seed{SEED}"
|
|
756
|
+
|
|
757
|
+
|
|
758
|
+
def wandb_run_info() -> dict:
|
|
759
|
+
"""The live W&B run's {url, id, project} if W&B is active, else {}. Recorded in metrics.json so
|
|
760
|
+
the W&B run is verifiable + the freesolo agent's `wandb_runs` / the SDK's link_wandb can point at
|
|
761
|
+
the real dashboard URL — the link the flash migration otherwise dropped. Never raises."""
|
|
762
|
+
try:
|
|
763
|
+
import wandb
|
|
764
|
+
|
|
765
|
+
run = getattr(wandb, "run", None)
|
|
766
|
+
if run is None:
|
|
767
|
+
return {}
|
|
768
|
+
return {
|
|
769
|
+
"wandb_url": getattr(run, "url", None),
|
|
770
|
+
"wandb_id": getattr(run, "id", None),
|
|
771
|
+
"wandb_project": getattr(run, "project", None),
|
|
772
|
+
}
|
|
773
|
+
except Exception:
|
|
774
|
+
return {}
|
|
775
|
+
|
|
776
|
+
|
|
777
|
+
def make_lora(model_id: str | None = None):
|
|
778
|
+
"""LoRA config. We target 'all-linear' (every nn.Linear) rather than a hardcoded
|
|
779
|
+
q/k/v/o list: it is architecture-agnostic, so the same recipe works for the dense
|
|
780
|
+
default (Qwen3-4B-Instruct-2507) and for newer models with extra projection
|
|
781
|
+
types (e.g. the Qwen3.5 hybrid Gated-DeltaNet) without missing any adapters.
|
|
782
|
+
For natively-multimodal checkpoints the vision tower is excluded (see
|
|
783
|
+
``lora_exclude_modules``)."""
|
|
784
|
+
from peft import LoraConfig
|
|
785
|
+
|
|
786
|
+
# Adapt every linear projection. "all-linear" is a PEFT SPECIAL string (not a module name)
|
|
787
|
+
# that PEFT expands to all linear layers — the right managed default across the catalog.
|
|
788
|
+
targets = "all-linear"
|
|
789
|
+
rank = JOB_SPEC.train.lora_rank if JOB_SPEC else RECIPE.lora.rank
|
|
790
|
+
alpha = JOB_SPEC.train.lora_alpha if JOB_SPEC else RECIPE.lora.alpha
|
|
791
|
+
kwargs = {
|
|
792
|
+
"r": rank,
|
|
793
|
+
"lora_alpha": alpha,
|
|
794
|
+
"lora_dropout": RECIPE.lora.dropout,
|
|
795
|
+
"target_modules": targets,
|
|
796
|
+
"task_type": "CAUSAL_LM",
|
|
797
|
+
}
|
|
798
|
+
# Adapter initialization: standard zero-B init (the LoRA delta starts at zero, so the saved
|
|
799
|
+
# adapter is a plain residual that loads correctly onto the ORIGINAL base).
|
|
800
|
+
# PiSSA was removed: it mutates the effective base during training, so its saved adapter only
|
|
801
|
+
# reconstructs against the PiSSA-residual base. Loading that adapter onto the unmodified base
|
|
802
|
+
# at SERVING or GRPO WARM-START (which is exactly our flow) corrupts the model -> the served
|
|
803
|
+
# model emits only whitespace and warm-start GRPO hangs. peft can convert PiSSA->standard on
|
|
804
|
+
# save, but the simpler, robust choice is the default init (the convergence gain isn't worth
|
|
805
|
+
# silently breaking serve + warm-start).
|
|
806
|
+
kwargs["init_lora_weights"] = True
|
|
807
|
+
print(
|
|
808
|
+
"[lora] init_lora_weights=True (standard zero-B; PiSSA removed for serve/warm-start safety)"
|
|
809
|
+
)
|
|
810
|
+
# Standard LoRA scaling (alpha/r). rsLoRA was removed: it scales by alpha/sqrt(r) (~5.6x larger
|
|
811
|
+
# for r=32/alpha=64), so with the usual LoRA LR (e.g. 2e-4) the effective update is ~5.6x too
|
|
812
|
+
# large -> SFT diverges to a degenerate adapter (served model repeats a single token / emits
|
|
813
|
+
# whitespace) and the adapter is also fragile under vLLM's rsLoRA handling at serve time.
|
|
814
|
+
# Standard scaling keeps the catalog LRs sane and the saved adapter serve-safe.
|
|
815
|
+
kwargs["use_rslora"] = False
|
|
816
|
+
if model_id and targets == "all-linear":
|
|
817
|
+
exclude = lora_exclude_modules(model_id)
|
|
818
|
+
if exclude:
|
|
819
|
+
kwargs["exclude_modules"] = exclude
|
|
820
|
+
print(f"[lora] excluding modules for {model_id}: {exclude}")
|
|
821
|
+
return LoraConfig(**kwargs)
|
|
822
|
+
|
|
823
|
+
|
|
824
|
+
def require_vllm_for_rollout_func(use_rollout_func: bool, use_vllm: bool, model_id: str) -> None:
|
|
825
|
+
"""Fail fast when a multi-turn GRPO run needs colocated vLLM but it's disabled.
|
|
826
|
+
|
|
827
|
+
The multi-turn rollout closure (``multiturn_rollout.build_rollout_func``) drives generation
|
|
828
|
+
through ``trainer.vllm_generation.llm``. TRL only creates that engine when ``use_vllm`` is
|
|
829
|
+
True, so with vLLM disabled the rollout would AttributeError at the first turn. GRPO now always
|
|
830
|
+
colocates vLLM (``use_vllm`` is unconditionally True), so this guard is defensive — keep it to
|
|
831
|
+
fail fast with an actionable message should a future tier disable the rollout engine.
|
|
832
|
+
"""
|
|
833
|
+
if use_rollout_func and not use_vllm:
|
|
834
|
+
raise RuntimeError(
|
|
835
|
+
f"multi-turn GRPO needs colocated vLLM, which is disabled for {model_id}. "
|
|
836
|
+
"Use a single-turn environment for this model, or a model tier that keeps "
|
|
837
|
+
"vLLM enabled for rollouts."
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
|
|
841
|
+
def run_sft():
|
|
842
|
+
from datasets import Dataset
|
|
843
|
+
from transformers import AutoTokenizer
|
|
844
|
+
from trl import SFTConfig as TRLSFTConfig
|
|
845
|
+
from trl import SFTTrainer
|
|
846
|
+
|
|
847
|
+
env = require_active_env() # fail loudly (not AttributeError: NoneType) on the no-JobSpec path
|
|
848
|
+
t_start = time.time()
|
|
849
|
+
heartbeat("sft_start", gpu=gpu_diagnostics())
|
|
850
|
+
# SFT on a multi-turn env: rows whose target completion is a full trajectory train on the whole
|
|
851
|
+
# transcript (proper multi-turn SFT, handled below); rows with a single-turn target completion
|
|
852
|
+
# collapse to one assistant turn. Warn only for the collapsing case (computed during the
|
|
853
|
+
# dataset build below), not unconditionally.
|
|
854
|
+
wait_for_gpu()
|
|
855
|
+
setup_perf_backends()
|
|
856
|
+
model_id = JOB_SPEC.model if JOB_SPEC else RECIPE.hf_model_id
|
|
857
|
+
download_seconds = prefetch_model(model_id)
|
|
858
|
+
tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
|
859
|
+
if tok.pad_token is None:
|
|
860
|
+
tok.pad_token = tok.eos_token
|
|
861
|
+
|
|
862
|
+
# Build SFT text dataset (seeded shuffle for reproducibility)
|
|
863
|
+
train = env.dataset()
|
|
864
|
+
rng = random.Random(SEED)
|
|
865
|
+
rng.shuffle(train)
|
|
866
|
+
max_examples = int(
|
|
867
|
+
JOB_SPEC.train.max_examples or 0
|
|
868
|
+
if JOB_SPEC and JOB_SPEC.train and JOB_SPEC.train.max_examples is not None
|
|
869
|
+
else 0
|
|
870
|
+
)
|
|
871
|
+
if max_examples > 0:
|
|
872
|
+
train = train[:max_examples]
|
|
873
|
+
texts = []
|
|
874
|
+
multiturn_targets = 0
|
|
875
|
+
for ex in train:
|
|
876
|
+
# The env (via the freesolo-sdk Environment.sft_completion) owns the target completion: the
|
|
877
|
+
# full multi-turn target trajectory (assistant turns + tool calls + tool results + replies)
|
|
878
|
+
# when the row ships one, else a single target assistant turn. Training on the whole
|
|
879
|
+
# transcript is what makes SFT actually multi-turn (the tool-call protocol + replies) — the
|
|
880
|
+
# warm start the GRPO recipe expects. A >1-message completion is a multi-turn trajectory.
|
|
881
|
+
completion = env.sft_completion(ex)
|
|
882
|
+
if len(completion) > 1: # a multi-turn target trajectory (vs a single assistant turn)
|
|
883
|
+
multiturn_targets += 1
|
|
884
|
+
msgs = [*env.prompt_messages(ex), *completion]
|
|
885
|
+
texts.append(
|
|
886
|
+
{
|
|
887
|
+
"text": tok.apply_chat_template(
|
|
888
|
+
msgs, tokenize=False, add_generation_prompt=False, enable_thinking=THINKING
|
|
889
|
+
)
|
|
890
|
+
}
|
|
891
|
+
)
|
|
892
|
+
if multiturn_targets:
|
|
893
|
+
print(f"[sft] multi-turn SFT: {multiturn_targets}/{len(train)} rows train on a full target transcript")
|
|
894
|
+
elif getattr(env, "multi_turn", False):
|
|
895
|
+
print(
|
|
896
|
+
"[sft][warn] this is a multi-turn Freesolo environment but no row ships a multi-turn "
|
|
897
|
+
"target completion; SFT collapses to a single assistant turn per row (tool/env turns "
|
|
898
|
+
"ignored). Provide target transcripts (output={\"messages\": [...]}) for proper multi-turn SFT."
|
|
899
|
+
)
|
|
900
|
+
if THINKING and not any("<think>" in t["text"] for t in texts[:256]):
|
|
901
|
+
print(
|
|
902
|
+
"WARN: thinking mode is ON but no sampled SFT target contains a <think> "
|
|
903
|
+
"trace — training on non-reasoning targets teaches the model to SKIP "
|
|
904
|
+
"thinking. Use a dataset with reasoning traces, or set thinking = false."
|
|
905
|
+
)
|
|
906
|
+
ds = Dataset.from_list(texts)
|
|
907
|
+
|
|
908
|
+
setup_seconds = time.time() - t_start
|
|
909
|
+
heartbeat("sft_model_load", setup_seconds=setup_seconds, gpu=gpu_diagnostics())
|
|
910
|
+
|
|
911
|
+
# Epochs come from the run's [train] epochs (already in JOB_SPEC), else the recipe default.
|
|
912
|
+
epochs = int(
|
|
913
|
+
JOB_SPEC.train.epochs
|
|
914
|
+
if JOB_SPEC and JOB_SPEC.train.epochs is not None
|
|
915
|
+
else RECIPE.sft.num_epochs
|
|
916
|
+
)
|
|
917
|
+
# SDK [train] knobs override the recipe default.
|
|
918
|
+
from flash.catalog import vocab_size_for
|
|
919
|
+
from flash.engine.vram import resolve_params_b, sft_grad_accum, sft_logits_fused
|
|
920
|
+
|
|
921
|
+
_t = JOB_SPEC.train if JOB_SPEC else None
|
|
922
|
+
sft_lr = _t.learning_rate if _t and _t.learning_rate is not None else RECIPE.sft.learning_rate
|
|
923
|
+
sft_max_len = (
|
|
924
|
+
_t.max_length
|
|
925
|
+
if _t and _t.max_length is not None
|
|
926
|
+
else (RECIPE.sft.max_seq_len_thinking if THINKING else RECIPE.sft.max_seq_len)
|
|
927
|
+
)
|
|
928
|
+
# batch_size is the GLOBAL/effective batch; sft_grad_accum sizes the per-device micro-batch +
|
|
929
|
+
# grad-accum to realize it (shared with the cost estimator's step count, see engine.vram).
|
|
930
|
+
effective_batch = (
|
|
931
|
+
_t.batch_size if _t and _t.batch_size is not None else RECIPE.sft.effective_batch
|
|
932
|
+
)
|
|
933
|
+
# Large-vocab OOM guard: when the fused CE (Liger) is OFF, the SFTTrainer materializes the full
|
|
934
|
+
# [per_device, seq, vocab] fp32 logits + grad — at Qwen3.5's ~248k vocab a 0.8B SFT OOM'd a
|
|
935
|
+
# 24 GB card in backward. Cap the per-device micro-batch by the real model vocab + seq so those
|
|
936
|
+
# logits stay within the logits budget; grad-accum rises to keep the effective batch unchanged
|
|
937
|
+
# (the SFT mirror of rl_per_device_comps' GRPO cap). fused mirrors liger_on(_memory_mode(...))
|
|
938
|
+
# below, so the cap binds exactly when the worker won't fuse the CE.
|
|
939
|
+
_sft_params_b = resolve_params_b(model_id) # catalog stat else HF safetensors (open models)
|
|
940
|
+
_sft_vocab = vocab_size_for(model_id)
|
|
941
|
+
# Actual fused-CE decision == what `use_liger_kernel` is set from below (line ~879). sft_logits_fused
|
|
942
|
+
# is the offline size/ctx mirror; liger_on(...) adds the runtime CUDA + liger_kernel-importable
|
|
943
|
+
# check, so the cap binds exactly when the fused CE is NOT really taken.
|
|
944
|
+
_sft_fused = sft_logits_fused(_sft_params_b, sft_max_len) and liger_on(
|
|
945
|
+
_memory_mode(model_id, sft_max_len)
|
|
946
|
+
)
|
|
947
|
+
per_device_bs, grad_accum = sft_grad_accum(
|
|
948
|
+
effective_batch, seq_len=sft_max_len, vocab=_sft_vocab, fused=_sft_fused
|
|
949
|
+
)
|
|
950
|
+
if not _sft_fused and per_device_bs < min(effective_batch, 4):
|
|
951
|
+
print(
|
|
952
|
+
f"[sft] large-vocab logits cap: per_device={per_device_bs} grad_accum={grad_accum} "
|
|
953
|
+
f"(seq={sft_max_len}, vocab={_sft_vocab}; realized batch "
|
|
954
|
+
f"{per_device_bs * grad_accum} >= requested {effective_batch})"
|
|
955
|
+
)
|
|
956
|
+
sft_save_default = _t.save_every if _t and _t.save_every is not None else 50
|
|
957
|
+
out_dir = f"/tmp/sft_seed{SEED}"
|
|
958
|
+
resume_ckpt = hf_resume_checkpoint()
|
|
959
|
+
|
|
960
|
+
# [train].max_steps>0 caps optimizer steps (used by the cheap pre-flight smoke).
|
|
961
|
+
max_steps = int(_t.max_steps or 0 if _t and _t.max_steps is not None else 0)
|
|
962
|
+
cfg_kwargs = {
|
|
963
|
+
"output_dir": out_dir,
|
|
964
|
+
"num_train_epochs": epochs,
|
|
965
|
+
"per_device_train_batch_size": per_device_bs,
|
|
966
|
+
"gradient_accumulation_steps": grad_accum,
|
|
967
|
+
"learning_rate": sft_lr,
|
|
968
|
+
"warmup_ratio": RECIPE.sft.warmup_frac,
|
|
969
|
+
"logging_steps": 10,
|
|
970
|
+
"save_steps": sft_save_default,
|
|
971
|
+
"save_total_limit": 1,
|
|
972
|
+
# Resumable checkpoints: save the optimizer / scheduler / RNG state alongside the (small)
|
|
973
|
+
# LoRA adapter. We DO resume mid-run — make_checkpoint_upload_callback streams each save to
|
|
974
|
+
# HF and a replacement worker calls resume_from_checkpoint(hf_resume_checkpoint()) after a
|
|
975
|
+
# preemption — so without this the resumed run would re-initialize the optimizer (Adam
|
|
976
|
+
# moments) and LR schedule instead of truly continuing. For LoRA the optimizer state is tiny
|
|
977
|
+
# (it covers only the trainable adapter params), so the save spike is negligible. The
|
|
978
|
+
# deployable per-step snapshot (publish_deployable_checkpoint) strips this trainer state
|
|
979
|
+
# separately, so serving still gets adapter-only files.
|
|
980
|
+
"save_only_model": False,
|
|
981
|
+
"max_length": sft_max_len,
|
|
982
|
+
"bf16": True,
|
|
983
|
+
"report_to": wandb_report_to(), # W&B when WANDB_API_KEY present (restored post-flash-migration)
|
|
984
|
+
"run_name": wandb_run_name(),
|
|
985
|
+
# Dataloader parallelism: overlap host-side collation/tokenization with GPU compute so a
|
|
986
|
+
# real (large) training set isn't dataloader-bound. Pure throughput, zero quality change.
|
|
987
|
+
# Negligible on the tiny benchmark (pre-tokenized, in-memory); a real win at production
|
|
988
|
+
# dataset sizes.
|
|
989
|
+
"dataloader_num_workers": 4,
|
|
990
|
+
"dataloader_pin_memory": True,
|
|
991
|
+
"dataloader_persistent_workers": True,
|
|
992
|
+
"seed": SEED,
|
|
993
|
+
"gradient_checkpointing": grad_checkpointing_on(model_id, sft_max_len),
|
|
994
|
+
# Non-reentrant checkpointing: composes cleanly with autograd hooks (verl #3629) and is
|
|
995
|
+
# required by TRL for correct grad flow through the LoRA adapters.
|
|
996
|
+
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
|
997
|
+
"completion_only_loss": False,
|
|
998
|
+
# Optimizer: 8-bit paged AdamW (int8 state paged to host RAM -> fits a smaller GPU).
|
|
999
|
+
"optim": fused_optim_name(),
|
|
1000
|
+
}
|
|
1001
|
+
if max_steps > 0:
|
|
1002
|
+
cfg_kwargs["max_steps"] = max_steps
|
|
1003
|
+
# Example packing: concatenate short examples into full max_length sequences so a batch isn't
|
|
1004
|
+
# mostly pad tokens — PR #174 measured a 4.4-10.7x SFT speedup (h100 8.2x, 4090 10.7x) because
|
|
1005
|
+
# instruction targets are far shorter than max_seq_len; unpacked batches waste most of their
|
|
1006
|
+
# FLOPs on padding. TRL's 'bfd' strategy makes padding-free batches whose example boundaries are
|
|
1007
|
+
# honored ONLY by an attention impl that reads them — under plain SDPA packed examples
|
|
1008
|
+
# cross-contaminate (silent quality loss). The boundary-correct backend is FlashAttention-2
|
|
1009
|
+
# varlen (reads position_ids), which the worker image bakes in best-effort: Dockerfile.worker
|
|
1010
|
+
# installs FLASH_ATTN_SPEC (a community cu128/torch2.10/cp312 wheel preferred, source build as a
|
|
1011
|
+
# fallback) and tolerates a build failure -> SDPA. So _fa_ok is True whenever that install landed;
|
|
1012
|
+
# packing is ON then (varlen keeps 'bfd' example boundaries correct). If the best-effort install
|
|
1013
|
+
# failed, _fa_ok is False and we SKIP packing — without a boundary-correct attn backend examples
|
|
1014
|
+
# would cross-contaminate under SDPA.
|
|
1015
|
+
# Pure full-attention vs GatedDeltaNet hybrid (Qwen3.5/3.6) — probed ONCE here and reused across
|
|
1016
|
+
# the whole packing decision (each probe reads the cached HF config). TRL 'bfd' packing keeps
|
|
1017
|
+
# example boundaries via position_ids that a varlen attn honors, but it provides NO seq_idx, so it
|
|
1018
|
+
# can't reset a GDN hybrid's causal conv -> bfd-packing a GDN model silently cross-contaminates its
|
|
1019
|
+
# linear-attention layers. So bfd is enabled for PURE full-attention models only; GDN hybrids pack
|
|
1020
|
+
# via the cu_seqlens/seq_idx varlen collator branch below (when their kernels are present).
|
|
1021
|
+
_pure_attn = model_is_pure_attention(model_id)
|
|
1022
|
+
_gdn = model_is_gdn_hybrid(model_id)
|
|
1023
|
+
_fa_ok = _flash_attn_available()
|
|
1024
|
+
if _fa_ok and _pure_attn:
|
|
1025
|
+
cfg_kwargs["packing"] = True
|
|
1026
|
+
print("[sft] example packing enabled (FA2 varlen)")
|
|
1027
|
+
elif _fa_ok and _gdn:
|
|
1028
|
+
print(
|
|
1029
|
+
"[sft] TRL bfd packing NOT used for the GatedDeltaNet hybrid (bfd can't reset the conv); "
|
|
1030
|
+
"the cu_seqlens/seq_idx varlen collator handles its packing when both kernels are present."
|
|
1031
|
+
)
|
|
1032
|
+
else:
|
|
1033
|
+
# FA2 bfd packing not enabled here — either flash_attn isn't importable, or it is but the arch
|
|
1034
|
+
# isn't bfd-safe (e.g. sliding-window). This is NOT the final word: the SDPA block-diagonal /
|
|
1035
|
+
# GDN-varlen block below may still turn packing on for a pure-attention or GDN-hybrid model.
|
|
1036
|
+
_bfd_why = "flash_attn not importable" if not _fa_ok else "arch not bfd-safe under FA2 varlen"
|
|
1037
|
+
print(f"[sft] TRL bfd (FA2) packing not used ({_bfd_why}); the SDPA-mask path decides packing below.")
|
|
1038
|
+
# Liger fused CE/RMSNorm/RoPE kernels, gated by model size (_memory_mode). The fused linear
|
|
1039
|
+
# cross-entropy is the big large-vocab (Qwen3.5 ~248k) memory/throughput win.
|
|
1040
|
+
if liger_on(_memory_mode(model_id, sft_max_len)):
|
|
1041
|
+
cfg_kwargs["use_liger_kernel"] = True
|
|
1042
|
+
print("[sft] liger fused kernels enabled")
|
|
1043
|
+
_attn = optimal_attn_impl() # arch-best FlashAttention (FA3 Hopper / FA2 Ampere·Ada) or SDPA
|
|
1044
|
+
# Packing correctness: 'bfd' packed batches are boundary-correct ONLY under a varlen-capable attn
|
|
1045
|
+
# (FA2 and FA3 both expose flash_attn_varlen_func; plain SDPA cross-contaminates packed examples).
|
|
1046
|
+
# Use the ARCH-BEST flash impl optimal_attn_impl already picked (so Hopper packs under FA3, not
|
|
1047
|
+
# FA2). Cases when it did NOT pick a flash impl:
|
|
1048
|
+
# * _attn == "sdpa" (sm120, the deliberate no-flash exception): DISABLE packing — consumer
|
|
1049
|
+
# Blackwell stays plain SDPA; do NOT force FA2 (its sm120 kernel coverage is unverified).
|
|
1050
|
+
# * _attn is None (Hopper without FA3): force FA2 for boundary-correct varlen IF the wheel is
|
|
1051
|
+
# importable; else drop packing rather than silently cross-contaminate.
|
|
1052
|
+
if cfg_kwargs.get("packing"):
|
|
1053
|
+
if _attn in ("flash_attention_2", "flash_attention_3"):
|
|
1054
|
+
print(f"[sft] attn_implementation={_attn} (packing boundary-correct varlen)")
|
|
1055
|
+
elif _attn == "sdpa":
|
|
1056
|
+
cfg_kwargs["packing"] = False
|
|
1057
|
+
print("[sft] packing disabled: selected attn_implementation=sdpa (no varlen flash backend)")
|
|
1058
|
+
elif _fa_ok:
|
|
1059
|
+
_attn = "flash_attention_2"
|
|
1060
|
+
print("[sft] attn_implementation=flash_attention_2 (packing boundary-correct varlen)")
|
|
1061
|
+
else:
|
|
1062
|
+
cfg_kwargs["packing"] = False
|
|
1063
|
+
print("[sft] packing disabled: no varlen flash backend (FA2/FA3) available -> plain SDPA")
|
|
1064
|
+
|
|
1065
|
+
# --- True token packing via a 4D block-diagonal SDPA mask (no flash-attn / no flex) ---------
|
|
1066
|
+
# When the run lands on plain SDPA (no varlen flash backend) the block above left packing OFF —
|
|
1067
|
+
# notably on sm120 (RTX 5090, flash's DEFAULT GPU), and anywhere the best-effort flash-attn
|
|
1068
|
+
# build didn't land. For a PURE full-attention model we can still pack: concatenate examples
|
|
1069
|
+
# into max_length blocks and feed a 4D block-diagonal causal mask SDPA honors natively, so
|
|
1070
|
+
# packed examples never attend across boundaries (boundary-correct, numerically identical to
|
|
1071
|
+
# unpacked — verified on a tiny Qwen3/Llama: |packed-separate| logits ~1e-7). This reclaims the packing
|
|
1072
|
+
# throughput win on the default GPU with neither flash-attn nor flex_attention. GatedDeltaNet
|
|
1073
|
+
# hybrids (Qwen3.5/3.6) take the NEXT branch instead — a mask alone can't reset their linear-
|
|
1074
|
+
# attention state, so they also need the cu_seqlens/seq_idx varlen kwargs.
|
|
1075
|
+
_collator = None
|
|
1076
|
+
# The mask paths materialize a dense [B, 1, T, T] mask — O(T^2) memory. At very long context that
|
|
1077
|
+
# tax (hundreds of MB to >1 GB) can OOM a run that previously fit under memory-efficient SDPA, and
|
|
1078
|
+
# packing buys little there anyway (long rows already fill a block). Above this cap, leave packing
|
|
1079
|
+
# off (train unpacked, as today). 16384: the dense bf16/bool mask stays <=~256 MB at bsz=1.
|
|
1080
|
+
_PACK_MASK_MAX_LEN = 16384
|
|
1081
|
+
_mask_pack_ok = sft_max_len <= _PACK_MASK_MAX_LEN
|
|
1082
|
+
_sdpa_pack = bool(not cfg_kwargs.get("packing") and _pure_attn and _mask_pack_ok)
|
|
1083
|
+
if _sdpa_pack:
|
|
1084
|
+
# The 4D mask requires a MASK-READING attn (SDPA). DOWNGRADE any flash impl optimal_attn_impl
|
|
1085
|
+
# picked — e.g. FA3 on a Hopper worker whose FA2 wheel didn't build — to SDPA: a flash varlen
|
|
1086
|
+
# kernel SILENTLY IGNORES the 4D mask, so packed examples would attend across boundaries. (A
|
|
1087
|
+
# bare ``_attn or "sdpa"`` would leave the truthy flash string in place — the bug this avoids.)
|
|
1088
|
+
if _attn in ("flash_attention_2", "flash_attention_3"):
|
|
1089
|
+
print(f"[sft] packing under SDPA: downgrading {_attn} -> sdpa (a flash kernel ignores the 4D mask)")
|
|
1090
|
+
_attn = "sdpa"
|
|
1091
|
+
cfg_kwargs["packing"] = False # we own the packing; TRL must not also pack
|
|
1092
|
+
# Hand TRL pre-tokenized, pre-packed rows + our collator: skip its dataset prep and stop the
|
|
1093
|
+
# signature-based column pruning from dropping our seq_lengths column before collation.
|
|
1094
|
+
_dk = dict(cfg_kwargs.get("dataset_kwargs") or {})
|
|
1095
|
+
_dk["skip_prepare_dataset"] = True
|
|
1096
|
+
cfg_kwargs["dataset_kwargs"] = _dk
|
|
1097
|
+
cfg_kwargs["remove_unused_columns"] = False
|
|
1098
|
+
# Tokenize EXACTLY like TRL's non-packed prep (EOS-append parity so the model still learns to
|
|
1099
|
+
# stop; batched; truncate to max_length) then bin-pack into <= max_length blocks.
|
|
1100
|
+
_tokenized = tokenize_for_packing([t["text"] for t in texts], tok, sft_max_len)
|
|
1101
|
+
_packed_rows = pack_token_ids(_tokenized, sft_max_len)
|
|
1102
|
+
ds = Dataset.from_list(_packed_rows)
|
|
1103
|
+
_collator = BlockDiagonalCollator(pad_token_id=tok.pad_token_id)
|
|
1104
|
+
# Memory: re-size the per-device micro-batch (in BLOCKS) for the full-block [pd, max_length,
|
|
1105
|
+
# vocab] fp32 logits budget — a no-op under Liger's fused CE. Quality: each block holds
|
|
1106
|
+
# ~ex_per_block examples, so KEEP the effective batch in EXAMPLES at the configured value by
|
|
1107
|
+
# re-deriving grad_accum from the block count. Without this, packing balloons the effective
|
|
1108
|
+
# batch ~ex_per_block-fold (fewer, larger updates -> mild undertraining at the same epochs:
|
|
1109
|
+
# an A/B measured +5.2% held-out loss vs unpacked, closed to +0.1% once matched).
|
|
1110
|
+
_pd_pack, _ = sft_grad_accum(
|
|
1111
|
+
effective_batch, seq_len=sft_max_len, vocab=vocab_size_for(model_id),
|
|
1112
|
+
fused=bool(cfg_kwargs.get("use_liger_kernel")),
|
|
1113
|
+
)
|
|
1114
|
+
# The dense [pd, 1, T, T] bool mask is pd*T^2 bytes — under Liger the logits cap doesn't bind
|
|
1115
|
+
# so pd can be 4, and at long context that mask alone is GBs. Cap pd so the mask stays <=512MB
|
|
1116
|
+
# (a no-op at short ctx: at T=2048 it allows pd up to ~125; it only bites past ~12k tokens).
|
|
1117
|
+
_pd_pack = max(1, min(_pd_pack, (512 * 1024 * 1024) // (sft_max_len * sft_max_len)))
|
|
1118
|
+
_ex_per_block = len(_tokenized) / max(1, len(_packed_rows))
|
|
1119
|
+
cfg_kwargs["per_device_train_batch_size"] = _pd_pack
|
|
1120
|
+
cfg_kwargs["gradient_accumulation_steps"] = max(
|
|
1121
|
+
1, math.ceil(effective_batch / max(1.0, _pd_pack * _ex_per_block))
|
|
1122
|
+
)
|
|
1123
|
+
print(
|
|
1124
|
+
"[sft] true token packing ENABLED (4D block-diagonal SDPA mask): "
|
|
1125
|
+
f"{len(_tokenized)} examples -> {len(_packed_rows)} blocks (~{_ex_per_block:.1f} ex/block, "
|
|
1126
|
+
f"{packing_efficiency(_packed_rows, sft_max_len):.0%} dense) of <= {sft_max_len} tok; "
|
|
1127
|
+
f"pd={_pd_pack} ga={cfg_kwargs['gradient_accumulation_steps']} (effective batch kept "
|
|
1128
|
+
f"~{effective_batch} ex); no flash-attn / no flex_attention"
|
|
1129
|
+
)
|
|
1130
|
+
elif not cfg_kwargs.get("packing") and _gdn and gdn_packing_available(model_id) and _mask_pack_ok:
|
|
1131
|
+
# GatedDeltaNet hybrid (Qwen3.5/3.6, flash's flagship tier): the 4D block-diagonal mask makes
|
|
1132
|
+
# the FULL-attention layers boundary-correct, and the linear-attention (DeltaNet) layers reset
|
|
1133
|
+
# their recurrence + causal conv at example boundaries via cu_seq_lens_q (fla kernel) + seq_idx
|
|
1134
|
+
# (causal_conv1d). GPU-validated on Qwen3.5-0.8B (RTX 5090): a packed example's output is
|
|
1135
|
+
# byte-identical regardless of its neighbors' content (ZERO cross-example leakage); the only
|
|
1136
|
+
# diff vs unpacked is benign bf16 GDN-kernel tiling numerics (~0.3 on logits). Gated on BOTH
|
|
1137
|
+
# kernels being importable (gdn_packing_available) so a worker without them stays unpacked.
|
|
1138
|
+
# Pin SDPA for the full-attn layers (downgrade any flash impl, e.g. FA3 on Hopper — it would
|
|
1139
|
+
# ignore the 4D mask); the DeltaNet layers are unaffected (they use cu_seqlens/seq_idx).
|
|
1140
|
+
if _attn in ("flash_attention_2", "flash_attention_3"):
|
|
1141
|
+
print(f"[sft] GDN packing under SDPA: downgrading {_attn} -> sdpa for the full-attn layers")
|
|
1142
|
+
_attn = "sdpa"
|
|
1143
|
+
cfg_kwargs["packing"] = False
|
|
1144
|
+
_dk = dict(cfg_kwargs.get("dataset_kwargs") or {})
|
|
1145
|
+
_dk["skip_prepare_dataset"] = True
|
|
1146
|
+
cfg_kwargs["dataset_kwargs"] = _dk
|
|
1147
|
+
cfg_kwargs["remove_unused_columns"] = False
|
|
1148
|
+
# EOS-append parity + batched + truncated tokenization (same as the unpacked path), then pack.
|
|
1149
|
+
_tokenized = tokenize_for_packing([t["text"] for t in texts], tok, sft_max_len)
|
|
1150
|
+
_packed_rows = pack_token_ids(_tokenized, sft_max_len)
|
|
1151
|
+
ds = Dataset.from_list(_packed_rows)
|
|
1152
|
+
_collator = BlockDiagonalCollator(pad_token_id=tok.pad_token_id, emit_varlen=True)
|
|
1153
|
+
# cu_seqlens spans ONE packed block, so per-device is a single block; keep the effective batch
|
|
1154
|
+
# in EXAMPLES at the configured value via grad-accum (each block holds ~ex_per_block examples —
|
|
1155
|
+
# without this the effective batch would balloon ~ex_per_block-fold -> undertraining).
|
|
1156
|
+
_ex_per_block = len(_tokenized) / max(1, len(_packed_rows))
|
|
1157
|
+
cfg_kwargs["per_device_train_batch_size"] = 1
|
|
1158
|
+
cfg_kwargs["gradient_accumulation_steps"] = max(1, math.ceil(effective_batch / max(1.0, _ex_per_block)))
|
|
1159
|
+
print(
|
|
1160
|
+
"[sft] true token packing ENABLED for GatedDeltaNet hybrid (4D mask + cu_seqlens/seq_idx "
|
|
1161
|
+
f"varlen): {len(_tokenized)} examples -> {len(_packed_rows)} blocks (~{_ex_per_block:.1f} "
|
|
1162
|
+
f"ex/block, {packing_efficiency(_packed_rows, sft_max_len):.0%} dense) of <= {sft_max_len} "
|
|
1163
|
+
f"tok; pd=1 ga={cfg_kwargs['gradient_accumulation_steps']} (effective batch kept ~{effective_batch} ex)"
|
|
1164
|
+
)
|
|
1165
|
+
elif not cfg_kwargs.get("packing") and (_pure_attn or _gdn) and not _mask_pack_ok:
|
|
1166
|
+
print(
|
|
1167
|
+
f"[sft] packing stays OFF: max_length {sft_max_len} > {_PACK_MASK_MAX_LEN} — the dense "
|
|
1168
|
+
"O(T^2) block-diagonal mask gets too large at long context (unpacked is more memory-"
|
|
1169
|
+
"efficient there, and long rows already fill a block)."
|
|
1170
|
+
)
|
|
1171
|
+
elif not cfg_kwargs.get("packing") and not _pure_attn:
|
|
1172
|
+
_why = (
|
|
1173
|
+
"hybrid GatedDeltaNet but the fla/causal_conv1d varlen kernels aren't both importable"
|
|
1174
|
+
if _gdn
|
|
1175
|
+
else "non-full-attention arch (e.g. sliding-window) a block-diagonal mask can't pack"
|
|
1176
|
+
)
|
|
1177
|
+
print(f"[sft] packing stays OFF: {_why}. (Pure full-attention models pack via the SDPA mask.)")
|
|
1178
|
+
# Explicit bf16 + no auto device-map: TRL/transformers-5 string loading can
|
|
1179
|
+
# otherwise fall back to fp32 (2x VRAM; observed 18.6 GB for a 4.66B model) or
|
|
1180
|
+
# accelerate-offload large models to meta ("expected device meta but got
|
|
1181
|
+
# cuda:0" in backward on the 9B).
|
|
1182
|
+
mik = {"dtype": "bfloat16", "device_map": None}
|
|
1183
|
+
if _attn:
|
|
1184
|
+
mik["attn_implementation"] = _attn
|
|
1185
|
+
cfg_kwargs["model_init_kwargs"] = mik
|
|
1186
|
+
cfg = TRLSFTConfig(**cfg_kwargs)
|
|
1187
|
+
|
|
1188
|
+
# LoRA+ (convergence lever, arXiv 2402.12354; always-on: measured -52% train loss in A/B
|
|
1189
|
+
# (gpu-bench)): give the LoRA B matrices a higher LR than A (ratio 16). Reported ~2x fewer steps
|
|
1190
|
+
# to target at identical per-step FLOPs. TRL builds the model from a string inside __init__, so
|
|
1191
|
+
# the optimizer (which needs the instantiated params) can't be pre-built — override
|
|
1192
|
+
# create_optimizer to construct it from self.model once it exists.
|
|
1193
|
+
_lp_ratio = 16
|
|
1194
|
+
_SFT = SFTTrainer
|
|
1195
|
+
if _lp_ratio > 1:
|
|
1196
|
+
|
|
1197
|
+
class _SFT(SFTTrainer): # local LoRA+ subclass
|
|
1198
|
+
_loraplus_applied = False # True only once the LoRA+ grouping actually installs
|
|
1199
|
+
|
|
1200
|
+
def create_optimizer(self):
|
|
1201
|
+
if self.optimizer is None:
|
|
1202
|
+
try:
|
|
1203
|
+
from peft.optimizers import create_loraplus_optimizer
|
|
1204
|
+
|
|
1205
|
+
# Mirror the configured `optim` so LoRA+ and the 8-bit paged optimizer state
|
|
1206
|
+
# coexist (instead of silently forcing fp32 AdamW); see loraplus_optimizer_cls.
|
|
1207
|
+
# .value (not str()): self.args.optim is a TRL OptimizerNames enum whose
|
|
1208
|
+
# str() is "OptimizerNames.PAGED_ADAMW_8BIT"; pass the raw value
|
|
1209
|
+
# ("paged_adamw_8bit") so the 8-bit match works.
|
|
1210
|
+
opt_cls, extra = loraplus_optimizer_cls(
|
|
1211
|
+
getattr(self.args.optim, "value", self.args.optim)
|
|
1212
|
+
)
|
|
1213
|
+
# Forward the TrainingArguments optimizer config that the default HF
|
|
1214
|
+
# create_optimizer path would have applied. Building the optimizer
|
|
1215
|
+
# ourselves means we must replicate it explicitly, or LoRA+ runs would
|
|
1216
|
+
# silently use the optimizer class's own defaults instead of the
|
|
1217
|
+
# configured betas/eps/weight_decay. betas/eps go straight to the optimizer
|
|
1218
|
+
# constructor (alongside any `extra` from loraplus_optimizer_cls);
|
|
1219
|
+
# weight_decay is handled separately below.
|
|
1220
|
+
fwd = dict(extra)
|
|
1221
|
+
_betas = (
|
|
1222
|
+
getattr(self.args, "adam_beta1", None),
|
|
1223
|
+
getattr(self.args, "adam_beta2", None),
|
|
1224
|
+
)
|
|
1225
|
+
if None not in _betas:
|
|
1226
|
+
fwd.setdefault("betas", _betas)
|
|
1227
|
+
_eps = getattr(self.args, "adam_epsilon", None)
|
|
1228
|
+
if _eps is not None:
|
|
1229
|
+
fwd.setdefault("eps", _eps)
|
|
1230
|
+
# PEFT does NOT read args.weight_decay; it applies decay via its own LoRA+
|
|
1231
|
+
# param groups, keyed off the loraplus_weight_decay kwarg (which it pops
|
|
1232
|
+
# before constructing the optimizer). Pass it as a top-level kwarg so it
|
|
1233
|
+
# isn't forwarded into the optimizer constructor.
|
|
1234
|
+
lp_extra: dict[str, object] = {}
|
|
1235
|
+
_wd = getattr(self.args, "weight_decay", None)
|
|
1236
|
+
if _wd is not None:
|
|
1237
|
+
lp_extra["loraplus_weight_decay"] = _wd
|
|
1238
|
+
# PEFT's create_loraplus_optimizer forwards extra kwargs to the optimizer;
|
|
1239
|
+
# the lr keyword name has shifted across PEFT versions, so pass it via
|
|
1240
|
+
# optimizer_kwargs (the stable form) and fall back to a top-level lr=.
|
|
1241
|
+
try:
|
|
1242
|
+
self.optimizer = create_loraplus_optimizer(
|
|
1243
|
+
model=self.model,
|
|
1244
|
+
optimizer_cls=opt_cls,
|
|
1245
|
+
optimizer_kwargs={"lr": self.args.learning_rate, **fwd},
|
|
1246
|
+
loraplus_lr_ratio=_lp_ratio,
|
|
1247
|
+
**lp_extra,
|
|
1248
|
+
)
|
|
1249
|
+
except TypeError:
|
|
1250
|
+
self.optimizer = create_loraplus_optimizer(
|
|
1251
|
+
model=self.model,
|
|
1252
|
+
optimizer_cls=opt_cls,
|
|
1253
|
+
lr=self.args.learning_rate,
|
|
1254
|
+
loraplus_lr_ratio=_lp_ratio,
|
|
1255
|
+
**fwd,
|
|
1256
|
+
**lp_extra,
|
|
1257
|
+
)
|
|
1258
|
+
self._loraplus_applied = True
|
|
1259
|
+
print(
|
|
1260
|
+
f"[lora+] optimizer enabled (B-matrix LR ratio={_lp_ratio}, "
|
|
1261
|
+
f"cls={opt_cls.__name__})"
|
|
1262
|
+
)
|
|
1263
|
+
return self.optimizer
|
|
1264
|
+
except Exception as e: # never block training on the LoRA+ wiring
|
|
1265
|
+
print("[lora+] setup failed, falling back to default optimizer:", e)
|
|
1266
|
+
return super().create_optimizer()
|
|
1267
|
+
|
|
1268
|
+
# Pass model as a string id + tokenizer as processing_class so TRL takes the
|
|
1269
|
+
# text/causal-LM path (not the VLM processor path) for this multimodal checkpoint.
|
|
1270
|
+
# SFTTrainer.__init__ blocks for 10-15 min on first use (FA2 CUDA kernel JIT compilation);
|
|
1271
|
+
# without a heartbeat the control plane can't distinguish this from a real hang and may
|
|
1272
|
+
# recycle the worker. A daemon thread pings every 30s so the stall detector stays quiet.
|
|
1273
|
+
_sft_init_done = threading.Event()
|
|
1274
|
+
|
|
1275
|
+
def _sft_init_heartbeat() -> None:
|
|
1276
|
+
while not _sft_init_done.wait(30.0):
|
|
1277
|
+
heartbeat("sft_initializing", gpu=gpu_diagnostics())
|
|
1278
|
+
|
|
1279
|
+
_sft_init_hb = threading.Thread(target=_sft_init_heartbeat, daemon=True)
|
|
1280
|
+
_sft_init_hb.start()
|
|
1281
|
+
try:
|
|
1282
|
+
trainer = _SFT(
|
|
1283
|
+
model=model_id,
|
|
1284
|
+
args=cfg,
|
|
1285
|
+
train_dataset=ds,
|
|
1286
|
+
peft_config=make_lora(model_id),
|
|
1287
|
+
processing_class=tok,
|
|
1288
|
+
# Our block-diagonal collator on the SDPA-packing path; None elsewhere == TRL default.
|
|
1289
|
+
data_collator=_collator,
|
|
1290
|
+
callbacks=[make_sft_heartbeat_callback(), make_checkpoint_upload_callback()],
|
|
1291
|
+
)
|
|
1292
|
+
finally:
|
|
1293
|
+
_sft_init_done.set()
|
|
1294
|
+
# Apply chalk's gap-filling kernels (RoPE/LoRA-delta/embedding, like Liger) on the materialized
|
|
1295
|
+
# SFT trainer.model — chalk's apply patches the LIVE module, so it must run AFTER TRL builds the
|
|
1296
|
+
# model (chalk composes on top of TRL's Liger). No-op unless a FLASH_* kernel flag selects it and
|
|
1297
|
+
# freesolo-chalk is installed.
|
|
1298
|
+
_chalk_report = install_chalk_kernels(getattr(trainer, "model", None))
|
|
1299
|
+
|
|
1300
|
+
_reset_peak_gpu() # so peak_gpu_gb reflects the train loop (optimizer-state A/B is measurable)
|
|
1301
|
+
_gpu_sampler = _GpuPeakSampler().start() # true device peak incl. bnb managed optimizer pages
|
|
1302
|
+
t_train = time.time()
|
|
1303
|
+
with _sdpa_cudnn_ctx(_attn): # force cuDNN SDPA on sm120 (no-op otherwise)
|
|
1304
|
+
trainer.train(resume_from_checkpoint=resume_ckpt)
|
|
1305
|
+
train_wall = time.time() - t_train
|
|
1306
|
+
sft_peak_gpu_gb = _peak_gpu_gb()
|
|
1307
|
+
sft_device_peak_gpu_gb = _gpu_sampler.stop_gb()
|
|
1308
|
+
|
|
1309
|
+
adapter_dir = f"{out_dir}/adapter"
|
|
1310
|
+
trainer.model.save_pretrained(adapter_dir)
|
|
1311
|
+
tok.save_pretrained(adapter_dir)
|
|
1312
|
+
hf_upload_folder(adapter_dir, "adapter", required=True)
|
|
1313
|
+
heartbeat("sft_trained", train_wall=train_wall, gpu=gpu_diagnostics())
|
|
1314
|
+
|
|
1315
|
+
# count train tokens
|
|
1316
|
+
train_tokens = int(sum(len(tok(t["text"])["input_ids"]) for t in texts) * epochs)
|
|
1317
|
+
|
|
1318
|
+
# Write train metadata + the completion sentinel (metrics.json/DONE) for this phase.
|
|
1319
|
+
write_train_meta(
|
|
1320
|
+
phase="sft",
|
|
1321
|
+
adapter_dir=adapter_dir,
|
|
1322
|
+
model_id=model_id,
|
|
1323
|
+
train_wall=train_wall,
|
|
1324
|
+
setup_seconds=setup_seconds,
|
|
1325
|
+
train_tokens=train_tokens,
|
|
1326
|
+
generated_tokens=0,
|
|
1327
|
+
notes={
|
|
1328
|
+
"epochs": epochs,
|
|
1329
|
+
"resumed": bool(resume_ckpt),
|
|
1330
|
+
"download_seconds": download_seconds,
|
|
1331
|
+
"hf_transfer": os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", ""),
|
|
1332
|
+
"thinking": THINKING,
|
|
1333
|
+
# Persist the loss curve so a CONVERGENCE A/B (PiSSA / LoRA+ init, etc.) is measurable
|
|
1334
|
+
# without a checkpoint: trainer_state.json is only written on a save_step, and the
|
|
1335
|
+
# console is only uploaded on failure, so a short successful run otherwise drops its
|
|
1336
|
+
# loss history entirely.
|
|
1337
|
+
"loss_curve": _metric_curve(trainer, "loss"),
|
|
1338
|
+
# Peak torch-allocated GPU memory during the train loop (excludes bnb managed pages, so
|
|
1339
|
+
# it overstates the 8-bit saving — use device_peak_gpu_gb for the true footprint).
|
|
1340
|
+
"peak_gpu_gb": sft_peak_gpu_gb,
|
|
1341
|
+
# True peak device memory (total-free, incl. bnb managed optimizer pages): the honest
|
|
1342
|
+
# headline for the fp32-vs-8-bit LoRA+ optimizer A/B.
|
|
1343
|
+
"device_peak_gpu_gb": sft_device_peak_gpu_gb,
|
|
1344
|
+
# Report the optimizer ACTUALLY built on the trainer, not the planned class: if the
|
|
1345
|
+
# LoRA+ create_optimizer override failed, training falls back to TRL's configured
|
|
1346
|
+
# optimizer without LoRA+ grouping. loraplus_applied records which path actually ran.
|
|
1347
|
+
# Accelerate wraps the optimizer (AcceleratedOptimizer) under transformers 5.x, so unwrap
|
|
1348
|
+
# via `.optimizer` to record the underlying PagedAdamW8bit/AdamW the A/B cares about, not
|
|
1349
|
+
# the wrapper name.
|
|
1350
|
+
"loraplus_optim": (
|
|
1351
|
+
type(getattr(trainer.optimizer, "optimizer", trainer.optimizer)).__name__
|
|
1352
|
+
if getattr(trainer, "optimizer", None) is not None
|
|
1353
|
+
else loraplus_optimizer_cls(fused_optim_name())[0].__name__
|
|
1354
|
+
),
|
|
1355
|
+
"loraplus_applied": getattr(trainer, "_loraplus_applied", False),
|
|
1356
|
+
# Which chalk gap-filling kernels actually ENGAGED (empty/None = chalk not installed or
|
|
1357
|
+
# every kernel fell back) — verifies the chalk stack without the console.
|
|
1358
|
+
"chalk_kernels": active_kernels(_chalk_report) or None,
|
|
1359
|
+
**wandb_run_info(),
|
|
1360
|
+
},
|
|
1361
|
+
)
|
|
1362
|
+
free_gpu(trainer)
|
|
1363
|
+
|
|
1364
|
+
|
|
1365
|
+
# ---------------------------------------------------------------------------
|
|
1366
|
+
# RL (GRPO) with TRL + colocated vLLM
|
|
1367
|
+
# ---------------------------------------------------------------------------
|
|
1368
|
+
def compute_grpo_batching(prompts_per_step: int, group_size: int, per_device_comps: int) -> dict:
|
|
1369
|
+
"""Translate an intended ``prompts_per_step`` into a TRL GRPO batch configuration.
|
|
1370
|
+
|
|
1371
|
+
TRL's GRPO batch sizing is denominated in **completions (prompt-completion pairs), not
|
|
1372
|
+
prompts**. The number of *unique prompts* optimized per step is
|
|
1373
|
+
|
|
1374
|
+
(per_device_train_batch_size * gradient_accumulation_steps * num_processes)
|
|
1375
|
+
/ num_generations
|
|
1376
|
+
|
|
1377
|
+
So to actually optimize ``prompts_per_step`` prompts per step, the global *completion*
|
|
1378
|
+
batch must equal ``prompts_per_step * group_size``. We keep ``per_device`` small (it,
|
|
1379
|
+
not grad-accum, sets peak VRAM) and put the rest in gradient accumulation.
|
|
1380
|
+
|
|
1381
|
+
The bug this fixes: ``grad_accum = prompts_per_step // per_device`` treated
|
|
1382
|
+
``per_device_train_batch_size`` as a *prompt* count, omitting the ``* group_size``
|
|
1383
|
+
factor, so a run intended as 64 prompts/step actually optimized only
|
|
1384
|
+
``64 / group_size = 8`` prompts/step (an 8x smaller effective batch).
|
|
1385
|
+
"""
|
|
1386
|
+
group_size = max(1, int(group_size))
|
|
1387
|
+
prompts_per_step = max(1, int(prompts_per_step))
|
|
1388
|
+
per_device = max(1, int(per_device_comps))
|
|
1389
|
+
target_comps = prompts_per_step * group_size # total completions / optimizer step
|
|
1390
|
+
# Never let the per-device completion micro-batch exceed the target completion batch:
|
|
1391
|
+
# a small prompts_per_step would otherwise overshoot it (mirrors run_sft's
|
|
1392
|
+
# `min(per_device_bs, effective_batch)`). No-op at the default (prompts_per_step=64).
|
|
1393
|
+
per_device = max(1, min(per_device, target_comps))
|
|
1394
|
+
# per_device is the fixed VRAM knob, but when it does NOT divide target_comps neither floor
|
|
1395
|
+
# nor ceil of grad_accum is right: floor (the old bug) silently optimizes FEWER prompts than
|
|
1396
|
+
# requested, while ceil over-shoots and asks TRL for MORE unique prompts than the (already
|
|
1397
|
+
# dataset-capped) prompts_per_step -- which, on a small retained dataset, yields no batches
|
|
1398
|
+
# after the paid worker is provisioned. Instead shrink per_device to the largest divisor of
|
|
1399
|
+
# target_comps that is <= the requested per_device: that lowers (never raises) peak VRAM and
|
|
1400
|
+
# makes per_device * grad_accum == target_comps EXACTLY, so unique prompts == prompts_per_step
|
|
1401
|
+
# with no over/under-shoot. (per_device=16, target_comps=40 -> 10 -> grad_accum=4 -> 40 comps
|
|
1402
|
+
# = exactly 5 prompts. A divisor always exists since 1 divides everything.)
|
|
1403
|
+
while target_comps % per_device != 0:
|
|
1404
|
+
per_device -= 1
|
|
1405
|
+
grad_accum = max(1, target_comps // per_device)
|
|
1406
|
+
# The global completion batch (per_device * grad_accum == target_comps) is divisible by
|
|
1407
|
+
# num_generations (= group_size) by construction, since target_comps = prompts_per_step *
|
|
1408
|
+
# group_size; TRL's divisibility requirement is satisfied with no further rounding.
|
|
1409
|
+
generations_per_step = per_device * grad_accum
|
|
1410
|
+
unique_prompts_per_step = generations_per_step // group_size
|
|
1411
|
+
return {
|
|
1412
|
+
"per_device_train_batch_size": per_device,
|
|
1413
|
+
"gradient_accumulation_steps": grad_accum,
|
|
1414
|
+
"generations_per_step": generations_per_step,
|
|
1415
|
+
"unique_prompts_per_step": unique_prompts_per_step,
|
|
1416
|
+
# TRL requires the global completion batch be divisible by num_generations.
|
|
1417
|
+
"divisible_by_group": (generations_per_step % group_size == 0),
|
|
1418
|
+
}
|
|
1419
|
+
|
|
1420
|
+
|
|
1421
|
+
def resolve_grpo_prompts_per_step(requested: int, available_prompts: int) -> int:
|
|
1422
|
+
"""Cap GRPO's prompt batch to the retained dataset size.
|
|
1423
|
+
|
|
1424
|
+
TRL's GRPO dataloader can yield zero batches when the configured prompt batch is larger
|
|
1425
|
+
than the dataset that remains after prompt-budget filtering. That surfaces late as
|
|
1426
|
+
"There seems not to be a single sample in your epoch_iterator" and then our no-reward guard
|
|
1427
|
+
reports the wrong cause. Small smoke envs should still train; use every retained prompt per
|
|
1428
|
+
step instead of asking TRL for an impossible larger batch.
|
|
1429
|
+
"""
|
|
1430
|
+
requested = max(1, int(requested))
|
|
1431
|
+
available_prompts = int(available_prompts)
|
|
1432
|
+
if available_prompts <= 0:
|
|
1433
|
+
raise ValueError("GRPO needs at least one retained training prompt")
|
|
1434
|
+
return min(requested, available_prompts)
|
|
1435
|
+
|
|
1436
|
+
|
|
1437
|
+
def build_grpo_prompt_dataset(prompts: list[dict]) -> tuple[list[dict], list]:
|
|
1438
|
+
"""Arrow-safe GRPO rollout rows + the parallel example lookup ``reward_fn`` maps back through.
|
|
1439
|
+
|
|
1440
|
+
``Dataset.from_list`` lets PyArrow infer ONE column type per (nested) field across ALL rows, so
|
|
1441
|
+
embedding the rich per-example record makes a *valid* env whose per-row ``info``/``metadata``
|
|
1442
|
+
legitimately mixes types crash dataset construction with ``ArrowInvalid`` — and the whole RL
|
|
1443
|
+
phase dies at startup, AFTER the paid GPU is provisioned, on input that passed offline
|
|
1444
|
+
single-example validation. (Observed with ifeval-lite: ``metadata.param`` is an int target word
|
|
1445
|
+
count for some rows and a required-word string ``'gentle'`` for others; Arrow infers ``int64``
|
|
1446
|
+
from the leading rows then fails on the first string.)
|
|
1447
|
+
|
|
1448
|
+
Fix: keep the dataset columns trivially typed — the TRL-required ``prompt`` plus a stable integer
|
|
1449
|
+
``example_idx`` — and return the original example objects in a parallel list. ``reward_fn`` maps
|
|
1450
|
+
the index back, so the env still sees its EXACT record (no JSON/Arrow round-trip, no type
|
|
1451
|
+
coercion). ``rows[i]["example_idx"] == i`` and ``examples[i]`` is that row's record.
|
|
1452
|
+
"""
|
|
1453
|
+
examples = [p["example"] for p in prompts]
|
|
1454
|
+
rows = [{"prompt": p["prompt"], "example_idx": i} for i, p in enumerate(prompts)]
|
|
1455
|
+
return rows, examples
|
|
1456
|
+
|
|
1457
|
+
|
|
1458
|
+
# Hard ceiling on the per-device completion micro-batch when growing on a SHORT-seq run. MEASURED
|
|
1459
|
+
# (RunPod, Qwen3.5-0.8B GRPO, group8, gsm8k, seq1024, 6 steps): trainer throughput rises from
|
|
1460
|
+
# per_device 4 -> 8 (~+12%) and plateaus 8..16 (A100 80GB: 375/407/411 tok/s at pd 4/8/16), then
|
|
1461
|
+
# REGRESSES at pd 32 (326 tok/s, -20%) as the larger forward stops buying MFU. So we never grow
|
|
1462
|
+
# past the top of that plateau, even on a card with VRAM to spare. (Reward histories at pd 4 and
|
|
1463
|
+
# 16 were identical -> per_device is a pure speed/VRAM knob, not an optimization change.)
|
|
1464
|
+
_RL_PER_DEVICE_MAX = 16
|
|
1465
|
+
# Reference sequence length the activation/VRAM divisor is calibrated at. The colocate activation
|
|
1466
|
+
# peak grows with the training sequence length; the cap is scaled by seq_len/_RL_ACT_SEQ_REF so a
|
|
1467
|
+
# short-seq run (the underfed regime) is allowed a proportionally bigger micro-batch.
|
|
1468
|
+
_RL_ACT_SEQ_REF = 2048.0
|
|
1469
|
+
# VRAM-per-(micro-batch element) divisor at the reference seq, normalized to ~2B width (1.41).
|
|
1470
|
+
# MEASURED: Qwen3.5-2B group8 seq2048 OOMs a 32 GB card at per_device=8 but trains at 4 ->
|
|
1471
|
+
# 32 / (7.5 * 1.0 * 1.0) = 4. (Unchanged from the historical colocate cap, so at/above the
|
|
1472
|
+
# reference seq the value is byte-for-byte the old one — no regression.)
|
|
1473
|
+
_RL_ACT_DIVISOR = 7.5
|
|
1474
|
+
# Floor on the seq scale: caps how far a short sequence may grow the micro-batch. Set so the
|
|
1475
|
+
# underfed case that motivated this — Qwen3.5-0.8B GRPO on a 24 GB card at seq<=1024 — lands on
|
|
1476
|
+
# the MEASURED-SAFE per_device 8 (RunPod RTX 4090 24 GB: pd8 fits at 19.0 GB and is +12.6% over
|
|
1477
|
+
# pd4, while the old seq-independent cap under-fed it at ~5; pd16 there would need ~27 GB -> OOM).
|
|
1478
|
+
# 24 / (7.5 * (0.894/1.41) * 0.63) = 8.0. Bounds short-seq growth to ~1.6x the reference cap.
|
|
1479
|
+
_RL_ACT_SEQ_SCALE_FLOOR = 0.63
|
|
1480
|
+
# Clamp the seq scale at 1.0 (never ABOVE the reference). Combined with the short_seq growth gate,
|
|
1481
|
+
# this makes a seq>=reference run byte-for-byte the old value: seq_scale==1.0 -> vram_cap == the
|
|
1482
|
+
# old colocate cap, and the ceiling falls back to the historical default, so min(default, ...) is
|
|
1483
|
+
# exactly what the old code returned. We deliberately do NOT tighten long-seq below the historical
|
|
1484
|
+
# value (grad checkpointing makes activations sub-linear in seq there, so the linear model would
|
|
1485
|
+
# over-cap), nor grow above it (unvalidated — the regression is in tokens-in-flight = pd x seq).
|
|
1486
|
+
_RL_ACT_SEQ_SCALE_CEIL = 1.0
|
|
1487
|
+
|
|
1488
|
+
|
|
1489
|
+
def rl_per_device_comps(
|
|
1490
|
+
completion_len: int = 0,
|
|
1491
|
+
vocab: int = 248_320,
|
|
1492
|
+
*,
|
|
1493
|
+
use_vllm: bool = True,
|
|
1494
|
+
params_b: float | None = None,
|
|
1495
|
+
seq_len: int = 0,
|
|
1496
|
+
) -> int:
|
|
1497
|
+
"""Per-device *completion* micro-batch for GRPO (TRL counts completions, not prompts).
|
|
1498
|
+
|
|
1499
|
+
This, not grad-accum, sets peak trainer VRAM AND the trainer step's MFU: a bigger
|
|
1500
|
+
micro-batch means bigger, fewer GEMMs (less launch overhead, fuller tensor cores) at the
|
|
1501
|
+
same effective batch (compute_grpo_batching pushes the remainder into grad-accum, so the
|
|
1502
|
+
optimization is identical — only speed/VRAM change). MEASURED on RunPod (Qwen3.5-0.8B GRPO,
|
|
1503
|
+
group8, seq1024): the old seq-independent colocate cap under-fed a 24 GB card at per_device ~5,
|
|
1504
|
+
while per_device 8 fits (19.0 GB) and is +12.6% throughput; on an 80 GB card throughput
|
|
1505
|
+
plateaus at per_device 8..16 and regresses by per_device 32. So on a SHORT-seq run we grow the
|
|
1506
|
+
micro-batch into the card's measured VRAM headroom up to the plateau ceiling.
|
|
1507
|
+
|
|
1508
|
+
Growth is GATED to short sequences (seq < the reference). At/above the reference seq the value
|
|
1509
|
+
is byte-for-byte the historical one — bigger per_device at long context is unvalidated and the
|
|
1510
|
+
regression is driven by tokens-in-flight (per_device x seq), which a fixed-per_device ceiling
|
|
1511
|
+
would not catch.
|
|
1512
|
+
|
|
1513
|
+
Two upper bounds cap the growth:
|
|
1514
|
+
|
|
1515
|
+
* **logits budget (6 GB)** — a HARD correctness cap. The logprob pass can materialize fp32
|
|
1516
|
+
logits of shape [per_device, completion_len, vocab]; at Qwen3.5's ~248k vocab a long
|
|
1517
|
+
completion is enormous (per_device 8 x 4096 tok x 248k x 4 B = ~30 GiB -> OOMs a small
|
|
1518
|
+
card). Liger normally fuses these away, but this stays a safety net for the fallback path.
|
|
1519
|
+
|
|
1520
|
+
* **activation/VRAM cap** — the per-device forward holds the model's attention/activation
|
|
1521
|
+
memory (the Qwen3.5 GDN/FLA kernels peak per micro-batch even with grad checkpointing),
|
|
1522
|
+
which the logits term can't see and which Liger does NOT touch. Calibrated against the live
|
|
1523
|
+
card's VRAM, model width (~sqrt(params)), and — unlike the old seq-independent cap — the
|
|
1524
|
+
training sequence length: activations scale ~linearly with seq, so a SHORT-seq run gets a
|
|
1525
|
+
proportionally bigger cap. MEASURED at seq_ref=2048: Qwen3.5-2B (width ~1.41) group8 OOMs a
|
|
1526
|
+
32 GB card at per_device=8 but trains at 4 -> 32 / 7.5 = 4.
|
|
1527
|
+
|
|
1528
|
+
Off a live card (allocator / unit tests) there is no VRAM signal, so we fall back to the
|
|
1529
|
+
conservative historical default (8, or 2 with thinking) bounded by the logits budget — the
|
|
1530
|
+
allocator already provisions for that floor, and the worker only ever grows INTO the spare
|
|
1531
|
+
VRAM the chosen card actually reports, so it cannot over-fill the card it was routed to.
|
|
1532
|
+
"""
|
|
1533
|
+
default = 2 if THINKING else 8
|
|
1534
|
+
|
|
1535
|
+
# Logits budget: hard upper bound on the fp32 [per_device, completion, vocab] logprob tensor.
|
|
1536
|
+
logits_cap = _RL_PER_DEVICE_MAX
|
|
1537
|
+
if completion_len > 0:
|
|
1538
|
+
logits_cap = max(1, int(6.0e9 / (max(1, completion_len) * vocab * 4)))
|
|
1539
|
+
|
|
1540
|
+
# Growth is gated to SHORT sequences (seq < the reference). At/above the reference seq the
|
|
1541
|
+
# micro-batch is left exactly as the historical code computed it: bigger per_device at long
|
|
1542
|
+
# context is unvalidated and risky — the measured throughput regression is driven by
|
|
1543
|
+
# tokens-in-flight (per_device x seq), so per_device 16 at seq 2048 (~the regression-zone
|
|
1544
|
+
# per_device 32 at seq 1024) could regress, and a fixed-per_device ceiling would not catch it.
|
|
1545
|
+
short_seq = (seq_len or _RL_ACT_SEQ_REF) < _RL_ACT_SEQ_REF
|
|
1546
|
+
|
|
1547
|
+
# Activation/VRAM cap — only computable on a live card. It both caps DOWN (big model / small
|
|
1548
|
+
# card / long seq) and, on a SHORT-seq run, lets the micro-batch GROW into spare VRAM.
|
|
1549
|
+
vram_cap = None
|
|
1550
|
+
if use_vllm:
|
|
1551
|
+
try:
|
|
1552
|
+
import torch
|
|
1553
|
+
|
|
1554
|
+
if torch.cuda.is_available():
|
|
1555
|
+
vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
|
1556
|
+
width = (max(float(params_b), 0.1) ** 0.5) if params_b else 1.41
|
|
1557
|
+
seq_scale = min(
|
|
1558
|
+
_RL_ACT_SEQ_SCALE_CEIL,
|
|
1559
|
+
max(_RL_ACT_SEQ_SCALE_FLOOR, (seq_len or _RL_ACT_SEQ_REF) / _RL_ACT_SEQ_REF),
|
|
1560
|
+
)
|
|
1561
|
+
vram_cap = max(
|
|
1562
|
+
1, int(vram_gb / (_RL_ACT_DIVISOR * (width / 1.41) * seq_scale))
|
|
1563
|
+
)
|
|
1564
|
+
except Exception as e:
|
|
1565
|
+
print("rl_per_device_comps colocate cap probe failed (keeping logits cap):", e)
|
|
1566
|
+
|
|
1567
|
+
if vram_cap is None:
|
|
1568
|
+
# No live card (allocator / offline / unit tests): conservative default, logits-bounded.
|
|
1569
|
+
return max(1, min(default, logits_cap))
|
|
1570
|
+
# Short seq -> grow into measured VRAM headroom up to the plateau ceiling. At/above the
|
|
1571
|
+
# reference seq the ceiling is the historical default, and seq_scale is clamped to 1.0 so
|
|
1572
|
+
# vram_cap == the old colocate cap -> the result is byte-for-byte the old value (no regression,
|
|
1573
|
+
# no unvalidated long-seq growth).
|
|
1574
|
+
#
|
|
1575
|
+
# THINKING runs are EXCLUDED from the growth path: they emit long completions whose
|
|
1576
|
+
# activation/logprob cost the prompt-only `seq_len` gate cannot see, so letting short-seq
|
|
1577
|
+
# growth raise the ceiling to _RL_PER_DEVICE_MAX would silently override the conservative
|
|
1578
|
+
# thinking default (2) and risk OOM / unstable training. They keep `default` as the ceiling,
|
|
1579
|
+
# i.e. byte-for-byte the historical value.
|
|
1580
|
+
ceiling = _RL_PER_DEVICE_MAX if (short_seq and not THINKING) else default
|
|
1581
|
+
return max(1, min(ceiling, logits_cap, vram_cap))
|
|
1582
|
+
|
|
1583
|
+
|
|
1584
|
+
_STEP_GPU_DIAG_INTERVAL_S = 300.0
|
|
1585
|
+
_SFT_HEARTBEAT_INTERVAL_S = 60.0
|
|
1586
|
+
|
|
1587
|
+
|
|
1588
|
+
def make_reward_heartbeat_callback():
|
|
1589
|
+
"""A TRL/transformers callback that streams the per-step mean reward to the HF heartbeat
|
|
1590
|
+
channel, giving the worker a live RL signal (no pod log API) and recording a
|
|
1591
|
+
``reward_history``. Built lazily so the module imports without transformers installed."""
|
|
1592
|
+
from transformers import TrainerCallback
|
|
1593
|
+
|
|
1594
|
+
class _RewardHeartbeat(TrainerCallback):
|
|
1595
|
+
def __init__(self):
|
|
1596
|
+
self.reward_history = []
|
|
1597
|
+
self.last_gpu_diag_at = 0.0
|
|
1598
|
+
|
|
1599
|
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
|
1600
|
+
if not logs:
|
|
1601
|
+
return
|
|
1602
|
+
r = logs.get("reward")
|
|
1603
|
+
if r is None:
|
|
1604
|
+
return
|
|
1605
|
+
try:
|
|
1606
|
+
r = float(r)
|
|
1607
|
+
except (TypeError, ValueError):
|
|
1608
|
+
return
|
|
1609
|
+
self.reward_history.append(r)
|
|
1610
|
+
step = int(getattr(state, "global_step", len(self.reward_history)))
|
|
1611
|
+
payload = {
|
|
1612
|
+
"step": step,
|
|
1613
|
+
"reward": r,
|
|
1614
|
+
"reward_last": self.reward_history[-8:],
|
|
1615
|
+
}
|
|
1616
|
+
now = time.monotonic()
|
|
1617
|
+
if (
|
|
1618
|
+
self.last_gpu_diag_at == 0.0
|
|
1619
|
+
or now - self.last_gpu_diag_at >= _STEP_GPU_DIAG_INTERVAL_S
|
|
1620
|
+
):
|
|
1621
|
+
payload["gpu"] = gpu_diagnostics()
|
|
1622
|
+
self.last_gpu_diag_at = now
|
|
1623
|
+
heartbeat("rl_step", **payload)
|
|
1624
|
+
|
|
1625
|
+
return _RewardHeartbeat()
|
|
1626
|
+
|
|
1627
|
+
|
|
1628
|
+
def make_sft_heartbeat_callback():
|
|
1629
|
+
"""Stream SFT trainer logs so a run is not silent between model load and completion."""
|
|
1630
|
+
from transformers import TrainerCallback
|
|
1631
|
+
|
|
1632
|
+
class _SFTHeartbeat(TrainerCallback):
|
|
1633
|
+
def __init__(self):
|
|
1634
|
+
self.last_heartbeat_at = 0.0
|
|
1635
|
+
self.last_gpu_diag_at = 0.0
|
|
1636
|
+
|
|
1637
|
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
|
1638
|
+
if not logs:
|
|
1639
|
+
return
|
|
1640
|
+
now = time.monotonic()
|
|
1641
|
+
if self.last_heartbeat_at and now - self.last_heartbeat_at < _SFT_HEARTBEAT_INTERVAL_S:
|
|
1642
|
+
return
|
|
1643
|
+
self.last_heartbeat_at = now
|
|
1644
|
+
payload = {
|
|
1645
|
+
"step": int(getattr(state, "global_step", 0) or 0),
|
|
1646
|
+
"epoch": logs.get("epoch"),
|
|
1647
|
+
"loss": logs.get("loss"),
|
|
1648
|
+
"grad_norm": logs.get("grad_norm"),
|
|
1649
|
+
"learning_rate": logs.get("learning_rate"),
|
|
1650
|
+
}
|
|
1651
|
+
if (
|
|
1652
|
+
self.last_gpu_diag_at == 0.0
|
|
1653
|
+
or now - self.last_gpu_diag_at >= _STEP_GPU_DIAG_INTERVAL_S
|
|
1654
|
+
):
|
|
1655
|
+
payload["gpu"] = gpu_diagnostics()
|
|
1656
|
+
self.last_gpu_diag_at = now
|
|
1657
|
+
heartbeat("sft_step", **{k: v for k, v in payload.items() if v is not None})
|
|
1658
|
+
|
|
1659
|
+
return _SFTHeartbeat()
|
|
1660
|
+
|
|
1661
|
+
|
|
1662
|
+
def grpo_overrides() -> dict:
|
|
1663
|
+
"""The GRPO recipe knobs, read off the job spec's ``[train]`` table (``TrainSpec``).
|
|
1664
|
+
A field left unset (None) is omitted here so the recipe default applies downstream.
|
|
1665
|
+
|
|
1666
|
+
Knobs: group_size, temperature, max_tokens (completion budget), kl_penalty_coef (the KL
|
|
1667
|
+
beta), advantage_clip (centered-advantage clip), and thinking_length_penalty_coef
|
|
1668
|
+
(a per-<think>-token reward deduction). These live in ``[train]`` — NOT in
|
|
1669
|
+
``[environment.params]``, which is forwarded verbatim to the Freesolo env loader."""
|
|
1670
|
+
if not JOB_SPEC:
|
|
1671
|
+
return {}
|
|
1672
|
+
train = JOB_SPEC.train
|
|
1673
|
+
cfg = {
|
|
1674
|
+
"group_size": train.group_size,
|
|
1675
|
+
"temperature": train.temperature,
|
|
1676
|
+
"max_tokens": train.max_tokens,
|
|
1677
|
+
"kl_penalty_coef": train.kl_penalty_coef,
|
|
1678
|
+
"advantage_clip": train.advantage_clip,
|
|
1679
|
+
"thinking_length_penalty_coef": train.thinking_length_penalty_coef,
|
|
1680
|
+
}
|
|
1681
|
+
return {k: v for k, v in cfg.items() if v is not None}
|
|
1682
|
+
|
|
1683
|
+
|
|
1684
|
+
def think_token_count(completion: str | None, tokenizer) -> int:
|
|
1685
|
+
"""Number of tokens inside the completion's <think>...</think> span (0 if none).
|
|
1686
|
+
|
|
1687
|
+
Used for the thinking-length reward deduction: long reasoning is penalized in
|
|
1688
|
+
proportion to the tokens it spent, mirroring the SDK's thinking_length_penalty_coef.
|
|
1689
|
+
"""
|
|
1690
|
+
if not completion or "<think>" not in completion:
|
|
1691
|
+
return 0
|
|
1692
|
+
after = completion.split("<think>", 1)[1]
|
|
1693
|
+
think_text = after.split("</think>", 1)[0] if "</think>" in after else after
|
|
1694
|
+
if not think_text:
|
|
1695
|
+
return 0
|
|
1696
|
+
return len(tokenizer(think_text, add_special_tokens=False)["input_ids"])
|
|
1697
|
+
|
|
1698
|
+
|
|
1699
|
+
def _init_adapter_model(model_id: str):
|
|
1700
|
+
"""Base model + the ``train.init_from_adapter`` adapter loaded as a trainable
|
|
1701
|
+
PeftModel, or the plain ``model_id`` string + a fresh LoRA when it is unset.
|
|
1702
|
+
|
|
1703
|
+
GRPO continuing an SFT adapter: TRL trains the LOADED adapter (peft_config=None)
|
|
1704
|
+
instead of attaching a fresh one."""
|
|
1705
|
+
prefix = JOB_SPEC.train.init_from_adapter if JOB_SPEC else ""
|
|
1706
|
+
if not prefix:
|
|
1707
|
+
return model_id, make_lora(model_id)
|
|
1708
|
+
adir = _download_adapter(prefix)
|
|
1709
|
+
if not adir:
|
|
1710
|
+
# The user explicitly asked GRPO to continue from this adapter; silently
|
|
1711
|
+
# falling back to a fresh base-model LoRA would spend a full paid run
|
|
1712
|
+
# optimizing the wrong starting point. Fail hard instead.
|
|
1713
|
+
raise RuntimeError(
|
|
1714
|
+
f"train.init_from_adapter={prefix!r} could not be downloaded from the artifact "
|
|
1715
|
+
"store (wrong/missing prefix or no access); refusing to silently start GRPO from "
|
|
1716
|
+
"the base model. Fix the adapter prefix / HF credentials, or omit "
|
|
1717
|
+
"init_from_adapter to train a fresh LoRA."
|
|
1718
|
+
)
|
|
1719
|
+
from peft import PeftModel
|
|
1720
|
+
from transformers import AutoModelForCausalLM
|
|
1721
|
+
|
|
1722
|
+
print(f"[init-adapter] initializing LoRA from {prefix}")
|
|
1723
|
+
# VL checkpoints (Qwen3.5/3.6): the SFT step saved the adapter against the FULL multimodal model
|
|
1724
|
+
# (keys under ``base_model.model.model.language_model.layers.*``), but we load the base here via
|
|
1725
|
+
# AutoModelForCausalLM (text-only tree, ``base_model.model.model.layers.*``). Strip the
|
|
1726
|
+
# ``.language_model.`` infix on disk so PeftModel.from_pretrained matches the SFT keys —
|
|
1727
|
+
# otherwise peft only WARNS about missing keys and silently trains a fresh LoRA, discarding the
|
|
1728
|
+
# SFT. No-op for non-VL checkpoints. See flash/engine/worker/lora.py.
|
|
1729
|
+
remap_vl_adapter_dir(adir, model_id)
|
|
1730
|
+
_attn = optimal_attn_impl()
|
|
1731
|
+
base = AutoModelForCausalLM.from_pretrained(
|
|
1732
|
+
model_id,
|
|
1733
|
+
dtype="bfloat16",
|
|
1734
|
+
trust_remote_code=True,
|
|
1735
|
+
**({"attn_implementation": _attn} if _attn else {}),
|
|
1736
|
+
)
|
|
1737
|
+
model = PeftModel.from_pretrained(base, adir, is_trainable=True)
|
|
1738
|
+
# Fail loudly if the adapter didn't actually apply (a key mismatch would otherwise silently start
|
|
1739
|
+
# GRPO from the base model again). from_pretrained loads with load_state_dict(strict=False) and
|
|
1740
|
+
# only WARNS on a mismatch, discarding the load result — so re-run load_adapter to CAPTURE which
|
|
1741
|
+
# keys matched and assert matched==saved (peft injects the LoRA modules from target_modules BEFORE
|
|
1742
|
+
# loading weights, so the module-count check alone can't see a silent weight discard). The reload
|
|
1743
|
+
# is idempotent: same weights into the same "default" adapter. See flash/engine/worker/lora.py.
|
|
1744
|
+
# Mirror from_pretrained's key_mapping: for transformers models that define a
|
|
1745
|
+
# ``_checkpoint_conversion_mapping`` (renamed-arch checkpoints), from_pretrained remaps the adapter
|
|
1746
|
+
# keys before loading; the reload must apply the SAME mapping or it would reinterpret valid keys as
|
|
1747
|
+
# mismatched and falsely abort. peft reads it off the base model (peft_model.py from_pretrained).
|
|
1748
|
+
key_mapping = getattr(base, "_checkpoint_conversion_mapping", None)
|
|
1749
|
+
load_result = model.load_adapter(
|
|
1750
|
+
adir, adapter_name="default", is_trainable=True, key_mapping=key_mapping
|
|
1751
|
+
)
|
|
1752
|
+
assert_adapter_load_clean(load_result, model_id)
|
|
1753
|
+
assert_lora_applied(model, model_id)
|
|
1754
|
+
assert_adapter_delta_nonzero(model, model_id)
|
|
1755
|
+
return model, None
|
|
1756
|
+
|
|
1757
|
+
|
|
1758
|
+
def _grpo_resume_already_complete(resume_ckpt, target_steps: int, steps_run: int) -> bool:
|
|
1759
|
+
"""True when this worker resumed a checkpoint that already reached the target step count.
|
|
1760
|
+
|
|
1761
|
+
Such a resume legitimately performs ZERO new optimizer steps (so the fresh hb_cb has an empty
|
|
1762
|
+
reward_history) yet the policy IS fully trained — it must NOT be flagged as a no-op failure.
|
|
1763
|
+
"""
|
|
1764
|
+
return bool(resume_ckpt) and target_steps > 0 and steps_run >= target_steps
|
|
1765
|
+
|
|
1766
|
+
|
|
1767
|
+
def _grpo_is_no_op_failure(reward_history, resume_ckpt, target_steps: int, steps_run: int) -> bool:
|
|
1768
|
+
"""True when a GRPO run trained NOTHING and must fail loudly instead of reporting as done.
|
|
1769
|
+
|
|
1770
|
+
An empty ``reward_history`` means the reward callback never fired — the rollout scored nothing
|
|
1771
|
+
(e.g. vLLM silently returning no completions), so no real training happened. The sole exception
|
|
1772
|
+
is a resume that already reached the target steps (see ``_grpo_resume_already_complete``): that
|
|
1773
|
+
has an empty fresh history but a fully-trained policy, so it is NOT a failure.
|
|
1774
|
+
"""
|
|
1775
|
+
if reward_history:
|
|
1776
|
+
return False
|
|
1777
|
+
return not _grpo_resume_already_complete(resume_ckpt, target_steps, steps_run)
|
|
1778
|
+
|
|
1779
|
+
|
|
1780
|
+
def run_rl():
|
|
1781
|
+
from datasets import Dataset
|
|
1782
|
+
from transformers import AutoTokenizer
|
|
1783
|
+
from trl import GRPOConfig, GRPOTrainer
|
|
1784
|
+
|
|
1785
|
+
env = require_active_env() # fail loudly (not AttributeError: NoneType) on the no-JobSpec path
|
|
1786
|
+
t_start = time.time()
|
|
1787
|
+
heartbeat("rl_start", gpu=gpu_diagnostics())
|
|
1788
|
+
# GRPO rollout strategy by env shape (trl 1.6 adds the hooks these need):
|
|
1789
|
+
# * single-turn -> TRL single-shot generation + per-completion reward (below);
|
|
1790
|
+
# * tool (ToolEnv & subs:
|
|
1791
|
+
# Stateful/Sandbox/Python) -> TRL drives the tool-call loop natively via
|
|
1792
|
+
# GRPOTrainer(tools=...) (it parses tool calls, executes the tools, and masks the
|
|
1793
|
+
# tool-result tokens itself); the reward scores the full transcript;
|
|
1794
|
+
# * pure multi-turn -> a custom rollout_func (flash.engine.multiturn_rollout)
|
|
1795
|
+
# drives THIS env's turn loop on the colocate engine and returns the interleaved
|
|
1796
|
+
# token sequence with an env_mask so only the model's tokens are trained.
|
|
1797
|
+
is_tool_env = getattr(env, "is_tool_env", False)
|
|
1798
|
+
is_multi_turn = getattr(env, "multi_turn", False)
|
|
1799
|
+
conversational = is_multi_turn # message-list prompts (tool + pure multi-turn) vs strings
|
|
1800
|
+
if is_multi_turn:
|
|
1801
|
+
# The Liger fused GRPO loss (use_liger_kernel, kept ON to avoid the 248k-vocab fp32-logits
|
|
1802
|
+
# OOM) torch.compiles, and on the VARIABLE-length multi-turn completions its dynamo guard
|
|
1803
|
+
# build trips a torch 2.10 bug (symbol_to_source IndexError) that crashes the first
|
|
1804
|
+
# training step. Let dynamo FALL BACK TO EAGER for the offending function instead of
|
|
1805
|
+
# raising. This is NOT `TORCHDYNAMO_DISABLE` (which would also break the colocate vLLM
|
|
1806
|
+
# engine's required compilation) — dynamo stays enabled; only erroring graphs run eager.
|
|
1807
|
+
try:
|
|
1808
|
+
import torch._dynamo
|
|
1809
|
+
|
|
1810
|
+
torch._dynamo.config.suppress_errors = True
|
|
1811
|
+
print("[rl] multi-turn: torch._dynamo suppress_errors=True (Liger loss falls back to eager on dynamic shapes)")
|
|
1812
|
+
except Exception as exc: # never let a torch internals change block the run
|
|
1813
|
+
print(f"[rl] could not set torch._dynamo.suppress_errors: {exc!r}")
|
|
1814
|
+
wait_for_gpu()
|
|
1815
|
+
setup_perf_backends()
|
|
1816
|
+
model_id = JOB_SPEC.model if JOB_SPEC else RECIPE.hf_model_id
|
|
1817
|
+
download_seconds = prefetch_model(model_id)
|
|
1818
|
+
rl = RECIPE.rl
|
|
1819
|
+
# Steps come from the run's [train] steps (already in JOB_SPEC), else the recipe default.
|
|
1820
|
+
steps = int(
|
|
1821
|
+
JOB_SPEC.train.steps if JOB_SPEC and JOB_SPEC.train.steps is not None else rl.num_steps
|
|
1822
|
+
)
|
|
1823
|
+
# Throughput/quality knobs: the number of prompts optimized per step, completions per
|
|
1824
|
+
# prompt, and whether vLLM offloads weights between steps. Sleep mode frees memory for the
|
|
1825
|
+
# optimizer but reloads ~weights each step (a large per-step cost); it's gated OFF by model
|
|
1826
|
+
# size when both the policy and rollout engine fit resident.
|
|
1827
|
+
gcfg = grpo_overrides()
|
|
1828
|
+
_t = JOB_SPEC.train if JOB_SPEC else None
|
|
1829
|
+
# batch_size = prompts per optimizer step for GRPO.
|
|
1830
|
+
# prompts per optimizer step = the run config's [train].batch_size (recipe default otherwise).
|
|
1831
|
+
prompts_per_step = int(
|
|
1832
|
+
_t.batch_size if _t and _t.batch_size is not None else rl.prompts_per_step
|
|
1833
|
+
)
|
|
1834
|
+
group_size = int(gcfg.get("group_size") or rl.group_size)
|
|
1835
|
+
# temperature: explicit None check, NOT `or` — a configured 0.0 (greedy/deterministic
|
|
1836
|
+
# rollouts) must be honored, not fall back to the recipe sampling temperature.
|
|
1837
|
+
_gcfg_temp = gcfg.get("temperature")
|
|
1838
|
+
_temperature = float(_gcfg_temp if _gcfg_temp is not None else rl.sampling_temperature)
|
|
1839
|
+
_kl_beta = float(gcfg.get("kl_penalty_coef") or 0.0)
|
|
1840
|
+
_adv_clip = float(gcfg.get("advantage_clip") or 0.0)
|
|
1841
|
+
_think_penalty = float(gcfg.get("thinking_length_penalty_coef") or 0.0)
|
|
1842
|
+
# vLLM sleep mode offloads the rollout engine's weights between steps to free memory for the
|
|
1843
|
+
# optimizer, but reloading each step is a large per-step cost (PR #174 measured ~2-2.6x faster
|
|
1844
|
+
# GRPO with it OFF on models that fit) AND on the large-model GRPO path the sleep/wake cycle
|
|
1845
|
+
# STALLS the colocated rollout (the rollout emits unparseable completions, then the worker
|
|
1846
|
+
# hangs mid-training). So enable sleep only when the run genuinely can't fit RESIDENT on THIS
|
|
1847
|
+
# card: large/long-context AND the policy + colocated rollout engine + training peak don't fit
|
|
1848
|
+
# on the live GPU. When they fit (the common allocator-sized case), skip sleep entirely.
|
|
1849
|
+
_grpo_ctx = int(_t.max_length if _t and _t.max_length else 0)
|
|
1850
|
+
_card_vram_gb = 0.0
|
|
1851
|
+
try:
|
|
1852
|
+
import torch as _torch_card
|
|
1853
|
+
|
|
1854
|
+
if _torch_card.cuda.is_available():
|
|
1855
|
+
# Binary GiB (/(1024**3)), NOT decimal GB (/1e9 over-reports ~7%): grpo_fits_resident's
|
|
1856
|
+
# VRAM estimate is in GiB, so a decimal card size would make a marginal card look big
|
|
1857
|
+
# enough to fit resident and wrongly disable sleep, risking OOM.
|
|
1858
|
+
_card_vram_gb = _torch_card.cuda.get_device_properties(0).total_memory / (1024**3)
|
|
1859
|
+
except Exception as _e:
|
|
1860
|
+
print("[rl] card VRAM probe failed (sleep-mode gate falls back to size/context):", _e)
|
|
1861
|
+
_lora_rank = int(_t.lora_rank) if _t and _t.lora_rank else 32
|
|
1862
|
+
sleep_mode = grpo_sleep_mode(
|
|
1863
|
+
model_id,
|
|
1864
|
+
max_length=_grpo_ctx,
|
|
1865
|
+
group_size=group_size,
|
|
1866
|
+
max_tokens=gcfg.get("max_tokens"),
|
|
1867
|
+
lora_rank=_lora_rank,
|
|
1868
|
+
thinking=THINKING,
|
|
1869
|
+
card_vram_gb=_card_vram_gb,
|
|
1870
|
+
)
|
|
1871
|
+
print(
|
|
1872
|
+
f"[rl] vLLM sleep mode = {sleep_mode} "
|
|
1873
|
+
f"(model={model_id}, ctx={_grpo_ctx}, card={_card_vram_gb:.0f}GB)"
|
|
1874
|
+
)
|
|
1875
|
+
# Rollout backend: always colocated vLLM (fast). The whole supported catalog runs GRPO with
|
|
1876
|
+
# colocated vLLM; there is no transformers-generation fallback.
|
|
1877
|
+
use_vllm = True
|
|
1878
|
+
print("[rl] rollout backend: colocated vLLM")
|
|
1879
|
+
from flash.catalog import MODELS as _CATALOG
|
|
1880
|
+
|
|
1881
|
+
_info = _CATALOG.get(model_id)
|
|
1882
|
+
tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
|
1883
|
+
if tok.pad_token is None:
|
|
1884
|
+
tok.pad_token = tok.eos_token
|
|
1885
|
+
|
|
1886
|
+
train = env.dataset()
|
|
1887
|
+
rng = random.Random(SEED)
|
|
1888
|
+
rng.shuffle(train)
|
|
1889
|
+
if conversational:
|
|
1890
|
+
# Message-list prompts so the chat template applies roles + (for tool envs) the tool
|
|
1891
|
+
# schemas; per-turn length is managed by the tool loop / rollout_func, not a flat budget.
|
|
1892
|
+
prompts = [{"prompt": env.prompt_messages(ex), "example": ex} for ex in train]
|
|
1893
|
+
else:
|
|
1894
|
+
prompts = [{"prompt": render_prompt(tok, ex), "example": ex} for ex in train]
|
|
1895
|
+
# The colocated vLLM engine's model length is the hard cap on prompt+completion at
|
|
1896
|
+
# rollout. Size it from [train].max_length and derive the prompt budget from it so a
|
|
1897
|
+
# bigger engine or a smaller completion automatically admits longer prompts (rather than
|
|
1898
|
+
# a fixed rl.max_prompt_len that no env override could lift).
|
|
1899
|
+
_max_completion = int(
|
|
1900
|
+
gcfg.get("max_tokens")
|
|
1901
|
+
or (rl.max_completion_len_thinking if THINKING else rl.max_completion_len)
|
|
1902
|
+
)
|
|
1903
|
+
# Engine context = the run's [train].max_length (so a long-context GRPO config sized/paid for
|
|
1904
|
+
# by the allocator actually RUNS at that length), else the recipe default. Without the
|
|
1905
|
+
# train.max_length fallback the allocator provisions a big GPU for the long context but the
|
|
1906
|
+
# engine runs short — paying for headroom we never use.
|
|
1907
|
+
_train_ctx = _t.max_length if (_t and _t.max_length) else 0
|
|
1908
|
+
vllm_max_len = int(_train_ctx or max(1024, rl.max_prompt_len + _max_completion))
|
|
1909
|
+
# The engine must fit completion + at least some prompt. If [train].max_length is below the
|
|
1910
|
+
# completion budget, no prompt can ever fit — fail fast here rather than passing a 1-token
|
|
1911
|
+
# budget that lets prompts through and then OOMs/overflows mid-rollout.
|
|
1912
|
+
if vllm_max_len <= _max_completion:
|
|
1913
|
+
raise ValueError(
|
|
1914
|
+
f"engine length {vllm_max_len} leaves no room for the {_max_completion}-token "
|
|
1915
|
+
"completion; raise [train].max_length or lower [train].max_tokens"
|
|
1916
|
+
)
|
|
1917
|
+
prompt_budget = vllm_max_len - _max_completion
|
|
1918
|
+
|
|
1919
|
+
# TRL 1.5's GRPOConfig has no max_prompt_length and does NOT truncate prompts, so a prompt
|
|
1920
|
+
# that leaves no room for the completion within the engine length would fail mid-rollout
|
|
1921
|
+
# AFTER the paid worker is provisioned. Drop prompts that don't fit the budget up front.
|
|
1922
|
+
# render_prompt returns an apply_chat_template(tokenize=False) string that already carries
|
|
1923
|
+
# the special tokens, so tokenize with add_special_tokens=False (the default re-adds
|
|
1924
|
+
# BOS/EOS and over-counts).
|
|
1925
|
+
# Drop prompts that leave no room for the completion within the engine length — applies to
|
|
1926
|
+
# BOTH single-turn (string prompts) and conversational (message-list) prompts, so a tool /
|
|
1927
|
+
# multi-turn rollout can't overflow the colocate engine mid-generation. Conversational
|
|
1928
|
+
# prompts are length-checked via the chat template (with the generation prompt).
|
|
1929
|
+
# Tool schemas TRL injects into the prompt for native tools= GRPO — include them in the
|
|
1930
|
+
# budget for a tool env so a prompt isn't undercounted at filter time vs. rollout time.
|
|
1931
|
+
_oai_tools = (
|
|
1932
|
+
getattr(getattr(env, "_env", None), "oai_tools", None) if is_tool_env else None
|
|
1933
|
+
)
|
|
1934
|
+
|
|
1935
|
+
def _prompt_tokens(p) -> int:
|
|
1936
|
+
if conversational:
|
|
1937
|
+
# Render to text then tokenize — the SAME path the rollout uses — so the filter
|
|
1938
|
+
# count matches the rollout's count (avoids a tokenize=True vs text mismatch).
|
|
1939
|
+
kw = {"tools": _oai_tools} if _oai_tools else {}
|
|
1940
|
+
try:
|
|
1941
|
+
text = tok.apply_chat_template(
|
|
1942
|
+
p["prompt"],
|
|
1943
|
+
add_generation_prompt=True,
|
|
1944
|
+
tokenize=False,
|
|
1945
|
+
enable_thinking=THINKING,
|
|
1946
|
+
**kw,
|
|
1947
|
+
)
|
|
1948
|
+
except Exception as exc:
|
|
1949
|
+
# Fail fast WITH context: a tokenizer/template incompatibility would render every
|
|
1950
|
+
# prompt uncountable and otherwise surface as a misleading "all prompts exceed
|
|
1951
|
+
# budget" — raise so the model/template can be fixed before a paid run trains on
|
|
1952
|
+
# a degenerate dataset.
|
|
1953
|
+
raise RuntimeError(
|
|
1954
|
+
"failed to render a conversational prompt with this model's chat template "
|
|
1955
|
+
f"(fix the model/template or the env's prompts): {exc}"
|
|
1956
|
+
) from exc
|
|
1957
|
+
return len(tok(text, add_special_tokens=False).input_ids)
|
|
1958
|
+
return len(tok(p["prompt"], add_special_tokens=False).input_ids)
|
|
1959
|
+
|
|
1960
|
+
kept = [p for p in prompts if 0 < _prompt_tokens(p) <= prompt_budget]
|
|
1961
|
+
if len(kept) < len(prompts):
|
|
1962
|
+
print(
|
|
1963
|
+
f"[rl] dropped {len(prompts) - len(kept)} prompts over the {prompt_budget}-token "
|
|
1964
|
+
f"prompt budget (engine {vllm_max_len} - completion {_max_completion})"
|
|
1965
|
+
)
|
|
1966
|
+
if not kept:
|
|
1967
|
+
raise ValueError(
|
|
1968
|
+
f"every training prompt exceeds the {prompt_budget}-token prompt budget (engine "
|
|
1969
|
+
f"{vllm_max_len} - completion {_max_completion}); raise [train].max_length, lower "
|
|
1970
|
+
"[train].max_tokens, or shorten the environment's prompts"
|
|
1971
|
+
)
|
|
1972
|
+
prompts = kept
|
|
1973
|
+
resolved_prompts_per_step = resolve_grpo_prompts_per_step(prompts_per_step, len(prompts))
|
|
1974
|
+
if resolved_prompts_per_step != prompts_per_step:
|
|
1975
|
+
print(
|
|
1976
|
+
f"[rl] lowering prompts_per_step from {prompts_per_step} to "
|
|
1977
|
+
f"{resolved_prompts_per_step}: only {len(prompts)} prompt(s) fit after filtering"
|
|
1978
|
+
)
|
|
1979
|
+
prompts_per_step = resolved_prompts_per_step
|
|
1980
|
+
# Carry a stable integer index instead of the rich record so PyArrow can't crash on an env whose
|
|
1981
|
+
# per-row info/metadata legitimately mixes types (see build_grpo_prompt_dataset). reward_fn maps
|
|
1982
|
+
# the index back to the original example object below.
|
|
1983
|
+
ds_rows, rollout_examples = build_grpo_prompt_dataset(prompts)
|
|
1984
|
+
ds = Dataset.from_list(ds_rows)
|
|
1985
|
+
|
|
1986
|
+
def reward_fn(completions, **kwargs):
|
|
1987
|
+
# rollout_func (pure multi-turn) path: the per-rollout reward is computed by the env
|
|
1988
|
+
# during the rollout and forwarded as the "reward" extra field — pass it through.
|
|
1989
|
+
if kwargs.get("reward") is not None:
|
|
1990
|
+
return [float(r) for r in kwargs["reward"]]
|
|
1991
|
+
# Score the <think>-stripped text (graded_text), then — datums parity — deduct
|
|
1992
|
+
# the thinking-length penalty computed from the RAW completion's <think> span.
|
|
1993
|
+
# The dataset carries example_idx (not the record); map each back to its original object.
|
|
1994
|
+
# Fail LOUD if TRL stops forwarding example_idx (column pruning / a TRL change): defaulting to
|
|
1995
|
+
# [] would zip to ZERO examples -> empty rewards -> silent no-op / broken training (issues
|
|
1996
|
+
# #206 / #210). A reward over the wrong/empty examples is far worse than crashing the run.
|
|
1997
|
+
example_idx = kwargs.get("example_idx")
|
|
1998
|
+
if example_idx is None:
|
|
1999
|
+
raise RuntimeError(
|
|
2000
|
+
"GRPO reward_fn received no 'example_idx' column from TRL — the reward cannot be "
|
|
2001
|
+
"mapped back to its training example, so every reward would be empty/misaligned "
|
|
2002
|
+
f"(got kwargs keys {sorted(kwargs)}). This usually means TRL dropped the dataset "
|
|
2003
|
+
"column (remove_unused_columns / a TRL version change); the run is aborted rather "
|
|
2004
|
+
"than silently training on no signal."
|
|
2005
|
+
)
|
|
2006
|
+
if len(example_idx) != len(completions):
|
|
2007
|
+
raise RuntimeError(
|
|
2008
|
+
f"GRPO reward_fn example_idx/completions length mismatch "
|
|
2009
|
+
f"({len(example_idx)} vs {len(completions)}) — rewards would be misaligned with "
|
|
2010
|
+
"the sampled completions; aborting rather than training on a shifted reward signal."
|
|
2011
|
+
)
|
|
2012
|
+
examples = [rollout_examples[int(i)] for i in example_idx]
|
|
2013
|
+
rewards = []
|
|
2014
|
+
debug_rows = []
|
|
2015
|
+
for idx, (comp, ex) in enumerate(zip(completions, examples, strict=False)):
|
|
2016
|
+
if isinstance(comp, list):
|
|
2017
|
+
# Tool / conversational transcript (TRL passes a list of messages): score the
|
|
2018
|
+
# whole transcript via the environment reward (no <think> stripping —
|
|
2019
|
+
# multi-turn content).
|
|
2020
|
+
r = env.reward_from_messages(comp, ex)
|
|
2021
|
+
rewards.append(r)
|
|
2022
|
+
continue
|
|
2023
|
+
graded = graded_text(comp)
|
|
2024
|
+
breakdown = None
|
|
2025
|
+
if hasattr(env, "scores_breakdown"):
|
|
2026
|
+
breakdown = env.scores_breakdown(graded, ex)
|
|
2027
|
+
r = float(breakdown.get("total", 0.0))
|
|
2028
|
+
else:
|
|
2029
|
+
r = env.reward(graded, ex)
|
|
2030
|
+
if _think_penalty > 0 and THINKING:
|
|
2031
|
+
r -= _think_penalty * think_token_count(comp, tok)
|
|
2032
|
+
rewards.append(r)
|
|
2033
|
+
if idx < 8:
|
|
2034
|
+
debug_rows.append(
|
|
2035
|
+
{
|
|
2036
|
+
"ts": time.time(),
|
|
2037
|
+
"attempt": ATTEMPT,
|
|
2038
|
+
"run_id": RUN_ID,
|
|
2039
|
+
"mode": RUN_MODE,
|
|
2040
|
+
"seed": SEED,
|
|
2041
|
+
"reward": r,
|
|
2042
|
+
"breakdown": breakdown,
|
|
2043
|
+
"completion_prefix": str(comp or "")[:1000],
|
|
2044
|
+
"graded_prefix": str(graded or "")[:1000],
|
|
2045
|
+
"example_id": (ex or {}).get("id") if isinstance(ex, dict) else None,
|
|
2046
|
+
"example_input": (ex or {}).get("input") if isinstance(ex, dict) else None,
|
|
2047
|
+
}
|
|
2048
|
+
)
|
|
2049
|
+
upload_debug_jsonl("reward_debug.jsonl", debug_rows)
|
|
2050
|
+
return rewards
|
|
2051
|
+
|
|
2052
|
+
# TRL's per_device_train_batch_size counts COMPLETIONS, not prompts. Size grad-accum so
|
|
2053
|
+
# the global completion batch = prompts_per_step * group_size, i.e. each optimizer step
|
|
2054
|
+
# actually optimizes `prompts_per_step` prompts. The per-device *completion* micro-batch
|
|
2055
|
+
# is the VRAM knob (thinking-aware; see rl_per_device_comps).
|
|
2056
|
+
from flash.engine.vram import resolve_params_b
|
|
2057
|
+
|
|
2058
|
+
# Open-model (uncataloged) GRPO: size the colocate activation cap from the catalog stat, else
|
|
2059
|
+
# the HF safetensors metadata (no download). Without a real count a large open model falls back
|
|
2060
|
+
# to the ~2B-width default in rl_per_device_comps and gets too LOOSE a per-device cap ->
|
|
2061
|
+
# colocate OOM. Best-effort: stays None offline, keeping prior behavior.
|
|
2062
|
+
_params_b = resolve_params_b(model_id)
|
|
2063
|
+
from flash.catalog import vocab_size_for
|
|
2064
|
+
|
|
2065
|
+
# Per-device completion-logits cap: a multi-turn rollout accumulates a FULL transcript (model
|
|
2066
|
+
# turns + masked env tokens) up to the engine context — far longer than the single-turn per-turn
|
|
2067
|
+
# budget `_max_completion` — and the trainer's logprob forward processes that whole completion.
|
|
2068
|
+
# So size the fp32 [per_device, completion, vocab] cap against the WORST-CASE multi-turn
|
|
2069
|
+
# completion length (the engine context) instead of `_max_completion`, or a long multi-turn run
|
|
2070
|
+
# OOMs the trainer forward. Single-turn keeps `_max_completion` (its true completion length).
|
|
2071
|
+
_cap_completion_len = vllm_max_len if is_multi_turn else _max_completion
|
|
2072
|
+
per_device_comps = rl_per_device_comps(
|
|
2073
|
+
_cap_completion_len,
|
|
2074
|
+
vocab=vocab_size_for(model_id),
|
|
2075
|
+
use_vllm=use_vllm,
|
|
2076
|
+
params_b=_params_b,
|
|
2077
|
+
# The trainer forward processes prompt+completion up to the engine context, so the
|
|
2078
|
+
# activation/VRAM cap is sized against the worst-case training sequence length.
|
|
2079
|
+
seq_len=vllm_max_len,
|
|
2080
|
+
)
|
|
2081
|
+
if is_multi_turn and _cap_completion_len != _max_completion:
|
|
2082
|
+
print(
|
|
2083
|
+
f"[rl] multi-turn: sizing the per-device logits cap against the full transcript length "
|
|
2084
|
+
f"{_cap_completion_len} (engine context), not the per-turn budget {_max_completion}"
|
|
2085
|
+
)
|
|
2086
|
+
batching = compute_grpo_batching(prompts_per_step, group_size, per_device_comps)
|
|
2087
|
+
if not batching["divisible_by_group"]:
|
|
2088
|
+
print(
|
|
2089
|
+
"WARN: generation batch not divisible by group size; check prompts_per_step/group_size"
|
|
2090
|
+
)
|
|
2091
|
+
print(
|
|
2092
|
+
f"[rl] GRPO batching: per_device={batching['per_device_train_batch_size']} "
|
|
2093
|
+
f"grad_accum={batching['gradient_accumulation_steps']} "
|
|
2094
|
+
f"generations/step={batching['generations_per_step']} "
|
|
2095
|
+
f"unique_prompts/step={batching['unique_prompts_per_step']} "
|
|
2096
|
+
f"(target prompts/step={prompts_per_step}, group={group_size}, sleep={sleep_mode})"
|
|
2097
|
+
)
|
|
2098
|
+
out_dir = f"/tmp/rl_seed{SEED}"
|
|
2099
|
+
resume_ckpt = hf_resume_checkpoint()
|
|
2100
|
+
|
|
2101
|
+
grpo_kwargs = {
|
|
2102
|
+
"output_dir": out_dir,
|
|
2103
|
+
"learning_rate": (
|
|
2104
|
+
_t.learning_rate if _t and _t.learning_rate is not None else rl.learning_rate
|
|
2105
|
+
),
|
|
2106
|
+
"per_device_train_batch_size": batching["per_device_train_batch_size"],
|
|
2107
|
+
"gradient_accumulation_steps": batching["gradient_accumulation_steps"],
|
|
2108
|
+
"num_generations": group_size,
|
|
2109
|
+
# NB: GRPOConfig has no max_prompt_length field (TRL 1.5) and does not truncate
|
|
2110
|
+
# prompts; the dataset is pre-filtered above to prompts that fit prompt_budget
|
|
2111
|
+
# (vllm_max_len - completion), so every prompt fits the engine sized here.
|
|
2112
|
+
"max_completion_length": _max_completion,
|
|
2113
|
+
"max_steps": steps,
|
|
2114
|
+
"temperature": _temperature,
|
|
2115
|
+
"top_p": rl.sampling_top_p,
|
|
2116
|
+
"use_vllm": use_vllm,
|
|
2117
|
+
"logging_steps": 1,
|
|
2118
|
+
"save_steps": _t.save_every if _t and _t.save_every is not None else 20,
|
|
2119
|
+
"save_total_limit": 1,
|
|
2120
|
+
# Resumable checkpoints: keep the optimizer/scheduler/RNG state with the LoRA adapter so a
|
|
2121
|
+
# preempted GRPO run resumed via resume_from_checkpoint(hf_resume_checkpoint()) continues
|
|
2122
|
+
# with intact optimizer state + step instead of a fresh optimizer. For LoRA this state is
|
|
2123
|
+
# small (trainable adapter params only). The deployable per-step snapshot strips it
|
|
2124
|
+
# separately, so serving still gets adapter-only files.
|
|
2125
|
+
"save_only_model": False,
|
|
2126
|
+
"bf16": True,
|
|
2127
|
+
"report_to": wandb_report_to(), # W&B when WANDB_API_KEY present (restored post-flash-migration)
|
|
2128
|
+
"run_name": wandb_run_name(),
|
|
2129
|
+
"seed": SEED,
|
|
2130
|
+
"gradient_checkpointing": grad_checkpointing_on(model_id, vllm_max_len),
|
|
2131
|
+
# Non-reentrant checkpointing: the modern path that composes correctly with autograd
|
|
2132
|
+
# saved-tensor hooks and avoids the reentrant path's extra graph retention. (verl #3629.)
|
|
2133
|
+
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
|
2134
|
+
# Pin a stable, well-conditioned GRPO recipe instead of inheriting TRL's defaults
|
|
2135
|
+
# (which on a short run suppress the lift): constant LR (TRL default 'linear' decays
|
|
2136
|
+
# to 0 over the run), advantages centered by group mean only (no std scaling, which
|
|
2137
|
+
# biases by difficulty/length — matches datums.centered_advantages), and no
|
|
2138
|
+
# length-normalized loss. beta is the KL-to-reference coef (datums kl_masks ->
|
|
2139
|
+
# kl_penalty_coef).
|
|
2140
|
+
"lr_scheduler_type": "constant",
|
|
2141
|
+
"warmup_ratio": 0.0,
|
|
2142
|
+
"beta": _kl_beta,
|
|
2143
|
+
"scale_rewards": "none",
|
|
2144
|
+
"loss_type": "dr_grpo",
|
|
2145
|
+
# Optimizer: 8-bit paged AdamW (int8 state paged to host RAM -> fits a smaller GPU);
|
|
2146
|
+
# colocated GRPO (trainer + vLLM on one GPU) is memory-tight, so this is the right default.
|
|
2147
|
+
"optim": fused_optim_name(),
|
|
2148
|
+
}
|
|
2149
|
+
# Liger fused GRPO loss: fuses the lm_head + per-token logprob so the fp32
|
|
2150
|
+
# [batch, seq, ~248k vocab] logits never materialize — the documented GRPO OOM driver.
|
|
2151
|
+
# TRL 1.6's GRPOConfig flag is `use_liger_kernel` (NOT `use_liger_loss`, which doesn't
|
|
2152
|
+
# exist in 1.6). DEFAULT ON for the GRPO path regardless of model size: MEASURED that
|
|
2153
|
+
# WITHOUT it even Qwen3.5-0.8B GRPO OOMs a 24 GB (and 32 GB) card because the per-completion
|
|
2154
|
+
# logits over the 248k vocab dominate — the small-scale JIT cost is far cheaper than the OOM.
|
|
2155
|
+
# (This differs from SFT, where Liger is gated by size since 1B-class SFT can be net-negative.)
|
|
2156
|
+
if liger_on(True):
|
|
2157
|
+
grpo_kwargs["use_liger_kernel"] = True
|
|
2158
|
+
print("[rl] liger fused GRPO loss enabled")
|
|
2159
|
+
if use_vllm:
|
|
2160
|
+
# RTX 5090 / sm120: pin a PTX-independent vLLM attention backend (FLASHINFER) BEFORE TRL
|
|
2161
|
+
# builds the colocated engine — else the rollout can silently produce no completions on
|
|
2162
|
+
# old-driver Blackwell hosts (flash-attn PTX JIT failure). No-op off sm120 / if pinned.
|
|
2163
|
+
force_vllm_backend_for_sm120()
|
|
2164
|
+
# Colocate shares one GPU between the policy model and the vLLM rollout engine.
|
|
2165
|
+
# vllm_max_model_length bounds the KV cache to what GRPO needs (else vLLM sizes for
|
|
2166
|
+
# the model's FULL context and won't start on a consumer GPU).
|
|
2167
|
+
# vllm_gpu_memory_utilization sizes vLLM's KV pool. The blanket sleep-path 0.45 was a
|
|
2168
|
+
# misjudgement: on an 80 GB A100 it reserves 0.45 x 80 = 36 GB of KV, but a GRPO rollout only
|
|
2169
|
+
# holds ~num_generations x context tokens. MEASURED (Qwen3.5-4B colocate): that 36 GB
|
|
2170
|
+
# reservation is the dominant resident allocation and sets the step peak (~46 GB) — exactly why
|
|
2171
|
+
# trainer-side optimisations (mask-aware lm_head, fused layers) moved nothing. colocate_kv_util
|
|
2172
|
+
# sizes both paths from flash's per-model KV estimate instead (vram.py); MEASURED 4B/80 GB peak
|
|
2173
|
+
# 46 -> 26 GB, reward byte-identical, train_wall neutral.
|
|
2174
|
+
try:
|
|
2175
|
+
import torch as _torch_vram
|
|
2176
|
+
|
|
2177
|
+
from flash.engine.vram import colocate_kv_util
|
|
2178
|
+
|
|
2179
|
+
_total_vram_gb = _torch_vram.cuda.get_device_properties(0).total_memory / 1e9
|
|
2180
|
+
_vllm_gpu_mem_util = colocate_kv_util(
|
|
2181
|
+
_params_b, vllm_max_len, _total_vram_gb, sleep_mode, num_generations=group_size
|
|
2182
|
+
)
|
|
2183
|
+
except Exception:
|
|
2184
|
+
_vllm_gpu_mem_util = 0.45 if sleep_mode else 0.10 # safe fallback to the old constants
|
|
2185
|
+
grpo_kwargs.update(
|
|
2186
|
+
vllm_mode="colocate",
|
|
2187
|
+
vllm_max_model_length=vllm_max_len,
|
|
2188
|
+
vllm_gpu_memory_utilization=_vllm_gpu_mem_util,
|
|
2189
|
+
vllm_enable_sleep_mode=sleep_mode,
|
|
2190
|
+
)
|
|
2191
|
+
# Rollout-memory + throughput knobs, applied ONLY if this TRL exposes the field (so an
|
|
2192
|
+
# older TRL never crashes on an unknown kwarg). All verl-validated for GRPO colocate (#174).
|
|
2193
|
+
_grpo_fields = set(getattr(GRPOConfig, "__dataclass_fields__", {}))
|
|
2194
|
+
|
|
2195
|
+
def _set_vllm_field(names, value, label):
|
|
2196
|
+
for _f in names:
|
|
2197
|
+
if _f in _grpo_fields:
|
|
2198
|
+
grpo_kwargs[_f] = value
|
|
2199
|
+
print(f"[rl] {label} ({_f}={value})")
|
|
2200
|
+
return True
|
|
2201
|
+
return False
|
|
2202
|
+
|
|
2203
|
+
# fp8 KV cache only where the silicon has native fp8 (compute capability >= 8.9: Ada /
|
|
2204
|
+
# Hopper / Blackwell) — ~halves the rollout KV pool. Ampere (A100/A6000/3090) lacks
|
|
2205
|
+
# fp8, so it stays fp16 there (forcing it on would error / silently emulate).
|
|
2206
|
+
try:
|
|
2207
|
+
import torch as _torch
|
|
2208
|
+
|
|
2209
|
+
_want_fp8 = _torch.cuda.get_device_capability() >= (8, 9)
|
|
2210
|
+
except Exception:
|
|
2211
|
+
_want_fp8 = False
|
|
2212
|
+
if _want_fp8:
|
|
2213
|
+
_set_vllm_field(("vllm_kv_cache_dtype", "kv_cache_dtype"), "fp8", "fp8 KV cache")
|
|
2214
|
+
# PREFIX CACHING: every GRPO group of `num_generations` rollouts shares the SAME prompt
|
|
2215
|
+
# prefix, so caching the prompt KV computes it once and reuses it — the dominant rollout win
|
|
2216
|
+
# on one GPU. CHUNKED PREFILL interleaves prefill with decode so a long prompt doesn't stall
|
|
2217
|
+
# the batch. CUDAGRAPH MODE sets verl's full-graph-decode + piecewise-fallback rollout mode.
|
|
2218
|
+
_set_vllm_field(
|
|
2219
|
+
("vllm_enable_prefix_caching", "enable_prefix_caching"),
|
|
2220
|
+
True,
|
|
2221
|
+
"vLLM prefix caching (shared GRPO prompt KV reuse)",
|
|
2222
|
+
)
|
|
2223
|
+
_set_vllm_field(
|
|
2224
|
+
("vllm_enable_chunked_prefill", "enable_chunked_prefill"),
|
|
2225
|
+
True,
|
|
2226
|
+
"vLLM chunked prefill",
|
|
2227
|
+
)
|
|
2228
|
+
# vLLM 0.19.1 regressed the Triton _compute_slot_mapping_kernel: it launches
|
|
2229
|
+
# (num_reqs + 1) thread blocks but the block table only has num_reqs rows, so the
|
|
2230
|
+
# extra block causes an illegal memory access (cudaErrorIllegalAddress) on the first
|
|
2231
|
+
# generation step. CUDA graph compilation triggers this path. Skip FULL_AND_PIECEWISE
|
|
2232
|
+
# for vLLM versions outside TRL's supported range (0.12.0-0.19.0) until a fix lands.
|
|
2233
|
+
_cudagraph_safe = True
|
|
2234
|
+
try:
|
|
2235
|
+
import vllm as _vllm_mod
|
|
2236
|
+
|
|
2237
|
+
_ver_base = _vllm_mod.__version__.split("+")[0] # strip PEP440 local (e.g. +cu121)
|
|
2238
|
+
_vllm_ver = tuple(int(x) for x in _ver_base.split(".")[:3])
|
|
2239
|
+
if _vllm_ver > (0, 19, 0):
|
|
2240
|
+
_cudagraph_safe = False
|
|
2241
|
+
print(
|
|
2242
|
+
f"[rl][warn] vLLM {_vllm_mod.__version__} > 0.19.0: skipping "
|
|
2243
|
+
"FULL_AND_PIECEWISE CUDA graph compilation (Triton slot-mapping "
|
|
2244
|
+
"crash workaround; update vLLM to a TRL-supported version to re-enable)"
|
|
2245
|
+
)
|
|
2246
|
+
# vLLM 0.19.1 ALSO hits `RuntimeError: aot_compile is not supported by the
|
|
2247
|
+
# current configuration` through its DEFAULT torch.compile path on some GPU
|
|
2248
|
+
# arches (Ampere sm_86: A6000, A100) — it fires from vllm/compilation/wrapper.py
|
|
2249
|
+
# when torch._dynamo.is_compiling() is False inside the CUDA-graph capture path.
|
|
2250
|
+
# Skipping FULL_AND_PIECEWISE above is not enough (the vllm_compilation_config
|
|
2251
|
+
# GRPOConfig field doesn't exist in this TRL, so that _set_vllm_field is a no-op).
|
|
2252
|
+
# VLLM_TORCH_COMPILE_LEVEL=0 (NO_COMPILATION) forces vLLM to execute the model
|
|
2253
|
+
# eagerly, preventing the AOT path entirely. Official vLLM env var (vllm/envs.py);
|
|
2254
|
+
# a no-op on a vLLM that doesn't define it. Don't override an operator-set value.
|
|
2255
|
+
if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ:
|
|
2256
|
+
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = "0"
|
|
2257
|
+
print("[rl][warn] VLLM_TORCH_COMPILE_LEVEL=0 (prevent aot_compile on vLLM 0.19.1)")
|
|
2258
|
+
except Exception:
|
|
2259
|
+
pass
|
|
2260
|
+
if _cudagraph_safe:
|
|
2261
|
+
_set_vllm_field(
|
|
2262
|
+
("vllm_compilation_config", "compilation_config"),
|
|
2263
|
+
{"cudagraph_mode": "FULL_AND_PIECEWISE"},
|
|
2264
|
+
"vLLM cudagraph_mode (verl rollout default)",
|
|
2265
|
+
)
|
|
2266
|
+
# Adapter init: continue training the SFT adapter (peft_config=None, model is the
|
|
2267
|
+
# loaded PeftModel) when train.init_from_adapter is set, else a fresh LoRA on the
|
|
2268
|
+
# string model id (model_init_kwargs forces bf16 — TRL string-loading can fall back
|
|
2269
|
+
# to fp32 and double VRAM).
|
|
2270
|
+
init_model, init_peft = _init_adapter_model(model_id)
|
|
2271
|
+
# chalk's kernels are applied AFTER construction (below) against trainer.model: chalk's apply
|
|
2272
|
+
# patches the LIVE nn.Module, so there is nothing to install pre-build. On the fresh-LoRA path
|
|
2273
|
+
# init_model is just the model-id string (TRL builds the module), and even on the
|
|
2274
|
+
# continue-adapter path TRL may rebuild/wrap the PeftModel, so trainer.model is the
|
|
2275
|
+
# authoritative target.
|
|
2276
|
+
if init_peft is not None:
|
|
2277
|
+
# Fresh LoRA: TRL loads the string model id with these kwargs, then attaches the
|
|
2278
|
+
# adapter. Force bf16 (TRL string-loading can fall back to fp32 and double VRAM).
|
|
2279
|
+
_attn = optimal_attn_impl() # arch-aware FlashAttention (Kernels Hub) / SDPA
|
|
2280
|
+
grpo_kwargs["model_init_kwargs"] = {"dtype": "bfloat16"}
|
|
2281
|
+
if _attn:
|
|
2282
|
+
grpo_kwargs["model_init_kwargs"]["attn_implementation"] = _attn
|
|
2283
|
+
else:
|
|
2284
|
+
_attn = optimal_attn_impl()
|
|
2285
|
+
# stop_sequences: TRL forwards generation_kwargs to the (vLLM) sampler, whose
|
|
2286
|
+
# SamplingParams.stop truncates each rollout at the requested delimiter — so the reward
|
|
2287
|
+
# sees the same completion the config intends, instead of generating to max_completion.
|
|
2288
|
+
if _t and _t.stop_sequences:
|
|
2289
|
+
grpo_kwargs["generation_kwargs"] = {"stop": list(_t.stop_sequences)}
|
|
2290
|
+
# advantage_clip>0 is the datums centered-advantage clamp; TRL has no advantage-value
|
|
2291
|
+
# clip knob (it clips the importance ratio), so honor the default (clip off ==
|
|
2292
|
+
# centered) and surface a note when a config asks for an explicit clamp.
|
|
2293
|
+
if _adv_clip > 0:
|
|
2294
|
+
print(f"[rl] advantage_clip={_adv_clip} recorded; TRL centers advantages (no value clip)")
|
|
2295
|
+
# num_iterations (the one promoted GRPO speed lever, measured 1.38x faster) is feature-detected
|
|
2296
|
+
# so an older TRL that lacks the field is simply skipped (GRPOConfig rejects unknown kwargs).
|
|
2297
|
+
# Generation dominates GRPO wall-clock, so reusing each rollout batch for 2 optimizer steps is
|
|
2298
|
+
# the cheapest large speedup; mu=2 is the standard GRPO config and TRL's importance-sampling
|
|
2299
|
+
# correction (on by default) keeps the step stable. (The GSPO/DAPO A/B levers were dropped: the
|
|
2300
|
+
# framework-scan in gpu-bench/RESEARCH_FINDINGS.md measured no robust win over baseline.)
|
|
2301
|
+
import dataclasses as _dc
|
|
2302
|
+
|
|
2303
|
+
try:
|
|
2304
|
+
_grpo_fields = {f.name for f in _dc.fields(GRPOConfig)}
|
|
2305
|
+
except TypeError:
|
|
2306
|
+
_grpo_fields = set() # not a dataclass on this TRL -> skip the feature-detected knob
|
|
2307
|
+
if "num_iterations" in _grpo_fields:
|
|
2308
|
+
grpo_kwargs["num_iterations"] = 2
|
|
2309
|
+
print("[rl] rollout amortization: num_iterations=2 (reuse each generation batch)")
|
|
2310
|
+
# truncated importance sampling (tis): trl's grpo applies an importance-sampling correction by
|
|
2311
|
+
# default, but with mode="sequence_mask" and clip_max=3.0. the verl/openrlhf recipe for the
|
|
2312
|
+
# rollout(vllm)-vs-training token-distribution mismatch is TOKEN-LEVEL truncated is with the
|
|
2313
|
+
# per-token ratio clipped at c=2 (verl rollout_is_threshold=2.0). adopt that recipe here:
|
|
2314
|
+
# token_truncate + c_max=2.0. feature-detected against this trl's GRPOConfig fields (canonical
|
|
2315
|
+
# clip field first, then the pre-2.0 deprecated alias), so a trl that lacks a field is skipped.
|
|
2316
|
+
# note: this deliberately changes trl's defaults (sequence_mask / 3.0) to the recipe values.
|
|
2317
|
+
if "vllm_importance_sampling_mode" in _grpo_fields:
|
|
2318
|
+
grpo_kwargs["vllm_importance_sampling_mode"] = "token_truncate"
|
|
2319
|
+
print("[rl] tis mode=token_truncate (token-level truncated importance sampling)")
|
|
2320
|
+
_tis_c = 2.0
|
|
2321
|
+
_tis_clip_field = next(
|
|
2322
|
+
(
|
|
2323
|
+
f
|
|
2324
|
+
for f in ("vllm_importance_sampling_clip_max", "vllm_importance_sampling_cap")
|
|
2325
|
+
if f in _grpo_fields
|
|
2326
|
+
),
|
|
2327
|
+
None,
|
|
2328
|
+
)
|
|
2329
|
+
if _tis_clip_field:
|
|
2330
|
+
grpo_kwargs[_tis_clip_field] = _tis_c
|
|
2331
|
+
print(f"[rl] tis clip c_max={_tis_c} ({_tis_clip_field})")
|
|
2332
|
+
else:
|
|
2333
|
+
print("[rl] tis: trl default importance-sampling correction in effect; no clip field on this trl")
|
|
2334
|
+
cfg = GRPOConfig(**grpo_kwargs)
|
|
2335
|
+
setup_seconds = time.time() - t_start
|
|
2336
|
+
heartbeat("rl_train_start", setup_seconds=setup_seconds, gpu=gpu_diagnostics())
|
|
2337
|
+
|
|
2338
|
+
# VL checkpoints (Qwen3.5/3.6) train text-only: make TRL's colocated rollout
|
|
2339
|
+
# engine skip the vision tower (VRAM + 5090 PTX-compat; see the patch docstring).
|
|
2340
|
+
# Only relevant when vLLM drives rollouts; transformers generation uses the trainer
|
|
2341
|
+
# model (already text-only via the LoRA target/exclude config).
|
|
2342
|
+
if use_vllm:
|
|
2343
|
+
patch_vllm_language_model_only(model_id)
|
|
2344
|
+
# Install (but do NOT yet activate) the TRL->vLLM weight-sync name remap for Qwen3.5/3.6:
|
|
2345
|
+
# the trainer pushes ``model.*`` names but the VL engine's LM params live under
|
|
2346
|
+
# ``language_model.*``, so the first sync_weights() would raise without this. Activated
|
|
2347
|
+
# below, after the trainer + its initial checkpoint load are built.
|
|
2348
|
+
patch_vllm_lm_weight_sync(model_id)
|
|
2349
|
+
hb_cb = make_reward_heartbeat_callback()
|
|
2350
|
+
# Multi-turn / tool wiring (trl 1.6): tool envs hand TRL the tool callables so it runs the
|
|
2351
|
+
# tool-call loop natively; pure multi-turn envs hand TRL a rollout_func that drives the
|
|
2352
|
+
# env's own turn loop on the colocate engine (env_mask masks the non-model tokens).
|
|
2353
|
+
extra_trainer_kwargs: dict = {}
|
|
2354
|
+
tools = env.tools() if is_tool_env else []
|
|
2355
|
+
# A tool env exposing NO tools would silently degrade to single-shot under tools=[]; drive
|
|
2356
|
+
# it through the rollout_func turn loop instead so it isn't mis-trained as single-turn.
|
|
2357
|
+
if is_tool_env and not tools:
|
|
2358
|
+
print("[rl][warn] tool env exposes no tools — using the multi-turn rollout_func path")
|
|
2359
|
+
use_rollout_func = is_multi_turn and not (is_tool_env and tools)
|
|
2360
|
+
require_vllm_for_rollout_func(use_rollout_func, use_vllm, model_id)
|
|
2361
|
+
if is_tool_env and tools:
|
|
2362
|
+
extra_trainer_kwargs["tools"] = tools
|
|
2363
|
+
print(f"[rl] tool env: handing {len(tools)} tool(s) to TRL's native tool loop")
|
|
2364
|
+
if use_rollout_func:
|
|
2365
|
+
from flash.engine.multiturn_rollout import (
|
|
2366
|
+
build_examples_index,
|
|
2367
|
+
build_rollout_func,
|
|
2368
|
+
index_collisions,
|
|
2369
|
+
)
|
|
2370
|
+
|
|
2371
|
+
examples_by_key = build_examples_index(train, env.prompt_messages)
|
|
2372
|
+
ncol = index_collisions(train, env.prompt_messages)
|
|
2373
|
+
if ncol:
|
|
2374
|
+
print(
|
|
2375
|
+
f"[rl][warn] {ncol} duplicate prompt(s) collide in the reward index; the shared "
|
|
2376
|
+
"prompt scores against the last example's answer/info"
|
|
2377
|
+
)
|
|
2378
|
+
extra_trainer_kwargs["rollout_func"] = build_rollout_func(
|
|
2379
|
+
active_env=env,
|
|
2380
|
+
tok=tok,
|
|
2381
|
+
examples_by_key=examples_by_key,
|
|
2382
|
+
max_completion=_max_completion,
|
|
2383
|
+
max_turns=getattr(env, "max_turns", 10),
|
|
2384
|
+
temperature=_temperature,
|
|
2385
|
+
top_p=rl.sampling_top_p,
|
|
2386
|
+
stop=(list(_t.stop_sequences) if _t and _t.stop_sequences else None),
|
|
2387
|
+
thinking=THINKING,
|
|
2388
|
+
engine_max_len=vllm_max_len,
|
|
2389
|
+
)
|
|
2390
|
+
print("[rl] multi-turn env: driving the turn loop via rollout_func")
|
|
2391
|
+
# GRPOTrainer.__init__ blocks during model/vLLM init + FA2 kernel compilation (can be
|
|
2392
|
+
# 10-20 min on first use). Background heartbeats keep the stall detector quiet.
|
|
2393
|
+
_rl_init_done = threading.Event()
|
|
2394
|
+
|
|
2395
|
+
def _rl_init_heartbeat() -> None:
|
|
2396
|
+
while not _rl_init_done.wait(30.0):
|
|
2397
|
+
heartbeat("rl_initializing", gpu=gpu_diagnostics())
|
|
2398
|
+
|
|
2399
|
+
_rl_init_hb = threading.Thread(target=_rl_init_heartbeat, daemon=True)
|
|
2400
|
+
_rl_init_hb.start()
|
|
2401
|
+
try:
|
|
2402
|
+
trainer = GRPOTrainer(
|
|
2403
|
+
model=init_model,
|
|
2404
|
+
args=cfg,
|
|
2405
|
+
train_dataset=ds,
|
|
2406
|
+
reward_funcs=reward_fn,
|
|
2407
|
+
peft_config=init_peft,
|
|
2408
|
+
processing_class=tok,
|
|
2409
|
+
callbacks=[hb_cb, make_checkpoint_upload_callback()],
|
|
2410
|
+
**extra_trainer_kwargs,
|
|
2411
|
+
)
|
|
2412
|
+
finally:
|
|
2413
|
+
_rl_init_done.set()
|
|
2414
|
+
# Apply chalk's gap-filling kernels (RoPE/LoRA-delta/embedding, like Liger) on the module
|
|
2415
|
+
# GRPOTrainer actually optimizes (trainer.model) — the fresh-LoRA path only passes the model-id
|
|
2416
|
+
# string to TRL, so trainer.model is the authoritative target. chalk composes on top of Liger.
|
|
2417
|
+
# Capture the install report so the engaged kernels land in metrics (active_kernels below).
|
|
2418
|
+
_chalk_report = install_chalk_kernels(getattr(trainer, "model", None))
|
|
2419
|
+
# Liger fused-loss chunk_size: TRL leaves it at the default 1, so the fused GRPO loss runs its
|
|
2420
|
+
# whole detach -> chunk_forward -> compiled-loss -> autograd.grad cycle ONCE PER SEQUENCE
|
|
2421
|
+
# (per_device_train_batch_size times) — Python/kernel-launch/compile-guard overhead that
|
|
2422
|
+
# dominates at small-model scale where the GEMMs are tiny. Collapse it to ONE invocation over the
|
|
2423
|
+
# whole per-device micro-batch. Numerically identical (every loss_type normalizes by the GLOBAL
|
|
2424
|
+
# token count, not the chunk-local size, and chunk losses are summed). Must run BEFORE the
|
|
2425
|
+
# mask-aware wrap below, which replaces trainer.liger_grpo_loss with a closure that has no
|
|
2426
|
+
# chunk_size attribute.
|
|
2427
|
+
_liger_loss = getattr(trainer, "liger_grpo_loss", None)
|
|
2428
|
+
if _liger_loss is not None and hasattr(_liger_loss, "chunk_size"):
|
|
2429
|
+
_cs = max(1, int(getattr(trainer.args, "per_device_train_batch_size", 1)))
|
|
2430
|
+
if _cs > int(getattr(_liger_loss, "chunk_size", 1)):
|
|
2431
|
+
_liger_loss.chunk_size = _cs
|
|
2432
|
+
print(f"[rl] liger fused-loss chunk_size -> {_cs} (one invocation, not one per sequence)")
|
|
2433
|
+
# Run liger's fused GRPO loss EAGER: drop ONLY its torch.compile (BROKEN on torch 2.10 — its
|
|
2434
|
+
# dynamo guard-gen trips a symbol_to_source IndexError that crashes the first GRPO step on every
|
|
2435
|
+
# path), keep the chunked memory path that prevents the 248k-vocab fp32-logit OOM. Must run BEFORE
|
|
2436
|
+
# the mask-aware wrap below, which replaces trainer.liger_grpo_loss with a closure. See the helper.
|
|
2437
|
+
if disable_liger_grpo_torch_compile(trainer):
|
|
2438
|
+
print(
|
|
2439
|
+
"[rl] liger GRPO loss: torch.compile DISABLED (eager loss math; chunked memory path "
|
|
2440
|
+
"retained) — dodges the torch 2.10 dynamo guard-gen crash (symbol_to_source IndexError)"
|
|
2441
|
+
)
|
|
2442
|
+
# Mask-aware lm_head: skip the 248k-vocab projection at MASKED completion positions in the GRPO
|
|
2443
|
+
# loss — its most expensive op, and the trainer step dominates train_wall. For MULTI-TURN that
|
|
2444
|
+
# masked set is the ~half-to-most of the transcript that is env/tool text; for SINGLE-TURN it is
|
|
2445
|
+
# the right-PADDING (GRPO samples variable-length completions, padded to the batch max). Either
|
|
2446
|
+
# way those positions add zero loss/gradient but pay full FLOPs. Loss-preserving; applies to ALL
|
|
2447
|
+
# GRPO with the Liger fused loss; no-op when nothing is masked (uniform-length single-turn).
|
|
2448
|
+
if grpo_kwargs.get("use_liger_kernel") and patch_grpo_mask_aware_lm_head(trainer):
|
|
2449
|
+
_masked_kind = "env + padding" if use_rollout_func else "padding"
|
|
2450
|
+
print(f"[rl] mask-aware lm_head: skipping masked ({_masked_kind}) positions in the GRPO loss")
|
|
2451
|
+
# The trainer (and its colocated vLLM engine + initial checkpoint load) is now built. Activate
|
|
2452
|
+
# the TRL->vLLM weight-sync name remap ONLY now (see patch_vllm_lm_weight_sync) so the initial
|
|
2453
|
+
# checkpoint load stayed untouched while the train-time syncs get remapped. No-op unless the VL
|
|
2454
|
+
# patch above was installed.
|
|
2455
|
+
if use_vllm:
|
|
2456
|
+
_LM_SYNC_REMAP_ON["on"] = True
|
|
2457
|
+
if is_vl_checkpoint(model_id):
|
|
2458
|
+
print("[vllm] LM weight-sync remap activated for training syncs")
|
|
2459
|
+
# Mid-run eval is intentionally NOT run during training: held-out evaluation happens on the
|
|
2460
|
+
# deploy/serving side (against the trained adapter), keeping training pure (no eval-phase cost
|
|
2461
|
+
# or eval-boundary stalls). Training streams only the per-step reward heartbeat.
|
|
2462
|
+
_reset_peak_gpu() # peak_gpu_gb reflects the train loop (verifies the micro-batch headroom)
|
|
2463
|
+
_gpu_sampler = _GpuPeakSampler().start() # true device peak incl. vLLM colocate + bnb pages
|
|
2464
|
+
t_train = time.time()
|
|
2465
|
+
with _sdpa_cudnn_ctx(_attn): # force cuDNN SDPA on sm120 (no-op otherwise)
|
|
2466
|
+
trainer.train(resume_from_checkpoint=resume_ckpt)
|
|
2467
|
+
train_wall = time.time() - t_train
|
|
2468
|
+
rl_peak_gpu_gb = _peak_gpu_gb()
|
|
2469
|
+
rl_device_peak_gpu_gb = _gpu_sampler.stop_gb()
|
|
2470
|
+
reward_history = list(getattr(hb_cb, "reward_history", []))
|
|
2471
|
+
# A GRPO run that finishes WITHOUT the reward callback ever firing (empty reward_history)
|
|
2472
|
+
# produced NO real training — the rollout scored nothing (e.g. vLLM generation silently
|
|
2473
|
+
# returning no completions, observed on RTX 5090 / sm120: ~1.4 s wall, empty reward + loss
|
|
2474
|
+
# curves, but the run otherwise "succeeds"). That is a FAILURE, not a success: a no-op run with
|
|
2475
|
+
# an unchanged adapter must not be reported as done — fail loudly so the operator/agent doesn't
|
|
2476
|
+
# trust it. (An env returning all-zero rewards still appends 0.0s, so an EMPTY history uniquely
|
|
2477
|
+
# means the reward path never ran.)
|
|
2478
|
+
_steps_run = int(getattr(trainer.state, "global_step", 0) or 0)
|
|
2479
|
+
# A resume that already reached the target steps legitimately performs ZERO new optimizer
|
|
2480
|
+
# steps: the previous worker uploaded the final checkpoint (and scored its rewards) but died
|
|
2481
|
+
# before writing metrics/DONE, so this worker's fresh hb_cb has an empty reward_history even
|
|
2482
|
+
# though the policy IS fully trained. Don't fail those — finalize from the resumed state. The
|
|
2483
|
+
# no-op guard below is only for a run that genuinely trained nothing (no resume, or the resume
|
|
2484
|
+
# didn't reach the target steps).
|
|
2485
|
+
_resumed_complete = _grpo_resume_already_complete(resume_ckpt, steps, _steps_run)
|
|
2486
|
+
if _grpo_is_no_op_failure(reward_history, resume_ckpt, steps, _steps_run):
|
|
2487
|
+
if _steps_run == 0:
|
|
2488
|
+
raise RuntimeError(
|
|
2489
|
+
"GRPO trainer completed zero optimizer steps before any reward was scored. "
|
|
2490
|
+
f"retained_prompts={len(prompts)}, prompts_per_step={prompts_per_step}, "
|
|
2491
|
+
f"generations_per_step={batching['generations_per_step']}. This usually means "
|
|
2492
|
+
"TRL built an empty dataloader; add training examples, lower [train].batch_size, "
|
|
2493
|
+
"or reduce prompt length/max_tokens so more examples fit."
|
|
2494
|
+
)
|
|
2495
|
+
raise RuntimeError(
|
|
2496
|
+
f"GRPO scored no reward in {train_wall:.1f}s over {_steps_run} step(s) — the rollout "
|
|
2497
|
+
"produced no completions, so the policy was never actually trained. Failing loudly "
|
|
2498
|
+
"instead of reporting a no-op run as done (seen on RTX 5090/sm120 vLLM rollout)."
|
|
2499
|
+
)
|
|
2500
|
+
if not reward_history and _resumed_complete:
|
|
2501
|
+
print(
|
|
2502
|
+
f"[resume] no new reward in this worker but resumed checkpoint already reached "
|
|
2503
|
+
f"{_steps_run}/{steps} step(s) — finalizing the completed policy instead of failing."
|
|
2504
|
+
)
|
|
2505
|
+
adapter_dir = f"{out_dir}/adapter"
|
|
2506
|
+
trainer.model.save_pretrained(adapter_dir)
|
|
2507
|
+
tok.save_pretrained(adapter_dir)
|
|
2508
|
+
hf_upload_folder(adapter_dir, "adapter", required=True)
|
|
2509
|
+
heartbeat("rl_trained", train_wall=train_wall, gpu=gpu_diagnostics())
|
|
2510
|
+
|
|
2511
|
+
# Upper bound on generated tokens: completions actually optimized (the intended
|
|
2512
|
+
# prompts_per_step after the batch fix) x the max completion length. Over-counts (most
|
|
2513
|
+
# completions are shorter); reported as an upper bound, used only for a rough throughput.
|
|
2514
|
+
gen_tokens = steps * batching["unique_prompts_per_step"] * group_size * _max_completion
|
|
2515
|
+
write_train_meta(
|
|
2516
|
+
phase="rl",
|
|
2517
|
+
adapter_dir=adapter_dir,
|
|
2518
|
+
model_id=model_id,
|
|
2519
|
+
train_wall=train_wall,
|
|
2520
|
+
setup_seconds=setup_seconds,
|
|
2521
|
+
train_tokens=0,
|
|
2522
|
+
generated_tokens=gen_tokens,
|
|
2523
|
+
notes={
|
|
2524
|
+
"steps": steps,
|
|
2525
|
+
"resumed": bool(resume_ckpt),
|
|
2526
|
+
"download_seconds": download_seconds,
|
|
2527
|
+
"hf_transfer": os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", ""),
|
|
2528
|
+
"reward_history": reward_history,
|
|
2529
|
+
"loss_curve": _metric_curve(trainer, "loss"),
|
|
2530
|
+
# Peak torch-allocated GPU memory during the GRPO train loop (excludes bnb managed
|
|
2531
|
+
# pages). device_peak_gpu_gb is the TRUE device footprint (total-free, incl. the vLLM
|
|
2532
|
+
# colocate engine + bnb pages): the headline for verifying the per-device micro-batch
|
|
2533
|
+
# left the card with headroom (no OOM) at the sized batch.
|
|
2534
|
+
"peak_gpu_gb": rl_peak_gpu_gb,
|
|
2535
|
+
"device_peak_gpu_gb": rl_device_peak_gpu_gb,
|
|
2536
|
+
# Which chalk gap-filling kernels actually ENGAGED (None = chalk not installed or every
|
|
2537
|
+
# kernel fell back) — verifies the chalk stack on a GRPO run without the console.
|
|
2538
|
+
"chalk_kernels": active_kernels(_chalk_report) or None,
|
|
2539
|
+
**wandb_run_info(),
|
|
2540
|
+
"gen_tokens_is_upper_bound": True,
|
|
2541
|
+
"thinking": THINKING,
|
|
2542
|
+
"max_completion_len": _max_completion,
|
|
2543
|
+
"prompts_per_step": batching["unique_prompts_per_step"],
|
|
2544
|
+
"generations_per_step": batching["generations_per_step"],
|
|
2545
|
+
"group_size": group_size,
|
|
2546
|
+
"per_device_train_batch_size": batching["per_device_train_batch_size"],
|
|
2547
|
+
"gradient_accumulation_steps": batching["gradient_accumulation_steps"],
|
|
2548
|
+
"grpo_recipe": {
|
|
2549
|
+
"lr_scheduler": "constant",
|
|
2550
|
+
"beta": _kl_beta,
|
|
2551
|
+
"scale_rewards": "none",
|
|
2552
|
+
"loss_type": "dr_grpo",
|
|
2553
|
+
"temperature": _temperature,
|
|
2554
|
+
"advantage_clip": _adv_clip,
|
|
2555
|
+
"thinking_length_penalty_coef": _think_penalty,
|
|
2556
|
+
"init_from_adapter": JOB_SPEC.train.init_from_adapter if JOB_SPEC else "",
|
|
2557
|
+
},
|
|
2558
|
+
},
|
|
2559
|
+
)
|
|
2560
|
+
free_gpu(trainer)
|
|
2561
|
+
|
|
2562
|
+
|
|
2563
|
+
# ---------------------------------------------------------------------------
|
|
2564
|
+
# Completion: train phase writes metrics.json + the DONE sentinel (see _finalize).
|
|
2565
|
+
# ---------------------------------------------------------------------------
|
|
2566
|
+
|
|
2567
|
+
|
|
2568
|
+
def write_train_meta(
|
|
2569
|
+
phase, adapter_dir, model_id, train_wall, setup_seconds, train_tokens, generated_tokens, notes
|
|
2570
|
+
):
|
|
2571
|
+
env = require_active_env()
|
|
2572
|
+
meta = {
|
|
2573
|
+
"phase": phase,
|
|
2574
|
+
"adapter_dir": adapter_dir,
|
|
2575
|
+
"model_id": model_id,
|
|
2576
|
+
"train_wall": train_wall,
|
|
2577
|
+
"setup_seconds": setup_seconds,
|
|
2578
|
+
"train_tokens": train_tokens,
|
|
2579
|
+
"generated_tokens": generated_tokens,
|
|
2580
|
+
"notes": notes or {},
|
|
2581
|
+
}
|
|
2582
|
+
with open("/tmp/train_meta.json", "w") as f:
|
|
2583
|
+
json.dump(meta, f)
|
|
2584
|
+
hf_upload_file("/tmp/train_meta.json", "train_meta.json")
|
|
2585
|
+
heartbeat(
|
|
2586
|
+
f"{phase}_train_done",
|
|
2587
|
+
**{k: meta[k] for k in ("train_wall", "train_tokens", "generated_tokens")},
|
|
2588
|
+
gpu=gpu_diagnostics(),
|
|
2589
|
+
)
|
|
2590
|
+
# Finalize directly from the training phase: build the run-metrics record (training
|
|
2591
|
+
# metrics only — loss/reward are streamed by the trainer; reward_history is in notes)
|
|
2592
|
+
# and write the completion sentinel. There is no separate eval phase.
|
|
2593
|
+
m = RunMetrics(
|
|
2594
|
+
# Substrate the worker actually ran on. The RunPod launcher sets FLASH_ARM; default to
|
|
2595
|
+
# "runpod" when unset so persisted metrics correctly attribute the compute backend.
|
|
2596
|
+
arm=os.environ.get("FLASH_ARM", "runpod"),
|
|
2597
|
+
phase=phase,
|
|
2598
|
+
seed=SEED,
|
|
2599
|
+
model_id=model_id,
|
|
2600
|
+
wall_seconds=train_wall,
|
|
2601
|
+
setup_seconds=setup_seconds,
|
|
2602
|
+
train_throughput_toks_per_s=(
|
|
2603
|
+
(generated_tokens or train_tokens) / train_wall if train_wall else 0.0
|
|
2604
|
+
),
|
|
2605
|
+
train_tokens=train_tokens,
|
|
2606
|
+
generated_tokens=generated_tokens,
|
|
2607
|
+
notes={
|
|
2608
|
+
**(notes or {}),
|
|
2609
|
+
"renderer": "flash_env",
|
|
2610
|
+
"thinking": THINKING,
|
|
2611
|
+
"train_wall": train_wall,
|
|
2612
|
+
"model_id": model_id,
|
|
2613
|
+
"environment": env.id,
|
|
2614
|
+
"job_spec": JOB_SPEC.to_dict() if JOB_SPEC else None,
|
|
2615
|
+
},
|
|
2616
|
+
)
|
|
2617
|
+
_finalize(m)
|
|
2618
|
+
|
|
2619
|
+
|
|
2620
|
+
def _resolve_adapter_ref(adapter_ref: str) -> tuple[str, str] | None:
|
|
2621
|
+
"""Resolve init_from_adapter into (repo, prefix).
|
|
2622
|
+
|
|
2623
|
+
The only public form is the exact adapter_ref emitted by ``flash status``:
|
|
2624
|
+
``<owner>/<repo>:<phase>/<run_id>/seed<N>``.
|
|
2625
|
+
"""
|
|
2626
|
+
adapter_ref = adapter_ref.strip()
|
|
2627
|
+
match = re.fullmatch(
|
|
2628
|
+
r"(?P<repo>[A-Za-z0-9][A-Za-z0-9._-]*/[A-Za-z0-9][A-Za-z0-9._-]*):"
|
|
2629
|
+
r"(?P<phase>sft|rl)/(?P<run_id>[A-Za-z0-9][A-Za-z0-9._-]{0,127})/seed(?P<seed>\d+)",
|
|
2630
|
+
adapter_ref,
|
|
2631
|
+
)
|
|
2632
|
+
if not match:
|
|
2633
|
+
return None
|
|
2634
|
+
repo, phase, run_id, seed = match.groups()
|
|
2635
|
+
return repo, f"{phase}/{run_id}/seed{seed}"
|
|
2636
|
+
|
|
2637
|
+
|
|
2638
|
+
def _download_adapter(adapter_prefix: str | None) -> str | None:
|
|
2639
|
+
"""Download an init_from_adapter LoRA to /tmp/evdl/<prefix>/adapter and return its dir.
|
|
2640
|
+
|
|
2641
|
+
``adapter_prefix`` must be the full ``adapter_ref`` string emitted by ``flash status``:
|
|
2642
|
+
``<owner>/<repo>:<phase>/<run_id>/seed<N>``.
|
|
2643
|
+
"""
|
|
2644
|
+
if not adapter_prefix:
|
|
2645
|
+
return None
|
|
2646
|
+
resolved = _resolve_adapter_ref(adapter_prefix)
|
|
2647
|
+
if not resolved:
|
|
2648
|
+
return None
|
|
2649
|
+
repo, prefix = resolved
|
|
2650
|
+
from huggingface_hub import snapshot_download
|
|
2651
|
+
|
|
2652
|
+
snapshot_download(
|
|
2653
|
+
repo_id=repo,
|
|
2654
|
+
repo_type="dataset",
|
|
2655
|
+
allow_patterns=[f"{prefix}/adapter/*"],
|
|
2656
|
+
local_dir="/tmp/evdl",
|
|
2657
|
+
token=os.environ.get("HF_TOKEN"),
|
|
2658
|
+
)
|
|
2659
|
+
adir = os.path.join("/tmp/evdl", prefix, "adapter")
|
|
2660
|
+
return adir if os.path.isdir(adir) else None
|
|
2661
|
+
|
|
2662
|
+
|
|
2663
|
+
def _finalize(metrics: RunMetrics):
|
|
2664
|
+
metrics.save("/tmp/metrics.json")
|
|
2665
|
+
# Required: a swallowed upload would make the control plane fail/retry a finished run.
|
|
2666
|
+
hf_upload_file("/tmp/metrics.json", "metrics.json", required=True)
|
|
2667
|
+
# DONE sentinel so the controller knows it's safe to tear down
|
|
2668
|
+
with open("/tmp/DONE", "w") as f:
|
|
2669
|
+
f.write(str(time.time()))
|
|
2670
|
+
hf_upload_file("/tmp/DONE", "DONE", required=True)
|
|
2671
|
+
heartbeat("done", gpu=gpu_diagnostics())
|
|
2672
|
+
print("NODE DONE:", metrics.to_json())
|
|
2673
|
+
|
|
2674
|
+
|
|
2675
|
+
# How long to wait for wandb.finish() to flush. On SUCCESS the full run must sync (a slow network /
|
|
2676
|
+
# large run can exceed the old 5s and leave the run "crashed"), so give it a generous-but-bounded
|
|
2677
|
+
# window; on FAILURE abort fast (the run is failing regardless and the worker is hard-exiting).
|
|
2678
|
+
_WANDB_FINISH_WAIT_S = 120.0
|
|
2679
|
+
_WANDB_FINISH_FAIL_WAIT_S = 5.0
|
|
2680
|
+
|
|
2681
|
+
|
|
2682
|
+
# Baked compiled-kernel cache (opt-in; see Dockerfile.worker + flash/engine/worker/kernel_warmup.py).
|
|
2683
|
+
# The Dockerfile points TRITON_CACHE_DIR/TORCHINDUCTOR_CACHE_DIR here and, when built with
|
|
2684
|
+
# --build-arg BUILD_KERNEL_CACHE=true, bakes a portable mega-cache produced on a real GPU. These
|
|
2685
|
+
# names are kept in lockstep with kernel_warmup.DEFAULT_CACHE_DIR / MEGA_CACHE_FILENAME.
|
|
2686
|
+
_KERNEL_CACHE_DIR = "/opt/flash/kernelcache"
|
|
2687
|
+
_KERNEL_CACHE_FILE = os.path.join(_KERNEL_CACHE_DIR, "mega_cache.bin")
|
|
2688
|
+
_KERNEL_CACHE_META_FILE = os.path.join(_KERNEL_CACHE_DIR, "mega_cache.json")
|
|
2689
|
+
|
|
2690
|
+
|
|
2691
|
+
def _current_cuda_sm(torch) -> str | None:
|
|
2692
|
+
try:
|
|
2693
|
+
if not torch.cuda.is_available():
|
|
2694
|
+
return None
|
|
2695
|
+
cap = torch.cuda.get_device_capability(0)
|
|
2696
|
+
return f"sm{cap[0]}{cap[1]}"
|
|
2697
|
+
except Exception:
|
|
2698
|
+
return None
|
|
2699
|
+
|
|
2700
|
+
|
|
2701
|
+
def _load_kernel_cache_if_present() -> bool:
|
|
2702
|
+
"""Best-effort: if a baked mega-cache blob exists, load it so the worker skips first-run JIT.
|
|
2703
|
+
|
|
2704
|
+
Loads the portable cache that kernel_warmup.py wrote on a GPU builder via
|
|
2705
|
+
``torch.compiler.load_cache_artifacts()`` — measured cold compile ~124s -> warm load ~0.2s.
|
|
2706
|
+
OPT-IN: when no baked cache is present (the default image build), this is a no-op and the worker
|
|
2707
|
+
JITs on first use exactly as before (#163's init heartbeat covers that stall). Never raises:
|
|
2708
|
+
a missing torch / missing file / unusable blob just logs and leaves the JIT path intact.
|
|
2709
|
+
"""
|
|
2710
|
+
def _reject(reason: str) -> bool:
|
|
2711
|
+
# a baked cache is present but unusable (no/garbled metadata or wrong arch): repoint
|
|
2712
|
+
# triton/inductor OFF the baked trees (Dockerfile points them at /opt/flash/kernelcache)
|
|
2713
|
+
# so the JIT fallback compiles fresh into scratch instead of reusing wrong-arch baked
|
|
2714
|
+
# entries that would collide with this worker's arch.
|
|
2715
|
+
print(f"[kernel-cache] {reason} -> first-run JIT fallback")
|
|
2716
|
+
scratch = os.path.join(tempfile.gettempdir(), "flash-kernelcache-jit")
|
|
2717
|
+
for sub, var in (("triton", "TRITON_CACHE_DIR"), ("inductor", "TORCHINDUCTOR_CACHE_DIR")):
|
|
2718
|
+
d = os.path.join(scratch, sub)
|
|
2719
|
+
os.makedirs(d, exist_ok=True)
|
|
2720
|
+
os.environ[var] = d
|
|
2721
|
+
return False
|
|
2722
|
+
|
|
2723
|
+
if not os.path.isfile(_KERNEL_CACHE_FILE):
|
|
2724
|
+
print(f"[kernel-cache] no baked cache at {_KERNEL_CACHE_FILE} -> first-run JIT (expected default)")
|
|
2725
|
+
return False
|
|
2726
|
+
try:
|
|
2727
|
+
import torch
|
|
2728
|
+
|
|
2729
|
+
current_sm = _current_cuda_sm(torch)
|
|
2730
|
+
try:
|
|
2731
|
+
with open(_KERNEL_CACHE_META_FILE) as f:
|
|
2732
|
+
meta = json.load(f)
|
|
2733
|
+
except FileNotFoundError:
|
|
2734
|
+
return _reject("baked cache has no metadata")
|
|
2735
|
+
except Exception as e:
|
|
2736
|
+
return _reject(f"metadata unreadable ({e})")
|
|
2737
|
+
cached_sm = str(meta.get("sm") or "")
|
|
2738
|
+
if not current_sm:
|
|
2739
|
+
# can't verify the worker's GPU arch -> don't risk loading a wrong-arch blob; JIT instead.
|
|
2740
|
+
return _reject("worker GPU arch undetermined")
|
|
2741
|
+
if cached_sm != current_sm:
|
|
2742
|
+
return _reject(
|
|
2743
|
+
f"baked cache arch {cached_sm or 'unknown'} does not match worker arch {current_sm}"
|
|
2744
|
+
)
|
|
2745
|
+
with open(_KERNEL_CACHE_FILE, "rb") as f:
|
|
2746
|
+
blob = f.read()
|
|
2747
|
+
torch.compiler.load_cache_artifacts(blob)
|
|
2748
|
+
print(
|
|
2749
|
+
f"[kernel-cache] loaded baked mega-cache for {cached_sm or 'unknown'} "
|
|
2750
|
+
f"({len(blob)} bytes) -> skipping first-run JIT"
|
|
2751
|
+
)
|
|
2752
|
+
return True
|
|
2753
|
+
except Exception as e:
|
|
2754
|
+
# never block boot on a bad/absent cache: fall back to the normal JIT path. repoint off the
|
|
2755
|
+
# baked trees too — if the mega blob was present + arch-matched but load raised, the on-disk
|
|
2756
|
+
# triton/inductor entries may be partial/corrupt, so JIT fresh into scratch.
|
|
2757
|
+
return _reject(f"load skipped ({e})")
|
|
2758
|
+
|
|
2759
|
+
|
|
2760
|
+
def wandb_finish(exit_code: int = 0) -> None:
|
|
2761
|
+
"""Finalize the W&B run before the worker's hard ``os._exit()``.
|
|
2762
|
+
|
|
2763
|
+
The worker hard-exits to dodge the colocated-vLLM teardown deadlock (see main),
|
|
2764
|
+
which skips wandb's atexit sync — so a *successfully completed* run was left
|
|
2765
|
+
dangling and W&B eventually marked it ``crashed`` even though all metrics were
|
|
2766
|
+
logged. Explicitly finish the run (we own it: we called ``wandb.init`` in
|
|
2767
|
+
``wandb_report_to``) so it shows ``finished``. Best-effort; never raises (W&B is
|
|
2768
|
+
optional, metrics.json is the source of truth)."""
|
|
2769
|
+
if not os.environ.get("WANDB_API_KEY"):
|
|
2770
|
+
return
|
|
2771
|
+
import importlib.util
|
|
2772
|
+
|
|
2773
|
+
# find_spec can RAISE (not just return None) when wandb is already in sys.modules with an
|
|
2774
|
+
# absent/partial __spec__ (e.g. a namespace-package or a partially-initialized import) — that
|
|
2775
|
+
# would propagate out of the shutdown path and skip the hard exit. Keep it best-effort: treat any
|
|
2776
|
+
# probe failure as "wandb present enough to try", and let the import + finish below (already
|
|
2777
|
+
# wrapped) decide. Only a definitive None (probe succeeded, module truly absent) returns early.
|
|
2778
|
+
try:
|
|
2779
|
+
if importlib.util.find_spec("wandb") is None:
|
|
2780
|
+
return
|
|
2781
|
+
except Exception:
|
|
2782
|
+
pass # ambiguous probe -> fall through and try to finish (still fully guarded below)
|
|
2783
|
+
try:
|
|
2784
|
+
import wandb
|
|
2785
|
+
|
|
2786
|
+
if getattr(wandb, "run", None) is None:
|
|
2787
|
+
return
|
|
2788
|
+
|
|
2789
|
+
errs: list[Exception] = []
|
|
2790
|
+
|
|
2791
|
+
def _finish() -> None:
|
|
2792
|
+
try:
|
|
2793
|
+
wandb.finish(exit_code=exit_code)
|
|
2794
|
+
except Exception as e:
|
|
2795
|
+
errs.append(e)
|
|
2796
|
+
|
|
2797
|
+
t = threading.Thread(target=_finish, daemon=True)
|
|
2798
|
+
t.start()
|
|
2799
|
+
# On SUCCESS (exit_code == 0) wandb.finish() must flush the full run; a slow network / large
|
|
2800
|
+
# run can take well over 5s, and cutting it off there is what leaves the run dangling ->
|
|
2801
|
+
# "crashed". Allow a longer, still-bounded wait on success; keep the short cut-off on the
|
|
2802
|
+
# FAILURE path (exit_code != 0) where we want to abort fast and the run is failing anyway.
|
|
2803
|
+
wait_s = _WANDB_FINISH_WAIT_S if exit_code == 0 else _WANDB_FINISH_FAIL_WAIT_S
|
|
2804
|
+
t.join(timeout=wait_s)
|
|
2805
|
+
if t.is_alive():
|
|
2806
|
+
print(f"[wandb] finish() did not complete within {wait_s}s; continuing with hard exit")
|
|
2807
|
+
elif errs:
|
|
2808
|
+
print(f"[wandb] finish() warning: {errs[0]}")
|
|
2809
|
+
except Exception as e: # pragma: no cover - logging-only path
|
|
2810
|
+
print(f"[wandb] finish() warning: {e}")
|
|
2811
|
+
|
|
2812
|
+
|
|
2813
|
+
def main():
|
|
2814
|
+
# Idempotency: if DONE was already uploaded, a re-delivered job re-fetches the final
|
|
2815
|
+
# metrics from HF and returns them immediately. (The previous behavior — sleeping in
|
|
2816
|
+
# an infinite loop — kept a billable GPU worker alive until the execution timeout.)
|
|
2817
|
+
try:
|
|
2818
|
+
# Idempotency FIRST — before any env-mutating pip install / package removal: a re-delivered
|
|
2819
|
+
# job whose DONE already exists must return the persisted metrics and exit WITHOUT running
|
|
2820
|
+
# _ensure_fla_fastpath_on_hopper() (mutates the env: pip-installs tilelang/fla) — that wasted
|
|
2821
|
+
# a worker mutating its env on an already-complete run. It runs after the DONE check below.
|
|
2822
|
+
if HF_REPO:
|
|
2823
|
+
from huggingface_hub import hf_hub_download
|
|
2824
|
+
|
|
2825
|
+
try:
|
|
2826
|
+
hf_hub_download(
|
|
2827
|
+
repo_id=HF_REPO,
|
|
2828
|
+
repo_type="dataset",
|
|
2829
|
+
filename=f"{hf_prefix()}/DONE",
|
|
2830
|
+
token=os.environ.get("HF_TOKEN"),
|
|
2831
|
+
)
|
|
2832
|
+
done = True
|
|
2833
|
+
except Exception:
|
|
2834
|
+
done = False
|
|
2835
|
+
if done:
|
|
2836
|
+
print("Run already complete (DONE present); returning persisted metrics.")
|
|
2837
|
+
heartbeat("already_done", gpu=gpu_diagnostics(include_torch=False))
|
|
2838
|
+
try:
|
|
2839
|
+
got = hf_hub_download(
|
|
2840
|
+
repo_id=HF_REPO,
|
|
2841
|
+
repo_type="dataset",
|
|
2842
|
+
filename=f"{hf_prefix()}/metrics.json",
|
|
2843
|
+
token=os.environ.get("HF_TOKEN"),
|
|
2844
|
+
)
|
|
2845
|
+
import shutil
|
|
2846
|
+
|
|
2847
|
+
shutil.copy(got, "/tmp/metrics.json")
|
|
2848
|
+
sys.stdout.flush()
|
|
2849
|
+
os._exit(0)
|
|
2850
|
+
except Exception as e:
|
|
2851
|
+
raise SystemExit(f"DONE present but metrics.json unavailable: {e}") from e
|
|
2852
|
+
# Not a DONE re-delivery -> this worker will train. These must run before any model import:
|
|
2853
|
+
_ensure_fla_fastpath_on_hopper() # Hopper: enable fla+tilelang GDN fast path (see perf.py)
|
|
2854
|
+
# Repoint tilelang's libcudart_stub.so at the real CUDA runtime so it can't shadow libcudart
|
|
2855
|
+
# in vLLM's CudaRTLibrary (intermittent `undefined symbol: cudaDeviceReset` on GRPO vLLM
|
|
2856
|
+
# init, any model size/arch). AFTER the fla fast path (a tilelang reinstall there rewrites
|
|
2857
|
+
# the stub) and BEFORE the model/vLLM import. See perf.py / flash #184.
|
|
2858
|
+
_neutralize_tilelang_cudart_stub()
|
|
2859
|
+
heartbeat("boot", gpu=gpu_diagnostics(include_torch=False))
|
|
2860
|
+
finalize_alloc_conf_for_sleep() # sync CUDA alloc conf to resolved sleep (before first CUDA alloc)
|
|
2861
|
+
# Opt-in: load a baked compiled-kernel mega-cache (if the image shipped one) so the worker
|
|
2862
|
+
# skips the ~10-15 min first-run JIT. Best-effort + no-op when absent (the default), so the
|
|
2863
|
+
# normal JIT path is untouched. Runs AFTER finalize_alloc_conf_for_sleep: _load probes CUDA
|
|
2864
|
+
# (_current_cuda_sm -> get_device_capability triggers CUDA init), so the allocator conf must be
|
|
2865
|
+
# resolved first; still before any model/kernel import that would otherwise trigger compilation.
|
|
2866
|
+
_load_kernel_cache_if_present()
|
|
2867
|
+
# Dispatch table — register new algorithms (e.g. ppo) here as they land.
|
|
2868
|
+
modes = {
|
|
2869
|
+
"sft": run_sft, # SFT (TRL SFTTrainer)
|
|
2870
|
+
"rl": run_rl, # GRPO (TRL GRPOTrainer + colocated vLLM)
|
|
2871
|
+
}
|
|
2872
|
+
handler = modes.get(RUN_MODE)
|
|
2873
|
+
if handler is None:
|
|
2874
|
+
raise SystemExit(f"unknown RUN_MODE {RUN_MODE}; known: {sorted(modes)}")
|
|
2875
|
+
handler()
|
|
2876
|
+
# All artifacts (adapter, train_meta, metrics, DONE) are uploaded to HF *inside* the
|
|
2877
|
+
# handler. The RL trainer's colocated vLLM can DEADLOCK at interpreter shutdown
|
|
2878
|
+
# during NCCL/IPC/CUDA teardown — not segfault-and-exit (which `check=False` on the
|
|
2879
|
+
# train subprocess already tolerates), but hang forever. That would block the Flash
|
|
2880
|
+
# handler's *blocking* `subprocess.run` (heartbeat frozen at "rl_train_done") and the
|
|
2881
|
+
# whole run stalls until the wall-clock cap. Hard-exit to bypass the hanging teardown now that
|
|
2882
|
+
# every output is safely persisted.
|
|
2883
|
+
wandb_finish(exit_code=0) # mark the W&B run finished BEFORE os._exit (which skips wandb's atexit sync)
|
|
2884
|
+
sys.stdout.flush()
|
|
2885
|
+
sys.stderr.flush()
|
|
2886
|
+
os._exit(0)
|
|
2887
|
+
except Exception as e:
|
|
2888
|
+
# Structured retry signal both pollers read: an infra failure -> retry on a fresh worker.
|
|
2889
|
+
# GitHubRateLimitError (env ref resolution hit a persistent GitHub rate limit) is retriable:
|
|
2890
|
+
# reschedule on a fresh worker once the limit window resets rather than hard-failing. Env
|
|
2891
|
+
# resolution runs lazily inside this try (require_active_env, called by the handlers above),
|
|
2892
|
+
# never at import, so a rate-limit raise reaches here and is classified correctly.
|
|
2893
|
+
retriable = isinstance(e, (RetriableInfraError, GitHubRateLimitError))
|
|
2894
|
+
tb = traceback.format_exc()
|
|
2895
|
+
traceback.print_exc()
|
|
2896
|
+
try:
|
|
2897
|
+
err_name = error_artifact_name(RUN_MODE)
|
|
2898
|
+
err_path = f"/tmp/{err_name}"
|
|
2899
|
+
with open(err_path, "w") as f:
|
|
2900
|
+
f.write(tb)
|
|
2901
|
+
hf_upload_file(err_path, err_name)
|
|
2902
|
+
except Exception as up_err:
|
|
2903
|
+
print("error-upload warn:", up_err)
|
|
2904
|
+
hb_flags = {"retriable": retriable}
|
|
2905
|
+
try:
|
|
2906
|
+
heartbeat(f"error_{RUN_MODE}", error=str(e)[:500], **hb_flags, diag=gpu_diagnostics())
|
|
2907
|
+
except Exception:
|
|
2908
|
+
heartbeat(f"error_{RUN_MODE}", error=str(e)[:500], **hb_flags)
|
|
2909
|
+
# keep container alive briefly so logs flush, then exit non-zero -> restart
|
|
2910
|
+
wandb_finish(exit_code=1) # finalize the W&B run as failed (don't leave it dangling -> "crashed")
|
|
2911
|
+
time.sleep(10)
|
|
2912
|
+
raise
|
|
2913
|
+
|
|
2914
|
+
|
|
2915
|
+
if __name__ == "__main__":
|
|
2916
|
+
main()
|