freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,517 @@
1
+ """Self-contained bootstrap shared by the instance-based providers (Lambda, Hyperstack).
2
+
3
+ Runs INSIDE the worker container on a rented GPU instance. Both providers' cloud-init ``user_data``
4
+ runs the prebuilt, PUBLIC ``WORKER_IMAGE`` via Docker on the host, and this module is the
5
+ container's command: install the run's extra pip deps, fetch the flash package from the HF dataset
6
+ repo, then run the substrate-neutral worker (``flash.engine.worker``) to train, uploading the
7
+ console tail to HF.
8
+
9
+ There is NO return channel from the instance: the worker's HF artifacts
10
+ (DONE/metrics.json/heartbeat.json) are the success signal, and the attempt-scoped
11
+ ``<arm>_attempt<N>.json`` marker (``arm`` = the substrate, e.g. ``lambda``/``hyperstack``) is the
12
+ terminal marker the control plane keys failures on. The full training stack is BAKED into the
13
+ image, so there is no base-stack install here — only the per-run ``extra_pip``.
14
+
15
+ Shipped verbatim inside the container command, so it must stay self-contained: stdlib +
16
+ huggingface_hub (baked into the image) only — never import flash here. It reads its payload from
17
+ ``/root/flash/payload.json``; the substrate name travels in ``payload["flash_arm"]``.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import contextlib
23
+ import json
24
+ import os
25
+ import signal
26
+ import subprocess
27
+ import sys
28
+ import threading
29
+ import time
30
+
31
+ PAYLOAD_PATH = "/root/flash/payload.json"
32
+ CODE_ROOT = "/runcode"
33
+ CODE_DIR = "/runcode/code"
34
+
35
+
36
+ class RetriableBootstrapError(RuntimeError):
37
+ """An infra-shaped bootstrap failure that should RETRY on a fresh host, not fail the run.
38
+
39
+ The control-plane pollers classify an attempt marker carrying ``retriable=True`` as
40
+ ``job_preempted`` (retried within the HF infra-retry budget) rather than ``job_failed`` (fails
41
+ fast). The bootstrap is self-contained (can't import the worker's ``RetriableInfraError``), so
42
+ this local sentinel marks the same shape: an HF-side failure (the spilled-spec fetch, or a
43
+ required-artifact upload that never landed) that a retry on a healthy host would clear. ``main``
44
+ keys the marker's ``retriable`` flag off whether the raised error is an instance of this."""
45
+
46
+
47
+ def load_payload() -> dict:
48
+ with open(PAYLOAD_PATH) as f:
49
+ return json.load(f)
50
+
51
+
52
+ def _arm(payload: dict) -> str:
53
+ return str(payload.get("flash_arm") or "instance")
54
+
55
+
56
+ def hf_upload(payload: dict, local_path: str, repo_subpath: str) -> None:
57
+ """Upload one artifact under the run's HF prefix; never raises."""
58
+ try:
59
+ from huggingface_hub import HfApi
60
+
61
+ HfApi(token=(payload.get("env") or {}).get("HF_TOKEN")).upload_file(
62
+ path_or_fileobj=local_path,
63
+ path_in_repo=f"{payload['hf_prefix']}/{repo_subpath}",
64
+ repo_id=payload["hf_repo"],
65
+ repo_type="dataset",
66
+ )
67
+ except Exception as exc:
68
+ print(f"hf upload warn ({repo_subpath}): {exc}", flush=True)
69
+
70
+
71
+ def hf_file_exists(payload: dict, repo_subpath: str) -> bool:
72
+ """True iff ``<hf_prefix>/<repo_subpath>`` exists in the run's HF dataset repo.
73
+
74
+ Used to confirm a worker's REQUIRED completion artifacts actually reached HF before the
75
+ bootstrap treats a non-zero worker exit as success — a local /tmp/metrics.json is NOT proof,
76
+ since the worker writes it locally before the (required, retried) upload that can still fail
77
+ infra-shaped. Raises on a genuine API error so the caller can be conservative."""
78
+ from huggingface_hub import HfApi
79
+
80
+ api = HfApi(token=(payload.get("env") or {}).get("HF_TOKEN"))
81
+ return api.file_exists(
82
+ repo_id=payload["hf_repo"],
83
+ filename=f"{payload['hf_prefix']}/{repo_subpath}",
84
+ repo_type="dataset",
85
+ )
86
+
87
+
88
+ def remote_completion_confirmed(payload: dict) -> bool:
89
+ """True iff the worker's required completion artifacts (DONE + metrics.json) are on HF.
90
+
91
+ The worker uploads metrics.json then DONE, both ``required=True`` (3 retries, then raises a
92
+ RetriableInfraError -> non-zero exit). Confirming the REMOTE artifacts — not just the local
93
+ /tmp/metrics.json — is the only proof the run actually finished; a transient upload failure
94
+ after the local file exists must propagate as a retriable failure, not a false ok=true."""
95
+ try:
96
+ return hf_file_exists(payload, "DONE") and hf_file_exists(payload, "metrics.json")
97
+ except Exception as exc:
98
+ # A read error here is itself infra-shaped; stay conservative (treat as unconfirmed) so a
99
+ # non-zero worker exit propagates and retries rather than masking the failure.
100
+ print(f"remote-completion check warn: {exc}", flush=True)
101
+ return False
102
+
103
+
104
+ def fetch_spec_from_hf(payload: dict) -> str:
105
+ """Download the run's spec spilled out-of-band to HF (``<hf_prefix>/job_spec.json``).
106
+
107
+ A large inline job spec (100s of KB of env params) would blow the provider's cloud-init
108
+ ``user_data`` size cap and get the launch rejected before any handle is persisted, so the
109
+ control plane (``_instance.build_user_data``) keeps it OUT of ``user_data`` and uploads it to
110
+ the run's HF dataset repo instead, leaving only a sentinel in the payload. The bootstrap fetches
111
+ it here (the code is already fetched from the same repo, so this adds no new dependency)."""
112
+ from huggingface_hub import hf_hub_download
113
+
114
+ local = hf_hub_download(
115
+ repo_id=payload["hf_repo"],
116
+ repo_type="dataset",
117
+ filename=f"{payload['hf_prefix']}/job_spec.json",
118
+ token=(payload.get("env") or {}).get("HF_TOKEN"),
119
+ )
120
+ with open(local) as f:
121
+ return f.read()
122
+
123
+
124
+ def build_worker_env(payload: dict) -> dict:
125
+ env = dict(os.environ)
126
+ env.update({k: str(v) for k, v in (payload.get("env") or {}).items()})
127
+ # The job spec may have been spilled to HF at launch (a large inline spec would overflow the
128
+ # provider's cloud-init user_data cap); fetch it here when only the sentinel rode in the payload.
129
+ spec_json = payload.get("job_spec_json")
130
+ if not spec_json and payload.get("job_spec_in_hf"):
131
+ # This fetch is the FIRST HF round-trip in the bootstrap and runs pre-worker (so the worker
132
+ # never starts and can't stamp a retriable heartbeat). Any failure here is infra-shaped, so
133
+ # surface it as RetriableBootstrapError — otherwise main() would mark ok=false with no
134
+ # retriable flag and the poller would fail the run fast (job_failed) instead of retrying it
135
+ # on a fresh host (job_preempted). A permanently missing/unreadable spec (a control-plane bug
136
+ # on the rare spilled-spec path) just burns the bounded infra-retry budget, then fails.
137
+ try:
138
+ spec_json = fetch_spec_from_hf(payload)
139
+ except Exception as e:
140
+ raise RetriableBootstrapError(
141
+ f"failed to fetch the spilled job spec from HF: {e}"
142
+ ) from e
143
+ if not spec_json:
144
+ # Neither an inline spec NOR the spilled-to-HF sentinel rode in the payload: a malformed
145
+ # payload (the control plane always sets exactly one). Fail loudly with the cause instead of
146
+ # crashing on the len(None) below with an opaque TypeError that buries the real problem.
147
+ raise RuntimeError(
148
+ "bootstrap payload carries no job spec: both job_spec_json and the job_spec_in_hf "
149
+ "sentinel are absent/empty — the control plane built an invalid worker payload"
150
+ )
151
+ # Pass a large spec via a file, not the environment: a job spec with large inline params can
152
+ # reach hundreds of KB, which trips execve's "Argument list too long". Mirrors
153
+ # runpod/train.py:_train_body.
154
+ if len(spec_json) > 96_000:
155
+ with open("/tmp/job_spec.json", "w") as f:
156
+ f.write(spec_json)
157
+ env["FLASH_JOB_SPEC_PATH"] = "/tmp/job_spec.json"
158
+ env.pop("FLASH_JOB_SPEC_JSON", None)
159
+ else:
160
+ env["FLASH_JOB_SPEC_JSON"] = spec_json
161
+ env["PHASE"] = payload["phase"]
162
+ env["SEED"] = str(payload["seed"])
163
+ # Compute substrate for the RunMetrics record (engine.worker reads FLASH_ARM). The payload env
164
+ # was built by the shared runpod env builder, which stamps "runpod"; this bootstrap runs on the
165
+ # rented instance, so override it to the real backend carried in the payload.
166
+ env["FLASH_ARM"] = _arm(payload)
167
+ env["PYTHONPATH"] = CODE_DIR + (os.pathsep + env["PYTHONPATH"] if env.get("PYTHONPATH") else "")
168
+ return env
169
+
170
+
171
+ def fetch_code(payload: dict) -> None:
172
+ from huggingface_hub import snapshot_download
173
+
174
+ snapshot_download(
175
+ repo_id=payload["hf_repo"],
176
+ repo_type="dataset",
177
+ allow_patterns=["code/**"],
178
+ local_dir=CODE_ROOT,
179
+ token=(payload.get("env") or {}).get("HF_TOKEN"),
180
+ )
181
+
182
+
183
+ def run_mode(payload: dict, env: dict, mode: str, deadline_ts: float) -> int:
184
+ """One worker process; console teed to a file and streamed to the container log.
185
+
186
+ On failure/SUCCESS (FLASH_UPLOAD_CONSOLE) the console tail is uploaded as console_<mode>.txt.
187
+ On deadline the process is killed and we raise.
188
+ """
189
+ console = f"/tmp/console_{mode}.txt"
190
+ timed_out = False
191
+ upload_enabled = env.get("FLASH_UPLOAD_CONSOLE", "").strip().lower() not in (
192
+ "", "0", "false", "no", "off",
193
+ )
194
+ upload_interval = max(5.0, float(env.get("FLASH_CONSOLE_UPLOAD_INTERVAL_S") or 30.0))
195
+
196
+ def upload_console_tail(extra: str = "") -> None:
197
+ tail_path = console + ".tail"
198
+ # Seek to the last 64k instead of reading the whole file: on long-running jobs the
199
+ # console grows unbounded and this runs on a periodic loop, so an O(n) read each pass
200
+ # would balloon the bootstrap container's memory/time. Read binary + decode with
201
+ # errors="replace" so a seek landing mid-UTF-8-sequence can't raise.
202
+ with open(console, "rb") as f:
203
+ f.seek(0, os.SEEK_END)
204
+ f.seek(max(0, f.tell() - 64_000))
205
+ tail = f.read().decode("utf-8", "replace")
206
+ if extra:
207
+ tail += extra
208
+ with open(tail_path, "w") as f:
209
+ f.write(tail)
210
+ hf_upload(payload, tail_path, f"console_{mode}.txt")
211
+
212
+ stop_upload = threading.Event()
213
+
214
+ def upload_loop() -> None:
215
+ while not stop_upload.wait(upload_interval):
216
+ try:
217
+ upload_console_tail()
218
+ except Exception as exc:
219
+ print(f"console upload warn: {exc}", flush=True)
220
+
221
+ uploader = None
222
+ with open(console, "w", buffering=1) as cf:
223
+ proc = subprocess.Popen(
224
+ [sys.executable, "-m", "flash.engine.worker"],
225
+ cwd=CODE_DIR,
226
+ env={**env, "RUN_MODE": mode},
227
+ stdout=subprocess.PIPE,
228
+ stderr=subprocess.STDOUT,
229
+ text=True,
230
+ )
231
+
232
+ def pump():
233
+ for line in proc.stdout:
234
+ print(line, end="", flush=True)
235
+ cf.write(line)
236
+
237
+ t = threading.Thread(target=pump, daemon=True)
238
+ t.start()
239
+ if upload_enabled:
240
+ uploader = threading.Thread(target=upload_loop, daemon=True)
241
+ uploader.start()
242
+ try:
243
+ # Honor the wall-clock deadline: wait only up to the time left (floored to a small
244
+ # positive so the call never blocks forever on a 0/negative timeout). A prior ``max(10.0,
245
+ # …)`` floor could overshoot the deadline by ~10s when little/no time remained — that
246
+ # leftover 10s is paid GPU time past the run's wall cap, so we clamp to the remaining
247
+ # budget instead.
248
+ proc.wait(timeout=max(1.0, deadline_ts - time.time()))
249
+ except subprocess.TimeoutExpired:
250
+ timed_out = True
251
+ proc.kill()
252
+ proc.wait()
253
+ t.join(timeout=10)
254
+ if uploader is not None:
255
+ stop_upload.set()
256
+ uploader.join(timeout=10)
257
+ if proc.returncode != 0 or timed_out or upload_enabled:
258
+ try:
259
+ extra = ""
260
+ if timed_out:
261
+ extra = f"\n--- bootstrap: mode '{mode}' hit the wall-clock cap; killed ---\n"
262
+ upload_console_tail(extra)
263
+ except Exception as exc:
264
+ print(f"console upload warn: {exc}", flush=True)
265
+ if timed_out:
266
+ raise TimeoutError(f"worker mode '{mode}' exceeded the wall-clock cap")
267
+ return proc.returncode
268
+
269
+
270
+ def write_attempt_marker(payload: dict, ok: bool, error: str = "", retriable: bool = False) -> None:
271
+ """Attempt-scoped terminal marker (``<arm>_attempt<N>.json``): how the control plane
272
+ distinguishes THIS attempt's failure from a prior attempt's leftovers under the same prefix.
273
+
274
+ ``retriable`` stamps the same flag the host failmark and the worker heartbeat use: the pollers
275
+ read ``marker.get("retriable")`` and classify a flagged failure as ``job_preempted`` (retried on
276
+ a fresh host within the HF infra budget) instead of ``job_failed`` (fails fast). Set it for
277
+ infra-shaped bootstrap failures (HF fetch/upload) so an HF outage doesn't burn the run."""
278
+ marker = {
279
+ "ok": bool(ok),
280
+ "ts": time.time(),
281
+ "attempt": int(payload.get("attempt") or 0),
282
+ "retriable": bool(retriable),
283
+ "error": error[:2000],
284
+ }
285
+ p = "/tmp/attempt_marker.json"
286
+ with open(p, "w") as f:
287
+ json.dump(marker, f)
288
+ hf_upload(payload, p, f"{_arm(payload)}_attempt{marker['attempt']}.json")
289
+
290
+
291
+ def _arm_preload_wall_cap(payload: dict) -> tuple[threading.Timer, threading.Event] | None:
292
+ """Arm the preload path's wall-clock cap. The training path enforces ``max_wall_s`` by
293
+ ``run_mode`` killing the worker SUBPROCESS on its deadline, but ``run_preload`` runs
294
+ ``snapshot_download`` IN-PROCESS — a hung Hub download (or a stalled NIC) has no subprocess to
295
+ time out, so without this the paid Lambda/Hyperstack box can keep running long past ``timeout_s``
296
+ (the control-plane driver's ``terminate_run_instances`` only fires if that driver process is
297
+ still alive, and nothing on the box self-terminates). Mirror the deadline here: a daemon timer
298
+ writes a terminal failure marker (so the warm driver stops polling and the box can be reaped) and
299
+ HARD-exits the process — ``os._exit`` because a blocked C-level socket read in ``snapshot_download``
300
+ can't be unwound by a Python exception/signal. Returns ``(timer, done)``: the caller cancels the
301
+ timer AND sets ``done`` on a clean finish, so a wall expiry racing that finish no-ops in _fire."""
302
+ wall_s = float(payload.get("max_wall_s") or 0)
303
+ if wall_s <= 0:
304
+ return None
305
+ # Set by the caller the instant ``run_preload`` returns cleanly. ``Timer.cancel()`` cannot stop an
306
+ # _fire that is ALREADY RUNNING, so without this guard a wall expiry racing a successful finish
307
+ # would still upload an ok=false marker + ``os._exit(1)`` over a warmed cache (the warm driver then
308
+ # reports failure). _fire checks it first and no-ops when the preload already completed.
309
+ done = threading.Event()
310
+
311
+ def _fire() -> None:
312
+ if done.is_set():
313
+ return
314
+ msg = f"preload exceeded the wall-clock cap ({int(wall_s)}s); self-terminating box"
315
+ print(f"FLASH: {msg}", flush=True)
316
+ # Best-effort terminal marker so the driver/sweeper sees a terminal failure instead of polling
317
+ # to its own timeout. The wall cap often fires BECAUSE the Hub/NIC is hung (the main thread is
318
+ # stuck in snapshot_download), and write_attempt_marker does a blocking HF upload — running it
319
+ # inline here would wedge the timer thread on that same hung network and the paid VM would
320
+ # NEVER self-terminate, defeating the wall cap. So attempt the marker on a SEPARATE daemon
321
+ # thread, join it only briefly, then HARD-exit regardless of whether the upload finished. The
322
+ # marker is best-effort; the driver's own poll-timeout still frees the box if it's lost.
323
+ def _mark() -> None:
324
+ with contextlib.suppress(Exception):
325
+ write_attempt_marker(payload, ok=False, error=msg)
326
+
327
+ marker_thread = threading.Thread(target=_mark, daemon=True)
328
+ marker_thread.start()
329
+ marker_thread.join(timeout=8.0)
330
+ os._exit(1)
331
+
332
+ timer = threading.Timer(wall_s, _fire)
333
+ timer.daemon = True
334
+ timer.start()
335
+ return timer, done
336
+
337
+
338
+ def run_preload(payload: dict) -> dict:
339
+ """Download-only warm: pull the requested models into the bind-mounted cache (HF_HOME) and exit.
340
+
341
+ The instance-provider mirror of runpod/train/endpoints._train_body's ``preload`` branch. NO
342
+ training, NO env code, NO worker subprocess — just ``snapshot_download`` straight into the cache so
343
+ the very first real run in this region is warm. HF_HOME (from the payload env) is rooted at the
344
+ per-region bind-mounted cache mount; we pass ``cache_dir`` EXPLICITLY (huggingface_hub freezes
345
+ HF_HUB_CACHE at import, so setting the env var here would be too late) and FAIL if the cache isn't
346
+ mounted (otherwise we'd warm ephemeral local disk that vanishes with the box).
347
+ """
348
+ env = payload.get("env") or {}
349
+ hf_home = env.get("HF_HOME") or ""
350
+ token = env.get("HF_TOKEN")
351
+ # The cache bind-mount must be present; HF_HOME is <mount>/hf-cache, so its parent is the mount.
352
+ # Checked BEFORE importing huggingface_hub so a missing mount fails fast (and stays testable).
353
+ mount = os.path.dirname(hf_home.rstrip("/")) if hf_home else ""
354
+ if not hf_home or not mount or not os.path.isdir(mount):
355
+ return {"preloaded": [], "already_cached": [], "failed": {},
356
+ "error": f"weight-cache not mounted (HF_HOME={hf_home!r}); refusing to warm ephemeral disk"}
357
+ # Require the mount sentinel for BOTH substrates. The cloud-init preamble drops it ONLY onto a
358
+ # verified-real mount: block-volume (Hyperstack) writes it after the device mounts; NFS (Lambda)
359
+ # writes it after confirming ``mountpoint`` (the platform auto-mount actually took). It is therefore
360
+ # visible here only when the REAL cache is mounted. Without it, Docker's ``-v`` bind silently
361
+ # auto-creates an EMPTY host dir -> the mount exists (isdir passes) but the sentinel is absent, which
362
+ # would otherwise warm EPHEMERAL disk (discarded at teardown) yet report a successful warm. The
363
+ # marker filename flows in via the payload from _instance.CACHE_MOUNT_MARKER (ONE source of truth);
364
+ # the literal is only a defensive fallback for an older payload that predates the field. A
365
+ # cache-attached preload payload always carries cache_mount_marker, so absence of the field is
366
+ # treated as "no sentinel expected" only when no cache mount was requested at all.
367
+ if payload.get("cache_mount_marker"):
368
+ marker = os.path.join(mount, payload["cache_mount_marker"])
369
+ if not os.path.exists(marker):
370
+ kind = "block volume" if payload.get("cache_block_device") else "NFS filesystem"
371
+ return {"preloaded": [], "already_cached": [], "failed": {},
372
+ "error": (f"weight-cache {kind} not mounted (no sentinel at {marker}); "
373
+ "refusing to warm ephemeral disk")}
374
+ from huggingface_hub import snapshot_download
375
+
376
+ cache_dir = os.path.join(hf_home, "hub")
377
+ # weights + tokenizer/config only (same exclusions as prefetch_model / the image bake / the RunPod
378
+ # preload branch) so the warmed cache matches exactly what workers later fetch.
379
+ ignore_patterns = ["*.pth", "*.gguf", "original/*", "*.onnx", "*.msgpack", "*.h5"]
380
+ done, already, failed = [], [], {}
381
+ for repo_id in payload.get("models") or []:
382
+ try:
383
+ # Idempotent: probe with local_files_only (HF's own resolution, NOT a dir-name guess) — if
384
+ # the snapshot is already on the volume, skip the network download. Mirrors the RunPod
385
+ # preload branch; accurate (no repo_id.replace heuristic) and avoids re-downloading.
386
+ try:
387
+ snapshot_download(repo_id=repo_id, token=token, cache_dir=cache_dir,
388
+ ignore_patterns=ignore_patterns, local_files_only=True)
389
+ already.append(repo_id)
390
+ print(f"preload: {repo_id} -> {cache_dir} (cached)", flush=True)
391
+ continue
392
+ except Exception:
393
+ pass
394
+ snapshot_download(repo_id=repo_id, token=token, cache_dir=cache_dir,
395
+ ignore_patterns=ignore_patterns)
396
+ done.append(repo_id)
397
+ print(f"preload: {repo_id} -> {cache_dir} (downloaded)", flush=True)
398
+ except Exception as exc:
399
+ failed[repo_id] = str(exc)
400
+ print(f"preload FAILED {repo_id}: {exc}", flush=True)
401
+ return {"preloaded": done, "already_cached": already, "failed": failed}
402
+
403
+
404
+ def main() -> int:
405
+ # Make SIGTERM (docker stop / wall-cap) unwind through finally so the terminal marker still
406
+ # gets uploaded.
407
+ signal.signal(signal.SIGTERM, lambda *a: sys.exit(1))
408
+ payload = load_payload()
409
+ ok = False
410
+ error = ""
411
+ retriable = False
412
+ try:
413
+ # hf_transfer is baked into the worker image; enable it so model pulls saturate the NIC.
414
+ try:
415
+ import importlib.util
416
+
417
+ if importlib.util.find_spec("hf_transfer") is not None:
418
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
419
+ except Exception as _e:
420
+ print("hf_transfer setup skipped:", _e)
421
+ # Preload (warm) mode: download-only into the mounted cache, then exit. No code fetch, no
422
+ # extra_pip, no worker subprocess — the warm driver (warm_instances) detects completion by
423
+ # polling the `<prefix>/preload_result.json` we upload just below (the attempt marker is still
424
+ # written in the finally, but it is NOT the preload completion signal).
425
+ if payload.get("mode") == "preload":
426
+ # Enforce the wall cap on the in-process download (the training path enforces it on the
427
+ # worker subprocess via run_mode; preload has no subprocess, so arm a watchdog here).
428
+ wall_cap = _arm_preload_wall_cap(payload)
429
+ try:
430
+ result = run_preload(payload)
431
+ finally:
432
+ if wall_cap is not None:
433
+ wall_timer, wall_done = wall_cap
434
+ # Mark done FIRST so a wall expiry racing this clean finish no-ops in _fire, THEN
435
+ # cancel the timer (cancel can't stop an _fire that is already in flight).
436
+ wall_done.set()
437
+ wall_timer.cancel()
438
+ with open("/tmp/preload_result.json", "w") as f:
439
+ json.dump(result, f)
440
+ # preload_result.json is the AUTHORITATIVE completion signal the warm driver polls — a
441
+ # single transient Hub blip on this one upload (hf_upload swallows it) would silently drop
442
+ # it, leaving the driver to poll to its full timeout then terminate an already-warmed box
443
+ # and report it timed out. Best-effort RETRY a few times (bounded — never block forever) so
444
+ # a transient blip doesn't lose the completion file. Still NON-FATAL: the box exits success
445
+ # after the retries even if every one fails (the driver's terminal attempt-marker handling
446
+ # is the backstop), but log loudly so a persistent failure is observable.
447
+ for attempt in range(3):
448
+ hf_upload(payload, "/tmp/preload_result.json", "preload_result.json")
449
+ try:
450
+ if hf_file_exists(payload, "preload_result.json"):
451
+ break
452
+ except Exception as exc:
453
+ print(f"preload_result.json upload confirm warn: {exc}", flush=True)
454
+ if attempt < 2:
455
+ time.sleep(2.0 * (attempt + 1))
456
+ else:
457
+ print("preload_result.json upload FAILED after 3 attempts (completion file may be "
458
+ "missing; driver falls back to the attempt marker)", flush=True)
459
+ ok = not result.get("error") and not result.get("failed")
460
+ error = result.get("error") or (f"models failed: {sorted(result.get('failed') or {})}" if result.get("failed") else "")
461
+ return 0 if ok else 1
462
+ # The base training stack is baked into WORKER_IMAGE; only the per-run extras install here
463
+ # (the verifiers/Freesolo env wheel + the chalk kernels) — exactly the payload's extra_pip.
464
+ extra_pip = payload.get("extra_pip") or []
465
+ if extra_pip:
466
+ subprocess.run([sys.executable, "-m", "pip", "install", *extra_pip], check=True)
467
+ fetch_code(payload)
468
+ env = build_worker_env(payload)
469
+ deadline = time.time() + float(payload.get("max_wall_s") or 24 * 3600)
470
+ phase = payload["phase"]
471
+ for stale in ("/tmp/train_meta.json", "/tmp/metrics.json"):
472
+ with contextlib.suppress(FileNotFoundError):
473
+ os.remove(stale)
474
+ # Train. A non-zero rc is tolerated ONLY when the run genuinely finished: RL's colocated
475
+ # vLLM can segfault at interpreter exit AFTER the adapter + metrics.json + DONE are saved
476
+ # AND uploaded. The local /tmp/metrics.json is NOT sufficient proof — the worker writes it
477
+ # locally before the required (retried) upload, so a transient RetriableInfraError uploading
478
+ # metrics.json/DONE leaves the local file present yet the run UNFINISHED (no remote
479
+ # artifacts). In that case the worker exits non-zero; honor it and let the run retry instead
480
+ # of stamping a false ok=true.
481
+ rc = run_mode(payload, env, phase, deadline)
482
+ if not os.path.exists("/tmp/metrics.json"):
483
+ raise RuntimeError(
484
+ f"train phase '{phase}' produced no /tmp/metrics.json (it crashed before "
485
+ f"finishing); see error_{phase}.txt and console_{phase}.txt in the HF dataset repo"
486
+ )
487
+ if rc != 0 and not remote_completion_confirmed(payload):
488
+ # The local metrics.json exists but the REQUIRED uploads (DONE/metrics.json) never landed
489
+ # on HF — an upload/HF-infra failure, not a code error. Surface it as retriable so the
490
+ # poller retries on a fresh host (job_preempted) within the HF infra budget instead of
491
+ # failing the run fast. During a full HF outage the worker's own retriable heartbeat may
492
+ # also be missing, so the marker's retriable flag is what carries the classification.
493
+ raise RetriableBootstrapError(
494
+ f"train phase '{phase}' exited non-zero ({rc}) and its required completion "
495
+ f"artifacts (DONE/metrics.json) are not on HF — the run did not finish (e.g. a "
496
+ f"failed upload after the local metrics.json was written); see error_{phase}.txt "
497
+ f"and console_{phase}.txt in the HF dataset repo"
498
+ )
499
+ ok = True
500
+ except BaseException as exc: # incl. SIGTERM's SystemExit / KeyboardInterrupt
501
+ # SIGTERM (docker stop / wall cap) raises SystemExit via the handler above; catching only
502
+ # Exception would skip it, uploading an ok=false marker with an EMPTY error and obscuring the
503
+ # cause from reattach/debugging. BaseException records a useful error and still re-exits
504
+ # nonzero (return 1) with the marker written in `finally`.
505
+ error = f"{type(exc).__name__}: {exc}"
506
+ # An infra-shaped bootstrap failure (the pre-worker spilled-spec HF fetch, or a required
507
+ # artifact that never uploaded) is raised as RetriableBootstrapError so the marker carries
508
+ # retriable=True and the poller retries on a fresh host instead of failing the run fast.
509
+ retriable = isinstance(exc, RetriableBootstrapError)
510
+ print(f"bootstrap failed: {error}", flush=True)
511
+ finally:
512
+ write_attempt_marker(payload, ok, error, retriable=retriable)
513
+ return 0 if ok else 1
514
+
515
+
516
+ if __name__ == "__main__":
517
+ sys.exit(main())