freesolo-flash-dev 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flash/__init__.py +29 -0
- flash/_channel.py +23 -0
- flash/_fileio.py +35 -0
- flash/_logging.py +49 -0
- flash/_update_check.py +266 -0
- flash/catalog.py +253 -0
- flash/cli/__init__.py +1 -0
- flash/cli/main/__init__.py +227 -0
- flash/cli/main/__main__.py +6 -0
- flash/cli/main/commands.py +636 -0
- flash/cli/main/envpush.py +317 -0
- flash/cli/main/render.py +599 -0
- flash/cli/main/training_doc.py +455 -0
- flash/client/__init__.py +14 -0
- flash/client/config.py +70 -0
- flash/client/http.py +372 -0
- flash/client/runtime_secrets.py +69 -0
- flash/client/specs.py +20 -0
- flash/cost/__init__.py +16 -0
- flash/cost/analytical.py +175 -0
- flash/cost/facts.py +114 -0
- flash/cost/spec.py +113 -0
- flash/cost/types.py +158 -0
- flash/engine/__init__.py +6 -0
- flash/engine/accounting.py +36 -0
- flash/engine/chalk_kernels.py +116 -0
- flash/engine/multiturn_rollout.py +780 -0
- flash/engine/recipe.py +86 -0
- flash/engine/vram.py +603 -0
- flash/engine/worker/__init__.py +2916 -0
- flash/engine/worker/__main__.py +4 -0
- flash/engine/worker/kernel_warmup.py +400 -0
- flash/engine/worker/lora.py +796 -0
- flash/engine/worker/packing.py +366 -0
- flash/engine/worker/perf.py +1048 -0
- flash/envs/__init__.py +10 -0
- flash/envs/adapter/__init__.py +883 -0
- flash/envs/adapter/rubric.py +222 -0
- flash/envs/base.py +52 -0
- flash/envs/registry.py +62 -0
- flash/mcp/__init__.py +1 -0
- flash/mcp/server.py +85 -0
- flash/providers/__init__.py +59 -0
- flash/providers/_auth.py +24 -0
- flash/providers/_http.py +230 -0
- flash/providers/_instance.py +416 -0
- flash/providers/_instance_bootstrap.py +517 -0
- flash/providers/_poll.py +311 -0
- flash/providers/allocator.py +193 -0
- flash/providers/base.py +431 -0
- flash/providers/hyperstack/__init__.py +127 -0
- flash/providers/hyperstack/api.py +522 -0
- flash/providers/hyperstack/auth.py +17 -0
- flash/providers/hyperstack/gpus.py +29 -0
- flash/providers/hyperstack/jobs/__init__.py +632 -0
- flash/providers/hyperstack/jobs/builders.py +122 -0
- flash/providers/hyperstack/preflight.py +23 -0
- flash/providers/hyperstack/pricing.py +26 -0
- flash/providers/hyperstack/train.py +25 -0
- flash/providers/lambdalabs/__init__.py +139 -0
- flash/providers/lambdalabs/api.py +261 -0
- flash/providers/lambdalabs/auth.py +18 -0
- flash/providers/lambdalabs/gpus.py +29 -0
- flash/providers/lambdalabs/jobs/__init__.py +724 -0
- flash/providers/lambdalabs/jobs/builders.py +118 -0
- flash/providers/lambdalabs/preflight.py +27 -0
- flash/providers/lambdalabs/pricing.py +51 -0
- flash/providers/lambdalabs/train.py +27 -0
- flash/providers/preflight.py +55 -0
- flash/providers/realized.py +80 -0
- flash/providers/runpod/__init__.py +130 -0
- flash/providers/runpod/api.py +186 -0
- flash/providers/runpod/auth.py +37 -0
- flash/providers/runpod/cost.py +57 -0
- flash/providers/runpod/gpus.py +46 -0
- flash/providers/runpod/jobs.py +956 -0
- flash/providers/runpod/keys.py +139 -0
- flash/providers/runpod/preflight.py +30 -0
- flash/providers/runpod/preload.py +915 -0
- flash/providers/runpod/pricing.py +18 -0
- flash/providers/runpod/slots.py +79 -0
- flash/providers/runpod/train/__init__.py +150 -0
- flash/providers/runpod/train/deps.py +395 -0
- flash/providers/runpod/train/endpoints.py +820 -0
- flash/py.typed +0 -0
- flash/runner/__init__.py +686 -0
- flash/runner/checkpoints.py +82 -0
- flash/runner/deploy.py +422 -0
- flash/runner/lifecycle.py +672 -0
- flash/schema/__init__.py +375 -0
- flash/schema/fields.py +331 -0
- flash/serve/__init__.py +1 -0
- flash/serve/deploy.py +326 -0
- flash/serve/pricing.py +60 -0
- flash/server/__init__.py +1 -0
- flash/server/__main__.py +20 -0
- flash/server/app.py +961 -0
- flash/server/auth.py +263 -0
- flash/server/billing.py +124 -0
- flash/server/checkpoints.py +110 -0
- flash/server/db.py +160 -0
- flash/server/environment_registry.py +102 -0
- flash/server/envs.py +360 -0
- flash/server/reconcile.py +163 -0
- flash/server/run_registry.py +150 -0
- flash/spec.py +333 -0
- freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
- freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
- freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
- freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
- freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
flash/providers/_poll.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
"""Shared poll-loop scaffolding for provider job pollers.
|
|
2
|
+
|
|
3
|
+
Poll loops share a timestamped ``say()`` logger, a consecutive-poll-error retry/give-up
|
|
4
|
+
counter, and the heartbeat progress-surfacing block (key on (stage, step, ts), log
|
|
5
|
+
``worker: stage=… step=… reward=…``). Only those neutral pieces live here; each poller
|
|
6
|
+
keeps its own status/terminal handling inline.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import os
|
|
12
|
+
import re
|
|
13
|
+
import time
|
|
14
|
+
from collections.abc import Callable
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
# Grace past a preload box's embedded wall deadline before an orphan sweep reaps it. A healthy warm
|
|
18
|
+
# self-bounds at its wall cap (the in-box timer ``os._exit``s) and the driver's ``finally`` terminates
|
|
19
|
+
# the instance; a box still alive THIS long past its deadline has lost its driver (the only thing that
|
|
20
|
+
# tears instance providers down), so it is provably orphaned and safe to reap. Generous so clock skew /
|
|
21
|
+
# a slow teardown / a near-deadline box mid-download is never reaped early.
|
|
22
|
+
PRELOAD_REAP_GRACE_S = 1800.0
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def preload_instance_run_id(provider: str, region: str, reap_deadline_epoch: int, suffix: str) -> str:
|
|
26
|
+
"""Build a ``flash-preload-*`` run id that embeds its wall-clock reap deadline (``-d<epoch>-``).
|
|
27
|
+
|
|
28
|
+
The epoch lets an orphan sweep reap a driver-lost warm box by NAME alone (no provider creation-time
|
|
29
|
+
field needed). ``reap_deadline_epoch`` is the box's wall-cap deadline in epoch seconds. Kept in sync
|
|
30
|
+
with ``preload_box_reap_due``'s parser — change both together."""
|
|
31
|
+
return f"flash-preload-{provider}-{region.lower()}-d{int(reap_deadline_epoch)}-{suffix}"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def preload_box_reap_due(name: str, now: float, grace_s: float = PRELOAD_REAP_GRACE_S) -> bool:
|
|
35
|
+
"""True when a ``flash-preload-*`` instance name carries an embedded reap deadline (``-d<epoch>-``,
|
|
36
|
+
written by ``preload_instance_run_id``) that elapsed more than ``grace_s`` ago.
|
|
37
|
+
|
|
38
|
+
Used by the Lambda/Hyperstack orphan sweeps: warm boxes are normally driver-owned and exempt, but a
|
|
39
|
+
driver that died before its ``terminate_run_instances`` finally would leave one billing forever.
|
|
40
|
+
Reaping past deadline+grace bounds that leak. Names WITHOUT a parseable deadline (legacy launches)
|
|
41
|
+
return False — the unconditional driver-owned exemption still applies to them. The 10+ digit guard
|
|
42
|
+
keeps a region segment like ``us-east-1`` from being mistaken for the ``-d<epoch>-`` token."""
|
|
43
|
+
m = re.search(r"-d(\d{10,})-", name)
|
|
44
|
+
if not m:
|
|
45
|
+
return False
|
|
46
|
+
return float(m.group(1)) + grace_s < now
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def make_say(log) -> Callable[[str], None]:
|
|
50
|
+
"""A timestamped line logger that no-ops when ``log`` is None."""
|
|
51
|
+
|
|
52
|
+
def say(msg: str) -> None:
|
|
53
|
+
if log is not None:
|
|
54
|
+
print(f"[{time.strftime('%H:%M:%S')}] {msg}", file=log, flush=True)
|
|
55
|
+
|
|
56
|
+
return say
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class PollErrorTracker:
|
|
60
|
+
"""Counts consecutive poll errors and decides when to give up.
|
|
61
|
+
|
|
62
|
+
Encapsulates the identical retry block both pollers use: on a transient fetch
|
|
63
|
+
error, log it, give up after ``max_errors`` consecutive failures, otherwise sleep
|
|
64
|
+
a linear backoff (capped at 60 s) before the caller retries.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(self, say: Callable[[str], None], interval_s: float, max_errors: int = 8) -> None:
|
|
68
|
+
self._say = say
|
|
69
|
+
self._interval_s = interval_s
|
|
70
|
+
self._max_errors = max_errors
|
|
71
|
+
self._count = 0
|
|
72
|
+
|
|
73
|
+
def reset(self) -> None:
|
|
74
|
+
self._count = 0
|
|
75
|
+
|
|
76
|
+
def record(self, exc: Exception) -> bool:
|
|
77
|
+
"""Register a poll error. Returns True if the caller should give up (too many),
|
|
78
|
+
else sleeps the backoff and returns False (caller should ``continue``)."""
|
|
79
|
+
self._count += 1
|
|
80
|
+
self._say(f"poll error ({self._count}): {exc}")
|
|
81
|
+
if self._count >= self._max_errors:
|
|
82
|
+
return True
|
|
83
|
+
time.sleep(min(60, self._interval_s * self._count))
|
|
84
|
+
return False
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _num(value: Any) -> float | None:
|
|
88
|
+
try:
|
|
89
|
+
if value is None:
|
|
90
|
+
return None
|
|
91
|
+
return float(value)
|
|
92
|
+
except (TypeError, ValueError):
|
|
93
|
+
return None
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _fmt_float(value: Any, digits: int = 3) -> str | None:
|
|
97
|
+
num = _num(value)
|
|
98
|
+
if num is None:
|
|
99
|
+
return None
|
|
100
|
+
return f"{num:.{digits}f}"
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _fmt_gb(value: Any) -> str | None:
|
|
104
|
+
num = _num(value)
|
|
105
|
+
if num is None:
|
|
106
|
+
return None
|
|
107
|
+
return f"{num:.1f}GB"
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _fmt_pct(value: Any) -> str | None:
|
|
111
|
+
num = _num(value)
|
|
112
|
+
if num is None:
|
|
113
|
+
return None
|
|
114
|
+
return f"{num:.0f}%"
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _fmt_watts(value: Any) -> str | None:
|
|
118
|
+
num = _num(value)
|
|
119
|
+
if num is None:
|
|
120
|
+
return None
|
|
121
|
+
return f"{num:.0f}W"
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _short_process_name(name: str) -> str:
|
|
125
|
+
base = os.path.basename(str(name or "").strip())
|
|
126
|
+
return base or "process"
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def format_gpu_status(gpu: Any) -> str:
|
|
130
|
+
"""Human-readable one-line GPU telemetry summary for heartbeat log lines."""
|
|
131
|
+
if not isinstance(gpu, dict) or not gpu:
|
|
132
|
+
return ""
|
|
133
|
+
parts: list[str] = []
|
|
134
|
+
name = gpu.get("device_name") or gpu.get("name")
|
|
135
|
+
if name:
|
|
136
|
+
parts.append(str(name))
|
|
137
|
+
driver = gpu.get("driver_version")
|
|
138
|
+
cuda = gpu.get("torch_cuda")
|
|
139
|
+
if driver:
|
|
140
|
+
parts.append(f"driver={driver}")
|
|
141
|
+
if cuda:
|
|
142
|
+
parts.append(f"cuda={cuda}")
|
|
143
|
+
util = _fmt_pct(gpu.get("gpu_util_pct"))
|
|
144
|
+
mem_util = _fmt_pct(gpu.get("mem_util_pct"))
|
|
145
|
+
if util:
|
|
146
|
+
parts.append(f"util={util}")
|
|
147
|
+
if mem_util:
|
|
148
|
+
parts.append(f"mem_util={mem_util}")
|
|
149
|
+
used = _fmt_gb(gpu.get("memory_used_gb"))
|
|
150
|
+
total = _fmt_gb(gpu.get("memory_total_gb"))
|
|
151
|
+
free = _fmt_gb(gpu.get("memory_free_gb"))
|
|
152
|
+
if used and total:
|
|
153
|
+
parts.append(f"mem={used}/{total}")
|
|
154
|
+
elif free and total:
|
|
155
|
+
parts.append(f"free={free}/{total}")
|
|
156
|
+
torch_alloc = _fmt_gb(gpu.get("torch_memory_allocated_gb"))
|
|
157
|
+
torch_reserved = _fmt_gb(gpu.get("torch_memory_reserved_gb"))
|
|
158
|
+
if torch_alloc:
|
|
159
|
+
if torch_reserved:
|
|
160
|
+
parts.append(f"torch={torch_alloc}/{torch_reserved}")
|
|
161
|
+
else:
|
|
162
|
+
parts.append(f"torch={torch_alloc}")
|
|
163
|
+
temp = _num(gpu.get("temperature_c"))
|
|
164
|
+
if temp is not None:
|
|
165
|
+
parts.append(f"temp={temp:.0f}C")
|
|
166
|
+
power = _fmt_watts(gpu.get("power_w"))
|
|
167
|
+
power_limit = _fmt_watts(gpu.get("power_limit_w"))
|
|
168
|
+
if power and power_limit:
|
|
169
|
+
parts.append(f"power={power}/{power_limit}")
|
|
170
|
+
elif power:
|
|
171
|
+
parts.append(f"power={power}")
|
|
172
|
+
pstate = gpu.get("pstate")
|
|
173
|
+
if pstate:
|
|
174
|
+
parts.append(f"pstate={pstate}")
|
|
175
|
+
processes = gpu.get("processes")
|
|
176
|
+
if isinstance(processes, list) and processes:
|
|
177
|
+
proc_parts = []
|
|
178
|
+
for proc in processes[:3]:
|
|
179
|
+
if not isinstance(proc, dict):
|
|
180
|
+
continue
|
|
181
|
+
pname = _short_process_name(str(proc.get("process_name") or ""))
|
|
182
|
+
pid = proc.get("pid")
|
|
183
|
+
mem = _fmt_gb(proc.get("used_memory_gb"))
|
|
184
|
+
label = f"{pname}:{pid}" if pid is not None else pname
|
|
185
|
+
if mem:
|
|
186
|
+
label = f"{label}:{mem}"
|
|
187
|
+
proc_parts.append(label)
|
|
188
|
+
if proc_parts:
|
|
189
|
+
parts.append("procs=" + ",".join(proc_parts))
|
|
190
|
+
if not parts:
|
|
191
|
+
if gpu.get("nvidia_smi"):
|
|
192
|
+
parts.append(str(gpu["nvidia_smi"])[:160])
|
|
193
|
+
elif gpu.get("nvidia_smi_err"):
|
|
194
|
+
parts.append(str(gpu["nvidia_smi_err"])[:160])
|
|
195
|
+
return " gpu[" + " ".join(parts) + "]" if parts else ""
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _format_heartbeat(hb: dict) -> str:
|
|
199
|
+
msg = f"worker: stage={hb.get('stage')}"
|
|
200
|
+
for key, digits in (
|
|
201
|
+
("step", 0),
|
|
202
|
+
("epoch", 3),
|
|
203
|
+
("reward", 3),
|
|
204
|
+
("loss", 4),
|
|
205
|
+
("grad_norm", 3),
|
|
206
|
+
("learning_rate", 8),
|
|
207
|
+
("setup_seconds", 1),
|
|
208
|
+
("train_wall", 1),
|
|
209
|
+
):
|
|
210
|
+
value = hb.get(key)
|
|
211
|
+
if value is None:
|
|
212
|
+
continue
|
|
213
|
+
if isinstance(value, (int, float)):
|
|
214
|
+
if digits == 0:
|
|
215
|
+
msg += f" {key}={int(value)}"
|
|
216
|
+
else:
|
|
217
|
+
msg += f" {key}={value:.{digits}f}"
|
|
218
|
+
else:
|
|
219
|
+
msg += f" {key}={value}"
|
|
220
|
+
msg += format_gpu_status(hb.get("gpu") or hb.get("diag"))
|
|
221
|
+
return msg
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def _record_heartbeat(hb: dict) -> None:
|
|
225
|
+
run_id = str(hb.get("run_id") or "").strip()
|
|
226
|
+
if not run_id:
|
|
227
|
+
return
|
|
228
|
+
try:
|
|
229
|
+
from flash.runner import record_heartbeat
|
|
230
|
+
|
|
231
|
+
record_heartbeat(run_id, hb)
|
|
232
|
+
except Exception:
|
|
233
|
+
# Status persistence is diagnostic only; polling/liveness must not depend on it.
|
|
234
|
+
pass
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def surface_heartbeat(
|
|
238
|
+
heartbeat_reader: Callable[[], Any] | None,
|
|
239
|
+
last_hb_key: tuple | None,
|
|
240
|
+
say: Callable[[str], None],
|
|
241
|
+
) -> tuple[tuple | None, str | None]:
|
|
242
|
+
"""Read a heartbeat and, if it advanced, log worker progress.
|
|
243
|
+
|
|
244
|
+
Returns ``(hb_key, stage)`` where ``hb_key`` is the new (stage, step, ts) key (or the
|
|
245
|
+
unchanged ``last_hb_key`` when nothing advanced) and ``stage`` is the stage of the new
|
|
246
|
+
heartbeat when it advanced (else None). Callers use the returned ``stage`` for their
|
|
247
|
+
own setup-vs-training stall bookkeeping.
|
|
248
|
+
"""
|
|
249
|
+
if heartbeat_reader is None:
|
|
250
|
+
return last_hb_key, None
|
|
251
|
+
try:
|
|
252
|
+
hb = heartbeat_reader()
|
|
253
|
+
except Exception:
|
|
254
|
+
hb = None
|
|
255
|
+
if not hb:
|
|
256
|
+
return last_hb_key, None
|
|
257
|
+
key = (hb.get("stage"), hb.get("step"), hb.get("ts"))
|
|
258
|
+
if key == last_hb_key:
|
|
259
|
+
return last_hb_key, None
|
|
260
|
+
_record_heartbeat(hb)
|
|
261
|
+
stage = hb.get("stage")
|
|
262
|
+
say(_format_heartbeat(hb))
|
|
263
|
+
return key, stage
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def heartbeat_progress_ts(hb_key: tuple | None, launch_ts: float | None) -> tuple[float, bool]:
|
|
267
|
+
"""Wall-clock to credit as 'last worker progress' for a just-surfaced heartbeat, plus whether
|
|
268
|
+
that heartbeat actually belongs to THIS attempt.
|
|
269
|
+
|
|
270
|
+
Use the heartbeat's OWN ``ts`` (key[2] = when the worker actually made progress), not the
|
|
271
|
+
poll time. On a delayed reattach after a control-plane restart, a heartbeat that was already
|
|
272
|
+
stale BEFORE the restart must not buy a fresh full stall window — crediting the poll time
|
|
273
|
+
would hand a hung worker another grace period while the instance keeps billing. Clamp to
|
|
274
|
+
``[launch, now]`` so worker/control-plane clock skew can neither make a healthy worker look
|
|
275
|
+
ancient (premature stall) nor land its progress in the future.
|
|
276
|
+
|
|
277
|
+
Returns ``(ts, fresh)``. ``fresh`` is False when the heartbeat's ts predates this attempt's
|
|
278
|
+
launch: that is a LEFTOVER heartbeat from a prior attempt (retries reuse the same seed
|
|
279
|
+
heartbeat path), so the caller must NOT treat it as current progress — otherwise a stale
|
|
280
|
+
training-stage heartbeat would arm the tighter training stall window and fail a healthy new
|
|
281
|
+
attempt mid-setup before it has overwritten the old file. ``launch_ts`` uses truthiness (not
|
|
282
|
+
``is not None``): the instance handles store started_ts as a non-Optional float coerced to 0.0
|
|
283
|
+
when missing, so 0.0 means "unknown launch" (a real launch is a large epoch ts). When launch is
|
|
284
|
+
UNKNOWN we cannot date heartbeats relative to it, so the clamp floor drops to 0.0 and every
|
|
285
|
+
heartbeat counts as fresh (the safe default: don't discard progress we can't date — clamping the
|
|
286
|
+
floor to ``now`` instead would mark every normal heartbeat, timestamped before it is read, stale
|
|
287
|
+
and stall a healthy recovered worker)."""
|
|
288
|
+
now = time.time()
|
|
289
|
+
ts = hb_key[2] if (isinstance(hb_key, tuple) and len(hb_key) >= 3) else None
|
|
290
|
+
try:
|
|
291
|
+
ts = float(ts)
|
|
292
|
+
except (TypeError, ValueError):
|
|
293
|
+
return now, False
|
|
294
|
+
lo = float(launch_ts) if launch_ts else 0.0 # unknown launch -> floor 0.0 (all heartbeats fresh)
|
|
295
|
+
fresh = ts >= lo
|
|
296
|
+
return min(now, max(lo, ts)), fresh
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def surface_forced_heartbeat(
|
|
300
|
+
heartbeat_reader: Callable[..., Any] | None,
|
|
301
|
+
last_hb_key: tuple | None,
|
|
302
|
+
say: Callable[[str], None],
|
|
303
|
+
) -> tuple[tuple | None, str | None]:
|
|
304
|
+
"""Force-read and surface the latest heartbeat, bypassing reader rate limits.
|
|
305
|
+
|
|
306
|
+
Used on terminal provider statuses so a fast worker failure still leaves the last worker/GPU
|
|
307
|
+
snapshot in both the run log and status JSON.
|
|
308
|
+
"""
|
|
309
|
+
if heartbeat_reader is None:
|
|
310
|
+
return last_hb_key, None
|
|
311
|
+
return surface_heartbeat(lambda: heartbeat_reader(force=True), last_hb_key, say)
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
"""GPU allocation: the cheapest fitting class across the active providers.
|
|
2
|
+
|
|
3
|
+
Given a base model (+ algorithm), compute the VRAM the FULL run needs — sized for the
|
|
4
|
+
heavier phase, GRPO, since the typical pipeline is SFT followed by GRPO — then rank every
|
|
5
|
+
fitting candidate by $/hr and pick the cheapest:
|
|
6
|
+
|
|
7
|
+
runpod every validated Flash-provisionable class (static $/hr)
|
|
8
|
+
lambda every fitting class that currently has LIVE regional capacity (live $/hr); opt-in,
|
|
9
|
+
available only when LAMBDA_API_KEY is set on the control plane
|
|
10
|
+
hyperstack every fitting class whose single-GPU flavor currently has STOCK (static $/hr); opt-in,
|
|
11
|
+
available only when HYPERSTACK_API_KEY is set on the control plane
|
|
12
|
+
|
|
13
|
+
RunPod's cheaper static rates almost always win on price, so the instance providers (Lambda,
|
|
14
|
+
Hyperstack) join the ranked list as capacity COMPLEMENTS: when RunPod's cheapest fitting class is
|
|
15
|
+
out of capacity (THROTTLED / queue backstop), the runner's gpu-walk steps down the ranked list and
|
|
16
|
+
reaches an in-capacity instance class. Both instance providers are capacity-filtered up front
|
|
17
|
+
(``_lambda_candidates`` / ``_hyperstack_candidates`` only offer a class a region/flavor can supply
|
|
18
|
+
right now), so the walk never lands on a class that would just fail to launch.
|
|
19
|
+
|
|
20
|
+
Allocation happens at SUBMIT time in the runner. The parse-time resolution in schema is a
|
|
21
|
+
RunPod-static provisional for validation/dry-run display.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
from flash._logging import get_logger
|
|
27
|
+
from flash.providers import PROVIDER_NAMES, available_providers, get_provider
|
|
28
|
+
from flash.providers.base import (
|
|
29
|
+
Allocation,
|
|
30
|
+
Candidate,
|
|
31
|
+
UnsupportedGpuError,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
logger = get_logger(__name__)
|
|
35
|
+
|
|
36
|
+
# "Comfortably" = the open-model VRAM estimate plus headroom, so a full SFT+GRPO run
|
|
37
|
+
# never lands in check_fit's "tight" band by construction. Curated catalog entries
|
|
38
|
+
# already carry measured minimums and are used as-is. The headroom (default 1.1 ==
|
|
39
|
+
# model_required_vram_gb's own default) is read at call time via vram_headroom() so allocate()
|
|
40
|
+
# and the parse-time provisional_gpu size identically.
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def vram_headroom() -> float:
|
|
44
|
+
"""The sizing headroom multiplier, honored by both the submit-time allocator and the
|
|
45
|
+
parse-time provisional_gpu so they never disagree. A constant."""
|
|
46
|
+
return 1.1
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def required_vram_gb(
|
|
50
|
+
model_id: str,
|
|
51
|
+
algorithm: str,
|
|
52
|
+
*,
|
|
53
|
+
train=None,
|
|
54
|
+
thinking: bool = False,
|
|
55
|
+
) -> int:
|
|
56
|
+
"""VRAM the full run needs, sized to the run's actual knobs (context length, LoRA
|
|
57
|
+
rank, batch / group size, thinking) via the shared ``model_required_vram_gb`` matrix.
|
|
58
|
+
|
|
59
|
+
Catalog GRPO floors stay hard floors (never under-provision a validated model); the
|
|
60
|
+
matrix sizes up from there for big contexts/groups and down to a cheaper card for
|
|
61
|
+
small runs. Unlisted open models size from HF metadata, falling back to the 24 GB tier
|
|
62
|
+
when unreadable (handled inside model_required_vram_gb)."""
|
|
63
|
+
from flash.engine.vram import model_required_vram_gb
|
|
64
|
+
|
|
65
|
+
return model_required_vram_gb(
|
|
66
|
+
model_id,
|
|
67
|
+
algorithm,
|
|
68
|
+
train=train,
|
|
69
|
+
thinking=thinking,
|
|
70
|
+
headroom=vram_headroom(),
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _runpod_candidates(need: int) -> list[Candidate]:
|
|
75
|
+
"""RunPod's fitting, validated classes priced by the static table.
|
|
76
|
+
|
|
77
|
+
Restricted to the validated pool (``g.validated``): the deployed control plane rejects a submit
|
|
78
|
+
for any non-validated class, so allocating one would only fail at submit time.
|
|
79
|
+
"""
|
|
80
|
+
provider = get_provider("runpod")
|
|
81
|
+
return [
|
|
82
|
+
Candidate("runpod", g.name, provider.hourly_rate(g.name), g.vram_gb)
|
|
83
|
+
for g in provider.gpu_classes()
|
|
84
|
+
if g.vram_gb >= need and g.validated
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _lambda_candidates(need: int) -> list[Candidate]:
|
|
89
|
+
"""Lambda's fitting classes that currently have LIVE capacity, priced live.
|
|
90
|
+
|
|
91
|
+
Capacity-aware by design: a Lambda class with no region advertising capacity is EXCLUDED, so
|
|
92
|
+
the allocator never hands the runner a Lambda class that would immediately fail to launch (and
|
|
93
|
+
burn a retry) — directly the "GPU allocation is good, doesn't randomly die" property. A Lambda
|
|
94
|
+
capacity-lookup failure (no key / network blip) degrades to the other providers: it is
|
|
95
|
+
non-fatal as long as another provider can supply a fitting class.
|
|
96
|
+
"""
|
|
97
|
+
from flash.providers.lambdalabs.jobs import usable_instances
|
|
98
|
+
|
|
99
|
+
provider = get_provider("lambda")
|
|
100
|
+
out: list[Candidate] = []
|
|
101
|
+
try:
|
|
102
|
+
for g in provider.gpu_classes():
|
|
103
|
+
if g.vram_gb < need:
|
|
104
|
+
continue
|
|
105
|
+
# usable_instances reads the cached /instance-types, so only the first call hits the API.
|
|
106
|
+
if usable_instances(g.name):
|
|
107
|
+
out.append(Candidate("lambda", g.name, provider.hourly_rate(g.name), g.vram_gb))
|
|
108
|
+
except Exception as exc:
|
|
109
|
+
logger.warning("lambda capacity lookup failed (%s); allocating without lambda", exc)
|
|
110
|
+
return []
|
|
111
|
+
return out
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _hyperstack_candidates(need: int) -> list[Candidate]:
|
|
115
|
+
"""Hyperstack's fitting classes that currently have flavor STOCK, priced statically.
|
|
116
|
+
|
|
117
|
+
Capacity-aware, exactly like Lambda: a class with no in-stock flavor is excluded so the runner
|
|
118
|
+
never walks onto a class that would immediately fail to launch. A capacity-lookup failure
|
|
119
|
+
degrades to the other providers.
|
|
120
|
+
"""
|
|
121
|
+
from flash.providers.hyperstack.jobs import usable_instances
|
|
122
|
+
|
|
123
|
+
provider = get_provider("hyperstack")
|
|
124
|
+
out: list[Candidate] = []
|
|
125
|
+
try:
|
|
126
|
+
for g in provider.gpu_classes():
|
|
127
|
+
if g.vram_gb < need:
|
|
128
|
+
continue
|
|
129
|
+
# usable_instances reads the cached /core/flavors, so only the first call hits the API.
|
|
130
|
+
if usable_instances(g.name):
|
|
131
|
+
out.append(Candidate("hyperstack", g.name, provider.hourly_rate(g.name), g.vram_gb))
|
|
132
|
+
except Exception as exc:
|
|
133
|
+
logger.warning("hyperstack capacity lookup failed (%s); allocating without hyperstack", exc)
|
|
134
|
+
return []
|
|
135
|
+
return out
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def allocate(
|
|
139
|
+
model_id: str,
|
|
140
|
+
algorithm: str,
|
|
141
|
+
*,
|
|
142
|
+
train=None,
|
|
143
|
+
thinking: bool = False,
|
|
144
|
+
) -> Allocation:
|
|
145
|
+
"""Pick the cheapest fitting (provider, GPU class) able to run the job.
|
|
146
|
+
|
|
147
|
+
There is no GPU pin — every fitting class on every available provider is eligible, and the
|
|
148
|
+
cheapest wins. RunPod is restricted to its validated pool (``GpuClass.validated``) because the
|
|
149
|
+
deployed control plane rejects a submit for any non-validated class; the instance providers
|
|
150
|
+
(Lambda via LAMBDA_API_KEY, Hyperstack via HYPERSTACK_API_KEY — both opt-in) each contribute
|
|
151
|
+
their fitting classes that currently have live capacity/stock. RunPod's cheaper static rates
|
|
152
|
+
usually win, with Lambda and Hyperstack joining as capacity complements lower in the ranked list.
|
|
153
|
+
``train``/``thinking`` size the requirement to the run's actual knobs (context, group, rank,
|
|
154
|
+
batch) via the matrix.
|
|
155
|
+
"""
|
|
156
|
+
need = required_vram_gb(model_id, algorithm, train=train, thinking=thinking)
|
|
157
|
+
available = available_providers()
|
|
158
|
+
candidates: list[Candidate] = []
|
|
159
|
+
if "runpod" in available:
|
|
160
|
+
candidates += _runpod_candidates(need)
|
|
161
|
+
if "lambda" in available:
|
|
162
|
+
candidates += _lambda_candidates(need)
|
|
163
|
+
if "hyperstack" in available:
|
|
164
|
+
candidates += _hyperstack_candidates(need)
|
|
165
|
+
if not candidates:
|
|
166
|
+
raise UnsupportedGpuError(
|
|
167
|
+
f"no allocatable GPU (>= {need} GB VRAM for {model_id}) on any available provider "
|
|
168
|
+
f"({', '.join(available) or '(none)'}); the run genuinely exceeds every active GPU class"
|
|
169
|
+
)
|
|
170
|
+
# Cheapest first; equal rates prefer less VRAM (don't burn a big card on a small job),
|
|
171
|
+
# then registry order.
|
|
172
|
+
order = {n: i for i, n in enumerate(PROVIDER_NAMES)}
|
|
173
|
+
ranked = sorted(candidates, key=lambda c: (c.hourly_usd, c.vram_gb, order.get(c.provider, 99)))
|
|
174
|
+
best = ranked[0]
|
|
175
|
+
return Allocation(
|
|
176
|
+
provider=best.provider,
|
|
177
|
+
gpu=best.gpu,
|
|
178
|
+
hourly_usd=best.hourly_usd,
|
|
179
|
+
min_vram_gb=need,
|
|
180
|
+
candidates=tuple(ranked),
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def allocation_summary(a: Allocation) -> str:
|
|
185
|
+
head = (
|
|
186
|
+
f"allocated {a.gpu} on {a.provider} at ${a.hourly_usd:.2f}/hr "
|
|
187
|
+
f"(need >= {a.min_vram_gb} GB VRAM"
|
|
188
|
+
)
|
|
189
|
+
head += ")"
|
|
190
|
+
if len(a.candidates) > 1:
|
|
191
|
+
nxt = a.candidates[1]
|
|
192
|
+
head += f"; next-best: {nxt.gpu}@{nxt.provider} ${nxt.hourly_usd:.2f}/hr"
|
|
193
|
+
return head
|