freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,796 @@
1
+ """Pure LoRA-target / VL-checkpoint helpers for the fine-tuning worker.
2
+
3
+ These helpers take the model id as an ARGUMENT and read NONE of the worker's run-scoped
4
+ module globals, so they live here as a leaf module. ``flash.engine.worker`` re-exports
5
+ them; this module must NOT import that package (no cycle). Heavy deps (transformers, peft,
6
+ vllm, the catalog) are imported lazily inside the functions so the module stays
7
+ CPU-importable.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+
13
+ def _patch_peft_weight_converter_compat() -> None:
14
+ """peft 0.19.1 x transformers 5.6-5.10: make MoE adapter loading work.
15
+
16
+ peft's ``build_peft_weight_mapping`` reconstructs transformers ``WeightConverter``
17
+ objects passing ``distributed_operation=`` / ``quantization_operation=`` — kwargs
18
+ the WeightConverter in transformers <5.11 doesn't accept (init=False dataclass
19
+ fields), so loading a LoRA adapter onto any arch WITH weight conversions dies with
20
+ ``TypeError: unexpected keyword argument 'distributed_operation'`` (observed on a
21
+ weight-converting checkpoint eval). The
22
+ worker can't take transformers>=5.11 (vllm 0.19.1 compat), so accept-and-drop
23
+ unknown kwargs; on a single GPU those fields are unused. No-op once signatures
24
+ match.
25
+ """
26
+ import inspect
27
+
28
+ try:
29
+ from transformers import core_model_loading as cml
30
+ except Exception: # pragma: no cover - older stacks have no converter module
31
+ return
32
+ converter = getattr(cml, "WeightConverter", None)
33
+ if converter is None or getattr(converter, "_flash_compat", False):
34
+ return
35
+ accepted = set(inspect.signature(converter.__init__).parameters)
36
+ if "distributed_operation" in accepted:
37
+ return
38
+ orig_init = converter.__init__
39
+
40
+ def _compat_init(self, *args, **kwargs):
41
+ dropped = [k for k in kwargs if k not in accepted]
42
+ for k in dropped:
43
+ kwargs.pop(k)
44
+ orig_init(self, *args, **kwargs)
45
+
46
+ converter.__init__ = _compat_init
47
+ converter._flash_compat = True
48
+ print("[compat] WeightConverter patched (peft<->transformers signature drift)")
49
+
50
+
51
+ # Module-path segments that must never receive LoRA on natively-multimodal checkpoints
52
+ # trained text-only: the vision tower / projector / MTP head. Critically, adapters that
53
+ # DO touch them cannot be loaded by vLLM in text-only (language_model_only) serving —
54
+ # its LoRA loader rejects "unexpected modules" (observed with Qwen3.5-2B).
55
+ _VL_EXCLUDE_SEGMENTS = ("visual", "vision_tower", "multi_modal_projector", "mtp")
56
+
57
+
58
+ def lora_exclude_modules(model_id: str) -> str | None:
59
+ """Regex (peft fullmatch semantics) excluding vision-tower modules from LoRA.
60
+
61
+ Returns None when no exclusion is needed (pure text architectures). NOTE: peft's
62
+ list-form exclude_modules uses suffix matching (like target_modules), which does
63
+ NOT match leaf modules under 'visual.*' — a regex string is required.
64
+ """
65
+ excludes = {
66
+ "qwen3_5": _VL_EXCLUDE_SEGMENTS,
67
+ "qwen3_5_moe": _VL_EXCLUDE_SEGMENTS,
68
+ "qwen3_6": _VL_EXCLUDE_SEGMENTS,
69
+ }
70
+ try:
71
+ from transformers import AutoConfig
72
+
73
+ cfg = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
74
+ model_type = getattr(cfg, "model_type", "") or ""
75
+ except Exception as e:
76
+ print("lora_exclude_modules: config probe failed:", e)
77
+ return None
78
+ segments = excludes.get(model_type)
79
+ if not segments:
80
+ return None
81
+ alt = "|".join(segments)
82
+ return rf"(^|.*\.)({alt})(\..*|$)"
83
+
84
+
85
+ def is_vl_checkpoint(model_id: str) -> bool:
86
+ """True for natively-multimodal checkpoints we train/serve text-only (Qwen3.5/3.6)."""
87
+ return bool(lora_exclude_modules(model_id))
88
+
89
+
90
+ def vllm_language_model_only_kwargs(model_id: str) -> dict:
91
+ """Engine kwargs to skip the vision tower for VL checkpoints (vLLM >= 0.19).
92
+
93
+ Besides wasting VRAM, the vision tower's attention path hardcodes vLLM's bundled
94
+ flash-attn, whose PTX needs a newer driver JIT than many RTX 5090 hosts have
95
+ ("PTX compiled with unsupported toolchain") — text-only loading sidesteps it and
96
+ is the officially supported way to run Qwen3.5 as a pure LLM.
97
+ """
98
+ return {"language_model_only": True} if is_vl_checkpoint(model_id) else {}
99
+
100
+
101
+ def patch_vllm_language_model_only(model_id: str) -> bool:
102
+ """Force ``language_model_only=True`` on vLLM engines created by third-party code
103
+ (TRL's colocated GRPO rollout engine) for VL checkpoints. Returns True if patched."""
104
+ extra = vllm_language_model_only_kwargs(model_id)
105
+ if not extra:
106
+ return False
107
+ try:
108
+ import vllm
109
+
110
+ if getattr(vllm.LLM.__init__, "_flash_lmo_patched", False):
111
+ return True
112
+ orig = vllm.LLM.__init__
113
+
114
+ def patched(self, *args, **kwargs):
115
+ kwargs.setdefault("language_model_only", True)
116
+ return orig(self, *args, **kwargs)
117
+
118
+ patched._flash_lmo_patched = True
119
+ vllm.LLM.__init__ = patched
120
+ print(f"[vllm] language_model_only patch active for {model_id}")
121
+ return True
122
+ except Exception as e:
123
+ print("patch_vllm_language_model_only warn:", e)
124
+ return False
125
+
126
+
127
+ # Flipped to True only AFTER the GRPO trainer (and its colocated vLLM engine + initial
128
+ # checkpoint load) is constructed, but BEFORE ``trainer.train()`` runs the first weight sync.
129
+ # See ``patch_vllm_lm_weight_sync``. A module dict (not a bare bool) so the gating flag is shared
130
+ # by reference between this module and the worker package that flips it.
131
+ _LM_SYNC_REMAP_ON = {"on": False}
132
+
133
+
134
+ def _remap_vl_sync_weights(weights):
135
+ """Rewrite TRL's trainer weight names to vLLM's VL-engine names for the train-time sync.
136
+
137
+ The trainer (built via ``AutoModelForCausalLM``) names its LM params ``model.layers.*`` /
138
+ ``model.norm`` / ``model.embed_tokens`` / ``lm_head.*``; the colocated vLLM engine loaded the
139
+ same checkpoint as ``Qwen3_5ForConditionalGeneration`` whose LM params live under
140
+ ``language_model.*``. Prefix incoming ``model.``/``lm_head.`` names with ``language_model.`` so
141
+ they resolve. Also tolerate a peft ``base_model.model.`` prefix (a merged-adapter sync can yield
142
+ base-model names through that wrapper) by stripping it before the language_model. prefix is
143
+ added. Names that already start with ``language_model.`` (or anything else) pass through
144
+ untouched. A generator so vLLM's loader still streams one (name, tensor) at a time.
145
+ """
146
+ for name, tensor in weights:
147
+ # A continued-adapter (PeftModel) sync can surface names through the peft wrapper as
148
+ # ``base_model.model.model.layers.*`` / ``base_model.model.lm_head.*``; strip the wrapper
149
+ # so the same model./lm_head. rule applies.
150
+ if name.startswith("base_model.model."):
151
+ name = name[len("base_model.model.") :]
152
+ if name.startswith(("model.", "lm_head.")):
153
+ name = "language_model." + name
154
+ yield name, tensor
155
+
156
+
157
+ def patch_vllm_lm_weight_sync(model_id: str) -> bool:
158
+ """Make TRL's GRPO ``sync_weights`` work for ``*ForConditionalGeneration`` checkpoints
159
+ (the whole Qwen3.5/3.6 family). Returns True if any vLLM model class was patched.
160
+
161
+ The trainer loads via ``AutoModelForCausalLM`` so its params are named ``model.layers.*`` /
162
+ ``model.norm`` / ``model.embed_tokens`` / ``lm_head.*``. vLLM loads the same checkpoint as
163
+ ``Qwen3_5ForConditionalGeneration`` whose LM params live under ``language_model.*``. TRL's
164
+ ``sync_weights`` pushes the trainer names verbatim, so vLLM's loader raises "There is no module
165
+ or parameter named 'model' in Qwen3_5ForConditionalGeneration" at the first generation step and
166
+ GRPO dies (even with ``language_model_only=True``: that only skips loading the vision tower, it
167
+ does NOT rename the surviving LM params out from under ``language_model.``).
168
+
169
+ The fix wraps the vLLM model class ``load_weights`` to remap incoming ``model.``/``lm_head.``
170
+ names to ``language_model.*`` (see ``_remap_vl_sync_weights``) — but ONLY while
171
+ ``_LM_SYNC_REMAP_ON`` is set. The INITIAL checkpoint load (during trainer construction) runs
172
+ with it OFF, so vLLM's own ``hf_to_vllm_mapper`` handles the on-disk checkpoint untouched; the
173
+ remap activates only for the train-time TRL syncs. The flag is flipped on between trainer
174
+ construction and ``train()``. Works for BOTH from-base and warm-started (init_from_adapter)
175
+ GRPO. No-op for non-VL checkpoints."""
176
+ if not is_vl_checkpoint(model_id):
177
+ return False
178
+ patched_any = False
179
+ try:
180
+ import importlib
181
+
182
+ # The dense class is REQUIRED for the whole Qwen3.5/3.6 family — if its module/class can't
183
+ # be imported (vLLM not installed where it should be, or the class renamed in a new vLLM)
184
+ # we must NOT silently no-op: the run would crash again at the first ``sync_weights()`` with
185
+ # a far less actionable error. Log loudly for the required one; the MoE class is OPTIONAL
186
+ # (only some models are MoE, and older vLLM lacks the module) so its absence stays quiet.
187
+ for mod_name, cls_name, required in (
188
+ ("vllm.model_executor.models.qwen3_5", "Qwen3_5ForConditionalGeneration", True),
189
+ ("vllm.model_executor.models.qwen3_5_moe", "Qwen3_5MoeForConditionalGeneration", False),
190
+ ):
191
+ try:
192
+ mod = importlib.import_module(mod_name)
193
+ except Exception as e:
194
+ mod = None
195
+ if required:
196
+ print(
197
+ f"[vllm] WARN patch_vllm_lm_weight_sync: could not import required module "
198
+ f"{mod_name} ({e!r}); GRPO weight-sync will NOT be remapped and the run may "
199
+ f"crash at the first sync_weights() for this VL checkpoint."
200
+ )
201
+ cls = getattr(mod, cls_name, None) if mod is not None else None
202
+ if cls is None:
203
+ if required and mod is not None:
204
+ print(
205
+ f"[vllm] WARN patch_vllm_lm_weight_sync: module {mod_name} imported but has "
206
+ f"no {cls_name} (vLLM API changed?); GRPO weight-sync will NOT be remapped "
207
+ f"and the run may crash at the first sync_weights() for this VL checkpoint."
208
+ )
209
+ continue
210
+ if getattr(cls.load_weights, "_flash_sync_patched", False):
211
+ continue
212
+ orig_load = cls.load_weights
213
+
214
+ def _make_patched(orig):
215
+ def patched(self, weights, *args, **kwargs):
216
+ if _LM_SYNC_REMAP_ON["on"]:
217
+ weights = _remap_vl_sync_weights(weights)
218
+ return orig(self, weights, *args, **kwargs)
219
+
220
+ patched._flash_sync_patched = True
221
+ return patched
222
+
223
+ cls.load_weights = _make_patched(orig_load)
224
+ patched_any = True
225
+ print(f"[vllm] LM weight-sync name patch installed for {cls_name} (gated)")
226
+ except Exception as e:
227
+ print("patch_vllm_lm_weight_sync warn:", e)
228
+ return patched_any
229
+
230
+
231
+ def patch_grpo_mask_aware_lm_head(trainer) -> bool:
232
+ """Skip the 248k-vocab ``lm_head`` projection at MASKED completion positions in the GRPO loss.
233
+
234
+ Targets MULTI-TURN GRPO, where the masked set is the env/tool text (~half-to-most of the
235
+ transcript: the rollout's ``env_mask`` -> TRL's ``tool_mask``) that EVERY row carries, so the
236
+ micro-batch has maskable headroom in all rows. TRL 1.6's ``compute_liger_loss`` hands the
237
+ FULL-length hidden states to ``liger_grpo_loss``, and the Liger kernel runs the lm_head matmul +
238
+ log-softmax for EVERY position (in the forward AND the backward recompute). Masked positions
239
+ contribute zero loss and zero gradient but still pay the full FLOPs of the single most expensive
240
+ GRPO op (the 248k-vocab projection Liger exists to tame). The saving scales with the env-masked
241
+ fraction. (SINGLE-TURN is effectively a no-op: its only mask is right-padding, and TRL pads
242
+ completions to the LONGEST in the micro-batch, so the deepest row has ``keep.sum() == full_t`` and
243
+ the across-batch no-op below triggers — there is no shared headroom to gather. It would engage
244
+ only if ``pad_to_multiple_of`` padded every row past the longest completion.)
245
+
246
+ Wrap ``trainer.liger_grpo_loss`` to GATHER the unmasked positions — ONE shared index applied
247
+ identically to every per-token tensor (``_input``, ``selected_token_ids``, ``attention_mask``,
248
+ ``old_per_token_logps``, ``ref_per_token_logps``, and a 2-D ``vllm_is_ratio``) — before the call,
249
+ so the kernel only projects the kept positions. Per-sequence ``advantages`` ``(B,)`` and the loss
250
+ object's ``max_completion_length`` are left untouched. This is EXACTLY loss-preserving: dr_grpo's
251
+ numerator only ever summed unmasked positions, and its normalizer is ``B * max_completion_length``
252
+ (a config constant on the loss object, independent of the gathered length); the gathered
253
+ sequence is re-padded with masked positions whose new mask is 0, so loss + credit assignment are
254
+ unchanged while the gathered length T' < T cuts the projection FLOPs by ~the masked fraction.
255
+ No-op when the deepest row is full-length (``max(unmasked) == T`` — e.g. single-turn padded to the
256
+ batch max), when nothing is masked at all, or when the loss object isn't present. Returns True if
257
+ wrapped."""
258
+ orig = getattr(trainer, "liger_grpo_loss", None)
259
+ if orig is None:
260
+ return False
261
+ if getattr(orig, "_flash_mask_aware", False):
262
+ return True # already wrapped — idempotent (mirrors the other patch helpers' sentinels)
263
+ import torch
264
+
265
+ def _gather(x, idx, tprime):
266
+ if x is None:
267
+ return None
268
+ if x.dim() == 2: # (B, T) per-token tensor
269
+ return torch.gather(x, 1, idx)
270
+ return torch.gather(x, 1, idx.unsqueeze(-1).expand(idx.size(0), tprime, x.size(-1)))
271
+
272
+ def masked_liger_loss(*args, **kwargs):
273
+ mask = kwargs.get("attention_mask") # loss mask = completion_mask * tool_mask, shape (B, T)
274
+ if args or mask is None or mask.dim() != 2:
275
+ return orig(*args, **kwargs) # unexpected call shape -> never alter the loss
276
+ keep = mask != 0
277
+ full_t = mask.size(1)
278
+ tprime = int(keep.sum(dim=1).max().item())
279
+ if tprime == 0 or tprime == full_t:
280
+ # Nothing maskable to skip across the batch: the DEEPEST row is full-length (max unmasked
281
+ # == T). Standard single-turn vLLM GRPO pads completions to the longest in the micro-batch,
282
+ # so this is its common case — the patch only engages where every row has masked headroom.
283
+ return orig(**kwargs)
284
+ # Defensive: we gather a KNOWN set of per-token tensors below. If TRL/Liger starts passing any
285
+ # OTHER per-token tensor shaped (B, T[, *]), it would stay full-length while the rest are
286
+ # gathered to T' -> a shape mismatch or misaligned credit. Bail to the unmodified loss instead.
287
+ # (Per-sequence ``advantages`` is (B,) and 2-D ``vllm_is_ratio`` is handled explicitly below.)
288
+ _known = {"attention_mask", "_input", "selected_token_ids", "old_per_token_logps",
289
+ "ref_per_token_logps", "vllm_is_ratio"}
290
+ for _k, _v in kwargs.items():
291
+ if (_k not in _known and isinstance(_v, torch.Tensor) and _v.dim() >= 2
292
+ and _v.size(0) == mask.size(0) and _v.size(1) == full_t):
293
+ return orig(**kwargs) # unknown per-token tensor -> don't risk a misaligned gather
294
+ # One shared gather index: the unmasked positions first (stable argsort -> their original
295
+ # order preserved), then the remaining masked positions in original order. Keep only the
296
+ # first tprime columns; a sequence with fewer than tprime unmasked positions has its filler
297
+ # entries taken from its masked positions, whose gathered mask is 0 — so they add zero
298
+ # loss/grad and can't perturb the per-token ratio/KL alignment.
299
+ order = torch.argsort((~keep).to(torch.int8), dim=1, stable=True)
300
+ idx = order[:, :tprime].contiguous()
301
+ gk = dict(kwargs)
302
+ gk["attention_mask"] = torch.gather(mask, 1, idx)
303
+ gk["_input"] = _gather(kwargs.get("_input"), idx, tprime)
304
+ gk["selected_token_ids"] = _gather(kwargs.get("selected_token_ids"), idx, tprime)
305
+ for key in ("old_per_token_logps", "ref_per_token_logps"):
306
+ if kwargs.get(key) is not None:
307
+ gk[key] = _gather(kwargs[key], idx, tprime)
308
+ ratio = kwargs.get("vllm_is_ratio")
309
+ if ratio is not None and ratio.dim() == 2 and ratio.size(1) == full_t:
310
+ gk["vllm_is_ratio"] = _gather(ratio, idx, tprime)
311
+ # The gathered tensors have shape (B, tprime) where tprime varies per micro-batch
312
+ # (it is the max unmasked-position count across the batch). torch.compile inside
313
+ # liger_kernel's compiled_compute_loss builds SHAPE_ENV guards keyed on static tensor
314
+ # dimensions; when tprime changes between calls, guard recompilation hits a
315
+ # symbol_to_source IndexError (InternalTorchDynamoError). Running the gathered call
316
+ # without torch.compile is still faster than the unmasked path: the gather already
317
+ # eliminated the masked FLOPs; eager overhead is negligible at 0.8B scale.
318
+ import torch._dynamo as _dynamo
319
+
320
+ _disabled_orig = getattr(masked_liger_loss, "_flash_disabled_orig", None)
321
+ if _disabled_orig is None:
322
+ _disabled_orig = _dynamo.disable(orig)
323
+ masked_liger_loss._flash_disabled_orig = _disabled_orig
324
+ return _disabled_orig(**gk)
325
+
326
+ masked_liger_loss._flash_mask_aware = True # sentinel for the idempotency check above
327
+ trainer.liger_grpo_loss = masked_liger_loss
328
+ return True
329
+
330
+
331
+ def disable_liger_grpo_torch_compile(trainer) -> bool:
332
+ """Run liger's fused GRPO loss EAGER — drop only its ``torch.compile``, keep the memory path.
333
+
334
+ ``LigerFusedLinearGRPOLoss`` wraps ONLY the loss math
335
+ (``fused_linear_ppo._compute_loss_from_logps``) in ``torch.compile`` (gated by its ``compiled``
336
+ flag, default True); the memory-efficient part — the chunked custom-autograd ``chunk_forward``
337
+ that never materializes the fp32 ``[batch, seq, ~248k vocab]`` logits — ALWAYS runs eager. On
338
+ torch 2.10 that ``torch.compile`` is BROKEN: its SHAPE_ENV guards are keyed on the per-call tensor
339
+ dims and guard generation trips a torch bug (``symbol_to_source`` IndexError surfaced as
340
+ ``InternalTorchDynamoError`` — "list index out of range" at ``symbolic_shapes.issue_guard``) that
341
+ crashes the FIRST GRPO step on EVERY path (single-turn, multi-turn, tool). It fires during
342
+ guard-build (after tracing), so neither the multi-turn ``suppress_errors=True`` nor the mask-aware
343
+ path's ``_dynamo.disable`` catches it.
344
+
345
+ Setting ``compiled=False`` makes liger skip the ``torch.compile`` wrapper entirely while KEEPING
346
+ the chunked memory path — so the 248k-vocab fp32-logit OOM fix (the whole reason
347
+ ``use_liger_kernel`` stays on for GRPO) is fully retained; only the loss-math JIT is dropped, and
348
+ its eager overhead is negligible at these tiny per-token GEMMs. Call this BEFORE
349
+ ``patch_grpo_mask_aware_lm_head`` (which replaces ``liger_grpo_loss`` with a closure) so it lands
350
+ on the live ``LigerFusedLinearGRPOLoss`` instance. No-op (returns False) when the loss isn't
351
+ present, predates the ``compiled`` flag, or already has it off. Returns True if it flipped it."""
352
+ loss = getattr(trainer, "liger_grpo_loss", None)
353
+ if loss is None or not getattr(loss, "compiled", False):
354
+ return False
355
+ loss.compiled = False
356
+ return True
357
+
358
+
359
+ # --------------------------------------------------------------------------------------------
360
+ # Warm-start (init_from_adapter) SFT-adapter key remap for VL checkpoints.
361
+ #
362
+ # SFT (run_sft) trains the FULL multimodal model: ``SFTTrainer(model=model_id,
363
+ # peft_config=make_lora(...))`` loads ``Qwen3_5ForConditionalGeneration`` whose LM modules live
364
+ # under ``language_model.``, so the SAVED adapter's keys are
365
+ # ``base_model.model.model.language_model.layers.X...``. But warm-started GRPO
366
+ # (``_init_adapter_model``) loads the base via ``AutoModelForCausalLM`` — a TEXT-ONLY module tree
367
+ # whose LoRA targets are named ``base_model.model.model.layers.X...`` (no ``language_model.``
368
+ # infix). ``PeftModel.from_pretrained`` then can't match the SFT keys: peft logs a *warning* about
369
+ # missing adapter keys and SILENTLY keeps the fresh zero-init LoRA, so the SFT is thrown away and
370
+ # GRPO restarts from the base model (observed: linkd-search warm-start reward ~= 0.001).
371
+ #
372
+ # Stripping the ``.language_model.`` infix from the saved adapter keys makes them line up with the
373
+ # ``AutoModelForCausalLM`` trainer (proven workaround: remapped adapters train correctly). We keep
374
+ # the trainer as ``AutoModelForCausalLM`` so the train-time vLLM weight-sync remap
375
+ # (``patch_vllm_lm_weight_sync`` / ``_remap_vl_sync_weights``) stays consistent.
376
+ # --------------------------------------------------------------------------------------------
377
+
378
+ _LANGUAGE_MODEL_INFIX = ".language_model."
379
+
380
+
381
+ def strip_language_model_infix(key: str) -> str:
382
+ """Strip the FIRST ``.language_model.`` infix from a peft adapter weight key.
383
+
384
+ ``base_model.model.model.language_model.layers.0.linear_attn.out_proj.lora_A.default.weight``
385
+ -> ``base_model.model.model.layers.0.linear_attn.out_proj.lora_A.default.weight``.
386
+
387
+ Only the first occurrence is removed (the LM-vs-VL boundary appears once in the path); keys
388
+ without the infix are returned unchanged.
389
+ """
390
+ i = key.find(_LANGUAGE_MODEL_INFIX)
391
+ if i == -1:
392
+ return key
393
+ # Replace ".language_model." with "." (keep one separator dot).
394
+ return key[:i] + "." + key[i + len(_LANGUAGE_MODEL_INFIX) :]
395
+
396
+
397
+ def remap_adapter_keys(keys):
398
+ """Map an iterable of adapter weight keys -> a dict {old_key: new_key} for keys that change.
399
+
400
+ Pure (no I/O); used both by the on-disk rewriter and by tests to assert the post-remap key set
401
+ matches an ``AutoModelForCausalLM``-named LoRA param set.
402
+ """
403
+ out = {}
404
+ for k in keys:
405
+ nk = strip_language_model_infix(k)
406
+ if nk != k:
407
+ out[k] = nk
408
+ return out
409
+
410
+
411
+ def _rewrite_safetensors_header_keys(path: str, rename) -> int:
412
+ """Rename tensor keys in a ``.safetensors`` file IN PLACE, editing only the header.
413
+
414
+ safetensors layout: 8-byte little-endian header length, then a JSON header mapping
415
+ ``name -> {dtype, shape, data_offsets}`` (plus an optional ``__metadata__`` entry), then the
416
+ raw tensor data. ``data_offsets`` are relative to the data section, so a pure key rename leaves
417
+ every byte of the data section valid — we only rewrite the JSON header and its length prefix.
418
+
419
+ ``rename`` is a callable ``old_key -> new_key``. Returns the number of keys renamed. No torch /
420
+ safetensors dependency (keeps this module CPU-importable on the server venv).
421
+ """
422
+ import json
423
+ import os
424
+ import shutil
425
+ import struct
426
+
427
+ with open(path, "rb") as f:
428
+ len_bytes = f.read(8)
429
+ if len(len_bytes) < 8:
430
+ raise ValueError(f"{path}: too small to be a safetensors file")
431
+ (hdr_len,) = struct.unpack("<Q", len_bytes)
432
+ header_bytes = f.read(hdr_len)
433
+ if len(header_bytes) < hdr_len:
434
+ raise ValueError(f"{path}: truncated safetensors header")
435
+ try:
436
+ header = json.loads(header_bytes)
437
+ except (json.JSONDecodeError, UnicodeDecodeError) as exc:
438
+ # Re-raise with the file path so a corrupt adapter being rewritten is diagnosable
439
+ # (a bare JSONDecodeError/UnicodeDecodeError names no file). Non-UTF8 header bytes
440
+ # raise UnicodeDecodeError, not JSONDecodeError, so catch both to keep the context.
441
+ raise ValueError(
442
+ f"{path}: safetensors header is not valid JSON "
443
+ f"(corrupt or not a safetensors file): {exc}"
444
+ ) from exc
445
+ data_start = 8 + hdr_len
446
+
447
+ new_header = {}
448
+ renamed = 0
449
+ for k, v in header.items():
450
+ if k == "__metadata__":
451
+ new_header[k] = v
452
+ continue
453
+ nk = rename(k)
454
+ if nk != k:
455
+ if nk in header or nk in new_header:
456
+ raise ValueError(
457
+ f"{path}: remapped key {nk!r} collides with an existing key; refusing to "
458
+ f"overwrite (adapter may already be remapped or malformed)"
459
+ )
460
+ renamed += 1
461
+ new_header[nk] = v
462
+
463
+ if renamed == 0:
464
+ return 0
465
+
466
+ # Re-serialize compactly. safetensors does not require any specific key order or padding; the
467
+ # only constraint is that data_offsets stay consistent with the (unchanged) data section.
468
+ new_header_bytes = json.dumps(new_header, separators=(",", ":")).encode("utf-8")
469
+ # Stream the (possibly multi-GB) tensor data straight from the original to a temp file instead
470
+ # of slurping the whole file into memory; os.replace makes the swap atomic so an interrupted
471
+ # rewrite can't corrupt the adapter.
472
+ tmp = path + ".remap.tmp"
473
+ try:
474
+ with open(path, "rb") as src, open(tmp, "wb") as out:
475
+ src.seek(data_start)
476
+ out.write(struct.pack("<Q", len(new_header_bytes)))
477
+ out.write(new_header_bytes)
478
+ shutil.copyfileobj(src, out, 8 * 1024 * 1024)
479
+ except BaseException:
480
+ if os.path.exists(tmp):
481
+ os.remove(tmp)
482
+ raise
483
+ os.replace(tmp, path)
484
+ return renamed
485
+
486
+
487
+ def _rewrite_bin_keys(path: str, rename) -> int:
488
+ """Rename keys in a PyTorch ``.bin`` (pickled ``state_dict``) adapter IN PLACE.
489
+
490
+ Used only when the saved adapter is the legacy ``.bin`` format (no ``.safetensors``). Needs
491
+ torch to (de)serialize; that's fine because this path runs only on the GPU worker.
492
+ """
493
+ import torch
494
+
495
+ sd = torch.load(path, map_location="cpu", weights_only=True)
496
+ new_sd = {}
497
+ renamed = 0
498
+ for k, v in sd.items():
499
+ nk = rename(k)
500
+ if nk != k:
501
+ if nk in sd or nk in new_sd:
502
+ raise ValueError(
503
+ f"{path}: remapped key {nk!r} collides with an existing key; refusing to "
504
+ f"overwrite (adapter may already be remapped or malformed)"
505
+ )
506
+ renamed += 1
507
+ new_sd[nk] = v
508
+ if renamed == 0:
509
+ return 0
510
+ torch.save(new_sd, path)
511
+ return renamed
512
+
513
+
514
+ # Substrings that identify a peft LoRA weight key (vs a base-model param). The whole adapter file
515
+ # is LoRA weights, but a wrong-arch / corrupt checkpoint can contain non-LoRA tensors, so we filter.
516
+ _LORA_KEY_MARKERS = (".lora_A.", ".lora_B.", ".lora_embedding_A.", ".lora_embedding_B.", "lora_")
517
+
518
+
519
+ def _is_lora_key(key: str) -> bool:
520
+ return any(m in key for m in _LORA_KEY_MARKERS)
521
+
522
+
523
+ # A safetensors header is small even for huge models (a few hundred KB at most); 100 MB is a wildly
524
+ # generous ceiling that still refuses a corrupt/hostile file declaring a multi-GB header length
525
+ # before we allocate/read it.
526
+ _MAX_SAFETENSORS_HEADER_BYTES = 100 * 1024 * 1024
527
+
528
+
529
+ def _read_adapter_tensor_keys(adir: str) -> list[str] | None:
530
+ """Tensor key names in the downloaded adapter.
531
+
532
+ For ``.safetensors`` this reads ONLY the JSON header (pure stdlib, no tensor data — keeps this
533
+ module CPU-importable). For the legacy ``.bin`` format the pickled ``state_dict`` must be
534
+ materialized via ``torch.load`` to enumerate its keys (a pickle can't be read key-only without
535
+ unpickling the tensor payloads — GPU-worker only). Returns ``None`` when neither weight file
536
+ exists in ``adir``.
537
+ """
538
+ import json
539
+ import os
540
+ import struct
541
+
542
+ st_path = os.path.join(adir, "adapter_model.safetensors")
543
+ bin_path = os.path.join(adir, "adapter_model.bin")
544
+ if os.path.isfile(st_path):
545
+ # safetensors layout: 8-byte LE header length, then the JSON header, then the tensor data.
546
+ # Bound the DECLARED header length against the real file size (and an absolute ceiling)
547
+ # BEFORE reading it, so a corrupt/hostile file can't trigger a huge allocation / long read.
548
+ file_size = os.path.getsize(st_path)
549
+ with open(st_path, "rb") as f:
550
+ len_bytes = f.read(8)
551
+ if len(len_bytes) < 8:
552
+ raise ValueError(f"{st_path}: too small to be a safetensors file")
553
+ (hdr_len,) = struct.unpack("<Q", len_bytes)
554
+ if hdr_len > file_size - 8 or hdr_len > _MAX_SAFETENSORS_HEADER_BYTES:
555
+ raise ValueError(
556
+ f"{st_path}: declared safetensors header length {hdr_len} is implausible "
557
+ f"(file is {file_size} bytes) — refusing to read a corrupt/oversized header"
558
+ )
559
+ header_bytes = f.read(hdr_len)
560
+ if len(header_bytes) < hdr_len:
561
+ raise ValueError(f"{st_path}: truncated safetensors header")
562
+ try:
563
+ header = json.loads(header_bytes)
564
+ except (json.JSONDecodeError, UnicodeDecodeError) as exc:
565
+ # A bare JSONDecodeError ("Expecting value: line 1 column 1") — or a
566
+ # UnicodeDecodeError from non-UTF8 header bytes — gives no clue WHICH adapter is
567
+ # corrupt. Re-raise with the file path so a bad download is diagnosable.
568
+ raise ValueError(
569
+ f"{st_path}: safetensors header is not valid JSON "
570
+ f"(corrupt or not a safetensors file): {exc}"
571
+ ) from exc
572
+ # The safetensors header MUST be a JSON object keyed by tensor name. A corrupt/hostile file
573
+ # could decode to a list/int/str, which would later blow up with a confusing TypeError in
574
+ # _is_lora_key (substring search on a non-str). (JSON object keys are always str, so only the
575
+ # container type needs checking.) Reject a non-object header early with a clear message.
576
+ if not isinstance(header, dict):
577
+ raise ValueError(
578
+ f"{st_path}: safetensors header is not a JSON object "
579
+ "(corrupt or not a safetensors file)"
580
+ )
581
+ return [k for k in header if k != "__metadata__"]
582
+ if os.path.isfile(bin_path):
583
+ import torch
584
+
585
+ sd = torch.load(bin_path, map_location="cpu", weights_only=True)
586
+ return list(sd.keys())
587
+ return None
588
+
589
+
590
+ def remap_vl_adapter_dir(adir: str, model_id: str) -> int:
591
+ """For a VL warm-start, strip the ``.language_model.`` infix from the downloaded SFT adapter so
592
+ its keys match the ``AutoModelForCausalLM`` trainer used by ``_init_adapter_model``.
593
+
594
+ The remap decision is driven by the ADAPTER'S OWN keys, not only the ``is_vl_checkpoint`` config
595
+ probe. ``is_vl_checkpoint`` calls ``AutoConfig.from_pretrained`` and swallows EVERY exception to
596
+ return False, so an HF rate-limit / network hiccup / uncached config silently turned a genuine
597
+ VL warm-start into a no-op: the ``.language_model.`` keys were left in place, the text-only base
598
+ couldn't match them, peft kept the zero-init LoRA, and GRPO aborted at
599
+ ``assert_adapter_delta_nonzero`` with all-zero ``lora_B`` (issue #286). Any adapter that actually
600
+ carries ``.language_model.`` LoRA keys was saved against the full multimodal model and MUST be
601
+ stripped regardless of the probe, so we key off the file contents and only fall back to the probe
602
+ for the (already-stripped / text-only) no-infix case.
603
+
604
+ Fails LOUDLY instead of silently dropping a mismatched adapter:
605
+ - a VL warm-start whose adapter has NO LoRA keys at all (corrupt / wrong-architecture) raises;
606
+ - any ``.language_model.`` LoRA key that SURVIVES the rewrite raises (it would be silently
607
+ discarded by the text-only base -> all-zero ``lora_B``).
608
+
609
+ Returns the number of keys renamed. No-op (returns 0) for a genuinely text-only model, or an
610
+ already-remapped adapter. Idempotent: a second call finds nothing to strip.
611
+ """
612
+ import os
613
+
614
+ keys = _read_adapter_tensor_keys(adir)
615
+ if keys is None:
616
+ print(
617
+ f"[init-adapter] remap_vl_adapter_dir: no adapter_model.safetensors/.bin in {adir!r}; "
618
+ "nothing to remap"
619
+ )
620
+ return 0
621
+
622
+ lora_keys = [k for k in keys if _is_lora_key(k)]
623
+ infixed = [k for k in lora_keys if _LANGUAGE_MODEL_INFIX in k]
624
+
625
+ # No '.language_model.' LoRA keys -> nothing to strip from the file itself. The ONLY reason to act
626
+ # is the config probe, so it runs HERE (the fallback case) rather than on every warm-start: a key
627
+ # already in text-only form needs no network round-trip to confirm. is_vl distinguishes a genuine
628
+ # text-only model (return 0) from an already-remapped / text-only-SFT VL adapter (diagnostic).
629
+ if not infixed:
630
+ if not is_vl_checkpoint(model_id):
631
+ return 0 # genuinely text-only model with text-only adapter keys
632
+ if not lora_keys:
633
+ # A VL warm-start whose adapter carries no LoRA weights can't hold a real SFT delta — it
634
+ # would load as the all-zero identity. Fail here, before the base-model download.
635
+ raise RuntimeError(
636
+ f"warm-start adapter in {adir!r} for {model_id} contains NO LoRA weight keys "
637
+ f"(found {len(keys)} tensor(s), 0 with a lora_ marker) — the adapter is corrupt, "
638
+ "incomplete, or from a different architecture, so GRPO would train from the base "
639
+ "model. Re-export the SFT adapter, or omit train.init_from_adapter for a fresh LoRA."
640
+ )
641
+ # VL checkpoint but nothing to strip: legitimately already-remapped (idempotent re-run) or a
642
+ # text-only SFT. Surface the adapter's actual LoRA prefix so a real key mismatch isn't a
643
+ # silent no-op — if GRPO later aborts with all-zero lora_B, these keys didn't match the base.
644
+ sample_prefix = next(
645
+ (k.split(".lora_")[0] for k in lora_keys if ".lora_" in k), lora_keys[0]
646
+ )
647
+ print(
648
+ f"[init-adapter] remap_vl_adapter_dir: 0 '.language_model.' keys to strip for VL "
649
+ f"checkpoint {model_id} ({len(lora_keys)} LoRA key(s); e.g. prefix {sample_prefix!r}) — "
650
+ "treating as already-remapped/text-only. If the warm-start later aborts with all-zero "
651
+ "lora_B, these keys did not match the base model."
652
+ )
653
+ return 0
654
+
655
+ # The adapter carries '.language_model.' LoRA keys: it was saved against the full multimodal model
656
+ # and MUST be stripped to match the AutoModelForCausalLM trainer — regardless of the config probe
657
+ # (a flaky/failed AutoConfig probe must not silently skip a needed remap -> issue #286). We don't
658
+ # call is_vl_checkpoint at all on this path: the adapter's own keys are sufficient evidence.
659
+ # Fail CLOSED *before* touching disk: strip_language_model_infix removes only the FIRST infix, so a
660
+ # key carrying it twice would still match no text-only module and be silently discarded (the #286
661
+ # all-zero-lora_B failure). Predict the post-strip keys from the in-memory list (no file re-read).
662
+ survivors = [
663
+ nk for nk in (strip_language_model_infix(k) for k in infixed) if _LANGUAGE_MODEL_INFIX in nk
664
+ ]
665
+ if survivors:
666
+ raise RuntimeError(
667
+ f"remap_vl_adapter_dir: {len(survivors)} LoRA key(s) in {adir!r} for {model_id} would "
668
+ f"still carry '.language_model.' after the remap (e.g. {survivors[0]!r}) — they will NOT "
669
+ "match the AutoModelForCausalLM trainer and would be silently discarded -> all-zero "
670
+ "lora_B. The adapter's key layout is unexpected; verify it was saved by this SFT pipeline."
671
+ )
672
+
673
+ st_path = os.path.join(adir, "adapter_model.safetensors")
674
+ bin_path = os.path.join(adir, "adapter_model.bin")
675
+ if os.path.isfile(st_path):
676
+ n = _rewrite_safetensors_header_keys(st_path, strip_language_model_infix)
677
+ else: # bin_path exists — keys were read from one of the two files above
678
+ n = _rewrite_bin_keys(bin_path, strip_language_model_infix)
679
+
680
+ print(
681
+ f"[init-adapter] remapped {n} VL SFT adapter key(s): stripped '.language_model.' infix "
682
+ f"to match the AutoModelForCausalLM trainer for {model_id}"
683
+ )
684
+ return n
685
+
686
+
687
+ def assert_lora_applied(model, model_id: str) -> int:
688
+ """After ``PeftModel.from_pretrained``, verify the adapter's LoRA actually loaded (non-empty)
689
+ so a future key-mismatch regression fails LOUDLY instead of silently training a fresh LoRA.
690
+
691
+ Counts the LoRA A/B submodules present on the PeftModel. Raises for ANY warm-start that ended
692
+ up with ZERO LoRA modules (a key mismatch from any cause; the VL ``.language_model.`` mismatch
693
+ this remap fixes is the common one). Returns the count.
694
+ """
695
+ count = 0
696
+ for name, _ in model.named_modules():
697
+ # peft names the per-target adapter submodules ``...lora_A.<adapter>`` / ``...lora_B.*``.
698
+ if name.endswith("lora_A.default") or name.endswith("lora_B.default"):
699
+ count += 1
700
+ if count == 0:
701
+ raise RuntimeError(
702
+ f"warm-start adapter for {model_id} loaded ZERO LoRA modules — the SFT adapter was NOT "
703
+ "applied (key mismatch). GRPO would silently restart from the base model. For Qwen3.5/"
704
+ "3.6 VL this is usually the '.language_model.' key-mismatch (check remap_vl_adapter_dir "
705
+ "ran on the adapter); otherwise verify the adapter's keys match the model."
706
+ )
707
+ print(f"[init-adapter] verified {count} LoRA submodule(s) applied for {model_id}")
708
+ return count
709
+
710
+
711
+ def assert_adapter_load_clean(load_result, model_id: str) -> None:
712
+ """Assert a peft adapter load matched ALL saved keys — fail closed on a silent discard.
713
+
714
+ ``PeftModel.from_pretrained`` loads adapter weights with ``load_state_dict(strict=False)`` and
715
+ only WARNS on a key mismatch (it throws the load result away), so an SFT adapter whose keys don't
716
+ line up with the target base is silently dropped and GRPO restarts from the base model (bug #67).
717
+ ``assert_lora_applied`` can't catch this: peft INJECTS the LoRA modules from ``target_modules``
718
+ BEFORE loading any weights, so the module count is non-zero even when zero saved weights matched.
719
+
720
+ ``load_result`` is the object returned by ``PeftModel.load_adapter`` (a ``_IncompatibleKeys`` with
721
+ ``missing_keys`` / ``unexpected_keys``). We only care about LoRA keys: an adapter-only checkpoint
722
+ loaded with ``strict=False`` legitimately leaves the base-model params out, so they can surface as
723
+ "missing" without anything being wrong. peft's ``load_adapter`` already filters ``missing_keys`` to
724
+ the tuner prefix, but we re-filter to keys carrying the LoRA prefix (``lora_``) ourselves so a
725
+ benign base-weight miss never aborts a correct warm-start even if peft's internal filtering
726
+ changes. Raises if any injected LoRA module got no saved weight (``missing_keys``) or any saved
727
+ LoRA key matched no module (``unexpected_keys``) — i.e. matched != saved.
728
+ """
729
+
730
+ def _lora_only(keys):
731
+ # the #67 mismatch keys (e.g. ...lora_A.default.weight) all carry this prefix; base-model
732
+ # params do not, so this drops the benign base misses peft can report under strict=False.
733
+ return [k for k in (keys or []) if "lora_" in k]
734
+
735
+ missing = _lora_only(getattr(load_result, "missing_keys", None))
736
+ unexpected = _lora_only(getattr(load_result, "unexpected_keys", None))
737
+ if missing or unexpected:
738
+ raise RuntimeError(
739
+ f"warm-start adapter for {model_id} did NOT load cleanly: {len(missing)} injected LoRA "
740
+ f"module(s) got no saved weight (missing) and {len(unexpected)} saved adapter key(s) "
741
+ "matched no module (unexpected). The adapter was silently discarded -> GRPO would restart "
742
+ "from the base model. For Qwen3.5/3.6 VL this is the '.language_model.' key mismatch "
743
+ "(check remap_vl_adapter_dir ran on the adapter); otherwise the adapter's keys don't match "
744
+ f"the base. missing[:3]={missing[:3]} unexpected[:3]={unexpected[:3]}"
745
+ )
746
+ print(
747
+ f"[init-adapter] adapter load matched all saved keys for {model_id} (no missing/unexpected)"
748
+ )
749
+
750
+
751
+ def assert_adapter_delta_nonzero(model, model_id: str) -> int:
752
+ """Assert at least one ``lora_B`` weight is non-zero — the adapter is not an identity no-op.
753
+
754
+ With standard zero-B init (``init_lora_weights=True``), a freshly-injected-but-unloaded adapter
755
+ has ``lora_B == 0`` everywhere, so the effective delta ``(B @ A) * scaling`` is identically zero
756
+ and the warm-started model equals the base. A real SFT adapter that actually loaded has non-zero
757
+ ``lora_B``. This is an API-independent backstop to ``assert_adapter_load_clean``: it catches a
758
+ silent discard even if peft's load-result shape changes. Returns the count of non-zero ``lora_B``
759
+ modules. When no ``lora_B`` modules exist at all, defers to ``assert_lora_applied`` (no raise).
760
+ """
761
+ seen = 0
762
+ nonzero = 0
763
+ for name, module in model.named_modules():
764
+ if not name.endswith("lora_B.default"):
765
+ continue
766
+ weight = getattr(module, "weight", None)
767
+ if weight is None:
768
+ continue
769
+ seen += 1
770
+ if bool(weight.detach().ne(0).any()):
771
+ nonzero += 1
772
+ if seen and nonzero == 0:
773
+ raise RuntimeError(
774
+ f"warm-start adapter for {model_id} has ALL-ZERO lora_B weights across {seen} module(s) — "
775
+ "the adapter delta is identically zero (an unloaded / silently-discarded adapter). GRPO "
776
+ "would train from the base model. Verify the adapter's keys match the base (see "
777
+ "remap_vl_adapter_dir)."
778
+ )
779
+ print(f"[init-adapter] verified non-zero lora_B in {nonzero}/{seen} module(s) for {model_id}")
780
+ return nonzero
781
+
782
+
783
+ def model_quant(model_id: str) -> str:
784
+ """Quantization tier for this model: catalog entry > bf16 (managed; no override).
785
+
786
+ The whole catalog is bf16, so this always returns ``"bf16"`` today; kept as the single
787
+ source of truth a future non-bf16 tier could feed (no caller branches on it now)."""
788
+ try:
789
+ from flash.catalog import MODELS
790
+
791
+ info = MODELS.get(model_id)
792
+ if info is not None:
793
+ return info.quant
794
+ except Exception as e:
795
+ print("model_quant: catalog probe failed:", e)
796
+ return "bf16"