freesolo-flash-dev 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flash/__init__.py +29 -0
- flash/_channel.py +23 -0
- flash/_fileio.py +35 -0
- flash/_logging.py +49 -0
- flash/_update_check.py +266 -0
- flash/catalog.py +253 -0
- flash/cli/__init__.py +1 -0
- flash/cli/main/__init__.py +227 -0
- flash/cli/main/__main__.py +6 -0
- flash/cli/main/commands.py +636 -0
- flash/cli/main/envpush.py +317 -0
- flash/cli/main/render.py +599 -0
- flash/cli/main/training_doc.py +455 -0
- flash/client/__init__.py +14 -0
- flash/client/config.py +70 -0
- flash/client/http.py +372 -0
- flash/client/runtime_secrets.py +69 -0
- flash/client/specs.py +20 -0
- flash/cost/__init__.py +16 -0
- flash/cost/analytical.py +175 -0
- flash/cost/facts.py +114 -0
- flash/cost/spec.py +113 -0
- flash/cost/types.py +158 -0
- flash/engine/__init__.py +6 -0
- flash/engine/accounting.py +36 -0
- flash/engine/chalk_kernels.py +116 -0
- flash/engine/multiturn_rollout.py +780 -0
- flash/engine/recipe.py +86 -0
- flash/engine/vram.py +603 -0
- flash/engine/worker/__init__.py +2916 -0
- flash/engine/worker/__main__.py +4 -0
- flash/engine/worker/kernel_warmup.py +400 -0
- flash/engine/worker/lora.py +796 -0
- flash/engine/worker/packing.py +366 -0
- flash/engine/worker/perf.py +1048 -0
- flash/envs/__init__.py +10 -0
- flash/envs/adapter/__init__.py +883 -0
- flash/envs/adapter/rubric.py +222 -0
- flash/envs/base.py +52 -0
- flash/envs/registry.py +62 -0
- flash/mcp/__init__.py +1 -0
- flash/mcp/server.py +85 -0
- flash/providers/__init__.py +59 -0
- flash/providers/_auth.py +24 -0
- flash/providers/_http.py +230 -0
- flash/providers/_instance.py +416 -0
- flash/providers/_instance_bootstrap.py +517 -0
- flash/providers/_poll.py +311 -0
- flash/providers/allocator.py +193 -0
- flash/providers/base.py +431 -0
- flash/providers/hyperstack/__init__.py +127 -0
- flash/providers/hyperstack/api.py +522 -0
- flash/providers/hyperstack/auth.py +17 -0
- flash/providers/hyperstack/gpus.py +29 -0
- flash/providers/hyperstack/jobs/__init__.py +632 -0
- flash/providers/hyperstack/jobs/builders.py +122 -0
- flash/providers/hyperstack/preflight.py +23 -0
- flash/providers/hyperstack/pricing.py +26 -0
- flash/providers/hyperstack/train.py +25 -0
- flash/providers/lambdalabs/__init__.py +139 -0
- flash/providers/lambdalabs/api.py +261 -0
- flash/providers/lambdalabs/auth.py +18 -0
- flash/providers/lambdalabs/gpus.py +29 -0
- flash/providers/lambdalabs/jobs/__init__.py +724 -0
- flash/providers/lambdalabs/jobs/builders.py +118 -0
- flash/providers/lambdalabs/preflight.py +27 -0
- flash/providers/lambdalabs/pricing.py +51 -0
- flash/providers/lambdalabs/train.py +27 -0
- flash/providers/preflight.py +55 -0
- flash/providers/realized.py +80 -0
- flash/providers/runpod/__init__.py +130 -0
- flash/providers/runpod/api.py +186 -0
- flash/providers/runpod/auth.py +37 -0
- flash/providers/runpod/cost.py +57 -0
- flash/providers/runpod/gpus.py +46 -0
- flash/providers/runpod/jobs.py +956 -0
- flash/providers/runpod/keys.py +139 -0
- flash/providers/runpod/preflight.py +30 -0
- flash/providers/runpod/preload.py +915 -0
- flash/providers/runpod/pricing.py +18 -0
- flash/providers/runpod/slots.py +79 -0
- flash/providers/runpod/train/__init__.py +150 -0
- flash/providers/runpod/train/deps.py +395 -0
- flash/providers/runpod/train/endpoints.py +820 -0
- flash/py.typed +0 -0
- flash/runner/__init__.py +686 -0
- flash/runner/checkpoints.py +82 -0
- flash/runner/deploy.py +422 -0
- flash/runner/lifecycle.py +672 -0
- flash/schema/__init__.py +375 -0
- flash/schema/fields.py +331 -0
- flash/serve/__init__.py +1 -0
- flash/serve/deploy.py +326 -0
- flash/serve/pricing.py +60 -0
- flash/server/__init__.py +1 -0
- flash/server/__main__.py +20 -0
- flash/server/app.py +961 -0
- flash/server/auth.py +263 -0
- flash/server/billing.py +124 -0
- flash/server/checkpoints.py +110 -0
- flash/server/db.py +160 -0
- flash/server/environment_registry.py +102 -0
- flash/server/envs.py +360 -0
- flash/server/reconcile.py +163 -0
- flash/server/run_registry.py +150 -0
- flash/spec.py +333 -0
- freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
- freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
- freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
- freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
- freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,796 @@
|
|
|
1
|
+
"""Pure LoRA-target / VL-checkpoint helpers for the fine-tuning worker.
|
|
2
|
+
|
|
3
|
+
These helpers take the model id as an ARGUMENT and read NONE of the worker's run-scoped
|
|
4
|
+
module globals, so they live here as a leaf module. ``flash.engine.worker`` re-exports
|
|
5
|
+
them; this module must NOT import that package (no cycle). Heavy deps (transformers, peft,
|
|
6
|
+
vllm, the catalog) are imported lazily inside the functions so the module stays
|
|
7
|
+
CPU-importable.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _patch_peft_weight_converter_compat() -> None:
|
|
14
|
+
"""peft 0.19.1 x transformers 5.6-5.10: make MoE adapter loading work.
|
|
15
|
+
|
|
16
|
+
peft's ``build_peft_weight_mapping`` reconstructs transformers ``WeightConverter``
|
|
17
|
+
objects passing ``distributed_operation=`` / ``quantization_operation=`` — kwargs
|
|
18
|
+
the WeightConverter in transformers <5.11 doesn't accept (init=False dataclass
|
|
19
|
+
fields), so loading a LoRA adapter onto any arch WITH weight conversions dies with
|
|
20
|
+
``TypeError: unexpected keyword argument 'distributed_operation'`` (observed on a
|
|
21
|
+
weight-converting checkpoint eval). The
|
|
22
|
+
worker can't take transformers>=5.11 (vllm 0.19.1 compat), so accept-and-drop
|
|
23
|
+
unknown kwargs; on a single GPU those fields are unused. No-op once signatures
|
|
24
|
+
match.
|
|
25
|
+
"""
|
|
26
|
+
import inspect
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
from transformers import core_model_loading as cml
|
|
30
|
+
except Exception: # pragma: no cover - older stacks have no converter module
|
|
31
|
+
return
|
|
32
|
+
converter = getattr(cml, "WeightConverter", None)
|
|
33
|
+
if converter is None or getattr(converter, "_flash_compat", False):
|
|
34
|
+
return
|
|
35
|
+
accepted = set(inspect.signature(converter.__init__).parameters)
|
|
36
|
+
if "distributed_operation" in accepted:
|
|
37
|
+
return
|
|
38
|
+
orig_init = converter.__init__
|
|
39
|
+
|
|
40
|
+
def _compat_init(self, *args, **kwargs):
|
|
41
|
+
dropped = [k for k in kwargs if k not in accepted]
|
|
42
|
+
for k in dropped:
|
|
43
|
+
kwargs.pop(k)
|
|
44
|
+
orig_init(self, *args, **kwargs)
|
|
45
|
+
|
|
46
|
+
converter.__init__ = _compat_init
|
|
47
|
+
converter._flash_compat = True
|
|
48
|
+
print("[compat] WeightConverter patched (peft<->transformers signature drift)")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# Module-path segments that must never receive LoRA on natively-multimodal checkpoints
|
|
52
|
+
# trained text-only: the vision tower / projector / MTP head. Critically, adapters that
|
|
53
|
+
# DO touch them cannot be loaded by vLLM in text-only (language_model_only) serving —
|
|
54
|
+
# its LoRA loader rejects "unexpected modules" (observed with Qwen3.5-2B).
|
|
55
|
+
_VL_EXCLUDE_SEGMENTS = ("visual", "vision_tower", "multi_modal_projector", "mtp")
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def lora_exclude_modules(model_id: str) -> str | None:
|
|
59
|
+
"""Regex (peft fullmatch semantics) excluding vision-tower modules from LoRA.
|
|
60
|
+
|
|
61
|
+
Returns None when no exclusion is needed (pure text architectures). NOTE: peft's
|
|
62
|
+
list-form exclude_modules uses suffix matching (like target_modules), which does
|
|
63
|
+
NOT match leaf modules under 'visual.*' — a regex string is required.
|
|
64
|
+
"""
|
|
65
|
+
excludes = {
|
|
66
|
+
"qwen3_5": _VL_EXCLUDE_SEGMENTS,
|
|
67
|
+
"qwen3_5_moe": _VL_EXCLUDE_SEGMENTS,
|
|
68
|
+
"qwen3_6": _VL_EXCLUDE_SEGMENTS,
|
|
69
|
+
}
|
|
70
|
+
try:
|
|
71
|
+
from transformers import AutoConfig
|
|
72
|
+
|
|
73
|
+
cfg = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
|
|
74
|
+
model_type = getattr(cfg, "model_type", "") or ""
|
|
75
|
+
except Exception as e:
|
|
76
|
+
print("lora_exclude_modules: config probe failed:", e)
|
|
77
|
+
return None
|
|
78
|
+
segments = excludes.get(model_type)
|
|
79
|
+
if not segments:
|
|
80
|
+
return None
|
|
81
|
+
alt = "|".join(segments)
|
|
82
|
+
return rf"(^|.*\.)({alt})(\..*|$)"
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def is_vl_checkpoint(model_id: str) -> bool:
|
|
86
|
+
"""True for natively-multimodal checkpoints we train/serve text-only (Qwen3.5/3.6)."""
|
|
87
|
+
return bool(lora_exclude_modules(model_id))
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def vllm_language_model_only_kwargs(model_id: str) -> dict:
|
|
91
|
+
"""Engine kwargs to skip the vision tower for VL checkpoints (vLLM >= 0.19).
|
|
92
|
+
|
|
93
|
+
Besides wasting VRAM, the vision tower's attention path hardcodes vLLM's bundled
|
|
94
|
+
flash-attn, whose PTX needs a newer driver JIT than many RTX 5090 hosts have
|
|
95
|
+
("PTX compiled with unsupported toolchain") — text-only loading sidesteps it and
|
|
96
|
+
is the officially supported way to run Qwen3.5 as a pure LLM.
|
|
97
|
+
"""
|
|
98
|
+
return {"language_model_only": True} if is_vl_checkpoint(model_id) else {}
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def patch_vllm_language_model_only(model_id: str) -> bool:
|
|
102
|
+
"""Force ``language_model_only=True`` on vLLM engines created by third-party code
|
|
103
|
+
(TRL's colocated GRPO rollout engine) for VL checkpoints. Returns True if patched."""
|
|
104
|
+
extra = vllm_language_model_only_kwargs(model_id)
|
|
105
|
+
if not extra:
|
|
106
|
+
return False
|
|
107
|
+
try:
|
|
108
|
+
import vllm
|
|
109
|
+
|
|
110
|
+
if getattr(vllm.LLM.__init__, "_flash_lmo_patched", False):
|
|
111
|
+
return True
|
|
112
|
+
orig = vllm.LLM.__init__
|
|
113
|
+
|
|
114
|
+
def patched(self, *args, **kwargs):
|
|
115
|
+
kwargs.setdefault("language_model_only", True)
|
|
116
|
+
return orig(self, *args, **kwargs)
|
|
117
|
+
|
|
118
|
+
patched._flash_lmo_patched = True
|
|
119
|
+
vllm.LLM.__init__ = patched
|
|
120
|
+
print(f"[vllm] language_model_only patch active for {model_id}")
|
|
121
|
+
return True
|
|
122
|
+
except Exception as e:
|
|
123
|
+
print("patch_vllm_language_model_only warn:", e)
|
|
124
|
+
return False
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
# Flipped to True only AFTER the GRPO trainer (and its colocated vLLM engine + initial
|
|
128
|
+
# checkpoint load) is constructed, but BEFORE ``trainer.train()`` runs the first weight sync.
|
|
129
|
+
# See ``patch_vllm_lm_weight_sync``. A module dict (not a bare bool) so the gating flag is shared
|
|
130
|
+
# by reference between this module and the worker package that flips it.
|
|
131
|
+
_LM_SYNC_REMAP_ON = {"on": False}
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _remap_vl_sync_weights(weights):
|
|
135
|
+
"""Rewrite TRL's trainer weight names to vLLM's VL-engine names for the train-time sync.
|
|
136
|
+
|
|
137
|
+
The trainer (built via ``AutoModelForCausalLM``) names its LM params ``model.layers.*`` /
|
|
138
|
+
``model.norm`` / ``model.embed_tokens`` / ``lm_head.*``; the colocated vLLM engine loaded the
|
|
139
|
+
same checkpoint as ``Qwen3_5ForConditionalGeneration`` whose LM params live under
|
|
140
|
+
``language_model.*``. Prefix incoming ``model.``/``lm_head.`` names with ``language_model.`` so
|
|
141
|
+
they resolve. Also tolerate a peft ``base_model.model.`` prefix (a merged-adapter sync can yield
|
|
142
|
+
base-model names through that wrapper) by stripping it before the language_model. prefix is
|
|
143
|
+
added. Names that already start with ``language_model.`` (or anything else) pass through
|
|
144
|
+
untouched. A generator so vLLM's loader still streams one (name, tensor) at a time.
|
|
145
|
+
"""
|
|
146
|
+
for name, tensor in weights:
|
|
147
|
+
# A continued-adapter (PeftModel) sync can surface names through the peft wrapper as
|
|
148
|
+
# ``base_model.model.model.layers.*`` / ``base_model.model.lm_head.*``; strip the wrapper
|
|
149
|
+
# so the same model./lm_head. rule applies.
|
|
150
|
+
if name.startswith("base_model.model."):
|
|
151
|
+
name = name[len("base_model.model.") :]
|
|
152
|
+
if name.startswith(("model.", "lm_head.")):
|
|
153
|
+
name = "language_model." + name
|
|
154
|
+
yield name, tensor
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def patch_vllm_lm_weight_sync(model_id: str) -> bool:
|
|
158
|
+
"""Make TRL's GRPO ``sync_weights`` work for ``*ForConditionalGeneration`` checkpoints
|
|
159
|
+
(the whole Qwen3.5/3.6 family). Returns True if any vLLM model class was patched.
|
|
160
|
+
|
|
161
|
+
The trainer loads via ``AutoModelForCausalLM`` so its params are named ``model.layers.*`` /
|
|
162
|
+
``model.norm`` / ``model.embed_tokens`` / ``lm_head.*``. vLLM loads the same checkpoint as
|
|
163
|
+
``Qwen3_5ForConditionalGeneration`` whose LM params live under ``language_model.*``. TRL's
|
|
164
|
+
``sync_weights`` pushes the trainer names verbatim, so vLLM's loader raises "There is no module
|
|
165
|
+
or parameter named 'model' in Qwen3_5ForConditionalGeneration" at the first generation step and
|
|
166
|
+
GRPO dies (even with ``language_model_only=True``: that only skips loading the vision tower, it
|
|
167
|
+
does NOT rename the surviving LM params out from under ``language_model.``).
|
|
168
|
+
|
|
169
|
+
The fix wraps the vLLM model class ``load_weights`` to remap incoming ``model.``/``lm_head.``
|
|
170
|
+
names to ``language_model.*`` (see ``_remap_vl_sync_weights``) — but ONLY while
|
|
171
|
+
``_LM_SYNC_REMAP_ON`` is set. The INITIAL checkpoint load (during trainer construction) runs
|
|
172
|
+
with it OFF, so vLLM's own ``hf_to_vllm_mapper`` handles the on-disk checkpoint untouched; the
|
|
173
|
+
remap activates only for the train-time TRL syncs. The flag is flipped on between trainer
|
|
174
|
+
construction and ``train()``. Works for BOTH from-base and warm-started (init_from_adapter)
|
|
175
|
+
GRPO. No-op for non-VL checkpoints."""
|
|
176
|
+
if not is_vl_checkpoint(model_id):
|
|
177
|
+
return False
|
|
178
|
+
patched_any = False
|
|
179
|
+
try:
|
|
180
|
+
import importlib
|
|
181
|
+
|
|
182
|
+
# The dense class is REQUIRED for the whole Qwen3.5/3.6 family — if its module/class can't
|
|
183
|
+
# be imported (vLLM not installed where it should be, or the class renamed in a new vLLM)
|
|
184
|
+
# we must NOT silently no-op: the run would crash again at the first ``sync_weights()`` with
|
|
185
|
+
# a far less actionable error. Log loudly for the required one; the MoE class is OPTIONAL
|
|
186
|
+
# (only some models are MoE, and older vLLM lacks the module) so its absence stays quiet.
|
|
187
|
+
for mod_name, cls_name, required in (
|
|
188
|
+
("vllm.model_executor.models.qwen3_5", "Qwen3_5ForConditionalGeneration", True),
|
|
189
|
+
("vllm.model_executor.models.qwen3_5_moe", "Qwen3_5MoeForConditionalGeneration", False),
|
|
190
|
+
):
|
|
191
|
+
try:
|
|
192
|
+
mod = importlib.import_module(mod_name)
|
|
193
|
+
except Exception as e:
|
|
194
|
+
mod = None
|
|
195
|
+
if required:
|
|
196
|
+
print(
|
|
197
|
+
f"[vllm] WARN patch_vllm_lm_weight_sync: could not import required module "
|
|
198
|
+
f"{mod_name} ({e!r}); GRPO weight-sync will NOT be remapped and the run may "
|
|
199
|
+
f"crash at the first sync_weights() for this VL checkpoint."
|
|
200
|
+
)
|
|
201
|
+
cls = getattr(mod, cls_name, None) if mod is not None else None
|
|
202
|
+
if cls is None:
|
|
203
|
+
if required and mod is not None:
|
|
204
|
+
print(
|
|
205
|
+
f"[vllm] WARN patch_vllm_lm_weight_sync: module {mod_name} imported but has "
|
|
206
|
+
f"no {cls_name} (vLLM API changed?); GRPO weight-sync will NOT be remapped "
|
|
207
|
+
f"and the run may crash at the first sync_weights() for this VL checkpoint."
|
|
208
|
+
)
|
|
209
|
+
continue
|
|
210
|
+
if getattr(cls.load_weights, "_flash_sync_patched", False):
|
|
211
|
+
continue
|
|
212
|
+
orig_load = cls.load_weights
|
|
213
|
+
|
|
214
|
+
def _make_patched(orig):
|
|
215
|
+
def patched(self, weights, *args, **kwargs):
|
|
216
|
+
if _LM_SYNC_REMAP_ON["on"]:
|
|
217
|
+
weights = _remap_vl_sync_weights(weights)
|
|
218
|
+
return orig(self, weights, *args, **kwargs)
|
|
219
|
+
|
|
220
|
+
patched._flash_sync_patched = True
|
|
221
|
+
return patched
|
|
222
|
+
|
|
223
|
+
cls.load_weights = _make_patched(orig_load)
|
|
224
|
+
patched_any = True
|
|
225
|
+
print(f"[vllm] LM weight-sync name patch installed for {cls_name} (gated)")
|
|
226
|
+
except Exception as e:
|
|
227
|
+
print("patch_vllm_lm_weight_sync warn:", e)
|
|
228
|
+
return patched_any
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def patch_grpo_mask_aware_lm_head(trainer) -> bool:
|
|
232
|
+
"""Skip the 248k-vocab ``lm_head`` projection at MASKED completion positions in the GRPO loss.
|
|
233
|
+
|
|
234
|
+
Targets MULTI-TURN GRPO, where the masked set is the env/tool text (~half-to-most of the
|
|
235
|
+
transcript: the rollout's ``env_mask`` -> TRL's ``tool_mask``) that EVERY row carries, so the
|
|
236
|
+
micro-batch has maskable headroom in all rows. TRL 1.6's ``compute_liger_loss`` hands the
|
|
237
|
+
FULL-length hidden states to ``liger_grpo_loss``, and the Liger kernel runs the lm_head matmul +
|
|
238
|
+
log-softmax for EVERY position (in the forward AND the backward recompute). Masked positions
|
|
239
|
+
contribute zero loss and zero gradient but still pay the full FLOPs of the single most expensive
|
|
240
|
+
GRPO op (the 248k-vocab projection Liger exists to tame). The saving scales with the env-masked
|
|
241
|
+
fraction. (SINGLE-TURN is effectively a no-op: its only mask is right-padding, and TRL pads
|
|
242
|
+
completions to the LONGEST in the micro-batch, so the deepest row has ``keep.sum() == full_t`` and
|
|
243
|
+
the across-batch no-op below triggers — there is no shared headroom to gather. It would engage
|
|
244
|
+
only if ``pad_to_multiple_of`` padded every row past the longest completion.)
|
|
245
|
+
|
|
246
|
+
Wrap ``trainer.liger_grpo_loss`` to GATHER the unmasked positions — ONE shared index applied
|
|
247
|
+
identically to every per-token tensor (``_input``, ``selected_token_ids``, ``attention_mask``,
|
|
248
|
+
``old_per_token_logps``, ``ref_per_token_logps``, and a 2-D ``vllm_is_ratio``) — before the call,
|
|
249
|
+
so the kernel only projects the kept positions. Per-sequence ``advantages`` ``(B,)`` and the loss
|
|
250
|
+
object's ``max_completion_length`` are left untouched. This is EXACTLY loss-preserving: dr_grpo's
|
|
251
|
+
numerator only ever summed unmasked positions, and its normalizer is ``B * max_completion_length``
|
|
252
|
+
(a config constant on the loss object, independent of the gathered length); the gathered
|
|
253
|
+
sequence is re-padded with masked positions whose new mask is 0, so loss + credit assignment are
|
|
254
|
+
unchanged while the gathered length T' < T cuts the projection FLOPs by ~the masked fraction.
|
|
255
|
+
No-op when the deepest row is full-length (``max(unmasked) == T`` — e.g. single-turn padded to the
|
|
256
|
+
batch max), when nothing is masked at all, or when the loss object isn't present. Returns True if
|
|
257
|
+
wrapped."""
|
|
258
|
+
orig = getattr(trainer, "liger_grpo_loss", None)
|
|
259
|
+
if orig is None:
|
|
260
|
+
return False
|
|
261
|
+
if getattr(orig, "_flash_mask_aware", False):
|
|
262
|
+
return True # already wrapped — idempotent (mirrors the other patch helpers' sentinels)
|
|
263
|
+
import torch
|
|
264
|
+
|
|
265
|
+
def _gather(x, idx, tprime):
|
|
266
|
+
if x is None:
|
|
267
|
+
return None
|
|
268
|
+
if x.dim() == 2: # (B, T) per-token tensor
|
|
269
|
+
return torch.gather(x, 1, idx)
|
|
270
|
+
return torch.gather(x, 1, idx.unsqueeze(-1).expand(idx.size(0), tprime, x.size(-1)))
|
|
271
|
+
|
|
272
|
+
def masked_liger_loss(*args, **kwargs):
|
|
273
|
+
mask = kwargs.get("attention_mask") # loss mask = completion_mask * tool_mask, shape (B, T)
|
|
274
|
+
if args or mask is None or mask.dim() != 2:
|
|
275
|
+
return orig(*args, **kwargs) # unexpected call shape -> never alter the loss
|
|
276
|
+
keep = mask != 0
|
|
277
|
+
full_t = mask.size(1)
|
|
278
|
+
tprime = int(keep.sum(dim=1).max().item())
|
|
279
|
+
if tprime == 0 or tprime == full_t:
|
|
280
|
+
# Nothing maskable to skip across the batch: the DEEPEST row is full-length (max unmasked
|
|
281
|
+
# == T). Standard single-turn vLLM GRPO pads completions to the longest in the micro-batch,
|
|
282
|
+
# so this is its common case — the patch only engages where every row has masked headroom.
|
|
283
|
+
return orig(**kwargs)
|
|
284
|
+
# Defensive: we gather a KNOWN set of per-token tensors below. If TRL/Liger starts passing any
|
|
285
|
+
# OTHER per-token tensor shaped (B, T[, *]), it would stay full-length while the rest are
|
|
286
|
+
# gathered to T' -> a shape mismatch or misaligned credit. Bail to the unmodified loss instead.
|
|
287
|
+
# (Per-sequence ``advantages`` is (B,) and 2-D ``vllm_is_ratio`` is handled explicitly below.)
|
|
288
|
+
_known = {"attention_mask", "_input", "selected_token_ids", "old_per_token_logps",
|
|
289
|
+
"ref_per_token_logps", "vllm_is_ratio"}
|
|
290
|
+
for _k, _v in kwargs.items():
|
|
291
|
+
if (_k not in _known and isinstance(_v, torch.Tensor) and _v.dim() >= 2
|
|
292
|
+
and _v.size(0) == mask.size(0) and _v.size(1) == full_t):
|
|
293
|
+
return orig(**kwargs) # unknown per-token tensor -> don't risk a misaligned gather
|
|
294
|
+
# One shared gather index: the unmasked positions first (stable argsort -> their original
|
|
295
|
+
# order preserved), then the remaining masked positions in original order. Keep only the
|
|
296
|
+
# first tprime columns; a sequence with fewer than tprime unmasked positions has its filler
|
|
297
|
+
# entries taken from its masked positions, whose gathered mask is 0 — so they add zero
|
|
298
|
+
# loss/grad and can't perturb the per-token ratio/KL alignment.
|
|
299
|
+
order = torch.argsort((~keep).to(torch.int8), dim=1, stable=True)
|
|
300
|
+
idx = order[:, :tprime].contiguous()
|
|
301
|
+
gk = dict(kwargs)
|
|
302
|
+
gk["attention_mask"] = torch.gather(mask, 1, idx)
|
|
303
|
+
gk["_input"] = _gather(kwargs.get("_input"), idx, tprime)
|
|
304
|
+
gk["selected_token_ids"] = _gather(kwargs.get("selected_token_ids"), idx, tprime)
|
|
305
|
+
for key in ("old_per_token_logps", "ref_per_token_logps"):
|
|
306
|
+
if kwargs.get(key) is not None:
|
|
307
|
+
gk[key] = _gather(kwargs[key], idx, tprime)
|
|
308
|
+
ratio = kwargs.get("vllm_is_ratio")
|
|
309
|
+
if ratio is not None and ratio.dim() == 2 and ratio.size(1) == full_t:
|
|
310
|
+
gk["vllm_is_ratio"] = _gather(ratio, idx, tprime)
|
|
311
|
+
# The gathered tensors have shape (B, tprime) where tprime varies per micro-batch
|
|
312
|
+
# (it is the max unmasked-position count across the batch). torch.compile inside
|
|
313
|
+
# liger_kernel's compiled_compute_loss builds SHAPE_ENV guards keyed on static tensor
|
|
314
|
+
# dimensions; when tprime changes between calls, guard recompilation hits a
|
|
315
|
+
# symbol_to_source IndexError (InternalTorchDynamoError). Running the gathered call
|
|
316
|
+
# without torch.compile is still faster than the unmasked path: the gather already
|
|
317
|
+
# eliminated the masked FLOPs; eager overhead is negligible at 0.8B scale.
|
|
318
|
+
import torch._dynamo as _dynamo
|
|
319
|
+
|
|
320
|
+
_disabled_orig = getattr(masked_liger_loss, "_flash_disabled_orig", None)
|
|
321
|
+
if _disabled_orig is None:
|
|
322
|
+
_disabled_orig = _dynamo.disable(orig)
|
|
323
|
+
masked_liger_loss._flash_disabled_orig = _disabled_orig
|
|
324
|
+
return _disabled_orig(**gk)
|
|
325
|
+
|
|
326
|
+
masked_liger_loss._flash_mask_aware = True # sentinel for the idempotency check above
|
|
327
|
+
trainer.liger_grpo_loss = masked_liger_loss
|
|
328
|
+
return True
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def disable_liger_grpo_torch_compile(trainer) -> bool:
|
|
332
|
+
"""Run liger's fused GRPO loss EAGER — drop only its ``torch.compile``, keep the memory path.
|
|
333
|
+
|
|
334
|
+
``LigerFusedLinearGRPOLoss`` wraps ONLY the loss math
|
|
335
|
+
(``fused_linear_ppo._compute_loss_from_logps``) in ``torch.compile`` (gated by its ``compiled``
|
|
336
|
+
flag, default True); the memory-efficient part — the chunked custom-autograd ``chunk_forward``
|
|
337
|
+
that never materializes the fp32 ``[batch, seq, ~248k vocab]`` logits — ALWAYS runs eager. On
|
|
338
|
+
torch 2.10 that ``torch.compile`` is BROKEN: its SHAPE_ENV guards are keyed on the per-call tensor
|
|
339
|
+
dims and guard generation trips a torch bug (``symbol_to_source`` IndexError surfaced as
|
|
340
|
+
``InternalTorchDynamoError`` — "list index out of range" at ``symbolic_shapes.issue_guard``) that
|
|
341
|
+
crashes the FIRST GRPO step on EVERY path (single-turn, multi-turn, tool). It fires during
|
|
342
|
+
guard-build (after tracing), so neither the multi-turn ``suppress_errors=True`` nor the mask-aware
|
|
343
|
+
path's ``_dynamo.disable`` catches it.
|
|
344
|
+
|
|
345
|
+
Setting ``compiled=False`` makes liger skip the ``torch.compile`` wrapper entirely while KEEPING
|
|
346
|
+
the chunked memory path — so the 248k-vocab fp32-logit OOM fix (the whole reason
|
|
347
|
+
``use_liger_kernel`` stays on for GRPO) is fully retained; only the loss-math JIT is dropped, and
|
|
348
|
+
its eager overhead is negligible at these tiny per-token GEMMs. Call this BEFORE
|
|
349
|
+
``patch_grpo_mask_aware_lm_head`` (which replaces ``liger_grpo_loss`` with a closure) so it lands
|
|
350
|
+
on the live ``LigerFusedLinearGRPOLoss`` instance. No-op (returns False) when the loss isn't
|
|
351
|
+
present, predates the ``compiled`` flag, or already has it off. Returns True if it flipped it."""
|
|
352
|
+
loss = getattr(trainer, "liger_grpo_loss", None)
|
|
353
|
+
if loss is None or not getattr(loss, "compiled", False):
|
|
354
|
+
return False
|
|
355
|
+
loss.compiled = False
|
|
356
|
+
return True
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
# --------------------------------------------------------------------------------------------
|
|
360
|
+
# Warm-start (init_from_adapter) SFT-adapter key remap for VL checkpoints.
|
|
361
|
+
#
|
|
362
|
+
# SFT (run_sft) trains the FULL multimodal model: ``SFTTrainer(model=model_id,
|
|
363
|
+
# peft_config=make_lora(...))`` loads ``Qwen3_5ForConditionalGeneration`` whose LM modules live
|
|
364
|
+
# under ``language_model.``, so the SAVED adapter's keys are
|
|
365
|
+
# ``base_model.model.model.language_model.layers.X...``. But warm-started GRPO
|
|
366
|
+
# (``_init_adapter_model``) loads the base via ``AutoModelForCausalLM`` — a TEXT-ONLY module tree
|
|
367
|
+
# whose LoRA targets are named ``base_model.model.model.layers.X...`` (no ``language_model.``
|
|
368
|
+
# infix). ``PeftModel.from_pretrained`` then can't match the SFT keys: peft logs a *warning* about
|
|
369
|
+
# missing adapter keys and SILENTLY keeps the fresh zero-init LoRA, so the SFT is thrown away and
|
|
370
|
+
# GRPO restarts from the base model (observed: linkd-search warm-start reward ~= 0.001).
|
|
371
|
+
#
|
|
372
|
+
# Stripping the ``.language_model.`` infix from the saved adapter keys makes them line up with the
|
|
373
|
+
# ``AutoModelForCausalLM`` trainer (proven workaround: remapped adapters train correctly). We keep
|
|
374
|
+
# the trainer as ``AutoModelForCausalLM`` so the train-time vLLM weight-sync remap
|
|
375
|
+
# (``patch_vllm_lm_weight_sync`` / ``_remap_vl_sync_weights``) stays consistent.
|
|
376
|
+
# --------------------------------------------------------------------------------------------
|
|
377
|
+
|
|
378
|
+
_LANGUAGE_MODEL_INFIX = ".language_model."
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def strip_language_model_infix(key: str) -> str:
|
|
382
|
+
"""Strip the FIRST ``.language_model.`` infix from a peft adapter weight key.
|
|
383
|
+
|
|
384
|
+
``base_model.model.model.language_model.layers.0.linear_attn.out_proj.lora_A.default.weight``
|
|
385
|
+
-> ``base_model.model.model.layers.0.linear_attn.out_proj.lora_A.default.weight``.
|
|
386
|
+
|
|
387
|
+
Only the first occurrence is removed (the LM-vs-VL boundary appears once in the path); keys
|
|
388
|
+
without the infix are returned unchanged.
|
|
389
|
+
"""
|
|
390
|
+
i = key.find(_LANGUAGE_MODEL_INFIX)
|
|
391
|
+
if i == -1:
|
|
392
|
+
return key
|
|
393
|
+
# Replace ".language_model." with "." (keep one separator dot).
|
|
394
|
+
return key[:i] + "." + key[i + len(_LANGUAGE_MODEL_INFIX) :]
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def remap_adapter_keys(keys):
|
|
398
|
+
"""Map an iterable of adapter weight keys -> a dict {old_key: new_key} for keys that change.
|
|
399
|
+
|
|
400
|
+
Pure (no I/O); used both by the on-disk rewriter and by tests to assert the post-remap key set
|
|
401
|
+
matches an ``AutoModelForCausalLM``-named LoRA param set.
|
|
402
|
+
"""
|
|
403
|
+
out = {}
|
|
404
|
+
for k in keys:
|
|
405
|
+
nk = strip_language_model_infix(k)
|
|
406
|
+
if nk != k:
|
|
407
|
+
out[k] = nk
|
|
408
|
+
return out
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def _rewrite_safetensors_header_keys(path: str, rename) -> int:
|
|
412
|
+
"""Rename tensor keys in a ``.safetensors`` file IN PLACE, editing only the header.
|
|
413
|
+
|
|
414
|
+
safetensors layout: 8-byte little-endian header length, then a JSON header mapping
|
|
415
|
+
``name -> {dtype, shape, data_offsets}`` (plus an optional ``__metadata__`` entry), then the
|
|
416
|
+
raw tensor data. ``data_offsets`` are relative to the data section, so a pure key rename leaves
|
|
417
|
+
every byte of the data section valid — we only rewrite the JSON header and its length prefix.
|
|
418
|
+
|
|
419
|
+
``rename`` is a callable ``old_key -> new_key``. Returns the number of keys renamed. No torch /
|
|
420
|
+
safetensors dependency (keeps this module CPU-importable on the server venv).
|
|
421
|
+
"""
|
|
422
|
+
import json
|
|
423
|
+
import os
|
|
424
|
+
import shutil
|
|
425
|
+
import struct
|
|
426
|
+
|
|
427
|
+
with open(path, "rb") as f:
|
|
428
|
+
len_bytes = f.read(8)
|
|
429
|
+
if len(len_bytes) < 8:
|
|
430
|
+
raise ValueError(f"{path}: too small to be a safetensors file")
|
|
431
|
+
(hdr_len,) = struct.unpack("<Q", len_bytes)
|
|
432
|
+
header_bytes = f.read(hdr_len)
|
|
433
|
+
if len(header_bytes) < hdr_len:
|
|
434
|
+
raise ValueError(f"{path}: truncated safetensors header")
|
|
435
|
+
try:
|
|
436
|
+
header = json.loads(header_bytes)
|
|
437
|
+
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
|
|
438
|
+
# Re-raise with the file path so a corrupt adapter being rewritten is diagnosable
|
|
439
|
+
# (a bare JSONDecodeError/UnicodeDecodeError names no file). Non-UTF8 header bytes
|
|
440
|
+
# raise UnicodeDecodeError, not JSONDecodeError, so catch both to keep the context.
|
|
441
|
+
raise ValueError(
|
|
442
|
+
f"{path}: safetensors header is not valid JSON "
|
|
443
|
+
f"(corrupt or not a safetensors file): {exc}"
|
|
444
|
+
) from exc
|
|
445
|
+
data_start = 8 + hdr_len
|
|
446
|
+
|
|
447
|
+
new_header = {}
|
|
448
|
+
renamed = 0
|
|
449
|
+
for k, v in header.items():
|
|
450
|
+
if k == "__metadata__":
|
|
451
|
+
new_header[k] = v
|
|
452
|
+
continue
|
|
453
|
+
nk = rename(k)
|
|
454
|
+
if nk != k:
|
|
455
|
+
if nk in header or nk in new_header:
|
|
456
|
+
raise ValueError(
|
|
457
|
+
f"{path}: remapped key {nk!r} collides with an existing key; refusing to "
|
|
458
|
+
f"overwrite (adapter may already be remapped or malformed)"
|
|
459
|
+
)
|
|
460
|
+
renamed += 1
|
|
461
|
+
new_header[nk] = v
|
|
462
|
+
|
|
463
|
+
if renamed == 0:
|
|
464
|
+
return 0
|
|
465
|
+
|
|
466
|
+
# Re-serialize compactly. safetensors does not require any specific key order or padding; the
|
|
467
|
+
# only constraint is that data_offsets stay consistent with the (unchanged) data section.
|
|
468
|
+
new_header_bytes = json.dumps(new_header, separators=(",", ":")).encode("utf-8")
|
|
469
|
+
# Stream the (possibly multi-GB) tensor data straight from the original to a temp file instead
|
|
470
|
+
# of slurping the whole file into memory; os.replace makes the swap atomic so an interrupted
|
|
471
|
+
# rewrite can't corrupt the adapter.
|
|
472
|
+
tmp = path + ".remap.tmp"
|
|
473
|
+
try:
|
|
474
|
+
with open(path, "rb") as src, open(tmp, "wb") as out:
|
|
475
|
+
src.seek(data_start)
|
|
476
|
+
out.write(struct.pack("<Q", len(new_header_bytes)))
|
|
477
|
+
out.write(new_header_bytes)
|
|
478
|
+
shutil.copyfileobj(src, out, 8 * 1024 * 1024)
|
|
479
|
+
except BaseException:
|
|
480
|
+
if os.path.exists(tmp):
|
|
481
|
+
os.remove(tmp)
|
|
482
|
+
raise
|
|
483
|
+
os.replace(tmp, path)
|
|
484
|
+
return renamed
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
def _rewrite_bin_keys(path: str, rename) -> int:
|
|
488
|
+
"""Rename keys in a PyTorch ``.bin`` (pickled ``state_dict``) adapter IN PLACE.
|
|
489
|
+
|
|
490
|
+
Used only when the saved adapter is the legacy ``.bin`` format (no ``.safetensors``). Needs
|
|
491
|
+
torch to (de)serialize; that's fine because this path runs only on the GPU worker.
|
|
492
|
+
"""
|
|
493
|
+
import torch
|
|
494
|
+
|
|
495
|
+
sd = torch.load(path, map_location="cpu", weights_only=True)
|
|
496
|
+
new_sd = {}
|
|
497
|
+
renamed = 0
|
|
498
|
+
for k, v in sd.items():
|
|
499
|
+
nk = rename(k)
|
|
500
|
+
if nk != k:
|
|
501
|
+
if nk in sd or nk in new_sd:
|
|
502
|
+
raise ValueError(
|
|
503
|
+
f"{path}: remapped key {nk!r} collides with an existing key; refusing to "
|
|
504
|
+
f"overwrite (adapter may already be remapped or malformed)"
|
|
505
|
+
)
|
|
506
|
+
renamed += 1
|
|
507
|
+
new_sd[nk] = v
|
|
508
|
+
if renamed == 0:
|
|
509
|
+
return 0
|
|
510
|
+
torch.save(new_sd, path)
|
|
511
|
+
return renamed
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
# Substrings that identify a peft LoRA weight key (vs a base-model param). The whole adapter file
|
|
515
|
+
# is LoRA weights, but a wrong-arch / corrupt checkpoint can contain non-LoRA tensors, so we filter.
|
|
516
|
+
_LORA_KEY_MARKERS = (".lora_A.", ".lora_B.", ".lora_embedding_A.", ".lora_embedding_B.", "lora_")
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
def _is_lora_key(key: str) -> bool:
|
|
520
|
+
return any(m in key for m in _LORA_KEY_MARKERS)
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
# A safetensors header is small even for huge models (a few hundred KB at most); 100 MB is a wildly
|
|
524
|
+
# generous ceiling that still refuses a corrupt/hostile file declaring a multi-GB header length
|
|
525
|
+
# before we allocate/read it.
|
|
526
|
+
_MAX_SAFETENSORS_HEADER_BYTES = 100 * 1024 * 1024
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
def _read_adapter_tensor_keys(adir: str) -> list[str] | None:
|
|
530
|
+
"""Tensor key names in the downloaded adapter.
|
|
531
|
+
|
|
532
|
+
For ``.safetensors`` this reads ONLY the JSON header (pure stdlib, no tensor data — keeps this
|
|
533
|
+
module CPU-importable). For the legacy ``.bin`` format the pickled ``state_dict`` must be
|
|
534
|
+
materialized via ``torch.load`` to enumerate its keys (a pickle can't be read key-only without
|
|
535
|
+
unpickling the tensor payloads — GPU-worker only). Returns ``None`` when neither weight file
|
|
536
|
+
exists in ``adir``.
|
|
537
|
+
"""
|
|
538
|
+
import json
|
|
539
|
+
import os
|
|
540
|
+
import struct
|
|
541
|
+
|
|
542
|
+
st_path = os.path.join(adir, "adapter_model.safetensors")
|
|
543
|
+
bin_path = os.path.join(adir, "adapter_model.bin")
|
|
544
|
+
if os.path.isfile(st_path):
|
|
545
|
+
# safetensors layout: 8-byte LE header length, then the JSON header, then the tensor data.
|
|
546
|
+
# Bound the DECLARED header length against the real file size (and an absolute ceiling)
|
|
547
|
+
# BEFORE reading it, so a corrupt/hostile file can't trigger a huge allocation / long read.
|
|
548
|
+
file_size = os.path.getsize(st_path)
|
|
549
|
+
with open(st_path, "rb") as f:
|
|
550
|
+
len_bytes = f.read(8)
|
|
551
|
+
if len(len_bytes) < 8:
|
|
552
|
+
raise ValueError(f"{st_path}: too small to be a safetensors file")
|
|
553
|
+
(hdr_len,) = struct.unpack("<Q", len_bytes)
|
|
554
|
+
if hdr_len > file_size - 8 or hdr_len > _MAX_SAFETENSORS_HEADER_BYTES:
|
|
555
|
+
raise ValueError(
|
|
556
|
+
f"{st_path}: declared safetensors header length {hdr_len} is implausible "
|
|
557
|
+
f"(file is {file_size} bytes) — refusing to read a corrupt/oversized header"
|
|
558
|
+
)
|
|
559
|
+
header_bytes = f.read(hdr_len)
|
|
560
|
+
if len(header_bytes) < hdr_len:
|
|
561
|
+
raise ValueError(f"{st_path}: truncated safetensors header")
|
|
562
|
+
try:
|
|
563
|
+
header = json.loads(header_bytes)
|
|
564
|
+
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
|
|
565
|
+
# A bare JSONDecodeError ("Expecting value: line 1 column 1") — or a
|
|
566
|
+
# UnicodeDecodeError from non-UTF8 header bytes — gives no clue WHICH adapter is
|
|
567
|
+
# corrupt. Re-raise with the file path so a bad download is diagnosable.
|
|
568
|
+
raise ValueError(
|
|
569
|
+
f"{st_path}: safetensors header is not valid JSON "
|
|
570
|
+
f"(corrupt or not a safetensors file): {exc}"
|
|
571
|
+
) from exc
|
|
572
|
+
# The safetensors header MUST be a JSON object keyed by tensor name. A corrupt/hostile file
|
|
573
|
+
# could decode to a list/int/str, which would later blow up with a confusing TypeError in
|
|
574
|
+
# _is_lora_key (substring search on a non-str). (JSON object keys are always str, so only the
|
|
575
|
+
# container type needs checking.) Reject a non-object header early with a clear message.
|
|
576
|
+
if not isinstance(header, dict):
|
|
577
|
+
raise ValueError(
|
|
578
|
+
f"{st_path}: safetensors header is not a JSON object "
|
|
579
|
+
"(corrupt or not a safetensors file)"
|
|
580
|
+
)
|
|
581
|
+
return [k for k in header if k != "__metadata__"]
|
|
582
|
+
if os.path.isfile(bin_path):
|
|
583
|
+
import torch
|
|
584
|
+
|
|
585
|
+
sd = torch.load(bin_path, map_location="cpu", weights_only=True)
|
|
586
|
+
return list(sd.keys())
|
|
587
|
+
return None
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
def remap_vl_adapter_dir(adir: str, model_id: str) -> int:
|
|
591
|
+
"""For a VL warm-start, strip the ``.language_model.`` infix from the downloaded SFT adapter so
|
|
592
|
+
its keys match the ``AutoModelForCausalLM`` trainer used by ``_init_adapter_model``.
|
|
593
|
+
|
|
594
|
+
The remap decision is driven by the ADAPTER'S OWN keys, not only the ``is_vl_checkpoint`` config
|
|
595
|
+
probe. ``is_vl_checkpoint`` calls ``AutoConfig.from_pretrained`` and swallows EVERY exception to
|
|
596
|
+
return False, so an HF rate-limit / network hiccup / uncached config silently turned a genuine
|
|
597
|
+
VL warm-start into a no-op: the ``.language_model.`` keys were left in place, the text-only base
|
|
598
|
+
couldn't match them, peft kept the zero-init LoRA, and GRPO aborted at
|
|
599
|
+
``assert_adapter_delta_nonzero`` with all-zero ``lora_B`` (issue #286). Any adapter that actually
|
|
600
|
+
carries ``.language_model.`` LoRA keys was saved against the full multimodal model and MUST be
|
|
601
|
+
stripped regardless of the probe, so we key off the file contents and only fall back to the probe
|
|
602
|
+
for the (already-stripped / text-only) no-infix case.
|
|
603
|
+
|
|
604
|
+
Fails LOUDLY instead of silently dropping a mismatched adapter:
|
|
605
|
+
- a VL warm-start whose adapter has NO LoRA keys at all (corrupt / wrong-architecture) raises;
|
|
606
|
+
- any ``.language_model.`` LoRA key that SURVIVES the rewrite raises (it would be silently
|
|
607
|
+
discarded by the text-only base -> all-zero ``lora_B``).
|
|
608
|
+
|
|
609
|
+
Returns the number of keys renamed. No-op (returns 0) for a genuinely text-only model, or an
|
|
610
|
+
already-remapped adapter. Idempotent: a second call finds nothing to strip.
|
|
611
|
+
"""
|
|
612
|
+
import os
|
|
613
|
+
|
|
614
|
+
keys = _read_adapter_tensor_keys(adir)
|
|
615
|
+
if keys is None:
|
|
616
|
+
print(
|
|
617
|
+
f"[init-adapter] remap_vl_adapter_dir: no adapter_model.safetensors/.bin in {adir!r}; "
|
|
618
|
+
"nothing to remap"
|
|
619
|
+
)
|
|
620
|
+
return 0
|
|
621
|
+
|
|
622
|
+
lora_keys = [k for k in keys if _is_lora_key(k)]
|
|
623
|
+
infixed = [k for k in lora_keys if _LANGUAGE_MODEL_INFIX in k]
|
|
624
|
+
|
|
625
|
+
# No '.language_model.' LoRA keys -> nothing to strip from the file itself. The ONLY reason to act
|
|
626
|
+
# is the config probe, so it runs HERE (the fallback case) rather than on every warm-start: a key
|
|
627
|
+
# already in text-only form needs no network round-trip to confirm. is_vl distinguishes a genuine
|
|
628
|
+
# text-only model (return 0) from an already-remapped / text-only-SFT VL adapter (diagnostic).
|
|
629
|
+
if not infixed:
|
|
630
|
+
if not is_vl_checkpoint(model_id):
|
|
631
|
+
return 0 # genuinely text-only model with text-only adapter keys
|
|
632
|
+
if not lora_keys:
|
|
633
|
+
# A VL warm-start whose adapter carries no LoRA weights can't hold a real SFT delta — it
|
|
634
|
+
# would load as the all-zero identity. Fail here, before the base-model download.
|
|
635
|
+
raise RuntimeError(
|
|
636
|
+
f"warm-start adapter in {adir!r} for {model_id} contains NO LoRA weight keys "
|
|
637
|
+
f"(found {len(keys)} tensor(s), 0 with a lora_ marker) — the adapter is corrupt, "
|
|
638
|
+
"incomplete, or from a different architecture, so GRPO would train from the base "
|
|
639
|
+
"model. Re-export the SFT adapter, or omit train.init_from_adapter for a fresh LoRA."
|
|
640
|
+
)
|
|
641
|
+
# VL checkpoint but nothing to strip: legitimately already-remapped (idempotent re-run) or a
|
|
642
|
+
# text-only SFT. Surface the adapter's actual LoRA prefix so a real key mismatch isn't a
|
|
643
|
+
# silent no-op — if GRPO later aborts with all-zero lora_B, these keys didn't match the base.
|
|
644
|
+
sample_prefix = next(
|
|
645
|
+
(k.split(".lora_")[0] for k in lora_keys if ".lora_" in k), lora_keys[0]
|
|
646
|
+
)
|
|
647
|
+
print(
|
|
648
|
+
f"[init-adapter] remap_vl_adapter_dir: 0 '.language_model.' keys to strip for VL "
|
|
649
|
+
f"checkpoint {model_id} ({len(lora_keys)} LoRA key(s); e.g. prefix {sample_prefix!r}) — "
|
|
650
|
+
"treating as already-remapped/text-only. If the warm-start later aborts with all-zero "
|
|
651
|
+
"lora_B, these keys did not match the base model."
|
|
652
|
+
)
|
|
653
|
+
return 0
|
|
654
|
+
|
|
655
|
+
# The adapter carries '.language_model.' LoRA keys: it was saved against the full multimodal model
|
|
656
|
+
# and MUST be stripped to match the AutoModelForCausalLM trainer — regardless of the config probe
|
|
657
|
+
# (a flaky/failed AutoConfig probe must not silently skip a needed remap -> issue #286). We don't
|
|
658
|
+
# call is_vl_checkpoint at all on this path: the adapter's own keys are sufficient evidence.
|
|
659
|
+
# Fail CLOSED *before* touching disk: strip_language_model_infix removes only the FIRST infix, so a
|
|
660
|
+
# key carrying it twice would still match no text-only module and be silently discarded (the #286
|
|
661
|
+
# all-zero-lora_B failure). Predict the post-strip keys from the in-memory list (no file re-read).
|
|
662
|
+
survivors = [
|
|
663
|
+
nk for nk in (strip_language_model_infix(k) for k in infixed) if _LANGUAGE_MODEL_INFIX in nk
|
|
664
|
+
]
|
|
665
|
+
if survivors:
|
|
666
|
+
raise RuntimeError(
|
|
667
|
+
f"remap_vl_adapter_dir: {len(survivors)} LoRA key(s) in {adir!r} for {model_id} would "
|
|
668
|
+
f"still carry '.language_model.' after the remap (e.g. {survivors[0]!r}) — they will NOT "
|
|
669
|
+
"match the AutoModelForCausalLM trainer and would be silently discarded -> all-zero "
|
|
670
|
+
"lora_B. The adapter's key layout is unexpected; verify it was saved by this SFT pipeline."
|
|
671
|
+
)
|
|
672
|
+
|
|
673
|
+
st_path = os.path.join(adir, "adapter_model.safetensors")
|
|
674
|
+
bin_path = os.path.join(adir, "adapter_model.bin")
|
|
675
|
+
if os.path.isfile(st_path):
|
|
676
|
+
n = _rewrite_safetensors_header_keys(st_path, strip_language_model_infix)
|
|
677
|
+
else: # bin_path exists — keys were read from one of the two files above
|
|
678
|
+
n = _rewrite_bin_keys(bin_path, strip_language_model_infix)
|
|
679
|
+
|
|
680
|
+
print(
|
|
681
|
+
f"[init-adapter] remapped {n} VL SFT adapter key(s): stripped '.language_model.' infix "
|
|
682
|
+
f"to match the AutoModelForCausalLM trainer for {model_id}"
|
|
683
|
+
)
|
|
684
|
+
return n
|
|
685
|
+
|
|
686
|
+
|
|
687
|
+
def assert_lora_applied(model, model_id: str) -> int:
|
|
688
|
+
"""After ``PeftModel.from_pretrained``, verify the adapter's LoRA actually loaded (non-empty)
|
|
689
|
+
so a future key-mismatch regression fails LOUDLY instead of silently training a fresh LoRA.
|
|
690
|
+
|
|
691
|
+
Counts the LoRA A/B submodules present on the PeftModel. Raises for ANY warm-start that ended
|
|
692
|
+
up with ZERO LoRA modules (a key mismatch from any cause; the VL ``.language_model.`` mismatch
|
|
693
|
+
this remap fixes is the common one). Returns the count.
|
|
694
|
+
"""
|
|
695
|
+
count = 0
|
|
696
|
+
for name, _ in model.named_modules():
|
|
697
|
+
# peft names the per-target adapter submodules ``...lora_A.<adapter>`` / ``...lora_B.*``.
|
|
698
|
+
if name.endswith("lora_A.default") or name.endswith("lora_B.default"):
|
|
699
|
+
count += 1
|
|
700
|
+
if count == 0:
|
|
701
|
+
raise RuntimeError(
|
|
702
|
+
f"warm-start adapter for {model_id} loaded ZERO LoRA modules — the SFT adapter was NOT "
|
|
703
|
+
"applied (key mismatch). GRPO would silently restart from the base model. For Qwen3.5/"
|
|
704
|
+
"3.6 VL this is usually the '.language_model.' key-mismatch (check remap_vl_adapter_dir "
|
|
705
|
+
"ran on the adapter); otherwise verify the adapter's keys match the model."
|
|
706
|
+
)
|
|
707
|
+
print(f"[init-adapter] verified {count} LoRA submodule(s) applied for {model_id}")
|
|
708
|
+
return count
|
|
709
|
+
|
|
710
|
+
|
|
711
|
+
def assert_adapter_load_clean(load_result, model_id: str) -> None:
|
|
712
|
+
"""Assert a peft adapter load matched ALL saved keys — fail closed on a silent discard.
|
|
713
|
+
|
|
714
|
+
``PeftModel.from_pretrained`` loads adapter weights with ``load_state_dict(strict=False)`` and
|
|
715
|
+
only WARNS on a key mismatch (it throws the load result away), so an SFT adapter whose keys don't
|
|
716
|
+
line up with the target base is silently dropped and GRPO restarts from the base model (bug #67).
|
|
717
|
+
``assert_lora_applied`` can't catch this: peft INJECTS the LoRA modules from ``target_modules``
|
|
718
|
+
BEFORE loading any weights, so the module count is non-zero even when zero saved weights matched.
|
|
719
|
+
|
|
720
|
+
``load_result`` is the object returned by ``PeftModel.load_adapter`` (a ``_IncompatibleKeys`` with
|
|
721
|
+
``missing_keys`` / ``unexpected_keys``). We only care about LoRA keys: an adapter-only checkpoint
|
|
722
|
+
loaded with ``strict=False`` legitimately leaves the base-model params out, so they can surface as
|
|
723
|
+
"missing" without anything being wrong. peft's ``load_adapter`` already filters ``missing_keys`` to
|
|
724
|
+
the tuner prefix, but we re-filter to keys carrying the LoRA prefix (``lora_``) ourselves so a
|
|
725
|
+
benign base-weight miss never aborts a correct warm-start even if peft's internal filtering
|
|
726
|
+
changes. Raises if any injected LoRA module got no saved weight (``missing_keys``) or any saved
|
|
727
|
+
LoRA key matched no module (``unexpected_keys``) — i.e. matched != saved.
|
|
728
|
+
"""
|
|
729
|
+
|
|
730
|
+
def _lora_only(keys):
|
|
731
|
+
# the #67 mismatch keys (e.g. ...lora_A.default.weight) all carry this prefix; base-model
|
|
732
|
+
# params do not, so this drops the benign base misses peft can report under strict=False.
|
|
733
|
+
return [k for k in (keys or []) if "lora_" in k]
|
|
734
|
+
|
|
735
|
+
missing = _lora_only(getattr(load_result, "missing_keys", None))
|
|
736
|
+
unexpected = _lora_only(getattr(load_result, "unexpected_keys", None))
|
|
737
|
+
if missing or unexpected:
|
|
738
|
+
raise RuntimeError(
|
|
739
|
+
f"warm-start adapter for {model_id} did NOT load cleanly: {len(missing)} injected LoRA "
|
|
740
|
+
f"module(s) got no saved weight (missing) and {len(unexpected)} saved adapter key(s) "
|
|
741
|
+
"matched no module (unexpected). The adapter was silently discarded -> GRPO would restart "
|
|
742
|
+
"from the base model. For Qwen3.5/3.6 VL this is the '.language_model.' key mismatch "
|
|
743
|
+
"(check remap_vl_adapter_dir ran on the adapter); otherwise the adapter's keys don't match "
|
|
744
|
+
f"the base. missing[:3]={missing[:3]} unexpected[:3]={unexpected[:3]}"
|
|
745
|
+
)
|
|
746
|
+
print(
|
|
747
|
+
f"[init-adapter] adapter load matched all saved keys for {model_id} (no missing/unexpected)"
|
|
748
|
+
)
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
def assert_adapter_delta_nonzero(model, model_id: str) -> int:
|
|
752
|
+
"""Assert at least one ``lora_B`` weight is non-zero — the adapter is not an identity no-op.
|
|
753
|
+
|
|
754
|
+
With standard zero-B init (``init_lora_weights=True``), a freshly-injected-but-unloaded adapter
|
|
755
|
+
has ``lora_B == 0`` everywhere, so the effective delta ``(B @ A) * scaling`` is identically zero
|
|
756
|
+
and the warm-started model equals the base. A real SFT adapter that actually loaded has non-zero
|
|
757
|
+
``lora_B``. This is an API-independent backstop to ``assert_adapter_load_clean``: it catches a
|
|
758
|
+
silent discard even if peft's load-result shape changes. Returns the count of non-zero ``lora_B``
|
|
759
|
+
modules. When no ``lora_B`` modules exist at all, defers to ``assert_lora_applied`` (no raise).
|
|
760
|
+
"""
|
|
761
|
+
seen = 0
|
|
762
|
+
nonzero = 0
|
|
763
|
+
for name, module in model.named_modules():
|
|
764
|
+
if not name.endswith("lora_B.default"):
|
|
765
|
+
continue
|
|
766
|
+
weight = getattr(module, "weight", None)
|
|
767
|
+
if weight is None:
|
|
768
|
+
continue
|
|
769
|
+
seen += 1
|
|
770
|
+
if bool(weight.detach().ne(0).any()):
|
|
771
|
+
nonzero += 1
|
|
772
|
+
if seen and nonzero == 0:
|
|
773
|
+
raise RuntimeError(
|
|
774
|
+
f"warm-start adapter for {model_id} has ALL-ZERO lora_B weights across {seen} module(s) — "
|
|
775
|
+
"the adapter delta is identically zero (an unloaded / silently-discarded adapter). GRPO "
|
|
776
|
+
"would train from the base model. Verify the adapter's keys match the base (see "
|
|
777
|
+
"remap_vl_adapter_dir)."
|
|
778
|
+
)
|
|
779
|
+
print(f"[init-adapter] verified non-zero lora_B in {nonzero}/{seen} module(s) for {model_id}")
|
|
780
|
+
return nonzero
|
|
781
|
+
|
|
782
|
+
|
|
783
|
+
def model_quant(model_id: str) -> str:
|
|
784
|
+
"""Quantization tier for this model: catalog entry > bf16 (managed; no override).
|
|
785
|
+
|
|
786
|
+
The whole catalog is bf16, so this always returns ``"bf16"`` today; kept as the single
|
|
787
|
+
source of truth a future non-bf16 tier could feed (no caller branches on it now)."""
|
|
788
|
+
try:
|
|
789
|
+
from flash.catalog import MODELS
|
|
790
|
+
|
|
791
|
+
info = MODELS.get(model_id)
|
|
792
|
+
if info is not None:
|
|
793
|
+
return info.quant
|
|
794
|
+
except Exception as e:
|
|
795
|
+
print("model_quant: catalog probe failed:", e)
|
|
796
|
+
return "bf16"
|