freesolo-flash-dev 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flash/__init__.py +29 -0
- flash/_channel.py +23 -0
- flash/_fileio.py +35 -0
- flash/_logging.py +49 -0
- flash/_update_check.py +266 -0
- flash/catalog.py +253 -0
- flash/cli/__init__.py +1 -0
- flash/cli/main/__init__.py +227 -0
- flash/cli/main/__main__.py +6 -0
- flash/cli/main/commands.py +636 -0
- flash/cli/main/envpush.py +317 -0
- flash/cli/main/render.py +599 -0
- flash/cli/main/training_doc.py +455 -0
- flash/client/__init__.py +14 -0
- flash/client/config.py +70 -0
- flash/client/http.py +372 -0
- flash/client/runtime_secrets.py +69 -0
- flash/client/specs.py +20 -0
- flash/cost/__init__.py +16 -0
- flash/cost/analytical.py +175 -0
- flash/cost/facts.py +114 -0
- flash/cost/spec.py +113 -0
- flash/cost/types.py +158 -0
- flash/engine/__init__.py +6 -0
- flash/engine/accounting.py +36 -0
- flash/engine/chalk_kernels.py +116 -0
- flash/engine/multiturn_rollout.py +780 -0
- flash/engine/recipe.py +86 -0
- flash/engine/vram.py +603 -0
- flash/engine/worker/__init__.py +2916 -0
- flash/engine/worker/__main__.py +4 -0
- flash/engine/worker/kernel_warmup.py +400 -0
- flash/engine/worker/lora.py +796 -0
- flash/engine/worker/packing.py +366 -0
- flash/engine/worker/perf.py +1048 -0
- flash/envs/__init__.py +10 -0
- flash/envs/adapter/__init__.py +883 -0
- flash/envs/adapter/rubric.py +222 -0
- flash/envs/base.py +52 -0
- flash/envs/registry.py +62 -0
- flash/mcp/__init__.py +1 -0
- flash/mcp/server.py +85 -0
- flash/providers/__init__.py +59 -0
- flash/providers/_auth.py +24 -0
- flash/providers/_http.py +230 -0
- flash/providers/_instance.py +416 -0
- flash/providers/_instance_bootstrap.py +517 -0
- flash/providers/_poll.py +311 -0
- flash/providers/allocator.py +193 -0
- flash/providers/base.py +431 -0
- flash/providers/hyperstack/__init__.py +127 -0
- flash/providers/hyperstack/api.py +522 -0
- flash/providers/hyperstack/auth.py +17 -0
- flash/providers/hyperstack/gpus.py +29 -0
- flash/providers/hyperstack/jobs/__init__.py +632 -0
- flash/providers/hyperstack/jobs/builders.py +122 -0
- flash/providers/hyperstack/preflight.py +23 -0
- flash/providers/hyperstack/pricing.py +26 -0
- flash/providers/hyperstack/train.py +25 -0
- flash/providers/lambdalabs/__init__.py +139 -0
- flash/providers/lambdalabs/api.py +261 -0
- flash/providers/lambdalabs/auth.py +18 -0
- flash/providers/lambdalabs/gpus.py +29 -0
- flash/providers/lambdalabs/jobs/__init__.py +724 -0
- flash/providers/lambdalabs/jobs/builders.py +118 -0
- flash/providers/lambdalabs/preflight.py +27 -0
- flash/providers/lambdalabs/pricing.py +51 -0
- flash/providers/lambdalabs/train.py +27 -0
- flash/providers/preflight.py +55 -0
- flash/providers/realized.py +80 -0
- flash/providers/runpod/__init__.py +130 -0
- flash/providers/runpod/api.py +186 -0
- flash/providers/runpod/auth.py +37 -0
- flash/providers/runpod/cost.py +57 -0
- flash/providers/runpod/gpus.py +46 -0
- flash/providers/runpod/jobs.py +956 -0
- flash/providers/runpod/keys.py +139 -0
- flash/providers/runpod/preflight.py +30 -0
- flash/providers/runpod/preload.py +915 -0
- flash/providers/runpod/pricing.py +18 -0
- flash/providers/runpod/slots.py +79 -0
- flash/providers/runpod/train/__init__.py +150 -0
- flash/providers/runpod/train/deps.py +395 -0
- flash/providers/runpod/train/endpoints.py +820 -0
- flash/py.typed +0 -0
- flash/runner/__init__.py +686 -0
- flash/runner/checkpoints.py +82 -0
- flash/runner/deploy.py +422 -0
- flash/runner/lifecycle.py +672 -0
- flash/schema/__init__.py +375 -0
- flash/schema/fields.py +331 -0
- flash/serve/__init__.py +1 -0
- flash/serve/deploy.py +326 -0
- flash/serve/pricing.py +60 -0
- flash/server/__init__.py +1 -0
- flash/server/__main__.py +20 -0
- flash/server/app.py +961 -0
- flash/server/auth.py +263 -0
- flash/server/billing.py +124 -0
- flash/server/checkpoints.py +110 -0
- flash/server/db.py +160 -0
- flash/server/environment_registry.py +102 -0
- flash/server/envs.py +360 -0
- flash/server/reconcile.py +163 -0
- flash/server/run_registry.py +150 -0
- flash/spec.py +333 -0
- freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
- freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
- freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
- freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
- freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
flash/server/app.py
ADDED
|
@@ -0,0 +1,961 @@
|
|
|
1
|
+
"""FastAPI control plane for the managed Flash service.
|
|
2
|
+
|
|
3
|
+
This is the operator-side component. It holds the provider credentials
|
|
4
|
+
(``RUNPOD_API_KEY``, ``HF_TOKEN``, and environment source tokens) and exposes the
|
|
5
|
+
full run lifecycle to clients that authenticate with their freesolo API key
|
|
6
|
+
(verified against the freesolo backend) — clients never see provider credentials.
|
|
7
|
+
|
|
8
|
+
Run state truth stays in the runner's JSON files; SQLite (server/db.py) holds
|
|
9
|
+
keys and run ownership. Runs the server owns are recovered on startup by re-attaching
|
|
10
|
+
to their persisted RunPod job handles.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import contextlib
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
import threading
|
|
20
|
+
import weakref
|
|
21
|
+
|
|
22
|
+
from flash import __version__
|
|
23
|
+
from flash.catalog import public_model_rows
|
|
24
|
+
from flash.client.runtime_secrets import DEFAULT_RUNTIME_SECRET_KEYS
|
|
25
|
+
from flash.runner import (
|
|
26
|
+
adapter_prefix,
|
|
27
|
+
attach_checkpoint_deployment,
|
|
28
|
+
cancel_run,
|
|
29
|
+
get_status,
|
|
30
|
+
mark_deployed,
|
|
31
|
+
mark_undeployed,
|
|
32
|
+
new_run_id,
|
|
33
|
+
runs_file_path,
|
|
34
|
+
submit_job,
|
|
35
|
+
)
|
|
36
|
+
from flash.runner.checkpoints import checkpoint_adapter_prefix, list_checkpoints
|
|
37
|
+
from flash.schema import ConfigError, spec_from_dict
|
|
38
|
+
from flash.serve.deploy import ServingError, deploy_adapter, undeploy_adapter
|
|
39
|
+
from flash.serve.deploy import chat as serve_chat
|
|
40
|
+
from flash.serve.deploy import chat_stream as serve_chat_stream
|
|
41
|
+
from flash.spec import JobSpec
|
|
42
|
+
|
|
43
|
+
from . import auth, db
|
|
44
|
+
|
|
45
|
+
_RUNTIME_SECRET_KEYS = DEFAULT_RUNTIME_SECRET_KEYS
|
|
46
|
+
_RECOVERABLE = {"queued", "provisioning", "running"}
|
|
47
|
+
# Run states that have produced a downloadable adapter artifact.
|
|
48
|
+
_DEPLOYABLE_STATES = {"done", "deployed"}
|
|
49
|
+
# A specific intermediate checkpoint can also be deployed from a run that stopped mid-RL
|
|
50
|
+
# (cancelled/failed): the per-step adapter was already streamed to HF, so it serves even though
|
|
51
|
+
# the run never sealed a final adapter. `dry_run` is excluded — it never trained.
|
|
52
|
+
_CHECKPOINT_DEPLOYABLE_STATES = _DEPLOYABLE_STATES | {"cancelled", "failed"}
|
|
53
|
+
_SERVER_EXTRAS_HINT = "the control plane needs the server extras: pip install 'flash[server]'"
|
|
54
|
+
|
|
55
|
+
_log = logging.getLogger("flash.server")
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _resolve_deploy_step(run_id: str, spec, raw_step) -> int | None:
|
|
59
|
+
"""Validate an optional deploy ``step`` against the run's published checkpoints.
|
|
60
|
+
|
|
61
|
+
Returns the integer step to deploy, or ``None`` when no step was requested (deploy the
|
|
62
|
+
final adapter). Raises ``HTTPException(400)`` for a malformed step and ``HTTPException(404)``
|
|
63
|
+
— listing the available steps — when the run has no deployable checkpoint at that step."""
|
|
64
|
+
if raw_step is None:
|
|
65
|
+
return None
|
|
66
|
+
from fastapi import HTTPException
|
|
67
|
+
|
|
68
|
+
# Accept only an actual integer step — NOT a bool (True would coerce to step 1) and not a
|
|
69
|
+
# non-integer float/string (40.9 / "40.9" must not silently round to a different checkpoint).
|
|
70
|
+
want: int | None = None
|
|
71
|
+
if isinstance(raw_step, bool):
|
|
72
|
+
want = None
|
|
73
|
+
elif isinstance(raw_step, int):
|
|
74
|
+
want = raw_step
|
|
75
|
+
elif isinstance(raw_step, float):
|
|
76
|
+
want = int(raw_step) if raw_step.is_integer() else None
|
|
77
|
+
elif isinstance(raw_step, str) and raw_step.strip().lstrip("-").isdigit():
|
|
78
|
+
want = int(raw_step.strip())
|
|
79
|
+
if want is None:
|
|
80
|
+
raise HTTPException(status_code=400, detail=f"invalid checkpoint step: {raw_step!r}")
|
|
81
|
+
checkpoints = list_checkpoints(spec)
|
|
82
|
+
if any(c["step"] == want for c in checkpoints):
|
|
83
|
+
return want
|
|
84
|
+
available = ", ".join(str(c["step"]) for c in checkpoints) or "none"
|
|
85
|
+
raise HTTPException(
|
|
86
|
+
status_code=404,
|
|
87
|
+
detail=f"run {run_id} has no deployable checkpoint at step {want} (available: {available})",
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
async def _reconcile_cost_loop() -> None:
|
|
92
|
+
"""Background loop: periodically pull realized provider cost (COGS) for finished runs and
|
|
93
|
+
report it to the freesolo backend for estimator accuracy. The provider billing calls are
|
|
94
|
+
blocking urllib, so each sweep is offloaded to a thread; failures are swallowed and retried
|
|
95
|
+
next cycle. Off entirely when FREESOLO_INTERNAL_KEY is unset (see reconcile_enabled)."""
|
|
96
|
+
from flash.server.reconcile import reconcile_once
|
|
97
|
+
|
|
98
|
+
interval = 3600.0 # COGS reconcile sweep interval (fixed; flash is fully managed)
|
|
99
|
+
while True:
|
|
100
|
+
await asyncio.sleep(interval)
|
|
101
|
+
# Handle cancellation EXPLICITLY (re-raise it) and swallow only real Exceptions, exactly
|
|
102
|
+
# like the sibling loops below (_reap_idle_endpoints_loop / _sweep_orphan_instances_loop).
|
|
103
|
+
# On the supported Pythons (>=3.11) asyncio.CancelledError already derives from
|
|
104
|
+
# BaseException, so the old `contextlib.suppress(Exception)` did not swallow a shutdown
|
|
105
|
+
# cancel arriving during the blocking sweep — but being explicit makes the cancel path
|
|
106
|
+
# obvious and uniform, and logs a failed sweep instead of silently dropping it.
|
|
107
|
+
try:
|
|
108
|
+
reported = await asyncio.to_thread(reconcile_once)
|
|
109
|
+
if reported:
|
|
110
|
+
_log.info("reconciled realized cost for %d run(s)", reported)
|
|
111
|
+
except asyncio.CancelledError:
|
|
112
|
+
raise # shutdown: let the lifespan's task.cancel() propagate, don't swallow it
|
|
113
|
+
except Exception:
|
|
114
|
+
_log.debug("realized-cost reconcile sweep failed; retrying next cycle", exc_info=True)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _protected_train_endpoint_names() -> set[str]:
|
|
118
|
+
"""Training-endpoint names that must NEVER be reaped: every endpoint tied to a LIVE
|
|
119
|
+
(non-terminal) run, in both the bare ``flash-...`` and SDK ``live-flash-...`` forms.
|
|
120
|
+
|
|
121
|
+
Derived from the run registry so the reaper can't delete a run that's merely idle between
|
|
122
|
+
jobs/seeds. Includes both the run's persisted handle name and the name re-derived from its
|
|
123
|
+
spec, so a run is protected even in the submit -> handle-persisted provisioning window.
|
|
124
|
+
"""
|
|
125
|
+
from flash.providers.base import canonical_gpu
|
|
126
|
+
from flash.providers.runpod.train import _run_suffix, endpoint_name
|
|
127
|
+
from flash.runner import TERMINAL_STATES
|
|
128
|
+
|
|
129
|
+
names: set[str] = set()
|
|
130
|
+
|
|
131
|
+
def _protect(name: str | None) -> None:
|
|
132
|
+
if name:
|
|
133
|
+
names.add(name)
|
|
134
|
+
names.add(f"live-{name}")
|
|
135
|
+
|
|
136
|
+
for row in db.all_runs():
|
|
137
|
+
try:
|
|
138
|
+
status = get_status(row["run_id"])
|
|
139
|
+
except FileNotFoundError:
|
|
140
|
+
continue
|
|
141
|
+
if status.state in TERMINAL_STATES:
|
|
142
|
+
continue
|
|
143
|
+
_protect((status.remote or {}).get("endpoint_name"))
|
|
144
|
+
gpu = ((status.spec or {}).get("gpu") or {}).get("type")
|
|
145
|
+
if gpu:
|
|
146
|
+
with contextlib.suppress(Exception):
|
|
147
|
+
_protect(endpoint_name(canonical_gpu(gpu), _run_suffix(status.run_id)))
|
|
148
|
+
return names
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _reap_idle_endpoints_once(min_idle_s: float) -> int:
|
|
152
|
+
"""One run-aware sweep of idle, orphaned RunPod training endpoints. Returns count deleted."""
|
|
153
|
+
from flash.providers.runpod.jobs import _sweep_idle_flash_endpoints
|
|
154
|
+
|
|
155
|
+
return _sweep_idle_flash_endpoints(_protected_train_endpoint_names(), min_idle_s=min_idle_s)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
async def _reap_idle_endpoints_loop() -> None:
|
|
159
|
+
"""Background loop: proactively delete idle, orphaned RunPod training endpoints (workers doing
|
|
160
|
+
nothing that still hold worker quota) so they don't linger between quota errors. Run-aware and
|
|
161
|
+
graced (see ``_sweep_idle_flash_endpoints``); the blocking RunPod calls are offloaded to a
|
|
162
|
+
thread, and a failed sweep is logged and retried next cycle."""
|
|
163
|
+
interval = 600.0 # sweep every 10 min
|
|
164
|
+
min_idle_s = 900.0 # only reap an endpoint idle for >= 15 min (well past any cold start)
|
|
165
|
+
while True:
|
|
166
|
+
await asyncio.sleep(interval)
|
|
167
|
+
try:
|
|
168
|
+
deleted = await asyncio.to_thread(_reap_idle_endpoints_once, min_idle_s)
|
|
169
|
+
if deleted:
|
|
170
|
+
_log.info("reaped %d idle RunPod endpoint(s) doing nothing", deleted)
|
|
171
|
+
except asyncio.CancelledError:
|
|
172
|
+
raise # shutdown: let the lifespan's task.cancel() propagate, don't swallow it
|
|
173
|
+
except Exception:
|
|
174
|
+
_log.debug("idle-endpoint reaper sweep failed; retrying next cycle", exc_info=True)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
# Run states that may still OWN a live, billing training instance, so their provider instances must
|
|
178
|
+
# be PROTECTED from the orphan sweep. Deliberately EXCLUDES ``deployed``: a run only reaches
|
|
179
|
+
# ``deployed`` after it went ``done`` (the seed loop's ``finally`` already tore every training
|
|
180
|
+
# instance down), so a deployed run owns no training worker — keeping it in the protection set would
|
|
181
|
+
# instead SHIELD a genuine leaked instance under its prefix from the sweep (the very thing the sweep
|
|
182
|
+
# exists to reap). Terminal states are excluded for the same reason. This is exactly ``_RECOVERABLE``
|
|
183
|
+
# — a run is recoverable on restart iff it may still have an in-flight worker — so it is ALIASED
|
|
184
|
+
# (one source of truth) to keep the two protection sets from silently drifting apart.
|
|
185
|
+
_INSTANCE_OWNING_STATES = _RECOVERABLE
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _active_run_ids() -> set[str]:
|
|
189
|
+
"""Run ids of every run that may still own a live training instance — the set whose provider
|
|
190
|
+
instances must be PROTECTED from the periodic orphan sweep below. The instance providers'
|
|
191
|
+
``sweep_orphans`` re-derives each instance-label prefix from a run id via ``run_label_prefix``,
|
|
192
|
+
so it wants raw run ids (unlike ``_protected_train_endpoint_names``, which yields RunPod endpoint
|
|
193
|
+
*names*).
|
|
194
|
+
|
|
195
|
+
Why this is a safe protection set with no idle grace: a run's status is flipped to an
|
|
196
|
+
instance-owning state BEFORE its first instance is ever launched (``_run_seed_loop`` writes
|
|
197
|
+
``running`` ahead of ``_submit_seed_supervised``), and the launched instance is torn down BEFORE
|
|
198
|
+
the run can leave these states for ``done``/``deployed``/terminal (the provider lifecycle's
|
|
199
|
+
``finally``). So a billed instance exists ONLY while its run is in this set — ownership is a
|
|
200
|
+
deterministic name->run mapping, not the noisy idle signal the RunPod reaper must grace. The
|
|
201
|
+
sweep passes this function itself (a callable) so the set is read AFTER the provider lists, which
|
|
202
|
+
closes the launch race — see ``_sweep_orphan_instances_once``. (Startup recovery in
|
|
203
|
+
``recover_runs`` deliberately uses a NARROWER set — only handle-backed/resume runs — because it
|
|
204
|
+
is simultaneously RESUBMITTING handle-less runs and must reap their stale half-rented instances;
|
|
205
|
+
in-lifetime we instead protect every instance-owning run.)"""
|
|
206
|
+
ids: set[str] = set()
|
|
207
|
+
for row in db.all_runs():
|
|
208
|
+
try:
|
|
209
|
+
status = get_status(row["run_id"])
|
|
210
|
+
except FileNotFoundError:
|
|
211
|
+
continue
|
|
212
|
+
if status.state in _INSTANCE_OWNING_STATES:
|
|
213
|
+
ids.add(status.run_id)
|
|
214
|
+
return ids
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _sweep_orphan_instances_once() -> int:
|
|
218
|
+
"""One run-aware sweep of orphaned instance-provider workers — Lambda/Hyperstack VMs whose run
|
|
219
|
+
finished or crashed without the per-run ``finally`` tearing them down. Returns the count torn
|
|
220
|
+
down. Dispatched to every configured provider; RunPod's ``sweep_orphans`` is a no-op (its
|
|
221
|
+
serverless endpoints carry no standing per-run billing and are handled by the idle reaper).
|
|
222
|
+
|
|
223
|
+
``_active_run_ids`` is passed as a CALLABLE, not a precomputed set, so each instance provider
|
|
224
|
+
resolves the live-run protection set AFTER it has listed its instances. That ordering closes the
|
|
225
|
+
launch race: any instance already present in the list had its run's status row committed before
|
|
226
|
+
the instance was launched, so it is guaranteed to be in the set read post-listing — a run that
|
|
227
|
+
started a worker concurrently with this sweep can never be mis-reaped as a phantom orphan. (The
|
|
228
|
+
instance APIs expose no creation timestamp, so this post-listing read — not an age grace — is
|
|
229
|
+
what makes it airtight.)"""
|
|
230
|
+
from flash.providers import configured_providers
|
|
231
|
+
|
|
232
|
+
torn = 0
|
|
233
|
+
for prov in configured_providers():
|
|
234
|
+
try:
|
|
235
|
+
deleted = prov.sweep_orphans(active_labels=_active_run_ids)
|
|
236
|
+
except Exception:
|
|
237
|
+
# One provider's API blip / outage must not skip the others — and must NOT be silent
|
|
238
|
+
# (the loop docstring promises failures are logged + retried next cycle), so a
|
|
239
|
+
# persistent failure (bad creds, signature mismatch) is visible instead of looking
|
|
240
|
+
# like a healthy sweep reaping nothing.
|
|
241
|
+
_log.warning(
|
|
242
|
+
"instance orphan sweep failed for provider %r; retrying next cycle",
|
|
243
|
+
getattr(prov, "name", prov),
|
|
244
|
+
exc_info=True,
|
|
245
|
+
)
|
|
246
|
+
continue
|
|
247
|
+
torn += len(deleted)
|
|
248
|
+
return torn
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
async def _sweep_orphan_instances_loop() -> None:
|
|
252
|
+
"""Background loop: proactively tear down orphaned Lambda/Hyperstack instances (billed VMs left
|
|
253
|
+
by finished/crashed runs that the per-run ``finally`` teardown missed) so they stop billing
|
|
254
|
+
without waiting for the next control-plane restart. This is the in-lifetime counterpart of the
|
|
255
|
+
instance providers' startup ``sweep_orphans`` (``recover_runs``) — the instance analogue of
|
|
256
|
+
``_reap_idle_endpoints_loop`` for RunPod. Blocking provider calls are offloaded to a thread; a
|
|
257
|
+
failed sweep is logged and retried next cycle."""
|
|
258
|
+
interval = 600.0 # sweep every 10 min (matches the RunPod idle reaper)
|
|
259
|
+
while True:
|
|
260
|
+
await asyncio.sleep(interval)
|
|
261
|
+
try:
|
|
262
|
+
torn = await asyncio.to_thread(_sweep_orphan_instances_once)
|
|
263
|
+
if torn:
|
|
264
|
+
_log.info("swept %d orphaned instance-provider worker(s)", torn)
|
|
265
|
+
except asyncio.CancelledError:
|
|
266
|
+
raise # shutdown: let the lifespan's task.cancel() propagate, don't swallow it
|
|
267
|
+
except Exception:
|
|
268
|
+
_log.debug("instance orphan sweep failed; retrying next cycle", exc_info=True)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def _instance_providers_configured() -> bool:
|
|
272
|
+
"""True when an instance-based provider (Lambda / Hyperstack) is configured on this plane, so the
|
|
273
|
+
periodic instance orphan sweep is worth running. RunPod-only planes skip it — RunPod has no
|
|
274
|
+
standing per-run billing to reap between restarts (its idle reaper covers warm endpoints)."""
|
|
275
|
+
from flash.providers import available_providers
|
|
276
|
+
|
|
277
|
+
return any(name in ("lambda", "hyperstack") for name in available_providers())
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
class _RunLock:
|
|
281
|
+
"""A weak-referenceable mutex usable as a context manager.
|
|
282
|
+
|
|
283
|
+
``threading.Lock()`` returns a ``_thread.lock`` that does NOT support weak references,
|
|
284
|
+
so it can't live in a WeakValueDictionary directly — wrap it in a tiny object that does
|
|
285
|
+
(and acquire/release via ``with``).
|
|
286
|
+
"""
|
|
287
|
+
|
|
288
|
+
__slots__ = ("__weakref__", "_lock")
|
|
289
|
+
|
|
290
|
+
def __init__(self) -> None:
|
|
291
|
+
self._lock = threading.Lock()
|
|
292
|
+
|
|
293
|
+
def __enter__(self) -> _RunLock:
|
|
294
|
+
self._lock.acquire()
|
|
295
|
+
return self
|
|
296
|
+
|
|
297
|
+
def __exit__(self, *exc: object) -> None:
|
|
298
|
+
self._lock.release()
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
# Per-run lock serializing deploy vs undeploy: registration with the freesolo serving app
|
|
302
|
+
# is slow and runs OUTSIDE the status lock, so without this the two could interleave —
|
|
303
|
+
# a racing undeploy could leave a stale deployment record (registered with freesolo but
|
|
304
|
+
# unrecorded here, or vice-versa), or a deploy's cleanup of a raced finalize could clobber
|
|
305
|
+
# another. Serving is delegated to freesolo (scales to zero per base model), so there is no
|
|
306
|
+
# billable flash-side endpoint at stake — only the deployment record's consistency.
|
|
307
|
+
# WeakValueDictionary so an entry is dropped once no request holds the lock — the map
|
|
308
|
+
# can't grow unboundedly with one entry per distinct run_id over the server's lifetime.
|
|
309
|
+
_DEPLOY_LOCKS: weakref.WeakValueDictionary[str, _RunLock] = weakref.WeakValueDictionary()
|
|
310
|
+
_DEPLOY_LOCKS_GUARD = threading.Lock()
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def _deploy_lock(run_id: str) -> _RunLock:
|
|
314
|
+
# The returned lock must be held by the caller (a `with` block) to keep it alive; once
|
|
315
|
+
# released and unreferenced, the weak entry is garbage-collected.
|
|
316
|
+
with _DEPLOY_LOCKS_GUARD:
|
|
317
|
+
lk = _DEPLOY_LOCKS.get(run_id)
|
|
318
|
+
if lk is None:
|
|
319
|
+
lk = _RunLock()
|
|
320
|
+
_DEPLOY_LOCKS[run_id] = lk
|
|
321
|
+
return lk
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def _append_run_log(run_id: str, message: str) -> None:
|
|
325
|
+
"""Append a timestamped note to a run's log so it surfaces in `flash status --logs`."""
|
|
326
|
+
import time
|
|
327
|
+
|
|
328
|
+
with open(runs_file_path(run_id, ".log"), "a") as f:
|
|
329
|
+
f.write(f"[{time.strftime('%H:%M:%S')}] {message}\n")
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def _worker_artifacts(spec) -> dict[str, str]:
|
|
333
|
+
"""The run's train-subprocess stdout + traceback, fetched from its HF artifact repo.
|
|
334
|
+
|
|
335
|
+
The control-plane ``.log`` only carries orchestrator lines (and, on a terminal failure, a
|
|
336
|
+
truncated tail of the worker console). The full ``console_<phase>.txt`` / ``error_<phase>.txt``
|
|
337
|
+
the worker streams to HF are the real train stdout/traceback — but the repo is PRIVATE, so a
|
|
338
|
+
user's own HF token 404s. We fetch them here with the OPERATOR ``HF_TOKEN`` (the control plane
|
|
339
|
+
already holds it) so ``flash status --logs`` shows the real worker output regardless of run
|
|
340
|
+
state and without the user needing repo access. Best-effort: a missing file / no repo yields {}.
|
|
341
|
+
"""
|
|
342
|
+
repo = getattr(getattr(spec, "train", None), "hf_repo", None)
|
|
343
|
+
if not repo:
|
|
344
|
+
return {}
|
|
345
|
+
try:
|
|
346
|
+
from huggingface_hub import hf_hub_download
|
|
347
|
+
except Exception:
|
|
348
|
+
return {}
|
|
349
|
+
prefix = adapter_prefix(spec)
|
|
350
|
+
out: dict[str, str] = {}
|
|
351
|
+
for name in (f"console_{spec.phase}.txt", f"error_{spec.phase}.txt"):
|
|
352
|
+
try:
|
|
353
|
+
path = hf_hub_download(
|
|
354
|
+
repo_id=repo,
|
|
355
|
+
repo_type="dataset",
|
|
356
|
+
filename=f"{prefix}/{name}",
|
|
357
|
+
token=os.environ.get("HF_TOKEN"),
|
|
358
|
+
# The worker appends to console/error files across the run, so a cached copy goes
|
|
359
|
+
# stale; force a fresh pull (matches other HF artifact readers, e.g.
|
|
360
|
+
# flash/providers/runpod/jobs.py:make_hf_text_reader).
|
|
361
|
+
force_download=True,
|
|
362
|
+
)
|
|
363
|
+
# errors="replace": worker stdout can carry non-UTF-8 bytes (tracebacks, progress bars);
|
|
364
|
+
# decode leniently so a single bad byte never drops the whole log on UnicodeDecodeError.
|
|
365
|
+
with open(path, encoding="utf-8", errors="replace") as f:
|
|
366
|
+
out[name] = f.read()
|
|
367
|
+
except Exception:
|
|
368
|
+
continue # file not uploaded yet / not produced for this phase
|
|
369
|
+
return out
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
def recover_runs() -> None:
|
|
373
|
+
"""Recover every in-flight run after a restart so a redeploy never loses a training session:
|
|
374
|
+
re-attach to ``running`` jobs, resume multi-seed runs across the inter-seed gap, and resubmit
|
|
375
|
+
``queued``/``provisioning`` runs that never reached a worker."""
|
|
376
|
+
from flash.runner import (
|
|
377
|
+
_gc_run_endpoints,
|
|
378
|
+
_run_job_background,
|
|
379
|
+
_update,
|
|
380
|
+
attach_run,
|
|
381
|
+
resume_run,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
active: set[str] = set()
|
|
385
|
+
# Deferred until after the orphan sweep so a half-rented instance from a crashed pre-handle
|
|
386
|
+
# attempt is reaped without racing the resubmit's fresh allocation.
|
|
387
|
+
resubmit: list[JobSpec] = []
|
|
388
|
+
for row in db.all_runs():
|
|
389
|
+
try:
|
|
390
|
+
status = get_status(row["run_id"])
|
|
391
|
+
except FileNotFoundError:
|
|
392
|
+
continue
|
|
393
|
+
if status.state not in _RECOVERABLE:
|
|
394
|
+
continue
|
|
395
|
+
if status.remote:
|
|
396
|
+
# Only handle-backed runs are kept by the sweep; a handle-less run is being
|
|
397
|
+
# resubmitted, so its stale half-rented instance (if any) must NOT be shielded.
|
|
398
|
+
active.add(status.run_id)
|
|
399
|
+
threading.Thread(target=lambda rid=row["run_id"]: attach_run(rid), daemon=True).start()
|
|
400
|
+
elif status.resume_seed_index is not None:
|
|
401
|
+
# Restarted between seeds: resume the remaining seeds, preserving the finished ones.
|
|
402
|
+
active.add(status.run_id)
|
|
403
|
+
threading.Thread(target=lambda rid=row["run_id"]: resume_run(rid), daemon=True).start()
|
|
404
|
+
else:
|
|
405
|
+
# No handle yet: the restart hit the submit->provisioning window, so no worker exists.
|
|
406
|
+
# A spec that won't parse can never be resubmitted -> mark it terminally failed
|
|
407
|
+
# (operator-visible, dropped from _RECOVERABLE so it isn't re-skipped every restart);
|
|
408
|
+
# otherwise GC any half-made endpoint and resubmit from scratch.
|
|
409
|
+
try:
|
|
410
|
+
spec = JobSpec.from_dict(status.spec)
|
|
411
|
+
except Exception as exc:
|
|
412
|
+
_log.warning(
|
|
413
|
+
"marking run %s failed: persisted spec could not be parsed",
|
|
414
|
+
status.run_id,
|
|
415
|
+
exc_info=True,
|
|
416
|
+
)
|
|
417
|
+
detail = f"unrecoverable: persisted spec is malformed: {exc}"
|
|
418
|
+
with contextlib.suppress(Exception):
|
|
419
|
+
_update(status.run_id, "failed", error=detail)
|
|
420
|
+
with contextlib.suppress(Exception):
|
|
421
|
+
_append_run_log(status.run_id, detail)
|
|
422
|
+
# The aborted attempt may STILL have registered its uniquely-named RunPod
|
|
423
|
+
# endpoint before crashing (the exact leak the good-spec branch's
|
|
424
|
+
# `_gc_run_endpoints` guards against). The `sweep_orphans` dispatch below is a
|
|
425
|
+
# no-op for RunPod, and the periodic idle reaper would only reclaim this after its
|
|
426
|
+
# 15-min idle grace — so tear it down by name HERE for immediate cleanup.
|
|
427
|
+
# `_gc_run_endpoints` needs a parsed `JobSpec`, which we don't have; but the
|
|
428
|
+
# endpoint name is derived deterministically from the run id + GPU class
|
|
429
|
+
# (`endpoint_name(gpu, _run_suffix(run_id))`), both readable from the RAW
|
|
430
|
+
# persisted status without parsing the spec. Terminate by that reconstructed
|
|
431
|
+
# name. Best-effort/suppressed so it can never re-abort recovery; then continue.
|
|
432
|
+
with contextlib.suppress(Exception):
|
|
433
|
+
gpu_type = (status.spec.get("gpu") or {}).get("type")
|
|
434
|
+
if gpu_type:
|
|
435
|
+
from flash.providers.runpod.train import terminate_endpoint
|
|
436
|
+
|
|
437
|
+
terminate_endpoint(gpu_type, status.run_id)
|
|
438
|
+
continue
|
|
439
|
+
with contextlib.suppress(Exception):
|
|
440
|
+
_gc_run_endpoints(spec)
|
|
441
|
+
resubmit.append(spec)
|
|
442
|
+
# Reap orphaned per-run provider resources; each provider sweeps its own.
|
|
443
|
+
from flash.providers import configured_providers
|
|
444
|
+
|
|
445
|
+
for prov in configured_providers():
|
|
446
|
+
with contextlib.suppress(Exception):
|
|
447
|
+
prov.sweep_orphans(active_labels=active)
|
|
448
|
+
|
|
449
|
+
for spec in resubmit:
|
|
450
|
+
_log.info("resubmitting run %s after control-plane restart", spec.run_id)
|
|
451
|
+
with contextlib.suppress(Exception):
|
|
452
|
+
_append_run_log(
|
|
453
|
+
spec.run_id, "control plane restarted before provisioning; resubmitting"
|
|
454
|
+
)
|
|
455
|
+
threading.Thread(target=_run_job_background, args=(spec,), daemon=True).start()
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def create_app():
|
|
459
|
+
try:
|
|
460
|
+
from fastapi import Depends, FastAPI, Header, HTTPException
|
|
461
|
+
from fastapi.responses import StreamingResponse
|
|
462
|
+
except ImportError as exc:
|
|
463
|
+
raise RuntimeError(_SERVER_EXTRAS_HINT) from exc
|
|
464
|
+
from contextlib import asynccontextmanager
|
|
465
|
+
|
|
466
|
+
@asynccontextmanager
|
|
467
|
+
async def lifespan(app):
|
|
468
|
+
from flash.providers.preflight import check_run_preflight
|
|
469
|
+
from flash.server.reconcile import reconcile_enabled
|
|
470
|
+
|
|
471
|
+
check_run_preflight() # operator credentials: fail fast, before serving anyone
|
|
472
|
+
recover_runs()
|
|
473
|
+
# Reconcile the shared RunPod endpoint-slot quota against the live endpoint list so a
|
|
474
|
+
# crash can't leak slots permanently (no-op without an internal key). Best-effort.
|
|
475
|
+
with contextlib.suppress(Exception):
|
|
476
|
+
from flash.providers.runpod.train.endpoints import reconcile_endpoint_slots
|
|
477
|
+
|
|
478
|
+
reconcile_endpoint_slots()
|
|
479
|
+
# Periodic realized-cost reconciliation (estimator accuracy), only when the operator
|
|
480
|
+
# internal key is configured.
|
|
481
|
+
cost_task = asyncio.create_task(_reconcile_cost_loop()) if reconcile_enabled() else None
|
|
482
|
+
# Periodic idle-endpoint reaper: proactively delete RunPod training endpoints doing
|
|
483
|
+
# nothing (orphans from finished/crashed runs) so workers don't linger holding quota.
|
|
484
|
+
# Only when this plane manages RunPod (its API key is configured).
|
|
485
|
+
reap_task = (
|
|
486
|
+
asyncio.create_task(_reap_idle_endpoints_loop())
|
|
487
|
+
if os.environ.get("RUNPOD_API_KEY")
|
|
488
|
+
else None
|
|
489
|
+
)
|
|
490
|
+
# Periodic instance orphan sweep: proactively tear down Lambda/Hyperstack VMs left billing by
|
|
491
|
+
# finished/crashed runs (the in-lifetime counterpart of their startup sweep_orphans). Only
|
|
492
|
+
# when an instance provider is configured — RunPod-only planes have nothing standing to reap.
|
|
493
|
+
sweep_task = (
|
|
494
|
+
asyncio.create_task(_sweep_orphan_instances_loop())
|
|
495
|
+
if _instance_providers_configured()
|
|
496
|
+
else None
|
|
497
|
+
)
|
|
498
|
+
try:
|
|
499
|
+
yield
|
|
500
|
+
finally:
|
|
501
|
+
for task in (cost_task, reap_task, sweep_task):
|
|
502
|
+
if task is not None:
|
|
503
|
+
task.cancel()
|
|
504
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
505
|
+
await task
|
|
506
|
+
|
|
507
|
+
app = FastAPI(title="Flash Control Plane", version=__version__, lifespan=lifespan)
|
|
508
|
+
|
|
509
|
+
def require_key(authorization: str | None = Header(default=None)) -> dict:
|
|
510
|
+
key = auth.authenticate(authorization)
|
|
511
|
+
if key is None:
|
|
512
|
+
raise HTTPException(
|
|
513
|
+
status_code=401,
|
|
514
|
+
detail="invalid or missing API key; log in with `flash login` using your "
|
|
515
|
+
"freesolo API key",
|
|
516
|
+
)
|
|
517
|
+
return key
|
|
518
|
+
|
|
519
|
+
def owned_run(run_id: str, key: dict):
|
|
520
|
+
"""Load a run's status iff `key` owns it; 404 otherwise (don't leak existence)."""
|
|
521
|
+
if db.run_owner(run_id) != key["id"]:
|
|
522
|
+
raise HTTPException(status_code=404, detail=f"unknown run_id: {run_id}")
|
|
523
|
+
try:
|
|
524
|
+
return get_status(run_id)
|
|
525
|
+
except FileNotFoundError as exc:
|
|
526
|
+
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
|
527
|
+
|
|
528
|
+
@app.get("/v1/health")
|
|
529
|
+
def health():
|
|
530
|
+
return {"ok": True, "service": "flash", "version": __version__}
|
|
531
|
+
|
|
532
|
+
@app.get("/v1/me")
|
|
533
|
+
def me(key: dict = Depends(require_key)):
|
|
534
|
+
payload = {
|
|
535
|
+
"kind": "internal" if key.get("auth_kind") == "internal" else "freesolo_api_key",
|
|
536
|
+
"key_prefix": key["key_prefix"],
|
|
537
|
+
}
|
|
538
|
+
for field in (
|
|
539
|
+
"email",
|
|
540
|
+
"user_id",
|
|
541
|
+
"org_id",
|
|
542
|
+
"api_key_id",
|
|
543
|
+
"training_agent_job_id",
|
|
544
|
+
"project_id",
|
|
545
|
+
):
|
|
546
|
+
if key.get(field):
|
|
547
|
+
payload[field] = key[field]
|
|
548
|
+
return payload
|
|
549
|
+
|
|
550
|
+
@app.get("/v1/models")
|
|
551
|
+
def models(_: dict = Depends(require_key)):
|
|
552
|
+
return {"models": public_model_rows()}
|
|
553
|
+
|
|
554
|
+
@app.post("/v1/envs")
|
|
555
|
+
def publish_env(payload: dict, key: dict = Depends(require_key)):
|
|
556
|
+
# Publish a client-built Freesolo environment package to the managed
|
|
557
|
+
# environment repository. Users never need direct repository credentials.
|
|
558
|
+
from flash.server import envs
|
|
559
|
+
|
|
560
|
+
# Default to "" only when the key is missing/None — pass a present-but-falsy
|
|
561
|
+
# non-string (0, False, []) THROUGH so publish_package's type checks reject it with
|
|
562
|
+
# the right 400, instead of `or ""` silently coercing it to a valid-looking empty string.
|
|
563
|
+
_pkg = payload.get("package_b64")
|
|
564
|
+
_name = payload.get("name")
|
|
565
|
+
try:
|
|
566
|
+
slug = envs.publish_package(
|
|
567
|
+
package_b64="" if _pkg is None else _pkg,
|
|
568
|
+
name="" if _name is None else _name,
|
|
569
|
+
key=key,
|
|
570
|
+
)
|
|
571
|
+
except envs.EnvPublishError as exc:
|
|
572
|
+
raise HTTPException(status_code=exc.status, detail=str(exc)) from exc
|
|
573
|
+
from flash.server.environment_registry import record_published_environment
|
|
574
|
+
|
|
575
|
+
record_published_environment(slug=slug, name=str(_name), key=key)
|
|
576
|
+
return {"id": slug}
|
|
577
|
+
|
|
578
|
+
def _parse_spec(payload: dict, run_id: str) -> JobSpec:
|
|
579
|
+
spec_raw = payload.get("spec") or {}
|
|
580
|
+
env_raw = spec_raw.get("environment") or {}
|
|
581
|
+
if env_raw.get("path"):
|
|
582
|
+
raise HTTPException(
|
|
583
|
+
status_code=400,
|
|
584
|
+
detail="local environment paths are not supported on the managed service; "
|
|
585
|
+
"publish the environment with `flash env push --name <name>`, then reference it "
|
|
586
|
+
"by the returned environment id",
|
|
587
|
+
)
|
|
588
|
+
try:
|
|
589
|
+
return spec_from_dict(spec_raw, run_id=run_id)
|
|
590
|
+
except (ConfigError, ValueError) as exc:
|
|
591
|
+
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
592
|
+
|
|
593
|
+
def _runtime_secrets(
|
|
594
|
+
payload: dict, spec: JobSpec, *, require_environment_secrets: bool
|
|
595
|
+
) -> dict[str, str]:
|
|
596
|
+
raw = payload.get("runtime_secrets") or {}
|
|
597
|
+
if not isinstance(raw, dict):
|
|
598
|
+
raise HTTPException(status_code=400, detail="runtime_secrets must be a JSON object")
|
|
599
|
+
allowed = set(_RUNTIME_SECRET_KEYS) | set(spec.environment.secrets)
|
|
600
|
+
unknown = sorted(set(raw) - allowed)
|
|
601
|
+
if unknown:
|
|
602
|
+
raise HTTPException(
|
|
603
|
+
status_code=400,
|
|
604
|
+
detail=(
|
|
605
|
+
"unsupported runtime secret(s): "
|
|
606
|
+
f"{', '.join(unknown)} (allowed: {', '.join(sorted(allowed))})"
|
|
607
|
+
),
|
|
608
|
+
)
|
|
609
|
+
out: dict[str, str] = {}
|
|
610
|
+
for key, value in raw.items():
|
|
611
|
+
if value is None:
|
|
612
|
+
continue
|
|
613
|
+
if not isinstance(value, str):
|
|
614
|
+
raise HTTPException(
|
|
615
|
+
status_code=400, detail=f"runtime_secrets.{key} must be a string"
|
|
616
|
+
)
|
|
617
|
+
value = value.strip()
|
|
618
|
+
if value:
|
|
619
|
+
out[key] = value
|
|
620
|
+
if require_environment_secrets:
|
|
621
|
+
missing = sorted(set(spec.environment.secrets) - set(out))
|
|
622
|
+
if missing:
|
|
623
|
+
raise HTTPException(
|
|
624
|
+
status_code=400,
|
|
625
|
+
detail=(
|
|
626
|
+
"missing runtime secret(s) required by [environment] secrets: "
|
|
627
|
+
f"{', '.join(missing)}"
|
|
628
|
+
),
|
|
629
|
+
)
|
|
630
|
+
return out
|
|
631
|
+
|
|
632
|
+
@app.post("/v1/runs")
|
|
633
|
+
def create_run(
|
|
634
|
+
payload: dict,
|
|
635
|
+
key: dict = Depends(require_key),
|
|
636
|
+
):
|
|
637
|
+
spec = _parse_spec(payload, run_id=new_run_id())
|
|
638
|
+
dry_run = bool(payload.get("dry_run", False))
|
|
639
|
+
runtime_secrets = _runtime_secrets(
|
|
640
|
+
payload, spec, require_environment_secrets=not dry_run
|
|
641
|
+
)
|
|
642
|
+
# External user-key runs are charged only after training succeeds. Persist the org id
|
|
643
|
+
# (non-secret) so the background runner can bill with the operator internal key at
|
|
644
|
+
# completion; never persist the submitting user's API key.
|
|
645
|
+
bill_on_completion = not dry_run and key.get("auth_kind") != "internal"
|
|
646
|
+
billing_context = None
|
|
647
|
+
if bill_on_completion:
|
|
648
|
+
org_id = str(key.get("org_id") or "").strip()
|
|
649
|
+
if not org_id:
|
|
650
|
+
raise HTTPException(
|
|
651
|
+
status_code=400,
|
|
652
|
+
detail="org id is required to bill a completed training run",
|
|
653
|
+
)
|
|
654
|
+
billing_context = {"org_id": org_id}
|
|
655
|
+
try:
|
|
656
|
+
db.record_run(spec.run_id, key["id"])
|
|
657
|
+
submit_kwargs = {"dry_run": dry_run, "background": True}
|
|
658
|
+
if runtime_secrets:
|
|
659
|
+
submit_kwargs["runtime_secrets"] = runtime_secrets
|
|
660
|
+
if billing_context:
|
|
661
|
+
submit_kwargs["billing_context"] = billing_context
|
|
662
|
+
platform_context = {
|
|
663
|
+
field: value
|
|
664
|
+
for field, value in {
|
|
665
|
+
"org_id": key.get("org_id"),
|
|
666
|
+
"user_id": key.get("user_id"),
|
|
667
|
+
"api_key_id": key.get("api_key_id"),
|
|
668
|
+
}.items()
|
|
669
|
+
if value
|
|
670
|
+
}
|
|
671
|
+
if platform_context:
|
|
672
|
+
submit_kwargs["platform_context"] = platform_context
|
|
673
|
+
status = submit_job(spec, **submit_kwargs)
|
|
674
|
+
# submit_job already reports the freshly-created status to the backend via
|
|
675
|
+
# _report_status -> record_training_run, and the status carries platform_context
|
|
676
|
+
# (org_id/user_id/api_key_id derived from `key`), so a second explicit
|
|
677
|
+
# record_training_run(status, key) here would just re-POST the same creation record.
|
|
678
|
+
# Don't duplicate it.
|
|
679
|
+
from flash.envs.adapter import is_managed_environment_slug
|
|
680
|
+
from flash.server.environment_registry import record_environment_use
|
|
681
|
+
|
|
682
|
+
if is_managed_environment_slug(spec.environment.id):
|
|
683
|
+
record_environment_use(slug=spec.environment.id, run_id=spec.run_id, key=key)
|
|
684
|
+
except Exception as exc:
|
|
685
|
+
db.delete_run(spec.run_id) # idempotent: a no-op if record_run never landed
|
|
686
|
+
if isinstance(exc, HTTPException):
|
|
687
|
+
raise
|
|
688
|
+
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
689
|
+
return status.to_dict()
|
|
690
|
+
|
|
691
|
+
@app.get("/v1/runs")
|
|
692
|
+
def list_runs(key: dict = Depends(require_key)):
|
|
693
|
+
out = []
|
|
694
|
+
for row in db.runs_for_key(key["id"]):
|
|
695
|
+
try:
|
|
696
|
+
out.append(get_status(row["run_id"]).to_dict())
|
|
697
|
+
except FileNotFoundError:
|
|
698
|
+
continue
|
|
699
|
+
return {"runs": out}
|
|
700
|
+
|
|
701
|
+
@app.get("/v1/runs/{run_id}")
|
|
702
|
+
def run_status(run_id: str, key: dict = Depends(require_key)):
|
|
703
|
+
status = owned_run(run_id, key)
|
|
704
|
+
return status.to_dict()
|
|
705
|
+
|
|
706
|
+
@app.get("/v1/runs/{run_id}/logs")
|
|
707
|
+
def run_logs(run_id: str, offset: int = 0, key: dict = Depends(require_key)):
|
|
708
|
+
status = owned_run(run_id, key)
|
|
709
|
+
log_path = runs_file_path(run_id, ".log")
|
|
710
|
+
chunk, end = "", max(0, offset)
|
|
711
|
+
if os.path.exists(log_path):
|
|
712
|
+
with open(log_path) as f:
|
|
713
|
+
f.seek(end)
|
|
714
|
+
chunk = f.read()
|
|
715
|
+
end = f.tell()
|
|
716
|
+
return {
|
|
717
|
+
"run_id": run_id,
|
|
718
|
+
"logs": chunk,
|
|
719
|
+
"offset": end,
|
|
720
|
+
"state": status.state,
|
|
721
|
+
"last_heartbeat": status.last_heartbeat,
|
|
722
|
+
"gpu_status": status.gpu_status,
|
|
723
|
+
}
|
|
724
|
+
|
|
725
|
+
@app.get("/v1/runs/{run_id}/worker")
|
|
726
|
+
def run_worker_output(run_id: str, key: dict = Depends(require_key)):
|
|
727
|
+
# The full train-subprocess stdout/traceback, pulled from the run's HF artifact repo with
|
|
728
|
+
# the operator token — the real worker output the offset-paged .log can't carry. Kept off
|
|
729
|
+
# the hot /logs poll path (it hits HF) so streaming `--follow` stays fast; `--logs` calls
|
|
730
|
+
# this once. Best-effort: {} when nothing's been uploaded yet.
|
|
731
|
+
status = owned_run(run_id, key)
|
|
732
|
+
return {"run_id": run_id, "worker": _worker_artifacts(JobSpec.from_dict(status.spec))}
|
|
733
|
+
|
|
734
|
+
@app.post("/v1/runs/{run_id}/cancel")
|
|
735
|
+
def cancel(run_id: str, key: dict = Depends(require_key)):
|
|
736
|
+
owned_run(run_id, key)
|
|
737
|
+
return cancel_run(run_id).to_dict()
|
|
738
|
+
|
|
739
|
+
@app.post("/v1/runs/{run_id}/deploy")
|
|
740
|
+
def deploy(run_id: str, payload: dict | None = None, key: dict = Depends(require_key)):
|
|
741
|
+
payload = payload or {}
|
|
742
|
+
# Serialize deploy vs undeploy (and a second deploy) for this run: registration
|
|
743
|
+
# with the freesolo serving app runs outside the status lock, so without this they
|
|
744
|
+
# could interleave and leave the serving record and the control plane inconsistent.
|
|
745
|
+
with _deploy_lock(run_id):
|
|
746
|
+
status = owned_run(run_id, key)
|
|
747
|
+
spec = JobSpec.from_dict(status.spec)
|
|
748
|
+
dry_run = bool(payload.get("dry_run", False))
|
|
749
|
+
# Optional `step`: deploy a specific intermediate checkpoint instead of the run's
|
|
750
|
+
# final adapter. We resolve it against what's actually on HF (the source of truth),
|
|
751
|
+
# so a missing step 404s with the available list rather than 500ing at serve time.
|
|
752
|
+
checkpoint_step = _resolve_deploy_step(run_id, spec, payload.get("step"))
|
|
753
|
+
is_checkpoint = checkpoint_step is not None
|
|
754
|
+
allowed_states = (
|
|
755
|
+
_CHECKPOINT_DEPLOYABLE_STATES if is_checkpoint else _DEPLOYABLE_STATES
|
|
756
|
+
)
|
|
757
|
+
if not dry_run and status.state not in allowed_states:
|
|
758
|
+
detail = (
|
|
759
|
+
f"run {run_id} is {status.state!r}; deploy a checkpoint only once the run "
|
|
760
|
+
"has finished or been cancelled"
|
|
761
|
+
if is_checkpoint
|
|
762
|
+
else f"run {run_id} is {status.state!r}; only finished runs with "
|
|
763
|
+
"trained adapter artifacts can be deployed"
|
|
764
|
+
)
|
|
765
|
+
raise HTTPException(status_code=409, detail=detail)
|
|
766
|
+
# Legacy runs persisted before [train].hf_repo was mandatory rehydrate with an
|
|
767
|
+
# empty hf_repo; without this guard freesolo serving cannot locate the adapter
|
|
768
|
+
# artifacts (the per-run HF dataset repo). Reject early with a clear 409.
|
|
769
|
+
if not dry_run and not spec.train.hf_repo:
|
|
770
|
+
raise HTTPException(
|
|
771
|
+
status_code=409,
|
|
772
|
+
detail=(
|
|
773
|
+
f"run {run_id} has no [train].hf_repo (legacy run); its adapter artifacts "
|
|
774
|
+
"cannot be located, so it cannot be deployed"
|
|
775
|
+
),
|
|
776
|
+
)
|
|
777
|
+
# A checkpoint deploy serves the per-step adapter; otherwise the run's final adapter.
|
|
778
|
+
deploy_prefix = (
|
|
779
|
+
checkpoint_adapter_prefix(spec, checkpoint_step)
|
|
780
|
+
if is_checkpoint
|
|
781
|
+
else adapter_prefix(spec)
|
|
782
|
+
)
|
|
783
|
+
# The state the run must still be in for this deploy to finalize — a CAS guard so
|
|
784
|
+
# a /cancel (NOT serialized by the deploy lock) that terminalized the run can't be
|
|
785
|
+
# silently overwritten by the deployment record.
|
|
786
|
+
prev_state = status.state
|
|
787
|
+
# Attribute the adapter to the RUN's owning org so serving can authorize external chat
|
|
788
|
+
# by org. Prefer the org persisted WITH the run — billing_context for user runs,
|
|
789
|
+
# platform_context for internal/operator runs (see submit path) — over the caller's key,
|
|
790
|
+
# so an operator deploy still lands on the run's owner. Each context is isinstance-guarded
|
|
791
|
+
# against a non-dict legacy value (mirrors flash/server/billing.py / checkpoints.py).
|
|
792
|
+
def _run_org(*contexts) -> str:
|
|
793
|
+
for ctx in contexts:
|
|
794
|
+
if isinstance(ctx, dict):
|
|
795
|
+
org = str(ctx.get("org_id") or "").strip()
|
|
796
|
+
if org:
|
|
797
|
+
return org
|
|
798
|
+
return ""
|
|
799
|
+
|
|
800
|
+
deploy_org_id = (
|
|
801
|
+
_run_org(
|
|
802
|
+
getattr(status, "billing_context", None),
|
|
803
|
+
getattr(status, "platform_context", None),
|
|
804
|
+
)
|
|
805
|
+
or str(key.get("org_id") or "").strip()
|
|
806
|
+
or None
|
|
807
|
+
)
|
|
808
|
+
try:
|
|
809
|
+
dep = deploy_adapter(
|
|
810
|
+
run_id=run_id,
|
|
811
|
+
model=spec.model,
|
|
812
|
+
hf_repo=spec.train.hf_repo,
|
|
813
|
+
adapter_prefix=deploy_prefix,
|
|
814
|
+
gpu_name=spec.gpu.type,
|
|
815
|
+
dry_run=dry_run,
|
|
816
|
+
# a run trained with thinking serves with thinking (per-run parity)
|
|
817
|
+
thinking=spec.thinking,
|
|
818
|
+
org_id=deploy_org_id,
|
|
819
|
+
)
|
|
820
|
+
except ServingError as exc:
|
|
821
|
+
# The serving backend rejected the registration or was unreachable. This is an
|
|
822
|
+
# upstream/gateway failure, not a flash bug, so surface a clean 502 with the
|
|
823
|
+
# real reason instead of letting httpx escape as an unhandled 500 + traceback.
|
|
824
|
+
raise HTTPException(status_code=502, detail=str(exc)) from exc
|
|
825
|
+
except Exception as exc:
|
|
826
|
+
if isinstance(exc, ValueError):
|
|
827
|
+
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
828
|
+
raise
|
|
829
|
+
dep_dict = dep.to_dict()
|
|
830
|
+
if is_checkpoint:
|
|
831
|
+
dep_dict["checkpoint_step"] = checkpoint_step
|
|
832
|
+
if not dry_run:
|
|
833
|
+
if is_checkpoint and status.state not in _DEPLOYABLE_STATES:
|
|
834
|
+
# Deploying a checkpoint of a run that stopped mid-RL (cancelled/failed):
|
|
835
|
+
# attach the serving deployment but KEEP the run's terminal training state
|
|
836
|
+
# — flipping it to `deployed` would erase the outcome and make undeploy
|
|
837
|
+
# wrongly restore it to `done`.
|
|
838
|
+
attach_checkpoint_deployment(run_id, dep_dict)
|
|
839
|
+
else:
|
|
840
|
+
# Record the deployment. The CAS no-ops only if a /cancel raced finalization
|
|
841
|
+
# — then the adapter we just registered is orphaned, so deregister it and
|
|
842
|
+
# report the conflict instead of a bogus 200.
|
|
843
|
+
marked = mark_deployed(run_id, dep_dict, expect_state=prev_state)
|
|
844
|
+
if marked.state != "deployed":
|
|
845
|
+
with contextlib.suppress(Exception):
|
|
846
|
+
undeploy_adapter(run_id)
|
|
847
|
+
raise HTTPException(
|
|
848
|
+
status_code=409,
|
|
849
|
+
detail=f"run {run_id} became {marked.state!r} during deploy; aborted",
|
|
850
|
+
)
|
|
851
|
+
return dep_dict
|
|
852
|
+
|
|
853
|
+
@app.get("/v1/runs/{run_id}/checkpoints")
|
|
854
|
+
def run_checkpoints(run_id: str, key: dict = Depends(require_key)):
|
|
855
|
+
"""List a run's deployable per-step RL checkpoints (each `flash deploy --step N`-able).
|
|
856
|
+
|
|
857
|
+
Reads the snapshots the worker streamed to HF, and best-effort mirrors them to the
|
|
858
|
+
backend store so a listing also persists them."""
|
|
859
|
+
status = owned_run(run_id, key)
|
|
860
|
+
spec = JobSpec.from_dict(status.spec)
|
|
861
|
+
checkpoints = list_checkpoints(spec)
|
|
862
|
+
with contextlib.suppress(Exception):
|
|
863
|
+
from flash.server.checkpoints import register_checkpoints_best_effort
|
|
864
|
+
|
|
865
|
+
register_checkpoints_best_effort(status)
|
|
866
|
+
return {"run_id": run_id, "checkpoints": checkpoints}
|
|
867
|
+
|
|
868
|
+
@app.delete("/v1/runs/{run_id}/deploy")
|
|
869
|
+
def undeploy(run_id: str, key: dict = Depends(require_key)):
|
|
870
|
+
# Same per-run lock as deploy: an undeploy must not interleave with an in-flight
|
|
871
|
+
# deploy's provisioning/finalization.
|
|
872
|
+
with _deploy_lock(run_id):
|
|
873
|
+
status = owned_run(run_id, key)
|
|
874
|
+
try:
|
|
875
|
+
deleted = undeploy_adapter(run_id)
|
|
876
|
+
except ServingError as exc:
|
|
877
|
+
# A serving-backend failure (unreachable / non-404 error) is an upstream/gateway
|
|
878
|
+
# problem, not a flash bug — surface a clean 502 with the real reason (mirrors the
|
|
879
|
+
# deploy handler) instead of letting the ServingError escape as an unhandled 500.
|
|
880
|
+
raise HTTPException(status_code=502, detail=str(exc)) from exc
|
|
881
|
+
# Delete is idempotent: a missing serving-side adapter still means the local
|
|
882
|
+
# deployment record can be cleared.
|
|
883
|
+
if status.deployment:
|
|
884
|
+
mark_undeployed(run_id)
|
|
885
|
+
return {"run_id": run_id, "deleted_endpoints": deleted}
|
|
886
|
+
|
|
887
|
+
@app.get("/v1/deployments")
|
|
888
|
+
def deployments(key: dict = Depends(require_key)):
|
|
889
|
+
out = []
|
|
890
|
+
for row in db.runs_for_key(key["id"]):
|
|
891
|
+
try:
|
|
892
|
+
status = get_status(row["run_id"])
|
|
893
|
+
except FileNotFoundError:
|
|
894
|
+
continue
|
|
895
|
+
if status.deployment and status.deployment.get("state") not in (
|
|
896
|
+
"undeployed",
|
|
897
|
+
"dry_run",
|
|
898
|
+
):
|
|
899
|
+
out.append(status.to_dict())
|
|
900
|
+
return {"deployments": out}
|
|
901
|
+
|
|
902
|
+
@app.post("/v1/runs/{run_id}/chat")
|
|
903
|
+
def chat(run_id: str, payload: dict, key: dict = Depends(require_key)):
|
|
904
|
+
status = owned_run(run_id, key)
|
|
905
|
+
spec = JobSpec.from_dict(status.spec)
|
|
906
|
+
deployment = status.deployment or {}
|
|
907
|
+
# A cancelled run's serve endpoint was torn down at cancel time; never let a
|
|
908
|
+
# chat recreate it (closes the window before cancel marks the deployment
|
|
909
|
+
# inactive, and covers a teardown that deleted nothing).
|
|
910
|
+
if status.state == "cancelled":
|
|
911
|
+
raise HTTPException(
|
|
912
|
+
status_code=409, detail=f"run {run_id} was cancelled; redeploy is not allowed"
|
|
913
|
+
)
|
|
914
|
+
# Chat must ride an explicit deployment (with its cost controls), not
|
|
915
|
+
# implicitly provision a serving endpoint that /v1/deployments cannot see.
|
|
916
|
+
if deployment.get("state") in (None, "undeployed", "dry_run"):
|
|
917
|
+
raise HTTPException(
|
|
918
|
+
status_code=409,
|
|
919
|
+
detail=f"run {run_id} has no active deployment; `flash deploy {run_id}` first",
|
|
920
|
+
)
|
|
921
|
+
# Legacy run with no artifact repo (mirrors the /deploy guard): a run that never had a
|
|
922
|
+
# [train].hf_repo was never registered with freesolo serving, so reject early with a
|
|
923
|
+
# clear 409 instead of an opaque downstream inference error.
|
|
924
|
+
if not spec.train.hf_repo:
|
|
925
|
+
raise HTTPException(
|
|
926
|
+
status_code=409,
|
|
927
|
+
detail=f"run {run_id} has no [train].hf_repo (legacy run); its adapter cannot be served",
|
|
928
|
+
)
|
|
929
|
+
try:
|
|
930
|
+
if payload.get("stream") is True:
|
|
931
|
+
return StreamingResponse(
|
|
932
|
+
serve_chat_stream(
|
|
933
|
+
run_id=run_id,
|
|
934
|
+
messages=payload.get("messages") or [],
|
|
935
|
+
temperature=float(payload.get("temperature") or 0.0),
|
|
936
|
+
max_tokens=int(payload.get("max_tokens") or 512),
|
|
937
|
+
# a run trained with thinking serves with thinking (per-run parity)
|
|
938
|
+
thinking=spec.thinking,
|
|
939
|
+
),
|
|
940
|
+
media_type="text/plain; charset=utf-8",
|
|
941
|
+
)
|
|
942
|
+
return serve_chat(
|
|
943
|
+
run_id=run_id,
|
|
944
|
+
messages=payload.get("messages") or [],
|
|
945
|
+
temperature=float(payload.get("temperature") or 0.0),
|
|
946
|
+
max_tokens=int(payload.get("max_tokens") or 512),
|
|
947
|
+
# a run trained with thinking serves with thinking (per-run parity)
|
|
948
|
+
thinking=spec.thinking,
|
|
949
|
+
)
|
|
950
|
+
except Exception as exc:
|
|
951
|
+
raise HTTPException(status_code=502, detail=f"inference failure: {exc}") from exc
|
|
952
|
+
|
|
953
|
+
return app
|
|
954
|
+
|
|
955
|
+
|
|
956
|
+
def run_server(host: str = "127.0.0.1", port: int = 8080) -> None:
|
|
957
|
+
try:
|
|
958
|
+
import uvicorn
|
|
959
|
+
except ImportError as exc:
|
|
960
|
+
raise RuntimeError(_SERVER_EXTRAS_HINT) from exc
|
|
961
|
+
uvicorn.run(create_app(), host=host, port=port)
|