freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,686 @@
1
+ """Platform runner: drives managed RunPod GPUs, one allocation per seed."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import contextlib
6
+ import json
7
+ import os
8
+ import re
9
+ import tempfile
10
+ import threading
11
+ import time
12
+ import uuid
13
+ from dataclasses import asdict, dataclass, field
14
+
15
+ from flash.catalog import ModelInfo, resolve_model
16
+ from flash.spec import JobSpec
17
+
18
+ # Fixed local storage roots (not operator-configurable): run-state JSON + result artifacts,
19
+ # both under the ~/.flash state dir (same root as server/db.py's DB_PATH) so a single
20
+ # directory holds all control-plane state — mount one volume at ~/.flash to persist it.
21
+ # Tests redirect them via monkeypatch.setattr(runner, "RUNS_DIR"/"RESULTS_DIR").
22
+ _STATE_DIR = os.path.join(os.path.expanduser("~"), ".flash")
23
+ RUNS_DIR = os.path.join(_STATE_DIR, "runs")
24
+ RESULTS_DIR = os.path.join(_STATE_DIR, "results")
25
+ TERMINAL_STATES = frozenset({"done", "failed", "cancelled", "dry_run"})
26
+ # Terminal states a deploy must NOT overwrite. `done` is terminal but IS deployable
27
+ # (deploying a finished run is the whole point), so it's excluded here; cancelled/failed/
28
+ # dry_run must never be flipped to `deployed`.
29
+ _UNDEPLOYABLE_STATES = TERMINAL_STATES - {"done"}
30
+ # Serializes the read-check-write in _update so a status transition is an atomic
31
+ # compare-and-set (the control plane is single-instance with per-run threads).
32
+ _STATUS_LOCK = threading.Lock()
33
+
34
+
35
+ def artifacts_dir(spec: JobSpec) -> str:
36
+ """Run-scoped artifact root: results/runpod/<phase>/<run_id>."""
37
+ return os.path.join(RESULTS_DIR, "runpod", spec.phase, spec.run_id)
38
+
39
+
40
+ def adapter_prefix(spec: JobSpec, seed: int | None = None) -> str:
41
+ """A run's adapter location on the HF artifact store."""
42
+ chosen = spec.train.seeds[0] if seed is None else seed
43
+ return f"{spec.phase}/{spec.run_id}/seed{chosen}"
44
+
45
+
46
+ def adapter_ref(spec: JobSpec, seed: int | None = None) -> str | None:
47
+ """Full init_from_adapter reference for a run's trained adapter."""
48
+ if not spec.train.hf_repo:
49
+ return None
50
+ return f"{spec.train.hf_repo}:{adapter_prefix(spec, seed=seed)}"
51
+
52
+
53
+ def _adapter_ref_from_status_spec(raw: dict) -> str | None:
54
+ try:
55
+ return adapter_ref(JobSpec.from_dict(raw))
56
+ except Exception:
57
+ return None
58
+
59
+
60
+ def _gpu_rate(gpu_type: str) -> float:
61
+ """Static representative $/hr for cost projection;
62
+ the worker also records wall time so cost = wall_hours * rate."""
63
+ try:
64
+ from flash.providers.runpod.pricing import hourly_rate
65
+
66
+ return hourly_rate(gpu_type)
67
+ except Exception:
68
+ return 0.80
69
+
70
+
71
+ @dataclass
72
+ class RunStatus:
73
+ run_id: str
74
+ state: str
75
+ spec: dict
76
+ created_at: float = field(default_factory=time.time)
77
+ updated_at: float = field(default_factory=time.time)
78
+ cost_usd: float = 0.0
79
+ error: str | None = None
80
+ artifacts_dir: str | None = None
81
+ adapter_ref: str | None = None
82
+ deployment: dict | None = None
83
+ # Durable job handle {endpoint_id, endpoint_name, job_id} — lets any process
84
+ # reattach to / cancel the remote job (see `flash status --follow`).
85
+ remote: dict | None = None
86
+ # Index of the next seed to run for a multi-seed job, set while the remote handle
87
+ # is cleared in the gap between seeds. Lets recover_runs resume the remaining seeds
88
+ # after an inter-seed restart instead of failing the run (losing completed work).
89
+ resume_seed_index: int | None = None
90
+ # Realized provider cost (COGS), pulled from the provider's billing API after the run
91
+ # finishes by the reconciliation job (flash/server/reconcile.py) and reported to the
92
+ # freesolo backend for estimator accuracy. Distinct from ``cost_usd`` (the wall x $/hr
93
+ # PROJECTION); ``reconciled_at`` marks that the realized pull has happened so it isn't
94
+ # re-pulled. Both stay None for un-reconciled / pre-instrumentation runs.
95
+ realized_cost_usd: float | None = None
96
+ reconciled_at: float | None = None
97
+ # Wall-clock the run first went terminal (~training teardown). Stamped ONCE on the first
98
+ # terminal transition and never moved, so it survives later ``updated_at`` bumps from
99
+ # deploy / heartbeat / reconcile. Reconciliation uses it as the instance-billing ``run_end``:
100
+ # a run deployed after completion has ``updated_at`` = deploy time, which would over-bill the
101
+ # flat $/hr from launch until deployment instead of until training teardown. None pre-feature.
102
+ finished_at: float | None = None
103
+ # Non-secret customer billing context, set for externally-submitted runs. Completion-time
104
+ # billing uses this org id with the operator internal key; user API keys are not persisted.
105
+ billing_context: dict | None = None
106
+ billing_state: str | None = None
107
+ billing_error: str | None = None
108
+ billing_charge: dict | None = None
109
+ # Non-secret Freesolo identity used to mirror run status to the platform UI.
110
+ platform_context: dict | None = None
111
+ # Last worker heartbeat observed by the provider poller. This is intentionally
112
+ # duplicated from the HF artifact channel into local run status so `flash status`
113
+ # can show live worker/GPU state without doing a fresh HF read.
114
+ last_heartbeat: dict | None = None
115
+ gpu_status: dict | None = None
116
+
117
+ def to_dict(self) -> dict:
118
+ data = asdict(self)
119
+ data["adapter_ref"] = (
120
+ _adapter_ref_from_status_spec(self.spec)
121
+ if self.state in {"done", "deployed"}
122
+ else None
123
+ )
124
+ return data
125
+
126
+
127
+ class _RunCancelled(RuntimeError):
128
+ """User cancellation observed mid-run; terminal, never retried/overwritten."""
129
+
130
+
131
+ def new_run_id() -> str:
132
+ return f"flash-{int(time.time())}-{uuid.uuid4().hex[:8]}"
133
+
134
+
135
+ _RUN_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
136
+
137
+
138
+ def require_safe_run_id(run_id: str) -> str:
139
+ """Reject run ids that could traverse outside the runs directory.
140
+
141
+ Run ids flow from API path params into filesystem paths (status json,
142
+ log files); restrict them to a conservative filename alphabet.
143
+ """
144
+ if not _RUN_ID_RE.match(run_id or ""):
145
+ raise ValueError(f"invalid run_id: {run_id!r}")
146
+ return run_id
147
+
148
+
149
+ def runs_file_path(run_id: str, suffix: str) -> str:
150
+ """Containment-checked path for a run's file under RUNS_DIR.
151
+
152
+ Belt and braces with require_safe_run_id: the resolved path must stay
153
+ inside the runs directory even if the alphabet check ever regresses.
154
+ """
155
+ base = os.path.abspath(RUNS_DIR)
156
+ path = os.path.normpath(os.path.join(base, f"{require_safe_run_id(run_id)}{suffix}"))
157
+ if not path.startswith(base + os.sep):
158
+ raise ValueError(f"invalid run_id: {run_id!r}")
159
+ return path
160
+
161
+
162
+ def _with_model_disk(spec: JobSpec, info: ModelInfo) -> dict:
163
+ """Spec dict with gpu.disk_gb raised to the model's min_disk_gb (catalog).
164
+
165
+ Big-checkpoint models (whose weights alone exceed the default) need more container
166
+ disk than the platform's 64 GB default; this makes them work without users having
167
+ to know the right ``gpu.disk_gb``.
168
+ """
169
+ d = spec.to_dict()
170
+ need = int(getattr(info, "min_disk_gb", 0) or 0)
171
+ if need > int(d["gpu"].get("disk_gb") or 0):
172
+ d["gpu"] = {**d["gpu"], "disk_gb": need}
173
+ return d
174
+
175
+
176
+ # The HF namespace the control plane creates per-run artifact repos under: the operator org whose
177
+ # HF_TOKEN the control plane runs with. An operator-infra constant, not a user/env knob.
178
+ _ARTIFACT_NAMESPACE = "Freesolo-Co"
179
+
180
+
181
+ def _assign_managed_hf_repo(spec: JobSpec) -> JobSpec:
182
+ """Assign the run's HF artifact repo server-side — it is platform-managed, never user-set.
183
+
184
+ Each run gets its own private dataset repo ``Freesolo-Co/flashrun-<run_id>``. The control-plane
185
+ HF_TOKEN creates and writes it (code, adapters, checkpoints, telemetry); a user-chosen namespace
186
+ would 403 that token at ``upload_code``. Any inbound ``train.hf_repo`` is overwritten. The
187
+ run_id must be finalized first: a per-run repo keyed on the placeholder ``"local"`` would
188
+ collide across runs and overwrite each other's code/adapters/state.
189
+ """
190
+ if not spec.run_id or spec.run_id == "local":
191
+ raise ValueError("run_id must be finalized before assigning the per-run artifact repo")
192
+ repo = f"{_ARTIFACT_NAMESPACE}/flashrun-{spec.run_id}"
193
+ d = spec.to_dict()
194
+ d["train"] = {**d["train"], "hf_repo": repo}
195
+ return JobSpec.from_dict(d)
196
+
197
+
198
+ def _assign_resolved_env_sha(spec: JobSpec) -> JobSpec:
199
+ """Resolve the environment's GitHub ref->SHA ONCE here so every worker in the fan-out boots from
200
+ an immutable pinned sha instead of each re-resolving the symbolic ref (e.g. "main") against the
201
+ GitHub commits API. A cold spawn wave of N workers otherwise fires N concurrent commit-API calls
202
+ and trips GitHub's secondary rate limit; a worker-side in-process cache cannot help, because
203
+ each worker is a separate process. Best-effort: any failure (no network/token, transient limit,
204
+ or a non-GitHub env) leaves resolved_sha empty and the worker resolves the ref itself via the
205
+ in-worker jittered retry + retriable-reschedule path, so submission never blocks on GitHub.
206
+ """
207
+ import logging
208
+
209
+ env_id = spec.environment.id
210
+ if not env_id or spec.environment.resolved_sha:
211
+ return spec
212
+ try:
213
+ from flash.envs.adapter import (
214
+ _parse_github_environment_ref,
215
+ _resolve_ref_sha,
216
+ is_managed_environment_slug,
217
+ managed_slug_to_github_ref,
218
+ )
219
+
220
+ ref_str = (
221
+ managed_slug_to_github_ref(env_id) if is_managed_environment_slug(env_id) else env_id
222
+ )
223
+ parsed = _parse_github_environment_ref(ref_str)
224
+ if parsed is None:
225
+ return spec # local/path or non-GitHub env: nothing to pin
226
+ # Fail fast: a single short request, no rate-limit sleeps. This best-effort pin must never
227
+ # delay/block run creation (esp. submit_job(background=True)); if GitHub is slow or limiting,
228
+ # we fall straight through and the worker resolves the ref itself with the full retry budget.
229
+ sha = _resolve_ref_sha(parsed, timeout=10.0, max_rate_limit_retries=0)
230
+ except Exception as e:
231
+ # Never block submission on a control-plane resolve; the worker falls back to resolving the
232
+ # ref itself. Log for visibility (consistent with the rest of this module's logging).
233
+ logging.getLogger(__name__).warning(
234
+ "resolve-once: could not pin env ref->sha for %r (%s); worker will resolve", env_id, e
235
+ )
236
+ return spec
237
+ if not sha:
238
+ return spec
239
+ d = spec.to_dict()
240
+ d["environment"] = {**d["environment"], "resolved_sha": sha}
241
+ return JobSpec.from_dict(d)
242
+
243
+
244
+ # Shared, platform-wide model-weight cache (NOT per-org). The cache holds downloaded base-model
245
+ # weights — a run's trained adapters/checkpoints upload to the per-run managed HF repo
246
+ # (_assign_managed_hf_repo), never here — so one global volume reused by every run is both safe and
247
+ # the highest-hit-rate option: a popular base model (e.g. the 4B) is downloaded once per region,
248
+ # ever, instead of once per run. The RunPod provider attaches a per-DC volume in EVERY cache
249
+ # datacenter and allows the endpoint across all of them, so there is NO single-DC pin (the failure
250
+ # mode that sank the earlier per-org EU-RO-1 attempt — runs wedged IN_QUEUE on one full region).
251
+ # Fully managed: a fixed name + size, no env knobs. 100 GB holds the whole curated catalog (the
252
+ # largest, the 9B, is ~19 GB of weights) with ample headroom; the preload step warms it.
253
+ # COST/GC: provisioned EAGERLY — one ``flash-weights-<dc>`` volume in EVERY storage datacenter, so the
254
+ # cache exists in every region a run could land in (no first-run-cold-then-warm-next-time gap). The
255
+ # endpoint deploy creates-or-attaches all of them (jobs.weight_cache_volumes over the full storage-DC
256
+ # set), and ``preload`` warms them with the catalog weights. Standing storage is therefore the whole
257
+ # fleet: (#storage DCs) x 100 GB of PERMANENT billed storage (~11 x 100 GB ~= 1.1 TB ~= $77/mo today;
258
+ # grows by one volume if the SDK adds a storage region). RunPod never auto-deletes network volumes;
259
+ # reclaim the fleet with ``python -m flash.providers.runpod.preload --teardown`` (also reclaims the
260
+ # Lambda/Hyperstack caches). Lambda filesystems + Hyperstack volumes are likewise pre-created in every
261
+ # region/environment by ``preload --provision`` (pure control-plane API, no GPU).
262
+ #
263
+ # TRUST MODEL (shared multi-tenant cache). The catalog gate makes the run's SPEC model public:
264
+ # ``_assign_weight_cache_volume`` attaches the cache only for ``model_policy == "catalog"`` runs
265
+ # (always public; resolve_model validates catalog membership) and leaves open-model ("allow") runs
266
+ # cache-less, so a spec that NAMES a private/gated model never persists it onto the shared mount.
267
+ # CONFIDENTIALITY CAVEAT (not fully closed): the redirect is process-global (``weight_cache_env`` sets
268
+ # ``HF_HOME`` onto the mount), so it scopes the SPEC model but NOT additional HF repos the run's
269
+ # environment/reward code may fetch at execution time with the forwarded platform HF_TOKEN — those
270
+ # would also land on the shared mount and be readable by a later tenant in the region. The residual is
271
+ # bounded by (a) the catalog gate on the base model, (b) the scope of the platform HF_TOKEN, and (c)
272
+ # flash environments being published/reviewed Hub/GitHub artifacts (not anonymous code) — but it is a
273
+ # real limitation. The proper hardening (scope the mount to the trusted base-model prefetch via an
274
+ # explicit ``cache_dir`` while env/reward code uses an ephemeral HF cache, or a READ-ONLY mount
275
+ # populated only by preload) is worker-side and tracked as a follow-up.
276
+ # A second residual is INTEGRITY: the mount is read-WRITE on every run and a run executes
277
+ # its Freesolo environment code on the worker, so a hostile/buggy environment COULD overwrite a cached
278
+ # public model's content-addressed blobs and poison a later run loading that same model in the region.
279
+ # That is the accepted flip side of "one shared cache for everything" — flash environments are
280
+ # published/reviewed Hub artifacts, not anonymous code, and the data at risk is public weights, not
281
+ # secrets. The clean isolation — mounting the volume READ-ONLY for the run and writing only via the
282
+ # trusted preload — is NOT yet expressible through the runpod_flash SDK (NetworkVolume has no
283
+ # mount-mode field; ``extra="forbid"``). When the SDK gains a read-only mount, switch runs to RO +
284
+ # populate exclusively via preload. Until then this is a documented integrity tradeoff; flip to per-org
285
+ # volumes (keyed off platform_context.org_id) if strict tenant isolation is required.
286
+ WEIGHT_CACHE_VOLUME_NAME = "flash-weights"
287
+ WEIGHT_CACHE_VOLUME_GB = 100
288
+
289
+
290
+ def _assign_weight_cache_volume(spec: JobSpec) -> JobSpec:
291
+ """Attach the shared, platform-managed weight-cache volume — ONLY for PUBLIC catalog models.
292
+
293
+ Platform-managed (never user config), exactly like the managed HF repo: assigned here, not
294
+ surfaced in the config schema. The provider builds the per-region volume fleet + the cross-DC
295
+ endpoint at deploy time (jobs.weight_cache_endpoint_kwargs) off this name, and the worker env
296
+ redirects HF_HOME onto the mount whenever the volume is attached.
297
+
298
+ CONFIDENTIALITY GATE: the cache is SHARED cross-tenant, and attaching it redirects HF_HOME onto
299
+ the shared mount, so a model's downloaded weights persist there for every later run in the region.
300
+ That is only safe for PUBLIC weights. Managed config runs are always catalog-only (the schema
301
+ hardcodes model_policy="catalog"), and ``submit_job`` runs ``resolve_model`` BEFORE this — so a
302
+ ``catalog``-policy spec is already guaranteed to be a curated PUBLIC catalog model (resolve_model
303
+ raises otherwise). The ONLY way to reach a non-catalog, possibly PRIVATE/GATED HF repo is
304
+ model_policy="allow" (programmatic/internal use; not selectable from a submitted config). Such a
305
+ model would be downloaded with the forwarded platform HF_TOKEN, and persisting its weights to the
306
+ shared multi-tenant cache would leak them cross-tenant. So the cache is attached ONLY for
307
+ ``model_policy == "catalog"`` runs; an open/"allow" run is left cache-less, confining its weights
308
+ to the worker's ephemeral disk (it can still use the per-org escape-hatch volume).
309
+
310
+ The confidentiality gate takes PRECEDENCE over the "don't override an explicit volume" no-op: an
311
+ open-model ("allow") run that ALREADY carries the SHARED cache name (e.g. a programmatic spec that
312
+ pre-set it) is FORCED cache-less here — its possibly-private weights must never reach the shared
313
+ mount. A different (per-org / custom) volume name on an open run is left intact: that's the
314
+ escape-hatch isolation, not the shared cache.
315
+
316
+ Outcomes: (a) open-model run -> never on the SHARED cache (strip it if pre-set; keep a non-shared
317
+ volume); (b) catalog run with a pre-set volume -> left as-is (explicit/test assignment honored);
318
+ (c) catalog run with no volume -> attach the shared cache.
319
+
320
+ See the module-level TRUST MODEL note above for the shared-cache integrity tradeoff (a run's env
321
+ code has write access to the shared mount; RO mount isn't SDK-expressible yet).
322
+ """
323
+ is_catalog = getattr(spec, "model_policy", "catalog") == "catalog"
324
+ existing = getattr(spec.gpu, "network_volume", None)
325
+ # CONFIDENTIALITY: an open-model run must NEVER ride the SHARED cross-tenant cache — even if the
326
+ # spec already pinned it. Strip the shared name (force cache-less); a non-shared per-org volume is
327
+ # the intended escape hatch and is left intact. This is checked BEFORE the "honor an existing
328
+ # volume" no-op so a pre-set flash-weights can't bypass the gate.
329
+ if not is_catalog:
330
+ if existing == WEIGHT_CACHE_VOLUME_NAME:
331
+ d = spec.to_dict()
332
+ d["gpu"] = {**d["gpu"], "network_volume": None}
333
+ return JobSpec.from_dict(d)
334
+ return spec # no shared cache to strip (cache-less already, or a non-shared escape-hatch volume)
335
+ if existing:
336
+ return spec # catalog run with an explicit/test volume already assigned — honor it
337
+ d = spec.to_dict()
338
+ d["gpu"] = {
339
+ **d["gpu"],
340
+ "network_volume": WEIGHT_CACHE_VOLUME_NAME,
341
+ "network_volume_gb": WEIGHT_CACHE_VOLUME_GB,
342
+ }
343
+ return JobSpec.from_dict(d)
344
+
345
+
346
+ def _run_job_background(
347
+ spec: JobSpec,
348
+ runtime_secrets: dict[str, str] | None = None,
349
+ *,
350
+ resolve_env_sha: bool = False,
351
+ ) -> None:
352
+ """Daemon-thread entrypoint for background runs.
353
+
354
+ ``_run_job`` -> ``_run_job_inner`` persists the terminal state (failed/cancelled) BEFORE the
355
+ inner ``raise`` that the synchronous ``submit_job(background=False)`` contract depends on (its
356
+ callers — e.g. ``test_supervisor_fail_fast`` — expect the exception). In a daemon thread that
357
+ re-raise has no caller, so Python prints a full ``Exception in thread`` traceback for *every*
358
+ failed/cancelled run — log noise that buries real errors and trips monitoring. Swallow + log a
359
+ one-line note here, while defensively ensuring a terminal ``failed`` state via the
360
+ terminal-sticky ``_update`` (covers a crash BEFORE ``_run_job_inner`` persisted anything, e.g. an
361
+ import/model-resolve error), leaving the synchronous raise path untouched. Defined in this module
362
+ (not lifecycle) so it dispatches through the package-level ``_run_job`` that tests monkeypatch.
363
+
364
+ ``resolve_env_sha`` defers the (network) env ref->sha pin to THIS background thread, off the
365
+ run-creation critical path: ``submit_job(background=True)`` (the managed API path) saves + reports
366
+ the run status FIRST and returns, so a slow/rate-limited GitHub commits API can never block or
367
+ delay run creation. The resolve is still best-effort (any failure leaves ``resolved_sha`` empty
368
+ and the worker resolves the ref itself); the pinned spec is handed to the fan-out below — it is
369
+ only a boot optimization, so it does not need to be re-persisted into the run status JSON.
370
+ """
371
+ import logging
372
+
373
+ try:
374
+ if resolve_env_sha:
375
+ # Pin the env ref->sha HERE (in the background) instead of before status save, so a slow
376
+ # GitHub commits API can't delay run creation. Best-effort: on any failure the spec stays
377
+ # unpinned and each worker resolves the ref itself with its full retry budget.
378
+ with contextlib.suppress(Exception):
379
+ spec = _assign_resolved_env_sha(spec)
380
+ if runtime_secrets:
381
+ _run_job(spec, runtime_secrets=runtime_secrets)
382
+ else:
383
+ _run_job(spec)
384
+ except Exception as e:
385
+ # _run_job -> _run_job_inner normally persists the terminal failure before its re-raise, but a
386
+ # crash before that persist point would leave the run stuck non-terminal. Record `failed` ONLY
387
+ # when the run isn't already terminal: _update allows same-state writes (so workers can update
388
+ # cost/error/artifacts on a terminal run), so an unconditional write here would clobber an
389
+ # already-persisted failure detail with this wrapper's (less specific) exception. Guard the
390
+ # whole safety-net (suppress) so a missing/unwritable status can't re-raise out of the daemon
391
+ # thread — that traceback is the exact noise this wrapper exists to prevent.
392
+ with contextlib.suppress(Exception):
393
+ if get_status(spec.run_id).state not in TERMINAL_STATES:
394
+ _update(spec.run_id, "failed", error=str(e))
395
+ logging.getLogger(__name__).warning("background run %s ended in error: %s", spec.run_id, e)
396
+
397
+
398
+ def submit_job(
399
+ spec: JobSpec,
400
+ dry_run: bool = False,
401
+ background: bool = False,
402
+ runtime_secrets: dict[str, str] | None = None,
403
+ billing_context: dict | None = None,
404
+ platform_context: dict | None = None,
405
+ ) -> RunStatus:
406
+ """Submit a job. In real mode this allocates and provisions the cheapest validated GPU class
407
+ that fits the run; dry-run only records state."""
408
+ info = resolve_model(spec.model, spec.algorithm, policy=spec.model_policy, gpu=spec.gpu.type)
409
+ # Finalize the run_id BEFORE assigning the per-run artifact repo. The JobSpec default run_id is
410
+ # the placeholder "local" (truthy), so `or new_run_id()` alone would keep it; treat "local" as
411
+ # unset so programmatic/test callers also get a unique id and per-run repos never collide.
412
+ run_id = spec.run_id if (spec.run_id and spec.run_id != "local") else new_run_id()
413
+ spec = JobSpec.from_dict({**_with_model_disk(spec, info), "run_id": run_id})
414
+ # The artifact repo is assigned here, after the run_id is finalized: per-run, operator-owned.
415
+ spec = _assign_managed_hf_repo(spec)
416
+ # Attach the shared model-weight cache (platform-managed). Before the RunStatus build so a
417
+ # dry-run spec carries it too (the dry-run short-circuits below) — keeps the assignment testable
418
+ # without a real provision and visible in `flash status`.
419
+ spec = _assign_weight_cache_volume(spec)
420
+ # NB: the env ref->sha pin (_assign_resolved_env_sha) makes a GitHub commits-API call, so it is
421
+ # deliberately NOT done here, on the run-creation critical path. The status is created + saved +
422
+ # reported FIRST (below) so creation never blocks/delays on a slow or rate-limited GitHub — the
423
+ # pin is deferred into the background run thread (background=True) or done just before the
424
+ # synchronous fan-out (background=False), both AFTER the run record exists.
425
+ status = RunStatus(
426
+ run_id=spec.run_id,
427
+ state="queued",
428
+ spec=spec.to_dict(),
429
+ billing_context=billing_context,
430
+ billing_state="pending" if billing_context else None,
431
+ platform_context=platform_context,
432
+ )
433
+ _save_status(status)
434
+ _report_status(status)
435
+ if dry_run:
436
+ status.state = "dry_run"
437
+ _save_status(status)
438
+ _report_status(status)
439
+ return status
440
+ if background:
441
+ # Run creation is now done (status saved + reported); the GitHub env-sha pin happens INSIDE
442
+ # this thread (resolve_env_sha=True), so the API response is never blocked by GitHub retries.
443
+ threading.Thread(
444
+ target=_run_job_background,
445
+ args=(spec, runtime_secrets or {}),
446
+ kwargs={"resolve_env_sha": True},
447
+ daemon=True,
448
+ ).start()
449
+ return get_status(spec.run_id)
450
+ # Synchronous path: the status record already exists, so resolving the pin here no longer blocks
451
+ # the creation of the run record (only this in-process caller's own wait). Resolve once before
452
+ # the fan-out so workers boot from the pin and skip the GitHub commits API (cold-spawn rate-limit
453
+ # wave). Best-effort, as before.
454
+ spec = _assign_resolved_env_sha(spec)
455
+ if runtime_secrets:
456
+ _run_job(spec, runtime_secrets=runtime_secrets)
457
+ else:
458
+ _run_job(spec)
459
+ return get_status(spec.run_id)
460
+
461
+
462
+ def get_status(run_id: str) -> RunStatus:
463
+ path = runs_file_path(run_id, ".json")
464
+ if not os.path.exists(path):
465
+ raise FileNotFoundError(f"unknown run_id: {run_id}")
466
+ with open(path) as f:
467
+ return RunStatus(**json.load(f))
468
+
469
+
470
+ def list_runs() -> list[RunStatus]:
471
+ os.makedirs(RUNS_DIR, exist_ok=True)
472
+ runs = []
473
+ for name in sorted(os.listdir(RUNS_DIR)):
474
+ if name.endswith(".json"):
475
+ with open(os.path.join(RUNS_DIR, name)) as f:
476
+ runs.append(RunStatus(**json.load(f)))
477
+ return runs
478
+
479
+
480
+ def get_logs(run_id: str) -> str:
481
+ log_path = runs_file_path(run_id, ".log")
482
+ if not os.path.exists(log_path):
483
+ return ""
484
+ with open(log_path) as f:
485
+ return f.read()
486
+
487
+
488
+ def _sanitize_status_value(value, *, depth: int = 0):
489
+ """Bound a heartbeat payload before persisting it in run status JSON."""
490
+ if depth > 5:
491
+ return str(value)[:200]
492
+ if value is None or isinstance(value, (bool, int, float)):
493
+ return value
494
+ if isinstance(value, str):
495
+ return value[:1000]
496
+ if isinstance(value, list):
497
+ return [_sanitize_status_value(v, depth=depth + 1) for v in value[:16]]
498
+ if isinstance(value, dict):
499
+ out = {}
500
+ for i, (k, v) in enumerate(value.items()):
501
+ if i >= 64:
502
+ out["truncated"] = True
503
+ break
504
+ out[str(k)[:120]] = _sanitize_status_value(v, depth=depth + 1)
505
+ return out
506
+ return str(value)[:500]
507
+
508
+
509
+ def record_heartbeat(run_id: str, heartbeat: dict) -> None:
510
+ """Persist the latest worker heartbeat/GPU snapshot without changing run state."""
511
+ if not run_id or not isinstance(heartbeat, dict):
512
+ return
513
+ if not os.path.exists(runs_file_path(run_id, ".json")):
514
+ return
515
+ hb = _sanitize_status_value(heartbeat)
516
+ gpu = (hb.get("gpu") or hb.get("diag")) if isinstance(hb, dict) else None
517
+ with _STATUS_LOCK:
518
+ try:
519
+ status = get_status(run_id)
520
+ except FileNotFoundError:
521
+ return
522
+ status.last_heartbeat = hb
523
+ status.gpu_status = gpu if isinstance(gpu, dict) else None
524
+ status.updated_at = time.time()
525
+ _save_status(status)
526
+ _report_status(status)
527
+
528
+
529
+ def _persist_metrics(spec: JobSpec, seed: int, metrics: dict) -> float:
530
+ """Write metrics to results/runpod/<phase>/<run_id>/seedN and return the cost.
531
+
532
+ The run id keeps concurrent/sequential runs of the same phase+seed from
533
+ overwriting each other's artifacts."""
534
+ dest = os.path.join(artifacts_dir(spec), f"seed{seed}")
535
+ os.makedirs(dest, exist_ok=True)
536
+ # Rate the actually-allocated class, not the parse-time provisional spec.gpu.type:
537
+ # a policy GPU can be re-allocated to a different RunPod class at submit time, so
538
+ # the worker stamps "allocated_gpu" into metrics for the cost fallback below.
539
+ gpu_type = metrics.get("allocated_gpu") or spec.gpu.type
540
+ rate = _gpu_rate(gpu_type)
541
+ cost = metrics.get("cost_usd")
542
+ if cost:
543
+ cost = float(cost or 0.0)
544
+ else:
545
+ wall = float(metrics.get("wall_seconds") or 0.0)
546
+ cost = wall / 3600.0 * rate
547
+ metrics = {**metrics, "cost_usd": cost}
548
+ metrics.setdefault("notes", {})
549
+ if isinstance(metrics["notes"], dict):
550
+ metrics["notes"]["provider"] = "runpod"
551
+ metrics["notes"]["runpod_rate_usd_hr"] = rate
552
+ metrics["notes"]["runpod_gpu"] = gpu_type
553
+ with open(os.path.join(dest, "metrics.json"), "w") as f:
554
+ json.dump(metrics, f, indent=2)
555
+ with contextlib.suppress(Exception):
556
+ from flash.server.run_registry import record_training_checkpoint
557
+
558
+ record_training_checkpoint(spec=spec, seed=seed, metrics=metrics, artifact_path=dest)
559
+ return float(cost)
560
+
561
+
562
+ def _update(run_id: str, state: str, *, allow_from_terminal: bool = False, **updates) -> bool:
563
+ """Atomically transition a run's status, honoring terminal-stickiness.
564
+
565
+ Returns ``True`` if the transition was applied, ``False`` if it was rejected because
566
+ the run was already in a terminal state (the sticky compare-and-set below). Callers
567
+ that gate PAID work on a transition (e.g. the recovery path resuming ``_run_seed_loop``)
568
+ must check this return so a run concurrently flipped terminal does not get resumed.
569
+ """
570
+ # The read-check-write below must be atomic: a concurrent `flash cancel` (also via
571
+ # _update) landing between the get_status read and the _save_status write could
572
+ # otherwise be clobbered by this stale background update, resurrecting a cancelled
573
+ # run. The control plane is single-instance with per-run threads, so a process-wide
574
+ # lock serializes all status transitions into a compare-and-set.
575
+ report_status: RunStatus | None = None
576
+ with _STATUS_LOCK:
577
+ status = get_status(run_id)
578
+ # Terminal states are STICKY: once a run is done/failed/cancelled/dry_run, no
579
+ # other state may overwrite it. This closes the whole cancel-race class at the
580
+ # source — a cancel landing between a caller's check and a later write
581
+ # (provisioning/running, or even a late terminal done/failed from a worker that
582
+ # finished as the cancel arrived) can no longer resurrect the run. Same-state
583
+ # writes still pass so terminal field updates (cost_usd, error, artifacts_dir)
584
+ # are preserved.
585
+ #
586
+ # allow_from_terminal is the NARROW escape hatch used ONLY by cancel_run's final
587
+ # `cancelled` transition, and ONLY when the run was `deployed` at cancel entry (see
588
+ # cancel_run). In that case an explicit user cancel must WIN over a racing
589
+ # mark_undeployed() that flipped the `deployed` run to terminal `done` mid-teardown —
590
+ # that `done` is an undeploy artifact (restoring the pre-deploy completion marker while
591
+ # retiring serving), not a fresh result. Without the override the `cancelled` write
592
+ # no-ops against the freshly-written `done` and the run wrongly ends `done` despite the
593
+ # user asking to cancel. cancel_run passes allow_from_terminal=False for a non-deployed
594
+ # run, so a GENUINE training-completion `done` racing in from the run's own training
595
+ # thread is protected by the CAS below — cancel correctly loses to a real finish.
596
+ if status.state in TERMINAL_STATES and state != status.state and not allow_from_terminal:
597
+ return False
598
+ was_terminal = status.state in TERMINAL_STATES # before this write overwrites updated_at
599
+ prev_updated_at = status.updated_at
600
+ status.state = state
601
+ status.updated_at = time.time()
602
+ # Freeze the training-teardown time on the FIRST terminal transition (and only then) so
603
+ # reconciliation has an immutable run-end even after deploy/heartbeat/reconcile later bump
604
+ # updated_at. A same-state terminal re-write (terminal field updates) keeps the original.
605
+ if state in TERMINAL_STATES and status.finished_at is None:
606
+ # A genuine non-terminal -> terminal transition: the just-set updated_at == teardown.
607
+ # But a LEGACY run (finished_at never stamped) that is ALREADY terminal and gets a
608
+ # same-state field-only touch (e.g. billing_state via _update(run_id, current_state,...))
609
+ # must backfill from the PRE-update updated_at -- the prior persisted terminal time --
610
+ # not the freshly-set now, which would skew run_end / the reconcile window.
611
+ status.finished_at = prev_updated_at if was_terminal else status.updated_at
612
+ for key, value in updates.items():
613
+ setattr(status, key, value)
614
+ _save_status(status)
615
+ report_status = status
616
+ if report_status is not None:
617
+ _report_status(report_status)
618
+ return True
619
+
620
+
621
+ def record_realized_cost(run_id: str, *, realized_cost_usd: float, reconciled_at: float) -> None:
622
+ """Persist reconciliation results (realized COGS + the reconciled marker) WITHOUT touching
623
+ the run's state. Unlike ``_update``, which sets ``state`` from its caller, this re-reads the
624
+ current status under the lock and writes only the two cost columns, so a run that advanced
625
+ (e.g. to ``deployed``) after the reconcile snapshot was taken keeps its current state — the
626
+ background reconciliation job must never revert a live deployment while saving cost fields.
627
+ No-ops if the run vanished. Always allowed: cost is a field-only update on any state."""
628
+ with _STATUS_LOCK:
629
+ try:
630
+ status = get_status(run_id)
631
+ except FileNotFoundError:
632
+ return
633
+ status.realized_cost_usd = realized_cost_usd
634
+ status.reconciled_at = reconciled_at
635
+ status.updated_at = time.time()
636
+ _save_status(status)
637
+ _report_status(status)
638
+
639
+
640
+ def _report_status(status: RunStatus) -> None:
641
+ with contextlib.suppress(Exception):
642
+ from flash.server.run_registry import record_training_run
643
+
644
+ record_training_run(status=status)
645
+
646
+
647
+ def _save_status(status: RunStatus) -> None:
648
+ os.makedirs(RUNS_DIR, exist_ok=True)
649
+ # Write-then-rename: a concurrent reader (poll on /v1/runs or /logs) must
650
+ # never observe a half-written/truncated file and 500 on JSONDecodeError.
651
+ # The temp name is UNIQUE per write (mkstemp) so two threads updating the same
652
+ # run (e.g. a cancel racing the background seed update) can't clobber each
653
+ # other's temp file mid-dump — each os.replace is atomic and independent.
654
+ path = runs_file_path(status.run_id, ".json")
655
+ fd, tmp = tempfile.mkstemp(dir=RUNS_DIR, prefix=f"{status.run_id}.", suffix=".tmp")
656
+ try:
657
+ with os.fdopen(fd, "w") as f:
658
+ json.dump(status.to_dict(), f, indent=2, sort_keys=True)
659
+ os.replace(tmp, path)
660
+ finally:
661
+ with contextlib.suppress(FileNotFoundError):
662
+ os.unlink(tmp)
663
+
664
+
665
+ # Re-export the run-execution and deploy/recover transitions as package-level attributes
666
+ # so external `from flash.runner import X` keeps working AND the test monkeypatches
667
+ # (flash.runner._run_job / ._gc_run_endpoints / .cancel_run ...) resolve here. These imports
668
+ # run AFTER the store layer above is fully defined; lifecycle/deploy import the store via
669
+ # FUNCTION-LOCAL lazy `from flash.runner import ...` to avoid a partially-initialized cycle.
670
+ from flash.runner.deploy import ( # noqa: E402,F401
671
+ attach_checkpoint_deployment,
672
+ attach_run,
673
+ cancel_run,
674
+ mark_deployed,
675
+ mark_deployment_undeployed,
676
+ mark_undeployed,
677
+ resume_run,
678
+ )
679
+ from flash.runner.lifecycle import ( # noqa: E402,F401
680
+ _gc_run_endpoints,
681
+ _run_job,
682
+ _run_job_inner,
683
+ _run_seed_loop,
684
+ _spec_with_gpu,
685
+ _submit_seed_supervised,
686
+ )