freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,82 @@
1
+ """List a run's deployable per-step RL checkpoints from its HF artifact repo.
2
+
3
+ The GPU worker publishes each trainer save's LoRA adapter to a stable, NON-pruned path
4
+ (``<adapter_prefix>/checkpoints/step-<N>/adapter``; see
5
+ ``flash.engine.worker.publish_deployable_checkpoint``). This module is the control-plane
6
+ reader: it enumerates those snapshots so ``flash checkpoints`` can list them and
7
+ ``flash deploy --step N`` can serve a specific one — including for a run that was cancelled or
8
+ failed mid-RL and so never sealed a final adapter. HF (not the backend DB) is the source of
9
+ truth for what's deployable; backend persistence is a mirror (see
10
+ ``flash.server.checkpoints``)."""
11
+
12
+ from __future__ import annotations
13
+
14
+ import os
15
+ import re
16
+
17
+ from flash.runner import adapter_prefix
18
+ from flash.spec import JobSpec
19
+
20
+ # The PEFT weights file a step must carry (alongside adapter_config.json) to be servable.
21
+ _ADAPTER_WEIGHT_FILES = frozenset({"adapter_model.safetensors", "adapter_model.bin"})
22
+
23
+
24
+ def checkpoint_adapter_prefix(spec: JobSpec, step: int, seed: int | None = None) -> str:
25
+ """The ``adapter_prefix`` that serves checkpoint ``step``.
26
+
27
+ ``deploy_adapter`` appends ``/adapter`` to whatever prefix it's given, so this returns the
28
+ per-step root (``<run prefix>/checkpoints/step-<N>``) — matching the worker's upload path —
29
+ and the existing deploy path needs no special-casing for checkpoints."""
30
+ return f"{adapter_prefix(spec, seed)}/checkpoints/step-{step}"
31
+
32
+
33
+ def _adapter_file_re(base: str) -> re.Pattern[str]:
34
+ """Matches ``<base>/checkpoints/step-<N>/adapter/<filename>`` and captures (step, filename)."""
35
+ return re.compile(re.escape(base) + r"/checkpoints/step-(\d+)/adapter/([^/]+)$")
36
+
37
+
38
+ def list_checkpoints(spec: JobSpec, seed: int | None = None) -> list[dict]:
39
+ """Deployable per-step adapter snapshots for ``spec``, ascending by step.
40
+
41
+ A step is included only if its adapter folder carries BOTH ``adapter_config.json`` AND a
42
+ weights file (so ``/deploy --step`` can never target a half-uploaded, unloadable step). Each
43
+ entry: ``{"step", "adapter_prefix", "subfolder", "repo_id", "repo_type"}`` where
44
+ ``adapter_prefix`` is the value to hand ``deploy_adapter`` to serve that exact step and
45
+ ``subfolder`` is the full path of the adapter folder in the repo. Returns ``[]`` when the
46
+ run has no HF repo or no published snapshots (older runs, or none saved yet)."""
47
+ repo = spec.train.hf_repo
48
+ if not repo:
49
+ return []
50
+ base = adapter_prefix(spec, seed)
51
+ pattern = _adapter_file_re(base)
52
+ try:
53
+ from huggingface_hub import HfApi
54
+
55
+ files = HfApi(token=os.environ.get("HF_TOKEN")).list_repo_files(
56
+ repo, repo_type="dataset"
57
+ )
58
+ except Exception as exc: # listing is best-effort; never raise into a run/route
59
+ print(f"[ckpt] list warn for {spec.run_id}: {exc}")
60
+ return []
61
+ # Collect each step's adapter-folder filenames, then keep only steps with config + weights.
62
+ by_step: dict[int, set[str]] = {}
63
+ for path in files:
64
+ match = pattern.search(path)
65
+ if match:
66
+ by_step.setdefault(int(match.group(1)), set()).add(match.group(2))
67
+ out: list[dict] = []
68
+ for step in sorted(by_step):
69
+ names = by_step[step]
70
+ if "adapter_config.json" not in names or names.isdisjoint(_ADAPTER_WEIGHT_FILES):
71
+ continue
72
+ prefix = checkpoint_adapter_prefix(spec, step, seed)
73
+ out.append(
74
+ {
75
+ "step": step,
76
+ "adapter_prefix": prefix,
77
+ "subfolder": f"{prefix}/adapter",
78
+ "repo_id": repo,
79
+ "repo_type": "dataset",
80
+ }
81
+ )
82
+ return out
flash/runner/deploy.py ADDED
@@ -0,0 +1,422 @@
1
+ """Deploy / cancel / recover state transitions for a run.
2
+
3
+ Store helpers and the lifecycle functions (``_run_seed_loop`` / ``_gc_run_endpoints``) are
4
+ pulled in via FUNCTION-LOCAL lazy ``from flash.runner import ...`` imports — never at module
5
+ level — for the same two reasons as ``lifecycle.py``: avoid a partially-initialized-package
6
+ import cycle, and keep the test monkeypatches (e.g. ``flash.runner._gc_run_endpoints``)
7
+ reachable through the package global rather than a statically-bound copy.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import contextlib
13
+ import time
14
+ from typing import TYPE_CHECKING
15
+
16
+ from flash.spec import JobSpec
17
+
18
+ if TYPE_CHECKING:
19
+ # RunStatus lives in flash.runner.__init__ and is only referenced here as a return
20
+ # annotation (stringized by `from __future__ import annotations`), so a TYPE_CHECKING
21
+ # import keeps it resolvable for tooling without a runtime import cycle.
22
+ from flash.runner import RunStatus
23
+
24
+
25
+ def cancel_run(run_id: str) -> RunStatus:
26
+ """Cancel a run: delete its remote Flash endpoint (stopping the worker), then mark it
27
+ cancelled.
28
+
29
+ Uses ``terminate_endpoint`` (reconstructs the run's uniquely-named endpoint and deletes it
30
+ via the RunPod API) so the cancel works **cross-process** — a fresh ``flash cancel`` actually
31
+ stops the GPU worker, instead of leaving it running until the wall cap. Best-effort: any
32
+ teardown error is recorded but still flips the run to ``cancelled``.
33
+ """
34
+ from flash.runner import (
35
+ TERMINAL_STATES,
36
+ _gc_run_endpoints,
37
+ _update,
38
+ get_status,
39
+ mark_deployment_undeployed,
40
+ )
41
+
42
+ status = get_status(run_id)
43
+ if status.state in TERMINAL_STATES:
44
+ return status
45
+ # Whether the run was a live `deployed` serving run at cancel entry. This scopes the
46
+ # final `cancelled` transition's terminal override below: only a `deployed` run can have
47
+ # a concurrent undeploy (`mark_undeployed`) race this teardown and write a non-completion
48
+ # terminal `done`. A non-deployed run (running/provisioning/queued) has an in-flight
49
+ # TRAINING thread whose only terminal `done` is a GENUINE completion — which cancel must
50
+ # never clobber. See the final _update call for how this gates the override.
51
+ entered_deployed = status.state == "deployed"
52
+ spec = JobSpec.from_dict(status.spec)
53
+ remote = status.remote or {}
54
+ # A deployed run also owns a serving registration with the freesolo serving
55
+ # app that the training-endpoint GC below does not touch; deregister it too so
56
+ # a cancelled run can't leave a deployment registered as active.
57
+ if status.state == "deployed":
58
+ try:
59
+ from flash.serve.deploy import undeploy_adapter
60
+
61
+ undeploy_adapter(run_id)
62
+ # Mark the deployment inactive so /v1/deployments and /chat stop treating the
63
+ # cancelled run as active. Delete is idempotent: an already-absent adapter still
64
+ # means the local deployment record can be cleared.
65
+ if status.deployment:
66
+ # Mark the deployment inactive through the lock-guarded path so this write
67
+ # participates in the same _STATUS_LOCK as the rest of the runner. A bare
68
+ # _save_status here would persist a stale pre-teardown snapshot OUTSIDE the
69
+ # lock, bypassing serialization and potentially clobbering a concurrent field
70
+ # update. We mark ONLY the deployment field and leave the run's state alone
71
+ # (no state re-assert): a concurrent mark_undeployed can move the run to
72
+ # terminal `done` between our get_status read and this write, and _update's
73
+ # compare-and-set rejects ANY transition off a terminal state (even a
74
+ # same-field re-assert of the stale `deployed`), which would silently leave
75
+ # the deployment advertised as `ready`. mark_deployment_undeployed flips the
76
+ # deployment regardless of (and without disturbing) the current state.
77
+ mark_deployment_undeployed(run_id)
78
+ except Exception:
79
+ # Best-effort serving teardown: a failure here must not block the cancel
80
+ # below (the run still flips to cancelled and the training endpoint is GC'd).
81
+ pass
82
+ # Durable path first: stop the exact remote worker via the handle's provider
83
+ # (works from any process); endpoint/instance teardown is shared with the GC.
84
+ # Dispatched generically through the registry — never a hardcoded per-provider branch.
85
+ if remote:
86
+ try:
87
+ from flash.providers import get_provider
88
+ from flash.providers.base import JobHandle
89
+
90
+ handle = JobHandle.from_dict(remote)
91
+ provider = get_provider(handle.provider)
92
+ provider.cancel(handle)
93
+ # Belt-and-suspenders destroy after cancel; RunPod endpoint GC follows.
94
+ provider.destroy(handle)
95
+ except Exception:
96
+ # Best-effort remote stop; _gc_run_endpoints below still tears the endpoint down.
97
+ pass
98
+ _gc_run_endpoints(spec)
99
+ # Final transition to `cancelled`. The run was NON-terminal at entry, but teardown takes
100
+ # time and a terminal state can race in mid-teardown. We must distinguish two cases:
101
+ #
102
+ # - A concurrent mark_undeployed() (an external `DELETE /v1/runs/{id}/deploy`) flipped a
103
+ # `deployed` run to terminal `done`. That `done` is NOT a fresh result — it just
104
+ # restored the run's pre-deploy completion marker while retiring serving. The user
105
+ # explicitly asked to cancel, so this must be OVERRIDDEN to `cancelled`.
106
+ # - A genuine training-COMPLETION `done` from the run's own training thread
107
+ # (_run_job_inner / attach_run), which persisted real metrics+cost+artifacts. Cancel
108
+ # must NEVER clobber that — the run finished, so the real result is preserved.
109
+ #
110
+ # These two races are mutually exclusive on the entry state: only a `deployed` run owns a
111
+ # deployment that mark_undeployed can race, and only a non-deployed (running/provisioning/
112
+ # queued) run has an in-flight training thread that can complete mid-teardown. So scope the
113
+ # terminal override to runs that were `deployed` at entry — there a racing `done` is always
114
+ # an undeploy artifact (cancel wins); elsewhere a racing `done` is a genuine completion that
115
+ # _update's CAS correctly protects (cancel loses to a real finish).
116
+ _update(run_id, "cancelled", allow_from_terminal=entered_deployed)
117
+ # A run cancelled mid-RL keeps whatever per-step adapters the worker already streamed to
118
+ # HF; mirror them to the backend store now so the cancelled run is immediately listable +
119
+ # deployable (`flash checkpoints` / `flash deploy --step N`). Best-effort: never let
120
+ # checkpoint bookkeeping fail a cancel.
121
+ with contextlib.suppress(Exception):
122
+ from flash.server.checkpoints import register_checkpoints_best_effort
123
+
124
+ register_checkpoints_best_effort(get_status(run_id))
125
+ return get_status(run_id)
126
+
127
+
128
+ def attach_run(run_id: str, log_stream=None) -> RunStatus:
129
+ """Re-attach to a run's remote job from ANY process (after a client crash/restart).
130
+
131
+ Uses the persisted {endpoint_id, job_id} handle to resume polling; on completion,
132
+ persists metrics exactly like the original client would have, flips the state, and
133
+ GCs the endpoint. Raises if the run has no persisted handle (it failed or was
134
+ cancelled before a worker was provisioned).
135
+ """
136
+ import sys
137
+
138
+ from flash.runner import (
139
+ TERMINAL_STATES,
140
+ _gc_run_endpoints,
141
+ _persist_metrics,
142
+ _run_seed_loop,
143
+ _RunCancelled,
144
+ _update,
145
+ artifacts_dir,
146
+ get_status,
147
+ )
148
+
149
+ status = get_status(run_id)
150
+ if status.state in TERMINAL_STATES:
151
+ return status
152
+ if not status.remote:
153
+ raise ValueError(f"run {run_id} has no persisted job handle; cannot reattach")
154
+
155
+ spec = JobSpec.from_dict(status.spec)
156
+ remote = dict(status.remote)
157
+ seed = int(remote.pop("seed", spec.train.seeds[0]))
158
+ # The class the run actually provisioned (a policy retry may have walked past the
159
+ # provisional spec.gpu.type). The in-process success path stamps this into metrics;
160
+ # on recovery the worker output carries no such field, so recover it from the handle
161
+ # to cost the right card.
162
+ allocated_gpu = remote.pop("allocated_gpu", None)
163
+ log = log_stream or sys.stderr
164
+ # Dispatch the poll generically via the handle's provider (the provider owns its
165
+ # heartbeat reader + poll loop); the orchestrator stays provider-agnostic.
166
+ from flash.providers import get_provider
167
+ from flash.providers.base import JobHandle
168
+
169
+ handle = JobHandle.from_dict(remote)
170
+ print(f"attaching to {run_id}: provider={handle.provider} {handle.data}", file=log)
171
+ res = get_provider(handle.provider).poll(handle, spec, seed, log=log)
172
+ try:
173
+ # A best-effort cancel deletes the job/instance, which the poller reports as a
174
+ # failure (or a late worker may still succeed) — either way, re-read the state
175
+ # first so a recovery thread can't overwrite the user's terminal `cancelled`.
176
+ if get_status(run_id).state == "cancelled":
177
+ return get_status(run_id)
178
+ if not res.ok:
179
+ # Job ended not-ok — usually because it was abandoned during the redeploy. Resume the
180
+ # in-flight seed from its last HF checkpoint instead of failing; the seed loop
181
+ # (unchanged) still terminates a genuinely broken run when it re-fails.
182
+ try:
183
+ seed_index = list(spec.train.seeds).index(seed)
184
+ except ValueError:
185
+ seed_index = 0
186
+ print(
187
+ f"attach: {run_id} seed {seed} ended ({res.failure}); resuming from checkpoint",
188
+ file=log,
189
+ )
190
+ # GC the dead endpoint, then clear the stale handle and record the seed so a second
191
+ # restart mid-allocation resumes the right one.
192
+ with contextlib.suppress(Exception):
193
+ _gc_run_endpoints(spec)
194
+ # Bail if the run was raced to terminal during the long poll above: _update's CAS
195
+ # returns False, and resuming would submit paid work for a dead run.
196
+ if not _update(run_id, "running", remote=None, resume_seed_index=seed_index):
197
+ print(f"attach: {run_id} went terminal during recovery; not resuming", file=log)
198
+ return get_status(run_id)
199
+ _run_seed_loop(
200
+ spec, log, start_index=seed_index, prior_cost=float(status.cost_usd or 0.0)
201
+ )
202
+ return get_status(run_id)
203
+ # Carry the provisioned class into metrics so _persist_metrics costs the card the
204
+ # run actually used (the in-process path stamps this; recovery must restore it).
205
+ if allocated_gpu and isinstance(res.metrics, dict):
206
+ res.metrics.setdefault("allocated_gpu", allocated_gpu)
207
+ # Earlier seeds of a multi-seed run already persisted their cost into
208
+ # status.cost_usd; add this seed's so recovery doesn't underreport spend.
209
+ total = float(status.cost_usd or 0.0) + _persist_metrics(spec, seed, res.metrics)
210
+ # A cancel can land while this thread persists the recovered seed's metrics
211
+ # (after the late-cancel check above). Re-read before the post-seed writes so
212
+ # the "running" update and the terminal "done" below can't resurrect a
213
+ # user-cancelled run (mirrors the fresh seed loop). _RunCancelled is caught
214
+ # below, leaving the cancellation intact.
215
+ if get_status(run_id).state == "cancelled":
216
+ raise _RunCancelled(f"run {run_id} was cancelled")
217
+ # The remote handle only identifies the seed that was in flight. For a
218
+ # multi-seed run, resume the remaining seeds instead of terminally
219
+ # completing the whole run after just this one.
220
+ try:
221
+ resumed_index = list(spec.train.seeds).index(seed) + 1
222
+ except ValueError:
223
+ resumed_index = len(spec.train.seeds)
224
+ more_seeds = resumed_index < len(spec.train.seeds)
225
+ # Clear the now-stale completed handle before resuming. In the
226
+ # allocation/provisioning gap before the next seed's on_handle() persists a
227
+ # fresh handle, a server restart must not reattach recovery to this finished
228
+ # job — that would double-count its cost and replay the wrong seed. Record the
229
+ # next seed index so a restart in that gap resumes the remaining seeds rather
230
+ # than failing the run. (The last seed keeps its handle for post-run
231
+ # observability, mirroring the fresh-submit seed loop.)
232
+ applied = _update(
233
+ run_id,
234
+ "running",
235
+ cost_usd=total,
236
+ artifacts_dir=artifacts_dir(spec),
237
+ **({"remote": None, "resume_seed_index": resumed_index} if more_seeds else {}),
238
+ )
239
+ # Same TOCTOU guard as the not-ok recovery path: a concurrent thread can flip this
240
+ # run terminal (e.g. failed/done from another recovery) between the cancel re-check
241
+ # above and here. The sticky CAS rejects the `running` write (applied is False) — so
242
+ # don't resume the remaining seeds and submit paid GPU work for an already-terminal
243
+ # run. (The non-multi-seed arm writes the terminal `done`; the CAS protects a racing
244
+ # terminal there too, so no extra guard is needed.)
245
+ if more_seeds:
246
+ if not applied:
247
+ print(
248
+ f"attach: {run_id} went terminal during recovery; "
249
+ "not resuming the remaining seeds",
250
+ file=log,
251
+ )
252
+ return get_status(run_id)
253
+ _run_seed_loop(spec, log, start_index=resumed_index, prior_cost=total)
254
+ else:
255
+ _update(run_id, "done", cost_usd=total, artifacts_dir=artifacts_dir(spec))
256
+ except _RunCancelled:
257
+ # Intentional: cancel_run already wrote the terminal `cancelled` state; leave it.
258
+ pass
259
+ except Exception as exc:
260
+ if get_status(run_id).state != "cancelled":
261
+ _update(run_id, "failed", error=str(exc))
262
+ finally:
263
+ _gc_run_endpoints(spec)
264
+ return get_status(run_id)
265
+
266
+
267
+ def resume_run(run_id: str, log_stream=None) -> RunStatus:
268
+ """Resume the remaining seeds of a multi-seed run after a restart in the inter-seed gap.
269
+
270
+ Between two seeds the completed seed's handle is cleared and ``resume_seed_index`` is
271
+ recorded (see ``_run_seed_loop``). A control-plane restart in that handle-less window
272
+ must RESUME from that index rather than fail the run and discard the finished seeds.
273
+ Unlike ``attach_run`` there is no live job to poll — the prior process already tore the
274
+ seed's endpoint down — so we start a fresh seed loop from the recorded index. The flash
275
+ package was uploaded to HF on the original submit, so the worker can still fetch it; no
276
+ re-upload is needed.
277
+ """
278
+ import sys
279
+
280
+ from flash.runner import (
281
+ TERMINAL_STATES,
282
+ _gc_run_endpoints,
283
+ _run_seed_loop,
284
+ _RunCancelled,
285
+ _update,
286
+ get_status,
287
+ )
288
+
289
+ status = get_status(run_id)
290
+ if status.state in TERMINAL_STATES:
291
+ return status
292
+ if status.resume_seed_index is None:
293
+ raise ValueError(f"run {run_id} has no resume_seed_index; cannot resume")
294
+ spec = JobSpec.from_dict(status.spec)
295
+ log = log_stream or sys.stderr
296
+ print(f"resuming {run_id}: remaining seeds from index {status.resume_seed_index}", file=log)
297
+ try:
298
+ _run_seed_loop(
299
+ spec,
300
+ log,
301
+ start_index=status.resume_seed_index,
302
+ prior_cost=float(status.cost_usd or 0.0),
303
+ )
304
+ except _RunCancelled:
305
+ pass # cancel_run already set the terminal state
306
+ except Exception as exc:
307
+ if get_status(run_id).state != "cancelled":
308
+ _update(run_id, "failed", error=str(exc))
309
+ finally:
310
+ # Mirror _run_job: GC any endpoint a transient destroy left behind rather than
311
+ # leaking a billable RunPod endpoint.
312
+ _gc_run_endpoints(spec)
313
+ return get_status(run_id)
314
+
315
+
316
+ def mark_deployed(run_id: str, deployment: dict, expect_state: str | None = None) -> RunStatus:
317
+ from flash.runner import _STATUS_LOCK, _UNDEPLOYABLE_STATES, _save_status, get_status
318
+
319
+ # Atomic + terminal-respecting (same guard as _update): a /cancel landing during
320
+ # deployment writes `cancelled`; this must NOT overwrite it with
321
+ # `deployed` and resurrect the run as an active deployment. `done` is deployable
322
+ # though (the common case: deploy a finished run), so only the non-`done` terminal
323
+ # states block here — otherwise a freshly finished run could never be deployed.
324
+ #
325
+ # expect_state is a compare-and-set: the deploy flow passes the state it expects the
326
+ # run to still be in (the pre-deploy snapshot, or "deployed" after the provisional
327
+ # mark). If an undeploy raced finalization — deleting the endpoint and writing `done`
328
+ # with deployment.state="undeployed" mid-warmup — the state no longer matches and we
329
+ # refuse to re-advertise the just-deleted endpoint.
330
+ with _STATUS_LOCK:
331
+ status = get_status(run_id)
332
+ if status.state in _UNDEPLOYABLE_STATES:
333
+ return status
334
+ if expect_state is not None and status.state != expect_state:
335
+ return status
336
+ # Freeze the training-teardown time before the deploy bumps updated_at. New terminal runs
337
+ # already stamp finished_at on their first terminal transition, but a LEGACY run that went
338
+ # `done` before that field existed has finished_at=None while its current updated_at still
339
+ # holds the real teardown time. Capture it ONLY on the `done` -> `deployed` transition, where
340
+ # updated_at == teardown (mark_deployed is also called on an already-`deployed` run via the
341
+ # CAS finalization with expect_state="deployed", where updated_at is the DEPLOY time), and
342
+ # only when not yet reconciled — record_realized_cost moves updated_at to the reconcile time,
343
+ # so a reconciled-then-deployed legacy run would otherwise freeze that later stamp. Both
344
+ # guards keep us from stamping past the real teardown (the over-billing this fixes); a
345
+ # reconciled run is never re-billed, so leaving finished_at unset there is harmless.
346
+ if status.state == "done" and status.finished_at is None and not status.reconciled_at:
347
+ status.finished_at = status.updated_at
348
+ status.deployment = deployment
349
+ status.state = "deployed"
350
+ status.updated_at = time.time()
351
+ _save_status(status)
352
+ return status
353
+
354
+
355
+ def attach_checkpoint_deployment(run_id: str, deployment: dict) -> RunStatus:
356
+ """Attach a serving deployment to a run WITHOUT changing its training state.
357
+
358
+ Used when deploying a specific intermediate checkpoint of a run that never reached
359
+ ``done`` — e.g. one cancelled or failed mid-RL. The checkpoint adapter exists on HF, so it
360
+ can be served, but the run's terminal training outcome (``cancelled``/``failed``) must be
361
+ preserved: flipping it to ``deployed`` would both erase that outcome and make a later
362
+ undeploy wrongly restore it to ``done`` (``mark_undeployed`` sends non-terminal runs to
363
+ ``done``). The deployment is tracked via the ``deployment`` field exactly like a normal
364
+ deploy, so ``/v1/deployments`` lists it and undeploy clears it. Lock-guarded so it
365
+ serializes with a racing deploy/undeploy on the same run.
366
+ """
367
+ from flash.runner import _STATUS_LOCK, _save_status, get_status
368
+
369
+ with _STATUS_LOCK:
370
+ status = get_status(run_id)
371
+ status.deployment = deployment
372
+ status.updated_at = time.time()
373
+ _save_status(status)
374
+ return status
375
+
376
+
377
+ def mark_undeployed(run_id: str) -> RunStatus:
378
+ """Record an explicit undeploy (endpoint torn down -> run back to `done`).
379
+
380
+ Lock-guarded so it serializes with a racing deploy finalization: the raw read +
381
+ _save_status the endpoint used to do could interleave with mark_deployed and be
382
+ clobbered. With this under the same lock, mark_deployed's expect_state CAS then sees
383
+ the `done`/undeployed write and won't re-advertise the deleted endpoint.
384
+ """
385
+ from flash.runner import _STATUS_LOCK, TERMINAL_STATES, _save_status, get_status
386
+
387
+ with _STATUS_LOCK:
388
+ status = get_status(run_id)
389
+ if status.deployment:
390
+ status.deployment = {**status.deployment, "state": "undeployed"}
391
+ # Record the teardown but don't resurrect a terminal run: undeploying a
392
+ # cancelled/failed run keeps its terminal state (only a live `deployed` run goes
393
+ # back to `done`). `done` is terminal too, so this naturally no-ops the state.
394
+ if status.state not in TERMINAL_STATES:
395
+ status.state = "done"
396
+ status.updated_at = time.time()
397
+ _save_status(status)
398
+ return status
399
+
400
+
401
+ def mark_deployment_undeployed(run_id: str) -> RunStatus:
402
+ """Flip ONLY the deployment field to ``undeployed``, leaving the run's state untouched.
403
+
404
+ Used by ``cancel_run`` to retire a deployed run's serving record. Unlike
405
+ ``mark_undeployed`` (which is a state transition: a live `deployed` run goes back to
406
+ `done`), this never asserts or changes the run state. That matters under the cancel
407
+ race: a concurrent ``mark_undeployed`` may have already moved the run to terminal
408
+ `done`, and ``_update``'s compare-and-set rejects any transition off a terminal state —
409
+ even re-asserting `deployed` to carry the deployment field — which would leave the
410
+ deployment advertised as `ready`. Marking the field directly (lock-guarded for
411
+ serialization) sidesteps the CAS so the deployment reliably ends `undeployed`, while the
412
+ trailing ``cancelled`` transition is left to ``_update``.
413
+ """
414
+ from flash.runner import _STATUS_LOCK, _save_status, get_status
415
+
416
+ with _STATUS_LOCK:
417
+ status = get_status(run_id)
418
+ if status.deployment:
419
+ status.deployment = {**status.deployment, "state": "undeployed"}
420
+ status.updated_at = time.time()
421
+ _save_status(status)
422
+ return status