freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,431 @@
1
+ """Shared GPU-provider interface + GPU registry.
2
+
3
+ RunPod is the managed GPU substrate. This module owns the parts that are not specific to the
4
+ RunPod transport:
5
+
6
+ * ``GpuClass`` — one managed GPU class with its RunPod Flash identity.
7
+ * ``JobHandle`` / ``PollResult`` — the persisted-handle + poll-outcome shapes the
8
+ orchestrator round-trips through the provider.
9
+ * ``Candidate`` / ``Allocation`` — the allocation result.
10
+ * The canonicalization / alias / policy helpers every call site already used.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from dataclasses import dataclass, field
16
+ from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
17
+
18
+ if TYPE_CHECKING:
19
+ from collections.abc import Callable
20
+
21
+ from flash.spec import JobSpec
22
+
23
+
24
+ # ---------------------------------------------------------------------------
25
+ # GPU class registry (provider-agnostic identity)
26
+ # ---------------------------------------------------------------------------
27
+ @dataclass(frozen=True)
28
+ class GpuClass:
29
+ """One managed RunPod GPU class: a friendly name + RunPod Flash identity/metadata."""
30
+
31
+ name: str # canonical friendly name used in configs / the catalog
32
+ enum_member: str | None # runpod_flash GpuType member name; None -> not on RunPod
33
+ vram_gb: int
34
+ short: str # endpoint-name-safe token (e.g. "4090", "a5000")
35
+ sm: str # CUDA arch (informational; sm80+ only)
36
+ hourly_usd: float # static rate used by pricing, cost projection, and ranking
37
+ # Min host CUDA (driver) on the modern stack. None -> 12.8. Blackwell (sm120/sm100)
38
+ # needs CUDA-13 drivers to JIT the wheels' PTX (no SASS shipped).
39
+ min_cuda_modern: str | None = None
40
+ # Whether this class has passed Flash's LIVE validation smoke (a real train+eval run on the
41
+ # card). The deployed control plane REJECTS a submit for a non-validated class ("gpu type 'X'
42
+ # has not passed Flash's live validation smoke"), so client-side allocation restricts to the
43
+ # validated pool by default (see ``validated_classes`` / allocator) — otherwise a default
44
+ # `flash train` could pick the absolute-cheapest fitting class (e.g. "L4") that the
45
+ # server then refuses, and the run never submits. Exactly the smoke-validated members below
46
+ # are marked True.
47
+ validated: bool = False
48
+ # Lambda Cloud instance-type name for this class (e.g. "gpu_1x_a10"); None -> not on Lambda.
49
+ # Lambda is the instance-based complement to RunPod's serverless substrate: a class with a
50
+ # ``lambda_name`` is provisionable on Lambda (capacity permitting), priced from Lambda's own
51
+ # live ``/instance-types`` rate (NOT the RunPod ``hourly_usd`` snapshot above).
52
+ lambda_name: str | None = None
53
+ # Hyperstack single-GPU flavor name for this class (e.g. "n3-L40x1"); None -> not on Hyperstack.
54
+ # Same instance-based model as Lambda (cloud-init -> Docker); a class with a ``hyperstack_name``
55
+ # is provisionable on Hyperstack when its flavor has stock, priced from Hyperstack's static map.
56
+ hyperstack_name: str | None = None
57
+
58
+
59
+ # Static hourly rates are RunPod secure-cloud on-demand snapshots.
60
+ GPU_CLASSES: tuple[GpuClass, ...] = (
61
+ GpuClass(
62
+ "RTX 4090",
63
+ "NVIDIA_GEFORCE_RTX_4090",
64
+ 24,
65
+ "4090",
66
+ "sm89",
67
+ 0.69,
68
+ validated=True,
69
+ ),
70
+ GpuClass(
71
+ "RTX 5090",
72
+ "NVIDIA_GEFORCE_RTX_5090",
73
+ 32,
74
+ "5090",
75
+ "sm120",
76
+ 0.99,
77
+ min_cuda_modern="13.0",
78
+ validated=True,
79
+ ),
80
+ # ---- Ampere/Ada workstation + datacenter cards (cheap capacity pools) ----
81
+ # 24 GB is the floor: the sub-24 GB tiers (16 GB RTX A4000 / RTX 2000 Ada, 20 GB RTX A4500 /
82
+ # RTX 4000 Ada) were dropped — the 24 GB classes below are the smallest managed cards.
83
+ # (RTX 3090 was removed from the catalog — see git history.)
84
+ GpuClass("L4", "NVIDIA_L4", 24, "l4", "sm89", 0.39),
85
+ # Lambda-only 24 GB Ampere datacenter card (RunPod has no A10). Instance-based capacity
86
+ # complement: chosen by the allocator only when the cheaper RunPod 24 GB classes are out of
87
+ # capacity, so it never undercuts RunPod on price.
88
+ GpuClass("A10", None, 24, "a10", "sm86", 1.29, lambda_name="gpu_1x_a10"),
89
+ # Live-validated 2026-06-22: Qwen3.5-0.8B/9B SFT+GRPO train smokes (RunPod). The 48 GB tier that
90
+ # fills the 32->80 GB gap (e.g. 4B GRPO @ 35 GB) ~55% cheaper than the A100.
91
+ GpuClass(
92
+ "RTX A6000", "NVIDIA_RTX_A6000", 48, "a6000", "sm86", 0.49,
93
+ validated=True,
94
+ lambda_name="gpu_1x_a6000",
95
+ hyperstack_name="n3-RTX-A6000x1",
96
+ ),
97
+ GpuClass("A40", "NVIDIA_A40", 48, "a40", "sm86", 0.44),
98
+ GpuClass(
99
+ "RTX 6000 Ada",
100
+ "NVIDIA_RTX_6000_ADA_GENERATION",
101
+ 48,
102
+ "6000ada",
103
+ "sm89",
104
+ 0.77,
105
+ ),
106
+ # L40 48 GB (Ada, sm89): datacenter card on Hyperstack (+ Nebius/DO/Vultr/Scaleway), NOT on
107
+ # RunPod or Lambda. Hyperstack-only here. hourly_usd is the Hyperstack list price.
108
+ GpuClass("L40", None, 48, "l40", "sm89", 1.00, hyperstack_name="n3-L40x1"),
109
+ # Lambda-only 40 GB A100 (SXM4) — RunPod's A100s are all 80 GB, so this fills the 32->80 GB gap
110
+ # on Lambda (e.g. a 4B GRPO at ~35 GB) as an instance-based capacity complement.
111
+ GpuClass(
112
+ "A100 SXM 40GB", None, 40, "a100sxm40", "sm80", 1.99, lambda_name="gpu_1x_a100_sxm4"
113
+ ),
114
+ # ---- big-VRAM tier (9B bf16 GRPO, future >9B bf16) ----
115
+ # Validated 2026-06-11: 0.6B SFT smoke (phase6).
116
+ GpuClass(
117
+ "A100 PCIe",
118
+ "NVIDIA_A100_80GB_PCIe",
119
+ 80,
120
+ "a100pcie",
121
+ "sm80",
122
+ 1.39,
123
+ validated=True,
124
+ hyperstack_name="n3-A100x1",
125
+ ),
126
+ # Live-validated 2026-06-22: Qwen3.5 0.8B/MiniCPM/2B/9B SFT+GRPO train smokes (RunPod).
127
+ GpuClass(
128
+ "A100 SXM", "NVIDIA_A100_SXM4_80GB", 80, "a100sxm", "sm80", 1.49,
129
+ validated=True,
130
+ ),
131
+ # Live-validated 2026-06-22: MiniCPM/2B/4B SFT+GRPO train smokes (RunPod).
132
+ GpuClass(
133
+ "H100", "NVIDIA_H100_80GB_HBM3", 80, "h100", "sm90", 3.29,
134
+ validated=True,
135
+ lambda_name="gpu_1x_h100_pcie",
136
+ hyperstack_name="n3-H100x1",
137
+ ),
138
+ # Live-validated 2026-06-22: MiniCPM/2B/4B SFT+GRPO train smokes (RunPod, sm120/CUDA-13).
139
+ GpuClass(
140
+ "RTX Pro 6000",
141
+ "NVIDIA_RTX_PRO_6000_BLACKWELL_SERVER_EDITION",
142
+ 96,
143
+ "pro6000",
144
+ "sm120",
145
+ 2.09,
146
+ min_cuda_modern="13.0",
147
+ validated=True,
148
+ # NOT mapped to Hyperstack: this Blackwell class needs a CUDA-13 host driver, but Hyperstack
149
+ # only ships up to CUDA-12.8 (R570) images — it would boot then fail at worker setup. Re-add
150
+ # ``hyperstack_name="n3-RTX-PRO6000-SEx1"`` once a CUDA-13 Hyperstack image is available.
151
+ ),
152
+ )
153
+
154
+ GPU_INFO: dict[str, GpuClass] = {g.name: g for g in GPU_CLASSES}
155
+
156
+ # Canonical friendly names Flash exposes in configs / the catalog.
157
+ KNOWN = tuple(GPU_INFO)
158
+
159
+ # The names that have passed Flash's live validation smoke. Client-side allocation restricts to
160
+ # these by default because the deployed control plane REJECTS a submit for any class outside the
161
+ # pool ("gpu type 'X' has not passed Flash's live validation smoke"); allocating an unvalidated
162
+ # (e.g. absolute-cheapest) class would just make the server refuse the run. Kept in sync with the
163
+ # ``validated=True`` flags above.
164
+ VALIDATED = tuple(g.name for g in GPU_CLASSES if g.validated)
165
+
166
+
167
+ def _alias_keys(name: str) -> set[str]:
168
+ """All accepted spellings of a friendly name (lowercased)."""
169
+ base = name.lower()
170
+ keys = {base, base.replace(" ", ""), base.replace(" ", "_"), base.replace(" ", "-")}
171
+ if base.startswith("rtx "):
172
+ tail = base[4:]
173
+ keys |= {tail, tail.replace(" ", ""), tail.replace(" ", "_")}
174
+ keys.add(f"nvidia {base}")
175
+ return keys
176
+
177
+
178
+ _ALIASES: dict[str, str] = {}
179
+ for _info in GPU_INFO.values():
180
+ for _k in _alias_keys(_info.name):
181
+ _ALIASES[_k] = _info.name
182
+ # Spellings that don't fall out of the generic rules: full marketing names (what
183
+ # nvidia-smi / the RunPod API print) and historical Flash aliases.
184
+ _ALIASES.update(
185
+ {
186
+ "nvidia geforce rtx 4090": "RTX 4090",
187
+ "nvidia geforce rtx 5090": "RTX 5090",
188
+ "nvidia l4": "L4",
189
+ "nvidia a40": "A40",
190
+ "nvidia rtx 6000 ada generation": "RTX 6000 Ada",
191
+ "rtx 6000 ada generation": "RTX 6000 Ada",
192
+ "nvidia a100 80gb pcie": "A100 PCIe",
193
+ "a100 80gb pcie": "A100 PCIe",
194
+ "a100-80g-pcie": "A100 PCIe",
195
+ "nvidia a100-sxm4-80gb": "A100 SXM",
196
+ "a100-sxm4-80gb": "A100 SXM",
197
+ "a100": "A100 PCIe",
198
+ "nvidia h100 80gb hbm3": "H100",
199
+ "h100 80gb hbm3": "H100",
200
+ "rtx pro 6000 blackwell": "RTX Pro 6000",
201
+ "nvidia rtx pro 6000 blackwell server edition": "RTX Pro 6000",
202
+ }
203
+ )
204
+
205
+
206
+ class UnsupportedGpuError(ValueError):
207
+ pass
208
+
209
+
210
+ def canonical_gpu(name: str) -> str:
211
+ """Normalize a friendly GPU name to one of ``KNOWN``; raise otherwise."""
212
+ key = (name or "").strip().lower()
213
+ if key in _ALIASES:
214
+ return _ALIASES[key]
215
+ raise UnsupportedGpuError(
216
+ f'unsupported gpu {name!r}; Flash manages {", ".join(KNOWN)}'
217
+ )
218
+
219
+
220
+ def get_gpu_info(name: str) -> GpuClass:
221
+ return GPU_INFO[canonical_gpu(name)]
222
+
223
+
224
+ def providers_for(name: str) -> tuple[str, ...]:
225
+ """Providers that can provision this GPU class."""
226
+ info = get_gpu_info(name)
227
+ out: list[str] = []
228
+ if info.enum_member:
229
+ out.append("runpod")
230
+ if info.lambda_name:
231
+ out.append("lambda")
232
+ if info.hyperstack_name:
233
+ out.append("hyperstack")
234
+ return tuple(out)
235
+
236
+
237
+ def gpu_short(name: str) -> str:
238
+ """Short, endpoint-name-safe token for a GPU (e.g. '4090')."""
239
+ return get_gpu_info(name).short
240
+
241
+
242
+ def min_cuda_modern(name: str) -> str:
243
+ """Minimum host CUDA (driver) version for this GPU class on the modern stack."""
244
+ return get_gpu_info(name).min_cuda_modern or "12.8"
245
+
246
+
247
+ def cheapest_gpu(min_vram_gb: int) -> str:
248
+ """Cheapest validated RunPod GPU class with at least ``min_vram_gb`` VRAM.
249
+
250
+ RunPod-static by design so the result is always deployable via Flash, and offline resolution
251
+ stays deterministic. Restricted to the validated pool so the picked class matches what the
252
+ deployed control plane will actually accept — a non-validated class submits then gets rejected.
253
+ """
254
+ pool = [
255
+ g
256
+ for g in GPU_INFO.values()
257
+ if g.enum_member and g.vram_gb >= min_vram_gb and g.validated
258
+ ]
259
+ if not pool:
260
+ raise UnsupportedGpuError(
261
+ f"no validated RunPod-provisionable GPU class has >= {min_vram_gb} GB VRAM"
262
+ )
263
+ from flash.providers.runpod.pricing import hourly_rate
264
+
265
+ return min(pool, key=lambda g: (hourly_rate(g.name), g.vram_gb)).name
266
+
267
+
268
+ def provisional_gpu(
269
+ model_id: str,
270
+ algorithm: str = "sft",
271
+ *,
272
+ train=None,
273
+ thinking: bool = False,
274
+ ) -> str:
275
+ """The cheapest VALIDATED GPU class whose VRAM covers the model -- a parse-time provisional.
276
+
277
+ GPU pinning is gone: this picks the cheapest RunPod-provisionable class whose VRAM covers the
278
+ model, restricted to the validated pool (``cheapest_gpu``'s default) so the provisional
279
+ matches what the deployed control plane will accept. The submit-time allocator
280
+ (``flash.providers.allocator``) always re-resolves the cheapest fitting validated class; this
281
+ is the RunPod-static, offline-deterministic equivalent the schema uses for sizing/display.
282
+ """
283
+ from flash.engine.vram import model_required_vram_gb
284
+ from flash.providers.allocator import vram_headroom
285
+
286
+ # Honor FLASH_VRAM_HEADROOM so parse-time sizing matches the submit-time allocator exactly.
287
+ min_vram = model_required_vram_gb(
288
+ model_id,
289
+ algorithm,
290
+ train=train,
291
+ thinking=thinking,
292
+ headroom=vram_headroom(),
293
+ )
294
+ return cheapest_gpu(min_vram)
295
+
296
+
297
+ # ---------------------------------------------------------------------------
298
+ # Handles + poll outcomes (round-tripped through any provider)
299
+ # ---------------------------------------------------------------------------
300
+ @dataclass
301
+ class JobHandle:
302
+ """Provider-tagged, persisted handle: enough to reattach/cancel from any process.
303
+
304
+ The provider owns the rest of its handle shape (RunPod: endpoint_id/job_id). ``provider`` is
305
+ the routing key the orchestrator uses to dispatch poll/cancel/destroy through the registry.
306
+ """
307
+
308
+ provider: str
309
+ data: dict = field(default_factory=dict)
310
+
311
+ def to_dict(self) -> dict:
312
+ return {"provider": self.provider, **self.data}
313
+
314
+ @classmethod
315
+ def from_dict(cls, d: dict) -> JobHandle:
316
+ d = dict(d)
317
+ provider = d.pop("provider", "runpod")
318
+ return cls(provider=provider, data=d)
319
+
320
+
321
+ @dataclass
322
+ class PollResult:
323
+ ok: bool
324
+ metrics: dict | None = None
325
+ # "job_failed" : genuine worker/job code error (NOT retried)
326
+ # "job_preempted" : provider killed the worker (platform termination) -> infra-shaped, retried
327
+ # "no_capacity" : NEVER scheduled — no provider capacity for the pinned GPU class (job sat
328
+ # IN_QUEUE / the only worker stayed THROTTLED) -> infra-shaped, retried on the
329
+ # next-best GPU. Distinct from "stalled" (a worker WAS scheduled then stopped
330
+ # making progress) so the terminal message points at capacity, not worker health.
331
+ # "stalled" : a scheduled worker made no progress within the budget -> infra-shaped, retried
332
+ # "poll_error" : client-side polling / deploy breakdown -> infra-shaped, retried
333
+ failure: str | None = None
334
+ detail: str | None = None
335
+
336
+
337
+ # ---------------------------------------------------------------------------
338
+ # Allocation result
339
+ # ---------------------------------------------------------------------------
340
+ @dataclass(frozen=True)
341
+ class Candidate:
342
+ provider: str
343
+ gpu: str
344
+ hourly_usd: float
345
+ vram_gb: int
346
+
347
+
348
+ @dataclass(frozen=True)
349
+ class Allocation:
350
+ provider: str
351
+ gpu: str
352
+ hourly_usd: float
353
+ min_vram_gb: int
354
+ candidates: tuple[Candidate, ...] # full ranked list (retry walks this)
355
+
356
+
357
+ # ---------------------------------------------------------------------------
358
+ # The provider interface (FIXED method set both providers implement)
359
+ # ---------------------------------------------------------------------------
360
+ @runtime_checkable
361
+ class Provider(Protocol):
362
+ """The GPU-substrate interface implemented by ``providers/runpod``."""
363
+
364
+ name: str
365
+
366
+ def is_configured(self) -> bool:
367
+ """Whether this provider is usable right now (creds present, net reachable)."""
368
+ ...
369
+
370
+ def preflight(self, require_hf: bool = True) -> list[str]:
371
+ """Missing-config problems (empty list == ready). The control plane aggregates
372
+ these into one fail-fast error at startup."""
373
+ ...
374
+
375
+ def gpu_classes(self) -> list[GpuClass]:
376
+ """The GPU classes this provider can provision (its rows of the shared table)."""
377
+ ...
378
+
379
+ def hourly_rate(self, gpu: str) -> float:
380
+ """Static $/hr for one friendly GPU name."""
381
+ ...
382
+
383
+ def submit_run(
384
+ self,
385
+ spec: JobSpec,
386
+ seed: int,
387
+ *,
388
+ log: Any = None,
389
+ on_handle: Any = None,
390
+ attempt: int = 0,
391
+ runtime_secrets: dict[str, str] | None = None,
392
+ on_last_gpu: bool = False,
393
+ ) -> PollResult:
394
+ """Deploy/rent -> submit -> persist handle (via ``on_handle``) -> poll.
395
+
396
+ ``on_last_gpu`` is True when no further GPU attempt will be made after this one — either the
397
+ candidate list is exhausted or the retry budget is exhausted — so there is no next-best class to fall
398
+ to and capacity backstops should wait longer before giving up.
399
+ """
400
+ ...
401
+
402
+ def poll(self, handle: JobHandle, spec: JobSpec, seed: int, *, log: Any = None) -> PollResult:
403
+ """Reattach to a persisted handle and poll it to a terminal state."""
404
+ ...
405
+
406
+ def cancel(self, handle: JobHandle) -> None:
407
+ """Stop the exact remote worker for this handle (cross-process)."""
408
+ ...
409
+
410
+ def destroy(self, handle: JobHandle) -> None:
411
+ """Tear down the billable resource this handle owns (idempotent)."""
412
+ ...
413
+
414
+ def gc(self, spec: JobSpec) -> None:
415
+ """Best-effort: reap any resource this run may have left registered."""
416
+ ...
417
+
418
+ def sweep_orphans(
419
+ self, active_labels: set[str] | Callable[[], set[str]] | None = None
420
+ ) -> list[int | str]:
421
+ """Destroy any billable resource this provider owns that no live run claims.
422
+
423
+ Crash recovery: run at server startup (and after runs). ``active_labels`` is the set of
424
+ RAW run ids still owned by live runs — each instance provider derives its own instance-label
425
+ prefix from them via ``run_label_prefix`` and reaps anything matching none of them. It may
426
+ instead be a CALLABLE returning that set, which the instance providers resolve AFTER listing
427
+ their resources (the periodic in-lifetime sweep passes one to close the launch race — see the
428
+ instance ``sweep_orphans``). Returns the destroyed resource ids (RunPod uses int ids; the
429
+ instance providers use opaque string ids). Providers without a standing-billing substrate
430
+ (RunPod's serverless endpoints self-reap) implement this as a no-op."""
431
+ ...
@@ -0,0 +1,127 @@
1
+ """Hyperstack (NexGen Cloud) provider: single-GPU VMs bootstrapped via cloud-init (a second
2
+ instance-based complement to RunPod, alongside Lambda).
3
+
4
+ Fine-tuning runs on a Hyperstack GPU VM launched by Flash. The VM's cloud-init ``user_data`` runs
5
+ the prebuilt, PUBLIC ``WORKER_IMAGE`` via Docker (Hyperstack's Docker-preinstalled Ubuntu/CUDA boot
6
+ image), which executes ``flash.engine.worker`` on the GPU; completion is detected from the worker's
7
+ HF artifacts. It implements the SAME ``base.Provider`` interface as RunPod/Lambda.
8
+
9
+ ``PROVIDER`` is the ``base.Provider`` implementation the registry hands out.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from collections.abc import Callable
15
+ from typing import Any
16
+
17
+ from flash.providers.base import GpuClass, JobHandle, PollResult, Provider
18
+
19
+
20
+ class HyperstackProvider:
21
+ """``base.Provider`` for the Hyperstack substrate."""
22
+
23
+ name = "hyperstack"
24
+
25
+ def is_configured(self) -> bool:
26
+ from flash.providers.hyperstack.auth import load_api_key
27
+
28
+ # Opt-in: available only when HYPERSTACK_API_KEY is present (else allocation degrades to the
29
+ # other configured providers).
30
+ return load_api_key() is not None
31
+
32
+ def preflight(self, require_hf: bool = True) -> list[str]:
33
+ from flash.providers.hyperstack.preflight import missing_credentials
34
+
35
+ return missing_credentials(require_hf=require_hf)
36
+
37
+ def gpu_classes(self) -> list[GpuClass]:
38
+ from flash.providers.hyperstack.gpus import gpu_classes
39
+
40
+ return gpu_classes()
41
+
42
+ def hourly_rate(self, gpu: str) -> float:
43
+ from flash.providers.hyperstack.pricing import hourly_rate
44
+
45
+ return hourly_rate(gpu)
46
+
47
+ def submit_run(
48
+ self,
49
+ spec,
50
+ seed: int,
51
+ *,
52
+ log: Any = None,
53
+ on_handle: Any = None,
54
+ attempt: int = 0,
55
+ runtime_secrets: dict[str, str] | None = None,
56
+ on_last_gpu: bool = False,
57
+ ) -> PollResult:
58
+ # ``on_last_gpu`` stretches the setup/no-capacity grace when no further GPU attempt will be
59
+ # made after this one — either the candidate list is exhausted or the retry budget is exhausted.
60
+ from flash.providers.hyperstack.jobs import submit_run_hyperstack
61
+
62
+ return submit_run_hyperstack(
63
+ spec, seed, log=log, on_handle=on_handle, attempt=attempt,
64
+ runtime_secrets=runtime_secrets, on_last_gpu=on_last_gpu,
65
+ )
66
+
67
+ def poll(self, handle: JobHandle, spec, seed: int, *, log: Any = None) -> PollResult:
68
+ import contextlib
69
+
70
+ from flash.providers.hyperstack import api as hs_api
71
+ from flash.providers.hyperstack.jobs import (
72
+ PROVISION_GRACE_S,
73
+ HyperstackJobHandle,
74
+ poll_hs_job,
75
+ )
76
+ from flash.providers.runpod.jobs import make_hf_heartbeat_reader
77
+
78
+ hf_repo = spec.train.hf_repo
79
+ prefix = f"{spec.phase}/{spec.run_id}/seed{seed}"
80
+ reader = make_hf_heartbeat_reader(hf_repo, prefix) if hf_repo else None
81
+ hh = HyperstackJobHandle.from_dict(handle.to_dict())
82
+ if log is not None:
83
+ print(f"attaching: hyperstack vm={hh.vm_id}", file=log, flush=True)
84
+ # Deadline counts from LAUNCH, not this reattach (no server-side timeout, so a restart must
85
+ # not extend the billable window). The poll loop already anchors its deadline check to
86
+ # ``handle.started_ts`` (start = launch), so we pass the FULL launch-relative budget;
87
+ # pre-subtracting elapsed too would double-count and delete a still-valid VM the moment a
88
+ # recovered run is past half its window.
89
+ deadline = max(60.0, int(spec.gpu.max_wall_seconds) + PROVISION_GRACE_S)
90
+ try:
91
+ return poll_hs_job(hh, spec, seed, log=log, heartbeat_reader=reader, deadline_s=deadline)
92
+ finally:
93
+ # Recovery has no submit_run_hyperstack teardown ``finally``; delete the reattached VM
94
+ # here so a finished/abandoned recovered seed stops billing immediately.
95
+ with contextlib.suppress(Exception):
96
+ hs_api.delete_vm(hh.vm_id)
97
+
98
+ def cancel(self, handle: JobHandle) -> None:
99
+ from flash.providers.hyperstack import api as hs_api
100
+
101
+ d = handle.to_dict()
102
+ if d.get("vm_id"):
103
+ hs_api.delete_vm(str(d["vm_id"]))
104
+
105
+ def destroy(self, handle: JobHandle) -> None:
106
+ from flash.providers.hyperstack import api as hs_api
107
+
108
+ d = handle.to_dict()
109
+ if d.get("vm_id"):
110
+ hs_api.delete_vm(str(d["vm_id"]))
111
+
112
+ def gc(self, spec) -> None:
113
+ from flash.providers.hyperstack.jobs import terminate_run_instances
114
+
115
+ terminate_run_instances(spec.run_id)
116
+
117
+ def sweep_orphans(
118
+ self, active_labels: set[str] | Callable[[], set[str]] | None = None
119
+ ) -> list[str]:
120
+ """Hyperstack VM ids are opaque STRINGS (the ``base.Provider`` protocol widens the return to
121
+ ``list[int | str]`` to cover both substrates); the orchestrator only logs/counts them."""
122
+ from flash.providers.hyperstack.jobs import sweep_orphans
123
+
124
+ return sweep_orphans(active_labels=active_labels)
125
+
126
+
127
+ PROVIDER: Provider = HyperstackProvider()