freesolo-flash-dev 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flash/__init__.py +29 -0
- flash/_channel.py +23 -0
- flash/_fileio.py +35 -0
- flash/_logging.py +49 -0
- flash/_update_check.py +266 -0
- flash/catalog.py +253 -0
- flash/cli/__init__.py +1 -0
- flash/cli/main/__init__.py +227 -0
- flash/cli/main/__main__.py +6 -0
- flash/cli/main/commands.py +636 -0
- flash/cli/main/envpush.py +317 -0
- flash/cli/main/render.py +599 -0
- flash/cli/main/training_doc.py +455 -0
- flash/client/__init__.py +14 -0
- flash/client/config.py +70 -0
- flash/client/http.py +372 -0
- flash/client/runtime_secrets.py +69 -0
- flash/client/specs.py +20 -0
- flash/cost/__init__.py +16 -0
- flash/cost/analytical.py +175 -0
- flash/cost/facts.py +114 -0
- flash/cost/spec.py +113 -0
- flash/cost/types.py +158 -0
- flash/engine/__init__.py +6 -0
- flash/engine/accounting.py +36 -0
- flash/engine/chalk_kernels.py +116 -0
- flash/engine/multiturn_rollout.py +780 -0
- flash/engine/recipe.py +86 -0
- flash/engine/vram.py +603 -0
- flash/engine/worker/__init__.py +2916 -0
- flash/engine/worker/__main__.py +4 -0
- flash/engine/worker/kernel_warmup.py +400 -0
- flash/engine/worker/lora.py +796 -0
- flash/engine/worker/packing.py +366 -0
- flash/engine/worker/perf.py +1048 -0
- flash/envs/__init__.py +10 -0
- flash/envs/adapter/__init__.py +883 -0
- flash/envs/adapter/rubric.py +222 -0
- flash/envs/base.py +52 -0
- flash/envs/registry.py +62 -0
- flash/mcp/__init__.py +1 -0
- flash/mcp/server.py +85 -0
- flash/providers/__init__.py +59 -0
- flash/providers/_auth.py +24 -0
- flash/providers/_http.py +230 -0
- flash/providers/_instance.py +416 -0
- flash/providers/_instance_bootstrap.py +517 -0
- flash/providers/_poll.py +311 -0
- flash/providers/allocator.py +193 -0
- flash/providers/base.py +431 -0
- flash/providers/hyperstack/__init__.py +127 -0
- flash/providers/hyperstack/api.py +522 -0
- flash/providers/hyperstack/auth.py +17 -0
- flash/providers/hyperstack/gpus.py +29 -0
- flash/providers/hyperstack/jobs/__init__.py +632 -0
- flash/providers/hyperstack/jobs/builders.py +122 -0
- flash/providers/hyperstack/preflight.py +23 -0
- flash/providers/hyperstack/pricing.py +26 -0
- flash/providers/hyperstack/train.py +25 -0
- flash/providers/lambdalabs/__init__.py +139 -0
- flash/providers/lambdalabs/api.py +261 -0
- flash/providers/lambdalabs/auth.py +18 -0
- flash/providers/lambdalabs/gpus.py +29 -0
- flash/providers/lambdalabs/jobs/__init__.py +724 -0
- flash/providers/lambdalabs/jobs/builders.py +118 -0
- flash/providers/lambdalabs/preflight.py +27 -0
- flash/providers/lambdalabs/pricing.py +51 -0
- flash/providers/lambdalabs/train.py +27 -0
- flash/providers/preflight.py +55 -0
- flash/providers/realized.py +80 -0
- flash/providers/runpod/__init__.py +130 -0
- flash/providers/runpod/api.py +186 -0
- flash/providers/runpod/auth.py +37 -0
- flash/providers/runpod/cost.py +57 -0
- flash/providers/runpod/gpus.py +46 -0
- flash/providers/runpod/jobs.py +956 -0
- flash/providers/runpod/keys.py +139 -0
- flash/providers/runpod/preflight.py +30 -0
- flash/providers/runpod/preload.py +915 -0
- flash/providers/runpod/pricing.py +18 -0
- flash/providers/runpod/slots.py +79 -0
- flash/providers/runpod/train/__init__.py +150 -0
- flash/providers/runpod/train/deps.py +395 -0
- flash/providers/runpod/train/endpoints.py +820 -0
- flash/py.typed +0 -0
- flash/runner/__init__.py +686 -0
- flash/runner/checkpoints.py +82 -0
- flash/runner/deploy.py +422 -0
- flash/runner/lifecycle.py +672 -0
- flash/schema/__init__.py +375 -0
- flash/schema/fields.py +331 -0
- flash/serve/__init__.py +1 -0
- flash/serve/deploy.py +326 -0
- flash/serve/pricing.py +60 -0
- flash/server/__init__.py +1 -0
- flash/server/__main__.py +20 -0
- flash/server/app.py +961 -0
- flash/server/auth.py +263 -0
- flash/server/billing.py +124 -0
- flash/server/checkpoints.py +110 -0
- flash/server/db.py +160 -0
- flash/server/environment_registry.py +102 -0
- flash/server/envs.py +360 -0
- flash/server/reconcile.py +163 -0
- flash/server/run_registry.py +150 -0
- flash/spec.py +333 -0
- freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
- freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
- freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
- freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
- freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,820 @@
|
|
|
1
|
+
"""RunPod Flash endpoint lifecycle: provision/cache/teardown + the worker handler.
|
|
2
|
+
|
|
3
|
+
The live ("ad-hoc") endpoint deploy/cache (``get_train_endpoint``), the worker-side
|
|
4
|
+
training handler it registers (``_train_body``), the per-run endpoint naming/suffix,
|
|
5
|
+
the runpod_flash SDK state-isolation + backoff patches, and cross-process teardown
|
|
6
|
+
(``terminate_endpoint``). Imports the dependency stack + worker env builder from
|
|
7
|
+
``.deps`` (the leaf).
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import asyncio
|
|
13
|
+
import contextlib
|
|
14
|
+
import os
|
|
15
|
+
import threading
|
|
16
|
+
import time
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
from flash.providers.base import canonical_gpu, gpu_short
|
|
20
|
+
from flash.providers.runpod.gpus import flash_gpu
|
|
21
|
+
from flash.providers.runpod.train.deps import (
|
|
22
|
+
DEFAULT_EXECUTION_TIMEOUT_MS,
|
|
23
|
+
WORKER_SYSTEM_DEPS,
|
|
24
|
+
logger,
|
|
25
|
+
resolve_worker_deps,
|
|
26
|
+
worker_image_for_gpu,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
# The control plane runs each training run in its own thread. All runpod_flash deploy/
|
|
30
|
+
# undeploy work goes through a shared asyncio singleton whose Lock binds to the first event
|
|
31
|
+
# loop that touches it; two threads each calling asyncio.run() get distinct loops and the
|
|
32
|
+
# second fails with "Lock ... is bound to a different event loop". Serialize every Flash SDK
|
|
33
|
+
# async section (deploy AND undeploy) behind this one process-wide lock. Deploys/teardowns
|
|
34
|
+
# are infrequent vs training, so the serialization cost is negligible.
|
|
35
|
+
FLASH_SDK_LOCK = threading.Lock()
|
|
36
|
+
|
|
37
|
+
# Quota: cap in-flight endpoints under RunPod's account-wide max-workers quota (30; the cap of
|
|
38
|
+
# 28 leaves a 2-slot buffer for endpoints deployed by other tools/the CLI). Every endpoint this
|
|
39
|
+
# process creates claims one slot; terminate_endpoint releases it once the remote endpoint is
|
|
40
|
+
# provably torn down. Runs that find the quota full QUEUE (block) here instead of failing with
|
|
41
|
+
# RunPod's "Max workers across all endpoints must not exceed your workers quota (30)" error.
|
|
42
|
+
#
|
|
43
|
+
# The store is CROSS-PROCESS when an operator internal key is configured: the claim is an
|
|
44
|
+
# advisory-locked atomic op in Postgres (via the freesolo backend), so >1 control-plane replica
|
|
45
|
+
# can never together exceed the cap, and a startup reconcile recovers the true in-use count after
|
|
46
|
+
# a crash. Without an internal key (local/dev single process) we fall back to an in-process
|
|
47
|
+
# semaphore. Releases route by how the slot was acquired (recorded in ``_ACQUIRED``), so a slot
|
|
48
|
+
# claimed against the shared store is always released there.
|
|
49
|
+
RUNPOD_ENDPOINT_SLOT_CAP = 28
|
|
50
|
+
# How long to wait before re-checking a full shared quota (queue-don't-crash), and how many
|
|
51
|
+
# consecutive store errors to tolerate before falling back to the local semaphore so a single
|
|
52
|
+
# process stays capped rather than deadlocking on a persistently-unreachable backend.
|
|
53
|
+
_SLOT_QUEUE_WAIT_S = 10.0
|
|
54
|
+
_SLOT_STORE_MAX_ERRORS = 6
|
|
55
|
+
|
|
56
|
+
# Local (single-process) fallback, used when no internal key is configured or the shared store is
|
|
57
|
+
# persistently unreachable. Resets on restart — fine, because a fresh process holds no slots.
|
|
58
|
+
_LOCAL_SLOTS = threading.Semaphore(RUNPOD_ENDPOINT_SLOT_CAP)
|
|
59
|
+
# name -> "shared" | "local": how this process acquired each held slot, so release routes to the
|
|
60
|
+
# same store. Makes release idempotent — terminate_endpoint may run more than once for a name
|
|
61
|
+
# (e.g. a retry after a failed undeploy), and only the call that still finds the name releases.
|
|
62
|
+
_ACQUIRED: dict[str, str] = {}
|
|
63
|
+
_ACQUIRED_LOCK = threading.Lock()
|
|
64
|
+
|
|
65
|
+
_ENDPOINT_CACHE: dict[str, Any] = {}
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _acquire_local_slot(name: str) -> None:
|
|
69
|
+
"""Claim an in-process semaphore slot (no internal key / store-unreachable fallback)."""
|
|
70
|
+
if not _LOCAL_SLOTS.acquire(blocking=False):
|
|
71
|
+
logger.info(
|
|
72
|
+
"Quota full (%d/%d slots) — waiting for a free slot...",
|
|
73
|
+
RUNPOD_ENDPOINT_SLOT_CAP,
|
|
74
|
+
RUNPOD_ENDPOINT_SLOT_CAP,
|
|
75
|
+
)
|
|
76
|
+
_LOCAL_SLOTS.acquire()
|
|
77
|
+
with _ACQUIRED_LOCK:
|
|
78
|
+
_ACQUIRED[name] = "local"
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _acquire_endpoint_slot(name: str) -> None:
|
|
82
|
+
"""Claim a quota slot for ``name``, QUEUEING (blocking) until one is free. Idempotent per name.
|
|
83
|
+
|
|
84
|
+
Uses the shared cross-process store when an internal key is set, else the in-process
|
|
85
|
+
semaphore. On persistent store errors it falls back to the local semaphore so a single
|
|
86
|
+
process stays capped instead of deadlocking or oversubscribing.
|
|
87
|
+
"""
|
|
88
|
+
with _ACQUIRED_LOCK:
|
|
89
|
+
if name in _ACQUIRED:
|
|
90
|
+
return # already hold a slot for this name (e.g. _ENDPOINT_CACHE re-entry)
|
|
91
|
+
from flash.providers.runpod import slots
|
|
92
|
+
|
|
93
|
+
if slots.internal_key() is None:
|
|
94
|
+
_acquire_local_slot(name)
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
errors = 0
|
|
98
|
+
queued = False
|
|
99
|
+
while True:
|
|
100
|
+
try:
|
|
101
|
+
claimed, in_use = slots.claim(
|
|
102
|
+
name, cap=RUNPOD_ENDPOINT_SLOT_CAP, claimed_by=slots.claimed_by_ident()
|
|
103
|
+
)
|
|
104
|
+
except slots.SlotStoreError as exc:
|
|
105
|
+
errors += 1
|
|
106
|
+
logger.warning(
|
|
107
|
+
"slot-store claim failed for %s (%s) [%d/%d]",
|
|
108
|
+
name,
|
|
109
|
+
exc,
|
|
110
|
+
errors,
|
|
111
|
+
_SLOT_STORE_MAX_ERRORS,
|
|
112
|
+
)
|
|
113
|
+
if errors >= _SLOT_STORE_MAX_ERRORS:
|
|
114
|
+
logger.error(
|
|
115
|
+
"slot store unreachable; falling back to the in-process cap for %s", name
|
|
116
|
+
)
|
|
117
|
+
_acquire_local_slot(name)
|
|
118
|
+
return
|
|
119
|
+
time.sleep(_SLOT_QUEUE_WAIT_S)
|
|
120
|
+
continue
|
|
121
|
+
errors = 0
|
|
122
|
+
if claimed:
|
|
123
|
+
if queued:
|
|
124
|
+
logger.info(
|
|
125
|
+
"RunPod endpoint slot acquired for %s after queueing (%d/%d in use)",
|
|
126
|
+
name,
|
|
127
|
+
in_use,
|
|
128
|
+
RUNPOD_ENDPOINT_SLOT_CAP,
|
|
129
|
+
)
|
|
130
|
+
with _ACQUIRED_LOCK:
|
|
131
|
+
_ACQUIRED[name] = "shared"
|
|
132
|
+
return
|
|
133
|
+
if not queued:
|
|
134
|
+
logger.info(
|
|
135
|
+
"RunPod quota full (%d/%d) — queueing for a free slot...",
|
|
136
|
+
in_use,
|
|
137
|
+
RUNPOD_ENDPOINT_SLOT_CAP,
|
|
138
|
+
)
|
|
139
|
+
queued = True
|
|
140
|
+
time.sleep(_SLOT_QUEUE_WAIT_S)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _release_endpoint_slot(name: str) -> bool:
|
|
144
|
+
"""Release the quota slot for ``name``, routed to the store it was claimed from.
|
|
145
|
+
|
|
146
|
+
The ``_ACQUIRED`` map records how this process claimed the slot (``"local"`` | ``"shared"``)
|
|
147
|
+
and makes the common path idempotent — only the first call for a name this process claimed
|
|
148
|
+
releases. When a shared store is configured, a teardown running on a DIFFERENT replica than
|
|
149
|
+
the one that claimed the slot (e.g. a ``flash cancel`` routed to another control-plane
|
|
150
|
+
replica) has no ``_ACQUIRED`` entry but must STILL drop the shared row, or the slot leaks
|
|
151
|
+
until the next reconcile and the queue is throttled even though the endpoint is gone. The
|
|
152
|
+
callers only invoke this once the remote endpoint is provably gone, and server-side release
|
|
153
|
+
is idempotent, so the best-effort cross-replica release is safe.
|
|
154
|
+
|
|
155
|
+
Returns ``True`` if a slot was (or, on the shared path, may have been) released; ``False`` for
|
|
156
|
+
a no-op (nothing tracked locally and no shared store to release against).
|
|
157
|
+
"""
|
|
158
|
+
with _ACQUIRED_LOCK:
|
|
159
|
+
mode = _ACQUIRED.pop(name, None)
|
|
160
|
+
if mode == "local":
|
|
161
|
+
_LOCAL_SLOTS.release()
|
|
162
|
+
return True
|
|
163
|
+
from flash.providers.runpod import slots
|
|
164
|
+
|
|
165
|
+
# mode == "shared": this process holds the row. mode is None: either a cross-replica teardown
|
|
166
|
+
# (release the shared row anyway) or — with no shared store at all — genuinely nothing to do.
|
|
167
|
+
cross_replica = mode is None
|
|
168
|
+
if cross_replica and slots.internal_key() is None:
|
|
169
|
+
return False
|
|
170
|
+
|
|
171
|
+
try:
|
|
172
|
+
released = slots.release(name)
|
|
173
|
+
except slots.SlotStoreError as exc:
|
|
174
|
+
# A transient release failure can't leak the slot permanently: the endpoint is gone, so
|
|
175
|
+
# the next startup reconcile reclaims its row against the live endpoint list.
|
|
176
|
+
logger.warning(
|
|
177
|
+
"slot-store release failed for %s (%s); reconcile will reclaim it on restart",
|
|
178
|
+
name,
|
|
179
|
+
exc,
|
|
180
|
+
)
|
|
181
|
+
return not cross_replica
|
|
182
|
+
return released if cross_replica else True
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def reconcile_endpoint_slots() -> None:
|
|
186
|
+
"""Reconcile the shared slot store against the live RunPod endpoint list (control-plane
|
|
187
|
+
startup), so a crash can't leak slots: rows for endpoints that no longer exist are reclaimed
|
|
188
|
+
and the in-use count recovers to the truth. No-op without an internal key (the local
|
|
189
|
+
semaphore needs no reconciliation — a fresh process starts empty). Best-effort: never raises.
|
|
190
|
+
"""
|
|
191
|
+
from flash.providers.runpod import slots
|
|
192
|
+
|
|
193
|
+
if slots.internal_key() is None:
|
|
194
|
+
return
|
|
195
|
+
try:
|
|
196
|
+
from flash.providers.runpod import api as runpod_api
|
|
197
|
+
|
|
198
|
+
# Match BOTH registered forms: the bare ``flash-<gpu>-<run>`` AND RunPod Flash's
|
|
199
|
+
# ``live-flash-...`` (live-provisioned) name. Reconciling only the bare prefix omitted
|
|
200
|
+
# every live endpoint, leaving their slot rows unreclaimed after a crash. Use the
|
|
201
|
+
# canonical sweep predicate (``jobs._is_flash_endpoint``) so the two stay in lockstep.
|
|
202
|
+
from flash.providers.runpod.jobs import _is_flash_endpoint
|
|
203
|
+
|
|
204
|
+
live = [
|
|
205
|
+
name
|
|
206
|
+
for e in runpod_api.list_endpoints()
|
|
207
|
+
if _is_flash_endpoint(name := (e.get("name") or ""))
|
|
208
|
+
]
|
|
209
|
+
except Exception as exc: # listing failed: do NOT reconcile against a partial/empty list
|
|
210
|
+
logger.warning("slot reconcile skipped: could not list RunPod endpoints (%s)", exc)
|
|
211
|
+
return
|
|
212
|
+
try:
|
|
213
|
+
result = slots.reconcile(live)
|
|
214
|
+
logger.info(
|
|
215
|
+
"RunPod slot reconcile: %s in use, %s reclaimed",
|
|
216
|
+
result.get("inUse"),
|
|
217
|
+
result.get("reclaimed"),
|
|
218
|
+
)
|
|
219
|
+
except slots.SlotStoreError as exc:
|
|
220
|
+
logger.warning("slot reconcile failed (%s)", exc)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def _train_body(input_data: dict) -> dict:
|
|
224
|
+
"""Runs ON the RunPod GPU worker: fetch code, train (phase), return metrics.
|
|
225
|
+
|
|
226
|
+
NOTE: Flash serializes this handler and runs it standalone, so every name it uses
|
|
227
|
+
must be imported INSIDE the function body (module-level imports are not in scope).
|
|
228
|
+
"""
|
|
229
|
+
import contextlib
|
|
230
|
+
import json
|
|
231
|
+
import os
|
|
232
|
+
import subprocess
|
|
233
|
+
import sys
|
|
234
|
+
|
|
235
|
+
from huggingface_hub import snapshot_download
|
|
236
|
+
|
|
237
|
+
# Weight-cache PRELOAD mode: download-only, no code repo / no training subprocess. Runs on a
|
|
238
|
+
# worker whose `flash-weights` volume for ONE datacenter is mounted at /runpod-volume; with
|
|
239
|
+
# HF_HOME pointed there (passed in env), each snapshot_download warms that region's cache so the
|
|
240
|
+
# first real run in that region is a cache hit. Returns the repos it cached. Kept here (in the
|
|
241
|
+
# baked handler) so preload reuses the existing image/handler — no separate worker image.
|
|
242
|
+
if input_data.get("mode") == "preload":
|
|
243
|
+
overrides = {k: str(v) for k, v in (input_data.get("env") or {}).items()}
|
|
244
|
+
os.environ.update(overrides)
|
|
245
|
+
# NB: HF_HUB_ENABLE_HF_TRANSFER is NOT set here — the worker image already exports it
|
|
246
|
+
# (Dockerfile.worker ENV), same as the training path relies on (see deps.py).
|
|
247
|
+
tok = overrides.get("HF_TOKEN")
|
|
248
|
+
# CRITICAL: huggingface_hub froze HF_HUB_CACHE from HF_HOME at IMPORT time (the
|
|
249
|
+
# `from huggingface_hub import snapshot_download` above, before this branch), so the
|
|
250
|
+
# HF_HOME we just set in os.environ is ignored by snapshot_download. Pass cache_dir
|
|
251
|
+
# EXPLICITLY = <HF_HOME>/hub (HF's own layout) so the download lands on the mounted volume
|
|
252
|
+
# instead of the worker's ephemeral default cache. (The training path is immune: it spawns a
|
|
253
|
+
# subprocess that imports huggingface_hub fresh with HF_HOME already in its env.)
|
|
254
|
+
hf_home = overrides.get("HF_HOME")
|
|
255
|
+
# The whole point of preload is to write onto the per-region network volume mounted at
|
|
256
|
+
# /runpod-volume. If HF_HOME is MISSING or not rooted there, cache_dir would fall back to the
|
|
257
|
+
# worker's EPHEMERAL default cache: snapshot_download would "succeed" and report repos
|
|
258
|
+
# preloaded while persisting NOTHING to the volume — a phantom warm the driver would count as a
|
|
259
|
+
# warmed region. Refuse a misconfigured preload instead (the flash driver always passes a
|
|
260
|
+
# volume-rooted HF_HOME, so this only fires on a handler-shape/env bug).
|
|
261
|
+
if not hf_home or not hf_home.startswith("/runpod-volume"):
|
|
262
|
+
return {
|
|
263
|
+
"preloaded": [], "already_cached": [], "failed": {},
|
|
264
|
+
"error": f"preload requires HF_HOME rooted at /runpod-volume (got HF_HOME={hf_home!r})",
|
|
265
|
+
"hf_home": hf_home,
|
|
266
|
+
}
|
|
267
|
+
# Rooted at the volume but the mount is absent (endpoint deployed without the volume / RunPod
|
|
268
|
+
# didn't mount it) — same phantom-warm risk; fail loudly so the driver records this region failed.
|
|
269
|
+
if not os.path.isdir("/runpod-volume"):
|
|
270
|
+
return {
|
|
271
|
+
"preloaded": [], "already_cached": [], "failed": {},
|
|
272
|
+
"error": f"weight-cache volume not mounted at /runpod-volume (HF_HOME={hf_home})",
|
|
273
|
+
"hf_home": hf_home,
|
|
274
|
+
}
|
|
275
|
+
cache_dir = os.path.join(hf_home, "hub") # hf_home is now guaranteed non-empty + volume-rooted
|
|
276
|
+
# Same exclusions as the worker prefetch (engine/worker.prefetch_model), the image bake, and
|
|
277
|
+
# the instance-provider preload (_instance_bootstrap.run_preload): weights + tokenizer/config
|
|
278
|
+
# only, never the large unused artifacts. Inlined (this handler is baked self-contained, so it
|
|
279
|
+
# can't import a shared constant). Keeps the warmed cache byte-for-byte what workers fetch,
|
|
280
|
+
# so the 100GB sizing holds and the local_files_only `already_cached` probe stays consistent.
|
|
281
|
+
ignore_patterns = ["*.pth", "*.gguf", "original/*", "*.onnx", "*.msgpack", "*.h5"]
|
|
282
|
+
done, already, failed = [], [], {}
|
|
283
|
+
for repo_id in input_data.get("models") or []:
|
|
284
|
+
try:
|
|
285
|
+
# Idempotent: if the volume already has this snapshot, skip the download. The
|
|
286
|
+
# local_files_only probe is also the persistence signal the live test reads —
|
|
287
|
+
# ``already_cached`` proves the weights survived a previous, separate deployment.
|
|
288
|
+
try:
|
|
289
|
+
snapshot_download(
|
|
290
|
+
repo_id=repo_id, token=tok, cache_dir=cache_dir,
|
|
291
|
+
ignore_patterns=ignore_patterns, local_files_only=True,
|
|
292
|
+
)
|
|
293
|
+
already.append(repo_id)
|
|
294
|
+
continue
|
|
295
|
+
except Exception:
|
|
296
|
+
pass
|
|
297
|
+
snapshot_download(
|
|
298
|
+
repo_id=repo_id, token=tok, cache_dir=cache_dir, ignore_patterns=ignore_patterns
|
|
299
|
+
)
|
|
300
|
+
done.append(repo_id)
|
|
301
|
+
except Exception as exc: # one bad/gated repo must not abort warming the rest
|
|
302
|
+
failed[repo_id] = str(exc)[:300]
|
|
303
|
+
return {
|
|
304
|
+
"preloaded": done,
|
|
305
|
+
"already_cached": already,
|
|
306
|
+
"failed": failed,
|
|
307
|
+
"hf_home": os.environ.get("HF_HOME"),
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
# NB: the Hopper fla fast-path setup lives in flash.engine.worker._ensure_fla_fastpath_on_hopper (runs
|
|
311
|
+
# in the worker process AFTER all installs, before any model import) — doing it here would be
|
|
312
|
+
# undone by a later extra_pip / `prime env install`, and depends on a handler redeploy.
|
|
313
|
+
|
|
314
|
+
# Extra pip deps for Freesolo environments (installed per-run).
|
|
315
|
+
extra_pip = input_data.get("extra_pip") or []
|
|
316
|
+
if extra_pip:
|
|
317
|
+
# check=True: a deterministic dependency failure should fail fast here,
|
|
318
|
+
# not after model download + worker startup with a less actionable error.
|
|
319
|
+
subprocess.run([sys.executable, "-m", "pip", "install", *extra_pip], check=True)
|
|
320
|
+
|
|
321
|
+
# NB: fla is kept on ALL arches. On Hopper (sm90) fla's GDN backward is miscomputed with
|
|
322
|
+
# Triton>=3.4 (#640); the fix is fla's tilelang backend, so flash.engine.worker._ensure_fla_fastpath_on_hopper
|
|
323
|
+
# makes fla+tilelang live at worker startup (instead of dropping fla) for ~4-13x faster + ~2x
|
|
324
|
+
# lighter Hopper GDN training than the pure-PyTorch delta fallback.
|
|
325
|
+
|
|
326
|
+
overrides = {k: str(v) for k, v in (input_data.get("env") or {}).items()}
|
|
327
|
+
snapshot_download(
|
|
328
|
+
repo_id=input_data["hf_repo"],
|
|
329
|
+
repo_type="dataset",
|
|
330
|
+
allow_patterns=["code/**"],
|
|
331
|
+
local_dir="/runcode",
|
|
332
|
+
token=overrides.get("HF_TOKEN"),
|
|
333
|
+
)
|
|
334
|
+
code_dir = "/runcode/code"
|
|
335
|
+
|
|
336
|
+
env = dict(os.environ)
|
|
337
|
+
env.update(overrides)
|
|
338
|
+
# If the weight-cache volume isn't actually mounted (cold/no-volume run, or the attach degraded
|
|
339
|
+
# to {}), don't leave HF_HOME pointing at a missing /runpod-volume path — fall back to the
|
|
340
|
+
# default ephemeral cache. INLINED (not a flash import): this handler is extracted standalone
|
|
341
|
+
# into the baked image's rp_handler (docker/make_rp_handler.py), where flash is NOT importable,
|
|
342
|
+
# so it must stay self-contained. Mirrors deps.drop_unmounted_cache_env (unit-tested there).
|
|
343
|
+
if not os.path.isdir("/runpod-volume"):
|
|
344
|
+
for _k in [k for k, v in env.items() if str(v).startswith("/runpod-volume")]:
|
|
345
|
+
env.pop(_k, None)
|
|
346
|
+
# Always pass the spec via a file (FLASH_JOB_SPEC_PATH): a large inline spec can blow past the
|
|
347
|
+
# ~128 KiB per-env-string exec limit ("Argument list too long"), and a file is ONE code path for
|
|
348
|
+
# every size (cheap write). load_job_spec_from_env reads it.
|
|
349
|
+
spec_path = "/tmp/job_spec.json"
|
|
350
|
+
with open(spec_path, "w") as sf:
|
|
351
|
+
sf.write(input_data["job_spec_json"])
|
|
352
|
+
env["FLASH_JOB_SPEC_PATH"] = spec_path
|
|
353
|
+
env.pop("FLASH_JOB_SPEC_JSON", None)
|
|
354
|
+
env["PHASE"] = input_data["phase"]
|
|
355
|
+
env["SEED"] = str(input_data["seed"])
|
|
356
|
+
env["PYTHONPATH"] = code_dir + (os.pathsep + env["PYTHONPATH"] if env.get("PYTHONPATH") else "")
|
|
357
|
+
|
|
358
|
+
def _upload_console(mode: str) -> None:
|
|
359
|
+
"""Upload the captured console tail for ``mode`` to ``{phase_ns}/{run_id}/seed{n}/
|
|
360
|
+
console_<mode>.txt`` in the run repo. Idempotent and best-effort, so it is safe to call
|
|
361
|
+
from both the subprocess-failure path and the missing-metrics crash path: a worker killed
|
|
362
|
+
without a Python exception (OOM/SIGKILL, segfault, or a silent early exit) writes NO
|
|
363
|
+
``error_<mode>.txt``, so the captured console is then the only root-cause record — and a
|
|
364
|
+
crash that exits 0 would otherwise skip the upload entirely, leaving the failure opaque."""
|
|
365
|
+
console = f"/tmp/console_{mode}.txt"
|
|
366
|
+
if not os.path.exists(console):
|
|
367
|
+
return
|
|
368
|
+
try:
|
|
369
|
+
from huggingface_hub import HfApi
|
|
370
|
+
|
|
371
|
+
spec = json.loads(input_data["job_spec_json"])
|
|
372
|
+
phase_ns = "rl" if spec.get("algorithm") == "grpo" else spec["algorithm"]
|
|
373
|
+
prefix = f"{phase_ns}/{spec['run_id']}/seed{input_data['seed']}"
|
|
374
|
+
# Read only the last 64 KB (seek from the end) — the console can be very large on long
|
|
375
|
+
# runs, so f.read()[-64_000:] would pull the whole file into memory just to slice it.
|
|
376
|
+
tail_bytes = 64_000
|
|
377
|
+
with open(console, "rb") as f:
|
|
378
|
+
f.seek(0, os.SEEK_END)
|
|
379
|
+
f.seek(max(0, f.tell() - tail_bytes))
|
|
380
|
+
tail = f.read().decode("utf-8", "replace")
|
|
381
|
+
# utf-8 + replace so a non-ASCII console tail can't raise UnicodeEncodeError under a
|
|
382
|
+
# minimal-container ASCII locale (LANG=C); the tail itself was decoded utf-8/replace.
|
|
383
|
+
with open(console + ".tail", "w", encoding="utf-8", errors="replace") as f:
|
|
384
|
+
f.write(tail)
|
|
385
|
+
HfApi(token=env.get("HF_TOKEN")).upload_file(
|
|
386
|
+
path_or_fileobj=console + ".tail",
|
|
387
|
+
path_in_repo=f"{prefix}/console_{mode}.txt",
|
|
388
|
+
repo_id=input_data["hf_repo"],
|
|
389
|
+
repo_type="dataset",
|
|
390
|
+
)
|
|
391
|
+
except Exception as up_err:
|
|
392
|
+
print("console upload warn:", up_err)
|
|
393
|
+
|
|
394
|
+
def run_mode(mode: str, check: bool) -> int:
|
|
395
|
+
"""Run one worker process, tee its console to a file, and upload the tail to HF as
|
|
396
|
+
console_<mode>.txt on failure — the engine-core root cause of crashes like vLLM
|
|
397
|
+
EngineDeadError only ever appears on the subprocess console, never in the Python
|
|
398
|
+
traceback. With FLASH_UPLOAD_CONSOLE=1 (forwarded via build_worker_env) the console
|
|
399
|
+
is also uploaded on SUCCESS, so an operator can verify which optimizations engaged."""
|
|
400
|
+
console = f"/tmp/console_{mode}.txt"
|
|
401
|
+
with open(console, "w") as cf:
|
|
402
|
+
proc = subprocess.Popen(
|
|
403
|
+
[sys.executable, "-m", "flash.engine.worker"],
|
|
404
|
+
cwd=code_dir,
|
|
405
|
+
env={**env, "RUN_MODE": mode},
|
|
406
|
+
stdout=subprocess.PIPE,
|
|
407
|
+
stderr=subprocess.STDOUT,
|
|
408
|
+
text=True,
|
|
409
|
+
)
|
|
410
|
+
for line in proc.stdout:
|
|
411
|
+
print(line, end="") # keep streaming to the platform console
|
|
412
|
+
cf.write(line)
|
|
413
|
+
proc.wait()
|
|
414
|
+
# Console is uploaded on FAILURE (crash root-cause). FLASH_UPLOAD_CONSOLE=1 also uploads it
|
|
415
|
+
# on SUCCESS so an operator can verify which optimizations engaged — LoRA+/8-bit-AdamW/
|
|
416
|
+
# Liger/PiSSA/rsLoRA/fla/chalk all log their engagement (or fallback) to the console.
|
|
417
|
+
_force_console = env.get("FLASH_UPLOAD_CONSOLE", "").strip().lower() not in (
|
|
418
|
+
"", "0", "false", "no", "off",
|
|
419
|
+
)
|
|
420
|
+
if proc.returncode != 0 or _force_console:
|
|
421
|
+
_upload_console(mode)
|
|
422
|
+
if proc.returncode != 0 and check:
|
|
423
|
+
raise RuntimeError(
|
|
424
|
+
f"worker mode '{mode}' exited {proc.returncode}; see console_{mode}.txt "
|
|
425
|
+
f"and error_{mode}.txt in the HF dataset repo"
|
|
426
|
+
)
|
|
427
|
+
return proc.returncode
|
|
428
|
+
|
|
429
|
+
# A warm worker can carry a previous seed's metrics files; a stale metrics.json
|
|
430
|
+
# would let a crashed train phase report the previous run's numbers. Clear before
|
|
431
|
+
# training.
|
|
432
|
+
for stale in ("/tmp/train_meta.json", "/tmp/metrics.json"):
|
|
433
|
+
with contextlib.suppress(FileNotFoundError):
|
|
434
|
+
os.remove(stale)
|
|
435
|
+
# Train. check=False — RL's colocated vLLM can segfault at interpreter exit AFTER
|
|
436
|
+
# the adapter + metrics.json + DONE are saved; don't treat that as a failure.
|
|
437
|
+
run_mode(input_data["phase"], check=False)
|
|
438
|
+
# The train phase writes metrics.json + the DONE sentinel itself (RunPod can also
|
|
439
|
+
# redeliver a completed job, whose worker restores metrics.json from DONE). If it
|
|
440
|
+
# is missing, the train phase crashed before finishing — fail fast with the real
|
|
441
|
+
# cause (full traceback in error_<phase>.txt / console_<phase>.txt in the HF repo).
|
|
442
|
+
if not os.path.exists("/tmp/metrics.json"):
|
|
443
|
+
phase = input_data["phase"]
|
|
444
|
+
# run_mode skips the console upload when the worker exits 0 (and a hard OOM/segfault kill
|
|
445
|
+
# may have raced it), so force it here — otherwise this exact "crashed before finishing"
|
|
446
|
+
# failure is undebuggable: no metrics.json, often no error_<phase>.txt, and the message
|
|
447
|
+
# below points operators at a console_<phase>.txt that was never uploaded.
|
|
448
|
+
_upload_console(phase)
|
|
449
|
+
raise RuntimeError(
|
|
450
|
+
f"train phase '{phase}' produced no /tmp/metrics.json (it crashed before "
|
|
451
|
+
f"finishing); see error_{phase}.txt and console_{phase}.txt in the HF "
|
|
452
|
+
f"dataset repo for the full traceback"
|
|
453
|
+
)
|
|
454
|
+
with open("/tmp/metrics.json") as f:
|
|
455
|
+
return json.load(f)
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def isolate_flash_state(scope: str | None = None) -> None:
|
|
459
|
+
"""Point the Flash SDK's resource registry at a per-process/private directory.
|
|
460
|
+
|
|
461
|
+
The SDK persists its registry to ``./.flash/resources.pkl`` — shared, whole-dict,
|
|
462
|
+
last-writer-wins across every process in the CWD. Observed failure modes: stale
|
|
463
|
+
entries resurrecting long-dead endpoints on later syncs, and concurrent processes
|
|
464
|
+
clobbering each other's bookkeeping. Each Flash process gets its own registry
|
|
465
|
+
under ``~/.flash/flash-state/<scope>``; remote cleanup never relies on the
|
|
466
|
+
registry anyway (REST by id/name — see api.py).
|
|
467
|
+
"""
|
|
468
|
+
try:
|
|
469
|
+
from pathlib import Path
|
|
470
|
+
|
|
471
|
+
import runpod_flash.core.resources.resource_manager as rm
|
|
472
|
+
|
|
473
|
+
scope = scope or f"pid{os.getpid()}"
|
|
474
|
+
state_dir = Path.home() / ".flash" / "flash-state" / scope
|
|
475
|
+
state_dir.mkdir(parents=True, exist_ok=True)
|
|
476
|
+
rm.FLASH_STATE_DIR = state_dir
|
|
477
|
+
rm.RESOURCE_STATE_FILE = state_dir / "resources.pkl"
|
|
478
|
+
except Exception as exc: # never block a run on this
|
|
479
|
+
logger.warning("flash state isolation skipped: %s", exc)
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
def _patch_runpod_backoff() -> None:
|
|
483
|
+
"""Work around a runpod_flash bug that aborts long-running jobs.
|
|
484
|
+
|
|
485
|
+
The SDK polls a synchronous job with exponential backoff computed as
|
|
486
|
+
``base * (2 ** attempt)`` and only clamps to ``max_seconds`` afterwards. On a long
|
|
487
|
+
run the poll ``attempt`` grows without bound, so ``2 ** attempt`` becomes a huge int
|
|
488
|
+
and the float multiply raises ``OverflowError: int too large to convert to float``
|
|
489
|
+
(observed ~80 min in), killing an otherwise-healthy job mid-run. We patch the symbol
|
|
490
|
+
to cap the exponent before the power so the delay still saturates at ``max_seconds``.
|
|
491
|
+
"""
|
|
492
|
+
try:
|
|
493
|
+
import math
|
|
494
|
+
import random
|
|
495
|
+
|
|
496
|
+
from runpod_flash.core.utils import backoff as _bo
|
|
497
|
+
|
|
498
|
+
if getattr(_bo, "_flash_backoff_patched", False):
|
|
499
|
+
return
|
|
500
|
+
|
|
501
|
+
def _safe_get_backoff_delay(
|
|
502
|
+
attempt,
|
|
503
|
+
base=0.1,
|
|
504
|
+
max_seconds=10.0,
|
|
505
|
+
jitter=0.2,
|
|
506
|
+
strategy=_bo.BackoffStrategy.EXPONENTIAL,
|
|
507
|
+
):
|
|
508
|
+
a = min(int(attempt), 30) # cap exponent: 2**30 is plenty; delay saturates anyway
|
|
509
|
+
if strategy == _bo.BackoffStrategy.EXPONENTIAL:
|
|
510
|
+
delay = base * (2**a)
|
|
511
|
+
elif strategy == _bo.BackoffStrategy.LINEAR:
|
|
512
|
+
delay = base + (attempt * base)
|
|
513
|
+
elif strategy == _bo.BackoffStrategy.LOGARITHMIC:
|
|
514
|
+
delay = base * math.log2(attempt + 2)
|
|
515
|
+
else:
|
|
516
|
+
raise ValueError(f"Unsupported backoff strategy: {strategy}")
|
|
517
|
+
delay = min(delay, max_seconds)
|
|
518
|
+
return delay * random.uniform(1 - jitter, 1 + jitter)
|
|
519
|
+
|
|
520
|
+
_bo.get_backoff_delay = _safe_get_backoff_delay
|
|
521
|
+
_bo._flash_backoff_patched = True
|
|
522
|
+
# serverless.py did `from ..utils.backoff import get_backoff_delay`, so patch its ref too.
|
|
523
|
+
try:
|
|
524
|
+
from runpod_flash.core.resources import serverless as _sl
|
|
525
|
+
|
|
526
|
+
_sl.get_backoff_delay = _safe_get_backoff_delay
|
|
527
|
+
except Exception:
|
|
528
|
+
# serverless.py may not import the symbol in this SDK version; the primary
|
|
529
|
+
# patch above still applies, so a missing alias is fine to ignore.
|
|
530
|
+
pass
|
|
531
|
+
except Exception as exc: # never let the patch break submission
|
|
532
|
+
logger.warning("runpod backoff patch skipped: %s", exc)
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
def min_cuda_for(friendly_gpu: str) -> str:
|
|
536
|
+
"""Minimum host CUDA (driver) version for this GPU class on the active stack.
|
|
537
|
+
|
|
538
|
+
Blackwell classes (sm_120 — RTX 5090, RTX Pro 6000): pypi wheels for
|
|
539
|
+
the modern stack (vllm 0.19) ship no Blackwell SASS, so every custom CUDA kernel
|
|
540
|
+
is PTX-JIT'd by the driver — and their PTX is built with a newer toolchain than
|
|
541
|
+
CUDA-12.8-era drivers can JIT (observed: "the provided PTX was compiled with an
|
|
542
|
+
unsupported toolchain" on driver 570.x). CUDA-13 drivers JIT it fine, so those
|
|
543
|
+
classes are pinned to >=13.0 on the modern stack (per-GPU ``min_cuda_modern`` in
|
|
544
|
+
providers.base.GPU_INFO). Ampere/Ada/Hopper have SASS in the wheels and run on 12.8.
|
|
545
|
+
Fully managed per-GPU (no override).
|
|
546
|
+
"""
|
|
547
|
+
from flash.providers.base import min_cuda_modern
|
|
548
|
+
|
|
549
|
+
return min_cuda_modern(friendly_gpu)
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
def endpoint_name(friendly_gpu: str, suffix: str | None = None) -> str:
|
|
553
|
+
"""Flash endpoint/template name for a GPU class, optionally made unique per run.
|
|
554
|
+
|
|
555
|
+
A fixed name (``flash-5090``) collides across back-to-back runs: runpod_flash's
|
|
556
|
+
``get_or_deploy_resource`` finds the prior run's still-registered resource and tries to
|
|
557
|
+
*update* it, which fails with ``GraphQL errors: Template name must be unique`` (there is
|
|
558
|
+
no endpoint GC/reuse). A per-run ``suffix`` (the run id tail) gives each run its own
|
|
559
|
+
endpoint so it deploys fresh instead of colliding. RunPod scales each to zero when idle.
|
|
560
|
+
"""
|
|
561
|
+
base = f"flash-{gpu_short(friendly_gpu)}"
|
|
562
|
+
if not suffix:
|
|
563
|
+
return base
|
|
564
|
+
safe = "".join(c for c in str(suffix) if c.isalnum() or c == "-").strip("-")[:24]
|
|
565
|
+
return f"{base}-{safe}" if safe else base
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
def get_train_endpoint(
|
|
569
|
+
friendly_gpu: str,
|
|
570
|
+
execution_timeout_ms: int | None = None,
|
|
571
|
+
name_suffix: str | None = None,
|
|
572
|
+
disk_gb: int | None = None,
|
|
573
|
+
spec=None,
|
|
574
|
+
):
|
|
575
|
+
"""Build (and cache) the live Flash endpoint handler for a GPU class."""
|
|
576
|
+
# Live ("ad-hoc") provisioning: provision on call, no separate `flash deploy`.
|
|
577
|
+
os.environ["FLASH_IS_LIVE_PROVISIONING"] = "true"
|
|
578
|
+
from runpod_flash import Endpoint
|
|
579
|
+
|
|
580
|
+
from flash.providers.runpod.auth import ensure_auth
|
|
581
|
+
|
|
582
|
+
ensure_auth()
|
|
583
|
+
_patch_runpod_backoff()
|
|
584
|
+
isolate_flash_state(name_suffix)
|
|
585
|
+
|
|
586
|
+
friendly = canonical_gpu(friendly_gpu)
|
|
587
|
+
name = endpoint_name(friendly, name_suffix)
|
|
588
|
+
if name in _ENDPOINT_CACHE:
|
|
589
|
+
# Slot was already acquired when the entry was first created; don't re-acquire.
|
|
590
|
+
return _ENDPOINT_CACHE[name]
|
|
591
|
+
# Claim a quota slot before creating the endpoint. This QUEUES (blocks) when the quota is
|
|
592
|
+
# full, making new runs wait here instead of failing with RunPod's worker-quota error. The
|
|
593
|
+
# slot is released by terminate_endpoint once the remote endpoint is provably torn down.
|
|
594
|
+
_acquire_endpoint_slot(name)
|
|
595
|
+
try:
|
|
596
|
+
kwargs = {
|
|
597
|
+
"name": name,
|
|
598
|
+
"gpu": flash_gpu(friendly),
|
|
599
|
+
"gpu_count": 1,
|
|
600
|
+
"min_cuda_version": min_cuda_for(friendly),
|
|
601
|
+
"execution_timeout_ms": execution_timeout_ms or DEFAULT_EXECUTION_TIMEOUT_MS,
|
|
602
|
+
"workers": (0, 1), # one dedicated worker per run; scale to zero when idle
|
|
603
|
+
}
|
|
604
|
+
# live endpoints keep the boot-install path by default. an operator can still opt into a
|
|
605
|
+
# serverless-compatible image through FLASH_WORKER_IMAGE or the per-sm image template.
|
|
606
|
+
image = worker_image_for_gpu(friendly, allow_default=False)
|
|
607
|
+
if image:
|
|
608
|
+
kwargs["image"] = image
|
|
609
|
+
else:
|
|
610
|
+
kwargs["dependencies"] = resolve_worker_deps()
|
|
611
|
+
kwargs["system_dependencies"] = WORKER_SYSTEM_DEPS
|
|
612
|
+
# Parity with the baked-image deploy_train_endpoint path: attach the multi-region weight
|
|
613
|
+
# cache (best-effort {} on no-cache/error). Local import avoids a jobs<->endpoints import
|
|
614
|
+
# cycle (jobs imports this module at load), same as apply_disk_gb below.
|
|
615
|
+
from flash.providers.runpod.jobs import weight_cache_endpoint_kwargs
|
|
616
|
+
|
|
617
|
+
kwargs.update(weight_cache_endpoint_kwargs(spec))
|
|
618
|
+
ep = Endpoint(**kwargs)
|
|
619
|
+
handler = ep(_train_body) # register the queue-based handler; returns the callable
|
|
620
|
+
# The resource config is cached on the Endpoint, so raising the disk on it here
|
|
621
|
+
# carries through to the deploy that the first handler call triggers.
|
|
622
|
+
from flash.providers.runpod.jobs import apply_disk_gb
|
|
623
|
+
|
|
624
|
+
cfg = ep._build_resource_config()
|
|
625
|
+
apply_disk_gb(cfg, disk_gb)
|
|
626
|
+
_ENDPOINT_CACHE[name] = handler
|
|
627
|
+
return handler
|
|
628
|
+
except Exception:
|
|
629
|
+
# Endpoint creation failed — release the slot so other runs are not permanently blocked.
|
|
630
|
+
_release_endpoint_slot(name)
|
|
631
|
+
raise
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
def _run_suffix(run_id: str | None) -> str | None:
|
|
635
|
+
"""Short, COLLISION-FREE per-run endpoint suffix.
|
|
636
|
+
|
|
637
|
+
Must be unique per run_id: the endpoint name is ``endpoint_name(friendly, suffix)`` and
|
|
638
|
+
RunPod reuses an endpoint by name -- two runs with the same suffix share one endpoint (and
|
|
639
|
+
its cached image/deps/registry-auth/template), so a later run silently reuses the earlier
|
|
640
|
+
one's config. The old ``run_id.split("-")[-1]`` only worked for hash-tailed default ids; a
|
|
641
|
+
descriptive run_id ending in e.g. the card name (``...-a100``) collided across every run.
|
|
642
|
+
Use a stable short hash of the WHOLE run_id, with a sanitized prefix for readability."""
|
|
643
|
+
if not run_id:
|
|
644
|
+
return None
|
|
645
|
+
import hashlib
|
|
646
|
+
import re
|
|
647
|
+
|
|
648
|
+
h = hashlib.sha1(run_id.encode()).hexdigest()[:8]
|
|
649
|
+
prefix = re.sub(r"[^a-z0-9]", "", run_id.lower())[-12:]
|
|
650
|
+
return f"{prefix}{h}" if prefix else h
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
def stop_endpoint(friendly_gpu: str, name: str | None = None) -> None:
|
|
654
|
+
"""Best-effort: scale cached endpoint(s) to zero / drop them.
|
|
655
|
+
|
|
656
|
+
With ``name`` only that run's cached endpoint is dropped; without it, every
|
|
657
|
+
cached endpoint of the GPU class is — so a per-run teardown passes ``name``
|
|
658
|
+
to avoid evicting a concurrent run's handler in the same process.
|
|
659
|
+
|
|
660
|
+
NOTE: this only touches THIS process's in-memory cache, so it does nothing in a fresh
|
|
661
|
+
``flash cancel`` process. Use ``terminate_endpoint`` to actually delete the remote endpoint.
|
|
662
|
+
"""
|
|
663
|
+
friendly = canonical_gpu(friendly_gpu)
|
|
664
|
+
prefix = f"flash-{gpu_short(friendly)}"
|
|
665
|
+
if name:
|
|
666
|
+
match = [k for k in _ENDPOINT_CACHE if k == name]
|
|
667
|
+
else:
|
|
668
|
+
match = [k for k in _ENDPOINT_CACHE if k.startswith(prefix)]
|
|
669
|
+
for key in match:
|
|
670
|
+
handler = _ENDPOINT_CACHE.pop(key, None)
|
|
671
|
+
ep = getattr(handler, "__self__", None) or getattr(handler, "endpoint", None)
|
|
672
|
+
for meth in ("scale_to_zero", "stop", "delete"):
|
|
673
|
+
fn = getattr(ep, meth, None)
|
|
674
|
+
if callable(fn):
|
|
675
|
+
try:
|
|
676
|
+
fn()
|
|
677
|
+
break
|
|
678
|
+
except Exception:
|
|
679
|
+
continue
|
|
680
|
+
|
|
681
|
+
|
|
682
|
+
def _select_endpoint_resources(resources: dict, target: str) -> list[str]:
|
|
683
|
+
"""Resource ids whose resource ``.name`` contains ``target``.
|
|
684
|
+
|
|
685
|
+
The live-provisioned resource is named ``live-<endpoint_name>``, so we match by substring
|
|
686
|
+
to catch the prefix. ``target`` is the endpoint name (``flash-<gpu>[-<run>]``).
|
|
687
|
+
"""
|
|
688
|
+
if not target:
|
|
689
|
+
return []
|
|
690
|
+
out = []
|
|
691
|
+
for uid, res in (resources or {}).items():
|
|
692
|
+
name = str(getattr(res, "name", "") or "")
|
|
693
|
+
if target in name:
|
|
694
|
+
out.append(uid)
|
|
695
|
+
return out
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
def terminate_endpoint(friendly_gpu: str, run_id: str | None = None) -> list[dict]:
|
|
699
|
+
"""Reliably tear down the remote Flash endpoint(s) for a run — cross-process.
|
|
700
|
+
|
|
701
|
+
Unlike ``stop_endpoint`` (which only touches this process's in-memory cache), this looks
|
|
702
|
+
the endpoint up by name in runpod_flash's *persisted* resource registry and deletes it via
|
|
703
|
+
the RunPod API (``ResourceManager.undeploy_resource`` -> ``delete_endpoint``), which stops
|
|
704
|
+
any running worker. Best-effort: never raises. Returns the per-resource undeploy results.
|
|
705
|
+
|
|
706
|
+
With ``run_id`` it targets exactly that run's uniquely-named endpoint; without it, the
|
|
707
|
+
bare ``flash-<gpu>`` prefix matches every endpoint of that GPU class.
|
|
708
|
+
"""
|
|
709
|
+
friendly = canonical_gpu(friendly_gpu)
|
|
710
|
+
target = endpoint_name(friendly, _run_suffix(run_id))
|
|
711
|
+
# Hold FLASH_SDK_LOCK across the ENTIRE Flash critical section, not just the undeploy.
|
|
712
|
+
# isolate_flash_state() swaps runpod_flash's process-wide registry globals and
|
|
713
|
+
# ResourceManager shares the SDK's asyncio singleton, so a concurrent deploy/undeploy on
|
|
714
|
+
# another thread could swap the registry scope between our lookup and our undeploy and tear
|
|
715
|
+
# down the wrong run's resources. Serialize isolation + lookup + undeploy together.
|
|
716
|
+
with FLASH_SDK_LOCK:
|
|
717
|
+
try:
|
|
718
|
+
from flash.providers.runpod.auth import ensure_auth
|
|
719
|
+
|
|
720
|
+
ensure_auth()
|
|
721
|
+
isolate_flash_state(_run_suffix(run_id))
|
|
722
|
+
from runpod_flash.core.resources.resource_manager import ResourceManager
|
|
723
|
+
except Exception as exc: # SDK/auth unavailable
|
|
724
|
+
return [{"success": False, "name": target, "message": f"flash unavailable: {exc}"}]
|
|
725
|
+
|
|
726
|
+
try:
|
|
727
|
+
rm = ResourceManager()
|
|
728
|
+
resources = rm.list_all_resources()
|
|
729
|
+
uids = _select_endpoint_resources(resources, target)
|
|
730
|
+
except Exception as exc:
|
|
731
|
+
return [{"success": False, "name": target, "message": f"resource lookup failed: {exc}"}]
|
|
732
|
+
|
|
733
|
+
async def _undeploy_all() -> list:
|
|
734
|
+
out = []
|
|
735
|
+
for uid in uids:
|
|
736
|
+
res = resources.get(uid)
|
|
737
|
+
name = getattr(res, "name", None)
|
|
738
|
+
try:
|
|
739
|
+
out.append(
|
|
740
|
+
await rm.undeploy_resource(uid, resource_name=name, force_remove=True)
|
|
741
|
+
)
|
|
742
|
+
except Exception as exc:
|
|
743
|
+
out.append({"success": False, "name": name, "message": str(exc)})
|
|
744
|
+
return out
|
|
745
|
+
|
|
746
|
+
try:
|
|
747
|
+
try:
|
|
748
|
+
asyncio.get_running_loop()
|
|
749
|
+
except RuntimeError:
|
|
750
|
+
# No running event loop — asyncio.run() works directly.
|
|
751
|
+
results = asyncio.run(_undeploy_all())
|
|
752
|
+
else:
|
|
753
|
+
# Running event loop (e.g. FastAPI lifespan) — asyncio.run() would raise;
|
|
754
|
+
# daemon=True so a hung undeploy cannot prevent process shutdown.
|
|
755
|
+
_out: list = []
|
|
756
|
+
_err: list = []
|
|
757
|
+
|
|
758
|
+
def _run_undeploy() -> None:
|
|
759
|
+
try:
|
|
760
|
+
_out.append(asyncio.run(_undeploy_all()))
|
|
761
|
+
except Exception as _e:
|
|
762
|
+
_err.append(_e)
|
|
763
|
+
|
|
764
|
+
_t = threading.Thread(target=_run_undeploy, daemon=True)
|
|
765
|
+
_t.start()
|
|
766
|
+
_t.join(timeout=30)
|
|
767
|
+
if _err:
|
|
768
|
+
raise _err[0]
|
|
769
|
+
if not _out:
|
|
770
|
+
raise TimeoutError("undeploy timed out after 30s")
|
|
771
|
+
results = _out[0]
|
|
772
|
+
except Exception as exc:
|
|
773
|
+
results = [{"success": False, "name": target, "message": str(exc)}]
|
|
774
|
+
|
|
775
|
+
# Registry-less fallback: isolate_flash_state() keeps the Flash SDK's resource
|
|
776
|
+
# registry per-process under ~/.flash, so a recreated container (or a crash before
|
|
777
|
+
# on_handle() persisted the endpoint id) leaves the live endpoint invisible to the
|
|
778
|
+
# lookup above. Delete it via the RunPod REST API by its reconstructed name so it
|
|
779
|
+
# can't keep a paid worker alive. ``rest_confirmed_absent`` records that the REST
|
|
780
|
+
# lookup actually RAN and found no endpoint of this name — i.e. we positively verified
|
|
781
|
+
# there is nothing to tear down (distinct from the API being unreachable, where we
|
|
782
|
+
# cannot tell and must keep holding the slot).
|
|
783
|
+
rest_confirmed_absent = False
|
|
784
|
+
if not uids:
|
|
785
|
+
try:
|
|
786
|
+
from flash.providers.runpod import api as runpod_api
|
|
787
|
+
|
|
788
|
+
matches = [
|
|
789
|
+
e for e in runpod_api.find_endpoints_by_name(target) if e.get("name") == target
|
|
790
|
+
]
|
|
791
|
+
for ep in matches:
|
|
792
|
+
if runpod_api.delete_endpoint(ep["id"]):
|
|
793
|
+
results.append(
|
|
794
|
+
{"success": True, "name": target, "message": "deleted via REST API"}
|
|
795
|
+
)
|
|
796
|
+
# No endpoint of this name exists remotely — positively verified absent.
|
|
797
|
+
rest_confirmed_absent = not matches
|
|
798
|
+
except Exception as exc:
|
|
799
|
+
# REST API unreachable: we cannot prove the endpoint is gone, so do NOT treat this
|
|
800
|
+
# as "absent" — releasing on an unverified absence risks oversubscribing the quota.
|
|
801
|
+
logger.debug("REST endpoint lookup failed for %s: %s", target, exc)
|
|
802
|
+
|
|
803
|
+
# Release the quota slot for this run's endpoint. Releases the slot this process acquired,
|
|
804
|
+
# or — when a shared store is configured — the row a different replica claimed (a ``flash
|
|
805
|
+
# cancel`` may run on another replica than the one that deployed). Release ONLY when
|
|
806
|
+
# the remote endpoint is provably gone: (a) at least one undeploy/delete succeeded, or
|
|
807
|
+
# (b) we positively verified no remote endpoint exists (registry returned no uids AND the
|
|
808
|
+
# REST lookup confirmed none — e.g. the endpoint never finished deploying, so its slot
|
|
809
|
+
# would otherwise leak forever and deadlock the queue). We deliberately do NOT release on
|
|
810
|
+
# an undeploy/delete *failure*: the endpoint may still be live and counting against the
|
|
811
|
+
# RunPod quota, so releasing would oversubscribe it — the slot stays held until a later
|
|
812
|
+
# teardown confirms the endpoint is gone. (stop_endpoint no longer releases the slot.)
|
|
813
|
+
if any(r.get("success") for r in results) or (not uids and rest_confirmed_absent):
|
|
814
|
+
_release_endpoint_slot(target)
|
|
815
|
+
|
|
816
|
+
# also drop the in-process cached handler for THIS run only (a class-wide
|
|
817
|
+
# drop would evict a concurrent run's endpoint on the same GPU class).
|
|
818
|
+
with contextlib.suppress(Exception):
|
|
819
|
+
stop_endpoint(friendly, name=target)
|
|
820
|
+
return results
|