freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
flash/engine/vram.py ADDED
@@ -0,0 +1,603 @@
1
+ """Coarse VRAM-fit estimation for one-consumer-GPU LoRA jobs.
2
+
3
+ Used by the open-model policy (``model_policy = "allow"``) to sanity-check that an
4
+ unlisted HF model can plausibly run on the requested GPU before provisioning it.
5
+
6
+ These are deliberately coarse heuristics (documented ±20%): they exist to catch
7
+ *provably impossible* configurations (70B bf16 on a 24 GB card) and to warn on tight
8
+ fits — not to guarantee success. Calibrated against the measured catalog entries
9
+ (Qwen3-0.6B/4B/8B, Qwen3.5 dense).
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import math
15
+ import os
16
+ import re
17
+ from dataclasses import dataclass
18
+
19
+
20
+ def _gpu_vram_table() -> dict[str, int]:
21
+ try:
22
+ from flash.providers.base import GPU_INFO
23
+
24
+ return {name: info.vram_gb for name, info in GPU_INFO.items()}
25
+ except Exception:
26
+ return {"RTX 4090": 24, "RTX 5090": 32}
27
+
28
+
29
+ GPU_VRAM_GB = _gpu_vram_table()
30
+
31
+ _BYTES_PER_PARAM = {
32
+ "bf16": 2.0,
33
+ "fp16": 2.0,
34
+ }
35
+
36
+ # Fixed overheads (GB): CUDA context + activations w/ gradient checkpointing +
37
+ # LoRA params/grads/Adam states (tiny at rank<=64) + fragmentation headroom.
38
+ _BASE_OVERHEAD_GB = 4.0
39
+ # Activations with gradient checkpointing scale ~linearly with tokens-in-flight
40
+ # (batch x seq) and model width (~sqrt of params). Coef calibrated so 4.7B SFT at
41
+ # seq 32k / batch 1 lands ~22 GB (measured: fits a 32 GB 5090).
42
+ _ACT_COEF = 0.12
43
+ # SFT activations + logits peak on the worker's PER-DEVICE micro-batch, not [train].batch_size
44
+ # (which is the global/effective batch realized via gradient accumulation). The worker caps the
45
+ # micro-batch at 4 and, when the fused CE is off, vocab-sizes it further to the logits budget (see
46
+ # ``sft_per_device``). Mirror that here so an unset/long-context SFT run still reserves the
47
+ # micro-batch peak, and a large effective batch isn't mis-counted as resident VRAM (it's grad-accum,
48
+ # not in-flight activations).
49
+ _SFT_PER_DEVICE_BS_DEFAULT = 4
50
+
51
+
52
+ def _sft_per_device_bs() -> int:
53
+ """The worker's BASE per-device SFT micro-batch cap (before the big-vocab logits cap layered on
54
+ by ``sft_per_device``) — the activation-peak driver to size against.
55
+
56
+ SFT micro-batch is a MANAGED default: the control plane no longer forwards
57
+ ``SFT_PER_DEVICE_BS`` to the worker (build_worker_env dropped the tuning allowlist), and the
58
+ worker's own process env never carries it, so the worker always runs the fixed default. The
59
+ allocator must size against that SAME fixed value — reading the control-plane process env here
60
+ would size a card for a micro-batch the worker never uses, under-routing an
61
+ ``SFT_PER_DEVICE_BS=1`` operator env to a too-small GPU that then OOMs at the default
62
+ micro-batch 4 (the asymmetry the env-knobs cleanup removed everywhere else)."""
63
+ return _SFT_PER_DEVICE_BS_DEFAULT
64
+
65
+
66
+ def sft_grad_accum(
67
+ batch_size: int, *, seq_len: int = 0, vocab: int = 0, fused: bool = True
68
+ ) -> tuple[int, int]:
69
+ """(per-device micro-batch, grad-accum steps) the worker realizes for a requested GLOBAL
70
+ ``batch_size``: per_device capped at the micro-batch default (and ADDITIONALLY vocab-sized to
71
+ the logits budget when the fused CE is off — see ``sft_per_device``), grad-accum CEIL'd so the
72
+ realized global batch is never BELOW the request (e.g. batch 6 -> per_device 4 x
73
+ ceil(6/4)=2 grad-accum = realized 8, >= 6).
74
+
75
+ ``seq_len``/``vocab``/``fused`` are the big-vocab logits-cap inputs; omitted (or ``fused``) they
76
+ reduce to the old fixed per-device cap, so existing callers are unchanged."""
77
+ target = max(1, int(batch_size))
78
+ per_device = sft_per_device(target, seq_len=seq_len, vocab=vocab, fused=fused)
79
+ grad_accum = max(1, -(-target // per_device)) # ceil
80
+ return per_device, grad_accum
81
+
82
+
83
+ def sft_realized_batch(
84
+ batch_size: int, *, seq_len: int = 0, vocab: int = 0, fused: bool = True
85
+ ) -> int:
86
+ """The realized SFT global batch (per_device x grad_accum) for a requested ``batch_size`` —
87
+ mirrors the worker so the cost step-count matches. Pass seq_len/vocab/fused to honor the
88
+ big-vocab per-device cap (the cost path does); omitted, it's the old fixed-cap behavior."""
89
+ per_device, grad_accum = sft_grad_accum(batch_size, seq_len=seq_len, vocab=vocab, fused=fused)
90
+ return per_device * grad_accum
91
+
92
+
93
+ # Colocated-GRPO vLLM KV pool: grows with the engine's max context (seq) and model
94
+ # width, but vLLM bounds the pool to a fraction of the card and PAGES rather than OOMs,
95
+ # so it's capped (_KV_CAP) instead of growing without bound at long context.
96
+ _KV_COEF = 2.0
97
+ _KV_CAP = 8.0
98
+
99
+
100
+ def grpo_rollout_seq_len(
101
+ max_length: int = 0,
102
+ max_tokens: int | None = None,
103
+ thinking: bool = False,
104
+ ) -> int:
105
+ """The vLLM engine context a GRPO run ACTUALLY uses, mirroring run_rl(): the run's
106
+ ``[train].max_length`` when set, else ``max(1024, RLConfig.max_prompt_len + completion)`` where
107
+ ``completion`` is ``[train].max_tokens`` or the recipe's thinking/non-thinking default. The
108
+ allocator sizing, the sleep-mode resident gate, and the colocate KV budget all resolve the SAME
109
+ length here so a run whose max_length is unset is not sized as a 1024-token rollout while the
110
+ worker launches a ~2368-token (3584 with thinking) engine."""
111
+ from flash.engine.recipe import RECIPE
112
+
113
+ rl = RECIPE.rl
114
+ completion = int(
115
+ max_tokens or (rl.max_completion_len_thinking if thinking else rl.max_completion_len)
116
+ )
117
+ return int(max_length or max(1024, rl.max_prompt_len + completion))
118
+
119
+
120
+ def _resident_kv_gb(params_b: float | None, vllm_max_len: int, num_generations: int = 8) -> float:
121
+ """KV (GB) a colocated rollout engine holds RESIDENT for the engine context + generation group.
122
+ Scales with BOTH (vLLM's cache blocks must cover ``vllm_max_model_length`` for ``num_generations``
123
+ concurrent sequences) -- unlike the sleep-mode rollout estimate, which caps it (``_KV_CAP``)
124
+ because the engine is offloaded during the backward there. Shared by the resident-fit estimate
125
+ and the non-sleep colocate budget so the gate and the budget size the SAME KV."""
126
+ width = math.sqrt(max(float(params_b or 1.0), 0.1))
127
+ return _KV_COEF * (max(1, vllm_max_len) / 1024.0) * width * (max(1, num_generations) / 8.0)
128
+
129
+
130
+ def colocate_kv_util(
131
+ params_b: float | None,
132
+ vllm_max_len: int,
133
+ total_vram_gb: float,
134
+ sleep_mode: bool,
135
+ num_generations: int = 8,
136
+ ) -> float:
137
+ """``vllm_gpu_memory_utilization`` for the colocated GRPO rollout engine, sized to the ACTUAL need
138
+ rather than a blanket fraction of the card.
139
+
140
+ ``gpu_memory_utilization`` is vLLM's WHOLE model-executor budget — its (2nd) bf16 weight copy PLUS
141
+ the KV cache — so we budget BOTH (budgeting KV alone would starve the weights and, for big models,
142
+ under-size the engine). The KV a GRPO rollout needs scales with the engine context AND the
143
+ concurrent generation group (``num_generations`` simultaneous sequences), so we size the pool as
144
+ ``_KV_COEF x seq x sqrt(params) x group/8`` with a 1.5x margin and an 8 GB floor — NOT capped, so
145
+ long-context / large-group runs keep a big pool (the 0.45 utilization cap bounds it like the old
146
+ blanket did). The old blanket sleep-path 0.45 reserved ~36 GB on an 80 GB A100 — MEASURED as the
147
+ dominant resident allocation that set the GRPO step peak (~46 GB). BOTH paths budget the weight
148
+ copy + KV; the non-sleep path uses the leaner resident-KV target (_KV_CAP). MEASURED at
149
+ 4B/group8/2k ctx: 0.25 util -> peak 46 -> 26 GB, reward byte-identical, train_wall neutral; a
150
+ tighter 12 GB budget preempts, confirming this as the floor."""
151
+ weights_gb = max(0.5, float(params_b or 1.0)) * 2.0 # vLLM's bf16 weight copy lives in the budget
152
+ if not sleep_mode:
153
+ # Resident KV ON TOP of the weight copy: gpu_memory_utilization is the WHOLE executor budget,
154
+ # so budgeting KV alone (the old _KV_CAP/total) starved the weights and vLLM raised "No
155
+ # available memory for the cache blocks" on >=3B models whose weights exceed an 8 GB budget.
156
+ # The KV must ALSO cover the rollout context -- a flat _KV_CAP starves the cache blocks on a
157
+ # long-context run (vLLM's blocks must span vllm_max_model_length), so scale it with the
158
+ # context + group (floored at _KV_CAP for the validated short-context lean point, bounded by
159
+ # the 0.45 util cap below). Matches the resident-fit estimate (estimate_vram_gb sleep_offload
160
+ # =False) so grpo_sleep_mode's gate and this budget size the SAME KV.
161
+ kv_gb = max(_KV_CAP, _resident_kv_gb(params_b, vllm_max_len, num_generations))
162
+ return max(0.10, min(0.45, (weights_gb + kv_gb) / max(1.0, total_vram_gb)))
163
+ # Sleep mode keeps a larger pool (1.5x margin): the engine is offloaded during the backward, so a
164
+ # bigger rollout-phase KV does not compete with the training peak.
165
+ kv_pool_gb = max(_KV_CAP, 1.5 * _resident_kv_gb(params_b, vllm_max_len, num_generations))
166
+ return min(0.45, (weights_gb + kv_pool_gb) / max(1.0, total_vram_gb))
167
+ # GRPO backward (activations + fp32 logits over the completion micro-batch) per unit
168
+ # context x model width. Grad checkpointing makes this MILD in seq -- calibrated to
169
+ # measured boundaries: 0.8B GRPO fits 24 GB up to seq 32k (seq ~free), while 4.7B GRPO
170
+ # steps off a 32 GB card between seq 16k and 32k. group size scales it sublinearly.
171
+ _TRAIN_COEF = 0.27
172
+ # Fixed floor for colocated-vLLM GRPO: the vLLM engine's CUDA context + KV pool (sized to the
173
+ # CARD's VRAM via gpu_util, not the model) + the 2nd resident weight copy is ~model-independent
174
+ # for small models and dominates their param estimate, so tiny/mid models all need the 32 GB tier.
175
+ # MEASURED at the default group_size=8: 0.8B GRPO OOMs a 20 GB card; 2B GRPO OOMs a 24 GB card
176
+ # (-> both need 32); 4B GRPO fits 32 (param est ~31 already clears this floor, so it's untouched).
177
+ _VLLM_COLOCATE_FLOOR_GB = 28.0
178
+ # Fallback output vocab (lm_head / logits width) for estimate_vram_gb when no model vocab is
179
+ # passed; the model-aware path (model_required_vram_gb) resolves the real per-model value
180
+ # from flash.catalog via vocab_size_for(). Mirrors catalog._DEFAULT_VOCAB_SIZE.
181
+ _VOCAB_DEFAULT = 248_320
182
+ # Matches the worker's logits budget (6 GB): the per-device fp32 logits are capped to this
183
+ # (rl_per_device_comps spills the rest into grad-accum), so the estimator never reserves above it.
184
+ _LOGITS_BUDGET_GB = 6.0
185
+
186
+ # ---- SFT big-vocab logits: the SFT analog of the GRPO fp32-logits term above ----
187
+ # When the worker's fused cross-entropy (Liger) is OFF, an SFT forward materializes the FULL-sequence
188
+ # [per_device, seq_len, vocab] logits AND keeps their gradient live through the backward. At
189
+ # Qwen3.5's ~248k vocab this is the documented big-vocab SFT OOM driver (a 0.8B SFT OOM'd a 24 GB
190
+ # card). The worker fuses CE only for a >=3B model OR a >=2048-token context (mirrors
191
+ # engine.worker.perf._memory_mode); BELOW that the term is real and was previously ignored entirely.
192
+ # An SFT step holds, AT ONCE, the fp32 logits (4) + their fp32 grad (4) + the bf16 logits the model
193
+ # emits (2) + the bf16 grad (2) + the cross-entropy log_softmax temp (4) ~= 16 B/elem. (8 B/elem --
194
+ # fp32 logits+grad only -- UNDER-counted: a live 2B SFT seq1024 at per_device=2 peaked ~15.8 GiB and
195
+ # OOM'd a 16 GB card whose usable is ~15.6 GiB.) At 16 B/elem the per-device cap drops to 1 for a
196
+ # big-vocab un-fused SFT, so the worker materializes far less and the real peak clears even the
197
+ # tightest 16 GB card. The worker vocab-sizes the per-device micro-batch so these logits never
198
+ # exceed _LOGITS_BUDGET_GB while pd CAN be reduced; the estimator reserves the TRUE per-device-capped
199
+ # term (no budget clamp -- the irreducible pd=1 floor can exceed the budget at a near-2048 ctx) -- so
200
+ # the allocator provably covers the worker's real peak. VALIDATED by a live re-run.
201
+ _SFT_LOGITS_BYTES_PER_ELEM = 16.0
202
+ # Canonical fused-CE (Liger) gate thresholds: the worker fuses the SFT cross-entropy for a >=3B
203
+ # model OR a >=2048-token context. SINGLE SOURCE OF TRUTH -- engine.worker.perf imports these (its
204
+ # _LIGER_MIN_PARAMS / _LONG_CONTEXT_TOKENS derive from them) and sft_logits_fused mirrors the gate
205
+ # offline (no network AutoConfig probe) so the cost estimator stays deterministic.
206
+ _LIGER_MIN_PARAMS_B = 3.0
207
+ _LIGER_LONG_CTX_TOKENS = 2048
208
+
209
+
210
+ def sft_logits_fused(params_b: float | None, seq_len: int) -> bool:
211
+ """Whether the worker fuses the SFT cross-entropy (Liger), so the [per_device, seq, vocab] logits
212
+ never materialize. Mirrors engine.worker.perf._memory_mode without a network probe: fused for a
213
+ >=3B model OR a >=2048-token context. (The worker image bakes liger-kernel, so True here means
214
+ the fused kernel is actually used; if it were ever absent the per-device cap still bounds the
215
+ logits.)"""
216
+ if seq_len >= _LIGER_LONG_CTX_TOKENS:
217
+ return True
218
+ return (params_b or 0.0) >= _LIGER_MIN_PARAMS_B
219
+
220
+
221
+ def sft_logits_per_device_cap(seq_len: int, vocab: int) -> int:
222
+ """Largest SFT per-device micro-batch whose un-fused [per_device, seq, vocab] fp32 logits (+grad)
223
+ fit _LOGITS_BUDGET_GB. The SFT mirror of rl_per_device_comps' completion cap, sizing the FULL
224
+ sequence (not just the completion): the worker spills the remainder into grad-accum so the
225
+ realized global batch is unchanged, and the estimator reserves the same bounded term."""
226
+ denom = max(1, int(seq_len)) * max(1, int(vocab)) * _SFT_LOGITS_BYTES_PER_ELEM
227
+ return max(1, int(_LOGITS_BUDGET_GB * 1e9 / denom))
228
+
229
+
230
+ def sft_per_device(batch_size: int, *, seq_len: int = 0, vocab: int = 0, fused: bool = True) -> int:
231
+ """The per-device SFT micro-batch the worker runs: the requested global batch capped at the
232
+ micro-batch default (4) and ADDITIONALLY vocab-sized to the logits budget when the fused CE is
233
+ OFF (small model AND short context) — so the big-vocab [per_device, seq, vocab] logits can't OOM
234
+ the card. With seq_len/vocab unset (or fused), this remains the fixed cap."""
235
+ per_device = max(1, min(_SFT_PER_DEVICE_BS_DEFAULT, max(1, int(batch_size))))
236
+ if not fused and seq_len and vocab:
237
+ per_device = min(per_device, sft_logits_per_device_cap(seq_len, vocab))
238
+ return per_device
239
+
240
+
241
+ def grpo_seq_escalation_gb(params_b: float | None, seq_len: int) -> int:
242
+ """Extra GB a long-context GRPO run needs beyond its base footprint.
243
+
244
+ Big-model GRPO is tight: colocate holds 2 weight copies + a KV pool, so headroom shrinks
245
+ with model size and long context overflows it. Calibrated on a bf16 9.7B GRPO run (RunPod):
246
+ fits 80 GB to seq 4096 but OOMs at 8192. Safe headroom ~ 48500/params_b tokens; past that
247
+ escalate, STEEPER for bigger models. Applies to both catalog and open-model GRPO so neither
248
+ under-provisions.
249
+ """
250
+ coef = 0.9
251
+ if not params_b:
252
+ return 0
253
+ seq_thresh = 48_500.0 / params_b
254
+ if seq_len <= seq_thresh:
255
+ return 0
256
+ return math.ceil(coef * params_b * (seq_len / seq_thresh - 1))
257
+
258
+
259
+ def params_b_from_str(s: str | None) -> float | None:
260
+ """Leading param count (billions) from a catalog ``params`` string, e.g.
261
+ "4.7B (text-only fine-tune)" -> 4.7, "9.7B (text-only fine-tune)" -> 9.7."""
262
+ if not s:
263
+ return None
264
+ m = re.search(r"([0-9]+(?:\.[0-9]+)?)\s*B", s)
265
+ return float(m.group(1)) if m else None
266
+
267
+
268
+ @dataclass(frozen=True)
269
+ class VramEstimate:
270
+ params_b: float | None
271
+ algorithm: str
272
+ quant: str
273
+ est_gb: float | None
274
+ gpu: str
275
+ gpu_gb: int
276
+ verdict: str # "fits" | "tight" | "too_big" | "unknown"
277
+
278
+ def describe(self) -> str:
279
+ if self.est_gb is None:
280
+ return f"{self.gpu}: VRAM need unknown (could not read model size)"
281
+ return (
282
+ f"{self.gpu} ({self.gpu_gb} GB): estimated ~{self.est_gb:.0f} GB needed "
283
+ f"({self.params_b:.1f}B params, {self.quant}, {self.algorithm}) -> {self.verdict}"
284
+ )
285
+
286
+
287
+ def estimate_vram_gb(
288
+ params_b: float,
289
+ algorithm: str,
290
+ quant: str = "bf16",
291
+ *,
292
+ seq_len: int = 1024,
293
+ max_tokens: int | None = None,
294
+ lora_rank: int = 32,
295
+ batch_size: int = 1,
296
+ group_size: int = 8,
297
+ thinking: bool = False,
298
+ use_vllm: bool = True,
299
+ vocab: int = _VOCAB_DEFAULT,
300
+ sleep_offload: bool = True,
301
+ ) -> float:
302
+ """Estimated peak VRAM (GB) for a LoRA job on one GPU, over the full knob matrix.
303
+
304
+ Terms (all in GB):
305
+ weights params x bytes/param (bf16=2)
306
+ base CUDA context + framework + fragmentation headroom
307
+ lora_opt LoRA adapter + grads + Adam states (rank-linear, model-scaled)
308
+ activations grad-checkpointed activations ~ batch x seq x sqrt(params)
309
+ grpo only:
310
+ +weights colocated vLLM holds a 2nd resident weight copy at the rollout peak
311
+ (sleep mode offloads it BETWEEN steps, not during) -- skipped when
312
+ use_vllm is False (transformers generation, single copy)
313
+ kv vLLM KV pool ~ seq x sqrt(params)
314
+ logits fp32 logits [per_device_comps, completion, vocab]
315
+ """
316
+ bpp = _BYTES_PER_PARAM.get(quant, 2.0)
317
+ weights = params_b * bpp
318
+ algo = "grpo" if (algorithm or "").lower() in ("grpo", "rl") else "sft"
319
+ width = math.sqrt(max(params_b, 0.1))
320
+ lora_opt = (lora_rank / 16.0) * (0.3 + 0.04 * params_b)
321
+ base = weights + _BASE_OVERHEAD_GB + lora_opt
322
+ if algo == "grpo":
323
+ # GRPO alternates two phases that DON'T peak together (sleep mode offloads the
324
+ # vLLM engine during the backward), so peak = max(rollout, train), not the sum:
325
+ # rollout: colocated vLLM 2nd weight copy + KV pool (skipped if use_vllm=False)
326
+ # train: backward activations + fp32 logits -- MILD in seq (grad ckpt)
327
+ rollout = 0.0
328
+ if use_vllm:
329
+ if sleep_offload:
330
+ # Sleep mode offloads the engine during the backward, so the rollout-phase KV (capped
331
+ # at _KV_CAP) never competes with the training peak. This is the ALLOCATOR's estimate
332
+ # (model_required_vram_gb) -- keep it calibrated; it sizes every GRPO allocation.
333
+ rollout = weights + min(_KV_COEF * (seq_len / 1024.0) * width, _KV_CAP)
334
+ else:
335
+ # Resident: the engine stays live THROUGH the backward, so its KV (which must cover the
336
+ # rollout context for the whole generation group) is held alongside training -- size it
337
+ # to the real context, matching colocate_kv_util's non-sleep budget, instead of the
338
+ # flat _KV_CAP (which let grpo_fits_resident wrongly admit long-context runs).
339
+ rollout = weights + _resident_kv_gb(params_b, seq_len, group_size)
340
+ group_factor = max(1.0, (max(1, group_size) / 4.0) ** 0.5)
341
+ think_factor = 1.3 if thinking else 1.0
342
+ activations = _TRAIN_COEF * (seq_len / 1024.0) * width * group_factor * think_factor
343
+ # fp32 logits [per_device, completion, vocab] are the documented GRPO OOM driver. The
344
+ # worker MEMORY-CAPS per_device (rl_per_device_comps) so the live logits never exceed the
345
+ # logits budget (6 GB) and the rest spills into grad-accum -- so the IRREDUCIBLE floor the
346
+ # card must hold is the per_device=1 logits for the completion length: it scales with
347
+ # max_tokens (NOT seq_len) and is capped at the budget. completion defaults to the recipe
348
+ # budget (~min(seq_len, 1024)) when max_tokens is unset.
349
+ completion = max_tokens if max_tokens else min(seq_len, 1024)
350
+ logits = min(completion * vocab * 4 / 1e9, _LOGITS_BUDGET_GB)
351
+ train = activations + logits
352
+ # Sleep mode offloads the vLLM rollout engine during the backward, so rollout and train
353
+ # don't peak together (peak = max). WITHOUT sleep the engine stays resident through the
354
+ # backward, so both are live at once (peak = sum). sleep_offload=False sizes that resident
355
+ # peak -- used by grpo_fits_resident to decide whether a run can skip sleep mode.
356
+ return base + (max(rollout, train) if sleep_offload else rollout + train)
357
+ # SFT: peak = base + activations + the big-vocab logits term. Both activations and logits are
358
+ # driven by the worker's per-device micro-batch (capped at 4 AND vocab-sized to the logits budget
359
+ # when the fused CE is off), NOT the global/effective batch_size (grad-accum realizes that). Use
360
+ # the SAME ``sft_per_device`` the worker runs so the estimate tracks what actually executes.
361
+ fused = sft_logits_fused(params_b, seq_len)
362
+ pd = sft_per_device(batch_size, seq_len=seq_len, vocab=vocab, fused=fused)
363
+ activations = _ACT_COEF * pd * (seq_len / 1024.0) * width
364
+ # fp32-logits term: 0 when the worker fuses CE (>=3B model OR >=2048-token ctx, so the lm_head
365
+ # never materializes [B,T,vocab]); else the [per_device, seq_len, vocab] logits the forward holds.
366
+ # Reserve the TRUE per-device-capped value -- NOT clamped to the budget: the budget only chooses
367
+ # ``pd`` (so pd>1 cases stay <= budget), but once pd floors to 1 the logits are IRREDUCIBLE and
368
+ # can exceed the budget at a near-2048 ctx -- clamping there would under-reserve and OOM (the
369
+ # worker can't go below pd=1). The SFT analog of the GRPO logits term, sized over the FULL seq_len
370
+ # (SFT loss spans the sequence) -- the term the SFT estimate once ignored entirely.
371
+ logits = 0.0 if fused else pd * seq_len * vocab * _SFT_LOGITS_BYTES_PER_ELEM / 1e9
372
+ return base + activations + logits
373
+
374
+
375
+ def grpo_fits_resident(
376
+ model_id: str,
377
+ *,
378
+ seq_len: int = 1024,
379
+ max_tokens: int | None = None,
380
+ lora_rank: int = 32,
381
+ group_size: int = 8,
382
+ thinking: bool = False,
383
+ card_vram_gb: float = 0.0,
384
+ margin: float = 1.15,
385
+ ) -> bool:
386
+ """Whether a colocated-vLLM GRPO run fits RESIDENT (no vLLM sleep-mode offload) on a card of
387
+ ``card_vram_gb`` with a safety ``margin``. When it fits, sleep mode is unnecessary -- and the
388
+ sleep/wake cycle is what stalls the large-model GRPO rollout -- so the worker can skip it.
389
+ Conservative: an unknown card size or unknown model size returns False (keep the memory-safe
390
+ sleep default)."""
391
+ if not card_vram_gb or card_vram_gb <= 0:
392
+ return False
393
+ from flash.catalog import MODELS, vocab_size_for
394
+
395
+ info = MODELS.get(model_id)
396
+ params_b = float(getattr(info, "params_b", 0.0) or 0.0) if info else 0.0
397
+ if params_b <= 0:
398
+ return False # unknown size (open-model path) -> keep the safe default
399
+ quant = (getattr(info, "quant", "bf16") or "bf16") if info else "bf16"
400
+ resident = estimate_vram_gb(
401
+ params_b,
402
+ "grpo",
403
+ quant,
404
+ seq_len=max(1, int(seq_len or 1024)),
405
+ max_tokens=max_tokens,
406
+ lora_rank=lora_rank,
407
+ group_size=group_size,
408
+ thinking=thinking,
409
+ use_vllm=True,
410
+ vocab=vocab_size_for(model_id),
411
+ sleep_offload=False,
412
+ )
413
+ return resident * margin <= card_vram_gb
414
+
415
+
416
+ def model_required_vram_gb(
417
+ model_id: str,
418
+ algorithm: str,
419
+ *,
420
+ train=None,
421
+ thinking: bool = False,
422
+ headroom: float = 1.1,
423
+ ) -> int:
424
+ """Cheapest-sufficient VRAM (GB) for a specific run -- the matrix the allocator and
425
+ ``provisional_gpu`` both size against.
426
+
427
+ Catalog models size from their known param count + the run's actual knobs (``train``
428
+ may be a TrainSpec, a dict, or None for recipe defaults). Curated GRPO floors
429
+ (``grpo_min_vram_gb``) stay as HARD floors so we never under-provision a validated
430
+ model; the matrix only ever sizes UP from there. Unlisted open models size from HF
431
+ metadata, falling back to the 24 GB tier when the size can't be read.
432
+ """
433
+
434
+ # Best-effort knob extraction: this provisional sizing runs at parse time BEFORE the
435
+ # dedicated train validators, so malformed knobs (nan/inf/strings/<=0) must fall back
436
+ # to a default here, never crash -- config_schema raises the proper ConfigError next.
437
+ def _g(obj, key):
438
+ if obj is None:
439
+ return None
440
+ return obj.get(key) if isinstance(obj, dict) else getattr(obj, key, None)
441
+
442
+ def _pos_int(v, default):
443
+ try:
444
+ if isinstance(v, bool):
445
+ return default
446
+ f = float(v)
447
+ return int(f) if math.isfinite(f) and f >= 1 else default
448
+ except (TypeError, ValueError):
449
+ return default
450
+
451
+ max_tokens = _pos_int(_g(train, "max_tokens"), None)
452
+ # Default sequence length when [train].max_length is unset. For GRPO this must MIRROR what
453
+ # run_rl() actually starts vLLM at — max(1024, RLConfig.max_prompt_len + completion) — not a
454
+ # flat 1024, or the allocator can pick a GPU sized for 1024 tokens while the worker launches a
455
+ # ~2368-token (3584 with thinking) engine and OOMs after provisioning. Completion = the run's
456
+ # [train].max_tokens override, else the recipe's thinking/non-thinking completion default.
457
+ if (algorithm or "").lower() in ("grpo", "rl"):
458
+ # Same engine context run_rl() launches (max_length, else max(1024, prompt+completion)) via
459
+ # the shared helper, so the allocator and the worker never disagree on the rollout length.
460
+ _grpo_default_len = grpo_rollout_seq_len(0, max_tokens, thinking)
461
+ else:
462
+ _grpo_default_len = 1024
463
+ seq_len = _pos_int(_g(train, "max_length"), _grpo_default_len)
464
+ lora_rank = _pos_int(_g(train, "lora_rank"), 32)
465
+ group_size = _pos_int(_g(train, "group_size"), 8)
466
+ # Default to the worker's per-device SFT micro-batch (4): an unset
467
+ # [train].batch_size still realizes that micro-batch on the worker, so size for it
468
+ # rather than 1 (which would under-route a long-context SFT run to a too-small card).
469
+ batch_size = _pos_int(_g(train, "batch_size"), _sft_per_device_bs())
470
+
471
+ def _need(
472
+ params_b: float,
473
+ algorithm: str,
474
+ *,
475
+ quant: str = "bf16",
476
+ use_vllm: bool = True,
477
+ vocab: int = _VOCAB_DEFAULT,
478
+ ) -> int:
479
+ # estimate over the run's full knob matrix, then apply the safety headroom. Both the
480
+ # catalog and open-model paths size through here so they stay in sync on the knob set.
481
+ est = estimate_vram_gb(
482
+ params_b,
483
+ algorithm,
484
+ quant,
485
+ seq_len=seq_len,
486
+ max_tokens=max_tokens,
487
+ lora_rank=lora_rank,
488
+ batch_size=batch_size,
489
+ group_size=group_size,
490
+ thinking=thinking,
491
+ use_vllm=use_vllm,
492
+ vocab=vocab,
493
+ )
494
+ return math.ceil(est * headroom)
495
+
496
+ from flash.catalog import MODELS, vocab_size_for
497
+
498
+ info = MODELS.get(model_id)
499
+ # Per-model output vocab (lm_head / logits width) sizes the fp32-logits term; resolved
500
+ # from the catalog (curated value, else open-model fallback) instead of a hardcoded const.
501
+ model_vocab = vocab_size_for(model_id)
502
+ is_grpo = (algorithm or "").lower() in ("grpo", "rl")
503
+ if info is not None:
504
+ params_b = params_b_from_str(info.params)
505
+ quant = getattr(info, "quant", "bf16") or "bf16"
506
+ # GRPO always runs the rollout on a colocated vLLM engine, so sizing must reserve room for
507
+ # the 2nd (rollout) weight copy on the same card.
508
+ use_vllm = True
509
+ need = _need(params_b or 4.0, algorithm, quant=quant, use_vllm=use_vllm, vocab=model_vocab)
510
+ # Hard floor the param-based matrix can't see: a curated GRPO floor.
511
+ floor = 0
512
+ if is_grpo and getattr(info, "grpo_min_vram_gb", 0):
513
+ floor = int(info.grpo_min_vram_gb)
514
+ # Big-model GRPO is TIGHT at its floor (2 weight copies + KV pool), so long context
515
+ # overflows it -> escalate to a bigger tier. See grpo_seq_escalation_gb.
516
+ if is_grpo and floor:
517
+ floor += grpo_seq_escalation_gb(params_b, seq_len)
518
+ need = max(need, floor)
519
+ # vLLM-colocate floor: the engine (CUDA context + KV pool sized to the CARD's VRAM +
520
+ # framework) + the 2nd resident weight copy add a ~constant the param estimate misses,
521
+ # so small-model GRPO under-provisions. MEASURED at the default group_size=8: 0.8B GRPO
522
+ # fits a 24 GB card but OOMs 20 (est ~18, ~6 GB headroom on 24); 2B GRPO OOMs a 24 GB
523
+ # card (est ~20 but the colocate cost tips it past 24 -> needs the 32 tier). So sub-~1B
524
+ # models floor at 24, while larger small-models that the param estimate still under-shoots
525
+ # floor at the 32 tier. 4B+ already exceed this via their param estimate, so untouched.
526
+ if is_grpo and use_vllm:
527
+ floor_gb = 24 if (params_b or 0.0) <= 1.0 else int(_VLLM_COLOCATE_FLOOR_GB)
528
+ need = max(need, floor_gb)
529
+ return need
530
+ # Unlisted open model: size from HF metadata (GRPO is the heavier phase).
531
+ params_b = fetch_hf_params_b(model_id)
532
+ if params_b is None:
533
+ return 24
534
+ # Open models size against the heavier GRPO phase regardless of the requested algorithm.
535
+ need = _need(params_b, "grpo", vocab=model_vocab)
536
+ # Same long-context GRPO escalation as the catalog path so a big open model isn't
537
+ # under-provisioned at long context either.
538
+ if is_grpo:
539
+ need += grpo_seq_escalation_gb(params_b, seq_len)
540
+ return need
541
+
542
+
543
+ def fetch_hf_params_b(model_id: str) -> float | None:
544
+ """Total params (billions) from the HF API safetensors metadata (no download).
545
+
546
+ Best-effort: returns ``None`` when the size can't be read (no network / no HF metadata),
547
+ so callers fall back to the offline heuristic rather than failing.
548
+ """
549
+ try:
550
+ from huggingface_hub import HfApi
551
+
552
+ info = HfApi(token=os.environ.get("HF_TOKEN")).model_info(
553
+ model_id, expand=["safetensors"]
554
+ )
555
+ total = getattr(getattr(info, "safetensors", None), "total", None)
556
+ if total:
557
+ return float(total) / 1e9
558
+ except Exception:
559
+ # Best-effort size probe (network/HF-metadata may be unavailable); fall through
560
+ # to None so callers report "size unknown" rather than failing.
561
+ pass
562
+ return None
563
+
564
+
565
+ def resolve_params_b(model_id: str) -> float | None:
566
+ """Model size in billions, resolved the ONE way the worker and the cost estimator agree on:
567
+ the curated catalog ``params_b`` (else its ``params`` display string), else the real HF
568
+ safetensors param count for an open-policy (uncataloged) model. Best-effort: returns None only
569
+ when the model is uncataloged AND HF metadata is unavailable, so callers degrade to the
570
+ size-unknown path (e.g. the fused-CE gate stays memory-safe, the colocate cap stays loose).
571
+ The single source of truth for "how big is this model" -- run_sft, run_rl and cost.spec all
572
+ call this so they can never drift."""
573
+ from flash.catalog import MODELS
574
+
575
+ info = MODELS.get(model_id)
576
+ if info is not None:
577
+ pb = getattr(info, "params_b", 0.0) or params_b_from_str(getattr(info, "params", None))
578
+ if pb:
579
+ return pb
580
+ return fetch_hf_params_b(model_id)
581
+
582
+
583
+ def check_fit(
584
+ model_id: str,
585
+ algorithm: str,
586
+ gpu: str,
587
+ quant: str = "bf16",
588
+ params_b: float | None = None,
589
+ ) -> VramEstimate:
590
+ """Estimate whether ``model_id`` plausibly trains on ``gpu``; never raises."""
591
+ gpu_gb = GPU_VRAM_GB.get(gpu, 32)
592
+ if params_b is None:
593
+ params_b = fetch_hf_params_b(model_id)
594
+ if params_b is None:
595
+ return VramEstimate(None, algorithm, quant, None, gpu, gpu_gb, "unknown")
596
+ est = estimate_vram_gb(params_b, algorithm, quant)
597
+ if est > gpu_gb * 1.15:
598
+ verdict = "too_big"
599
+ elif est > gpu_gb * 0.85:
600
+ verdict = "tight"
601
+ else:
602
+ verdict = "fits"
603
+ return VramEstimate(params_b, algorithm, quant, est, gpu, gpu_gb, verdict)