freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,311 @@
1
+ """Shared poll-loop scaffolding for provider job pollers.
2
+
3
+ Poll loops share a timestamped ``say()`` logger, a consecutive-poll-error retry/give-up
4
+ counter, and the heartbeat progress-surfacing block (key on (stage, step, ts), log
5
+ ``worker: stage=… step=… reward=…``). Only those neutral pieces live here; each poller
6
+ keeps its own status/terminal handling inline.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import os
12
+ import re
13
+ import time
14
+ from collections.abc import Callable
15
+ from typing import Any
16
+
17
+ # Grace past a preload box's embedded wall deadline before an orphan sweep reaps it. A healthy warm
18
+ # self-bounds at its wall cap (the in-box timer ``os._exit``s) and the driver's ``finally`` terminates
19
+ # the instance; a box still alive THIS long past its deadline has lost its driver (the only thing that
20
+ # tears instance providers down), so it is provably orphaned and safe to reap. Generous so clock skew /
21
+ # a slow teardown / a near-deadline box mid-download is never reaped early.
22
+ PRELOAD_REAP_GRACE_S = 1800.0
23
+
24
+
25
+ def preload_instance_run_id(provider: str, region: str, reap_deadline_epoch: int, suffix: str) -> str:
26
+ """Build a ``flash-preload-*`` run id that embeds its wall-clock reap deadline (``-d<epoch>-``).
27
+
28
+ The epoch lets an orphan sweep reap a driver-lost warm box by NAME alone (no provider creation-time
29
+ field needed). ``reap_deadline_epoch`` is the box's wall-cap deadline in epoch seconds. Kept in sync
30
+ with ``preload_box_reap_due``'s parser — change both together."""
31
+ return f"flash-preload-{provider}-{region.lower()}-d{int(reap_deadline_epoch)}-{suffix}"
32
+
33
+
34
+ def preload_box_reap_due(name: str, now: float, grace_s: float = PRELOAD_REAP_GRACE_S) -> bool:
35
+ """True when a ``flash-preload-*`` instance name carries an embedded reap deadline (``-d<epoch>-``,
36
+ written by ``preload_instance_run_id``) that elapsed more than ``grace_s`` ago.
37
+
38
+ Used by the Lambda/Hyperstack orphan sweeps: warm boxes are normally driver-owned and exempt, but a
39
+ driver that died before its ``terminate_run_instances`` finally would leave one billing forever.
40
+ Reaping past deadline+grace bounds that leak. Names WITHOUT a parseable deadline (legacy launches)
41
+ return False — the unconditional driver-owned exemption still applies to them. The 10+ digit guard
42
+ keeps a region segment like ``us-east-1`` from being mistaken for the ``-d<epoch>-`` token."""
43
+ m = re.search(r"-d(\d{10,})-", name)
44
+ if not m:
45
+ return False
46
+ return float(m.group(1)) + grace_s < now
47
+
48
+
49
+ def make_say(log) -> Callable[[str], None]:
50
+ """A timestamped line logger that no-ops when ``log`` is None."""
51
+
52
+ def say(msg: str) -> None:
53
+ if log is not None:
54
+ print(f"[{time.strftime('%H:%M:%S')}] {msg}", file=log, flush=True)
55
+
56
+ return say
57
+
58
+
59
+ class PollErrorTracker:
60
+ """Counts consecutive poll errors and decides when to give up.
61
+
62
+ Encapsulates the identical retry block both pollers use: on a transient fetch
63
+ error, log it, give up after ``max_errors`` consecutive failures, otherwise sleep
64
+ a linear backoff (capped at 60 s) before the caller retries.
65
+ """
66
+
67
+ def __init__(self, say: Callable[[str], None], interval_s: float, max_errors: int = 8) -> None:
68
+ self._say = say
69
+ self._interval_s = interval_s
70
+ self._max_errors = max_errors
71
+ self._count = 0
72
+
73
+ def reset(self) -> None:
74
+ self._count = 0
75
+
76
+ def record(self, exc: Exception) -> bool:
77
+ """Register a poll error. Returns True if the caller should give up (too many),
78
+ else sleeps the backoff and returns False (caller should ``continue``)."""
79
+ self._count += 1
80
+ self._say(f"poll error ({self._count}): {exc}")
81
+ if self._count >= self._max_errors:
82
+ return True
83
+ time.sleep(min(60, self._interval_s * self._count))
84
+ return False
85
+
86
+
87
+ def _num(value: Any) -> float | None:
88
+ try:
89
+ if value is None:
90
+ return None
91
+ return float(value)
92
+ except (TypeError, ValueError):
93
+ return None
94
+
95
+
96
+ def _fmt_float(value: Any, digits: int = 3) -> str | None:
97
+ num = _num(value)
98
+ if num is None:
99
+ return None
100
+ return f"{num:.{digits}f}"
101
+
102
+
103
+ def _fmt_gb(value: Any) -> str | None:
104
+ num = _num(value)
105
+ if num is None:
106
+ return None
107
+ return f"{num:.1f}GB"
108
+
109
+
110
+ def _fmt_pct(value: Any) -> str | None:
111
+ num = _num(value)
112
+ if num is None:
113
+ return None
114
+ return f"{num:.0f}%"
115
+
116
+
117
+ def _fmt_watts(value: Any) -> str | None:
118
+ num = _num(value)
119
+ if num is None:
120
+ return None
121
+ return f"{num:.0f}W"
122
+
123
+
124
+ def _short_process_name(name: str) -> str:
125
+ base = os.path.basename(str(name or "").strip())
126
+ return base or "process"
127
+
128
+
129
+ def format_gpu_status(gpu: Any) -> str:
130
+ """Human-readable one-line GPU telemetry summary for heartbeat log lines."""
131
+ if not isinstance(gpu, dict) or not gpu:
132
+ return ""
133
+ parts: list[str] = []
134
+ name = gpu.get("device_name") or gpu.get("name")
135
+ if name:
136
+ parts.append(str(name))
137
+ driver = gpu.get("driver_version")
138
+ cuda = gpu.get("torch_cuda")
139
+ if driver:
140
+ parts.append(f"driver={driver}")
141
+ if cuda:
142
+ parts.append(f"cuda={cuda}")
143
+ util = _fmt_pct(gpu.get("gpu_util_pct"))
144
+ mem_util = _fmt_pct(gpu.get("mem_util_pct"))
145
+ if util:
146
+ parts.append(f"util={util}")
147
+ if mem_util:
148
+ parts.append(f"mem_util={mem_util}")
149
+ used = _fmt_gb(gpu.get("memory_used_gb"))
150
+ total = _fmt_gb(gpu.get("memory_total_gb"))
151
+ free = _fmt_gb(gpu.get("memory_free_gb"))
152
+ if used and total:
153
+ parts.append(f"mem={used}/{total}")
154
+ elif free and total:
155
+ parts.append(f"free={free}/{total}")
156
+ torch_alloc = _fmt_gb(gpu.get("torch_memory_allocated_gb"))
157
+ torch_reserved = _fmt_gb(gpu.get("torch_memory_reserved_gb"))
158
+ if torch_alloc:
159
+ if torch_reserved:
160
+ parts.append(f"torch={torch_alloc}/{torch_reserved}")
161
+ else:
162
+ parts.append(f"torch={torch_alloc}")
163
+ temp = _num(gpu.get("temperature_c"))
164
+ if temp is not None:
165
+ parts.append(f"temp={temp:.0f}C")
166
+ power = _fmt_watts(gpu.get("power_w"))
167
+ power_limit = _fmt_watts(gpu.get("power_limit_w"))
168
+ if power and power_limit:
169
+ parts.append(f"power={power}/{power_limit}")
170
+ elif power:
171
+ parts.append(f"power={power}")
172
+ pstate = gpu.get("pstate")
173
+ if pstate:
174
+ parts.append(f"pstate={pstate}")
175
+ processes = gpu.get("processes")
176
+ if isinstance(processes, list) and processes:
177
+ proc_parts = []
178
+ for proc in processes[:3]:
179
+ if not isinstance(proc, dict):
180
+ continue
181
+ pname = _short_process_name(str(proc.get("process_name") or ""))
182
+ pid = proc.get("pid")
183
+ mem = _fmt_gb(proc.get("used_memory_gb"))
184
+ label = f"{pname}:{pid}" if pid is not None else pname
185
+ if mem:
186
+ label = f"{label}:{mem}"
187
+ proc_parts.append(label)
188
+ if proc_parts:
189
+ parts.append("procs=" + ",".join(proc_parts))
190
+ if not parts:
191
+ if gpu.get("nvidia_smi"):
192
+ parts.append(str(gpu["nvidia_smi"])[:160])
193
+ elif gpu.get("nvidia_smi_err"):
194
+ parts.append(str(gpu["nvidia_smi_err"])[:160])
195
+ return " gpu[" + " ".join(parts) + "]" if parts else ""
196
+
197
+
198
+ def _format_heartbeat(hb: dict) -> str:
199
+ msg = f"worker: stage={hb.get('stage')}"
200
+ for key, digits in (
201
+ ("step", 0),
202
+ ("epoch", 3),
203
+ ("reward", 3),
204
+ ("loss", 4),
205
+ ("grad_norm", 3),
206
+ ("learning_rate", 8),
207
+ ("setup_seconds", 1),
208
+ ("train_wall", 1),
209
+ ):
210
+ value = hb.get(key)
211
+ if value is None:
212
+ continue
213
+ if isinstance(value, (int, float)):
214
+ if digits == 0:
215
+ msg += f" {key}={int(value)}"
216
+ else:
217
+ msg += f" {key}={value:.{digits}f}"
218
+ else:
219
+ msg += f" {key}={value}"
220
+ msg += format_gpu_status(hb.get("gpu") or hb.get("diag"))
221
+ return msg
222
+
223
+
224
+ def _record_heartbeat(hb: dict) -> None:
225
+ run_id = str(hb.get("run_id") or "").strip()
226
+ if not run_id:
227
+ return
228
+ try:
229
+ from flash.runner import record_heartbeat
230
+
231
+ record_heartbeat(run_id, hb)
232
+ except Exception:
233
+ # Status persistence is diagnostic only; polling/liveness must not depend on it.
234
+ pass
235
+
236
+
237
+ def surface_heartbeat(
238
+ heartbeat_reader: Callable[[], Any] | None,
239
+ last_hb_key: tuple | None,
240
+ say: Callable[[str], None],
241
+ ) -> tuple[tuple | None, str | None]:
242
+ """Read a heartbeat and, if it advanced, log worker progress.
243
+
244
+ Returns ``(hb_key, stage)`` where ``hb_key`` is the new (stage, step, ts) key (or the
245
+ unchanged ``last_hb_key`` when nothing advanced) and ``stage`` is the stage of the new
246
+ heartbeat when it advanced (else None). Callers use the returned ``stage`` for their
247
+ own setup-vs-training stall bookkeeping.
248
+ """
249
+ if heartbeat_reader is None:
250
+ return last_hb_key, None
251
+ try:
252
+ hb = heartbeat_reader()
253
+ except Exception:
254
+ hb = None
255
+ if not hb:
256
+ return last_hb_key, None
257
+ key = (hb.get("stage"), hb.get("step"), hb.get("ts"))
258
+ if key == last_hb_key:
259
+ return last_hb_key, None
260
+ _record_heartbeat(hb)
261
+ stage = hb.get("stage")
262
+ say(_format_heartbeat(hb))
263
+ return key, stage
264
+
265
+
266
+ def heartbeat_progress_ts(hb_key: tuple | None, launch_ts: float | None) -> tuple[float, bool]:
267
+ """Wall-clock to credit as 'last worker progress' for a just-surfaced heartbeat, plus whether
268
+ that heartbeat actually belongs to THIS attempt.
269
+
270
+ Use the heartbeat's OWN ``ts`` (key[2] = when the worker actually made progress), not the
271
+ poll time. On a delayed reattach after a control-plane restart, a heartbeat that was already
272
+ stale BEFORE the restart must not buy a fresh full stall window — crediting the poll time
273
+ would hand a hung worker another grace period while the instance keeps billing. Clamp to
274
+ ``[launch, now]`` so worker/control-plane clock skew can neither make a healthy worker look
275
+ ancient (premature stall) nor land its progress in the future.
276
+
277
+ Returns ``(ts, fresh)``. ``fresh`` is False when the heartbeat's ts predates this attempt's
278
+ launch: that is a LEFTOVER heartbeat from a prior attempt (retries reuse the same seed
279
+ heartbeat path), so the caller must NOT treat it as current progress — otherwise a stale
280
+ training-stage heartbeat would arm the tighter training stall window and fail a healthy new
281
+ attempt mid-setup before it has overwritten the old file. ``launch_ts`` uses truthiness (not
282
+ ``is not None``): the instance handles store started_ts as a non-Optional float coerced to 0.0
283
+ when missing, so 0.0 means "unknown launch" (a real launch is a large epoch ts). When launch is
284
+ UNKNOWN we cannot date heartbeats relative to it, so the clamp floor drops to 0.0 and every
285
+ heartbeat counts as fresh (the safe default: don't discard progress we can't date — clamping the
286
+ floor to ``now`` instead would mark every normal heartbeat, timestamped before it is read, stale
287
+ and stall a healthy recovered worker)."""
288
+ now = time.time()
289
+ ts = hb_key[2] if (isinstance(hb_key, tuple) and len(hb_key) >= 3) else None
290
+ try:
291
+ ts = float(ts)
292
+ except (TypeError, ValueError):
293
+ return now, False
294
+ lo = float(launch_ts) if launch_ts else 0.0 # unknown launch -> floor 0.0 (all heartbeats fresh)
295
+ fresh = ts >= lo
296
+ return min(now, max(lo, ts)), fresh
297
+
298
+
299
+ def surface_forced_heartbeat(
300
+ heartbeat_reader: Callable[..., Any] | None,
301
+ last_hb_key: tuple | None,
302
+ say: Callable[[str], None],
303
+ ) -> tuple[tuple | None, str | None]:
304
+ """Force-read and surface the latest heartbeat, bypassing reader rate limits.
305
+
306
+ Used on terminal provider statuses so a fast worker failure still leaves the last worker/GPU
307
+ snapshot in both the run log and status JSON.
308
+ """
309
+ if heartbeat_reader is None:
310
+ return last_hb_key, None
311
+ return surface_heartbeat(lambda: heartbeat_reader(force=True), last_hb_key, say)
@@ -0,0 +1,193 @@
1
+ """GPU allocation: the cheapest fitting class across the active providers.
2
+
3
+ Given a base model (+ algorithm), compute the VRAM the FULL run needs — sized for the
4
+ heavier phase, GRPO, since the typical pipeline is SFT followed by GRPO — then rank every
5
+ fitting candidate by $/hr and pick the cheapest:
6
+
7
+ runpod every validated Flash-provisionable class (static $/hr)
8
+ lambda every fitting class that currently has LIVE regional capacity (live $/hr); opt-in,
9
+ available only when LAMBDA_API_KEY is set on the control plane
10
+ hyperstack every fitting class whose single-GPU flavor currently has STOCK (static $/hr); opt-in,
11
+ available only when HYPERSTACK_API_KEY is set on the control plane
12
+
13
+ RunPod's cheaper static rates almost always win on price, so the instance providers (Lambda,
14
+ Hyperstack) join the ranked list as capacity COMPLEMENTS: when RunPod's cheapest fitting class is
15
+ out of capacity (THROTTLED / queue backstop), the runner's gpu-walk steps down the ranked list and
16
+ reaches an in-capacity instance class. Both instance providers are capacity-filtered up front
17
+ (``_lambda_candidates`` / ``_hyperstack_candidates`` only offer a class a region/flavor can supply
18
+ right now), so the walk never lands on a class that would just fail to launch.
19
+
20
+ Allocation happens at SUBMIT time in the runner. The parse-time resolution in schema is a
21
+ RunPod-static provisional for validation/dry-run display.
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ from flash._logging import get_logger
27
+ from flash.providers import PROVIDER_NAMES, available_providers, get_provider
28
+ from flash.providers.base import (
29
+ Allocation,
30
+ Candidate,
31
+ UnsupportedGpuError,
32
+ )
33
+
34
+ logger = get_logger(__name__)
35
+
36
+ # "Comfortably" = the open-model VRAM estimate plus headroom, so a full SFT+GRPO run
37
+ # never lands in check_fit's "tight" band by construction. Curated catalog entries
38
+ # already carry measured minimums and are used as-is. The headroom (default 1.1 ==
39
+ # model_required_vram_gb's own default) is read at call time via vram_headroom() so allocate()
40
+ # and the parse-time provisional_gpu size identically.
41
+
42
+
43
+ def vram_headroom() -> float:
44
+ """The sizing headroom multiplier, honored by both the submit-time allocator and the
45
+ parse-time provisional_gpu so they never disagree. A constant."""
46
+ return 1.1
47
+
48
+
49
+ def required_vram_gb(
50
+ model_id: str,
51
+ algorithm: str,
52
+ *,
53
+ train=None,
54
+ thinking: bool = False,
55
+ ) -> int:
56
+ """VRAM the full run needs, sized to the run's actual knobs (context length, LoRA
57
+ rank, batch / group size, thinking) via the shared ``model_required_vram_gb`` matrix.
58
+
59
+ Catalog GRPO floors stay hard floors (never under-provision a validated model); the
60
+ matrix sizes up from there for big contexts/groups and down to a cheaper card for
61
+ small runs. Unlisted open models size from HF metadata, falling back to the 24 GB tier
62
+ when unreadable (handled inside model_required_vram_gb)."""
63
+ from flash.engine.vram import model_required_vram_gb
64
+
65
+ return model_required_vram_gb(
66
+ model_id,
67
+ algorithm,
68
+ train=train,
69
+ thinking=thinking,
70
+ headroom=vram_headroom(),
71
+ )
72
+
73
+
74
+ def _runpod_candidates(need: int) -> list[Candidate]:
75
+ """RunPod's fitting, validated classes priced by the static table.
76
+
77
+ Restricted to the validated pool (``g.validated``): the deployed control plane rejects a submit
78
+ for any non-validated class, so allocating one would only fail at submit time.
79
+ """
80
+ provider = get_provider("runpod")
81
+ return [
82
+ Candidate("runpod", g.name, provider.hourly_rate(g.name), g.vram_gb)
83
+ for g in provider.gpu_classes()
84
+ if g.vram_gb >= need and g.validated
85
+ ]
86
+
87
+
88
+ def _lambda_candidates(need: int) -> list[Candidate]:
89
+ """Lambda's fitting classes that currently have LIVE capacity, priced live.
90
+
91
+ Capacity-aware by design: a Lambda class with no region advertising capacity is EXCLUDED, so
92
+ the allocator never hands the runner a Lambda class that would immediately fail to launch (and
93
+ burn a retry) — directly the "GPU allocation is good, doesn't randomly die" property. A Lambda
94
+ capacity-lookup failure (no key / network blip) degrades to the other providers: it is
95
+ non-fatal as long as another provider can supply a fitting class.
96
+ """
97
+ from flash.providers.lambdalabs.jobs import usable_instances
98
+
99
+ provider = get_provider("lambda")
100
+ out: list[Candidate] = []
101
+ try:
102
+ for g in provider.gpu_classes():
103
+ if g.vram_gb < need:
104
+ continue
105
+ # usable_instances reads the cached /instance-types, so only the first call hits the API.
106
+ if usable_instances(g.name):
107
+ out.append(Candidate("lambda", g.name, provider.hourly_rate(g.name), g.vram_gb))
108
+ except Exception as exc:
109
+ logger.warning("lambda capacity lookup failed (%s); allocating without lambda", exc)
110
+ return []
111
+ return out
112
+
113
+
114
+ def _hyperstack_candidates(need: int) -> list[Candidate]:
115
+ """Hyperstack's fitting classes that currently have flavor STOCK, priced statically.
116
+
117
+ Capacity-aware, exactly like Lambda: a class with no in-stock flavor is excluded so the runner
118
+ never walks onto a class that would immediately fail to launch. A capacity-lookup failure
119
+ degrades to the other providers.
120
+ """
121
+ from flash.providers.hyperstack.jobs import usable_instances
122
+
123
+ provider = get_provider("hyperstack")
124
+ out: list[Candidate] = []
125
+ try:
126
+ for g in provider.gpu_classes():
127
+ if g.vram_gb < need:
128
+ continue
129
+ # usable_instances reads the cached /core/flavors, so only the first call hits the API.
130
+ if usable_instances(g.name):
131
+ out.append(Candidate("hyperstack", g.name, provider.hourly_rate(g.name), g.vram_gb))
132
+ except Exception as exc:
133
+ logger.warning("hyperstack capacity lookup failed (%s); allocating without hyperstack", exc)
134
+ return []
135
+ return out
136
+
137
+
138
+ def allocate(
139
+ model_id: str,
140
+ algorithm: str,
141
+ *,
142
+ train=None,
143
+ thinking: bool = False,
144
+ ) -> Allocation:
145
+ """Pick the cheapest fitting (provider, GPU class) able to run the job.
146
+
147
+ There is no GPU pin — every fitting class on every available provider is eligible, and the
148
+ cheapest wins. RunPod is restricted to its validated pool (``GpuClass.validated``) because the
149
+ deployed control plane rejects a submit for any non-validated class; the instance providers
150
+ (Lambda via LAMBDA_API_KEY, Hyperstack via HYPERSTACK_API_KEY — both opt-in) each contribute
151
+ their fitting classes that currently have live capacity/stock. RunPod's cheaper static rates
152
+ usually win, with Lambda and Hyperstack joining as capacity complements lower in the ranked list.
153
+ ``train``/``thinking`` size the requirement to the run's actual knobs (context, group, rank,
154
+ batch) via the matrix.
155
+ """
156
+ need = required_vram_gb(model_id, algorithm, train=train, thinking=thinking)
157
+ available = available_providers()
158
+ candidates: list[Candidate] = []
159
+ if "runpod" in available:
160
+ candidates += _runpod_candidates(need)
161
+ if "lambda" in available:
162
+ candidates += _lambda_candidates(need)
163
+ if "hyperstack" in available:
164
+ candidates += _hyperstack_candidates(need)
165
+ if not candidates:
166
+ raise UnsupportedGpuError(
167
+ f"no allocatable GPU (>= {need} GB VRAM for {model_id}) on any available provider "
168
+ f"({', '.join(available) or '(none)'}); the run genuinely exceeds every active GPU class"
169
+ )
170
+ # Cheapest first; equal rates prefer less VRAM (don't burn a big card on a small job),
171
+ # then registry order.
172
+ order = {n: i for i, n in enumerate(PROVIDER_NAMES)}
173
+ ranked = sorted(candidates, key=lambda c: (c.hourly_usd, c.vram_gb, order.get(c.provider, 99)))
174
+ best = ranked[0]
175
+ return Allocation(
176
+ provider=best.provider,
177
+ gpu=best.gpu,
178
+ hourly_usd=best.hourly_usd,
179
+ min_vram_gb=need,
180
+ candidates=tuple(ranked),
181
+ )
182
+
183
+
184
+ def allocation_summary(a: Allocation) -> str:
185
+ head = (
186
+ f"allocated {a.gpu} on {a.provider} at ${a.hourly_usd:.2f}/hr "
187
+ f"(need >= {a.min_vram_gb} GB VRAM"
188
+ )
189
+ head += ")"
190
+ if len(a.candidates) > 1:
191
+ nxt = a.candidates[1]
192
+ head += f"; next-best: {nxt.gpu}@{nxt.provider} ${nxt.hourly_usd:.2f}/hr"
193
+ return head