freesolo-flash-dev 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flash/__init__.py +29 -0
- flash/_channel.py +23 -0
- flash/_fileio.py +35 -0
- flash/_logging.py +49 -0
- flash/_update_check.py +266 -0
- flash/catalog.py +253 -0
- flash/cli/__init__.py +1 -0
- flash/cli/main/__init__.py +227 -0
- flash/cli/main/__main__.py +6 -0
- flash/cli/main/commands.py +636 -0
- flash/cli/main/envpush.py +317 -0
- flash/cli/main/render.py +599 -0
- flash/cli/main/training_doc.py +455 -0
- flash/client/__init__.py +14 -0
- flash/client/config.py +70 -0
- flash/client/http.py +372 -0
- flash/client/runtime_secrets.py +69 -0
- flash/client/specs.py +20 -0
- flash/cost/__init__.py +16 -0
- flash/cost/analytical.py +175 -0
- flash/cost/facts.py +114 -0
- flash/cost/spec.py +113 -0
- flash/cost/types.py +158 -0
- flash/engine/__init__.py +6 -0
- flash/engine/accounting.py +36 -0
- flash/engine/chalk_kernels.py +116 -0
- flash/engine/multiturn_rollout.py +780 -0
- flash/engine/recipe.py +86 -0
- flash/engine/vram.py +603 -0
- flash/engine/worker/__init__.py +2916 -0
- flash/engine/worker/__main__.py +4 -0
- flash/engine/worker/kernel_warmup.py +400 -0
- flash/engine/worker/lora.py +796 -0
- flash/engine/worker/packing.py +366 -0
- flash/engine/worker/perf.py +1048 -0
- flash/envs/__init__.py +10 -0
- flash/envs/adapter/__init__.py +883 -0
- flash/envs/adapter/rubric.py +222 -0
- flash/envs/base.py +52 -0
- flash/envs/registry.py +62 -0
- flash/mcp/__init__.py +1 -0
- flash/mcp/server.py +85 -0
- flash/providers/__init__.py +59 -0
- flash/providers/_auth.py +24 -0
- flash/providers/_http.py +230 -0
- flash/providers/_instance.py +416 -0
- flash/providers/_instance_bootstrap.py +517 -0
- flash/providers/_poll.py +311 -0
- flash/providers/allocator.py +193 -0
- flash/providers/base.py +431 -0
- flash/providers/hyperstack/__init__.py +127 -0
- flash/providers/hyperstack/api.py +522 -0
- flash/providers/hyperstack/auth.py +17 -0
- flash/providers/hyperstack/gpus.py +29 -0
- flash/providers/hyperstack/jobs/__init__.py +632 -0
- flash/providers/hyperstack/jobs/builders.py +122 -0
- flash/providers/hyperstack/preflight.py +23 -0
- flash/providers/hyperstack/pricing.py +26 -0
- flash/providers/hyperstack/train.py +25 -0
- flash/providers/lambdalabs/__init__.py +139 -0
- flash/providers/lambdalabs/api.py +261 -0
- flash/providers/lambdalabs/auth.py +18 -0
- flash/providers/lambdalabs/gpus.py +29 -0
- flash/providers/lambdalabs/jobs/__init__.py +724 -0
- flash/providers/lambdalabs/jobs/builders.py +118 -0
- flash/providers/lambdalabs/preflight.py +27 -0
- flash/providers/lambdalabs/pricing.py +51 -0
- flash/providers/lambdalabs/train.py +27 -0
- flash/providers/preflight.py +55 -0
- flash/providers/realized.py +80 -0
- flash/providers/runpod/__init__.py +130 -0
- flash/providers/runpod/api.py +186 -0
- flash/providers/runpod/auth.py +37 -0
- flash/providers/runpod/cost.py +57 -0
- flash/providers/runpod/gpus.py +46 -0
- flash/providers/runpod/jobs.py +956 -0
- flash/providers/runpod/keys.py +139 -0
- flash/providers/runpod/preflight.py +30 -0
- flash/providers/runpod/preload.py +915 -0
- flash/providers/runpod/pricing.py +18 -0
- flash/providers/runpod/slots.py +79 -0
- flash/providers/runpod/train/__init__.py +150 -0
- flash/providers/runpod/train/deps.py +395 -0
- flash/providers/runpod/train/endpoints.py +820 -0
- flash/py.typed +0 -0
- flash/runner/__init__.py +686 -0
- flash/runner/checkpoints.py +82 -0
- flash/runner/deploy.py +422 -0
- flash/runner/lifecycle.py +672 -0
- flash/schema/__init__.py +375 -0
- flash/schema/fields.py +331 -0
- flash/serve/__init__.py +1 -0
- flash/serve/deploy.py +326 -0
- flash/serve/pricing.py +60 -0
- flash/server/__init__.py +1 -0
- flash/server/__main__.py +20 -0
- flash/server/app.py +961 -0
- flash/server/auth.py +263 -0
- flash/server/billing.py +124 -0
- flash/server/checkpoints.py +110 -0
- flash/server/db.py +160 -0
- flash/server/environment_registry.py +102 -0
- flash/server/envs.py +360 -0
- flash/server/reconcile.py +163 -0
- flash/server/run_registry.py +150 -0
- flash/spec.py +333 -0
- freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
- freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
- freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
- freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
- freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""List a run's deployable per-step RL checkpoints from its HF artifact repo.
|
|
2
|
+
|
|
3
|
+
The GPU worker publishes each trainer save's LoRA adapter to a stable, NON-pruned path
|
|
4
|
+
(``<adapter_prefix>/checkpoints/step-<N>/adapter``; see
|
|
5
|
+
``flash.engine.worker.publish_deployable_checkpoint``). This module is the control-plane
|
|
6
|
+
reader: it enumerates those snapshots so ``flash checkpoints`` can list them and
|
|
7
|
+
``flash deploy --step N`` can serve a specific one — including for a run that was cancelled or
|
|
8
|
+
failed mid-RL and so never sealed a final adapter. HF (not the backend DB) is the source of
|
|
9
|
+
truth for what's deployable; backend persistence is a mirror (see
|
|
10
|
+
``flash.server.checkpoints``)."""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import os
|
|
15
|
+
import re
|
|
16
|
+
|
|
17
|
+
from flash.runner import adapter_prefix
|
|
18
|
+
from flash.spec import JobSpec
|
|
19
|
+
|
|
20
|
+
# The PEFT weights file a step must carry (alongside adapter_config.json) to be servable.
|
|
21
|
+
_ADAPTER_WEIGHT_FILES = frozenset({"adapter_model.safetensors", "adapter_model.bin"})
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def checkpoint_adapter_prefix(spec: JobSpec, step: int, seed: int | None = None) -> str:
|
|
25
|
+
"""The ``adapter_prefix`` that serves checkpoint ``step``.
|
|
26
|
+
|
|
27
|
+
``deploy_adapter`` appends ``/adapter`` to whatever prefix it's given, so this returns the
|
|
28
|
+
per-step root (``<run prefix>/checkpoints/step-<N>``) — matching the worker's upload path —
|
|
29
|
+
and the existing deploy path needs no special-casing for checkpoints."""
|
|
30
|
+
return f"{adapter_prefix(spec, seed)}/checkpoints/step-{step}"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _adapter_file_re(base: str) -> re.Pattern[str]:
|
|
34
|
+
"""Matches ``<base>/checkpoints/step-<N>/adapter/<filename>`` and captures (step, filename)."""
|
|
35
|
+
return re.compile(re.escape(base) + r"/checkpoints/step-(\d+)/adapter/([^/]+)$")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def list_checkpoints(spec: JobSpec, seed: int | None = None) -> list[dict]:
|
|
39
|
+
"""Deployable per-step adapter snapshots for ``spec``, ascending by step.
|
|
40
|
+
|
|
41
|
+
A step is included only if its adapter folder carries BOTH ``adapter_config.json`` AND a
|
|
42
|
+
weights file (so ``/deploy --step`` can never target a half-uploaded, unloadable step). Each
|
|
43
|
+
entry: ``{"step", "adapter_prefix", "subfolder", "repo_id", "repo_type"}`` where
|
|
44
|
+
``adapter_prefix`` is the value to hand ``deploy_adapter`` to serve that exact step and
|
|
45
|
+
``subfolder`` is the full path of the adapter folder in the repo. Returns ``[]`` when the
|
|
46
|
+
run has no HF repo or no published snapshots (older runs, or none saved yet)."""
|
|
47
|
+
repo = spec.train.hf_repo
|
|
48
|
+
if not repo:
|
|
49
|
+
return []
|
|
50
|
+
base = adapter_prefix(spec, seed)
|
|
51
|
+
pattern = _adapter_file_re(base)
|
|
52
|
+
try:
|
|
53
|
+
from huggingface_hub import HfApi
|
|
54
|
+
|
|
55
|
+
files = HfApi(token=os.environ.get("HF_TOKEN")).list_repo_files(
|
|
56
|
+
repo, repo_type="dataset"
|
|
57
|
+
)
|
|
58
|
+
except Exception as exc: # listing is best-effort; never raise into a run/route
|
|
59
|
+
print(f"[ckpt] list warn for {spec.run_id}: {exc}")
|
|
60
|
+
return []
|
|
61
|
+
# Collect each step's adapter-folder filenames, then keep only steps with config + weights.
|
|
62
|
+
by_step: dict[int, set[str]] = {}
|
|
63
|
+
for path in files:
|
|
64
|
+
match = pattern.search(path)
|
|
65
|
+
if match:
|
|
66
|
+
by_step.setdefault(int(match.group(1)), set()).add(match.group(2))
|
|
67
|
+
out: list[dict] = []
|
|
68
|
+
for step in sorted(by_step):
|
|
69
|
+
names = by_step[step]
|
|
70
|
+
if "adapter_config.json" not in names or names.isdisjoint(_ADAPTER_WEIGHT_FILES):
|
|
71
|
+
continue
|
|
72
|
+
prefix = checkpoint_adapter_prefix(spec, step, seed)
|
|
73
|
+
out.append(
|
|
74
|
+
{
|
|
75
|
+
"step": step,
|
|
76
|
+
"adapter_prefix": prefix,
|
|
77
|
+
"subfolder": f"{prefix}/adapter",
|
|
78
|
+
"repo_id": repo,
|
|
79
|
+
"repo_type": "dataset",
|
|
80
|
+
}
|
|
81
|
+
)
|
|
82
|
+
return out
|
flash/runner/deploy.py
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
1
|
+
"""Deploy / cancel / recover state transitions for a run.
|
|
2
|
+
|
|
3
|
+
Store helpers and the lifecycle functions (``_run_seed_loop`` / ``_gc_run_endpoints``) are
|
|
4
|
+
pulled in via FUNCTION-LOCAL lazy ``from flash.runner import ...`` imports — never at module
|
|
5
|
+
level — for the same two reasons as ``lifecycle.py``: avoid a partially-initialized-package
|
|
6
|
+
import cycle, and keep the test monkeypatches (e.g. ``flash.runner._gc_run_endpoints``)
|
|
7
|
+
reachable through the package global rather than a statically-bound copy.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import contextlib
|
|
13
|
+
import time
|
|
14
|
+
from typing import TYPE_CHECKING
|
|
15
|
+
|
|
16
|
+
from flash.spec import JobSpec
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
# RunStatus lives in flash.runner.__init__ and is only referenced here as a return
|
|
20
|
+
# annotation (stringized by `from __future__ import annotations`), so a TYPE_CHECKING
|
|
21
|
+
# import keeps it resolvable for tooling without a runtime import cycle.
|
|
22
|
+
from flash.runner import RunStatus
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def cancel_run(run_id: str) -> RunStatus:
|
|
26
|
+
"""Cancel a run: delete its remote Flash endpoint (stopping the worker), then mark it
|
|
27
|
+
cancelled.
|
|
28
|
+
|
|
29
|
+
Uses ``terminate_endpoint`` (reconstructs the run's uniquely-named endpoint and deletes it
|
|
30
|
+
via the RunPod API) so the cancel works **cross-process** — a fresh ``flash cancel`` actually
|
|
31
|
+
stops the GPU worker, instead of leaving it running until the wall cap. Best-effort: any
|
|
32
|
+
teardown error is recorded but still flips the run to ``cancelled``.
|
|
33
|
+
"""
|
|
34
|
+
from flash.runner import (
|
|
35
|
+
TERMINAL_STATES,
|
|
36
|
+
_gc_run_endpoints,
|
|
37
|
+
_update,
|
|
38
|
+
get_status,
|
|
39
|
+
mark_deployment_undeployed,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
status = get_status(run_id)
|
|
43
|
+
if status.state in TERMINAL_STATES:
|
|
44
|
+
return status
|
|
45
|
+
# Whether the run was a live `deployed` serving run at cancel entry. This scopes the
|
|
46
|
+
# final `cancelled` transition's terminal override below: only a `deployed` run can have
|
|
47
|
+
# a concurrent undeploy (`mark_undeployed`) race this teardown and write a non-completion
|
|
48
|
+
# terminal `done`. A non-deployed run (running/provisioning/queued) has an in-flight
|
|
49
|
+
# TRAINING thread whose only terminal `done` is a GENUINE completion — which cancel must
|
|
50
|
+
# never clobber. See the final _update call for how this gates the override.
|
|
51
|
+
entered_deployed = status.state == "deployed"
|
|
52
|
+
spec = JobSpec.from_dict(status.spec)
|
|
53
|
+
remote = status.remote or {}
|
|
54
|
+
# A deployed run also owns a serving registration with the freesolo serving
|
|
55
|
+
# app that the training-endpoint GC below does not touch; deregister it too so
|
|
56
|
+
# a cancelled run can't leave a deployment registered as active.
|
|
57
|
+
if status.state == "deployed":
|
|
58
|
+
try:
|
|
59
|
+
from flash.serve.deploy import undeploy_adapter
|
|
60
|
+
|
|
61
|
+
undeploy_adapter(run_id)
|
|
62
|
+
# Mark the deployment inactive so /v1/deployments and /chat stop treating the
|
|
63
|
+
# cancelled run as active. Delete is idempotent: an already-absent adapter still
|
|
64
|
+
# means the local deployment record can be cleared.
|
|
65
|
+
if status.deployment:
|
|
66
|
+
# Mark the deployment inactive through the lock-guarded path so this write
|
|
67
|
+
# participates in the same _STATUS_LOCK as the rest of the runner. A bare
|
|
68
|
+
# _save_status here would persist a stale pre-teardown snapshot OUTSIDE the
|
|
69
|
+
# lock, bypassing serialization and potentially clobbering a concurrent field
|
|
70
|
+
# update. We mark ONLY the deployment field and leave the run's state alone
|
|
71
|
+
# (no state re-assert): a concurrent mark_undeployed can move the run to
|
|
72
|
+
# terminal `done` between our get_status read and this write, and _update's
|
|
73
|
+
# compare-and-set rejects ANY transition off a terminal state (even a
|
|
74
|
+
# same-field re-assert of the stale `deployed`), which would silently leave
|
|
75
|
+
# the deployment advertised as `ready`. mark_deployment_undeployed flips the
|
|
76
|
+
# deployment regardless of (and without disturbing) the current state.
|
|
77
|
+
mark_deployment_undeployed(run_id)
|
|
78
|
+
except Exception:
|
|
79
|
+
# Best-effort serving teardown: a failure here must not block the cancel
|
|
80
|
+
# below (the run still flips to cancelled and the training endpoint is GC'd).
|
|
81
|
+
pass
|
|
82
|
+
# Durable path first: stop the exact remote worker via the handle's provider
|
|
83
|
+
# (works from any process); endpoint/instance teardown is shared with the GC.
|
|
84
|
+
# Dispatched generically through the registry — never a hardcoded per-provider branch.
|
|
85
|
+
if remote:
|
|
86
|
+
try:
|
|
87
|
+
from flash.providers import get_provider
|
|
88
|
+
from flash.providers.base import JobHandle
|
|
89
|
+
|
|
90
|
+
handle = JobHandle.from_dict(remote)
|
|
91
|
+
provider = get_provider(handle.provider)
|
|
92
|
+
provider.cancel(handle)
|
|
93
|
+
# Belt-and-suspenders destroy after cancel; RunPod endpoint GC follows.
|
|
94
|
+
provider.destroy(handle)
|
|
95
|
+
except Exception:
|
|
96
|
+
# Best-effort remote stop; _gc_run_endpoints below still tears the endpoint down.
|
|
97
|
+
pass
|
|
98
|
+
_gc_run_endpoints(spec)
|
|
99
|
+
# Final transition to `cancelled`. The run was NON-terminal at entry, but teardown takes
|
|
100
|
+
# time and a terminal state can race in mid-teardown. We must distinguish two cases:
|
|
101
|
+
#
|
|
102
|
+
# - A concurrent mark_undeployed() (an external `DELETE /v1/runs/{id}/deploy`) flipped a
|
|
103
|
+
# `deployed` run to terminal `done`. That `done` is NOT a fresh result — it just
|
|
104
|
+
# restored the run's pre-deploy completion marker while retiring serving. The user
|
|
105
|
+
# explicitly asked to cancel, so this must be OVERRIDDEN to `cancelled`.
|
|
106
|
+
# - A genuine training-COMPLETION `done` from the run's own training thread
|
|
107
|
+
# (_run_job_inner / attach_run), which persisted real metrics+cost+artifacts. Cancel
|
|
108
|
+
# must NEVER clobber that — the run finished, so the real result is preserved.
|
|
109
|
+
#
|
|
110
|
+
# These two races are mutually exclusive on the entry state: only a `deployed` run owns a
|
|
111
|
+
# deployment that mark_undeployed can race, and only a non-deployed (running/provisioning/
|
|
112
|
+
# queued) run has an in-flight training thread that can complete mid-teardown. So scope the
|
|
113
|
+
# terminal override to runs that were `deployed` at entry — there a racing `done` is always
|
|
114
|
+
# an undeploy artifact (cancel wins); elsewhere a racing `done` is a genuine completion that
|
|
115
|
+
# _update's CAS correctly protects (cancel loses to a real finish).
|
|
116
|
+
_update(run_id, "cancelled", allow_from_terminal=entered_deployed)
|
|
117
|
+
# A run cancelled mid-RL keeps whatever per-step adapters the worker already streamed to
|
|
118
|
+
# HF; mirror them to the backend store now so the cancelled run is immediately listable +
|
|
119
|
+
# deployable (`flash checkpoints` / `flash deploy --step N`). Best-effort: never let
|
|
120
|
+
# checkpoint bookkeeping fail a cancel.
|
|
121
|
+
with contextlib.suppress(Exception):
|
|
122
|
+
from flash.server.checkpoints import register_checkpoints_best_effort
|
|
123
|
+
|
|
124
|
+
register_checkpoints_best_effort(get_status(run_id))
|
|
125
|
+
return get_status(run_id)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def attach_run(run_id: str, log_stream=None) -> RunStatus:
|
|
129
|
+
"""Re-attach to a run's remote job from ANY process (after a client crash/restart).
|
|
130
|
+
|
|
131
|
+
Uses the persisted {endpoint_id, job_id} handle to resume polling; on completion,
|
|
132
|
+
persists metrics exactly like the original client would have, flips the state, and
|
|
133
|
+
GCs the endpoint. Raises if the run has no persisted handle (it failed or was
|
|
134
|
+
cancelled before a worker was provisioned).
|
|
135
|
+
"""
|
|
136
|
+
import sys
|
|
137
|
+
|
|
138
|
+
from flash.runner import (
|
|
139
|
+
TERMINAL_STATES,
|
|
140
|
+
_gc_run_endpoints,
|
|
141
|
+
_persist_metrics,
|
|
142
|
+
_run_seed_loop,
|
|
143
|
+
_RunCancelled,
|
|
144
|
+
_update,
|
|
145
|
+
artifacts_dir,
|
|
146
|
+
get_status,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
status = get_status(run_id)
|
|
150
|
+
if status.state in TERMINAL_STATES:
|
|
151
|
+
return status
|
|
152
|
+
if not status.remote:
|
|
153
|
+
raise ValueError(f"run {run_id} has no persisted job handle; cannot reattach")
|
|
154
|
+
|
|
155
|
+
spec = JobSpec.from_dict(status.spec)
|
|
156
|
+
remote = dict(status.remote)
|
|
157
|
+
seed = int(remote.pop("seed", spec.train.seeds[0]))
|
|
158
|
+
# The class the run actually provisioned (a policy retry may have walked past the
|
|
159
|
+
# provisional spec.gpu.type). The in-process success path stamps this into metrics;
|
|
160
|
+
# on recovery the worker output carries no such field, so recover it from the handle
|
|
161
|
+
# to cost the right card.
|
|
162
|
+
allocated_gpu = remote.pop("allocated_gpu", None)
|
|
163
|
+
log = log_stream or sys.stderr
|
|
164
|
+
# Dispatch the poll generically via the handle's provider (the provider owns its
|
|
165
|
+
# heartbeat reader + poll loop); the orchestrator stays provider-agnostic.
|
|
166
|
+
from flash.providers import get_provider
|
|
167
|
+
from flash.providers.base import JobHandle
|
|
168
|
+
|
|
169
|
+
handle = JobHandle.from_dict(remote)
|
|
170
|
+
print(f"attaching to {run_id}: provider={handle.provider} {handle.data}", file=log)
|
|
171
|
+
res = get_provider(handle.provider).poll(handle, spec, seed, log=log)
|
|
172
|
+
try:
|
|
173
|
+
# A best-effort cancel deletes the job/instance, which the poller reports as a
|
|
174
|
+
# failure (or a late worker may still succeed) — either way, re-read the state
|
|
175
|
+
# first so a recovery thread can't overwrite the user's terminal `cancelled`.
|
|
176
|
+
if get_status(run_id).state == "cancelled":
|
|
177
|
+
return get_status(run_id)
|
|
178
|
+
if not res.ok:
|
|
179
|
+
# Job ended not-ok — usually because it was abandoned during the redeploy. Resume the
|
|
180
|
+
# in-flight seed from its last HF checkpoint instead of failing; the seed loop
|
|
181
|
+
# (unchanged) still terminates a genuinely broken run when it re-fails.
|
|
182
|
+
try:
|
|
183
|
+
seed_index = list(spec.train.seeds).index(seed)
|
|
184
|
+
except ValueError:
|
|
185
|
+
seed_index = 0
|
|
186
|
+
print(
|
|
187
|
+
f"attach: {run_id} seed {seed} ended ({res.failure}); resuming from checkpoint",
|
|
188
|
+
file=log,
|
|
189
|
+
)
|
|
190
|
+
# GC the dead endpoint, then clear the stale handle and record the seed so a second
|
|
191
|
+
# restart mid-allocation resumes the right one.
|
|
192
|
+
with contextlib.suppress(Exception):
|
|
193
|
+
_gc_run_endpoints(spec)
|
|
194
|
+
# Bail if the run was raced to terminal during the long poll above: _update's CAS
|
|
195
|
+
# returns False, and resuming would submit paid work for a dead run.
|
|
196
|
+
if not _update(run_id, "running", remote=None, resume_seed_index=seed_index):
|
|
197
|
+
print(f"attach: {run_id} went terminal during recovery; not resuming", file=log)
|
|
198
|
+
return get_status(run_id)
|
|
199
|
+
_run_seed_loop(
|
|
200
|
+
spec, log, start_index=seed_index, prior_cost=float(status.cost_usd or 0.0)
|
|
201
|
+
)
|
|
202
|
+
return get_status(run_id)
|
|
203
|
+
# Carry the provisioned class into metrics so _persist_metrics costs the card the
|
|
204
|
+
# run actually used (the in-process path stamps this; recovery must restore it).
|
|
205
|
+
if allocated_gpu and isinstance(res.metrics, dict):
|
|
206
|
+
res.metrics.setdefault("allocated_gpu", allocated_gpu)
|
|
207
|
+
# Earlier seeds of a multi-seed run already persisted their cost into
|
|
208
|
+
# status.cost_usd; add this seed's so recovery doesn't underreport spend.
|
|
209
|
+
total = float(status.cost_usd or 0.0) + _persist_metrics(spec, seed, res.metrics)
|
|
210
|
+
# A cancel can land while this thread persists the recovered seed's metrics
|
|
211
|
+
# (after the late-cancel check above). Re-read before the post-seed writes so
|
|
212
|
+
# the "running" update and the terminal "done" below can't resurrect a
|
|
213
|
+
# user-cancelled run (mirrors the fresh seed loop). _RunCancelled is caught
|
|
214
|
+
# below, leaving the cancellation intact.
|
|
215
|
+
if get_status(run_id).state == "cancelled":
|
|
216
|
+
raise _RunCancelled(f"run {run_id} was cancelled")
|
|
217
|
+
# The remote handle only identifies the seed that was in flight. For a
|
|
218
|
+
# multi-seed run, resume the remaining seeds instead of terminally
|
|
219
|
+
# completing the whole run after just this one.
|
|
220
|
+
try:
|
|
221
|
+
resumed_index = list(spec.train.seeds).index(seed) + 1
|
|
222
|
+
except ValueError:
|
|
223
|
+
resumed_index = len(spec.train.seeds)
|
|
224
|
+
more_seeds = resumed_index < len(spec.train.seeds)
|
|
225
|
+
# Clear the now-stale completed handle before resuming. In the
|
|
226
|
+
# allocation/provisioning gap before the next seed's on_handle() persists a
|
|
227
|
+
# fresh handle, a server restart must not reattach recovery to this finished
|
|
228
|
+
# job — that would double-count its cost and replay the wrong seed. Record the
|
|
229
|
+
# next seed index so a restart in that gap resumes the remaining seeds rather
|
|
230
|
+
# than failing the run. (The last seed keeps its handle for post-run
|
|
231
|
+
# observability, mirroring the fresh-submit seed loop.)
|
|
232
|
+
applied = _update(
|
|
233
|
+
run_id,
|
|
234
|
+
"running",
|
|
235
|
+
cost_usd=total,
|
|
236
|
+
artifacts_dir=artifacts_dir(spec),
|
|
237
|
+
**({"remote": None, "resume_seed_index": resumed_index} if more_seeds else {}),
|
|
238
|
+
)
|
|
239
|
+
# Same TOCTOU guard as the not-ok recovery path: a concurrent thread can flip this
|
|
240
|
+
# run terminal (e.g. failed/done from another recovery) between the cancel re-check
|
|
241
|
+
# above and here. The sticky CAS rejects the `running` write (applied is False) — so
|
|
242
|
+
# don't resume the remaining seeds and submit paid GPU work for an already-terminal
|
|
243
|
+
# run. (The non-multi-seed arm writes the terminal `done`; the CAS protects a racing
|
|
244
|
+
# terminal there too, so no extra guard is needed.)
|
|
245
|
+
if more_seeds:
|
|
246
|
+
if not applied:
|
|
247
|
+
print(
|
|
248
|
+
f"attach: {run_id} went terminal during recovery; "
|
|
249
|
+
"not resuming the remaining seeds",
|
|
250
|
+
file=log,
|
|
251
|
+
)
|
|
252
|
+
return get_status(run_id)
|
|
253
|
+
_run_seed_loop(spec, log, start_index=resumed_index, prior_cost=total)
|
|
254
|
+
else:
|
|
255
|
+
_update(run_id, "done", cost_usd=total, artifacts_dir=artifacts_dir(spec))
|
|
256
|
+
except _RunCancelled:
|
|
257
|
+
# Intentional: cancel_run already wrote the terminal `cancelled` state; leave it.
|
|
258
|
+
pass
|
|
259
|
+
except Exception as exc:
|
|
260
|
+
if get_status(run_id).state != "cancelled":
|
|
261
|
+
_update(run_id, "failed", error=str(exc))
|
|
262
|
+
finally:
|
|
263
|
+
_gc_run_endpoints(spec)
|
|
264
|
+
return get_status(run_id)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def resume_run(run_id: str, log_stream=None) -> RunStatus:
|
|
268
|
+
"""Resume the remaining seeds of a multi-seed run after a restart in the inter-seed gap.
|
|
269
|
+
|
|
270
|
+
Between two seeds the completed seed's handle is cleared and ``resume_seed_index`` is
|
|
271
|
+
recorded (see ``_run_seed_loop``). A control-plane restart in that handle-less window
|
|
272
|
+
must RESUME from that index rather than fail the run and discard the finished seeds.
|
|
273
|
+
Unlike ``attach_run`` there is no live job to poll — the prior process already tore the
|
|
274
|
+
seed's endpoint down — so we start a fresh seed loop from the recorded index. The flash
|
|
275
|
+
package was uploaded to HF on the original submit, so the worker can still fetch it; no
|
|
276
|
+
re-upload is needed.
|
|
277
|
+
"""
|
|
278
|
+
import sys
|
|
279
|
+
|
|
280
|
+
from flash.runner import (
|
|
281
|
+
TERMINAL_STATES,
|
|
282
|
+
_gc_run_endpoints,
|
|
283
|
+
_run_seed_loop,
|
|
284
|
+
_RunCancelled,
|
|
285
|
+
_update,
|
|
286
|
+
get_status,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
status = get_status(run_id)
|
|
290
|
+
if status.state in TERMINAL_STATES:
|
|
291
|
+
return status
|
|
292
|
+
if status.resume_seed_index is None:
|
|
293
|
+
raise ValueError(f"run {run_id} has no resume_seed_index; cannot resume")
|
|
294
|
+
spec = JobSpec.from_dict(status.spec)
|
|
295
|
+
log = log_stream or sys.stderr
|
|
296
|
+
print(f"resuming {run_id}: remaining seeds from index {status.resume_seed_index}", file=log)
|
|
297
|
+
try:
|
|
298
|
+
_run_seed_loop(
|
|
299
|
+
spec,
|
|
300
|
+
log,
|
|
301
|
+
start_index=status.resume_seed_index,
|
|
302
|
+
prior_cost=float(status.cost_usd or 0.0),
|
|
303
|
+
)
|
|
304
|
+
except _RunCancelled:
|
|
305
|
+
pass # cancel_run already set the terminal state
|
|
306
|
+
except Exception as exc:
|
|
307
|
+
if get_status(run_id).state != "cancelled":
|
|
308
|
+
_update(run_id, "failed", error=str(exc))
|
|
309
|
+
finally:
|
|
310
|
+
# Mirror _run_job: GC any endpoint a transient destroy left behind rather than
|
|
311
|
+
# leaking a billable RunPod endpoint.
|
|
312
|
+
_gc_run_endpoints(spec)
|
|
313
|
+
return get_status(run_id)
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def mark_deployed(run_id: str, deployment: dict, expect_state: str | None = None) -> RunStatus:
|
|
317
|
+
from flash.runner import _STATUS_LOCK, _UNDEPLOYABLE_STATES, _save_status, get_status
|
|
318
|
+
|
|
319
|
+
# Atomic + terminal-respecting (same guard as _update): a /cancel landing during
|
|
320
|
+
# deployment writes `cancelled`; this must NOT overwrite it with
|
|
321
|
+
# `deployed` and resurrect the run as an active deployment. `done` is deployable
|
|
322
|
+
# though (the common case: deploy a finished run), so only the non-`done` terminal
|
|
323
|
+
# states block here — otherwise a freshly finished run could never be deployed.
|
|
324
|
+
#
|
|
325
|
+
# expect_state is a compare-and-set: the deploy flow passes the state it expects the
|
|
326
|
+
# run to still be in (the pre-deploy snapshot, or "deployed" after the provisional
|
|
327
|
+
# mark). If an undeploy raced finalization — deleting the endpoint and writing `done`
|
|
328
|
+
# with deployment.state="undeployed" mid-warmup — the state no longer matches and we
|
|
329
|
+
# refuse to re-advertise the just-deleted endpoint.
|
|
330
|
+
with _STATUS_LOCK:
|
|
331
|
+
status = get_status(run_id)
|
|
332
|
+
if status.state in _UNDEPLOYABLE_STATES:
|
|
333
|
+
return status
|
|
334
|
+
if expect_state is not None and status.state != expect_state:
|
|
335
|
+
return status
|
|
336
|
+
# Freeze the training-teardown time before the deploy bumps updated_at. New terminal runs
|
|
337
|
+
# already stamp finished_at on their first terminal transition, but a LEGACY run that went
|
|
338
|
+
# `done` before that field existed has finished_at=None while its current updated_at still
|
|
339
|
+
# holds the real teardown time. Capture it ONLY on the `done` -> `deployed` transition, where
|
|
340
|
+
# updated_at == teardown (mark_deployed is also called on an already-`deployed` run via the
|
|
341
|
+
# CAS finalization with expect_state="deployed", where updated_at is the DEPLOY time), and
|
|
342
|
+
# only when not yet reconciled — record_realized_cost moves updated_at to the reconcile time,
|
|
343
|
+
# so a reconciled-then-deployed legacy run would otherwise freeze that later stamp. Both
|
|
344
|
+
# guards keep us from stamping past the real teardown (the over-billing this fixes); a
|
|
345
|
+
# reconciled run is never re-billed, so leaving finished_at unset there is harmless.
|
|
346
|
+
if status.state == "done" and status.finished_at is None and not status.reconciled_at:
|
|
347
|
+
status.finished_at = status.updated_at
|
|
348
|
+
status.deployment = deployment
|
|
349
|
+
status.state = "deployed"
|
|
350
|
+
status.updated_at = time.time()
|
|
351
|
+
_save_status(status)
|
|
352
|
+
return status
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def attach_checkpoint_deployment(run_id: str, deployment: dict) -> RunStatus:
|
|
356
|
+
"""Attach a serving deployment to a run WITHOUT changing its training state.
|
|
357
|
+
|
|
358
|
+
Used when deploying a specific intermediate checkpoint of a run that never reached
|
|
359
|
+
``done`` — e.g. one cancelled or failed mid-RL. The checkpoint adapter exists on HF, so it
|
|
360
|
+
can be served, but the run's terminal training outcome (``cancelled``/``failed``) must be
|
|
361
|
+
preserved: flipping it to ``deployed`` would both erase that outcome and make a later
|
|
362
|
+
undeploy wrongly restore it to ``done`` (``mark_undeployed`` sends non-terminal runs to
|
|
363
|
+
``done``). The deployment is tracked via the ``deployment`` field exactly like a normal
|
|
364
|
+
deploy, so ``/v1/deployments`` lists it and undeploy clears it. Lock-guarded so it
|
|
365
|
+
serializes with a racing deploy/undeploy on the same run.
|
|
366
|
+
"""
|
|
367
|
+
from flash.runner import _STATUS_LOCK, _save_status, get_status
|
|
368
|
+
|
|
369
|
+
with _STATUS_LOCK:
|
|
370
|
+
status = get_status(run_id)
|
|
371
|
+
status.deployment = deployment
|
|
372
|
+
status.updated_at = time.time()
|
|
373
|
+
_save_status(status)
|
|
374
|
+
return status
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def mark_undeployed(run_id: str) -> RunStatus:
|
|
378
|
+
"""Record an explicit undeploy (endpoint torn down -> run back to `done`).
|
|
379
|
+
|
|
380
|
+
Lock-guarded so it serializes with a racing deploy finalization: the raw read +
|
|
381
|
+
_save_status the endpoint used to do could interleave with mark_deployed and be
|
|
382
|
+
clobbered. With this under the same lock, mark_deployed's expect_state CAS then sees
|
|
383
|
+
the `done`/undeployed write and won't re-advertise the deleted endpoint.
|
|
384
|
+
"""
|
|
385
|
+
from flash.runner import _STATUS_LOCK, TERMINAL_STATES, _save_status, get_status
|
|
386
|
+
|
|
387
|
+
with _STATUS_LOCK:
|
|
388
|
+
status = get_status(run_id)
|
|
389
|
+
if status.deployment:
|
|
390
|
+
status.deployment = {**status.deployment, "state": "undeployed"}
|
|
391
|
+
# Record the teardown but don't resurrect a terminal run: undeploying a
|
|
392
|
+
# cancelled/failed run keeps its terminal state (only a live `deployed` run goes
|
|
393
|
+
# back to `done`). `done` is terminal too, so this naturally no-ops the state.
|
|
394
|
+
if status.state not in TERMINAL_STATES:
|
|
395
|
+
status.state = "done"
|
|
396
|
+
status.updated_at = time.time()
|
|
397
|
+
_save_status(status)
|
|
398
|
+
return status
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def mark_deployment_undeployed(run_id: str) -> RunStatus:
|
|
402
|
+
"""Flip ONLY the deployment field to ``undeployed``, leaving the run's state untouched.
|
|
403
|
+
|
|
404
|
+
Used by ``cancel_run`` to retire a deployed run's serving record. Unlike
|
|
405
|
+
``mark_undeployed`` (which is a state transition: a live `deployed` run goes back to
|
|
406
|
+
`done`), this never asserts or changes the run state. That matters under the cancel
|
|
407
|
+
race: a concurrent ``mark_undeployed`` may have already moved the run to terminal
|
|
408
|
+
`done`, and ``_update``'s compare-and-set rejects any transition off a terminal state —
|
|
409
|
+
even re-asserting `deployed` to carry the deployment field — which would leave the
|
|
410
|
+
deployment advertised as `ready`. Marking the field directly (lock-guarded for
|
|
411
|
+
serialization) sidesteps the CAS so the deployment reliably ends `undeployed`, while the
|
|
412
|
+
trailing ``cancelled`` transition is left to ``_update``.
|
|
413
|
+
"""
|
|
414
|
+
from flash.runner import _STATUS_LOCK, _save_status, get_status
|
|
415
|
+
|
|
416
|
+
with _STATUS_LOCK:
|
|
417
|
+
status = get_status(run_id)
|
|
418
|
+
if status.deployment:
|
|
419
|
+
status.deployment = {**status.deployment, "state": "undeployed"}
|
|
420
|
+
status.updated_at = time.time()
|
|
421
|
+
_save_status(status)
|
|
422
|
+
return status
|