freesolo-flash-dev 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flash/__init__.py +29 -0
- flash/_channel.py +23 -0
- flash/_fileio.py +35 -0
- flash/_logging.py +49 -0
- flash/_update_check.py +266 -0
- flash/catalog.py +253 -0
- flash/cli/__init__.py +1 -0
- flash/cli/main/__init__.py +227 -0
- flash/cli/main/__main__.py +6 -0
- flash/cli/main/commands.py +636 -0
- flash/cli/main/envpush.py +317 -0
- flash/cli/main/render.py +599 -0
- flash/cli/main/training_doc.py +455 -0
- flash/client/__init__.py +14 -0
- flash/client/config.py +70 -0
- flash/client/http.py +372 -0
- flash/client/runtime_secrets.py +69 -0
- flash/client/specs.py +20 -0
- flash/cost/__init__.py +16 -0
- flash/cost/analytical.py +175 -0
- flash/cost/facts.py +114 -0
- flash/cost/spec.py +113 -0
- flash/cost/types.py +158 -0
- flash/engine/__init__.py +6 -0
- flash/engine/accounting.py +36 -0
- flash/engine/chalk_kernels.py +116 -0
- flash/engine/multiturn_rollout.py +780 -0
- flash/engine/recipe.py +86 -0
- flash/engine/vram.py +603 -0
- flash/engine/worker/__init__.py +2916 -0
- flash/engine/worker/__main__.py +4 -0
- flash/engine/worker/kernel_warmup.py +400 -0
- flash/engine/worker/lora.py +796 -0
- flash/engine/worker/packing.py +366 -0
- flash/engine/worker/perf.py +1048 -0
- flash/envs/__init__.py +10 -0
- flash/envs/adapter/__init__.py +883 -0
- flash/envs/adapter/rubric.py +222 -0
- flash/envs/base.py +52 -0
- flash/envs/registry.py +62 -0
- flash/mcp/__init__.py +1 -0
- flash/mcp/server.py +85 -0
- flash/providers/__init__.py +59 -0
- flash/providers/_auth.py +24 -0
- flash/providers/_http.py +230 -0
- flash/providers/_instance.py +416 -0
- flash/providers/_instance_bootstrap.py +517 -0
- flash/providers/_poll.py +311 -0
- flash/providers/allocator.py +193 -0
- flash/providers/base.py +431 -0
- flash/providers/hyperstack/__init__.py +127 -0
- flash/providers/hyperstack/api.py +522 -0
- flash/providers/hyperstack/auth.py +17 -0
- flash/providers/hyperstack/gpus.py +29 -0
- flash/providers/hyperstack/jobs/__init__.py +632 -0
- flash/providers/hyperstack/jobs/builders.py +122 -0
- flash/providers/hyperstack/preflight.py +23 -0
- flash/providers/hyperstack/pricing.py +26 -0
- flash/providers/hyperstack/train.py +25 -0
- flash/providers/lambdalabs/__init__.py +139 -0
- flash/providers/lambdalabs/api.py +261 -0
- flash/providers/lambdalabs/auth.py +18 -0
- flash/providers/lambdalabs/gpus.py +29 -0
- flash/providers/lambdalabs/jobs/__init__.py +724 -0
- flash/providers/lambdalabs/jobs/builders.py +118 -0
- flash/providers/lambdalabs/preflight.py +27 -0
- flash/providers/lambdalabs/pricing.py +51 -0
- flash/providers/lambdalabs/train.py +27 -0
- flash/providers/preflight.py +55 -0
- flash/providers/realized.py +80 -0
- flash/providers/runpod/__init__.py +130 -0
- flash/providers/runpod/api.py +186 -0
- flash/providers/runpod/auth.py +37 -0
- flash/providers/runpod/cost.py +57 -0
- flash/providers/runpod/gpus.py +46 -0
- flash/providers/runpod/jobs.py +956 -0
- flash/providers/runpod/keys.py +139 -0
- flash/providers/runpod/preflight.py +30 -0
- flash/providers/runpod/preload.py +915 -0
- flash/providers/runpod/pricing.py +18 -0
- flash/providers/runpod/slots.py +79 -0
- flash/providers/runpod/train/__init__.py +150 -0
- flash/providers/runpod/train/deps.py +395 -0
- flash/providers/runpod/train/endpoints.py +820 -0
- flash/py.typed +0 -0
- flash/runner/__init__.py +686 -0
- flash/runner/checkpoints.py +82 -0
- flash/runner/deploy.py +422 -0
- flash/runner/lifecycle.py +672 -0
- flash/schema/__init__.py +375 -0
- flash/schema/fields.py +331 -0
- flash/serve/__init__.py +1 -0
- flash/serve/deploy.py +326 -0
- flash/serve/pricing.py +60 -0
- flash/server/__init__.py +1 -0
- flash/server/__main__.py +20 -0
- flash/server/app.py +961 -0
- flash/server/auth.py +263 -0
- flash/server/billing.py +124 -0
- flash/server/checkpoints.py +110 -0
- flash/server/db.py +160 -0
- flash/server/environment_registry.py +102 -0
- flash/server/envs.py +360 -0
- flash/server/reconcile.py +163 -0
- flash/server/run_registry.py +150 -0
- flash/spec.py +333 -0
- freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
- freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
- freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
- freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
- freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,956 @@
|
|
|
1
|
+
"""Durable run primitives: explicit deploy -> submit -> poll with a persisted job handle.
|
|
2
|
+
|
|
3
|
+
Calling `runpod_flash`'s all-in-one blocking handler directly would tie a run's life to
|
|
4
|
+
one client process and one HTTP poll loop: a client crash/network blip orphans an
|
|
5
|
+
otherwise-healthy GPU job (no job id is ever persisted), and any SDK polling bug kills
|
|
6
|
+
the run. This module owns the lifecycle instead:
|
|
7
|
+
|
|
8
|
+
deploy_train_endpoint() -> endpoint_id (Flash SDK deploy, same worker template)
|
|
9
|
+
build_function_input() -> the exact FunctionRequest payload Flash workers expect
|
|
10
|
+
submit + poll_job() -> REST queue API with hardened retries; the job handle
|
|
11
|
+
{endpoint_id, job_id} is persisted by the runner so
|
|
12
|
+
any process can re-attach (`flash status --follow`).
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
import base64
|
|
19
|
+
import contextlib
|
|
20
|
+
import json
|
|
21
|
+
import os
|
|
22
|
+
import threading
|
|
23
|
+
import time
|
|
24
|
+
from dataclasses import dataclass
|
|
25
|
+
from typing import TYPE_CHECKING
|
|
26
|
+
|
|
27
|
+
from flash._logging import get_logger
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from collections.abc import Callable
|
|
31
|
+
from flash.providers._poll import (
|
|
32
|
+
PollErrorTracker,
|
|
33
|
+
make_say,
|
|
34
|
+
surface_forced_heartbeat,
|
|
35
|
+
surface_heartbeat,
|
|
36
|
+
)
|
|
37
|
+
from flash.providers.base import PollResult, canonical_gpu
|
|
38
|
+
from flash.providers.runpod import api as runpod_api
|
|
39
|
+
from flash.providers.runpod.gpus import flash_gpu
|
|
40
|
+
from flash.providers.runpod.train import (
|
|
41
|
+
DEFAULT_EXECUTION_TIMEOUT_MS,
|
|
42
|
+
FLASH_SDK_LOCK,
|
|
43
|
+
WORKER_IMAGE,
|
|
44
|
+
WORKER_SYSTEM_DEPS,
|
|
45
|
+
_patch_runpod_backoff,
|
|
46
|
+
_train_body,
|
|
47
|
+
endpoint_name,
|
|
48
|
+
isolate_flash_state,
|
|
49
|
+
min_cuda_for,
|
|
50
|
+
resolve_worker_deps,
|
|
51
|
+
worker_image_for_gpu,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
logger = get_logger(__name__)
|
|
55
|
+
|
|
56
|
+
# Re-export so callers/tests that did ``from ...jobs import PollResult`` keep working.
|
|
57
|
+
__all__ = [
|
|
58
|
+
"JobHandle",
|
|
59
|
+
"PollResult",
|
|
60
|
+
"apply_disk_gb",
|
|
61
|
+
"build_function_input",
|
|
62
|
+
"decode_output",
|
|
63
|
+
"deploy_train_endpoint",
|
|
64
|
+
"make_hf_failure_detail_reader",
|
|
65
|
+
"make_hf_heartbeat_reader",
|
|
66
|
+
"make_hf_text_reader",
|
|
67
|
+
"poll_job",
|
|
68
|
+
"submit_run",
|
|
69
|
+
"weight_cache_datacenters",
|
|
70
|
+
"weight_cache_endpoint_kwargs",
|
|
71
|
+
"weight_cache_volume_name",
|
|
72
|
+
"weight_cache_volumes",
|
|
73
|
+
]
|
|
74
|
+
|
|
75
|
+
TERMINAL_OK = {"COMPLETED"}
|
|
76
|
+
# The provider killed the worker (reclaim/preempt/time-cap) -> infra-shaped, retried. A worker
|
|
77
|
+
# "FAILED" is the run dying on its own (real traceback) -> fails fast.
|
|
78
|
+
PLATFORM_TERMINATIONS = {"CANCELLED", "TIMED_OUT"}
|
|
79
|
+
TERMINAL_FAIL = {"FAILED"} | PLATFORM_TERMINATIONS
|
|
80
|
+
|
|
81
|
+
# Heartbeat stages the worker emits DURING cold start, BEFORE the model is loaded and the
|
|
82
|
+
# training loop begins (boot -> sft_start/rl_start, then later sft_model_load/rl_train_start).
|
|
83
|
+
# Receiving one proves the worker is alive but NOT that the slow setup (model download +
|
|
84
|
+
# vLLM init) finished, so they must not flip stall detection to the tight training window.
|
|
85
|
+
_SETUP_HEARTBEAT_STAGES = frozenset(
|
|
86
|
+
{"boot", "sft_start", "rl_start", "sft_model_load", "rl_train_start"}
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def stall_kwargs(on_last_gpu: bool = False) -> dict:
|
|
91
|
+
"""``poll_job`` stall-window kwargs, shared by the submit and reattach paths so a recovered
|
|
92
|
+
run uses the same tuning as the original submit. The original submit's ``on_last_gpu`` is
|
|
93
|
+
PERSISTED in the run handle (the runner's ``on_handle`` writes it into ``remote``), so a
|
|
94
|
+
cross-process reattach (``RunpodProvider.poll``) reads it back and calls this with the same
|
|
95
|
+
value — a last-candidate run keeps its longer no-capacity grace after a control-plane restart
|
|
96
|
+
instead of being judged on the shorter non-last window. ``stall_after_s`` = post-training-heartbeat
|
|
97
|
+
window; ``setup_grace_s`` = the larger cold-start window before the first training heartbeat;
|
|
98
|
+
``queue_grace_s``/``throttled_grace_s`` = the two no-capacity backstops — how long a job may
|
|
99
|
+
sit IN_QUEUE with no worker (``queue_grace_s``) or wait on a worker stuck THROTTLED
|
|
100
|
+
(``throttled_grace_s``) before we treat the pinned GPU class as out of capacity and walk to
|
|
101
|
+
the next-best one.
|
|
102
|
+
|
|
103
|
+
These backstops are tuned to whether a further GPU attempt will follow. While a retry can still
|
|
104
|
+
fall to a next-best class (``on_last_gpu`` False) we wait ~5 min: long enough to ride out a brief
|
|
105
|
+
capacity blip, short enough that a genuinely starved class hands off to the next-best one
|
|
106
|
+
promptly. When no further GPU attempt will be made — the candidate list is exhausted OR the retry
|
|
107
|
+
budget is exhausted (``on_last_gpu`` True) — there is nowhere left to walk, so we wait ~15 min before
|
|
108
|
+
giving up: burning the last attempt on a class with no fallback (and no retry left to spend the
|
|
109
|
+
saved time on) is worse than waiting out a longer queue. Both are no-capacity backstops only:
|
|
110
|
+
once the job leaves IN_QUEUE (a worker picks it
|
|
111
|
+
up), the much larger ``setup_grace_s`` governs cold start and we never walk off an IN_PROGRESS
|
|
112
|
+
job at the capacity grace.
|
|
113
|
+
"""
|
|
114
|
+
grace = 900.0 if on_last_gpu else 300.0
|
|
115
|
+
return {
|
|
116
|
+
"stall_after_s": 1500.0,
|
|
117
|
+
"setup_grace_s": 3000.0,
|
|
118
|
+
"queue_grace_s": grace,
|
|
119
|
+
"throttled_grace_s": grace,
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
# RunPod DCs in the SDK's ``DataCenter.all()`` enum that do NOT support network volumes. LIVE-FOUND:
|
|
124
|
+
# the SDK enum is NOT the network-volume DC set (that assumption was wrong) — creating a volume in one
|
|
125
|
+
# of these 500s the WHOLE deploy ("data center ... does not support network volumes"), so eager runs
|
|
126
|
+
# would always fall back to cold and the cache would never work. The SDK exposes no volume-capability
|
|
127
|
+
# flag, so we maintain the exclusion here. (If RunPod drops volume support in another DC, that deploy
|
|
128
|
+
# 500s -> the lifecycle no_capacity/poll_error cache-drop falls back to a cold cross-region run, so a
|
|
129
|
+
# stale list degrades gracefully rather than wedging — but add the DC here to restore its cache.)
|
|
130
|
+
_VOLUME_INCAPABLE_DATACENTERS = frozenset({"US-MO-1"})
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def weight_cache_datacenters() -> list:
|
|
134
|
+
"""Every VOLUME-CAPABLE RunPod DC — both the set the endpoint is allowed across AND the set we
|
|
135
|
+
attach a per-DC volume in (eager: a volume in every region a run can land in, so any landing is
|
|
136
|
+
warm). ``DataCenter.all()`` minus ``_VOLUME_INCAPABLE_DATACENTERS`` (the enum includes DCs RunPod
|
|
137
|
+
no longer backs with network volumes — see that constant). A SDK upgrade that adds a storage region
|
|
138
|
+
is picked up automatically; one that adds a volume-less region must be excluded above.
|
|
139
|
+
"""
|
|
140
|
+
from runpod_flash.core.resources.datacenter import DataCenter
|
|
141
|
+
|
|
142
|
+
return [dc for dc in DataCenter.all() if dc.value not in _VOLUME_INCAPABLE_DATACENTERS]
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def weight_cache_volume_name(base: str, dc) -> str:
|
|
146
|
+
"""Physical volume name for ``base`` in datacenter ``dc`` — DISTINCT per DC.
|
|
147
|
+
|
|
148
|
+
The cache is one logical volume (``base`` == ``spec.gpu.network_volume``, e.g. ``flash-weights``)
|
|
149
|
+
realized as one physical volume per datacenter. The DC MUST be in the name: the runpod_flash SDK
|
|
150
|
+
keys its in-memory/persisted resource tracking on ``NetworkVolume:{name}`` WITHOUT the
|
|
151
|
+
datacenter (resources/base.py ``get_resource_key``), so N same-named volumes collide on one key
|
|
152
|
+
and deploying the 2nd triggers a replace -> the SDK's unimplemented ``NetworkVolume.undeploy`` ->
|
|
153
|
+
crash. A per-DC name gives each volume a unique key. The worker is unaffected — every volume
|
|
154
|
+
mounts at the same ``/runpod-volume`` regardless of name.
|
|
155
|
+
"""
|
|
156
|
+
return f"{base}-{dc.value.lower()}"
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def weight_cache_volumes(spec) -> list:
|
|
160
|
+
"""One ``NetworkVolume`` per storage datacenter — the EAGER fleet (``[]`` only if the cache is off).
|
|
161
|
+
|
|
162
|
+
Empty unless ``spec.gpu.network_volume`` is set (the runner assigns the logical base name for
|
|
163
|
+
eligible runs). Otherwise one physical volume per ``weight_cache_datacenters()`` entry — every
|
|
164
|
+
storage DC, so the cache exists in whichever region the endpoint lands in. Each physical volume is
|
|
165
|
+
``<base>-<dc>`` (see ``weight_cache_volume_name``), idempotent by (name, datacenter): runpod_flash
|
|
166
|
+
reuses an existing volume of that name/DC, so this is create-or-attach (the first deploy provisions
|
|
167
|
+
the whole fleet; later deploys just re-attach).
|
|
168
|
+
|
|
169
|
+
Multi-account pools: ``deploy_train_endpoint`` re-runs the WHOLE deploy on quota failover, so the
|
|
170
|
+
volumes are re-created on whichever account ends up hosting the endpoint (account-scoped). Orphans
|
|
171
|
+
on the failed-over-FROM account are reclaimed by ``preload --teardown`` (sweeps every pool account).
|
|
172
|
+
"""
|
|
173
|
+
base = getattr(spec.gpu, "network_volume", None) if spec is not None else None
|
|
174
|
+
if not base:
|
|
175
|
+
return []
|
|
176
|
+
dcs = weight_cache_datacenters() # EAGER: a volume in every storage DC
|
|
177
|
+
if not dcs:
|
|
178
|
+
return []
|
|
179
|
+
from runpod_flash import NetworkVolume
|
|
180
|
+
|
|
181
|
+
from flash.spec import _volume_gb
|
|
182
|
+
|
|
183
|
+
# Reuse the spec's tolerant parser: a stale/hand-edited spec with a non-numeric, "0", or negative
|
|
184
|
+
# network_volume_gb defaults to 100 GB rather than raising (which best-effort would swallow into a
|
|
185
|
+
# no-cache deploy) or creating a nonsensical 0-GB volume — matches _volume_gb's contract/tests.
|
|
186
|
+
size = _volume_gb(getattr(spec.gpu, "network_volume_gb", 100))
|
|
187
|
+
return [
|
|
188
|
+
NetworkVolume(name=weight_cache_volume_name(str(base), dc), size=size, datacenter=dc)
|
|
189
|
+
for dc in dcs
|
|
190
|
+
]
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def weight_cache_endpoint_kwargs(spec) -> dict:
|
|
194
|
+
"""Endpoint kwargs that attach the eager weight-cache fleet, or ``{}`` (best-effort).
|
|
195
|
+
|
|
196
|
+
``{"volume": [vol per storage dc...], "datacenter": [ALL storage DCs]}`` — the endpoint is allowed
|
|
197
|
+
across ALL DCs (so it lands wherever there's capacity) AND carries a volume in every one of them, so
|
|
198
|
+
whichever DC it lands in is warm. The SDK's "every volume DC must be in the endpoint datacenter
|
|
199
|
+
list" rule holds exactly (the two lists are the same storage-DC set). The first deploy
|
|
200
|
+
create-or-attaches the whole fleet; later deploys re-attach.
|
|
201
|
+
|
|
202
|
+
Returns ``{}`` only when the cache is off (no ``network_volume`` on the spec). Best-effort: ANY
|
|
203
|
+
failure (SDK import, validation) -> ``{}`` so the run deploys with no volume rather than failing;
|
|
204
|
+
the lifecycle still drops the volume on a no_capacity retry to widen onto the non-storage DC pool.
|
|
205
|
+
"""
|
|
206
|
+
try:
|
|
207
|
+
vols = weight_cache_volumes(spec)
|
|
208
|
+
if not vols:
|
|
209
|
+
return {} # cache off -> cold (no volumes, RunPod picks any region)
|
|
210
|
+
return {"volume": vols, "datacenter": weight_cache_datacenters()}
|
|
211
|
+
except Exception as exc:
|
|
212
|
+
# Best-effort: never let the cache break a deploy — fall back to a no-volume run.
|
|
213
|
+
logger.warning("weight cache disabled for this run (%s); deploying with no volume", exc)
|
|
214
|
+
return {}
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def apply_disk_gb(config, disk_gb: int | None) -> None:
|
|
218
|
+
"""Raise the worker's container disk on a built endpoint config.
|
|
219
|
+
|
|
220
|
+
The Flash SDK's ``PodTemplate.containerDiskInGb`` defaults to 64 GB and the
|
|
221
|
+
``Endpoint`` wrapper exposes no disk knob, which is what blocked models whose
|
|
222
|
+
checkpoint alone exceeds 64 GB. The template
|
|
223
|
+
is already populated by the SDK's validators when the resource config is built, so
|
|
224
|
+
raising the field here is the supported injection point. Raise-only: shrinking
|
|
225
|
+
below the SDK default buys nothing (serverless disk isn't billed separately) and
|
|
226
|
+
would regress runs whose configs carry the historical ``disk_gb = 60`` default.
|
|
227
|
+
"""
|
|
228
|
+
if not disk_gb:
|
|
229
|
+
return
|
|
230
|
+
template = getattr(config, "template", None)
|
|
231
|
+
if template is None:
|
|
232
|
+
logger.warning("disk_gb=%s requested but endpoint config has no template", disk_gb)
|
|
233
|
+
return
|
|
234
|
+
template.containerDiskInGb = max(int(disk_gb), int(template.containerDiskInGb or 0))
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
@dataclass
|
|
238
|
+
class JobHandle:
|
|
239
|
+
endpoint_id: str
|
|
240
|
+
endpoint_name: str
|
|
241
|
+
job_id: str
|
|
242
|
+
|
|
243
|
+
def to_dict(self) -> dict:
|
|
244
|
+
return {
|
|
245
|
+
"provider": "runpod",
|
|
246
|
+
"endpoint_id": self.endpoint_id,
|
|
247
|
+
"endpoint_name": self.endpoint_name,
|
|
248
|
+
"job_id": self.job_id,
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
@classmethod
|
|
252
|
+
def from_dict(cls, d: dict) -> JobHandle:
|
|
253
|
+
# `provider` is routing metadata consumed upstream (runner); handles
|
|
254
|
+
# persisted before it existed default to runpod there.
|
|
255
|
+
return cls(d["endpoint_id"], d.get("endpoint_name", ""), d["job_id"])
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def _is_workers_quota_error(exc: Exception) -> bool:
|
|
259
|
+
"""True when a RunPod exception signals the account worker quota is exhausted."""
|
|
260
|
+
msg = str(exc).lower()
|
|
261
|
+
return "max workers across all endpoints" in msg
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
# Per-endpoint "first observed idle" timestamps, so a candidate must STAY idle across sweeps for
|
|
265
|
+
# ``min_idle_s`` before deletion (a cold-starting / between-jobs endpoint reports a transient zero
|
|
266
|
+
# we must not act on). Pruned each sweep to the still-idle set, so it can't grow unbounded.
|
|
267
|
+
#
|
|
268
|
+
# Two threads can run a sweep at once — the periodic control-plane reaper (via asyncio.to_thread)
|
|
269
|
+
# and a deploy-time quota sweep — so every read/write of ``_idle_since`` is serialized by this lock
|
|
270
|
+
# (a dedicated lock, NOT FLASH_SDK_LOCK, since the sweep uses the REST API, not the Flash SDK).
|
|
271
|
+
# Holding it across the sweep also prevents a concurrent sweep's prune from disturbing this one's
|
|
272
|
+
# grace timers; contention is negligible (the reaper runs every 10 min, deploy sweeps are rare).
|
|
273
|
+
_idle_since: dict[str, float] = {}
|
|
274
|
+
_idle_since_lock = threading.Lock()
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def _is_flash_endpoint(name: str) -> bool:
|
|
278
|
+
"""True for a flash training endpoint this sweep may reap (matches the SDK's ``live-`` form).
|
|
279
|
+
Serving runs on freesolo's Modal app, not RunPod, so the only flash-* RunPod endpoints are
|
|
280
|
+
training endpoints."""
|
|
281
|
+
return name.removeprefix("live-").startswith("flash-")
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def _sweep_idle_flash_endpoints(
|
|
285
|
+
protected: set[str], min_idle_s: float = 0.0, reap_warm: bool = True
|
|
286
|
+
) -> int:
|
|
287
|
+
"""Delete idle, ORPHANED flash training endpoints — workers doing nothing that still hold
|
|
288
|
+
RunPod worker quota (runs that finished/crashed without tearing their endpoint down). Returns
|
|
289
|
+
the count deleted.
|
|
290
|
+
|
|
291
|
+
Safe by construction:
|
|
292
|
+
|
|
293
|
+
- ``protected`` — endpoint names tied to a LIVE run (both the bare ``flash-...`` and the SDK's
|
|
294
|
+
``live-flash-...`` form). Never deleted, even if momentarily idle (e.g. between seeds).
|
|
295
|
+
- ``reap_warm`` — when True (the run-aware periodic reaper, which protects EVERY live run),
|
|
296
|
+
a merely *warm* ``idle``/``ready`` worker left over after a job counts as doing nothing and
|
|
297
|
+
is reclaimable; that warm-idle state is the dominant leak, since RunPod keeps a worker warm
|
|
298
|
+
after each job. When False (the deploy-time reactive sweep, which only protects the current
|
|
299
|
+
run), a warm worker is treated as busy so the sweep reaps only endpoints that have FULLY
|
|
300
|
+
scaled to zero — it must not delete another live run's between-seeds warm endpoint.
|
|
301
|
+
- ``min_idle_s`` requires the idle reading to PERSIST across sweeps, so a single transient
|
|
302
|
+
zero (cold start / between jobs) never triggers a delete.
|
|
303
|
+
"""
|
|
304
|
+
deleted = 0
|
|
305
|
+
try:
|
|
306
|
+
endpoints = runpod_api.list_endpoints()
|
|
307
|
+
except Exception:
|
|
308
|
+
logger.debug("idle-sweep: failed to list endpoints", exc_info=True)
|
|
309
|
+
return 0
|
|
310
|
+
now = time.time()
|
|
311
|
+
still_idle: set[str] = set()
|
|
312
|
+
# Serialize all _idle_since access (see the lock's definition): a concurrent sweep must not
|
|
313
|
+
# mutate the dict mid-iteration (the prune below would raise) or disturb these grace timers.
|
|
314
|
+
with _idle_since_lock:
|
|
315
|
+
for ep in endpoints:
|
|
316
|
+
ep_name = ep.get("name") or ""
|
|
317
|
+
eid = ep.get("id")
|
|
318
|
+
if not (eid and _is_flash_endpoint(ep_name)):
|
|
319
|
+
continue
|
|
320
|
+
# Protect the run's endpoint in either registered form.
|
|
321
|
+
if ep_name in protected or ep_name.removeprefix("live-") in protected:
|
|
322
|
+
continue
|
|
323
|
+
try:
|
|
324
|
+
health = runpod_api.endpoint_health(eid) or {}
|
|
325
|
+
workers = health.get("workers")
|
|
326
|
+
jobs_info = health.get("jobs")
|
|
327
|
+
# Require non-empty dicts: a missing/empty workers section means the health
|
|
328
|
+
# response is incomplete and we can't confirm the endpoint is idle.
|
|
329
|
+
if not isinstance(workers, dict) or not workers or not isinstance(jobs_info, dict):
|
|
330
|
+
continue
|
|
331
|
+
# "Busy" = a worker actually working or spinning up, OR a job queued/in progress.
|
|
332
|
+
# With reap_warm, a warm idle/ready worker with no pending work is NOT busy — it is
|
|
333
|
+
# the leftover we reclaim (the protected set + grace keep it safe). Without it, a
|
|
334
|
+
# warm worker counts as busy so only fully-scaled-to-zero endpoints are reaped.
|
|
335
|
+
busy_workers = (workers.get("running") or 0) + (workers.get("initializing") or 0)
|
|
336
|
+
if not reap_warm:
|
|
337
|
+
busy_workers += (workers.get("ready") or 0) + (workers.get("idle") or 0)
|
|
338
|
+
in_flight = (jobs_info.get("inQueue") or 0) + (jobs_info.get("inProgress") or 0)
|
|
339
|
+
if busy_workers != 0 or in_flight != 0:
|
|
340
|
+
_idle_since.pop(eid, None) # busy again -> reset the grace timer
|
|
341
|
+
continue
|
|
342
|
+
still_idle.add(eid)
|
|
343
|
+
first_idle = _idle_since.setdefault(eid, now)
|
|
344
|
+
if now - first_idle < min_idle_s:
|
|
345
|
+
continue # idle, but not for long enough yet — wait for the next sweep
|
|
346
|
+
if runpod_api.delete_endpoint(eid):
|
|
347
|
+
deleted += 1
|
|
348
|
+
_idle_since.pop(eid, None)
|
|
349
|
+
logger.info("idle-sweep: deleted idle endpoint %s (%s)", ep_name, eid)
|
|
350
|
+
except Exception:
|
|
351
|
+
logger.debug(
|
|
352
|
+
"idle-sweep: error processing endpoint %s (%s)", ep_name, eid, exc_info=True
|
|
353
|
+
)
|
|
354
|
+
continue
|
|
355
|
+
# Drop grace timers for endpoints no longer idle/present (busy, deleted, gone, protected).
|
|
356
|
+
for stale in set(_idle_since) - still_idle:
|
|
357
|
+
_idle_since.pop(stale, None)
|
|
358
|
+
return deleted
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def deploy_train_endpoint(
|
|
362
|
+
friendly_gpu: str,
|
|
363
|
+
execution_timeout_ms: int | None = None,
|
|
364
|
+
name_suffix: str | None = None,
|
|
365
|
+
disk_gb: int | None = None,
|
|
366
|
+
spec=None,
|
|
367
|
+
endpoint_kwargs: dict | Callable[[], dict] | None = None,
|
|
368
|
+
) -> tuple[str, str]:
|
|
369
|
+
"""Deploy (or reuse) the run's uniquely-named worker endpoint; return (id, name).
|
|
370
|
+
|
|
371
|
+
On a worker-quota error, sweeps idle flash-* endpoints (from crashed/completed runs
|
|
372
|
+
that skipped GC) and retries up to ``_QUOTA_MAX_RETRIES`` times with backoff. If the
|
|
373
|
+
account's quota stays exhausted after sweeping and ``RUNPOD_API_KEY`` configures more
|
|
374
|
+
than one account, fails over to the next account (``keys.advance_key``) and deploys
|
|
375
|
+
there. A single key => single account, no failover (unchanged behavior).
|
|
376
|
+
|
|
377
|
+
``endpoint_kwargs`` overrides the volume/datacenter attachment (default: the full multi-DC
|
|
378
|
+
weight-cache fleet from ``weight_cache_endpoint_kwargs(spec)``). The preload driver passes a
|
|
379
|
+
SINGLE-DC volume+datacenter so the worker provably lands in that region and warms its volume. It
|
|
380
|
+
may be a dict OR a zero-arg FACTORY: under a multi-key pool the deploy retries on the next account
|
|
381
|
+
after a quota failover, and the SDK can stamp an account-scoped id onto a NetworkVolume object —
|
|
382
|
+
so a callable is re-invoked per account to build a FRESH volume (else the next account reuses the
|
|
383
|
+
first account's stale volume id and the single-DC preload fails).
|
|
384
|
+
"""
|
|
385
|
+
os.environ["FLASH_IS_LIVE_PROVISIONING"] = "true"
|
|
386
|
+
from runpod_flash import Endpoint
|
|
387
|
+
from runpod_flash.core.resources.resource_manager import ResourceManager
|
|
388
|
+
|
|
389
|
+
from flash.providers.runpod import keys as rp_keys
|
|
390
|
+
from flash.providers.runpod.auth import ensure_auth
|
|
391
|
+
|
|
392
|
+
_patch_runpod_backoff()
|
|
393
|
+
friendly = canonical_gpu(friendly_gpu)
|
|
394
|
+
name = endpoint_name(friendly, name_suffix)
|
|
395
|
+
# deploy a self-contained serverless-worker image directly. by default this is WORKER_IMAGE;
|
|
396
|
+
# when per-sm warmed images are enabled, the selected GPU class picks the matching image tag.
|
|
397
|
+
# FLASH_WORKER_IMAGE remains the absolute hotfix override.
|
|
398
|
+
image = worker_image_for_gpu(friendly, allow_default=True)
|
|
399
|
+
|
|
400
|
+
def _deploy_once():
|
|
401
|
+
"""One get_or_deploy on the currently-active account (SDK + lock critical section)."""
|
|
402
|
+
# isolate_flash_state mutates runpod_flash's process-wide registry globals for this run's
|
|
403
|
+
# suffix, and ResourceManager + the deploy share the SDK's asyncio singleton. Hold the
|
|
404
|
+
# lock across the whole critical section so a concurrent run can't swap the registry
|
|
405
|
+
# scope or race the event loop mid-deploy.
|
|
406
|
+
with FLASH_SDK_LOCK:
|
|
407
|
+
isolate_flash_state(name_suffix)
|
|
408
|
+
kwargs = {
|
|
409
|
+
"name": name,
|
|
410
|
+
"gpu": flash_gpu(friendly),
|
|
411
|
+
"gpu_count": 1,
|
|
412
|
+
"min_cuda_version": min_cuda_for(friendly),
|
|
413
|
+
"execution_timeout_ms": execution_timeout_ms or DEFAULT_EXECUTION_TIMEOUT_MS,
|
|
414
|
+
"workers": (0, 1),
|
|
415
|
+
}
|
|
416
|
+
if image:
|
|
417
|
+
kwargs["image"] = image
|
|
418
|
+
else:
|
|
419
|
+
kwargs["dependencies"] = resolve_worker_deps()
|
|
420
|
+
kwargs["system_dependencies"] = WORKER_SYSTEM_DEPS
|
|
421
|
+
# Attach the multi-region weight cache (best-effort: {} when no cache / on any error).
|
|
422
|
+
# The endpoint is allowed across every cache DC, so it is NOT pinned to one region.
|
|
423
|
+
# A caller (preload) may override with a single-DC volume+datacenter.
|
|
424
|
+
# Resolve a factory FRESH on each account attempt (see docstring: avoids reusing a
|
|
425
|
+
# NetworkVolume the SDK stamped with the prior account's id across a quota failover).
|
|
426
|
+
override = endpoint_kwargs() if callable(endpoint_kwargs) else endpoint_kwargs
|
|
427
|
+
kwargs.update(override if override is not None else weight_cache_endpoint_kwargs(spec))
|
|
428
|
+
ep = Endpoint(**kwargs)
|
|
429
|
+
ep._qb_target = _train_body
|
|
430
|
+
config = ep._build_resource_config()
|
|
431
|
+
apply_disk_gb(config, disk_gb)
|
|
432
|
+
# Worker image is PUBLIC, so no container-registry credential is needed to pull it.
|
|
433
|
+
rm = ResourceManager()
|
|
434
|
+
return asyncio.run(rm.get_or_deploy_resource(config))
|
|
435
|
+
|
|
436
|
+
_QUOTA_MAX_RETRIES = 3
|
|
437
|
+
resource = None
|
|
438
|
+
# One pass over the pool: advance_key() WRAPS (always True for a multi-key pool, even after the
|
|
439
|
+
# last account), so without a bound an all-exhausted pool would fail over forever here. Cap the
|
|
440
|
+
# failovers at "every OTHER account once" and then raise — the lifecycle retry budget handles
|
|
441
|
+
# waiting for quota to recover and re-enters this with a fresh attempt.
|
|
442
|
+
failovers_left = max(0, rp_keys.key_count() - 1)
|
|
443
|
+
while resource is None:
|
|
444
|
+
ensure_auth() # collapse RUNPOD_API_KEY to the (possibly failed-over) active account key
|
|
445
|
+
quota_exc: Exception | None = None
|
|
446
|
+
for quota_attempt in range(_QUOTA_MAX_RETRIES):
|
|
447
|
+
if quota_attempt > 0:
|
|
448
|
+
# Under acute quota pressure, sweep idle orphaned flash training endpoints on THIS
|
|
449
|
+
# account NOW (min_idle_s=0) to free a slot. This only protects THIS run's endpoint,
|
|
450
|
+
# so it stays conservative (reap_warm=False): it reaps only endpoints fully scaled
|
|
451
|
+
# to zero, never another live run's between-seeds WARM endpoint. The control-plane
|
|
452
|
+
# periodic reaper does the run-aware, graced warm-idle sweep across all live runs.
|
|
453
|
+
swept = _sweep_idle_flash_endpoints(
|
|
454
|
+
protected={name, f"live-{name}"}, min_idle_s=0.0, reap_warm=False
|
|
455
|
+
)
|
|
456
|
+
wait_s = 30 * quota_attempt
|
|
457
|
+
logger.warning(
|
|
458
|
+
"RunPod worker quota hit (attempt %d/%d): swept %d idle flash-* endpoint(s); "
|
|
459
|
+
"retrying in %ds",
|
|
460
|
+
quota_attempt + 1, _QUOTA_MAX_RETRIES, swept, wait_s,
|
|
461
|
+
)
|
|
462
|
+
time.sleep(wait_s)
|
|
463
|
+
try:
|
|
464
|
+
resource = _deploy_once()
|
|
465
|
+
break # success
|
|
466
|
+
except Exception as exc:
|
|
467
|
+
if not _is_workers_quota_error(exc):
|
|
468
|
+
raise
|
|
469
|
+
quota_exc = exc # freeable: sweep + retry, then fail over to the next account
|
|
470
|
+
if resource is not None:
|
|
471
|
+
break
|
|
472
|
+
# Quota still exhausted after sweeping this account dry — fail over to the next one, but only
|
|
473
|
+
# until every account has been tried once (failovers_left). advance_key() wraps and always
|
|
474
|
+
# returns True for a multi-key pool, so the count — not its return value — is what stops us.
|
|
475
|
+
if failovers_left > 0 and rp_keys.advance_key():
|
|
476
|
+
failovers_left -= 1
|
|
477
|
+
logger.warning(
|
|
478
|
+
"RunPod worker quota exhausted on this account after sweeping; failing over to "
|
|
479
|
+
"the next RUNPOD_API_KEY account (%d configured)",
|
|
480
|
+
rp_keys.key_count(),
|
|
481
|
+
)
|
|
482
|
+
continue
|
|
483
|
+
raise quota_exc or RuntimeError("deploy_train_endpoint: worker quota exhausted")
|
|
484
|
+
|
|
485
|
+
endpoint_id = getattr(resource, "id", None)
|
|
486
|
+
if not endpoint_id:
|
|
487
|
+
raise RuntimeError(f"deploy_train_endpoint: no endpoint id on resource {resource!r}")
|
|
488
|
+
return endpoint_id, name
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def build_function_input(payload: dict) -> dict:
|
|
492
|
+
"""The FunctionRequest dict a Flash queue worker expects for `_train_body(payload)`."""
|
|
493
|
+
if os.environ.get("FLASH_WORKER_IMAGE") or WORKER_IMAGE:
|
|
494
|
+
# Baked serverless-worker image (client mode): the image's rp_handler reads job["input"]
|
|
495
|
+
# and calls _train_body, so the job input IS the train payload (submit_job wraps it in
|
|
496
|
+
# {"input": ...}). No live-function source, no boot-install deps.
|
|
497
|
+
return payload
|
|
498
|
+
# Boot-install fallback (Flash default image + live function): ship _train_body's source for the
|
|
499
|
+
# generic worker to run, plus the pinned worker deps to install on first use.
|
|
500
|
+
from runpod_flash.runtime.serialization import serialize_args
|
|
501
|
+
from runpod_flash.stubs.live_serverless import get_function_source
|
|
502
|
+
|
|
503
|
+
source, _src_hash = get_function_source(_train_body)
|
|
504
|
+
return {
|
|
505
|
+
"function_name": "_train_body",
|
|
506
|
+
"function_code": source,
|
|
507
|
+
"args": serialize_args((payload,)),
|
|
508
|
+
"accelerate_downloads": True,
|
|
509
|
+
"dependencies": resolve_worker_deps(),
|
|
510
|
+
"system_dependencies": WORKER_SYSTEM_DEPS,
|
|
511
|
+
}
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
def decode_output(output) -> dict:
|
|
515
|
+
"""Decode a queue-job output into the worker's metrics dict. Handles BOTH job shapes:
|
|
516
|
+
|
|
517
|
+
- Flash LIVE-function (boot-install path): a FunctionResponse envelope
|
|
518
|
+
``{"success": True, "result": <base64 cloudpickle of the dict>}``.
|
|
519
|
+
- Client-mode SERVERLESS handler (baked-image path): our baked rp_handler returns
|
|
520
|
+
``_train_body(...)``'s metrics dict, which RunPod surfaces as ``job["output"]`` directly —
|
|
521
|
+
no envelope. The metrics dict has no ``success``/``result`` keys, so we return it as-is.
|
|
522
|
+
"""
|
|
523
|
+
if isinstance(output, str):
|
|
524
|
+
try:
|
|
525
|
+
output = json.loads(output)
|
|
526
|
+
except json.JSONDecodeError as exc:
|
|
527
|
+
raise RuntimeError(f"unexpected job output: {output[:200]}") from exc
|
|
528
|
+
if not isinstance(output, dict):
|
|
529
|
+
raise RuntimeError(f"unexpected job output type: {type(output)}")
|
|
530
|
+
# Flash live-function envelope (has success/result/error keys).
|
|
531
|
+
if "success" in output or "result" in output:
|
|
532
|
+
if output.get("success") and output.get("result") is not None:
|
|
533
|
+
import cloudpickle
|
|
534
|
+
|
|
535
|
+
result = cloudpickle.loads(base64.b64decode(output["result"]))
|
|
536
|
+
if not isinstance(result, dict):
|
|
537
|
+
raise RuntimeError(f"flash job returned no metrics: {result!r}")
|
|
538
|
+
return result
|
|
539
|
+
err = output.get("error") or "unknown worker error"
|
|
540
|
+
stdout_tail = (output.get("stdout") or "")[-1500:]
|
|
541
|
+
raise RuntimeError(
|
|
542
|
+
f"Remote execution failed: {err}\n--- worker stdout tail ---\n{stdout_tail}"
|
|
543
|
+
)
|
|
544
|
+
# Client-mode serverless handler: the metrics dict IS the output (baked rp_handler).
|
|
545
|
+
if output.get("error"):
|
|
546
|
+
# Mirror the Flash path: append the worker stdout tail when present so poll_job's
|
|
547
|
+
# root-cause diagnostics (e.g. a vLLM crash) survive the client-mode failure shape too.
|
|
548
|
+
stdout_tail = (output.get("stdout") or "")[-1500:]
|
|
549
|
+
msg = f"Remote execution failed: {output['error']}"
|
|
550
|
+
if stdout_tail:
|
|
551
|
+
msg += f"\n--- worker stdout tail ---\n{stdout_tail}"
|
|
552
|
+
raise RuntimeError(msg)
|
|
553
|
+
return output
|
|
554
|
+
|
|
555
|
+
|
|
556
|
+
def _append_failure_artifacts(detail: str, failure_detail_reader) -> str:
|
|
557
|
+
"""Append worker-uploaded failure artifacts to a RunPod terminal-status detail."""
|
|
558
|
+
if failure_detail_reader is None:
|
|
559
|
+
return detail
|
|
560
|
+
extra = failure_detail_reader(force=True)
|
|
561
|
+
if not extra:
|
|
562
|
+
return detail
|
|
563
|
+
if detail:
|
|
564
|
+
return f"{detail}\n{extra}"
|
|
565
|
+
return extra
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
def poll_job(
|
|
569
|
+
handle: JobHandle,
|
|
570
|
+
log=None,
|
|
571
|
+
interval_s: float = 10.0,
|
|
572
|
+
heartbeat_reader=None,
|
|
573
|
+
failure_detail_reader=None,
|
|
574
|
+
stall_after_s: float = 1200.0,
|
|
575
|
+
setup_grace_s: float = 3000.0,
|
|
576
|
+
unhealthy_grace_s: float = 240.0,
|
|
577
|
+
throttled_grace_s: float = 300.0,
|
|
578
|
+
queue_grace_s: float = 300.0,
|
|
579
|
+
deadline_s: float | None = None,
|
|
580
|
+
) -> PollResult:
|
|
581
|
+
"""Poll a queue job to completion; resilient to transient API errors.
|
|
582
|
+
|
|
583
|
+
Two stall windows: the cold-start phase (dep install, per-run env pip, model download,
|
|
584
|
+
vLLM init) is slow and only emits *setup* heartbeats (``_SETUP_HEARTBEAT_STAGES``).
|
|
585
|
+
Until a *training* heartbeat arrives we apply the larger ``setup_grace_s`` budget so a
|
|
586
|
+
slow cold start isn't misread as a stall; after it we use the tight ``stall_after_s``.
|
|
587
|
+
Needs a ``heartbeat_reader`` to tell the phases apart — without one we keep
|
|
588
|
+
``stall_after_s`` throughout (no regression).
|
|
589
|
+
|
|
590
|
+
``failure_detail_reader`` force-reads worker-uploaded artifacts (``error_<phase>.txt`` and
|
|
591
|
+
``console_<phase>.txt``) after a worker terminal failure so a generic RunPod handler wrapper
|
|
592
|
+
does not hide the real traceback.
|
|
593
|
+
|
|
594
|
+
``throttled_grace_s`` bounds how long we wait on a worker stuck THROTTLED (no RunPod
|
|
595
|
+
capacity for the pinned GPU class) before returning a retryable stall so the runner
|
|
596
|
+
walks to the next-best GPU. THROTTLED means there is no capacity for this class right now, so
|
|
597
|
+
once a class with a cheaper fallback has stayed throttled this long, failing over beats
|
|
598
|
+
blocking the run on a host that won't free up. ``stall_kwargs`` sets this to ~5 min while the
|
|
599
|
+
gpu-walk still has a next-best class, and ~15 min on the last candidate (nowhere left to walk).
|
|
600
|
+
|
|
601
|
+
``queue_grace_s`` is the capacity backstop for that same walk when RunPod *doesn't* surface
|
|
602
|
+
a THROTTLED/UNHEALTHY worker: a job can sit IN_QUEUE with zero workers assigned (or one stuck
|
|
603
|
+
INITIALIZING, or while ``endpoint_health`` errors are swallowed below) and the throttled/
|
|
604
|
+
unhealthy fast-fails never arm — so without this it would burn the full ``setup_grace_s``
|
|
605
|
+
(~50 min). Keyed off the authoritative job status (robust to a failing health probe), it
|
|
606
|
+
returns a retryable stall once a job has been IN_QUEUE longer than ``queue_grace_s`` (tuned by
|
|
607
|
+
``stall_kwargs`` like ``throttled_grace_s``: ~5 min normally, ~15 min on the last GPU class).
|
|
608
|
+
The queue timer applies only while the job status remains IN_QUEUE; once a worker picks the
|
|
609
|
+
job up (status leaves IN_QUEUE), it resets and ``setup_grace_s`` governs cold start.
|
|
610
|
+
"""
|
|
611
|
+
|
|
612
|
+
say = make_say(log)
|
|
613
|
+
poll_errors = PollErrorTracker(say, interval_s)
|
|
614
|
+
|
|
615
|
+
start = time.time()
|
|
616
|
+
last_status = None
|
|
617
|
+
last_hb_key = None
|
|
618
|
+
last_progress = time.time()
|
|
619
|
+
seen_heartbeat = False
|
|
620
|
+
last_health_probe = 0.0
|
|
621
|
+
unhealthy_since: float | None = None # first time the worker was seen stuck UNHEALTHY
|
|
622
|
+
throttled_since: float | None = None # first time the worker was seen stuck THROTTLED
|
|
623
|
+
queued_since: float | None = None # first time the job was seen IN_QUEUE with no worker yet
|
|
624
|
+
while True:
|
|
625
|
+
if deadline_s is not None and time.time() - start > deadline_s:
|
|
626
|
+
return PollResult(False, failure="stalled", detail="client-side deadline exceeded")
|
|
627
|
+
try:
|
|
628
|
+
st = runpod_api.job_status(handle.endpoint_id, handle.job_id)
|
|
629
|
+
poll_errors.reset()
|
|
630
|
+
except runpod_api.RunpodApiError as e:
|
|
631
|
+
if poll_errors.record(e):
|
|
632
|
+
return PollResult(False, failure="poll_error", detail=str(e))
|
|
633
|
+
continue
|
|
634
|
+
status = st.get("status")
|
|
635
|
+
if status != last_status:
|
|
636
|
+
say(f"job {handle.job_id}: {status}")
|
|
637
|
+
last_status = status
|
|
638
|
+
last_progress = time.time()
|
|
639
|
+
if status in TERMINAL_OK:
|
|
640
|
+
try:
|
|
641
|
+
return PollResult(True, metrics=decode_output(st.get("output")))
|
|
642
|
+
except RuntimeError as e:
|
|
643
|
+
# COMPLETED but the output decodes as an error (a handler exception). Consult the
|
|
644
|
+
# worker flag too: an infra failure can surface here and must still retry.
|
|
645
|
+
last_hb_key, _ = surface_forced_heartbeat(heartbeat_reader, last_hb_key, say)
|
|
646
|
+
retriable = worker_flagged_retriable(heartbeat_reader)
|
|
647
|
+
detail = _append_failure_artifacts(str(e), failure_detail_reader)
|
|
648
|
+
return PollResult(
|
|
649
|
+
False,
|
|
650
|
+
failure="job_preempted" if retriable else "job_failed",
|
|
651
|
+
detail=detail,
|
|
652
|
+
)
|
|
653
|
+
if status in TERMINAL_FAIL:
|
|
654
|
+
detail = str(st.get("error") or "")[:1500]
|
|
655
|
+
out = st.get("output")
|
|
656
|
+
if isinstance(out, dict) and out.get("stdout"):
|
|
657
|
+
# Worker stdout tail is the only place the REAL root cause lives for
|
|
658
|
+
# crashes inside subprocesses (e.g. vLLM EngineCore deaths).
|
|
659
|
+
detail += "\n--- worker stdout tail ---\n" + str(out["stdout"])[-2000:]
|
|
660
|
+
elif not detail:
|
|
661
|
+
detail = str(out)[:1500]
|
|
662
|
+
# Structural classification only ([{status}] prefix is for human-readable logs).
|
|
663
|
+
# A platform termination (CANCELLED/TIMED_OUT) is already retryable — skip the worker
|
|
664
|
+
# heartbeat read entirely (no worker error there, and it may not even exist yet).
|
|
665
|
+
if status in PLATFORM_TERMINATIONS:
|
|
666
|
+
return PollResult(False, failure="job_preempted", detail=f"[{status}] {detail}")
|
|
667
|
+
# A worker FAILED: consult the structured worker flag (one forced heartbeat read).
|
|
668
|
+
last_hb_key, _ = surface_forced_heartbeat(heartbeat_reader, last_hb_key, say)
|
|
669
|
+
retriable = worker_flagged_retriable(heartbeat_reader)
|
|
670
|
+
detail = _append_failure_artifacts(detail, failure_detail_reader)
|
|
671
|
+
return PollResult(
|
|
672
|
+
False,
|
|
673
|
+
failure="job_preempted" if retriable else "job_failed",
|
|
674
|
+
detail=f"[{status}] {detail}",
|
|
675
|
+
)
|
|
676
|
+
# Capacity backstop: bound how long the job may sit IN_QUEUE (no worker has accepted it).
|
|
677
|
+
# The throttled/unhealthy fast-fails below only arm when endpoint_health succeeds AND RunPod
|
|
678
|
+
# reports a THROTTLED/UNHEALTHY worker; a queue with zero workers, one stuck INITIALIZING, or
|
|
679
|
+
# a health probe that keeps erroring (its block is wrapped in `except: pass`) bypasses them and
|
|
680
|
+
# would otherwise wait the full setup_grace_s (~50 min). Keyed off the authoritative job status
|
|
681
|
+
# so it holds even when the health probe is blind: once IN_QUEUE exceeds queue_grace_s, return a
|
|
682
|
+
# retryable stall so the runner's gpu-walk re-provisions on the next-best (in-capacity) class.
|
|
683
|
+
now = time.time()
|
|
684
|
+
if status == "IN_QUEUE":
|
|
685
|
+
if queued_since is None:
|
|
686
|
+
queued_since = now
|
|
687
|
+
elif now - queued_since > queue_grace_s:
|
|
688
|
+
return PollResult(
|
|
689
|
+
False,
|
|
690
|
+
failure="no_capacity",
|
|
691
|
+
detail=f"never scheduled: job stuck IN_QUEUE for {int(now - queued_since)}s "
|
|
692
|
+
"(no RunPod capacity for the pinned GPU class); retrying on the next-best GPU",
|
|
693
|
+
)
|
|
694
|
+
else:
|
|
695
|
+
queued_since = None
|
|
696
|
+
# While queued, surface worker availability (throttled hosts are the common
|
|
697
|
+
# cause of silent multi-minute waits — make them visible in the run log).
|
|
698
|
+
if status == "IN_QUEUE" and now - last_health_probe > 90:
|
|
699
|
+
last_health_probe = now
|
|
700
|
+
try:
|
|
701
|
+
h = runpod_api.endpoint_health(handle.endpoint_id)
|
|
702
|
+
workers = h.get("workers") or {}
|
|
703
|
+
usable = workers.get("running") or workers.get("ready") or workers.get("idle")
|
|
704
|
+
recovering = workers.get("initializing")
|
|
705
|
+
if (
|
|
706
|
+
any(workers.get(k) for k in ("throttled", "unhealthy", "initializing"))
|
|
707
|
+
or not usable
|
|
708
|
+
):
|
|
709
|
+
say(f"queued; workers: {workers}")
|
|
710
|
+
# Fail fast on a worker stuck UNHEALTHY: a dead worker / failed image pull won't
|
|
711
|
+
# self-recover, so don't burn the full setup_grace_s (~50 min) waiting on it — once
|
|
712
|
+
# it has stayed unhealthy with nothing usable or (re)initializing for
|
|
713
|
+
# unhealthy_grace_s, return a (retryable) stall so the runner re-provisions a FRESH
|
|
714
|
+
# endpoint (fresh image pull, likely a different host). Observed: a mutable image
|
|
715
|
+
# tag republished mid-pull corrupts the worker -> unhealthy, and a fresh pull fixes it.
|
|
716
|
+
if workers.get("unhealthy") and not usable and not recovering:
|
|
717
|
+
if unhealthy_since is None:
|
|
718
|
+
unhealthy_since = time.time()
|
|
719
|
+
elif time.time() - unhealthy_since > unhealthy_grace_s:
|
|
720
|
+
return PollResult(
|
|
721
|
+
False,
|
|
722
|
+
failure="stalled",
|
|
723
|
+
detail=f"worker stuck unhealthy for "
|
|
724
|
+
f"{int(time.time() - unhealthy_since)}s while IN_QUEUE (likely a failed "
|
|
725
|
+
f"image pull); retrying on a fresh endpoint",
|
|
726
|
+
)
|
|
727
|
+
else:
|
|
728
|
+
unhealthy_since = None # recovered / usable worker appeared
|
|
729
|
+
# Fail fast on a worker stuck THROTTLED: RunPod has no capacity for the pinned GPU
|
|
730
|
+
# class/pool and a throttled worker won't self-recover, so don't burn the full
|
|
731
|
+
# setup_grace_s (~50 min) waiting on it. Once it has stayed throttled with nothing
|
|
732
|
+
# usable or (re)initializing for throttled_grace_s, return a (retryable) stall so
|
|
733
|
+
# the runner's gpu-walk re-provisions on the NEXT-BEST GPU class — the cheapest fit
|
|
734
|
+
# often has no capacity while the next-best (a few cents/hr more) does.
|
|
735
|
+
if workers.get("throttled") and not usable and not recovering:
|
|
736
|
+
if throttled_since is None:
|
|
737
|
+
throttled_since = time.time()
|
|
738
|
+
elif time.time() - throttled_since > throttled_grace_s:
|
|
739
|
+
return PollResult(
|
|
740
|
+
False,
|
|
741
|
+
failure="no_capacity",
|
|
742
|
+
detail=f"never scheduled: worker stuck THROTTLED for "
|
|
743
|
+
f"{int(time.time() - throttled_since)}s while IN_QUEUE (no RunPod "
|
|
744
|
+
f"capacity for the pinned GPU class); retrying on the next-best GPU",
|
|
745
|
+
)
|
|
746
|
+
else:
|
|
747
|
+
throttled_since = None # capacity appeared / usable worker
|
|
748
|
+
except Exception:
|
|
749
|
+
# Health surfacing is diagnostic only; a probe failure must not stop polling.
|
|
750
|
+
pass
|
|
751
|
+
# heartbeat progress surfacing + stall detection
|
|
752
|
+
new_key, stage = surface_heartbeat(heartbeat_reader, last_hb_key, say)
|
|
753
|
+
if new_key != last_hb_key:
|
|
754
|
+
last_hb_key = new_key
|
|
755
|
+
last_progress = time.time()
|
|
756
|
+
# Only a training-phase heartbeat means cold-start setup is done and we
|
|
757
|
+
# can switch to the tight window; setup heartbeats keep the grace budget.
|
|
758
|
+
if stage not in _SETUP_HEARTBEAT_STAGES:
|
|
759
|
+
seen_heartbeat = True
|
|
760
|
+
# Cold start (before any training-phase heartbeat) gets the larger setup_grace_s,
|
|
761
|
+
# but only when a heartbeat_reader lets us tell setup from training; without one we
|
|
762
|
+
# can't, so stay on stall_after_s (no regression).
|
|
763
|
+
in_setup = heartbeat_reader is not None and not seen_heartbeat
|
|
764
|
+
stall_limit = setup_grace_s if in_setup else stall_after_s
|
|
765
|
+
if time.time() - last_progress > stall_limit:
|
|
766
|
+
phase = "setup (pre-training)" if in_setup else "training"
|
|
767
|
+
return PollResult(
|
|
768
|
+
False,
|
|
769
|
+
failure="stalled",
|
|
770
|
+
detail=f"no worker progress for {int(time.time() - last_progress)}s "
|
|
771
|
+
f"during {phase} (job status {status}, limit {int(stall_limit)}s)",
|
|
772
|
+
)
|
|
773
|
+
time.sleep(interval_s)
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
def submit_run(
|
|
777
|
+
spec,
|
|
778
|
+
seed: int,
|
|
779
|
+
log=None,
|
|
780
|
+
on_handle=None,
|
|
781
|
+
attempt: int = 0,
|
|
782
|
+
runtime_secrets: dict[str, str] | None = None,
|
|
783
|
+
on_last_gpu: bool = False,
|
|
784
|
+
) -> PollResult:
|
|
785
|
+
"""Durable equivalent of ``submit_train``: deploy, submit, persist handle, poll.
|
|
786
|
+
|
|
787
|
+
``on_handle(handle_dict)`` is invoked as soon as the job is queued so the
|
|
788
|
+
runner can persist {endpoint_id, job_id} for cross-process reattach.
|
|
789
|
+
|
|
790
|
+
``on_last_gpu`` tells the no-capacity backstops no further GPU attempt will follow this one
|
|
791
|
+
(candidate list exhausted OR retry budget exhausted), so there is no next-best class to walk to
|
|
792
|
+
and they wait longer before giving up (see ``stall_kwargs``).
|
|
793
|
+
"""
|
|
794
|
+
from flash.envs.registry import worker_pip_for_env
|
|
795
|
+
from flash.providers.runpod.train import _run_suffix, build_worker_env, chalk_extra_pip
|
|
796
|
+
|
|
797
|
+
timeout_s = max(60, int(spec.gpu.max_wall_seconds))
|
|
798
|
+
# Per-attempt endpoint name: a retry must land on a genuinely fresh endpoint —
|
|
799
|
+
# reusing the name lets the SDK/platform pin the job back onto the same
|
|
800
|
+
# (possibly throttled/sick) host.
|
|
801
|
+
suffix = _run_suffix(spec.run_id)
|
|
802
|
+
if attempt:
|
|
803
|
+
suffix = f"{suffix}r{attempt}"
|
|
804
|
+
# Resolve worker pip deps BEFORE provisioning, so deterministic dependency issues surface
|
|
805
|
+
# before the endpoint exists.
|
|
806
|
+
# extra_pip runs for EVERY job here (the durable baked-image path skips resolve_worker_deps
|
|
807
|
+
# in build_function_input, but _train_body always pip-installs extra_pip), so the chalk spec
|
|
808
|
+
# is appended here to reach default runs.
|
|
809
|
+
extra_pip = (
|
|
810
|
+
list(spec.environment.pip) or worker_pip_for_env(spec.environment.id)
|
|
811
|
+
) + chalk_extra_pip(spec)
|
|
812
|
+
worker_env = build_worker_env(spec, seed, runtime_secrets=runtime_secrets)
|
|
813
|
+
worker_env["ATTEMPT"] = str(int(attempt))
|
|
814
|
+
endpoint_id, name = deploy_train_endpoint(
|
|
815
|
+
spec.gpu.type,
|
|
816
|
+
execution_timeout_ms=timeout_s * 1000,
|
|
817
|
+
name_suffix=suffix,
|
|
818
|
+
disk_gb=spec.gpu.disk_gb,
|
|
819
|
+
spec=spec,
|
|
820
|
+
)
|
|
821
|
+
payload = {
|
|
822
|
+
"hf_repo": spec.train.hf_repo,
|
|
823
|
+
"job_spec_json": spec.to_json(),
|
|
824
|
+
"phase": spec.phase,
|
|
825
|
+
"seed": int(seed),
|
|
826
|
+
"env": worker_env,
|
|
827
|
+
"extra_pip": extra_pip,
|
|
828
|
+
}
|
|
829
|
+
try:
|
|
830
|
+
job_id = runpod_api.submit_job(endpoint_id, build_function_input(payload))
|
|
831
|
+
except Exception:
|
|
832
|
+
# The endpoint is registered but no run handle exists yet, and a
|
|
833
|
+
# retry endpoint's rN-suffixed name can't be reconstructed from the run
|
|
834
|
+
# id later — delete it now so a transient submit failure doesn't leak a
|
|
835
|
+
# serverless endpoint against the account quota.
|
|
836
|
+
with contextlib.suppress(Exception):
|
|
837
|
+
runpod_api.delete_endpoint(endpoint_id)
|
|
838
|
+
raise
|
|
839
|
+
handle = JobHandle(endpoint_id, name, job_id)
|
|
840
|
+
if log is not None:
|
|
841
|
+
print(
|
|
842
|
+
f"submitted job: endpoint={name} ({endpoint_id}) job={job_id} "
|
|
843
|
+
f"attempt={attempt} gpu={spec.gpu.type} phase={spec.phase} seed={seed}",
|
|
844
|
+
file=log,
|
|
845
|
+
flush=True,
|
|
846
|
+
)
|
|
847
|
+
if on_handle is not None:
|
|
848
|
+
on_handle(handle.to_dict())
|
|
849
|
+
hf_repo = spec.train.hf_repo
|
|
850
|
+
prefix = f"{spec.phase}/{spec.run_id}/seed{seed}"
|
|
851
|
+
reader = make_hf_heartbeat_reader(hf_repo, prefix) if hf_repo else None
|
|
852
|
+
failure_reader = (
|
|
853
|
+
make_hf_failure_detail_reader(hf_repo, prefix, spec.phase) if hf_repo else None
|
|
854
|
+
)
|
|
855
|
+
return poll_job(
|
|
856
|
+
handle,
|
|
857
|
+
log=log,
|
|
858
|
+
heartbeat_reader=reader,
|
|
859
|
+
failure_detail_reader=failure_reader,
|
|
860
|
+
**stall_kwargs(on_last_gpu=on_last_gpu),
|
|
861
|
+
)
|
|
862
|
+
|
|
863
|
+
|
|
864
|
+
def make_hf_text_reader(hf_repo: str, path_in_repo: str, min_interval_s: float = 45.0):
|
|
865
|
+
"""Rate-limited reader for one HF artifact's text content (None until it exists).
|
|
866
|
+
|
|
867
|
+
Generic helper for HF-backed worker artifacts and heartbeats. ``read(force=False)``
|
|
868
|
+
re-downloads at most once per
|
|
869
|
+
``min_interval_s`` (``force=True`` bypasses the gate); it never raises — any HF error
|
|
870
|
+
(artifact absent, network) returns None.
|
|
871
|
+
"""
|
|
872
|
+
state = {"last": 0.0}
|
|
873
|
+
|
|
874
|
+
def read(force: bool = False) -> str | None:
|
|
875
|
+
if not hf_repo:
|
|
876
|
+
return None
|
|
877
|
+
if not force and time.time() - state["last"] < min_interval_s:
|
|
878
|
+
return None
|
|
879
|
+
state["last"] = time.time()
|
|
880
|
+
try:
|
|
881
|
+
from huggingface_hub import hf_hub_download
|
|
882
|
+
|
|
883
|
+
p = hf_hub_download(
|
|
884
|
+
hf_repo,
|
|
885
|
+
path_in_repo,
|
|
886
|
+
repo_type="dataset",
|
|
887
|
+
token=os.environ.get("HF_TOKEN"),
|
|
888
|
+
force_download=True,
|
|
889
|
+
)
|
|
890
|
+
with open(p) as f:
|
|
891
|
+
return f.read()
|
|
892
|
+
except Exception:
|
|
893
|
+
return None
|
|
894
|
+
|
|
895
|
+
return read
|
|
896
|
+
|
|
897
|
+
|
|
898
|
+
def make_hf_heartbeat_reader(hf_repo: str, prefix: str, min_interval_s: float = 30.0):
|
|
899
|
+
"""Reader for the worker's heartbeat.json on HF (rate-limited, never raises).
|
|
900
|
+
|
|
901
|
+
Thin JSON-parsing wrapper over :func:`make_hf_text_reader` bound to ``{prefix}/heartbeat.json``.
|
|
902
|
+
"""
|
|
903
|
+
text_reader = make_hf_text_reader(hf_repo, f"{prefix}/heartbeat.json", min_interval_s)
|
|
904
|
+
|
|
905
|
+
def read(force: bool = False) -> dict | None:
|
|
906
|
+
raw = text_reader(force=force)
|
|
907
|
+
if raw is None:
|
|
908
|
+
return None
|
|
909
|
+
try:
|
|
910
|
+
return json.loads(raw)
|
|
911
|
+
except (ValueError, TypeError):
|
|
912
|
+
return None
|
|
913
|
+
|
|
914
|
+
return read
|
|
915
|
+
|
|
916
|
+
|
|
917
|
+
def make_hf_failure_detail_reader(
|
|
918
|
+
hf_repo: str,
|
|
919
|
+
prefix: str,
|
|
920
|
+
phase: str,
|
|
921
|
+
min_interval_s: float = 45.0,
|
|
922
|
+
):
|
|
923
|
+
"""Reader for worker-uploaded RunPod failure artifacts on HF.
|
|
924
|
+
|
|
925
|
+
The RunPod queue often reports only the outer handler error (for example, "produced no
|
|
926
|
+
/tmp/metrics.json"). The worker writes the actual traceback and console tail to HF; this
|
|
927
|
+
reader lets ``poll_job`` force-download those files after a terminal worker failure.
|
|
928
|
+
"""
|
|
929
|
+
error_reader = make_hf_text_reader(hf_repo, f"{prefix}/error_{phase}.txt", min_interval_s)
|
|
930
|
+
console_reader = make_hf_text_reader(
|
|
931
|
+
hf_repo, f"{prefix}/console_{phase}.txt", min_interval_s
|
|
932
|
+
)
|
|
933
|
+
|
|
934
|
+
def read(force: bool = False) -> str | None:
|
|
935
|
+
parts: list[str] = []
|
|
936
|
+
error_text = error_reader(force=force)
|
|
937
|
+
if error_text:
|
|
938
|
+
parts.append(f"--- error_{phase}.txt ---\n{error_text[-4000:]}")
|
|
939
|
+
console_text = console_reader(force=force)
|
|
940
|
+
if console_text:
|
|
941
|
+
parts.append(f"--- console_{phase}.txt ---\n{console_text[-4000:]}")
|
|
942
|
+
return "\n".join(parts) if parts else None
|
|
943
|
+
|
|
944
|
+
return read
|
|
945
|
+
|
|
946
|
+
|
|
947
|
+
def worker_flagged_retriable(heartbeat_reader) -> bool:
|
|
948
|
+
"""True if the worker stamped ``retriable`` (a RetriableInfraError) in its last heartbeat — the
|
|
949
|
+
structured worker<->poller contract that replaces failure-detail parsing: ``retriable`` means
|
|
950
|
+
retry on a fresh worker. Forces a fresh read past the rate limit."""
|
|
951
|
+
if heartbeat_reader is None:
|
|
952
|
+
return False
|
|
953
|
+
hb = heartbeat_reader(force=True)
|
|
954
|
+
if not isinstance(hb, dict):
|
|
955
|
+
return False
|
|
956
|
+
return bool(hb.get("retriable"))
|