freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,956 @@
1
+ """Durable run primitives: explicit deploy -> submit -> poll with a persisted job handle.
2
+
3
+ Calling `runpod_flash`'s all-in-one blocking handler directly would tie a run's life to
4
+ one client process and one HTTP poll loop: a client crash/network blip orphans an
5
+ otherwise-healthy GPU job (no job id is ever persisted), and any SDK polling bug kills
6
+ the run. This module owns the lifecycle instead:
7
+
8
+ deploy_train_endpoint() -> endpoint_id (Flash SDK deploy, same worker template)
9
+ build_function_input() -> the exact FunctionRequest payload Flash workers expect
10
+ submit + poll_job() -> REST queue API with hardened retries; the job handle
11
+ {endpoint_id, job_id} is persisted by the runner so
12
+ any process can re-attach (`flash status --follow`).
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import asyncio
18
+ import base64
19
+ import contextlib
20
+ import json
21
+ import os
22
+ import threading
23
+ import time
24
+ from dataclasses import dataclass
25
+ from typing import TYPE_CHECKING
26
+
27
+ from flash._logging import get_logger
28
+
29
+ if TYPE_CHECKING:
30
+ from collections.abc import Callable
31
+ from flash.providers._poll import (
32
+ PollErrorTracker,
33
+ make_say,
34
+ surface_forced_heartbeat,
35
+ surface_heartbeat,
36
+ )
37
+ from flash.providers.base import PollResult, canonical_gpu
38
+ from flash.providers.runpod import api as runpod_api
39
+ from flash.providers.runpod.gpus import flash_gpu
40
+ from flash.providers.runpod.train import (
41
+ DEFAULT_EXECUTION_TIMEOUT_MS,
42
+ FLASH_SDK_LOCK,
43
+ WORKER_IMAGE,
44
+ WORKER_SYSTEM_DEPS,
45
+ _patch_runpod_backoff,
46
+ _train_body,
47
+ endpoint_name,
48
+ isolate_flash_state,
49
+ min_cuda_for,
50
+ resolve_worker_deps,
51
+ worker_image_for_gpu,
52
+ )
53
+
54
+ logger = get_logger(__name__)
55
+
56
+ # Re-export so callers/tests that did ``from ...jobs import PollResult`` keep working.
57
+ __all__ = [
58
+ "JobHandle",
59
+ "PollResult",
60
+ "apply_disk_gb",
61
+ "build_function_input",
62
+ "decode_output",
63
+ "deploy_train_endpoint",
64
+ "make_hf_failure_detail_reader",
65
+ "make_hf_heartbeat_reader",
66
+ "make_hf_text_reader",
67
+ "poll_job",
68
+ "submit_run",
69
+ "weight_cache_datacenters",
70
+ "weight_cache_endpoint_kwargs",
71
+ "weight_cache_volume_name",
72
+ "weight_cache_volumes",
73
+ ]
74
+
75
+ TERMINAL_OK = {"COMPLETED"}
76
+ # The provider killed the worker (reclaim/preempt/time-cap) -> infra-shaped, retried. A worker
77
+ # "FAILED" is the run dying on its own (real traceback) -> fails fast.
78
+ PLATFORM_TERMINATIONS = {"CANCELLED", "TIMED_OUT"}
79
+ TERMINAL_FAIL = {"FAILED"} | PLATFORM_TERMINATIONS
80
+
81
+ # Heartbeat stages the worker emits DURING cold start, BEFORE the model is loaded and the
82
+ # training loop begins (boot -> sft_start/rl_start, then later sft_model_load/rl_train_start).
83
+ # Receiving one proves the worker is alive but NOT that the slow setup (model download +
84
+ # vLLM init) finished, so they must not flip stall detection to the tight training window.
85
+ _SETUP_HEARTBEAT_STAGES = frozenset(
86
+ {"boot", "sft_start", "rl_start", "sft_model_load", "rl_train_start"}
87
+ )
88
+
89
+
90
+ def stall_kwargs(on_last_gpu: bool = False) -> dict:
91
+ """``poll_job`` stall-window kwargs, shared by the submit and reattach paths so a recovered
92
+ run uses the same tuning as the original submit. The original submit's ``on_last_gpu`` is
93
+ PERSISTED in the run handle (the runner's ``on_handle`` writes it into ``remote``), so a
94
+ cross-process reattach (``RunpodProvider.poll``) reads it back and calls this with the same
95
+ value — a last-candidate run keeps its longer no-capacity grace after a control-plane restart
96
+ instead of being judged on the shorter non-last window. ``stall_after_s`` = post-training-heartbeat
97
+ window; ``setup_grace_s`` = the larger cold-start window before the first training heartbeat;
98
+ ``queue_grace_s``/``throttled_grace_s`` = the two no-capacity backstops — how long a job may
99
+ sit IN_QUEUE with no worker (``queue_grace_s``) or wait on a worker stuck THROTTLED
100
+ (``throttled_grace_s``) before we treat the pinned GPU class as out of capacity and walk to
101
+ the next-best one.
102
+
103
+ These backstops are tuned to whether a further GPU attempt will follow. While a retry can still
104
+ fall to a next-best class (``on_last_gpu`` False) we wait ~5 min: long enough to ride out a brief
105
+ capacity blip, short enough that a genuinely starved class hands off to the next-best one
106
+ promptly. When no further GPU attempt will be made — the candidate list is exhausted OR the retry
107
+ budget is exhausted (``on_last_gpu`` True) — there is nowhere left to walk, so we wait ~15 min before
108
+ giving up: burning the last attempt on a class with no fallback (and no retry left to spend the
109
+ saved time on) is worse than waiting out a longer queue. Both are no-capacity backstops only:
110
+ once the job leaves IN_QUEUE (a worker picks it
111
+ up), the much larger ``setup_grace_s`` governs cold start and we never walk off an IN_PROGRESS
112
+ job at the capacity grace.
113
+ """
114
+ grace = 900.0 if on_last_gpu else 300.0
115
+ return {
116
+ "stall_after_s": 1500.0,
117
+ "setup_grace_s": 3000.0,
118
+ "queue_grace_s": grace,
119
+ "throttled_grace_s": grace,
120
+ }
121
+
122
+
123
+ # RunPod DCs in the SDK's ``DataCenter.all()`` enum that do NOT support network volumes. LIVE-FOUND:
124
+ # the SDK enum is NOT the network-volume DC set (that assumption was wrong) — creating a volume in one
125
+ # of these 500s the WHOLE deploy ("data center ... does not support network volumes"), so eager runs
126
+ # would always fall back to cold and the cache would never work. The SDK exposes no volume-capability
127
+ # flag, so we maintain the exclusion here. (If RunPod drops volume support in another DC, that deploy
128
+ # 500s -> the lifecycle no_capacity/poll_error cache-drop falls back to a cold cross-region run, so a
129
+ # stale list degrades gracefully rather than wedging — but add the DC here to restore its cache.)
130
+ _VOLUME_INCAPABLE_DATACENTERS = frozenset({"US-MO-1"})
131
+
132
+
133
+ def weight_cache_datacenters() -> list:
134
+ """Every VOLUME-CAPABLE RunPod DC — both the set the endpoint is allowed across AND the set we
135
+ attach a per-DC volume in (eager: a volume in every region a run can land in, so any landing is
136
+ warm). ``DataCenter.all()`` minus ``_VOLUME_INCAPABLE_DATACENTERS`` (the enum includes DCs RunPod
137
+ no longer backs with network volumes — see that constant). A SDK upgrade that adds a storage region
138
+ is picked up automatically; one that adds a volume-less region must be excluded above.
139
+ """
140
+ from runpod_flash.core.resources.datacenter import DataCenter
141
+
142
+ return [dc for dc in DataCenter.all() if dc.value not in _VOLUME_INCAPABLE_DATACENTERS]
143
+
144
+
145
+ def weight_cache_volume_name(base: str, dc) -> str:
146
+ """Physical volume name for ``base`` in datacenter ``dc`` — DISTINCT per DC.
147
+
148
+ The cache is one logical volume (``base`` == ``spec.gpu.network_volume``, e.g. ``flash-weights``)
149
+ realized as one physical volume per datacenter. The DC MUST be in the name: the runpod_flash SDK
150
+ keys its in-memory/persisted resource tracking on ``NetworkVolume:{name}`` WITHOUT the
151
+ datacenter (resources/base.py ``get_resource_key``), so N same-named volumes collide on one key
152
+ and deploying the 2nd triggers a replace -> the SDK's unimplemented ``NetworkVolume.undeploy`` ->
153
+ crash. A per-DC name gives each volume a unique key. The worker is unaffected — every volume
154
+ mounts at the same ``/runpod-volume`` regardless of name.
155
+ """
156
+ return f"{base}-{dc.value.lower()}"
157
+
158
+
159
+ def weight_cache_volumes(spec) -> list:
160
+ """One ``NetworkVolume`` per storage datacenter — the EAGER fleet (``[]`` only if the cache is off).
161
+
162
+ Empty unless ``spec.gpu.network_volume`` is set (the runner assigns the logical base name for
163
+ eligible runs). Otherwise one physical volume per ``weight_cache_datacenters()`` entry — every
164
+ storage DC, so the cache exists in whichever region the endpoint lands in. Each physical volume is
165
+ ``<base>-<dc>`` (see ``weight_cache_volume_name``), idempotent by (name, datacenter): runpod_flash
166
+ reuses an existing volume of that name/DC, so this is create-or-attach (the first deploy provisions
167
+ the whole fleet; later deploys just re-attach).
168
+
169
+ Multi-account pools: ``deploy_train_endpoint`` re-runs the WHOLE deploy on quota failover, so the
170
+ volumes are re-created on whichever account ends up hosting the endpoint (account-scoped). Orphans
171
+ on the failed-over-FROM account are reclaimed by ``preload --teardown`` (sweeps every pool account).
172
+ """
173
+ base = getattr(spec.gpu, "network_volume", None) if spec is not None else None
174
+ if not base:
175
+ return []
176
+ dcs = weight_cache_datacenters() # EAGER: a volume in every storage DC
177
+ if not dcs:
178
+ return []
179
+ from runpod_flash import NetworkVolume
180
+
181
+ from flash.spec import _volume_gb
182
+
183
+ # Reuse the spec's tolerant parser: a stale/hand-edited spec with a non-numeric, "0", or negative
184
+ # network_volume_gb defaults to 100 GB rather than raising (which best-effort would swallow into a
185
+ # no-cache deploy) or creating a nonsensical 0-GB volume — matches _volume_gb's contract/tests.
186
+ size = _volume_gb(getattr(spec.gpu, "network_volume_gb", 100))
187
+ return [
188
+ NetworkVolume(name=weight_cache_volume_name(str(base), dc), size=size, datacenter=dc)
189
+ for dc in dcs
190
+ ]
191
+
192
+
193
+ def weight_cache_endpoint_kwargs(spec) -> dict:
194
+ """Endpoint kwargs that attach the eager weight-cache fleet, or ``{}`` (best-effort).
195
+
196
+ ``{"volume": [vol per storage dc...], "datacenter": [ALL storage DCs]}`` — the endpoint is allowed
197
+ across ALL DCs (so it lands wherever there's capacity) AND carries a volume in every one of them, so
198
+ whichever DC it lands in is warm. The SDK's "every volume DC must be in the endpoint datacenter
199
+ list" rule holds exactly (the two lists are the same storage-DC set). The first deploy
200
+ create-or-attaches the whole fleet; later deploys re-attach.
201
+
202
+ Returns ``{}`` only when the cache is off (no ``network_volume`` on the spec). Best-effort: ANY
203
+ failure (SDK import, validation) -> ``{}`` so the run deploys with no volume rather than failing;
204
+ the lifecycle still drops the volume on a no_capacity retry to widen onto the non-storage DC pool.
205
+ """
206
+ try:
207
+ vols = weight_cache_volumes(spec)
208
+ if not vols:
209
+ return {} # cache off -> cold (no volumes, RunPod picks any region)
210
+ return {"volume": vols, "datacenter": weight_cache_datacenters()}
211
+ except Exception as exc:
212
+ # Best-effort: never let the cache break a deploy — fall back to a no-volume run.
213
+ logger.warning("weight cache disabled for this run (%s); deploying with no volume", exc)
214
+ return {}
215
+
216
+
217
+ def apply_disk_gb(config, disk_gb: int | None) -> None:
218
+ """Raise the worker's container disk on a built endpoint config.
219
+
220
+ The Flash SDK's ``PodTemplate.containerDiskInGb`` defaults to 64 GB and the
221
+ ``Endpoint`` wrapper exposes no disk knob, which is what blocked models whose
222
+ checkpoint alone exceeds 64 GB. The template
223
+ is already populated by the SDK's validators when the resource config is built, so
224
+ raising the field here is the supported injection point. Raise-only: shrinking
225
+ below the SDK default buys nothing (serverless disk isn't billed separately) and
226
+ would regress runs whose configs carry the historical ``disk_gb = 60`` default.
227
+ """
228
+ if not disk_gb:
229
+ return
230
+ template = getattr(config, "template", None)
231
+ if template is None:
232
+ logger.warning("disk_gb=%s requested but endpoint config has no template", disk_gb)
233
+ return
234
+ template.containerDiskInGb = max(int(disk_gb), int(template.containerDiskInGb or 0))
235
+
236
+
237
+ @dataclass
238
+ class JobHandle:
239
+ endpoint_id: str
240
+ endpoint_name: str
241
+ job_id: str
242
+
243
+ def to_dict(self) -> dict:
244
+ return {
245
+ "provider": "runpod",
246
+ "endpoint_id": self.endpoint_id,
247
+ "endpoint_name": self.endpoint_name,
248
+ "job_id": self.job_id,
249
+ }
250
+
251
+ @classmethod
252
+ def from_dict(cls, d: dict) -> JobHandle:
253
+ # `provider` is routing metadata consumed upstream (runner); handles
254
+ # persisted before it existed default to runpod there.
255
+ return cls(d["endpoint_id"], d.get("endpoint_name", ""), d["job_id"])
256
+
257
+
258
+ def _is_workers_quota_error(exc: Exception) -> bool:
259
+ """True when a RunPod exception signals the account worker quota is exhausted."""
260
+ msg = str(exc).lower()
261
+ return "max workers across all endpoints" in msg
262
+
263
+
264
+ # Per-endpoint "first observed idle" timestamps, so a candidate must STAY idle across sweeps for
265
+ # ``min_idle_s`` before deletion (a cold-starting / between-jobs endpoint reports a transient zero
266
+ # we must not act on). Pruned each sweep to the still-idle set, so it can't grow unbounded.
267
+ #
268
+ # Two threads can run a sweep at once — the periodic control-plane reaper (via asyncio.to_thread)
269
+ # and a deploy-time quota sweep — so every read/write of ``_idle_since`` is serialized by this lock
270
+ # (a dedicated lock, NOT FLASH_SDK_LOCK, since the sweep uses the REST API, not the Flash SDK).
271
+ # Holding it across the sweep also prevents a concurrent sweep's prune from disturbing this one's
272
+ # grace timers; contention is negligible (the reaper runs every 10 min, deploy sweeps are rare).
273
+ _idle_since: dict[str, float] = {}
274
+ _idle_since_lock = threading.Lock()
275
+
276
+
277
+ def _is_flash_endpoint(name: str) -> bool:
278
+ """True for a flash training endpoint this sweep may reap (matches the SDK's ``live-`` form).
279
+ Serving runs on freesolo's Modal app, not RunPod, so the only flash-* RunPod endpoints are
280
+ training endpoints."""
281
+ return name.removeprefix("live-").startswith("flash-")
282
+
283
+
284
+ def _sweep_idle_flash_endpoints(
285
+ protected: set[str], min_idle_s: float = 0.0, reap_warm: bool = True
286
+ ) -> int:
287
+ """Delete idle, ORPHANED flash training endpoints — workers doing nothing that still hold
288
+ RunPod worker quota (runs that finished/crashed without tearing their endpoint down). Returns
289
+ the count deleted.
290
+
291
+ Safe by construction:
292
+
293
+ - ``protected`` — endpoint names tied to a LIVE run (both the bare ``flash-...`` and the SDK's
294
+ ``live-flash-...`` form). Never deleted, even if momentarily idle (e.g. between seeds).
295
+ - ``reap_warm`` — when True (the run-aware periodic reaper, which protects EVERY live run),
296
+ a merely *warm* ``idle``/``ready`` worker left over after a job counts as doing nothing and
297
+ is reclaimable; that warm-idle state is the dominant leak, since RunPod keeps a worker warm
298
+ after each job. When False (the deploy-time reactive sweep, which only protects the current
299
+ run), a warm worker is treated as busy so the sweep reaps only endpoints that have FULLY
300
+ scaled to zero — it must not delete another live run's between-seeds warm endpoint.
301
+ - ``min_idle_s`` requires the idle reading to PERSIST across sweeps, so a single transient
302
+ zero (cold start / between jobs) never triggers a delete.
303
+ """
304
+ deleted = 0
305
+ try:
306
+ endpoints = runpod_api.list_endpoints()
307
+ except Exception:
308
+ logger.debug("idle-sweep: failed to list endpoints", exc_info=True)
309
+ return 0
310
+ now = time.time()
311
+ still_idle: set[str] = set()
312
+ # Serialize all _idle_since access (see the lock's definition): a concurrent sweep must not
313
+ # mutate the dict mid-iteration (the prune below would raise) or disturb these grace timers.
314
+ with _idle_since_lock:
315
+ for ep in endpoints:
316
+ ep_name = ep.get("name") or ""
317
+ eid = ep.get("id")
318
+ if not (eid and _is_flash_endpoint(ep_name)):
319
+ continue
320
+ # Protect the run's endpoint in either registered form.
321
+ if ep_name in protected or ep_name.removeprefix("live-") in protected:
322
+ continue
323
+ try:
324
+ health = runpod_api.endpoint_health(eid) or {}
325
+ workers = health.get("workers")
326
+ jobs_info = health.get("jobs")
327
+ # Require non-empty dicts: a missing/empty workers section means the health
328
+ # response is incomplete and we can't confirm the endpoint is idle.
329
+ if not isinstance(workers, dict) or not workers or not isinstance(jobs_info, dict):
330
+ continue
331
+ # "Busy" = a worker actually working or spinning up, OR a job queued/in progress.
332
+ # With reap_warm, a warm idle/ready worker with no pending work is NOT busy — it is
333
+ # the leftover we reclaim (the protected set + grace keep it safe). Without it, a
334
+ # warm worker counts as busy so only fully-scaled-to-zero endpoints are reaped.
335
+ busy_workers = (workers.get("running") or 0) + (workers.get("initializing") or 0)
336
+ if not reap_warm:
337
+ busy_workers += (workers.get("ready") or 0) + (workers.get("idle") or 0)
338
+ in_flight = (jobs_info.get("inQueue") or 0) + (jobs_info.get("inProgress") or 0)
339
+ if busy_workers != 0 or in_flight != 0:
340
+ _idle_since.pop(eid, None) # busy again -> reset the grace timer
341
+ continue
342
+ still_idle.add(eid)
343
+ first_idle = _idle_since.setdefault(eid, now)
344
+ if now - first_idle < min_idle_s:
345
+ continue # idle, but not for long enough yet — wait for the next sweep
346
+ if runpod_api.delete_endpoint(eid):
347
+ deleted += 1
348
+ _idle_since.pop(eid, None)
349
+ logger.info("idle-sweep: deleted idle endpoint %s (%s)", ep_name, eid)
350
+ except Exception:
351
+ logger.debug(
352
+ "idle-sweep: error processing endpoint %s (%s)", ep_name, eid, exc_info=True
353
+ )
354
+ continue
355
+ # Drop grace timers for endpoints no longer idle/present (busy, deleted, gone, protected).
356
+ for stale in set(_idle_since) - still_idle:
357
+ _idle_since.pop(stale, None)
358
+ return deleted
359
+
360
+
361
+ def deploy_train_endpoint(
362
+ friendly_gpu: str,
363
+ execution_timeout_ms: int | None = None,
364
+ name_suffix: str | None = None,
365
+ disk_gb: int | None = None,
366
+ spec=None,
367
+ endpoint_kwargs: dict | Callable[[], dict] | None = None,
368
+ ) -> tuple[str, str]:
369
+ """Deploy (or reuse) the run's uniquely-named worker endpoint; return (id, name).
370
+
371
+ On a worker-quota error, sweeps idle flash-* endpoints (from crashed/completed runs
372
+ that skipped GC) and retries up to ``_QUOTA_MAX_RETRIES`` times with backoff. If the
373
+ account's quota stays exhausted after sweeping and ``RUNPOD_API_KEY`` configures more
374
+ than one account, fails over to the next account (``keys.advance_key``) and deploys
375
+ there. A single key => single account, no failover (unchanged behavior).
376
+
377
+ ``endpoint_kwargs`` overrides the volume/datacenter attachment (default: the full multi-DC
378
+ weight-cache fleet from ``weight_cache_endpoint_kwargs(spec)``). The preload driver passes a
379
+ SINGLE-DC volume+datacenter so the worker provably lands in that region and warms its volume. It
380
+ may be a dict OR a zero-arg FACTORY: under a multi-key pool the deploy retries on the next account
381
+ after a quota failover, and the SDK can stamp an account-scoped id onto a NetworkVolume object —
382
+ so a callable is re-invoked per account to build a FRESH volume (else the next account reuses the
383
+ first account's stale volume id and the single-DC preload fails).
384
+ """
385
+ os.environ["FLASH_IS_LIVE_PROVISIONING"] = "true"
386
+ from runpod_flash import Endpoint
387
+ from runpod_flash.core.resources.resource_manager import ResourceManager
388
+
389
+ from flash.providers.runpod import keys as rp_keys
390
+ from flash.providers.runpod.auth import ensure_auth
391
+
392
+ _patch_runpod_backoff()
393
+ friendly = canonical_gpu(friendly_gpu)
394
+ name = endpoint_name(friendly, name_suffix)
395
+ # deploy a self-contained serverless-worker image directly. by default this is WORKER_IMAGE;
396
+ # when per-sm warmed images are enabled, the selected GPU class picks the matching image tag.
397
+ # FLASH_WORKER_IMAGE remains the absolute hotfix override.
398
+ image = worker_image_for_gpu(friendly, allow_default=True)
399
+
400
+ def _deploy_once():
401
+ """One get_or_deploy on the currently-active account (SDK + lock critical section)."""
402
+ # isolate_flash_state mutates runpod_flash's process-wide registry globals for this run's
403
+ # suffix, and ResourceManager + the deploy share the SDK's asyncio singleton. Hold the
404
+ # lock across the whole critical section so a concurrent run can't swap the registry
405
+ # scope or race the event loop mid-deploy.
406
+ with FLASH_SDK_LOCK:
407
+ isolate_flash_state(name_suffix)
408
+ kwargs = {
409
+ "name": name,
410
+ "gpu": flash_gpu(friendly),
411
+ "gpu_count": 1,
412
+ "min_cuda_version": min_cuda_for(friendly),
413
+ "execution_timeout_ms": execution_timeout_ms or DEFAULT_EXECUTION_TIMEOUT_MS,
414
+ "workers": (0, 1),
415
+ }
416
+ if image:
417
+ kwargs["image"] = image
418
+ else:
419
+ kwargs["dependencies"] = resolve_worker_deps()
420
+ kwargs["system_dependencies"] = WORKER_SYSTEM_DEPS
421
+ # Attach the multi-region weight cache (best-effort: {} when no cache / on any error).
422
+ # The endpoint is allowed across every cache DC, so it is NOT pinned to one region.
423
+ # A caller (preload) may override with a single-DC volume+datacenter.
424
+ # Resolve a factory FRESH on each account attempt (see docstring: avoids reusing a
425
+ # NetworkVolume the SDK stamped with the prior account's id across a quota failover).
426
+ override = endpoint_kwargs() if callable(endpoint_kwargs) else endpoint_kwargs
427
+ kwargs.update(override if override is not None else weight_cache_endpoint_kwargs(spec))
428
+ ep = Endpoint(**kwargs)
429
+ ep._qb_target = _train_body
430
+ config = ep._build_resource_config()
431
+ apply_disk_gb(config, disk_gb)
432
+ # Worker image is PUBLIC, so no container-registry credential is needed to pull it.
433
+ rm = ResourceManager()
434
+ return asyncio.run(rm.get_or_deploy_resource(config))
435
+
436
+ _QUOTA_MAX_RETRIES = 3
437
+ resource = None
438
+ # One pass over the pool: advance_key() WRAPS (always True for a multi-key pool, even after the
439
+ # last account), so without a bound an all-exhausted pool would fail over forever here. Cap the
440
+ # failovers at "every OTHER account once" and then raise — the lifecycle retry budget handles
441
+ # waiting for quota to recover and re-enters this with a fresh attempt.
442
+ failovers_left = max(0, rp_keys.key_count() - 1)
443
+ while resource is None:
444
+ ensure_auth() # collapse RUNPOD_API_KEY to the (possibly failed-over) active account key
445
+ quota_exc: Exception | None = None
446
+ for quota_attempt in range(_QUOTA_MAX_RETRIES):
447
+ if quota_attempt > 0:
448
+ # Under acute quota pressure, sweep idle orphaned flash training endpoints on THIS
449
+ # account NOW (min_idle_s=0) to free a slot. This only protects THIS run's endpoint,
450
+ # so it stays conservative (reap_warm=False): it reaps only endpoints fully scaled
451
+ # to zero, never another live run's between-seeds WARM endpoint. The control-plane
452
+ # periodic reaper does the run-aware, graced warm-idle sweep across all live runs.
453
+ swept = _sweep_idle_flash_endpoints(
454
+ protected={name, f"live-{name}"}, min_idle_s=0.0, reap_warm=False
455
+ )
456
+ wait_s = 30 * quota_attempt
457
+ logger.warning(
458
+ "RunPod worker quota hit (attempt %d/%d): swept %d idle flash-* endpoint(s); "
459
+ "retrying in %ds",
460
+ quota_attempt + 1, _QUOTA_MAX_RETRIES, swept, wait_s,
461
+ )
462
+ time.sleep(wait_s)
463
+ try:
464
+ resource = _deploy_once()
465
+ break # success
466
+ except Exception as exc:
467
+ if not _is_workers_quota_error(exc):
468
+ raise
469
+ quota_exc = exc # freeable: sweep + retry, then fail over to the next account
470
+ if resource is not None:
471
+ break
472
+ # Quota still exhausted after sweeping this account dry — fail over to the next one, but only
473
+ # until every account has been tried once (failovers_left). advance_key() wraps and always
474
+ # returns True for a multi-key pool, so the count — not its return value — is what stops us.
475
+ if failovers_left > 0 and rp_keys.advance_key():
476
+ failovers_left -= 1
477
+ logger.warning(
478
+ "RunPod worker quota exhausted on this account after sweeping; failing over to "
479
+ "the next RUNPOD_API_KEY account (%d configured)",
480
+ rp_keys.key_count(),
481
+ )
482
+ continue
483
+ raise quota_exc or RuntimeError("deploy_train_endpoint: worker quota exhausted")
484
+
485
+ endpoint_id = getattr(resource, "id", None)
486
+ if not endpoint_id:
487
+ raise RuntimeError(f"deploy_train_endpoint: no endpoint id on resource {resource!r}")
488
+ return endpoint_id, name
489
+
490
+
491
+ def build_function_input(payload: dict) -> dict:
492
+ """The FunctionRequest dict a Flash queue worker expects for `_train_body(payload)`."""
493
+ if os.environ.get("FLASH_WORKER_IMAGE") or WORKER_IMAGE:
494
+ # Baked serverless-worker image (client mode): the image's rp_handler reads job["input"]
495
+ # and calls _train_body, so the job input IS the train payload (submit_job wraps it in
496
+ # {"input": ...}). No live-function source, no boot-install deps.
497
+ return payload
498
+ # Boot-install fallback (Flash default image + live function): ship _train_body's source for the
499
+ # generic worker to run, plus the pinned worker deps to install on first use.
500
+ from runpod_flash.runtime.serialization import serialize_args
501
+ from runpod_flash.stubs.live_serverless import get_function_source
502
+
503
+ source, _src_hash = get_function_source(_train_body)
504
+ return {
505
+ "function_name": "_train_body",
506
+ "function_code": source,
507
+ "args": serialize_args((payload,)),
508
+ "accelerate_downloads": True,
509
+ "dependencies": resolve_worker_deps(),
510
+ "system_dependencies": WORKER_SYSTEM_DEPS,
511
+ }
512
+
513
+
514
+ def decode_output(output) -> dict:
515
+ """Decode a queue-job output into the worker's metrics dict. Handles BOTH job shapes:
516
+
517
+ - Flash LIVE-function (boot-install path): a FunctionResponse envelope
518
+ ``{"success": True, "result": <base64 cloudpickle of the dict>}``.
519
+ - Client-mode SERVERLESS handler (baked-image path): our baked rp_handler returns
520
+ ``_train_body(...)``'s metrics dict, which RunPod surfaces as ``job["output"]`` directly —
521
+ no envelope. The metrics dict has no ``success``/``result`` keys, so we return it as-is.
522
+ """
523
+ if isinstance(output, str):
524
+ try:
525
+ output = json.loads(output)
526
+ except json.JSONDecodeError as exc:
527
+ raise RuntimeError(f"unexpected job output: {output[:200]}") from exc
528
+ if not isinstance(output, dict):
529
+ raise RuntimeError(f"unexpected job output type: {type(output)}")
530
+ # Flash live-function envelope (has success/result/error keys).
531
+ if "success" in output or "result" in output:
532
+ if output.get("success") and output.get("result") is not None:
533
+ import cloudpickle
534
+
535
+ result = cloudpickle.loads(base64.b64decode(output["result"]))
536
+ if not isinstance(result, dict):
537
+ raise RuntimeError(f"flash job returned no metrics: {result!r}")
538
+ return result
539
+ err = output.get("error") or "unknown worker error"
540
+ stdout_tail = (output.get("stdout") or "")[-1500:]
541
+ raise RuntimeError(
542
+ f"Remote execution failed: {err}\n--- worker stdout tail ---\n{stdout_tail}"
543
+ )
544
+ # Client-mode serverless handler: the metrics dict IS the output (baked rp_handler).
545
+ if output.get("error"):
546
+ # Mirror the Flash path: append the worker stdout tail when present so poll_job's
547
+ # root-cause diagnostics (e.g. a vLLM crash) survive the client-mode failure shape too.
548
+ stdout_tail = (output.get("stdout") or "")[-1500:]
549
+ msg = f"Remote execution failed: {output['error']}"
550
+ if stdout_tail:
551
+ msg += f"\n--- worker stdout tail ---\n{stdout_tail}"
552
+ raise RuntimeError(msg)
553
+ return output
554
+
555
+
556
+ def _append_failure_artifacts(detail: str, failure_detail_reader) -> str:
557
+ """Append worker-uploaded failure artifacts to a RunPod terminal-status detail."""
558
+ if failure_detail_reader is None:
559
+ return detail
560
+ extra = failure_detail_reader(force=True)
561
+ if not extra:
562
+ return detail
563
+ if detail:
564
+ return f"{detail}\n{extra}"
565
+ return extra
566
+
567
+
568
+ def poll_job(
569
+ handle: JobHandle,
570
+ log=None,
571
+ interval_s: float = 10.0,
572
+ heartbeat_reader=None,
573
+ failure_detail_reader=None,
574
+ stall_after_s: float = 1200.0,
575
+ setup_grace_s: float = 3000.0,
576
+ unhealthy_grace_s: float = 240.0,
577
+ throttled_grace_s: float = 300.0,
578
+ queue_grace_s: float = 300.0,
579
+ deadline_s: float | None = None,
580
+ ) -> PollResult:
581
+ """Poll a queue job to completion; resilient to transient API errors.
582
+
583
+ Two stall windows: the cold-start phase (dep install, per-run env pip, model download,
584
+ vLLM init) is slow and only emits *setup* heartbeats (``_SETUP_HEARTBEAT_STAGES``).
585
+ Until a *training* heartbeat arrives we apply the larger ``setup_grace_s`` budget so a
586
+ slow cold start isn't misread as a stall; after it we use the tight ``stall_after_s``.
587
+ Needs a ``heartbeat_reader`` to tell the phases apart — without one we keep
588
+ ``stall_after_s`` throughout (no regression).
589
+
590
+ ``failure_detail_reader`` force-reads worker-uploaded artifacts (``error_<phase>.txt`` and
591
+ ``console_<phase>.txt``) after a worker terminal failure so a generic RunPod handler wrapper
592
+ does not hide the real traceback.
593
+
594
+ ``throttled_grace_s`` bounds how long we wait on a worker stuck THROTTLED (no RunPod
595
+ capacity for the pinned GPU class) before returning a retryable stall so the runner
596
+ walks to the next-best GPU. THROTTLED means there is no capacity for this class right now, so
597
+ once a class with a cheaper fallback has stayed throttled this long, failing over beats
598
+ blocking the run on a host that won't free up. ``stall_kwargs`` sets this to ~5 min while the
599
+ gpu-walk still has a next-best class, and ~15 min on the last candidate (nowhere left to walk).
600
+
601
+ ``queue_grace_s`` is the capacity backstop for that same walk when RunPod *doesn't* surface
602
+ a THROTTLED/UNHEALTHY worker: a job can sit IN_QUEUE with zero workers assigned (or one stuck
603
+ INITIALIZING, or while ``endpoint_health`` errors are swallowed below) and the throttled/
604
+ unhealthy fast-fails never arm — so without this it would burn the full ``setup_grace_s``
605
+ (~50 min). Keyed off the authoritative job status (robust to a failing health probe), it
606
+ returns a retryable stall once a job has been IN_QUEUE longer than ``queue_grace_s`` (tuned by
607
+ ``stall_kwargs`` like ``throttled_grace_s``: ~5 min normally, ~15 min on the last GPU class).
608
+ The queue timer applies only while the job status remains IN_QUEUE; once a worker picks the
609
+ job up (status leaves IN_QUEUE), it resets and ``setup_grace_s`` governs cold start.
610
+ """
611
+
612
+ say = make_say(log)
613
+ poll_errors = PollErrorTracker(say, interval_s)
614
+
615
+ start = time.time()
616
+ last_status = None
617
+ last_hb_key = None
618
+ last_progress = time.time()
619
+ seen_heartbeat = False
620
+ last_health_probe = 0.0
621
+ unhealthy_since: float | None = None # first time the worker was seen stuck UNHEALTHY
622
+ throttled_since: float | None = None # first time the worker was seen stuck THROTTLED
623
+ queued_since: float | None = None # first time the job was seen IN_QUEUE with no worker yet
624
+ while True:
625
+ if deadline_s is not None and time.time() - start > deadline_s:
626
+ return PollResult(False, failure="stalled", detail="client-side deadline exceeded")
627
+ try:
628
+ st = runpod_api.job_status(handle.endpoint_id, handle.job_id)
629
+ poll_errors.reset()
630
+ except runpod_api.RunpodApiError as e:
631
+ if poll_errors.record(e):
632
+ return PollResult(False, failure="poll_error", detail=str(e))
633
+ continue
634
+ status = st.get("status")
635
+ if status != last_status:
636
+ say(f"job {handle.job_id}: {status}")
637
+ last_status = status
638
+ last_progress = time.time()
639
+ if status in TERMINAL_OK:
640
+ try:
641
+ return PollResult(True, metrics=decode_output(st.get("output")))
642
+ except RuntimeError as e:
643
+ # COMPLETED but the output decodes as an error (a handler exception). Consult the
644
+ # worker flag too: an infra failure can surface here and must still retry.
645
+ last_hb_key, _ = surface_forced_heartbeat(heartbeat_reader, last_hb_key, say)
646
+ retriable = worker_flagged_retriable(heartbeat_reader)
647
+ detail = _append_failure_artifacts(str(e), failure_detail_reader)
648
+ return PollResult(
649
+ False,
650
+ failure="job_preempted" if retriable else "job_failed",
651
+ detail=detail,
652
+ )
653
+ if status in TERMINAL_FAIL:
654
+ detail = str(st.get("error") or "")[:1500]
655
+ out = st.get("output")
656
+ if isinstance(out, dict) and out.get("stdout"):
657
+ # Worker stdout tail is the only place the REAL root cause lives for
658
+ # crashes inside subprocesses (e.g. vLLM EngineCore deaths).
659
+ detail += "\n--- worker stdout tail ---\n" + str(out["stdout"])[-2000:]
660
+ elif not detail:
661
+ detail = str(out)[:1500]
662
+ # Structural classification only ([{status}] prefix is for human-readable logs).
663
+ # A platform termination (CANCELLED/TIMED_OUT) is already retryable — skip the worker
664
+ # heartbeat read entirely (no worker error there, and it may not even exist yet).
665
+ if status in PLATFORM_TERMINATIONS:
666
+ return PollResult(False, failure="job_preempted", detail=f"[{status}] {detail}")
667
+ # A worker FAILED: consult the structured worker flag (one forced heartbeat read).
668
+ last_hb_key, _ = surface_forced_heartbeat(heartbeat_reader, last_hb_key, say)
669
+ retriable = worker_flagged_retriable(heartbeat_reader)
670
+ detail = _append_failure_artifacts(detail, failure_detail_reader)
671
+ return PollResult(
672
+ False,
673
+ failure="job_preempted" if retriable else "job_failed",
674
+ detail=f"[{status}] {detail}",
675
+ )
676
+ # Capacity backstop: bound how long the job may sit IN_QUEUE (no worker has accepted it).
677
+ # The throttled/unhealthy fast-fails below only arm when endpoint_health succeeds AND RunPod
678
+ # reports a THROTTLED/UNHEALTHY worker; a queue with zero workers, one stuck INITIALIZING, or
679
+ # a health probe that keeps erroring (its block is wrapped in `except: pass`) bypasses them and
680
+ # would otherwise wait the full setup_grace_s (~50 min). Keyed off the authoritative job status
681
+ # so it holds even when the health probe is blind: once IN_QUEUE exceeds queue_grace_s, return a
682
+ # retryable stall so the runner's gpu-walk re-provisions on the next-best (in-capacity) class.
683
+ now = time.time()
684
+ if status == "IN_QUEUE":
685
+ if queued_since is None:
686
+ queued_since = now
687
+ elif now - queued_since > queue_grace_s:
688
+ return PollResult(
689
+ False,
690
+ failure="no_capacity",
691
+ detail=f"never scheduled: job stuck IN_QUEUE for {int(now - queued_since)}s "
692
+ "(no RunPod capacity for the pinned GPU class); retrying on the next-best GPU",
693
+ )
694
+ else:
695
+ queued_since = None
696
+ # While queued, surface worker availability (throttled hosts are the common
697
+ # cause of silent multi-minute waits — make them visible in the run log).
698
+ if status == "IN_QUEUE" and now - last_health_probe > 90:
699
+ last_health_probe = now
700
+ try:
701
+ h = runpod_api.endpoint_health(handle.endpoint_id)
702
+ workers = h.get("workers") or {}
703
+ usable = workers.get("running") or workers.get("ready") or workers.get("idle")
704
+ recovering = workers.get("initializing")
705
+ if (
706
+ any(workers.get(k) for k in ("throttled", "unhealthy", "initializing"))
707
+ or not usable
708
+ ):
709
+ say(f"queued; workers: {workers}")
710
+ # Fail fast on a worker stuck UNHEALTHY: a dead worker / failed image pull won't
711
+ # self-recover, so don't burn the full setup_grace_s (~50 min) waiting on it — once
712
+ # it has stayed unhealthy with nothing usable or (re)initializing for
713
+ # unhealthy_grace_s, return a (retryable) stall so the runner re-provisions a FRESH
714
+ # endpoint (fresh image pull, likely a different host). Observed: a mutable image
715
+ # tag republished mid-pull corrupts the worker -> unhealthy, and a fresh pull fixes it.
716
+ if workers.get("unhealthy") and not usable and not recovering:
717
+ if unhealthy_since is None:
718
+ unhealthy_since = time.time()
719
+ elif time.time() - unhealthy_since > unhealthy_grace_s:
720
+ return PollResult(
721
+ False,
722
+ failure="stalled",
723
+ detail=f"worker stuck unhealthy for "
724
+ f"{int(time.time() - unhealthy_since)}s while IN_QUEUE (likely a failed "
725
+ f"image pull); retrying on a fresh endpoint",
726
+ )
727
+ else:
728
+ unhealthy_since = None # recovered / usable worker appeared
729
+ # Fail fast on a worker stuck THROTTLED: RunPod has no capacity for the pinned GPU
730
+ # class/pool and a throttled worker won't self-recover, so don't burn the full
731
+ # setup_grace_s (~50 min) waiting on it. Once it has stayed throttled with nothing
732
+ # usable or (re)initializing for throttled_grace_s, return a (retryable) stall so
733
+ # the runner's gpu-walk re-provisions on the NEXT-BEST GPU class — the cheapest fit
734
+ # often has no capacity while the next-best (a few cents/hr more) does.
735
+ if workers.get("throttled") and not usable and not recovering:
736
+ if throttled_since is None:
737
+ throttled_since = time.time()
738
+ elif time.time() - throttled_since > throttled_grace_s:
739
+ return PollResult(
740
+ False,
741
+ failure="no_capacity",
742
+ detail=f"never scheduled: worker stuck THROTTLED for "
743
+ f"{int(time.time() - throttled_since)}s while IN_QUEUE (no RunPod "
744
+ f"capacity for the pinned GPU class); retrying on the next-best GPU",
745
+ )
746
+ else:
747
+ throttled_since = None # capacity appeared / usable worker
748
+ except Exception:
749
+ # Health surfacing is diagnostic only; a probe failure must not stop polling.
750
+ pass
751
+ # heartbeat progress surfacing + stall detection
752
+ new_key, stage = surface_heartbeat(heartbeat_reader, last_hb_key, say)
753
+ if new_key != last_hb_key:
754
+ last_hb_key = new_key
755
+ last_progress = time.time()
756
+ # Only a training-phase heartbeat means cold-start setup is done and we
757
+ # can switch to the tight window; setup heartbeats keep the grace budget.
758
+ if stage not in _SETUP_HEARTBEAT_STAGES:
759
+ seen_heartbeat = True
760
+ # Cold start (before any training-phase heartbeat) gets the larger setup_grace_s,
761
+ # but only when a heartbeat_reader lets us tell setup from training; without one we
762
+ # can't, so stay on stall_after_s (no regression).
763
+ in_setup = heartbeat_reader is not None and not seen_heartbeat
764
+ stall_limit = setup_grace_s if in_setup else stall_after_s
765
+ if time.time() - last_progress > stall_limit:
766
+ phase = "setup (pre-training)" if in_setup else "training"
767
+ return PollResult(
768
+ False,
769
+ failure="stalled",
770
+ detail=f"no worker progress for {int(time.time() - last_progress)}s "
771
+ f"during {phase} (job status {status}, limit {int(stall_limit)}s)",
772
+ )
773
+ time.sleep(interval_s)
774
+
775
+
776
+ def submit_run(
777
+ spec,
778
+ seed: int,
779
+ log=None,
780
+ on_handle=None,
781
+ attempt: int = 0,
782
+ runtime_secrets: dict[str, str] | None = None,
783
+ on_last_gpu: bool = False,
784
+ ) -> PollResult:
785
+ """Durable equivalent of ``submit_train``: deploy, submit, persist handle, poll.
786
+
787
+ ``on_handle(handle_dict)`` is invoked as soon as the job is queued so the
788
+ runner can persist {endpoint_id, job_id} for cross-process reattach.
789
+
790
+ ``on_last_gpu`` tells the no-capacity backstops no further GPU attempt will follow this one
791
+ (candidate list exhausted OR retry budget exhausted), so there is no next-best class to walk to
792
+ and they wait longer before giving up (see ``stall_kwargs``).
793
+ """
794
+ from flash.envs.registry import worker_pip_for_env
795
+ from flash.providers.runpod.train import _run_suffix, build_worker_env, chalk_extra_pip
796
+
797
+ timeout_s = max(60, int(spec.gpu.max_wall_seconds))
798
+ # Per-attempt endpoint name: a retry must land on a genuinely fresh endpoint —
799
+ # reusing the name lets the SDK/platform pin the job back onto the same
800
+ # (possibly throttled/sick) host.
801
+ suffix = _run_suffix(spec.run_id)
802
+ if attempt:
803
+ suffix = f"{suffix}r{attempt}"
804
+ # Resolve worker pip deps BEFORE provisioning, so deterministic dependency issues surface
805
+ # before the endpoint exists.
806
+ # extra_pip runs for EVERY job here (the durable baked-image path skips resolve_worker_deps
807
+ # in build_function_input, but _train_body always pip-installs extra_pip), so the chalk spec
808
+ # is appended here to reach default runs.
809
+ extra_pip = (
810
+ list(spec.environment.pip) or worker_pip_for_env(spec.environment.id)
811
+ ) + chalk_extra_pip(spec)
812
+ worker_env = build_worker_env(spec, seed, runtime_secrets=runtime_secrets)
813
+ worker_env["ATTEMPT"] = str(int(attempt))
814
+ endpoint_id, name = deploy_train_endpoint(
815
+ spec.gpu.type,
816
+ execution_timeout_ms=timeout_s * 1000,
817
+ name_suffix=suffix,
818
+ disk_gb=spec.gpu.disk_gb,
819
+ spec=spec,
820
+ )
821
+ payload = {
822
+ "hf_repo": spec.train.hf_repo,
823
+ "job_spec_json": spec.to_json(),
824
+ "phase": spec.phase,
825
+ "seed": int(seed),
826
+ "env": worker_env,
827
+ "extra_pip": extra_pip,
828
+ }
829
+ try:
830
+ job_id = runpod_api.submit_job(endpoint_id, build_function_input(payload))
831
+ except Exception:
832
+ # The endpoint is registered but no run handle exists yet, and a
833
+ # retry endpoint's rN-suffixed name can't be reconstructed from the run
834
+ # id later — delete it now so a transient submit failure doesn't leak a
835
+ # serverless endpoint against the account quota.
836
+ with contextlib.suppress(Exception):
837
+ runpod_api.delete_endpoint(endpoint_id)
838
+ raise
839
+ handle = JobHandle(endpoint_id, name, job_id)
840
+ if log is not None:
841
+ print(
842
+ f"submitted job: endpoint={name} ({endpoint_id}) job={job_id} "
843
+ f"attempt={attempt} gpu={spec.gpu.type} phase={spec.phase} seed={seed}",
844
+ file=log,
845
+ flush=True,
846
+ )
847
+ if on_handle is not None:
848
+ on_handle(handle.to_dict())
849
+ hf_repo = spec.train.hf_repo
850
+ prefix = f"{spec.phase}/{spec.run_id}/seed{seed}"
851
+ reader = make_hf_heartbeat_reader(hf_repo, prefix) if hf_repo else None
852
+ failure_reader = (
853
+ make_hf_failure_detail_reader(hf_repo, prefix, spec.phase) if hf_repo else None
854
+ )
855
+ return poll_job(
856
+ handle,
857
+ log=log,
858
+ heartbeat_reader=reader,
859
+ failure_detail_reader=failure_reader,
860
+ **stall_kwargs(on_last_gpu=on_last_gpu),
861
+ )
862
+
863
+
864
+ def make_hf_text_reader(hf_repo: str, path_in_repo: str, min_interval_s: float = 45.0):
865
+ """Rate-limited reader for one HF artifact's text content (None until it exists).
866
+
867
+ Generic helper for HF-backed worker artifacts and heartbeats. ``read(force=False)``
868
+ re-downloads at most once per
869
+ ``min_interval_s`` (``force=True`` bypasses the gate); it never raises — any HF error
870
+ (artifact absent, network) returns None.
871
+ """
872
+ state = {"last": 0.0}
873
+
874
+ def read(force: bool = False) -> str | None:
875
+ if not hf_repo:
876
+ return None
877
+ if not force and time.time() - state["last"] < min_interval_s:
878
+ return None
879
+ state["last"] = time.time()
880
+ try:
881
+ from huggingface_hub import hf_hub_download
882
+
883
+ p = hf_hub_download(
884
+ hf_repo,
885
+ path_in_repo,
886
+ repo_type="dataset",
887
+ token=os.environ.get("HF_TOKEN"),
888
+ force_download=True,
889
+ )
890
+ with open(p) as f:
891
+ return f.read()
892
+ except Exception:
893
+ return None
894
+
895
+ return read
896
+
897
+
898
+ def make_hf_heartbeat_reader(hf_repo: str, prefix: str, min_interval_s: float = 30.0):
899
+ """Reader for the worker's heartbeat.json on HF (rate-limited, never raises).
900
+
901
+ Thin JSON-parsing wrapper over :func:`make_hf_text_reader` bound to ``{prefix}/heartbeat.json``.
902
+ """
903
+ text_reader = make_hf_text_reader(hf_repo, f"{prefix}/heartbeat.json", min_interval_s)
904
+
905
+ def read(force: bool = False) -> dict | None:
906
+ raw = text_reader(force=force)
907
+ if raw is None:
908
+ return None
909
+ try:
910
+ return json.loads(raw)
911
+ except (ValueError, TypeError):
912
+ return None
913
+
914
+ return read
915
+
916
+
917
+ def make_hf_failure_detail_reader(
918
+ hf_repo: str,
919
+ prefix: str,
920
+ phase: str,
921
+ min_interval_s: float = 45.0,
922
+ ):
923
+ """Reader for worker-uploaded RunPod failure artifacts on HF.
924
+
925
+ The RunPod queue often reports only the outer handler error (for example, "produced no
926
+ /tmp/metrics.json"). The worker writes the actual traceback and console tail to HF; this
927
+ reader lets ``poll_job`` force-download those files after a terminal worker failure.
928
+ """
929
+ error_reader = make_hf_text_reader(hf_repo, f"{prefix}/error_{phase}.txt", min_interval_s)
930
+ console_reader = make_hf_text_reader(
931
+ hf_repo, f"{prefix}/console_{phase}.txt", min_interval_s
932
+ )
933
+
934
+ def read(force: bool = False) -> str | None:
935
+ parts: list[str] = []
936
+ error_text = error_reader(force=force)
937
+ if error_text:
938
+ parts.append(f"--- error_{phase}.txt ---\n{error_text[-4000:]}")
939
+ console_text = console_reader(force=force)
940
+ if console_text:
941
+ parts.append(f"--- console_{phase}.txt ---\n{console_text[-4000:]}")
942
+ return "\n".join(parts) if parts else None
943
+
944
+ return read
945
+
946
+
947
+ def worker_flagged_retriable(heartbeat_reader) -> bool:
948
+ """True if the worker stamped ``retriable`` (a RetriableInfraError) in its last heartbeat — the
949
+ structured worker<->poller contract that replaces failure-detail parsing: ``retriable`` means
950
+ retry on a fresh worker. Forces a fresh read past the rate limit."""
951
+ if heartbeat_reader is None:
952
+ return False
953
+ hb = heartbeat_reader(force=True)
954
+ if not isinstance(hb, dict):
955
+ return False
956
+ return bool(hb.get("retriable"))