freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,672 @@
1
+ """Run-execution machinery: the submit -> seed-loop -> per-seed supervised job -> GC flow.
2
+
3
+ Store helpers (get_status/_update/_save_status/artifacts_dir/_persist_metrics/RUNS_DIR/...)
4
+ and sibling lifecycle functions are pulled in via FUNCTION-LOCAL lazy
5
+ ``from flash.runner import ...`` imports — never at module level. That avoids a
6
+ partially-initialized-package import cycle (``flash.runner.__init__`` imports this module
7
+ while still being defined) AND keeps the test monkeypatches reachable: a reader that resolves
8
+ ``RUNS_DIR`` / ``_gc_run_endpoints`` / ``_run_job`` through the package global picks up
9
+ ``monkeypatch.setattr(runner, ...)`` instead of a statically-bound copy.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import contextlib
15
+ import os
16
+ import time
17
+
18
+ from flash.spec import JobSpec
19
+
20
+
21
+ def _run_job(spec: JobSpec, runtime_secrets: dict[str, str] | None = None) -> None:
22
+ # Lazy import so dry-run / unit tests never construct a Flash endpoint.
23
+ from flash.providers.runpod.train import upload_code
24
+ from flash.runner import (
25
+ RUNS_DIR,
26
+ TERMINAL_STATES,
27
+ _gc_run_endpoints,
28
+ _run_job_inner,
29
+ _update,
30
+ get_status,
31
+ )
32
+
33
+ # A cancel can land between the queued status being returned to the client and
34
+ # this background thread starting; don't overwrite a terminal state (cancelled)
35
+ # with provisioning and then launch a paid seed as if the cancel never happened.
36
+ if get_status(spec.run_id).state in TERMINAL_STATES:
37
+ return
38
+ _update(spec.run_id, "provisioning")
39
+ log_path = os.path.join(RUNS_DIR, f"{spec.run_id}.log")
40
+ try:
41
+ _run_job_inner(spec, log_path, upload_code, runtime_secrets=runtime_secrets)
42
+ finally:
43
+ # Endpoint GC: every run leaves its uniquely-named endpoint registered, and the
44
+ # account-wide *max workers quota* (5 by default) counts registered endpoints —
45
+ # after a handful of runs, ALL new submissions fail with "Max workers across all
46
+ # endpoints must not exceed your workers quota". Tear ours down on any terminal
47
+ # state (best-effort; never raises).
48
+ _gc_run_endpoints(spec)
49
+
50
+
51
+ def _spec_with_gpu(spec: JobSpec, gpu_type: str) -> JobSpec:
52
+ """The spec the workers/loggers see for THIS attempt's allocated class."""
53
+ if spec.gpu.type == gpu_type:
54
+ return spec
55
+ d = spec.to_dict()
56
+ d["gpu"] = {**d["gpu"], "type": gpu_type}
57
+ return JobSpec.from_dict(d)
58
+
59
+
60
+ def _drop_weight_cache(spec: JobSpec) -> JobSpec:
61
+ """Spec with the SHARED weight-cache volume removed (run cold + fully cross-region).
62
+
63
+ Used after a no-capacity attempt: attaching the cache restricts the endpoint to the cache's
64
+ datacenter set, so if that whole set is momentarily starved the next attempt should fall back to
65
+ the unrestricted all-DC pool. Dropping ``network_volume`` makes weight_cache_endpoint_kwargs
66
+ return ``{}`` (no volume, no datacenter list) and turns off the worker's HF_HOME redirect — i.e.
67
+ exactly today's cold cross-region behavior. Worst case for the cache is one capacity-grace wait,
68
+ never a permanent IN_QUEUE block.
69
+
70
+ ONLY the platform-managed SHARED cache (``WEIGHT_CACHE_VOLUME_NAME``) is dropped. A non-shared
71
+ per-org/custom ``network_volume`` is a deliberate escape-hatch isolation (see
72
+ runner._assign_weight_cache_volume) the user opted into — it is PRESERVED across retries rather
73
+ than silently stripped.
74
+ """
75
+ from flash.runner import WEIGHT_CACHE_VOLUME_NAME
76
+
77
+ if getattr(spec.gpu, "network_volume", None) != WEIGHT_CACHE_VOLUME_NAME:
78
+ return spec
79
+ d = spec.to_dict()
80
+ d["gpu"] = {**d["gpu"], "network_volume": None}
81
+ return JobSpec.from_dict(d)
82
+
83
+
84
+ def _select_candidate(candidates, failed_providers: set[str], tried_classes: set[tuple[str, str]]):
85
+ """Pick the next (provider, class) to try from the cross-provider ranked candidate list.
86
+
87
+ ``candidates`` is already price-sorted (cheapest first). On the FIRST attempt — nothing failed
88
+ yet — this returns the cheapest overall, unchanged. On an infra-shaped RETRY it ESCAPES the
89
+ failed substrate *cross-provider* before walking classes within it:
90
+
91
+ * a congested provider (RunPod queue timeout / no warm workers) is left for a DIFFERENT
92
+ provider (Hyperstack / Lambda) on retry instead of hopping to its next-cheapest class —
93
+ which, when the whole provider is busy, is just as likely to time out (issue: A6000 queue
94
+ timeout retried onto another RunPod class while Hyperstack A6000 sat available); and
95
+ * a provider handing out a broken GPU (a Hyperstack VM whose CUDA never comes up ->
96
+ ``job_preempted``) is likewise escaped to another provider rather than re-rolling the same
97
+ broken region.
98
+
99
+ When every provider has already burned a retry (or only one provider is configured) it falls
100
+ back to the cheapest class NOT yet tried, preserving the within-provider class walk.
101
+
102
+ Keyed on (provider, gpu) IDENTITY, never a list index, so it stays correct even though each
103
+ attempt re-allocates and the live-capacity ordering can shift between attempts.
104
+ """
105
+ return min(
106
+ candidates,
107
+ key=lambda c: (
108
+ c.provider in failed_providers, # 1) escape providers that already failed this run
109
+ (c.provider, c.gpu) in tried_classes, # 2) then prefer a class not yet tried
110
+ c.hourly_usd, # 3) then cheapest
111
+ c.vram_gb, # 4) then the smaller card (don't burn a big GPU on a small job)
112
+ ),
113
+ )
114
+
115
+
116
+ def _submit_seed_supervised(
117
+ spec: JobSpec,
118
+ seed: int,
119
+ log,
120
+ runtime_secrets: dict[str, str] | None = None,
121
+ ) -> dict:
122
+ """Run one seed with the job submit/poll path + bounded auto-retry.
123
+
124
+ Each attempt first ALLOCATES the GPU: the cheapest fitting class across every active provider
125
+ (RunPod's validated pool + any Lambda/Hyperstack class with live capacity), price-ranked. There
126
+ is no GPU pin — the cheapest fitting class wins the first attempt.
127
+
128
+ Retries (fresh job on a fresh host; worker resumes from the latest HF checkpoint) when the
129
+ failure looks infra-shaped: a stall (heartbeat frozen), no capacity, a client polling breakdown,
130
+ or a platform TIMED_OUT/preemption/worker-loss. Each infra retry ESCAPES the provider that just
131
+ failed cross-provider before walking classes within it (see ``_select_candidate``), so a
132
+ congested provider (RunPod queue timeout) or one handing out a broken GPU (a Hyperstack VM whose
133
+ CUDA never inits) is left for a healthy substrate rather than re-rolling the same failure.
134
+ Genuine worker errors (the run's code crashed; traceback persisted to HF) fail
135
+ immediately.
136
+ """
137
+ from flash.providers import get_provider
138
+ from flash.providers.allocator import allocate, allocation_summary
139
+ from flash.providers.base import PollResult
140
+ from flash.runner import TERMINAL_STATES, _RunCancelled, _spec_with_gpu, _update, get_status
141
+
142
+ last_handle: dict = {}
143
+ # The friendly GPU class the CURRENT attempt provisioned (set right before each submit),
144
+ # so on_handle persists it into the run handle and a recovery via attach_run costs the
145
+ # class actually used rather than the parse-time provisional spec.gpu.type.
146
+ current_gpu: dict = {}
147
+ # Whether the CURRENT attempt's class is the last gpu-walk candidate (set right before each
148
+ # submit). Persisted into the run handle so a recovery via attach_run polls with the SAME
149
+ # no-capacity stall tuning the original submit used (see jobs.stall_kwargs / RunpodProvider.poll)
150
+ # — otherwise a reattached last-candidate run would be judged on the shorter non-last grace.
151
+ current_on_last_gpu: dict = {"value": False}
152
+ # Every RunPod endpoint id this run registered across attempts. Retries run on
153
+ # rN-suffixed endpoints whose names _gc_run_endpoints cannot reconstruct, and a
154
+ # failed delete during the next attempt's teardown would otherwise lose the id;
155
+ # GC the whole set at exit so no retry endpoint leaks against the worker quota.
156
+ seen_endpoints: set[str] = set()
157
+
158
+ def on_handle(handle: dict):
159
+ last_handle.clear()
160
+ last_handle.update(handle)
161
+ if handle.get("endpoint_id"):
162
+ seen_endpoints.add(handle["endpoint_id"])
163
+ _update(
164
+ spec.run_id,
165
+ "running",
166
+ remote={
167
+ **handle,
168
+ "seed": int(seed),
169
+ "allocated_gpu": current_gpu.get("name"),
170
+ "on_last_gpu": bool(current_on_last_gpu["value"]),
171
+ },
172
+ )
173
+
174
+ def _gc_seen_endpoints() -> None:
175
+ if not seen_endpoints:
176
+ return
177
+ from flash.providers.runpod import api as runpod_api
178
+
179
+ for eid in seen_endpoints:
180
+ with contextlib.suppress(Exception):
181
+ runpod_api.delete_endpoint(eid)
182
+
183
+ max_retries = int(spec.gpu.max_retries)
184
+ last_detail = None
185
+ # Sticky: once a no-capacity failure shows the weight-cache datacenter set is starved, drop the
186
+ # cache (volume) for every remaining attempt so they run on the unrestricted all-DC pool.
187
+ drop_weight_cache = False
188
+ # The platform auto-attaches the SHARED weight cache (runner._assign_weight_cache_volume), so its
189
+ # endpoint-pinning DC-set restriction must not cost the USER a GPU-walk retry. Grant ONE extra,
190
+ # cache-less fallback attempt — consumed ONLY by the cache-drop transition below (the stop check
191
+ # gates the bonus on ``first_cache_drop``, never on a plain GPU walk) — so a no_capacity/poll_error
192
+ # the cache's datacenter set could have caused always earns one unrestricted cross-region retry,
193
+ # even at ``max_retries == 0`` (where the auto-cache would otherwise fail a run a cache-less launch
194
+ # could have won). A non-shared per-org/custom volume is the user's own choice and earns no bonus.
195
+ from flash.runner import WEIGHT_CACHE_VOLUME_NAME
196
+
197
+ started_with_shared_cache = getattr(spec.gpu, "network_volume", None) == WEIGHT_CACHE_VOLUME_NAME
198
+ cache_fallback_attempts = 1 if started_with_shared_cache else 0
199
+ # Cross-provider retry memory. ``failed_providers`` are the providers that consumed an
200
+ # infra-shaped attempt; ``tried_classes`` the exact (provider, gpu) pairs already attempted.
201
+ # Both grow only when an attempt that ACTUALLY provisioned a class lost it to an infra failure
202
+ # (see the retry tail) — a failed allocation never tried a card, so it can't poison the next
203
+ # pick. ``_select_candidate`` reads them to escape a sick/congested provider cross-provider on
204
+ # retry before walking classes within it.
205
+ failed_providers: set[str] = set()
206
+ tried_classes: set[tuple[str, str]] = set()
207
+ # Attempts spent on the cache-drop fallback, EXCLUDED from the GPU-walk budget. The bonus slot
208
+ # ``cache_fallback_attempts`` widens the loop range, but the budget checks below use the raw attempt
209
+ # counter; without this offset the cache-drop attempt would still tick the budget, so a run that
210
+ # spends its bonus on the cache drop could never reach its real ``max_retries`` GPU-walk retries
211
+ # (the fallback would silently steal the only user retry). ``walk_attempt`` = attempt index with the
212
+ # cache-drop attempt(s) removed, so the GPU walk gets its full budget AFTER a cache drop.
213
+ cache_drop_consumed = 0
214
+ for attempt in range(max_retries + 1 + cache_fallback_attempts):
215
+ walk_attempt = attempt - cache_drop_consumed
216
+ if attempt > 0 and last_handle:
217
+ # A stalled/timed-out attempt often means the worker is pinned to a
218
+ # throttled/sick host; tear it down so the fresh deploy lands elsewhere.
219
+ if last_handle.get("endpoint_id"):
220
+ try:
221
+ from flash.providers.runpod import api as runpod_api
222
+
223
+ runpod_api.cancel_job(last_handle["endpoint_id"], last_handle["job_id"])
224
+ runpod_api.delete_endpoint(last_handle["endpoint_id"])
225
+ print(
226
+ f"retry {attempt}: deleted endpoint {last_handle['endpoint_id']} "
227
+ "(escaping throttled/sick host)",
228
+ file=log,
229
+ flush=True,
230
+ )
231
+ except Exception:
232
+ # Logging the host-escape note is cosmetic; never let it abort the retry.
233
+ pass
234
+ elif last_handle.get("provider") in ("lambda", "hyperstack"):
235
+ # An instance-based provider bills until terminated: tear the previous attempt's
236
+ # instance down so the retry lands on a fresh host (and we stop paying for the sick
237
+ # one). Dispatched generically through the handle's provider (destroy() knows the
238
+ # provider's own id field — instance_id for Lambda, vm_id for Hyperstack).
239
+ with contextlib.suppress(Exception):
240
+ from flash.providers import get_provider
241
+ from flash.providers.base import JobHandle
242
+
243
+ _prov = last_handle["provider"]
244
+ get_provider(_prov).destroy(JobHandle.from_dict(last_handle))
245
+ _iid = last_handle.get("instance_id") or last_handle.get("vm_id")
246
+ print(
247
+ f"retry {attempt}: terminated {_prov} instance {_iid} (escaping sick host)",
248
+ file=log,
249
+ flush=True,
250
+ )
251
+ # The previous endpoint is now deleted; clear the persisted handle so a cancel
252
+ # or control-plane restart during the fresh deploy doesn't operate on (or get
253
+ # shielded by) the dead handle. The next on_handle() records the new one.
254
+ with contextlib.suppress(FileNotFoundError):
255
+ st = get_status(spec.run_id)
256
+ if st.state not in TERMINAL_STATES and st.remote is not None:
257
+ _update(spec.run_id, st.state, remote=None)
258
+ res = None
259
+ alloc = None
260
+ chosen = None
261
+ # A cancel can land after _run_seed_loop's pre-submit check but while
262
+ # allocation/pricing runs, when no handle exists yet for cancel_run() to
263
+ # delete. Re-read state right before paid provisioning so a cancelled run
264
+ # never launches a worker (the later checks only stop the final-state
265
+ # overwrite, after the GPU has already run and billed).
266
+ with contextlib.suppress(FileNotFoundError):
267
+ if get_status(spec.run_id).state == "cancelled":
268
+ raise _RunCancelled(f"run {spec.run_id} was cancelled")
269
+ try:
270
+ alloc = allocate(
271
+ spec.model,
272
+ spec.algorithm,
273
+ # Pass the run's train knobs + thinking so the VRAM estimate reflects THIS job's
274
+ # max_length / group_size / batch_size / lora_rank (and the seq escalation) instead
275
+ # of the generic defaults — else a long-context / big-group run is sized at seq=1024
276
+ # and OOMs the card it picks.
277
+ train=spec.train,
278
+ thinking=spec.thinking,
279
+ )
280
+ except Exception as exc:
281
+ from flash.providers.base import UnsupportedGpuError
282
+
283
+ if isinstance(exc, UnsupportedGpuError):
284
+ raise # config-shaped: no GPU anywhere can run this job
285
+ res = PollResult(False, failure="poll_error", detail=f"allocation: {exc}")
286
+ if alloc is not None:
287
+ # Re-check cancellation right before provisioning so a cancel during allocation
288
+ # doesn't still launch a paid worker.
289
+ with contextlib.suppress(FileNotFoundError):
290
+ if get_status(spec.run_id).state == "cancelled":
291
+ raise _RunCancelled(f"run {spec.run_id} was cancelled")
292
+ # Pick this attempt's (provider, class) from the cross-provider ranked list: the first
293
+ # attempt takes the cheapest; each retry that provisioned a class and lost it to an infra
294
+ # failure ESCAPES that provider before walking classes within it (see _select_candidate),
295
+ # so a congested/sick provider can't burn the whole budget.
296
+ chosen = _select_candidate(alloc.candidates, failed_providers, tried_classes)
297
+ # ``on_last_gpu`` == NO further GPU attempt will be made after this one — either the
298
+ # candidate list is exhausted (``len(untried) <= 1``) OR the retry budget is exhausted
299
+ # (``attempt >= max_retries``, including the single-attempt ``max_retries == 0`` case).
300
+ # Any remaining alternates are only ever reached on a RETRY, so on the final iteration
301
+ # there is no next-best GPU to fall back to regardless of how many candidates remain.
302
+ # Tell the provider so its no-capacity backstops wait longer before giving up rather than
303
+ # failing fast into a retry that will never happen. A pinned/single-candidate run is
304
+ # "last" from attempt 0, which is what we want.
305
+ untried = [c for c in alloc.candidates if (c.provider, c.gpu) not in tried_classes]
306
+ # The cache-drop fallback (cache_fallback_attempts) is a reserved attempt PAST the retry
307
+ # budget, so when it's still available a cache-attached RunPod attempt is not "last" by
308
+ # BUDGET — don't let ``attempt >= max_retries`` mark it last-GPU (long no-capacity grace),
309
+ # so a no_capacity fails fast into that fallback (notably at max_retries == 0). This only
310
+ # gates the BUDGET clause: genuine class exhaustion (``len(untried) <= 1``) still marks
311
+ # last-GPU (the fallback re-uses the same class cache-less — there's no OTHER class to walk
312
+ # to), preserving the walk semantics for non-cache-caused failures (e.g. a stalled walk).
313
+ cache_fallback_available = (
314
+ started_with_shared_cache
315
+ and not drop_weight_cache
316
+ and chosen is not None
317
+ and chosen.provider == "runpod"
318
+ )
319
+ on_last_gpu = len(untried) <= 1 or (
320
+ walk_attempt >= max_retries and not cache_fallback_available
321
+ )
322
+ # Mirror into the closure cell so on_handle persists THIS attempt's value (see
323
+ # current_on_last_gpu) for a recovery to reproduce the same stall tuning.
324
+ current_on_last_gpu["value"] = on_last_gpu
325
+ print(allocation_summary(alloc), file=log, flush=True)
326
+ if (chosen.provider, chosen.gpu) != (alloc.provider, alloc.gpu):
327
+ print(
328
+ f"retry {attempt}: walking past the cheapest class to {chosen.gpu} "
329
+ f"@ {chosen.provider} ${chosen.hourly_usd:.2f}/hr",
330
+ file=log,
331
+ flush=True,
332
+ )
333
+ run_spec = _spec_with_gpu(spec, chosen.gpu)
334
+ # After a no-capacity attempt, fall back to a cache-less cross-region run (see
335
+ # drop_weight_cache below): the attached cache pins the endpoint to its DC set, so the
336
+ # fallback must run on the unrestricted pool.
337
+ if drop_weight_cache:
338
+ run_spec = _drop_weight_cache(run_spec)
339
+ current_gpu["name"] = chosen.gpu
340
+ provider = get_provider(chosen.provider)
341
+ try:
342
+ submit_kwargs = {
343
+ "log": log,
344
+ "on_handle": on_handle,
345
+ "attempt": attempt,
346
+ "on_last_gpu": on_last_gpu,
347
+ }
348
+ if runtime_secrets:
349
+ submit_kwargs["runtime_secrets"] = runtime_secrets
350
+ res = provider.submit_run(run_spec, seed, **submit_kwargs)
351
+ except Exception as exc:
352
+ # Deploy/submit themselves can fail transiently (observed: RunPod
353
+ # GraphQL "Something went wrong" x3 during a retry deploy). That must
354
+ # consume a retry, not kill the run — the budget exists precisely for flakes.
355
+ res = PollResult(False, failure="poll_error", detail=f"deploy/submit: {exc}")
356
+ if attempt < max_retries:
357
+ time.sleep(10 * (attempt + 1)) # let the transient clear
358
+ if res.ok:
359
+ # A best-effort cancel may fail to stop the worker, which then completes
360
+ # successfully after cancel_run() persisted `cancelled`. Don't let a late
361
+ # worker success resurrect the run into running/done.
362
+ try:
363
+ if get_status(spec.run_id).state == "cancelled":
364
+ raise _RunCancelled(f"run {spec.run_id} was cancelled")
365
+ except FileNotFoundError:
366
+ # Status file not yet written (early race): treat as not-cancelled, proceed.
367
+ pass
368
+ # Worker is done (DONE sentinel seen); GC every endpoint this seed used,
369
+ # including intermediate rN retries _gc_run_endpoints can't name.
370
+ _gc_seen_endpoints()
371
+ # Record the class actually allocated so _persist_metrics rates the right
372
+ # RunPod card when a policy GPU was re-allocated away from the provisional.
373
+ if chosen is not None and isinstance(res.metrics, dict):
374
+ res.metrics.setdefault("allocated_gpu", chosen.gpu)
375
+ return res.metrics
376
+ last_detail = f"{res.failure}: {res.detail}"
377
+ # Retry only on a structured failure category the provider already classified; a real job
378
+ # failure fails fast. No detail-string parsing. (USER cancels are caught below, not here.)
379
+ infra_shaped = res.failure in ("stalled", "no_capacity", "poll_error", "job_preempted")
380
+ # A cancel deletes the endpoint, which the poller sees as an
381
+ # infra-shaped failure; retrying would resurrect the run and keep
382
+ # billing. The user's cancel wins over the retry budget.
383
+ try:
384
+ if get_status(spec.run_id).state == "cancelled":
385
+ raise _RunCancelled(f"run {spec.run_id} was cancelled")
386
+ except FileNotFoundError:
387
+ # Status file not yet written (early race): treat as not-cancelled and proceed.
388
+ pass
389
+ # Best-effort cache-drop fallback — computed BEFORE the log + budget stop so both reflect it.
390
+ # If a VOLUME-BACKED RunPod attempt failed in a way the cache could have caused — no_capacity
391
+ # (the cache restricts the endpoint to its DC set) or a deploy/submit poll_error (e.g. the SDK
392
+ # failing to create/attach a volume) — drop the cache so the run degrades to a cold, unrestricted
393
+ # cross-region attempt instead of looping on the same volume-backed spec (the IN_QUEUE-forever /
394
+ # persistent-volume-failure block). Sticky: once dropped it stays dropped. A non-volume flake
395
+ # (stall/preempt) keeps the cache so the warm-weights benefit survives ordinary retries.
396
+ # Gate to RunPod: instance providers (Lambda/Hyperstack) already fall back to a cold run
397
+ # per-region INSIDE the launch walk, so their no_capacity isn't cache-caused. Only the SHARED
398
+ # platform cache triggers it (gate on the exact name); a non-shared per-org/custom volume is the
399
+ # intended escape-hatch isolation (runner._assign_weight_cache_volume) and must NOT be stripped.
400
+ run_had_cache = bool(
401
+ chosen is not None
402
+ and chosen.provider == "runpod"
403
+ and getattr(run_spec.gpu, "network_volume", None) == WEIGHT_CACHE_VOLUME_NAME
404
+ )
405
+ first_cache_drop = (
406
+ run_had_cache
407
+ and not drop_weight_cache
408
+ and res.failure in ("no_capacity", "poll_error")
409
+ )
410
+ # "retrying" is true when the GPU-walk budget remains OR a cache-drop fallback will retry this
411
+ # even past it (first_cache_drop) — else the log would say "not retrying" while the loop actually
412
+ # continues with the reserved cache-less fallback attempt.
413
+ print(
414
+ f"seed={seed} attempt={attempt} failed ({res.failure}); "
415
+ f"{'retrying (resume from last checkpoint)' if infra_shaped and (walk_attempt < max_retries or first_cache_drop) else 'not retrying'}"
416
+ f"\n--- failure detail ---\n{(res.detail or '')[:2000]}\n---",
417
+ file=log,
418
+ flush=True,
419
+ )
420
+ if not infra_shaped:
421
+ break
422
+ # Stop when the GPU-walk retry budget is exhausted — UNLESS a cache-drop fallback is still
423
+ # available. The bonus attempt granted above is reserved for exactly this transition; once the
424
+ # cache is dropped (sticky), ``first_cache_drop`` is False so the budget check applies normally
425
+ # and the loop cannot spin past its one extra cache-less attempt.
426
+ if walk_attempt >= max_retries and not first_cache_drop:
427
+ break
428
+ if first_cache_drop:
429
+ drop_weight_cache = True
430
+ # This attempt was the FREE cache-drop fallback, not a GPU-walk retry — exclude it from the
431
+ # budget so the subsequent ``walk_attempt`` still counts ``max_retries`` real retries.
432
+ cache_drop_consumed += 1
433
+ # Do NOT advance the GPU walk on this transition: the next attempt should retry the SAME
434
+ # cheapest GPU without the volume on the wider all-DC pool first — the miss may have been
435
+ # the cache's datacenter set, not the GPU class globally. Only walk if THAT also fails.
436
+ elif chosen is not None:
437
+ # Record what THIS attempt burned so the next pick escapes it cross-provider — only when
438
+ # an attempt actually provisioned a class and lost it infra-shaped. An allocation/pricing
439
+ # failure (chosen is None) never tried a card, so it must not poison the next pick.
440
+ failed_providers.add(chosen.provider)
441
+ tried_classes.add((chosen.provider, chosen.gpu))
442
+ # Retry budget exhausted: GC every endpoint this seed registered (the final
443
+ # attempt's is in status.remote for _gc_run_endpoints, but intermediate rN ones
444
+ # are only known here).
445
+ _gc_seen_endpoints()
446
+ raise RuntimeError(f"seed {seed} failed after retries: {last_detail}")
447
+
448
+
449
+ def _run_job_inner(
450
+ spec: JobSpec,
451
+ log_path: str,
452
+ upload_code,
453
+ runtime_secrets: dict[str, str] | None = None,
454
+ ) -> None:
455
+ from flash.runner import _run_seed_loop, _RunCancelled, _update, get_status
456
+
457
+ try:
458
+ # Ship the flash package to the run's HF repo (the per-run [train] hf_repo) so the GPU
459
+ # worker — which fetches code/** from that same repo — can run it.
460
+ upload_code(spec.train.hf_repo)
461
+ with open(log_path, "a") as log:
462
+ _run_seed_loop(
463
+ spec,
464
+ log,
465
+ start_index=0,
466
+ prior_cost=0.0,
467
+ runtime_secrets=runtime_secrets,
468
+ )
469
+ except _RunCancelled:
470
+ return # cancel_run already set the terminal state
471
+ except Exception as exc:
472
+ if get_status(spec.run_id).state != "cancelled":
473
+ _update(spec.run_id, "failed", error=str(exc))
474
+ raise
475
+
476
+
477
+ def _run_seed_loop(
478
+ spec: JobSpec,
479
+ log,
480
+ *,
481
+ start_index: int,
482
+ prior_cost: float,
483
+ runtime_secrets: dict[str, str] | None = None,
484
+ ) -> None:
485
+ """Run spec.train.seeds[start_index:] under supervision; finalize the run.
486
+
487
+ Shared by a fresh submit (start_index=0) and post-restart recovery, which
488
+ resumes the remaining seeds after the in-flight one completes."""
489
+ from flash.runner import (
490
+ TERMINAL_STATES,
491
+ _persist_metrics,
492
+ _RunCancelled,
493
+ _submit_seed_supervised,
494
+ _update,
495
+ artifacts_dir,
496
+ get_status,
497
+ )
498
+
499
+ total_cost = prior_cost
500
+ seeds = spec.train.seeds
501
+ for i in range(start_index, len(seeds)):
502
+ seed = seeds[i]
503
+ # Defense in depth against the recovery TOCTOU (see attach_run): a run can be flipped
504
+ # into ANY terminal state — not just `cancelled` — by a concurrent thread/process
505
+ # (e.g. another recovery marking it failed/done) between the resume decision and here.
506
+ # Bail before _update + _submit_seed_supervised so we never submit PAID GPU work for an
507
+ # already-terminal run. (The `running` _update below would be CAS-rejected anyway, but
508
+ # the supervised submit would still have spent.) _RunCancelled is the loop's terminal
509
+ # signal; its callers already swallow it / leave the existing terminal state intact.
510
+ if get_status(spec.run_id).state in TERMINAL_STATES:
511
+ raise _RunCancelled(f"run {spec.run_id} is already terminal; not submitting seed")
512
+ _update(spec.run_id, "running")
513
+ print(
514
+ f"starting seed={seed} phase={spec.phase} model={spec.model} gpu={spec.gpu.type}",
515
+ file=log,
516
+ flush=True,
517
+ )
518
+ metrics = _submit_seed_supervised(spec, seed, log, runtime_secrets=runtime_secrets)
519
+ total_cost += _persist_metrics(spec, seed, metrics)
520
+ # A cancel can land while this thread writes metrics — after the supervised
521
+ # late-cancel check. Re-read before the post-seed status writes so a late
522
+ # worker success doesn't resurrect a user-cancelled run via this "running"
523
+ # update (or the final "done" below).
524
+ with contextlib.suppress(FileNotFoundError):
525
+ if get_status(spec.run_id).state == "cancelled":
526
+ raise _RunCancelled(f"run {spec.run_id} was cancelled")
527
+ # If more seeds follow, this seed's endpoint/instance is already torn down, so
528
+ # clear the now-stale remote handle: a restart in the gap before the next
529
+ # seed's on_handle must not make recover_runs reattach to a deleted handle and
530
+ # fail the run. Record the next seed index so a restart in that handle-less gap
531
+ # RESUMES the remaining seeds (recover_runs) instead of discarding the completed
532
+ # ones. The last seed keeps its handle for post-run observability (the run is
533
+ # about to go terminal, which recover_runs never reattaches).
534
+ more_seeds = (i + 1) < len(seeds)
535
+ _update(
536
+ spec.run_id,
537
+ "running",
538
+ cost_usd=total_cost,
539
+ **({"remote": None, "resume_seed_index": i + 1} if more_seeds else {}),
540
+ )
541
+ print(
542
+ f"seed={seed} done: train_wall={metrics.get('wall_seconds')} cost_usd={total_cost:.4f}",
543
+ file=log,
544
+ flush=True,
545
+ )
546
+ # Final guard: a cancel landing after the last seed's check must not be overwritten
547
+ # by the terminal "done".
548
+ with contextlib.suppress(FileNotFoundError):
549
+ if get_status(spec.run_id).state == "cancelled":
550
+ raise _RunCancelled(f"run {spec.run_id} was cancelled")
551
+ _update(
552
+ spec.run_id,
553
+ "done",
554
+ cost_usd=total_cost,
555
+ artifacts_dir=artifacts_dir(spec),
556
+ resume_seed_index=None,
557
+ )
558
+ _charge_completed_run_best_effort(spec, log)
559
+ _register_checkpoints_best_effort(spec, log)
560
+
561
+
562
+ def _register_checkpoints_best_effort(spec: JobSpec, log) -> None:
563
+ """Mirror a finished run's deployable per-step checkpoints to the backend store.
564
+
565
+ Best-effort and isolated from billing: the checkpoints live on HF regardless, so a
566
+ persistence miss never changes the run's outcome."""
567
+ from flash.runner import get_status
568
+
569
+ try:
570
+ from flash.server.checkpoints import register_checkpoints_best_effort
571
+
572
+ register_checkpoints_best_effort(get_status(spec.run_id), log=log)
573
+ except Exception as exc: # never let checkpoint bookkeeping disturb a run
574
+ print(f"[ckpt] register warn ({spec.run_id}): {exc}", file=log, flush=True)
575
+
576
+
577
+ def _charge_completed_run_best_effort(spec: JobSpec, log) -> None:
578
+ """Bill a successfully completed external run without changing its training result."""
579
+ from flash.runner import _update, get_status
580
+ from flash.server.auth import INTERNAL_KEY_ENV
581
+ from flash.server.billing import BillingError, charge_completed_run
582
+
583
+ status = get_status(spec.run_id)
584
+ if not status.billing_context or status.billing_state == "charged":
585
+ return
586
+
587
+ internal_key = os.environ.get(INTERNAL_KEY_ENV, "").strip()
588
+ if not internal_key:
589
+ detail = f"{INTERNAL_KEY_ENV} is not configured; completed run was not billed"
590
+ _update(
591
+ spec.run_id,
592
+ get_status(spec.run_id).state,
593
+ billing_state="failed",
594
+ billing_error=detail,
595
+ )
596
+ print(f"billing failed: {detail}", file=log, flush=True)
597
+ return
598
+
599
+ _update(
600
+ spec.run_id,
601
+ get_status(spec.run_id).state,
602
+ billing_state="charging",
603
+ billing_error=None,
604
+ )
605
+ status = get_status(spec.run_id)
606
+ try:
607
+ charge = charge_completed_run(internal_key=internal_key, status=status)
608
+ except BillingError as exc:
609
+ _update(
610
+ spec.run_id,
611
+ get_status(spec.run_id).state,
612
+ billing_state="failed",
613
+ billing_error=exc.detail,
614
+ )
615
+ print(f"billing failed: {exc.detail}", file=log, flush=True)
616
+ return
617
+
618
+ _update(
619
+ spec.run_id,
620
+ get_status(spec.run_id).state,
621
+ billing_state="charged",
622
+ billing_error=None,
623
+ billing_charge=charge,
624
+ )
625
+ print(
626
+ f"billing charged: amount_cents={charge.get('amountCents')} "
627
+ f"replay={bool(charge.get('replay'))}",
628
+ file=log,
629
+ flush=True,
630
+ )
631
+
632
+
633
+ def _gc_run_endpoints(spec: JobSpec) -> None:
634
+ """Best-effort teardown of every endpoint a run may have registered.
635
+
636
+ Retried attempts run on rN-suffixed endpoints whose runpod_flash state is
637
+ isolated per-suffix, so the name-based terminate_endpoint cannot see them;
638
+ the persisted remote handle's endpoint id covers whichever attempt ran
639
+ last via the plain REST API."""
640
+ from flash.runner import get_status
641
+
642
+ status = None
643
+ with contextlib.suppress(Exception):
644
+ status = get_status(spec.run_id)
645
+ if status is not None and status.remote:
646
+ try:
647
+ from flash.providers import get_provider
648
+ from flash.providers.base import JobHandle
649
+
650
+ handle = JobHandle.from_dict(status.remote)
651
+ get_provider(handle.provider).destroy(handle)
652
+ except Exception:
653
+ # Best-effort GC; the name-reconstructed RunPod gc below is the backstop.
654
+ pass
655
+ try:
656
+ # RunPod's gc reaps rN-suffixed endpoints the persisted handle can't name.
657
+ from flash.providers import get_provider
658
+
659
+ get_provider("runpod").gc(spec)
660
+ except Exception:
661
+ # Best-effort GC; an undeleted endpoint only holds worker quota, never blocks the run.
662
+ pass
663
+ # Instance-based providers (Lambda, Hyperstack) bill until terminated: the runner's per-attempt
664
+ # `finally` already tears them down, but a crashed supervisor thread can leave one behind. Reap
665
+ # any instance still named for this run via each configured provider's gc (best-effort).
666
+ from flash.providers import available_providers, get_provider
667
+
668
+ _avail = available_providers()
669
+ for _prov in ("lambda", "hyperstack"):
670
+ if _prov in _avail:
671
+ with contextlib.suppress(Exception):
672
+ get_provider(_prov).gc(spec)