freesolo-flash-dev 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flash/__init__.py +29 -0
- flash/_channel.py +23 -0
- flash/_fileio.py +35 -0
- flash/_logging.py +49 -0
- flash/_update_check.py +266 -0
- flash/catalog.py +253 -0
- flash/cli/__init__.py +1 -0
- flash/cli/main/__init__.py +227 -0
- flash/cli/main/__main__.py +6 -0
- flash/cli/main/commands.py +636 -0
- flash/cli/main/envpush.py +317 -0
- flash/cli/main/render.py +599 -0
- flash/cli/main/training_doc.py +455 -0
- flash/client/__init__.py +14 -0
- flash/client/config.py +70 -0
- flash/client/http.py +372 -0
- flash/client/runtime_secrets.py +69 -0
- flash/client/specs.py +20 -0
- flash/cost/__init__.py +16 -0
- flash/cost/analytical.py +175 -0
- flash/cost/facts.py +114 -0
- flash/cost/spec.py +113 -0
- flash/cost/types.py +158 -0
- flash/engine/__init__.py +6 -0
- flash/engine/accounting.py +36 -0
- flash/engine/chalk_kernels.py +116 -0
- flash/engine/multiturn_rollout.py +780 -0
- flash/engine/recipe.py +86 -0
- flash/engine/vram.py +603 -0
- flash/engine/worker/__init__.py +2916 -0
- flash/engine/worker/__main__.py +4 -0
- flash/engine/worker/kernel_warmup.py +400 -0
- flash/engine/worker/lora.py +796 -0
- flash/engine/worker/packing.py +366 -0
- flash/engine/worker/perf.py +1048 -0
- flash/envs/__init__.py +10 -0
- flash/envs/adapter/__init__.py +883 -0
- flash/envs/adapter/rubric.py +222 -0
- flash/envs/base.py +52 -0
- flash/envs/registry.py +62 -0
- flash/mcp/__init__.py +1 -0
- flash/mcp/server.py +85 -0
- flash/providers/__init__.py +59 -0
- flash/providers/_auth.py +24 -0
- flash/providers/_http.py +230 -0
- flash/providers/_instance.py +416 -0
- flash/providers/_instance_bootstrap.py +517 -0
- flash/providers/_poll.py +311 -0
- flash/providers/allocator.py +193 -0
- flash/providers/base.py +431 -0
- flash/providers/hyperstack/__init__.py +127 -0
- flash/providers/hyperstack/api.py +522 -0
- flash/providers/hyperstack/auth.py +17 -0
- flash/providers/hyperstack/gpus.py +29 -0
- flash/providers/hyperstack/jobs/__init__.py +632 -0
- flash/providers/hyperstack/jobs/builders.py +122 -0
- flash/providers/hyperstack/preflight.py +23 -0
- flash/providers/hyperstack/pricing.py +26 -0
- flash/providers/hyperstack/train.py +25 -0
- flash/providers/lambdalabs/__init__.py +139 -0
- flash/providers/lambdalabs/api.py +261 -0
- flash/providers/lambdalabs/auth.py +18 -0
- flash/providers/lambdalabs/gpus.py +29 -0
- flash/providers/lambdalabs/jobs/__init__.py +724 -0
- flash/providers/lambdalabs/jobs/builders.py +118 -0
- flash/providers/lambdalabs/preflight.py +27 -0
- flash/providers/lambdalabs/pricing.py +51 -0
- flash/providers/lambdalabs/train.py +27 -0
- flash/providers/preflight.py +55 -0
- flash/providers/realized.py +80 -0
- flash/providers/runpod/__init__.py +130 -0
- flash/providers/runpod/api.py +186 -0
- flash/providers/runpod/auth.py +37 -0
- flash/providers/runpod/cost.py +57 -0
- flash/providers/runpod/gpus.py +46 -0
- flash/providers/runpod/jobs.py +956 -0
- flash/providers/runpod/keys.py +139 -0
- flash/providers/runpod/preflight.py +30 -0
- flash/providers/runpod/preload.py +915 -0
- flash/providers/runpod/pricing.py +18 -0
- flash/providers/runpod/slots.py +79 -0
- flash/providers/runpod/train/__init__.py +150 -0
- flash/providers/runpod/train/deps.py +395 -0
- flash/providers/runpod/train/endpoints.py +820 -0
- flash/py.typed +0 -0
- flash/runner/__init__.py +686 -0
- flash/runner/checkpoints.py +82 -0
- flash/runner/deploy.py +422 -0
- flash/runner/lifecycle.py +672 -0
- flash/schema/__init__.py +375 -0
- flash/schema/fields.py +331 -0
- flash/serve/__init__.py +1 -0
- flash/serve/deploy.py +326 -0
- flash/serve/pricing.py +60 -0
- flash/server/__init__.py +1 -0
- flash/server/__main__.py +20 -0
- flash/server/app.py +961 -0
- flash/server/auth.py +263 -0
- flash/server/billing.py +124 -0
- flash/server/checkpoints.py +110 -0
- flash/server/db.py +160 -0
- flash/server/environment_registry.py +102 -0
- flash/server/envs.py +360 -0
- flash/server/reconcile.py +163 -0
- flash/server/run_registry.py +150 -0
- flash/spec.py +333 -0
- freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
- freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
- freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
- freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
- freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
flash/providers/base.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
1
|
+
"""Shared GPU-provider interface + GPU registry.
|
|
2
|
+
|
|
3
|
+
RunPod is the managed GPU substrate. This module owns the parts that are not specific to the
|
|
4
|
+
RunPod transport:
|
|
5
|
+
|
|
6
|
+
* ``GpuClass`` — one managed GPU class with its RunPod Flash identity.
|
|
7
|
+
* ``JobHandle`` / ``PollResult`` — the persisted-handle + poll-outcome shapes the
|
|
8
|
+
orchestrator round-trips through the provider.
|
|
9
|
+
* ``Candidate`` / ``Allocation`` — the allocation result.
|
|
10
|
+
* The canonicalization / alias / policy helpers every call site already used.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from dataclasses import dataclass, field
|
|
16
|
+
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from collections.abc import Callable
|
|
20
|
+
|
|
21
|
+
from flash.spec import JobSpec
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# ---------------------------------------------------------------------------
|
|
25
|
+
# GPU class registry (provider-agnostic identity)
|
|
26
|
+
# ---------------------------------------------------------------------------
|
|
27
|
+
@dataclass(frozen=True)
|
|
28
|
+
class GpuClass:
|
|
29
|
+
"""One managed RunPod GPU class: a friendly name + RunPod Flash identity/metadata."""
|
|
30
|
+
|
|
31
|
+
name: str # canonical friendly name used in configs / the catalog
|
|
32
|
+
enum_member: str | None # runpod_flash GpuType member name; None -> not on RunPod
|
|
33
|
+
vram_gb: int
|
|
34
|
+
short: str # endpoint-name-safe token (e.g. "4090", "a5000")
|
|
35
|
+
sm: str # CUDA arch (informational; sm80+ only)
|
|
36
|
+
hourly_usd: float # static rate used by pricing, cost projection, and ranking
|
|
37
|
+
# Min host CUDA (driver) on the modern stack. None -> 12.8. Blackwell (sm120/sm100)
|
|
38
|
+
# needs CUDA-13 drivers to JIT the wheels' PTX (no SASS shipped).
|
|
39
|
+
min_cuda_modern: str | None = None
|
|
40
|
+
# Whether this class has passed Flash's LIVE validation smoke (a real train+eval run on the
|
|
41
|
+
# card). The deployed control plane REJECTS a submit for a non-validated class ("gpu type 'X'
|
|
42
|
+
# has not passed Flash's live validation smoke"), so client-side allocation restricts to the
|
|
43
|
+
# validated pool by default (see ``validated_classes`` / allocator) — otherwise a default
|
|
44
|
+
# `flash train` could pick the absolute-cheapest fitting class (e.g. "L4") that the
|
|
45
|
+
# server then refuses, and the run never submits. Exactly the smoke-validated members below
|
|
46
|
+
# are marked True.
|
|
47
|
+
validated: bool = False
|
|
48
|
+
# Lambda Cloud instance-type name for this class (e.g. "gpu_1x_a10"); None -> not on Lambda.
|
|
49
|
+
# Lambda is the instance-based complement to RunPod's serverless substrate: a class with a
|
|
50
|
+
# ``lambda_name`` is provisionable on Lambda (capacity permitting), priced from Lambda's own
|
|
51
|
+
# live ``/instance-types`` rate (NOT the RunPod ``hourly_usd`` snapshot above).
|
|
52
|
+
lambda_name: str | None = None
|
|
53
|
+
# Hyperstack single-GPU flavor name for this class (e.g. "n3-L40x1"); None -> not on Hyperstack.
|
|
54
|
+
# Same instance-based model as Lambda (cloud-init -> Docker); a class with a ``hyperstack_name``
|
|
55
|
+
# is provisionable on Hyperstack when its flavor has stock, priced from Hyperstack's static map.
|
|
56
|
+
hyperstack_name: str | None = None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# Static hourly rates are RunPod secure-cloud on-demand snapshots.
|
|
60
|
+
GPU_CLASSES: tuple[GpuClass, ...] = (
|
|
61
|
+
GpuClass(
|
|
62
|
+
"RTX 4090",
|
|
63
|
+
"NVIDIA_GEFORCE_RTX_4090",
|
|
64
|
+
24,
|
|
65
|
+
"4090",
|
|
66
|
+
"sm89",
|
|
67
|
+
0.69,
|
|
68
|
+
validated=True,
|
|
69
|
+
),
|
|
70
|
+
GpuClass(
|
|
71
|
+
"RTX 5090",
|
|
72
|
+
"NVIDIA_GEFORCE_RTX_5090",
|
|
73
|
+
32,
|
|
74
|
+
"5090",
|
|
75
|
+
"sm120",
|
|
76
|
+
0.99,
|
|
77
|
+
min_cuda_modern="13.0",
|
|
78
|
+
validated=True,
|
|
79
|
+
),
|
|
80
|
+
# ---- Ampere/Ada workstation + datacenter cards (cheap capacity pools) ----
|
|
81
|
+
# 24 GB is the floor: the sub-24 GB tiers (16 GB RTX A4000 / RTX 2000 Ada, 20 GB RTX A4500 /
|
|
82
|
+
# RTX 4000 Ada) were dropped — the 24 GB classes below are the smallest managed cards.
|
|
83
|
+
# (RTX 3090 was removed from the catalog — see git history.)
|
|
84
|
+
GpuClass("L4", "NVIDIA_L4", 24, "l4", "sm89", 0.39),
|
|
85
|
+
# Lambda-only 24 GB Ampere datacenter card (RunPod has no A10). Instance-based capacity
|
|
86
|
+
# complement: chosen by the allocator only when the cheaper RunPod 24 GB classes are out of
|
|
87
|
+
# capacity, so it never undercuts RunPod on price.
|
|
88
|
+
GpuClass("A10", None, 24, "a10", "sm86", 1.29, lambda_name="gpu_1x_a10"),
|
|
89
|
+
# Live-validated 2026-06-22: Qwen3.5-0.8B/9B SFT+GRPO train smokes (RunPod). The 48 GB tier that
|
|
90
|
+
# fills the 32->80 GB gap (e.g. 4B GRPO @ 35 GB) ~55% cheaper than the A100.
|
|
91
|
+
GpuClass(
|
|
92
|
+
"RTX A6000", "NVIDIA_RTX_A6000", 48, "a6000", "sm86", 0.49,
|
|
93
|
+
validated=True,
|
|
94
|
+
lambda_name="gpu_1x_a6000",
|
|
95
|
+
hyperstack_name="n3-RTX-A6000x1",
|
|
96
|
+
),
|
|
97
|
+
GpuClass("A40", "NVIDIA_A40", 48, "a40", "sm86", 0.44),
|
|
98
|
+
GpuClass(
|
|
99
|
+
"RTX 6000 Ada",
|
|
100
|
+
"NVIDIA_RTX_6000_ADA_GENERATION",
|
|
101
|
+
48,
|
|
102
|
+
"6000ada",
|
|
103
|
+
"sm89",
|
|
104
|
+
0.77,
|
|
105
|
+
),
|
|
106
|
+
# L40 48 GB (Ada, sm89): datacenter card on Hyperstack (+ Nebius/DO/Vultr/Scaleway), NOT on
|
|
107
|
+
# RunPod or Lambda. Hyperstack-only here. hourly_usd is the Hyperstack list price.
|
|
108
|
+
GpuClass("L40", None, 48, "l40", "sm89", 1.00, hyperstack_name="n3-L40x1"),
|
|
109
|
+
# Lambda-only 40 GB A100 (SXM4) — RunPod's A100s are all 80 GB, so this fills the 32->80 GB gap
|
|
110
|
+
# on Lambda (e.g. a 4B GRPO at ~35 GB) as an instance-based capacity complement.
|
|
111
|
+
GpuClass(
|
|
112
|
+
"A100 SXM 40GB", None, 40, "a100sxm40", "sm80", 1.99, lambda_name="gpu_1x_a100_sxm4"
|
|
113
|
+
),
|
|
114
|
+
# ---- big-VRAM tier (9B bf16 GRPO, future >9B bf16) ----
|
|
115
|
+
# Validated 2026-06-11: 0.6B SFT smoke (phase6).
|
|
116
|
+
GpuClass(
|
|
117
|
+
"A100 PCIe",
|
|
118
|
+
"NVIDIA_A100_80GB_PCIe",
|
|
119
|
+
80,
|
|
120
|
+
"a100pcie",
|
|
121
|
+
"sm80",
|
|
122
|
+
1.39,
|
|
123
|
+
validated=True,
|
|
124
|
+
hyperstack_name="n3-A100x1",
|
|
125
|
+
),
|
|
126
|
+
# Live-validated 2026-06-22: Qwen3.5 0.8B/MiniCPM/2B/9B SFT+GRPO train smokes (RunPod).
|
|
127
|
+
GpuClass(
|
|
128
|
+
"A100 SXM", "NVIDIA_A100_SXM4_80GB", 80, "a100sxm", "sm80", 1.49,
|
|
129
|
+
validated=True,
|
|
130
|
+
),
|
|
131
|
+
# Live-validated 2026-06-22: MiniCPM/2B/4B SFT+GRPO train smokes (RunPod).
|
|
132
|
+
GpuClass(
|
|
133
|
+
"H100", "NVIDIA_H100_80GB_HBM3", 80, "h100", "sm90", 3.29,
|
|
134
|
+
validated=True,
|
|
135
|
+
lambda_name="gpu_1x_h100_pcie",
|
|
136
|
+
hyperstack_name="n3-H100x1",
|
|
137
|
+
),
|
|
138
|
+
# Live-validated 2026-06-22: MiniCPM/2B/4B SFT+GRPO train smokes (RunPod, sm120/CUDA-13).
|
|
139
|
+
GpuClass(
|
|
140
|
+
"RTX Pro 6000",
|
|
141
|
+
"NVIDIA_RTX_PRO_6000_BLACKWELL_SERVER_EDITION",
|
|
142
|
+
96,
|
|
143
|
+
"pro6000",
|
|
144
|
+
"sm120",
|
|
145
|
+
2.09,
|
|
146
|
+
min_cuda_modern="13.0",
|
|
147
|
+
validated=True,
|
|
148
|
+
# NOT mapped to Hyperstack: this Blackwell class needs a CUDA-13 host driver, but Hyperstack
|
|
149
|
+
# only ships up to CUDA-12.8 (R570) images — it would boot then fail at worker setup. Re-add
|
|
150
|
+
# ``hyperstack_name="n3-RTX-PRO6000-SEx1"`` once a CUDA-13 Hyperstack image is available.
|
|
151
|
+
),
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
GPU_INFO: dict[str, GpuClass] = {g.name: g for g in GPU_CLASSES}
|
|
155
|
+
|
|
156
|
+
# Canonical friendly names Flash exposes in configs / the catalog.
|
|
157
|
+
KNOWN = tuple(GPU_INFO)
|
|
158
|
+
|
|
159
|
+
# The names that have passed Flash's live validation smoke. Client-side allocation restricts to
|
|
160
|
+
# these by default because the deployed control plane REJECTS a submit for any class outside the
|
|
161
|
+
# pool ("gpu type 'X' has not passed Flash's live validation smoke"); allocating an unvalidated
|
|
162
|
+
# (e.g. absolute-cheapest) class would just make the server refuse the run. Kept in sync with the
|
|
163
|
+
# ``validated=True`` flags above.
|
|
164
|
+
VALIDATED = tuple(g.name for g in GPU_CLASSES if g.validated)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _alias_keys(name: str) -> set[str]:
|
|
168
|
+
"""All accepted spellings of a friendly name (lowercased)."""
|
|
169
|
+
base = name.lower()
|
|
170
|
+
keys = {base, base.replace(" ", ""), base.replace(" ", "_"), base.replace(" ", "-")}
|
|
171
|
+
if base.startswith("rtx "):
|
|
172
|
+
tail = base[4:]
|
|
173
|
+
keys |= {tail, tail.replace(" ", ""), tail.replace(" ", "_")}
|
|
174
|
+
keys.add(f"nvidia {base}")
|
|
175
|
+
return keys
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
_ALIASES: dict[str, str] = {}
|
|
179
|
+
for _info in GPU_INFO.values():
|
|
180
|
+
for _k in _alias_keys(_info.name):
|
|
181
|
+
_ALIASES[_k] = _info.name
|
|
182
|
+
# Spellings that don't fall out of the generic rules: full marketing names (what
|
|
183
|
+
# nvidia-smi / the RunPod API print) and historical Flash aliases.
|
|
184
|
+
_ALIASES.update(
|
|
185
|
+
{
|
|
186
|
+
"nvidia geforce rtx 4090": "RTX 4090",
|
|
187
|
+
"nvidia geforce rtx 5090": "RTX 5090",
|
|
188
|
+
"nvidia l4": "L4",
|
|
189
|
+
"nvidia a40": "A40",
|
|
190
|
+
"nvidia rtx 6000 ada generation": "RTX 6000 Ada",
|
|
191
|
+
"rtx 6000 ada generation": "RTX 6000 Ada",
|
|
192
|
+
"nvidia a100 80gb pcie": "A100 PCIe",
|
|
193
|
+
"a100 80gb pcie": "A100 PCIe",
|
|
194
|
+
"a100-80g-pcie": "A100 PCIe",
|
|
195
|
+
"nvidia a100-sxm4-80gb": "A100 SXM",
|
|
196
|
+
"a100-sxm4-80gb": "A100 SXM",
|
|
197
|
+
"a100": "A100 PCIe",
|
|
198
|
+
"nvidia h100 80gb hbm3": "H100",
|
|
199
|
+
"h100 80gb hbm3": "H100",
|
|
200
|
+
"rtx pro 6000 blackwell": "RTX Pro 6000",
|
|
201
|
+
"nvidia rtx pro 6000 blackwell server edition": "RTX Pro 6000",
|
|
202
|
+
}
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class UnsupportedGpuError(ValueError):
|
|
207
|
+
pass
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def canonical_gpu(name: str) -> str:
|
|
211
|
+
"""Normalize a friendly GPU name to one of ``KNOWN``; raise otherwise."""
|
|
212
|
+
key = (name or "").strip().lower()
|
|
213
|
+
if key in _ALIASES:
|
|
214
|
+
return _ALIASES[key]
|
|
215
|
+
raise UnsupportedGpuError(
|
|
216
|
+
f'unsupported gpu {name!r}; Flash manages {", ".join(KNOWN)}'
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def get_gpu_info(name: str) -> GpuClass:
|
|
221
|
+
return GPU_INFO[canonical_gpu(name)]
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def providers_for(name: str) -> tuple[str, ...]:
|
|
225
|
+
"""Providers that can provision this GPU class."""
|
|
226
|
+
info = get_gpu_info(name)
|
|
227
|
+
out: list[str] = []
|
|
228
|
+
if info.enum_member:
|
|
229
|
+
out.append("runpod")
|
|
230
|
+
if info.lambda_name:
|
|
231
|
+
out.append("lambda")
|
|
232
|
+
if info.hyperstack_name:
|
|
233
|
+
out.append("hyperstack")
|
|
234
|
+
return tuple(out)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def gpu_short(name: str) -> str:
|
|
238
|
+
"""Short, endpoint-name-safe token for a GPU (e.g. '4090')."""
|
|
239
|
+
return get_gpu_info(name).short
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def min_cuda_modern(name: str) -> str:
|
|
243
|
+
"""Minimum host CUDA (driver) version for this GPU class on the modern stack."""
|
|
244
|
+
return get_gpu_info(name).min_cuda_modern or "12.8"
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def cheapest_gpu(min_vram_gb: int) -> str:
|
|
248
|
+
"""Cheapest validated RunPod GPU class with at least ``min_vram_gb`` VRAM.
|
|
249
|
+
|
|
250
|
+
RunPod-static by design so the result is always deployable via Flash, and offline resolution
|
|
251
|
+
stays deterministic. Restricted to the validated pool so the picked class matches what the
|
|
252
|
+
deployed control plane will actually accept — a non-validated class submits then gets rejected.
|
|
253
|
+
"""
|
|
254
|
+
pool = [
|
|
255
|
+
g
|
|
256
|
+
for g in GPU_INFO.values()
|
|
257
|
+
if g.enum_member and g.vram_gb >= min_vram_gb and g.validated
|
|
258
|
+
]
|
|
259
|
+
if not pool:
|
|
260
|
+
raise UnsupportedGpuError(
|
|
261
|
+
f"no validated RunPod-provisionable GPU class has >= {min_vram_gb} GB VRAM"
|
|
262
|
+
)
|
|
263
|
+
from flash.providers.runpod.pricing import hourly_rate
|
|
264
|
+
|
|
265
|
+
return min(pool, key=lambda g: (hourly_rate(g.name), g.vram_gb)).name
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def provisional_gpu(
|
|
269
|
+
model_id: str,
|
|
270
|
+
algorithm: str = "sft",
|
|
271
|
+
*,
|
|
272
|
+
train=None,
|
|
273
|
+
thinking: bool = False,
|
|
274
|
+
) -> str:
|
|
275
|
+
"""The cheapest VALIDATED GPU class whose VRAM covers the model -- a parse-time provisional.
|
|
276
|
+
|
|
277
|
+
GPU pinning is gone: this picks the cheapest RunPod-provisionable class whose VRAM covers the
|
|
278
|
+
model, restricted to the validated pool (``cheapest_gpu``'s default) so the provisional
|
|
279
|
+
matches what the deployed control plane will accept. The submit-time allocator
|
|
280
|
+
(``flash.providers.allocator``) always re-resolves the cheapest fitting validated class; this
|
|
281
|
+
is the RunPod-static, offline-deterministic equivalent the schema uses for sizing/display.
|
|
282
|
+
"""
|
|
283
|
+
from flash.engine.vram import model_required_vram_gb
|
|
284
|
+
from flash.providers.allocator import vram_headroom
|
|
285
|
+
|
|
286
|
+
# Honor FLASH_VRAM_HEADROOM so parse-time sizing matches the submit-time allocator exactly.
|
|
287
|
+
min_vram = model_required_vram_gb(
|
|
288
|
+
model_id,
|
|
289
|
+
algorithm,
|
|
290
|
+
train=train,
|
|
291
|
+
thinking=thinking,
|
|
292
|
+
headroom=vram_headroom(),
|
|
293
|
+
)
|
|
294
|
+
return cheapest_gpu(min_vram)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
# ---------------------------------------------------------------------------
|
|
298
|
+
# Handles + poll outcomes (round-tripped through any provider)
|
|
299
|
+
# ---------------------------------------------------------------------------
|
|
300
|
+
@dataclass
|
|
301
|
+
class JobHandle:
|
|
302
|
+
"""Provider-tagged, persisted handle: enough to reattach/cancel from any process.
|
|
303
|
+
|
|
304
|
+
The provider owns the rest of its handle shape (RunPod: endpoint_id/job_id). ``provider`` is
|
|
305
|
+
the routing key the orchestrator uses to dispatch poll/cancel/destroy through the registry.
|
|
306
|
+
"""
|
|
307
|
+
|
|
308
|
+
provider: str
|
|
309
|
+
data: dict = field(default_factory=dict)
|
|
310
|
+
|
|
311
|
+
def to_dict(self) -> dict:
|
|
312
|
+
return {"provider": self.provider, **self.data}
|
|
313
|
+
|
|
314
|
+
@classmethod
|
|
315
|
+
def from_dict(cls, d: dict) -> JobHandle:
|
|
316
|
+
d = dict(d)
|
|
317
|
+
provider = d.pop("provider", "runpod")
|
|
318
|
+
return cls(provider=provider, data=d)
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
@dataclass
|
|
322
|
+
class PollResult:
|
|
323
|
+
ok: bool
|
|
324
|
+
metrics: dict | None = None
|
|
325
|
+
# "job_failed" : genuine worker/job code error (NOT retried)
|
|
326
|
+
# "job_preempted" : provider killed the worker (platform termination) -> infra-shaped, retried
|
|
327
|
+
# "no_capacity" : NEVER scheduled — no provider capacity for the pinned GPU class (job sat
|
|
328
|
+
# IN_QUEUE / the only worker stayed THROTTLED) -> infra-shaped, retried on the
|
|
329
|
+
# next-best GPU. Distinct from "stalled" (a worker WAS scheduled then stopped
|
|
330
|
+
# making progress) so the terminal message points at capacity, not worker health.
|
|
331
|
+
# "stalled" : a scheduled worker made no progress within the budget -> infra-shaped, retried
|
|
332
|
+
# "poll_error" : client-side polling / deploy breakdown -> infra-shaped, retried
|
|
333
|
+
failure: str | None = None
|
|
334
|
+
detail: str | None = None
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
# ---------------------------------------------------------------------------
|
|
338
|
+
# Allocation result
|
|
339
|
+
# ---------------------------------------------------------------------------
|
|
340
|
+
@dataclass(frozen=True)
|
|
341
|
+
class Candidate:
|
|
342
|
+
provider: str
|
|
343
|
+
gpu: str
|
|
344
|
+
hourly_usd: float
|
|
345
|
+
vram_gb: int
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
@dataclass(frozen=True)
|
|
349
|
+
class Allocation:
|
|
350
|
+
provider: str
|
|
351
|
+
gpu: str
|
|
352
|
+
hourly_usd: float
|
|
353
|
+
min_vram_gb: int
|
|
354
|
+
candidates: tuple[Candidate, ...] # full ranked list (retry walks this)
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
# ---------------------------------------------------------------------------
|
|
358
|
+
# The provider interface (FIXED method set both providers implement)
|
|
359
|
+
# ---------------------------------------------------------------------------
|
|
360
|
+
@runtime_checkable
|
|
361
|
+
class Provider(Protocol):
|
|
362
|
+
"""The GPU-substrate interface implemented by ``providers/runpod``."""
|
|
363
|
+
|
|
364
|
+
name: str
|
|
365
|
+
|
|
366
|
+
def is_configured(self) -> bool:
|
|
367
|
+
"""Whether this provider is usable right now (creds present, net reachable)."""
|
|
368
|
+
...
|
|
369
|
+
|
|
370
|
+
def preflight(self, require_hf: bool = True) -> list[str]:
|
|
371
|
+
"""Missing-config problems (empty list == ready). The control plane aggregates
|
|
372
|
+
these into one fail-fast error at startup."""
|
|
373
|
+
...
|
|
374
|
+
|
|
375
|
+
def gpu_classes(self) -> list[GpuClass]:
|
|
376
|
+
"""The GPU classes this provider can provision (its rows of the shared table)."""
|
|
377
|
+
...
|
|
378
|
+
|
|
379
|
+
def hourly_rate(self, gpu: str) -> float:
|
|
380
|
+
"""Static $/hr for one friendly GPU name."""
|
|
381
|
+
...
|
|
382
|
+
|
|
383
|
+
def submit_run(
|
|
384
|
+
self,
|
|
385
|
+
spec: JobSpec,
|
|
386
|
+
seed: int,
|
|
387
|
+
*,
|
|
388
|
+
log: Any = None,
|
|
389
|
+
on_handle: Any = None,
|
|
390
|
+
attempt: int = 0,
|
|
391
|
+
runtime_secrets: dict[str, str] | None = None,
|
|
392
|
+
on_last_gpu: bool = False,
|
|
393
|
+
) -> PollResult:
|
|
394
|
+
"""Deploy/rent -> submit -> persist handle (via ``on_handle``) -> poll.
|
|
395
|
+
|
|
396
|
+
``on_last_gpu`` is True when no further GPU attempt will be made after this one — either the
|
|
397
|
+
candidate list is exhausted or the retry budget is exhausted — so there is no next-best class to fall
|
|
398
|
+
to and capacity backstops should wait longer before giving up.
|
|
399
|
+
"""
|
|
400
|
+
...
|
|
401
|
+
|
|
402
|
+
def poll(self, handle: JobHandle, spec: JobSpec, seed: int, *, log: Any = None) -> PollResult:
|
|
403
|
+
"""Reattach to a persisted handle and poll it to a terminal state."""
|
|
404
|
+
...
|
|
405
|
+
|
|
406
|
+
def cancel(self, handle: JobHandle) -> None:
|
|
407
|
+
"""Stop the exact remote worker for this handle (cross-process)."""
|
|
408
|
+
...
|
|
409
|
+
|
|
410
|
+
def destroy(self, handle: JobHandle) -> None:
|
|
411
|
+
"""Tear down the billable resource this handle owns (idempotent)."""
|
|
412
|
+
...
|
|
413
|
+
|
|
414
|
+
def gc(self, spec: JobSpec) -> None:
|
|
415
|
+
"""Best-effort: reap any resource this run may have left registered."""
|
|
416
|
+
...
|
|
417
|
+
|
|
418
|
+
def sweep_orphans(
|
|
419
|
+
self, active_labels: set[str] | Callable[[], set[str]] | None = None
|
|
420
|
+
) -> list[int | str]:
|
|
421
|
+
"""Destroy any billable resource this provider owns that no live run claims.
|
|
422
|
+
|
|
423
|
+
Crash recovery: run at server startup (and after runs). ``active_labels`` is the set of
|
|
424
|
+
RAW run ids still owned by live runs — each instance provider derives its own instance-label
|
|
425
|
+
prefix from them via ``run_label_prefix`` and reaps anything matching none of them. It may
|
|
426
|
+
instead be a CALLABLE returning that set, which the instance providers resolve AFTER listing
|
|
427
|
+
their resources (the periodic in-lifetime sweep passes one to close the launch race — see the
|
|
428
|
+
instance ``sweep_orphans``). Returns the destroyed resource ids (RunPod uses int ids; the
|
|
429
|
+
instance providers use opaque string ids). Providers without a standing-billing substrate
|
|
430
|
+
(RunPod's serverless endpoints self-reap) implement this as a no-op."""
|
|
431
|
+
...
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""Hyperstack (NexGen Cloud) provider: single-GPU VMs bootstrapped via cloud-init (a second
|
|
2
|
+
instance-based complement to RunPod, alongside Lambda).
|
|
3
|
+
|
|
4
|
+
Fine-tuning runs on a Hyperstack GPU VM launched by Flash. The VM's cloud-init ``user_data`` runs
|
|
5
|
+
the prebuilt, PUBLIC ``WORKER_IMAGE`` via Docker (Hyperstack's Docker-preinstalled Ubuntu/CUDA boot
|
|
6
|
+
image), which executes ``flash.engine.worker`` on the GPU; completion is detected from the worker's
|
|
7
|
+
HF artifacts. It implements the SAME ``base.Provider`` interface as RunPod/Lambda.
|
|
8
|
+
|
|
9
|
+
``PROVIDER`` is the ``base.Provider`` implementation the registry hands out.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from collections.abc import Callable
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
from flash.providers.base import GpuClass, JobHandle, PollResult, Provider
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class HyperstackProvider:
|
|
21
|
+
"""``base.Provider`` for the Hyperstack substrate."""
|
|
22
|
+
|
|
23
|
+
name = "hyperstack"
|
|
24
|
+
|
|
25
|
+
def is_configured(self) -> bool:
|
|
26
|
+
from flash.providers.hyperstack.auth import load_api_key
|
|
27
|
+
|
|
28
|
+
# Opt-in: available only when HYPERSTACK_API_KEY is present (else allocation degrades to the
|
|
29
|
+
# other configured providers).
|
|
30
|
+
return load_api_key() is not None
|
|
31
|
+
|
|
32
|
+
def preflight(self, require_hf: bool = True) -> list[str]:
|
|
33
|
+
from flash.providers.hyperstack.preflight import missing_credentials
|
|
34
|
+
|
|
35
|
+
return missing_credentials(require_hf=require_hf)
|
|
36
|
+
|
|
37
|
+
def gpu_classes(self) -> list[GpuClass]:
|
|
38
|
+
from flash.providers.hyperstack.gpus import gpu_classes
|
|
39
|
+
|
|
40
|
+
return gpu_classes()
|
|
41
|
+
|
|
42
|
+
def hourly_rate(self, gpu: str) -> float:
|
|
43
|
+
from flash.providers.hyperstack.pricing import hourly_rate
|
|
44
|
+
|
|
45
|
+
return hourly_rate(gpu)
|
|
46
|
+
|
|
47
|
+
def submit_run(
|
|
48
|
+
self,
|
|
49
|
+
spec,
|
|
50
|
+
seed: int,
|
|
51
|
+
*,
|
|
52
|
+
log: Any = None,
|
|
53
|
+
on_handle: Any = None,
|
|
54
|
+
attempt: int = 0,
|
|
55
|
+
runtime_secrets: dict[str, str] | None = None,
|
|
56
|
+
on_last_gpu: bool = False,
|
|
57
|
+
) -> PollResult:
|
|
58
|
+
# ``on_last_gpu`` stretches the setup/no-capacity grace when no further GPU attempt will be
|
|
59
|
+
# made after this one — either the candidate list is exhausted or the retry budget is exhausted.
|
|
60
|
+
from flash.providers.hyperstack.jobs import submit_run_hyperstack
|
|
61
|
+
|
|
62
|
+
return submit_run_hyperstack(
|
|
63
|
+
spec, seed, log=log, on_handle=on_handle, attempt=attempt,
|
|
64
|
+
runtime_secrets=runtime_secrets, on_last_gpu=on_last_gpu,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def poll(self, handle: JobHandle, spec, seed: int, *, log: Any = None) -> PollResult:
|
|
68
|
+
import contextlib
|
|
69
|
+
|
|
70
|
+
from flash.providers.hyperstack import api as hs_api
|
|
71
|
+
from flash.providers.hyperstack.jobs import (
|
|
72
|
+
PROVISION_GRACE_S,
|
|
73
|
+
HyperstackJobHandle,
|
|
74
|
+
poll_hs_job,
|
|
75
|
+
)
|
|
76
|
+
from flash.providers.runpod.jobs import make_hf_heartbeat_reader
|
|
77
|
+
|
|
78
|
+
hf_repo = spec.train.hf_repo
|
|
79
|
+
prefix = f"{spec.phase}/{spec.run_id}/seed{seed}"
|
|
80
|
+
reader = make_hf_heartbeat_reader(hf_repo, prefix) if hf_repo else None
|
|
81
|
+
hh = HyperstackJobHandle.from_dict(handle.to_dict())
|
|
82
|
+
if log is not None:
|
|
83
|
+
print(f"attaching: hyperstack vm={hh.vm_id}", file=log, flush=True)
|
|
84
|
+
# Deadline counts from LAUNCH, not this reattach (no server-side timeout, so a restart must
|
|
85
|
+
# not extend the billable window). The poll loop already anchors its deadline check to
|
|
86
|
+
# ``handle.started_ts`` (start = launch), so we pass the FULL launch-relative budget;
|
|
87
|
+
# pre-subtracting elapsed too would double-count and delete a still-valid VM the moment a
|
|
88
|
+
# recovered run is past half its window.
|
|
89
|
+
deadline = max(60.0, int(spec.gpu.max_wall_seconds) + PROVISION_GRACE_S)
|
|
90
|
+
try:
|
|
91
|
+
return poll_hs_job(hh, spec, seed, log=log, heartbeat_reader=reader, deadline_s=deadline)
|
|
92
|
+
finally:
|
|
93
|
+
# Recovery has no submit_run_hyperstack teardown ``finally``; delete the reattached VM
|
|
94
|
+
# here so a finished/abandoned recovered seed stops billing immediately.
|
|
95
|
+
with contextlib.suppress(Exception):
|
|
96
|
+
hs_api.delete_vm(hh.vm_id)
|
|
97
|
+
|
|
98
|
+
def cancel(self, handle: JobHandle) -> None:
|
|
99
|
+
from flash.providers.hyperstack import api as hs_api
|
|
100
|
+
|
|
101
|
+
d = handle.to_dict()
|
|
102
|
+
if d.get("vm_id"):
|
|
103
|
+
hs_api.delete_vm(str(d["vm_id"]))
|
|
104
|
+
|
|
105
|
+
def destroy(self, handle: JobHandle) -> None:
|
|
106
|
+
from flash.providers.hyperstack import api as hs_api
|
|
107
|
+
|
|
108
|
+
d = handle.to_dict()
|
|
109
|
+
if d.get("vm_id"):
|
|
110
|
+
hs_api.delete_vm(str(d["vm_id"]))
|
|
111
|
+
|
|
112
|
+
def gc(self, spec) -> None:
|
|
113
|
+
from flash.providers.hyperstack.jobs import terminate_run_instances
|
|
114
|
+
|
|
115
|
+
terminate_run_instances(spec.run_id)
|
|
116
|
+
|
|
117
|
+
def sweep_orphans(
|
|
118
|
+
self, active_labels: set[str] | Callable[[], set[str]] | None = None
|
|
119
|
+
) -> list[str]:
|
|
120
|
+
"""Hyperstack VM ids are opaque STRINGS (the ``base.Provider`` protocol widens the return to
|
|
121
|
+
``list[int | str]`` to cover both substrates); the orchestrator only logs/counts them."""
|
|
122
|
+
from flash.providers.hyperstack.jobs import sweep_orphans
|
|
123
|
+
|
|
124
|
+
return sweep_orphans(active_labels=active_labels)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
PROVIDER: Provider = HyperstackProvider()
|