freesolo-flash-dev 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flash/__init__.py +29 -0
- flash/_channel.py +23 -0
- flash/_fileio.py +35 -0
- flash/_logging.py +49 -0
- flash/_update_check.py +266 -0
- flash/catalog.py +253 -0
- flash/cli/__init__.py +1 -0
- flash/cli/main/__init__.py +227 -0
- flash/cli/main/__main__.py +6 -0
- flash/cli/main/commands.py +636 -0
- flash/cli/main/envpush.py +317 -0
- flash/cli/main/render.py +599 -0
- flash/cli/main/training_doc.py +455 -0
- flash/client/__init__.py +14 -0
- flash/client/config.py +70 -0
- flash/client/http.py +372 -0
- flash/client/runtime_secrets.py +69 -0
- flash/client/specs.py +20 -0
- flash/cost/__init__.py +16 -0
- flash/cost/analytical.py +175 -0
- flash/cost/facts.py +114 -0
- flash/cost/spec.py +113 -0
- flash/cost/types.py +158 -0
- flash/engine/__init__.py +6 -0
- flash/engine/accounting.py +36 -0
- flash/engine/chalk_kernels.py +116 -0
- flash/engine/multiturn_rollout.py +780 -0
- flash/engine/recipe.py +86 -0
- flash/engine/vram.py +603 -0
- flash/engine/worker/__init__.py +2916 -0
- flash/engine/worker/__main__.py +4 -0
- flash/engine/worker/kernel_warmup.py +400 -0
- flash/engine/worker/lora.py +796 -0
- flash/engine/worker/packing.py +366 -0
- flash/engine/worker/perf.py +1048 -0
- flash/envs/__init__.py +10 -0
- flash/envs/adapter/__init__.py +883 -0
- flash/envs/adapter/rubric.py +222 -0
- flash/envs/base.py +52 -0
- flash/envs/registry.py +62 -0
- flash/mcp/__init__.py +1 -0
- flash/mcp/server.py +85 -0
- flash/providers/__init__.py +59 -0
- flash/providers/_auth.py +24 -0
- flash/providers/_http.py +230 -0
- flash/providers/_instance.py +416 -0
- flash/providers/_instance_bootstrap.py +517 -0
- flash/providers/_poll.py +311 -0
- flash/providers/allocator.py +193 -0
- flash/providers/base.py +431 -0
- flash/providers/hyperstack/__init__.py +127 -0
- flash/providers/hyperstack/api.py +522 -0
- flash/providers/hyperstack/auth.py +17 -0
- flash/providers/hyperstack/gpus.py +29 -0
- flash/providers/hyperstack/jobs/__init__.py +632 -0
- flash/providers/hyperstack/jobs/builders.py +122 -0
- flash/providers/hyperstack/preflight.py +23 -0
- flash/providers/hyperstack/pricing.py +26 -0
- flash/providers/hyperstack/train.py +25 -0
- flash/providers/lambdalabs/__init__.py +139 -0
- flash/providers/lambdalabs/api.py +261 -0
- flash/providers/lambdalabs/auth.py +18 -0
- flash/providers/lambdalabs/gpus.py +29 -0
- flash/providers/lambdalabs/jobs/__init__.py +724 -0
- flash/providers/lambdalabs/jobs/builders.py +118 -0
- flash/providers/lambdalabs/preflight.py +27 -0
- flash/providers/lambdalabs/pricing.py +51 -0
- flash/providers/lambdalabs/train.py +27 -0
- flash/providers/preflight.py +55 -0
- flash/providers/realized.py +80 -0
- flash/providers/runpod/__init__.py +130 -0
- flash/providers/runpod/api.py +186 -0
- flash/providers/runpod/auth.py +37 -0
- flash/providers/runpod/cost.py +57 -0
- flash/providers/runpod/gpus.py +46 -0
- flash/providers/runpod/jobs.py +956 -0
- flash/providers/runpod/keys.py +139 -0
- flash/providers/runpod/preflight.py +30 -0
- flash/providers/runpod/preload.py +915 -0
- flash/providers/runpod/pricing.py +18 -0
- flash/providers/runpod/slots.py +79 -0
- flash/providers/runpod/train/__init__.py +150 -0
- flash/providers/runpod/train/deps.py +395 -0
- flash/providers/runpod/train/endpoints.py +820 -0
- flash/py.typed +0 -0
- flash/runner/__init__.py +686 -0
- flash/runner/checkpoints.py +82 -0
- flash/runner/deploy.py +422 -0
- flash/runner/lifecycle.py +672 -0
- flash/schema/__init__.py +375 -0
- flash/schema/fields.py +331 -0
- flash/serve/__init__.py +1 -0
- flash/serve/deploy.py +326 -0
- flash/serve/pricing.py +60 -0
- flash/server/__init__.py +1 -0
- flash/server/__main__.py +20 -0
- flash/server/app.py +961 -0
- flash/server/auth.py +263 -0
- flash/server/billing.py +124 -0
- flash/server/checkpoints.py +110 -0
- flash/server/db.py +160 -0
- flash/server/environment_registry.py +102 -0
- flash/server/envs.py +360 -0
- flash/server/reconcile.py +163 -0
- flash/server/run_registry.py +150 -0
- flash/spec.py +333 -0
- freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
- freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
- freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
- freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
- freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
flash/engine/vram.py
ADDED
|
@@ -0,0 +1,603 @@
|
|
|
1
|
+
"""Coarse VRAM-fit estimation for one-consumer-GPU LoRA jobs.
|
|
2
|
+
|
|
3
|
+
Used by the open-model policy (``model_policy = "allow"``) to sanity-check that an
|
|
4
|
+
unlisted HF model can plausibly run on the requested GPU before provisioning it.
|
|
5
|
+
|
|
6
|
+
These are deliberately coarse heuristics (documented ±20%): they exist to catch
|
|
7
|
+
*provably impossible* configurations (70B bf16 on a 24 GB card) and to warn on tight
|
|
8
|
+
fits — not to guarantee success. Calibrated against the measured catalog entries
|
|
9
|
+
(Qwen3-0.6B/4B/8B, Qwen3.5 dense).
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import math
|
|
15
|
+
import os
|
|
16
|
+
import re
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _gpu_vram_table() -> dict[str, int]:
|
|
21
|
+
try:
|
|
22
|
+
from flash.providers.base import GPU_INFO
|
|
23
|
+
|
|
24
|
+
return {name: info.vram_gb for name, info in GPU_INFO.items()}
|
|
25
|
+
except Exception:
|
|
26
|
+
return {"RTX 4090": 24, "RTX 5090": 32}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
GPU_VRAM_GB = _gpu_vram_table()
|
|
30
|
+
|
|
31
|
+
_BYTES_PER_PARAM = {
|
|
32
|
+
"bf16": 2.0,
|
|
33
|
+
"fp16": 2.0,
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
# Fixed overheads (GB): CUDA context + activations w/ gradient checkpointing +
|
|
37
|
+
# LoRA params/grads/Adam states (tiny at rank<=64) + fragmentation headroom.
|
|
38
|
+
_BASE_OVERHEAD_GB = 4.0
|
|
39
|
+
# Activations with gradient checkpointing scale ~linearly with tokens-in-flight
|
|
40
|
+
# (batch x seq) and model width (~sqrt of params). Coef calibrated so 4.7B SFT at
|
|
41
|
+
# seq 32k / batch 1 lands ~22 GB (measured: fits a 32 GB 5090).
|
|
42
|
+
_ACT_COEF = 0.12
|
|
43
|
+
# SFT activations + logits peak on the worker's PER-DEVICE micro-batch, not [train].batch_size
|
|
44
|
+
# (which is the global/effective batch realized via gradient accumulation). The worker caps the
|
|
45
|
+
# micro-batch at 4 and, when the fused CE is off, vocab-sizes it further to the logits budget (see
|
|
46
|
+
# ``sft_per_device``). Mirror that here so an unset/long-context SFT run still reserves the
|
|
47
|
+
# micro-batch peak, and a large effective batch isn't mis-counted as resident VRAM (it's grad-accum,
|
|
48
|
+
# not in-flight activations).
|
|
49
|
+
_SFT_PER_DEVICE_BS_DEFAULT = 4
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _sft_per_device_bs() -> int:
|
|
53
|
+
"""The worker's BASE per-device SFT micro-batch cap (before the big-vocab logits cap layered on
|
|
54
|
+
by ``sft_per_device``) — the activation-peak driver to size against.
|
|
55
|
+
|
|
56
|
+
SFT micro-batch is a MANAGED default: the control plane no longer forwards
|
|
57
|
+
``SFT_PER_DEVICE_BS`` to the worker (build_worker_env dropped the tuning allowlist), and the
|
|
58
|
+
worker's own process env never carries it, so the worker always runs the fixed default. The
|
|
59
|
+
allocator must size against that SAME fixed value — reading the control-plane process env here
|
|
60
|
+
would size a card for a micro-batch the worker never uses, under-routing an
|
|
61
|
+
``SFT_PER_DEVICE_BS=1`` operator env to a too-small GPU that then OOMs at the default
|
|
62
|
+
micro-batch 4 (the asymmetry the env-knobs cleanup removed everywhere else)."""
|
|
63
|
+
return _SFT_PER_DEVICE_BS_DEFAULT
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def sft_grad_accum(
|
|
67
|
+
batch_size: int, *, seq_len: int = 0, vocab: int = 0, fused: bool = True
|
|
68
|
+
) -> tuple[int, int]:
|
|
69
|
+
"""(per-device micro-batch, grad-accum steps) the worker realizes for a requested GLOBAL
|
|
70
|
+
``batch_size``: per_device capped at the micro-batch default (and ADDITIONALLY vocab-sized to
|
|
71
|
+
the logits budget when the fused CE is off — see ``sft_per_device``), grad-accum CEIL'd so the
|
|
72
|
+
realized global batch is never BELOW the request (e.g. batch 6 -> per_device 4 x
|
|
73
|
+
ceil(6/4)=2 grad-accum = realized 8, >= 6).
|
|
74
|
+
|
|
75
|
+
``seq_len``/``vocab``/``fused`` are the big-vocab logits-cap inputs; omitted (or ``fused``) they
|
|
76
|
+
reduce to the old fixed per-device cap, so existing callers are unchanged."""
|
|
77
|
+
target = max(1, int(batch_size))
|
|
78
|
+
per_device = sft_per_device(target, seq_len=seq_len, vocab=vocab, fused=fused)
|
|
79
|
+
grad_accum = max(1, -(-target // per_device)) # ceil
|
|
80
|
+
return per_device, grad_accum
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def sft_realized_batch(
|
|
84
|
+
batch_size: int, *, seq_len: int = 0, vocab: int = 0, fused: bool = True
|
|
85
|
+
) -> int:
|
|
86
|
+
"""The realized SFT global batch (per_device x grad_accum) for a requested ``batch_size`` —
|
|
87
|
+
mirrors the worker so the cost step-count matches. Pass seq_len/vocab/fused to honor the
|
|
88
|
+
big-vocab per-device cap (the cost path does); omitted, it's the old fixed-cap behavior."""
|
|
89
|
+
per_device, grad_accum = sft_grad_accum(batch_size, seq_len=seq_len, vocab=vocab, fused=fused)
|
|
90
|
+
return per_device * grad_accum
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
# Colocated-GRPO vLLM KV pool: grows with the engine's max context (seq) and model
|
|
94
|
+
# width, but vLLM bounds the pool to a fraction of the card and PAGES rather than OOMs,
|
|
95
|
+
# so it's capped (_KV_CAP) instead of growing without bound at long context.
|
|
96
|
+
_KV_COEF = 2.0
|
|
97
|
+
_KV_CAP = 8.0
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def grpo_rollout_seq_len(
|
|
101
|
+
max_length: int = 0,
|
|
102
|
+
max_tokens: int | None = None,
|
|
103
|
+
thinking: bool = False,
|
|
104
|
+
) -> int:
|
|
105
|
+
"""The vLLM engine context a GRPO run ACTUALLY uses, mirroring run_rl(): the run's
|
|
106
|
+
``[train].max_length`` when set, else ``max(1024, RLConfig.max_prompt_len + completion)`` where
|
|
107
|
+
``completion`` is ``[train].max_tokens`` or the recipe's thinking/non-thinking default. The
|
|
108
|
+
allocator sizing, the sleep-mode resident gate, and the colocate KV budget all resolve the SAME
|
|
109
|
+
length here so a run whose max_length is unset is not sized as a 1024-token rollout while the
|
|
110
|
+
worker launches a ~2368-token (3584 with thinking) engine."""
|
|
111
|
+
from flash.engine.recipe import RECIPE
|
|
112
|
+
|
|
113
|
+
rl = RECIPE.rl
|
|
114
|
+
completion = int(
|
|
115
|
+
max_tokens or (rl.max_completion_len_thinking if thinking else rl.max_completion_len)
|
|
116
|
+
)
|
|
117
|
+
return int(max_length or max(1024, rl.max_prompt_len + completion))
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _resident_kv_gb(params_b: float | None, vllm_max_len: int, num_generations: int = 8) -> float:
|
|
121
|
+
"""KV (GB) a colocated rollout engine holds RESIDENT for the engine context + generation group.
|
|
122
|
+
Scales with BOTH (vLLM's cache blocks must cover ``vllm_max_model_length`` for ``num_generations``
|
|
123
|
+
concurrent sequences) -- unlike the sleep-mode rollout estimate, which caps it (``_KV_CAP``)
|
|
124
|
+
because the engine is offloaded during the backward there. Shared by the resident-fit estimate
|
|
125
|
+
and the non-sleep colocate budget so the gate and the budget size the SAME KV."""
|
|
126
|
+
width = math.sqrt(max(float(params_b or 1.0), 0.1))
|
|
127
|
+
return _KV_COEF * (max(1, vllm_max_len) / 1024.0) * width * (max(1, num_generations) / 8.0)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def colocate_kv_util(
|
|
131
|
+
params_b: float | None,
|
|
132
|
+
vllm_max_len: int,
|
|
133
|
+
total_vram_gb: float,
|
|
134
|
+
sleep_mode: bool,
|
|
135
|
+
num_generations: int = 8,
|
|
136
|
+
) -> float:
|
|
137
|
+
"""``vllm_gpu_memory_utilization`` for the colocated GRPO rollout engine, sized to the ACTUAL need
|
|
138
|
+
rather than a blanket fraction of the card.
|
|
139
|
+
|
|
140
|
+
``gpu_memory_utilization`` is vLLM's WHOLE model-executor budget — its (2nd) bf16 weight copy PLUS
|
|
141
|
+
the KV cache — so we budget BOTH (budgeting KV alone would starve the weights and, for big models,
|
|
142
|
+
under-size the engine). The KV a GRPO rollout needs scales with the engine context AND the
|
|
143
|
+
concurrent generation group (``num_generations`` simultaneous sequences), so we size the pool as
|
|
144
|
+
``_KV_COEF x seq x sqrt(params) x group/8`` with a 1.5x margin and an 8 GB floor — NOT capped, so
|
|
145
|
+
long-context / large-group runs keep a big pool (the 0.45 utilization cap bounds it like the old
|
|
146
|
+
blanket did). The old blanket sleep-path 0.45 reserved ~36 GB on an 80 GB A100 — MEASURED as the
|
|
147
|
+
dominant resident allocation that set the GRPO step peak (~46 GB). BOTH paths budget the weight
|
|
148
|
+
copy + KV; the non-sleep path uses the leaner resident-KV target (_KV_CAP). MEASURED at
|
|
149
|
+
4B/group8/2k ctx: 0.25 util -> peak 46 -> 26 GB, reward byte-identical, train_wall neutral; a
|
|
150
|
+
tighter 12 GB budget preempts, confirming this as the floor."""
|
|
151
|
+
weights_gb = max(0.5, float(params_b or 1.0)) * 2.0 # vLLM's bf16 weight copy lives in the budget
|
|
152
|
+
if not sleep_mode:
|
|
153
|
+
# Resident KV ON TOP of the weight copy: gpu_memory_utilization is the WHOLE executor budget,
|
|
154
|
+
# so budgeting KV alone (the old _KV_CAP/total) starved the weights and vLLM raised "No
|
|
155
|
+
# available memory for the cache blocks" on >=3B models whose weights exceed an 8 GB budget.
|
|
156
|
+
# The KV must ALSO cover the rollout context -- a flat _KV_CAP starves the cache blocks on a
|
|
157
|
+
# long-context run (vLLM's blocks must span vllm_max_model_length), so scale it with the
|
|
158
|
+
# context + group (floored at _KV_CAP for the validated short-context lean point, bounded by
|
|
159
|
+
# the 0.45 util cap below). Matches the resident-fit estimate (estimate_vram_gb sleep_offload
|
|
160
|
+
# =False) so grpo_sleep_mode's gate and this budget size the SAME KV.
|
|
161
|
+
kv_gb = max(_KV_CAP, _resident_kv_gb(params_b, vllm_max_len, num_generations))
|
|
162
|
+
return max(0.10, min(0.45, (weights_gb + kv_gb) / max(1.0, total_vram_gb)))
|
|
163
|
+
# Sleep mode keeps a larger pool (1.5x margin): the engine is offloaded during the backward, so a
|
|
164
|
+
# bigger rollout-phase KV does not compete with the training peak.
|
|
165
|
+
kv_pool_gb = max(_KV_CAP, 1.5 * _resident_kv_gb(params_b, vllm_max_len, num_generations))
|
|
166
|
+
return min(0.45, (weights_gb + kv_pool_gb) / max(1.0, total_vram_gb))
|
|
167
|
+
# GRPO backward (activations + fp32 logits over the completion micro-batch) per unit
|
|
168
|
+
# context x model width. Grad checkpointing makes this MILD in seq -- calibrated to
|
|
169
|
+
# measured boundaries: 0.8B GRPO fits 24 GB up to seq 32k (seq ~free), while 4.7B GRPO
|
|
170
|
+
# steps off a 32 GB card between seq 16k and 32k. group size scales it sublinearly.
|
|
171
|
+
_TRAIN_COEF = 0.27
|
|
172
|
+
# Fixed floor for colocated-vLLM GRPO: the vLLM engine's CUDA context + KV pool (sized to the
|
|
173
|
+
# CARD's VRAM via gpu_util, not the model) + the 2nd resident weight copy is ~model-independent
|
|
174
|
+
# for small models and dominates their param estimate, so tiny/mid models all need the 32 GB tier.
|
|
175
|
+
# MEASURED at the default group_size=8: 0.8B GRPO OOMs a 20 GB card; 2B GRPO OOMs a 24 GB card
|
|
176
|
+
# (-> both need 32); 4B GRPO fits 32 (param est ~31 already clears this floor, so it's untouched).
|
|
177
|
+
_VLLM_COLOCATE_FLOOR_GB = 28.0
|
|
178
|
+
# Fallback output vocab (lm_head / logits width) for estimate_vram_gb when no model vocab is
|
|
179
|
+
# passed; the model-aware path (model_required_vram_gb) resolves the real per-model value
|
|
180
|
+
# from flash.catalog via vocab_size_for(). Mirrors catalog._DEFAULT_VOCAB_SIZE.
|
|
181
|
+
_VOCAB_DEFAULT = 248_320
|
|
182
|
+
# Matches the worker's logits budget (6 GB): the per-device fp32 logits are capped to this
|
|
183
|
+
# (rl_per_device_comps spills the rest into grad-accum), so the estimator never reserves above it.
|
|
184
|
+
_LOGITS_BUDGET_GB = 6.0
|
|
185
|
+
|
|
186
|
+
# ---- SFT big-vocab logits: the SFT analog of the GRPO fp32-logits term above ----
|
|
187
|
+
# When the worker's fused cross-entropy (Liger) is OFF, an SFT forward materializes the FULL-sequence
|
|
188
|
+
# [per_device, seq_len, vocab] logits AND keeps their gradient live through the backward. At
|
|
189
|
+
# Qwen3.5's ~248k vocab this is the documented big-vocab SFT OOM driver (a 0.8B SFT OOM'd a 24 GB
|
|
190
|
+
# card). The worker fuses CE only for a >=3B model OR a >=2048-token context (mirrors
|
|
191
|
+
# engine.worker.perf._memory_mode); BELOW that the term is real and was previously ignored entirely.
|
|
192
|
+
# An SFT step holds, AT ONCE, the fp32 logits (4) + their fp32 grad (4) + the bf16 logits the model
|
|
193
|
+
# emits (2) + the bf16 grad (2) + the cross-entropy log_softmax temp (4) ~= 16 B/elem. (8 B/elem --
|
|
194
|
+
# fp32 logits+grad only -- UNDER-counted: a live 2B SFT seq1024 at per_device=2 peaked ~15.8 GiB and
|
|
195
|
+
# OOM'd a 16 GB card whose usable is ~15.6 GiB.) At 16 B/elem the per-device cap drops to 1 for a
|
|
196
|
+
# big-vocab un-fused SFT, so the worker materializes far less and the real peak clears even the
|
|
197
|
+
# tightest 16 GB card. The worker vocab-sizes the per-device micro-batch so these logits never
|
|
198
|
+
# exceed _LOGITS_BUDGET_GB while pd CAN be reduced; the estimator reserves the TRUE per-device-capped
|
|
199
|
+
# term (no budget clamp -- the irreducible pd=1 floor can exceed the budget at a near-2048 ctx) -- so
|
|
200
|
+
# the allocator provably covers the worker's real peak. VALIDATED by a live re-run.
|
|
201
|
+
_SFT_LOGITS_BYTES_PER_ELEM = 16.0
|
|
202
|
+
# Canonical fused-CE (Liger) gate thresholds: the worker fuses the SFT cross-entropy for a >=3B
|
|
203
|
+
# model OR a >=2048-token context. SINGLE SOURCE OF TRUTH -- engine.worker.perf imports these (its
|
|
204
|
+
# _LIGER_MIN_PARAMS / _LONG_CONTEXT_TOKENS derive from them) and sft_logits_fused mirrors the gate
|
|
205
|
+
# offline (no network AutoConfig probe) so the cost estimator stays deterministic.
|
|
206
|
+
_LIGER_MIN_PARAMS_B = 3.0
|
|
207
|
+
_LIGER_LONG_CTX_TOKENS = 2048
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def sft_logits_fused(params_b: float | None, seq_len: int) -> bool:
|
|
211
|
+
"""Whether the worker fuses the SFT cross-entropy (Liger), so the [per_device, seq, vocab] logits
|
|
212
|
+
never materialize. Mirrors engine.worker.perf._memory_mode without a network probe: fused for a
|
|
213
|
+
>=3B model OR a >=2048-token context. (The worker image bakes liger-kernel, so True here means
|
|
214
|
+
the fused kernel is actually used; if it were ever absent the per-device cap still bounds the
|
|
215
|
+
logits.)"""
|
|
216
|
+
if seq_len >= _LIGER_LONG_CTX_TOKENS:
|
|
217
|
+
return True
|
|
218
|
+
return (params_b or 0.0) >= _LIGER_MIN_PARAMS_B
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def sft_logits_per_device_cap(seq_len: int, vocab: int) -> int:
|
|
222
|
+
"""Largest SFT per-device micro-batch whose un-fused [per_device, seq, vocab] fp32 logits (+grad)
|
|
223
|
+
fit _LOGITS_BUDGET_GB. The SFT mirror of rl_per_device_comps' completion cap, sizing the FULL
|
|
224
|
+
sequence (not just the completion): the worker spills the remainder into grad-accum so the
|
|
225
|
+
realized global batch is unchanged, and the estimator reserves the same bounded term."""
|
|
226
|
+
denom = max(1, int(seq_len)) * max(1, int(vocab)) * _SFT_LOGITS_BYTES_PER_ELEM
|
|
227
|
+
return max(1, int(_LOGITS_BUDGET_GB * 1e9 / denom))
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def sft_per_device(batch_size: int, *, seq_len: int = 0, vocab: int = 0, fused: bool = True) -> int:
|
|
231
|
+
"""The per-device SFT micro-batch the worker runs: the requested global batch capped at the
|
|
232
|
+
micro-batch default (4) and ADDITIONALLY vocab-sized to the logits budget when the fused CE is
|
|
233
|
+
OFF (small model AND short context) — so the big-vocab [per_device, seq, vocab] logits can't OOM
|
|
234
|
+
the card. With seq_len/vocab unset (or fused), this remains the fixed cap."""
|
|
235
|
+
per_device = max(1, min(_SFT_PER_DEVICE_BS_DEFAULT, max(1, int(batch_size))))
|
|
236
|
+
if not fused and seq_len and vocab:
|
|
237
|
+
per_device = min(per_device, sft_logits_per_device_cap(seq_len, vocab))
|
|
238
|
+
return per_device
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def grpo_seq_escalation_gb(params_b: float | None, seq_len: int) -> int:
|
|
242
|
+
"""Extra GB a long-context GRPO run needs beyond its base footprint.
|
|
243
|
+
|
|
244
|
+
Big-model GRPO is tight: colocate holds 2 weight copies + a KV pool, so headroom shrinks
|
|
245
|
+
with model size and long context overflows it. Calibrated on a bf16 9.7B GRPO run (RunPod):
|
|
246
|
+
fits 80 GB to seq 4096 but OOMs at 8192. Safe headroom ~ 48500/params_b tokens; past that
|
|
247
|
+
escalate, STEEPER for bigger models. Applies to both catalog and open-model GRPO so neither
|
|
248
|
+
under-provisions.
|
|
249
|
+
"""
|
|
250
|
+
coef = 0.9
|
|
251
|
+
if not params_b:
|
|
252
|
+
return 0
|
|
253
|
+
seq_thresh = 48_500.0 / params_b
|
|
254
|
+
if seq_len <= seq_thresh:
|
|
255
|
+
return 0
|
|
256
|
+
return math.ceil(coef * params_b * (seq_len / seq_thresh - 1))
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def params_b_from_str(s: str | None) -> float | None:
|
|
260
|
+
"""Leading param count (billions) from a catalog ``params`` string, e.g.
|
|
261
|
+
"4.7B (text-only fine-tune)" -> 4.7, "9.7B (text-only fine-tune)" -> 9.7."""
|
|
262
|
+
if not s:
|
|
263
|
+
return None
|
|
264
|
+
m = re.search(r"([0-9]+(?:\.[0-9]+)?)\s*B", s)
|
|
265
|
+
return float(m.group(1)) if m else None
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
@dataclass(frozen=True)
|
|
269
|
+
class VramEstimate:
|
|
270
|
+
params_b: float | None
|
|
271
|
+
algorithm: str
|
|
272
|
+
quant: str
|
|
273
|
+
est_gb: float | None
|
|
274
|
+
gpu: str
|
|
275
|
+
gpu_gb: int
|
|
276
|
+
verdict: str # "fits" | "tight" | "too_big" | "unknown"
|
|
277
|
+
|
|
278
|
+
def describe(self) -> str:
|
|
279
|
+
if self.est_gb is None:
|
|
280
|
+
return f"{self.gpu}: VRAM need unknown (could not read model size)"
|
|
281
|
+
return (
|
|
282
|
+
f"{self.gpu} ({self.gpu_gb} GB): estimated ~{self.est_gb:.0f} GB needed "
|
|
283
|
+
f"({self.params_b:.1f}B params, {self.quant}, {self.algorithm}) -> {self.verdict}"
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def estimate_vram_gb(
|
|
288
|
+
params_b: float,
|
|
289
|
+
algorithm: str,
|
|
290
|
+
quant: str = "bf16",
|
|
291
|
+
*,
|
|
292
|
+
seq_len: int = 1024,
|
|
293
|
+
max_tokens: int | None = None,
|
|
294
|
+
lora_rank: int = 32,
|
|
295
|
+
batch_size: int = 1,
|
|
296
|
+
group_size: int = 8,
|
|
297
|
+
thinking: bool = False,
|
|
298
|
+
use_vllm: bool = True,
|
|
299
|
+
vocab: int = _VOCAB_DEFAULT,
|
|
300
|
+
sleep_offload: bool = True,
|
|
301
|
+
) -> float:
|
|
302
|
+
"""Estimated peak VRAM (GB) for a LoRA job on one GPU, over the full knob matrix.
|
|
303
|
+
|
|
304
|
+
Terms (all in GB):
|
|
305
|
+
weights params x bytes/param (bf16=2)
|
|
306
|
+
base CUDA context + framework + fragmentation headroom
|
|
307
|
+
lora_opt LoRA adapter + grads + Adam states (rank-linear, model-scaled)
|
|
308
|
+
activations grad-checkpointed activations ~ batch x seq x sqrt(params)
|
|
309
|
+
grpo only:
|
|
310
|
+
+weights colocated vLLM holds a 2nd resident weight copy at the rollout peak
|
|
311
|
+
(sleep mode offloads it BETWEEN steps, not during) -- skipped when
|
|
312
|
+
use_vllm is False (transformers generation, single copy)
|
|
313
|
+
kv vLLM KV pool ~ seq x sqrt(params)
|
|
314
|
+
logits fp32 logits [per_device_comps, completion, vocab]
|
|
315
|
+
"""
|
|
316
|
+
bpp = _BYTES_PER_PARAM.get(quant, 2.0)
|
|
317
|
+
weights = params_b * bpp
|
|
318
|
+
algo = "grpo" if (algorithm or "").lower() in ("grpo", "rl") else "sft"
|
|
319
|
+
width = math.sqrt(max(params_b, 0.1))
|
|
320
|
+
lora_opt = (lora_rank / 16.0) * (0.3 + 0.04 * params_b)
|
|
321
|
+
base = weights + _BASE_OVERHEAD_GB + lora_opt
|
|
322
|
+
if algo == "grpo":
|
|
323
|
+
# GRPO alternates two phases that DON'T peak together (sleep mode offloads the
|
|
324
|
+
# vLLM engine during the backward), so peak = max(rollout, train), not the sum:
|
|
325
|
+
# rollout: colocated vLLM 2nd weight copy + KV pool (skipped if use_vllm=False)
|
|
326
|
+
# train: backward activations + fp32 logits -- MILD in seq (grad ckpt)
|
|
327
|
+
rollout = 0.0
|
|
328
|
+
if use_vllm:
|
|
329
|
+
if sleep_offload:
|
|
330
|
+
# Sleep mode offloads the engine during the backward, so the rollout-phase KV (capped
|
|
331
|
+
# at _KV_CAP) never competes with the training peak. This is the ALLOCATOR's estimate
|
|
332
|
+
# (model_required_vram_gb) -- keep it calibrated; it sizes every GRPO allocation.
|
|
333
|
+
rollout = weights + min(_KV_COEF * (seq_len / 1024.0) * width, _KV_CAP)
|
|
334
|
+
else:
|
|
335
|
+
# Resident: the engine stays live THROUGH the backward, so its KV (which must cover the
|
|
336
|
+
# rollout context for the whole generation group) is held alongside training -- size it
|
|
337
|
+
# to the real context, matching colocate_kv_util's non-sleep budget, instead of the
|
|
338
|
+
# flat _KV_CAP (which let grpo_fits_resident wrongly admit long-context runs).
|
|
339
|
+
rollout = weights + _resident_kv_gb(params_b, seq_len, group_size)
|
|
340
|
+
group_factor = max(1.0, (max(1, group_size) / 4.0) ** 0.5)
|
|
341
|
+
think_factor = 1.3 if thinking else 1.0
|
|
342
|
+
activations = _TRAIN_COEF * (seq_len / 1024.0) * width * group_factor * think_factor
|
|
343
|
+
# fp32 logits [per_device, completion, vocab] are the documented GRPO OOM driver. The
|
|
344
|
+
# worker MEMORY-CAPS per_device (rl_per_device_comps) so the live logits never exceed the
|
|
345
|
+
# logits budget (6 GB) and the rest spills into grad-accum -- so the IRREDUCIBLE floor the
|
|
346
|
+
# card must hold is the per_device=1 logits for the completion length: it scales with
|
|
347
|
+
# max_tokens (NOT seq_len) and is capped at the budget. completion defaults to the recipe
|
|
348
|
+
# budget (~min(seq_len, 1024)) when max_tokens is unset.
|
|
349
|
+
completion = max_tokens if max_tokens else min(seq_len, 1024)
|
|
350
|
+
logits = min(completion * vocab * 4 / 1e9, _LOGITS_BUDGET_GB)
|
|
351
|
+
train = activations + logits
|
|
352
|
+
# Sleep mode offloads the vLLM rollout engine during the backward, so rollout and train
|
|
353
|
+
# don't peak together (peak = max). WITHOUT sleep the engine stays resident through the
|
|
354
|
+
# backward, so both are live at once (peak = sum). sleep_offload=False sizes that resident
|
|
355
|
+
# peak -- used by grpo_fits_resident to decide whether a run can skip sleep mode.
|
|
356
|
+
return base + (max(rollout, train) if sleep_offload else rollout + train)
|
|
357
|
+
# SFT: peak = base + activations + the big-vocab logits term. Both activations and logits are
|
|
358
|
+
# driven by the worker's per-device micro-batch (capped at 4 AND vocab-sized to the logits budget
|
|
359
|
+
# when the fused CE is off), NOT the global/effective batch_size (grad-accum realizes that). Use
|
|
360
|
+
# the SAME ``sft_per_device`` the worker runs so the estimate tracks what actually executes.
|
|
361
|
+
fused = sft_logits_fused(params_b, seq_len)
|
|
362
|
+
pd = sft_per_device(batch_size, seq_len=seq_len, vocab=vocab, fused=fused)
|
|
363
|
+
activations = _ACT_COEF * pd * (seq_len / 1024.0) * width
|
|
364
|
+
# fp32-logits term: 0 when the worker fuses CE (>=3B model OR >=2048-token ctx, so the lm_head
|
|
365
|
+
# never materializes [B,T,vocab]); else the [per_device, seq_len, vocab] logits the forward holds.
|
|
366
|
+
# Reserve the TRUE per-device-capped value -- NOT clamped to the budget: the budget only chooses
|
|
367
|
+
# ``pd`` (so pd>1 cases stay <= budget), but once pd floors to 1 the logits are IRREDUCIBLE and
|
|
368
|
+
# can exceed the budget at a near-2048 ctx -- clamping there would under-reserve and OOM (the
|
|
369
|
+
# worker can't go below pd=1). The SFT analog of the GRPO logits term, sized over the FULL seq_len
|
|
370
|
+
# (SFT loss spans the sequence) -- the term the SFT estimate once ignored entirely.
|
|
371
|
+
logits = 0.0 if fused else pd * seq_len * vocab * _SFT_LOGITS_BYTES_PER_ELEM / 1e9
|
|
372
|
+
return base + activations + logits
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def grpo_fits_resident(
|
|
376
|
+
model_id: str,
|
|
377
|
+
*,
|
|
378
|
+
seq_len: int = 1024,
|
|
379
|
+
max_tokens: int | None = None,
|
|
380
|
+
lora_rank: int = 32,
|
|
381
|
+
group_size: int = 8,
|
|
382
|
+
thinking: bool = False,
|
|
383
|
+
card_vram_gb: float = 0.0,
|
|
384
|
+
margin: float = 1.15,
|
|
385
|
+
) -> bool:
|
|
386
|
+
"""Whether a colocated-vLLM GRPO run fits RESIDENT (no vLLM sleep-mode offload) on a card of
|
|
387
|
+
``card_vram_gb`` with a safety ``margin``. When it fits, sleep mode is unnecessary -- and the
|
|
388
|
+
sleep/wake cycle is what stalls the large-model GRPO rollout -- so the worker can skip it.
|
|
389
|
+
Conservative: an unknown card size or unknown model size returns False (keep the memory-safe
|
|
390
|
+
sleep default)."""
|
|
391
|
+
if not card_vram_gb or card_vram_gb <= 0:
|
|
392
|
+
return False
|
|
393
|
+
from flash.catalog import MODELS, vocab_size_for
|
|
394
|
+
|
|
395
|
+
info = MODELS.get(model_id)
|
|
396
|
+
params_b = float(getattr(info, "params_b", 0.0) or 0.0) if info else 0.0
|
|
397
|
+
if params_b <= 0:
|
|
398
|
+
return False # unknown size (open-model path) -> keep the safe default
|
|
399
|
+
quant = (getattr(info, "quant", "bf16") or "bf16") if info else "bf16"
|
|
400
|
+
resident = estimate_vram_gb(
|
|
401
|
+
params_b,
|
|
402
|
+
"grpo",
|
|
403
|
+
quant,
|
|
404
|
+
seq_len=max(1, int(seq_len or 1024)),
|
|
405
|
+
max_tokens=max_tokens,
|
|
406
|
+
lora_rank=lora_rank,
|
|
407
|
+
group_size=group_size,
|
|
408
|
+
thinking=thinking,
|
|
409
|
+
use_vllm=True,
|
|
410
|
+
vocab=vocab_size_for(model_id),
|
|
411
|
+
sleep_offload=False,
|
|
412
|
+
)
|
|
413
|
+
return resident * margin <= card_vram_gb
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
def model_required_vram_gb(
|
|
417
|
+
model_id: str,
|
|
418
|
+
algorithm: str,
|
|
419
|
+
*,
|
|
420
|
+
train=None,
|
|
421
|
+
thinking: bool = False,
|
|
422
|
+
headroom: float = 1.1,
|
|
423
|
+
) -> int:
|
|
424
|
+
"""Cheapest-sufficient VRAM (GB) for a specific run -- the matrix the allocator and
|
|
425
|
+
``provisional_gpu`` both size against.
|
|
426
|
+
|
|
427
|
+
Catalog models size from their known param count + the run's actual knobs (``train``
|
|
428
|
+
may be a TrainSpec, a dict, or None for recipe defaults). Curated GRPO floors
|
|
429
|
+
(``grpo_min_vram_gb``) stay as HARD floors so we never under-provision a validated
|
|
430
|
+
model; the matrix only ever sizes UP from there. Unlisted open models size from HF
|
|
431
|
+
metadata, falling back to the 24 GB tier when the size can't be read.
|
|
432
|
+
"""
|
|
433
|
+
|
|
434
|
+
# Best-effort knob extraction: this provisional sizing runs at parse time BEFORE the
|
|
435
|
+
# dedicated train validators, so malformed knobs (nan/inf/strings/<=0) must fall back
|
|
436
|
+
# to a default here, never crash -- config_schema raises the proper ConfigError next.
|
|
437
|
+
def _g(obj, key):
|
|
438
|
+
if obj is None:
|
|
439
|
+
return None
|
|
440
|
+
return obj.get(key) if isinstance(obj, dict) else getattr(obj, key, None)
|
|
441
|
+
|
|
442
|
+
def _pos_int(v, default):
|
|
443
|
+
try:
|
|
444
|
+
if isinstance(v, bool):
|
|
445
|
+
return default
|
|
446
|
+
f = float(v)
|
|
447
|
+
return int(f) if math.isfinite(f) and f >= 1 else default
|
|
448
|
+
except (TypeError, ValueError):
|
|
449
|
+
return default
|
|
450
|
+
|
|
451
|
+
max_tokens = _pos_int(_g(train, "max_tokens"), None)
|
|
452
|
+
# Default sequence length when [train].max_length is unset. For GRPO this must MIRROR what
|
|
453
|
+
# run_rl() actually starts vLLM at — max(1024, RLConfig.max_prompt_len + completion) — not a
|
|
454
|
+
# flat 1024, or the allocator can pick a GPU sized for 1024 tokens while the worker launches a
|
|
455
|
+
# ~2368-token (3584 with thinking) engine and OOMs after provisioning. Completion = the run's
|
|
456
|
+
# [train].max_tokens override, else the recipe's thinking/non-thinking completion default.
|
|
457
|
+
if (algorithm or "").lower() in ("grpo", "rl"):
|
|
458
|
+
# Same engine context run_rl() launches (max_length, else max(1024, prompt+completion)) via
|
|
459
|
+
# the shared helper, so the allocator and the worker never disagree on the rollout length.
|
|
460
|
+
_grpo_default_len = grpo_rollout_seq_len(0, max_tokens, thinking)
|
|
461
|
+
else:
|
|
462
|
+
_grpo_default_len = 1024
|
|
463
|
+
seq_len = _pos_int(_g(train, "max_length"), _grpo_default_len)
|
|
464
|
+
lora_rank = _pos_int(_g(train, "lora_rank"), 32)
|
|
465
|
+
group_size = _pos_int(_g(train, "group_size"), 8)
|
|
466
|
+
# Default to the worker's per-device SFT micro-batch (4): an unset
|
|
467
|
+
# [train].batch_size still realizes that micro-batch on the worker, so size for it
|
|
468
|
+
# rather than 1 (which would under-route a long-context SFT run to a too-small card).
|
|
469
|
+
batch_size = _pos_int(_g(train, "batch_size"), _sft_per_device_bs())
|
|
470
|
+
|
|
471
|
+
def _need(
|
|
472
|
+
params_b: float,
|
|
473
|
+
algorithm: str,
|
|
474
|
+
*,
|
|
475
|
+
quant: str = "bf16",
|
|
476
|
+
use_vllm: bool = True,
|
|
477
|
+
vocab: int = _VOCAB_DEFAULT,
|
|
478
|
+
) -> int:
|
|
479
|
+
# estimate over the run's full knob matrix, then apply the safety headroom. Both the
|
|
480
|
+
# catalog and open-model paths size through here so they stay in sync on the knob set.
|
|
481
|
+
est = estimate_vram_gb(
|
|
482
|
+
params_b,
|
|
483
|
+
algorithm,
|
|
484
|
+
quant,
|
|
485
|
+
seq_len=seq_len,
|
|
486
|
+
max_tokens=max_tokens,
|
|
487
|
+
lora_rank=lora_rank,
|
|
488
|
+
batch_size=batch_size,
|
|
489
|
+
group_size=group_size,
|
|
490
|
+
thinking=thinking,
|
|
491
|
+
use_vllm=use_vllm,
|
|
492
|
+
vocab=vocab,
|
|
493
|
+
)
|
|
494
|
+
return math.ceil(est * headroom)
|
|
495
|
+
|
|
496
|
+
from flash.catalog import MODELS, vocab_size_for
|
|
497
|
+
|
|
498
|
+
info = MODELS.get(model_id)
|
|
499
|
+
# Per-model output vocab (lm_head / logits width) sizes the fp32-logits term; resolved
|
|
500
|
+
# from the catalog (curated value, else open-model fallback) instead of a hardcoded const.
|
|
501
|
+
model_vocab = vocab_size_for(model_id)
|
|
502
|
+
is_grpo = (algorithm or "").lower() in ("grpo", "rl")
|
|
503
|
+
if info is not None:
|
|
504
|
+
params_b = params_b_from_str(info.params)
|
|
505
|
+
quant = getattr(info, "quant", "bf16") or "bf16"
|
|
506
|
+
# GRPO always runs the rollout on a colocated vLLM engine, so sizing must reserve room for
|
|
507
|
+
# the 2nd (rollout) weight copy on the same card.
|
|
508
|
+
use_vllm = True
|
|
509
|
+
need = _need(params_b or 4.0, algorithm, quant=quant, use_vllm=use_vllm, vocab=model_vocab)
|
|
510
|
+
# Hard floor the param-based matrix can't see: a curated GRPO floor.
|
|
511
|
+
floor = 0
|
|
512
|
+
if is_grpo and getattr(info, "grpo_min_vram_gb", 0):
|
|
513
|
+
floor = int(info.grpo_min_vram_gb)
|
|
514
|
+
# Big-model GRPO is TIGHT at its floor (2 weight copies + KV pool), so long context
|
|
515
|
+
# overflows it -> escalate to a bigger tier. See grpo_seq_escalation_gb.
|
|
516
|
+
if is_grpo and floor:
|
|
517
|
+
floor += grpo_seq_escalation_gb(params_b, seq_len)
|
|
518
|
+
need = max(need, floor)
|
|
519
|
+
# vLLM-colocate floor: the engine (CUDA context + KV pool sized to the CARD's VRAM +
|
|
520
|
+
# framework) + the 2nd resident weight copy add a ~constant the param estimate misses,
|
|
521
|
+
# so small-model GRPO under-provisions. MEASURED at the default group_size=8: 0.8B GRPO
|
|
522
|
+
# fits a 24 GB card but OOMs 20 (est ~18, ~6 GB headroom on 24); 2B GRPO OOMs a 24 GB
|
|
523
|
+
# card (est ~20 but the colocate cost tips it past 24 -> needs the 32 tier). So sub-~1B
|
|
524
|
+
# models floor at 24, while larger small-models that the param estimate still under-shoots
|
|
525
|
+
# floor at the 32 tier. 4B+ already exceed this via their param estimate, so untouched.
|
|
526
|
+
if is_grpo and use_vllm:
|
|
527
|
+
floor_gb = 24 if (params_b or 0.0) <= 1.0 else int(_VLLM_COLOCATE_FLOOR_GB)
|
|
528
|
+
need = max(need, floor_gb)
|
|
529
|
+
return need
|
|
530
|
+
# Unlisted open model: size from HF metadata (GRPO is the heavier phase).
|
|
531
|
+
params_b = fetch_hf_params_b(model_id)
|
|
532
|
+
if params_b is None:
|
|
533
|
+
return 24
|
|
534
|
+
# Open models size against the heavier GRPO phase regardless of the requested algorithm.
|
|
535
|
+
need = _need(params_b, "grpo", vocab=model_vocab)
|
|
536
|
+
# Same long-context GRPO escalation as the catalog path so a big open model isn't
|
|
537
|
+
# under-provisioned at long context either.
|
|
538
|
+
if is_grpo:
|
|
539
|
+
need += grpo_seq_escalation_gb(params_b, seq_len)
|
|
540
|
+
return need
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
def fetch_hf_params_b(model_id: str) -> float | None:
|
|
544
|
+
"""Total params (billions) from the HF API safetensors metadata (no download).
|
|
545
|
+
|
|
546
|
+
Best-effort: returns ``None`` when the size can't be read (no network / no HF metadata),
|
|
547
|
+
so callers fall back to the offline heuristic rather than failing.
|
|
548
|
+
"""
|
|
549
|
+
try:
|
|
550
|
+
from huggingface_hub import HfApi
|
|
551
|
+
|
|
552
|
+
info = HfApi(token=os.environ.get("HF_TOKEN")).model_info(
|
|
553
|
+
model_id, expand=["safetensors"]
|
|
554
|
+
)
|
|
555
|
+
total = getattr(getattr(info, "safetensors", None), "total", None)
|
|
556
|
+
if total:
|
|
557
|
+
return float(total) / 1e9
|
|
558
|
+
except Exception:
|
|
559
|
+
# Best-effort size probe (network/HF-metadata may be unavailable); fall through
|
|
560
|
+
# to None so callers report "size unknown" rather than failing.
|
|
561
|
+
pass
|
|
562
|
+
return None
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
def resolve_params_b(model_id: str) -> float | None:
|
|
566
|
+
"""Model size in billions, resolved the ONE way the worker and the cost estimator agree on:
|
|
567
|
+
the curated catalog ``params_b`` (else its ``params`` display string), else the real HF
|
|
568
|
+
safetensors param count for an open-policy (uncataloged) model. Best-effort: returns None only
|
|
569
|
+
when the model is uncataloged AND HF metadata is unavailable, so callers degrade to the
|
|
570
|
+
size-unknown path (e.g. the fused-CE gate stays memory-safe, the colocate cap stays loose).
|
|
571
|
+
The single source of truth for "how big is this model" -- run_sft, run_rl and cost.spec all
|
|
572
|
+
call this so they can never drift."""
|
|
573
|
+
from flash.catalog import MODELS
|
|
574
|
+
|
|
575
|
+
info = MODELS.get(model_id)
|
|
576
|
+
if info is not None:
|
|
577
|
+
pb = getattr(info, "params_b", 0.0) or params_b_from_str(getattr(info, "params", None))
|
|
578
|
+
if pb:
|
|
579
|
+
return pb
|
|
580
|
+
return fetch_hf_params_b(model_id)
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
def check_fit(
|
|
584
|
+
model_id: str,
|
|
585
|
+
algorithm: str,
|
|
586
|
+
gpu: str,
|
|
587
|
+
quant: str = "bf16",
|
|
588
|
+
params_b: float | None = None,
|
|
589
|
+
) -> VramEstimate:
|
|
590
|
+
"""Estimate whether ``model_id`` plausibly trains on ``gpu``; never raises."""
|
|
591
|
+
gpu_gb = GPU_VRAM_GB.get(gpu, 32)
|
|
592
|
+
if params_b is None:
|
|
593
|
+
params_b = fetch_hf_params_b(model_id)
|
|
594
|
+
if params_b is None:
|
|
595
|
+
return VramEstimate(None, algorithm, quant, None, gpu, gpu_gb, "unknown")
|
|
596
|
+
est = estimate_vram_gb(params_b, algorithm, quant)
|
|
597
|
+
if est > gpu_gb * 1.15:
|
|
598
|
+
verdict = "too_big"
|
|
599
|
+
elif est > gpu_gb * 0.85:
|
|
600
|
+
verdict = "tight"
|
|
601
|
+
else:
|
|
602
|
+
verdict = "fits"
|
|
603
|
+
return VramEstimate(params_b, algorithm, quant, est, gpu, gpu_gb, verdict)
|