freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,724 @@
1
+ """Lambda Cloud run lifecycle: capacity walk -> launch -> HF-artifact poll -> guaranteed terminate.
2
+
3
+ The Lambda equivalent of ``providers/runpod/jobs.py``. Lambda has no serverless queue: we launch a
4
+ single-GPU instance from a region with capacity, ship a self-contained cloud-init ``user_data``
5
+ (``builders.build_user_data``) that runs the prebuilt ``WORKER_IMAGE`` via Docker, and detect
6
+ completion purely via the worker's HF artifacts (DONE/metrics.json/heartbeat.json) + the instance's
7
+ status — no inbound network to the box is ever needed.
8
+
9
+ Cost-safety invariant: a launched instance is ALWAYS terminated — the runner's ``finally``, the
10
+ poll deadline, the cancel path, and ``sweep_orphans`` (server startup / post-run) each independently
11
+ guarantee it. Lambda has no instance-scoped key, so (unlike Vast) there is no in-box self-destruct;
12
+ ``sweep_orphans`` at control-plane startup is the crash backstop.
13
+
14
+ The pure dataclasses + builders live in ``.builders`` and are re-exported here so the import path
15
+ ``flash.providers.lambdalabs.jobs`` is unchanged. The lifecycle functions and the constants tests
16
+ monkeypatch stay in this ``__init__`` so a ``monkeypatch.setattr(jobs, …)`` still takes effect.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import contextlib
22
+ import json
23
+ import time
24
+ from collections.abc import Callable
25
+
26
+ from flash._logging import get_logger
27
+ from flash.providers._poll import (
28
+ PollErrorTracker,
29
+ heartbeat_progress_ts,
30
+ make_say,
31
+ preload_box_reap_due,
32
+ surface_heartbeat,
33
+ )
34
+ from flash.providers.base import GPU_INFO, PollResult
35
+ from flash.providers.lambdalabs import api as lambda_api
36
+ from flash.providers.lambdalabs.jobs.builders import (
37
+ LambdaInstance,
38
+ LambdaJobHandle,
39
+ build_payload,
40
+ build_user_data,
41
+ instance_label,
42
+ run_label_prefix,
43
+ )
44
+ from flash.providers.runpod.jobs import make_hf_heartbeat_reader, make_hf_text_reader
45
+
46
+ logger = get_logger(__name__)
47
+
48
+ # How long an instance may sit in a non-active state (provisioning) before we give up and retry.
49
+ LOAD_TIMEOUT_S = 900.0
50
+ # No-progress window once the instance is active. The cold start on Lambda is dominated by the
51
+ # Docker image pull on a fresh host + per-run pip install + model download, none of which emits a
52
+ # heartbeat — so until a *training* heartbeat arrives we apply the larger ``SETUP_GRACE_S`` budget;
53
+ # after it we use the tight ``STALL_AFTER_S``.
54
+ SETUP_GRACE_S = 3000.0
55
+ STALL_AFTER_S = 1500.0
56
+ # Provision + cold-start grace added on top of the run's wall cap for the client-side poll deadline
57
+ # (Lambda has no server-side execution timeout, so the client deadline + the bootstrap's own cap
58
+ # bound spend). Larger than RunPod's because of the on-host Docker pull.
59
+ PROVISION_GRACE_S = 3000.0
60
+
61
+ # Heartbeat stages emitted DURING cold start, before the training loop begins. Receiving one proves
62
+ # the worker is alive but NOT that setup finished, so they keep the larger setup grace (cf. RunPod).
63
+ _SETUP_HEARTBEAT_STAGES = frozenset(
64
+ {"boot", "sft_start", "rl_start", "sft_model_load", "rl_train_start"}
65
+ )
66
+
67
+ # Lambda instance statuses that mean "the box is gone / will not progress".
68
+ _DEAD_STATES = {"terminated", "terminating", "preempted", "unhealthy"}
69
+
70
+
71
+ def resolve_ssh_key_names() -> list[str]:
72
+ """The (single) SSH key name to attach at launch.
73
+
74
+ Lambda REQUIRES exactly one SSH key on every launch, even though the box is bootstrapped via
75
+ cloud-init ``user_data`` and we never SSH in. Resolve it from ``LAMBDA_SSH_KEY_NAME`` if set,
76
+ else the first key registered on the account. Raises a clear error if the account has none.
77
+ """
78
+ import os
79
+
80
+ pinned = os.environ.get("LAMBDA_SSH_KEY_NAME")
81
+ if pinned:
82
+ return [pinned]
83
+ keys = lambda_api.list_ssh_keys()
84
+ names = [k.get("name") for k in keys if k.get("name")]
85
+ if not names:
86
+ raise lambda_api.LambdaApiError(
87
+ "Lambda launch requires an SSH key on the account, but none are registered and "
88
+ "LAMBDA_SSH_KEY_NAME is unset; add one in the Lambda console (the box is bootstrapped "
89
+ "via user_data, so the key is unused — any key works)."
90
+ )
91
+ return [names[0]]
92
+
93
+
94
+ def usable_instances(gpu_class: str, force: bool = False) -> list[LambdaInstance]:
95
+ """Launchable (region) candidates for a managed GPU class, only where capacity exists now.
96
+
97
+ Lambda prices per instance type (not per region), so every candidate for a class carries the
98
+ same $/hr; the list is the set of regions currently advertising capacity. Empty == the class
99
+ has no Lambda capacity right now (the allocator skips it; a mid-run vanish is handled by the
100
+ region walk + the runner's retry). ``force`` bypasses the ``/instance-types`` cache — used by the
101
+ in-launch refresh so it can actually discover newly-freed regions rather than re-reading the
102
+ just-populated allocation cache.
103
+ """
104
+ from flash.providers.lambdalabs.gpus import instance_type_for
105
+ from flash.providers.lambdalabs.pricing import hourly_rate
106
+
107
+ info = GPU_INFO[gpu_class]
108
+ itype = instance_type_for(gpu_class)
109
+ rate = hourly_rate(gpu_class)
110
+ return [
111
+ LambdaInstance(
112
+ gpu=gpu_class,
113
+ instance_type=itype,
114
+ region=region,
115
+ vram_gb=info.vram_gb,
116
+ price_usd_hr=rate,
117
+ )
118
+ for region in lambda_api.regions_with_capacity(itype, force=force)
119
+ ]
120
+
121
+
122
+ def _launch_rejection_is_clean(err: Exception) -> bool:
123
+ """True when a launch error is a DEFINITIVE rejection that created NO instance (safe to walk to
124
+ the next region). The shared RestClient fast-fails a non-429 4xx as ``... -> HTTP 4xx: ...``
125
+ (the provider rejected the request outright, e.g. no capacity). Anything else — a 429
126
+ (rate-limited), a 5xx / timeout (``failed after N attempts``), or a 2xx whose response lacked an
127
+ id (``returned no instance id``) — is AMBIGUOUS: the provider may have created a billed instance,
128
+ so we must NOT issue another launch."""
129
+ s = str(err)
130
+ return "-> HTTP 4" in s and "HTTP 429" not in s
131
+
132
+
133
+ def launch_and_submit(
134
+ spec,
135
+ seed: int,
136
+ instances: list[LambdaInstance],
137
+ attempt: int = 0,
138
+ log=None,
139
+ runtime_secrets: dict | None = None,
140
+ mode: str | None = None,
141
+ models: list | None = None,
142
+ ) -> LambdaJobHandle:
143
+ """Launch the first region that accepts the job; walk regions on a capacity rejection.
144
+
145
+ Capacity is a live market — between the allocator's capacity check and the launch the only
146
+ region with capacity is often taken. We walk every advertised region, then refresh the capacity
147
+ list once.
148
+
149
+ ``mode="preload"`` + ``models`` launches a download-only warm (the bootstrap pulls the models into
150
+ the mounted cache and exits — no worker); the cache user_data carries the preload payload.
151
+ """
152
+ say = make_say(log)
153
+ if not instances:
154
+ raise lambda_api.LambdaApiError(
155
+ f"no Lambda capacity for {spec.gpu.type} (no region advertises the instance type)"
156
+ )
157
+ # Weight cache: when the run wants it (runner-assigned network_volume), HF_HOME points at the
158
+ # Lambda filesystem bind-mounted at /lambda/nfs/<name> (region-independent path -> one user_data
159
+ # serves every region; we just ensure the FS exists per region in the walk). If the FS can't be
160
+ # ensured in a region, fall back to the cold user_data there. ``gpu=`` selects the per-GPU worker
161
+ # image (dev: worker_image_for_gpu via lambda_image).
162
+ cache_name = getattr(spec.gpu, "network_volume", None)
163
+ cold_user_data = build_user_data(
164
+ build_payload(spec, seed, attempt, runtime_secrets=runtime_secrets), gpu=spec.gpu.type
165
+ )
166
+
167
+ def _cache_user_data_for(mount_point: str) -> str:
168
+ """Cache user_data whose bind-mount targets THIS region's actual NFS host path."""
169
+ return build_user_data(
170
+ build_payload(
171
+ spec, seed, attempt, runtime_secrets=runtime_secrets,
172
+ cache_host_mount=mount_point, mode=mode, models=models,
173
+ ),
174
+ gpu=spec.gpu.type,
175
+ )
176
+
177
+ # Prebuild the cache user_data for the DEFAULT mount path (/lambda/nfs/<name>) once — the common
178
+ # case, so the walk reuses it without re-rendering. A region whose ensure_filesystem returns a
179
+ # DIFFERENT mount_point rebuilds with that real path (see the walk below), so the bootstrap
180
+ # bind-mount never points at a stale host path.
181
+ default_cache_mount = f"/lambda/nfs/{cache_name}" if cache_name else ""
182
+ cache_user_data = _cache_user_data_for(default_cache_mount) if cache_name else None
183
+ name = instance_label(spec.run_id, seed, attempt)
184
+ ssh_keys = resolve_ssh_key_names()
185
+
186
+ tried_regions: set[str] = set()
187
+ candidates = list(instances)
188
+ refreshed = False
189
+ last_err: Exception | None = None
190
+ while candidates:
191
+ inst = candidates.pop(0)
192
+ if inst.region in tried_regions:
193
+ continue
194
+ tried_regions.add(inst.region)
195
+ # Ensure the cache filesystem exists in THIS region (create-if-absent) and attach it at
196
+ # launch; on any failure, launch cold here (best-effort cache, never blocks the run).
197
+ user_data, fs_names = cold_user_data, None
198
+ if cache_name:
199
+ try:
200
+ mount_point = lambda_api.ensure_filesystem(cache_name, inst.region)
201
+ # Use the FS's ACTUAL mount_point: Lambda auto-mounts the NFS filesystem on the host
202
+ # there, and the bootstrap bind-mounts that exact path into the container. If it's the
203
+ # default /lambda/nfs/<name> (the usual case) reuse the prebuilt user_data; otherwise
204
+ # rebuild for this region so the bind-mount doesn't point at a stale/wrong host path
205
+ # (which would silently run cold / fail the preload mount check).
206
+ region_user_data = (
207
+ cache_user_data if mount_point == default_cache_mount
208
+ else _cache_user_data_for(mount_point)
209
+ )
210
+ user_data, fs_names = region_user_data, [cache_name]
211
+ except Exception as e:
212
+ # A preload run's WHOLE purpose is to warm the cache; the cold user_data carries no
213
+ # mode/models, so a cold fallback would run a full training bootstrap (GPU billing,
214
+ # timeout) and warm nothing. SKIP this region instead — try the next one, and fail the
215
+ # walk if no region can host the cache. Normal runs still degrade to a cold run.
216
+ if mode == "preload":
217
+ say(f"weight cache unavailable in {inst.region} ({e}); skipping (preload needs it)")
218
+ last_err = e
219
+ continue
220
+ say(f"weight cache unavailable in {inst.region} ({e}); launching cold")
221
+ try:
222
+ instance_id = lambda_api.launch_instance(
223
+ region_name=inst.region,
224
+ instance_type_name=inst.instance_type,
225
+ ssh_key_names=ssh_keys,
226
+ name=name,
227
+ user_data=user_data,
228
+ file_system_names=fs_names,
229
+ )
230
+ except lambda_api.LambdaApiError as e:
231
+ last_err = e
232
+ if not _launch_rejection_is_clean(e):
233
+ # Ambiguous failure (timeout / 5xx / 429 / accepted-but-no-id): Lambda may have
234
+ # created a billed instance whose id we never got. Do NOT launch another in this
235
+ # attempt — reconcile any phantom by run-name and stop; the runner's retry (+ gc /
236
+ # sweep_orphans) re-provisions cleanly. This is the non-idempotent-launch cost-safety
237
+ # the region walk would otherwise violate.
238
+ say(f"ambiguous launch failure in {inst.region}: {e}; reconciling + retrying fresh")
239
+ with contextlib.suppress(Exception):
240
+ terminate_run_instances(spec.run_id)
241
+ raise lambda_api.LambdaApiError(
242
+ f"ambiguous Lambda launch failure (possible phantom reaped): {e}"
243
+ ) from e
244
+ say(f"region {inst.region} ({inst.gpu} {inst.instance_type}) rejected: {e}")
245
+ # A CLEAN reject of a CACHE-backed launch whose error mentions the FILESYSTEM was likely
246
+ # caused by the attach itself (a just-created FS not yet attachable, an attach quota, an
247
+ # unsupported pairing) — not the GPU class. Best-effort cache must never make a region the
248
+ # cold path could have served fail outright, so retry THIS region once WITHOUT the cache
249
+ # before walking. Gated to filesystem-shaped errors so a plain capacity reject still walks
250
+ # (a cold retry there would just reject again). Skipped in preload mode (a cache-less
251
+ # preload warms nothing). The reject was clean -> no billed instance -> a 2nd launch is safe.
252
+ fs_attach_reject = fs_names and any(
253
+ tok in str(e).lower() for tok in ("file_system", "filesystem", "file-system")
254
+ )
255
+ if mode != "preload" and fs_attach_reject:
256
+ say(f"retrying {inst.region} WITHOUT the weight cache (attach may have caused the reject)")
257
+ try:
258
+ instance_id = lambda_api.launch_instance(
259
+ region_name=inst.region, instance_type_name=inst.instance_type,
260
+ ssh_key_names=ssh_keys, name=name, user_data=cold_user_data,
261
+ file_system_names=None,
262
+ )
263
+ except lambda_api.LambdaApiError as e2:
264
+ last_err = e2
265
+ if not _launch_rejection_is_clean(e2):
266
+ with contextlib.suppress(Exception):
267
+ terminate_run_instances(spec.run_id)
268
+ raise lambda_api.LambdaApiError(
269
+ f"ambiguous Lambda launch failure (possible phantom reaped): {e2}"
270
+ ) from e2
271
+ say(f"region {inst.region} also rejected cold: {e2}")
272
+ else:
273
+ say(
274
+ f"launched lambda instance {instance_id} (cold, cache-less): {inst.gpu} "
275
+ f"{inst.instance_type} in {inst.region} attempt={attempt} seed={seed}"
276
+ )
277
+ return LambdaJobHandle(
278
+ instance_id=instance_id, instance_type=inst.instance_type, region=inst.region,
279
+ name=name, gpu=inst.gpu, hourly_usd=inst.price_usd_hr, attempt=attempt,
280
+ started_ts=time.time(),
281
+ )
282
+ # NOT in preload mode: warm_instances pins each preload launch to ONE specific target
283
+ # region and reports that exact region as warmed. Refreshing to a DIFFERENT region here
284
+ # would warm region B while the caller reports the target region A as warmed (cache still
285
+ # cold). A preload that can't run in its target region must FAIL it (walk exhausts ->
286
+ # raise), never silently warm another.
287
+ if mode != "preload" and not candidates and not refreshed:
288
+ refreshed = True
289
+ # Force a fresh capacity fetch (the allocation cache is ~45s stale) so the refresh
290
+ # can discover regions that freed up since the walk started.
291
+ candidates = [
292
+ c for c in usable_instances(inst.gpu, force=True) if c.region not in tried_regions
293
+ ]
294
+ continue
295
+ say(
296
+ f"launched lambda instance {instance_id}: {inst.gpu} {inst.instance_type} "
297
+ f"${inst.price_usd_hr:.2f}/hr in {inst.region} attempt={attempt} seed={seed}"
298
+ )
299
+ return LambdaJobHandle(
300
+ instance_id=instance_id,
301
+ instance_type=inst.instance_type,
302
+ region=inst.region,
303
+ name=name,
304
+ gpu=inst.gpu,
305
+ hourly_usd=inst.price_usd_hr,
306
+ attempt=attempt,
307
+ started_ts=time.time(),
308
+ )
309
+ # Phantom-instance safety: a non-idempotent launch Lambda ACCEPTED but whose response lacked a
310
+ # parseable id raises (caught above as a region rejection), leaving a billed instance under our
311
+ # run name that no handle owns. Best-effort reap any such instance by run-name before giving up.
312
+ with contextlib.suppress(Exception):
313
+ terminate_run_instances(spec.run_id)
314
+ raise lambda_api.LambdaApiError(
315
+ f"all {len(tried_regions)} Lambda region(s) rejected the {spec.gpu.type} launch "
316
+ f"(no capacity): {last_err}"
317
+ )
318
+
319
+
320
+ # Rate-limited reader for one HF artifact's text content (None until it exists). Shared with
321
+ # runpod's poller via make_hf_text_reader; kept under this module-local name because tests
322
+ # monkeypatch ``lambda.jobs._make_hf_file_reader`` and the poll/failure paths resolve it as a
323
+ # module global (so a monkeypatch still takes effect).
324
+ _make_hf_file_reader = make_hf_text_reader
325
+
326
+
327
+ def _failure_detail(hf_repo: str, prefix: str, phase: str, marker: dict | None) -> str:
328
+ """Best root-cause detail we can assemble from the HF artifacts.
329
+
330
+ Lambda exposes NO instance console/log API, so the box's own ``lambda_boot.log`` (pushed to HF
331
+ by the cloud-init host uploader) is the substitute for Vast's ``instance_logs`` — the only home
332
+ of early-bootstrap failures (docker/GPU not ready, image-pull failure).
333
+ """
334
+ parts = []
335
+ if marker and marker.get("error"):
336
+ parts.append(str(marker["error"]))
337
+ err = _make_hf_file_reader(hf_repo, f"{prefix}/error_{phase}.txt")(force=True)
338
+ if err:
339
+ parts.append(f"--- error_{phase}.txt ---\n{err[-2000:]}")
340
+ boot = _make_hf_file_reader(hf_repo, f"{prefix}/lambda_boot.log")(force=True)
341
+ if boot:
342
+ parts.append(f"--- lambda_boot.log (host) ---\n{boot[-3000:]}")
343
+ return "\n".join(parts) or "lambda worker terminated without a DONE sentinel"
344
+
345
+
346
+ def poll_lambda_job(
347
+ handle: LambdaJobHandle,
348
+ spec,
349
+ seed: int,
350
+ log=None,
351
+ interval_s: float = 15.0,
352
+ heartbeat_reader=None,
353
+ setup_grace_s: float = SETUP_GRACE_S,
354
+ stall_after_s: float = STALL_AFTER_S,
355
+ deadline_s: float | None = None,
356
+ ) -> PollResult:
357
+ """Poll instance status + HF artifacts to a terminal state (cf. runpod.jobs.poll_job).
358
+
359
+ COMPLETED fresh DONE sentinel on HF -> metrics.json (cost stamped from the instance's $/hr).
360
+ job_failed attempt marker with ok=false (a real worker error; fails fast unless the worker
361
+ flagged it retriable).
362
+ job_preempted instance died without DONE/marker (host loss) -> infra-shaped, retried.
363
+ stalled never became active within LOAD_TIMEOUT_S, heartbeat frozen, or deadline passed.
364
+ """
365
+ say = make_say(log)
366
+
367
+ # Single source of truth for "when did this instance launch". started_ts is a non-Optional float
368
+ # that LambdaJobHandle.from_dict coerces to 0.0 when MISSING (old/corrupt handle), so 0.0 means
369
+ # "unknown launch" (a real launch is a large epoch ts, never 0.0). Fall back to now so EVERY use
370
+ # below -- the load/stall clocks AND done_is_fresh / finish_ok's wall+cost stamping -- treats a
371
+ # recovered corrupt handle consistently, instead of billing/comparing from the 1970 epoch.
372
+ launch_ts = handle.started_ts or time.time()
373
+
374
+ hf_repo = spec.train.hf_repo
375
+ prefix = f"{spec.phase}/{spec.run_id}/seed{seed}"
376
+ done_reader = _make_hf_file_reader(hf_repo, f"{prefix}/DONE")
377
+ marker_reader = _make_hf_file_reader(
378
+ hf_repo, f"{prefix}/lambda_attempt{handle.attempt}.json", min_interval_s=60.0
379
+ )
380
+ metrics_reader = _make_hf_file_reader(hf_repo, f"{prefix}/metrics.json")
381
+
382
+ def finish_ok(done_content: str | None = None) -> PollResult:
383
+ raw = metrics_reader(force=True)
384
+ if raw is None:
385
+ return PollResult(False, failure="job_failed", detail="DONE without metrics.json")
386
+ metrics = json.loads(raw)
387
+ # Prefer the worker's DONE timestamp when present and sane; fall back to now. On delayed
388
+ # recovery the control plane may poll hours after the box wrote DONE, so billing to now
389
+ # would over-bill by the downtime.
390
+ end_ts = time.time()
391
+ if done_content:
392
+ try:
393
+ done_ts = float(done_content.strip())
394
+ if launch_ts <= done_ts <= end_ts:
395
+ end_ts = done_ts
396
+ except ValueError:
397
+ pass
398
+ wall_h = (end_ts - launch_ts) / 3600.0
399
+ metrics["cost_usd"] = round(wall_h * handle.hourly_usd, 6)
400
+ notes = metrics.get("notes") if isinstance(metrics.get("notes"), dict) else {}
401
+ notes.update(
402
+ {
403
+ "provider": "lambda",
404
+ "lambda_rate_usd_hr": handle.hourly_usd,
405
+ "lambda_gpu": handle.gpu,
406
+ "lambda_instance_type": handle.instance_type,
407
+ "lambda_region": handle.region,
408
+ }
409
+ )
410
+ metrics["notes"] = notes
411
+ return PollResult(True, metrics=metrics)
412
+
413
+ def done_is_fresh(content: str) -> bool:
414
+ # DONE carries the worker's time.time(); 120 s of clock-skew grace. Anything older predates
415
+ # this attempt (leftover from a prior attempt's resume). Uses launch_ts (not handle.started_ts)
416
+ # so an unknown-launch (0.0) handle doesn't accept every leftover DONE as fresh.
417
+ try:
418
+ return float(content.strip()) > launch_ts - 120.0
419
+ except ValueError:
420
+ return False
421
+
422
+ def finish_from_ok_marker() -> PollResult:
423
+ # An ok marker means the worker finished (it wrote metrics.json before the marker), even if
424
+ # the DONE sentinel is STALE — a retry that hit the worker's already-complete path restores
425
+ # the prior attempt's metrics but leaves DONE at the old timestamp. Treat ok-marker + metrics
426
+ # as terminal success; pass the DONE only when it's genuinely fresh (so cost bills to it).
427
+ d = done_reader(force=True)
428
+ return finish_ok(d if (d is not None and done_is_fresh(d)) else None)
429
+
430
+ def fail_from_marker(marker: dict | None) -> PollResult:
431
+ # A real worker error fails fast UNLESS it is flagged retriable — the host failure marker
432
+ # (docker/GPU never ready) sets retriable=True, and the worker stamps it in heartbeat for a
433
+ # RetriableInfraError; either retries on a fresh host like a platform termination.
434
+ from flash.providers.runpod.jobs import worker_flagged_retriable
435
+
436
+ retriable = bool(marker and marker.get("retriable")) or worker_flagged_retriable(heartbeat_reader)
437
+ return PollResult(
438
+ False,
439
+ failure="job_preempted" if retriable else "job_failed",
440
+ detail=_failure_detail(hf_repo, prefix, spec.phase, marker),
441
+ )
442
+
443
+ def terminal_artifact_result() -> PollResult | None:
444
+ # One forced read of the worker's terminal HF artifacts (DONE / attempt ok-marker). Returns a
445
+ # terminal PollResult when the worker definitively finished or errored, else None. Used both
446
+ # when the host is dead AND before returning a recovered client-side-deadline `stalled`: a
447
+ # control-plane outage longer than max_wall+grace must not discard a seed the worker actually
448
+ # completed during the downtime (the deadline check would otherwise fire before any DONE read).
449
+ d = done_reader(force=True)
450
+ if d is not None and done_is_fresh(d):
451
+ return finish_ok(d)
452
+ raw = marker_reader(force=True)
453
+ if raw:
454
+ with contextlib.suppress(ValueError):
455
+ m = json.loads(raw)
456
+ if m.get("ok"):
457
+ return finish_from_ok_marker() # finished (stale DONE ok)
458
+ return fail_from_marker(m)
459
+ return None
460
+
461
+ poll_errors = PollErrorTracker(say, interval_s)
462
+ # Seed the load/stall clocks from the instance's LAUNCH (launch_ts), not this poll's start: on a
463
+ # delayed reattach after a control-plane restart the box has been billing since launch, so a
464
+ # still-booting instance that already blew LOAD_TIMEOUT_S must fail over NOW instead of getting
465
+ # another full window. launch_ts already maps an unknown-launch (0.0) handle to now (see above),
466
+ # so a fresh launch is a no-op and a corrupt handle won't peg the clocks to the epoch.
467
+ start = launch_ts
468
+ last_status = None
469
+ last_hb_key = None
470
+ last_progress = start
471
+ became_active = False
472
+ seen_training_hb = False
473
+ missing_streak = 0
474
+ while True:
475
+ if deadline_s is not None and time.time() - start > deadline_s:
476
+ # A recovered run can blow a launch-anchored deadline on the FIRST reattach tick (the
477
+ # outage lasted past max_wall+grace). Read terminal artifacts once before giving up: if
478
+ # the worker finished/errored during the downtime, persist that instead of retrying.
479
+ terminal = terminal_artifact_result()
480
+ if terminal is not None:
481
+ return terminal
482
+ return PollResult(False, failure="stalled", detail="client-side deadline exceeded")
483
+ try:
484
+ inst = lambda_api.get_instance(handle.instance_id)
485
+ poll_errors.reset()
486
+ except lambda_api.LambdaApiError as e:
487
+ if poll_errors.record(e):
488
+ return PollResult(False, failure="poll_error", detail=str(e))
489
+ continue
490
+ missing_streak = missing_streak + 1 if inst is None else 0
491
+ status = (inst or {}).get("status") or ("missing" if inst is None else "unknown")
492
+ if status != last_status:
493
+ say(f"instance {handle.instance_id}: {status}")
494
+ # Treat a status TRANSITION as progress, but NOT the first observation: last_status
495
+ # starts None, so on a reattach the very first read always "changes" — counting it as
496
+ # progress would overwrite the launch-anchored last_progress and hand a silent-since-
497
+ # launch worker a fresh full setup grace after every control-plane restart.
498
+ if last_status is not None:
499
+ last_progress = time.time()
500
+ last_status = status
501
+ if status == "active":
502
+ became_active = True
503
+
504
+ done = done_reader()
505
+ if done is not None and done_is_fresh(done):
506
+ return finish_ok(done)
507
+
508
+ dead = missing_streak >= 3 or status in _DEAD_STATES
509
+ if dead:
510
+ # One forced final read: the worker may have finished right before the box was torn
511
+ # down (the normal success order on this substrate).
512
+ terminal = terminal_artifact_result()
513
+ if terminal is not None:
514
+ return terminal
515
+ # Dead host with no ok-marker/DONE. Distinguish a genuine host LOSS (retry on a fresh
516
+ # host/class) from a worker that actually RAN and CRASHED early -- before it could write
517
+ # the attempt marker terminal_artifact_result() reads -- but DID leave error_{phase}.txt
518
+ # (a bad env id, a config/code error, an OOM). That is a DETERMINISTIC worker error, so
519
+ # fail FAST: classifying it job_preempted burns fresh GPUs re-running a crash that will
520
+ # repeat. A crash the worker flagged retriable (RetriableInfraError, stamped in the
521
+ # heartbeat) still retries, exactly like fail_from_marker. error_{phase}.txt is not
522
+ # attempt-scoped, but this can't flip a genuine preemption to job_failed: a prior
523
+ # attempt's NON-retriable crash already ended the run via this same branch, and a prior
524
+ # retriable crash leaves a retriable heartbeat that keeps this path on job_preempted.
525
+ from flash.providers.runpod.jobs import worker_flagged_retriable
526
+
527
+ err = _make_hf_file_reader(hf_repo, f"{prefix}/error_{spec.phase}.txt")(force=True)
528
+ worker_crashed = bool(err and err.strip()) and not worker_flagged_retriable(heartbeat_reader)
529
+ return PollResult(
530
+ False,
531
+ failure="job_failed" if worker_crashed else "job_preempted",
532
+ detail=_failure_detail(hf_repo, prefix, spec.phase, None),
533
+ )
534
+
535
+ raw_marker = marker_reader()
536
+ if raw_marker:
537
+ try:
538
+ marker = json.loads(raw_marker)
539
+ except ValueError:
540
+ marker = None
541
+ if marker and not marker.get("ok"):
542
+ return fail_from_marker(marker)
543
+ if marker and marker.get("ok"):
544
+ return finish_from_ok_marker() # ok marker + metrics == success (DONE may be stale)
545
+
546
+ if not became_active and time.time() - start > LOAD_TIMEOUT_S:
547
+ return PollResult(
548
+ False,
549
+ failure="stalled",
550
+ detail=f"instance stuck in '{status}' for {int(time.time() - start)}s "
551
+ f"(never became active; provisioning / host issue)",
552
+ )
553
+
554
+ new_key, stage = surface_heartbeat(heartbeat_reader, last_hb_key, say)
555
+ if new_key != last_hb_key:
556
+ last_hb_key = new_key
557
+ # Credit the heartbeat's OWN timestamp, not the poll time: a heartbeat that was
558
+ # already stale before a control-plane restart must not reset the stall clock to now
559
+ # on the first reattach read (last_hb_key starts None, so even an old heartbeat looks
560
+ # "new"). Clamped to [launch, now]. Healthy workers heartbeat well inside the stall
561
+ # window, so their ts ~= now (no behavior change on the normal path). ``fresh`` is False
562
+ # for a LEFTOVER heartbeat from a prior attempt (ts < launch); we then neither advance
563
+ # last_progress nor mark training seen, so a stale training heartbeat can't arm the
564
+ # tighter training stall window before this attempt overwrites the file. Dates against
565
+ # ``launch_ts`` (NOT the raw handle.started_ts) so an unknown-launch (0.0) handle is
566
+ # anchored to the SAME ``now`` reference as done_is_fresh / the load+stall clocks: a
567
+ # leftover heartbeat predating this reattach is then consistently rejected instead of
568
+ # blanket-trusted (which could otherwise arm the tighter training window off a prior
569
+ # attempt's training heartbeat). On a real launch this is exactly handle.started_ts.
570
+ hb_ts, fresh = heartbeat_progress_ts(new_key, launch_ts)
571
+ if fresh:
572
+ last_progress = hb_ts
573
+ if stage not in _SETUP_HEARTBEAT_STAGES:
574
+ seen_training_hb = True
575
+ # Before the first TRAINING heartbeat the box is still in the long cold start (Docker pull +
576
+ # pip + model download), so use the larger setup grace; tighten only once training begins.
577
+ if became_active:
578
+ limit = stall_after_s if seen_training_hb else setup_grace_s
579
+ if time.time() - last_progress > limit:
580
+ phase = "training" if seen_training_hb else "setup (pre-training)"
581
+ return PollResult(
582
+ False,
583
+ failure="stalled",
584
+ detail=f"no worker progress for {int(time.time() - last_progress)}s "
585
+ f"during {phase} (instance status {status}, limit {int(limit)}s)",
586
+ )
587
+ time.sleep(interval_s)
588
+
589
+
590
+ def submit_run_lambda(
591
+ spec,
592
+ seed: int,
593
+ log=None,
594
+ on_handle=None,
595
+ attempt: int = 0,
596
+ runtime_secrets: dict | None = None,
597
+ on_last_gpu: bool = False,
598
+ ) -> PollResult:
599
+ """Lambda equivalent of ``runpod.jobs.submit_run``: launch, persist, poll, terminate.
600
+
601
+ The ``finally`` terminate is the cost-safety primary: every exit path — success, failure,
602
+ stall, exception, KeyboardInterrupt — tears the paid instance down.
603
+ """
604
+ if spec.gpu.type not in GPU_INFO:
605
+ raise lambda_api.LambdaApiError(
606
+ f"submit_run_lambda needs a concrete gpu class, got {spec.gpu.type!r}"
607
+ )
608
+ instances = usable_instances(spec.gpu.type)
609
+ handle = launch_and_submit(
610
+ spec, seed, instances, attempt=attempt, log=log, runtime_secrets=runtime_secrets
611
+ )
612
+ # The instance is billing the MOMENT launch_and_submit returns; the teardown ``finally`` must
613
+ # guard EVERYTHING after that point — including ``on_handle`` (persisting the handle can itself
614
+ # raise) — so the paid box is terminated even if the handle is never persisted.
615
+ try:
616
+ if on_handle is not None:
617
+ on_handle(handle.to_dict())
618
+ hf_repo = spec.train.hf_repo
619
+ prefix = f"{spec.phase}/{spec.run_id}/seed{seed}"
620
+ reader = make_hf_heartbeat_reader(hf_repo, prefix) if hf_repo else None
621
+ # On the last GPU class there is nowhere left to walk, so be more patient before giving up.
622
+ setup_grace = SETUP_GRACE_S * (1.5 if on_last_gpu else 1.0)
623
+ deadline = max(60, int(spec.gpu.max_wall_seconds)) + PROVISION_GRACE_S
624
+ return poll_lambda_job(
625
+ handle,
626
+ spec,
627
+ seed,
628
+ log=log,
629
+ heartbeat_reader=reader,
630
+ setup_grace_s=setup_grace,
631
+ deadline_s=deadline,
632
+ )
633
+ finally:
634
+ lambda_api.terminate_instances([handle.instance_id])
635
+
636
+
637
+ def terminate_run_instances(run_id: str) -> list[str]:
638
+ """Terminate every instance belonging to ONE run (names start with its run prefix).
639
+
640
+ Cancel/GC path: unlike ``sweep_orphans`` this never looks at other runs, so it is safe to call
641
+ while they are in flight. Best-effort: never raises.
642
+ """
643
+ if not run_id:
644
+ return []
645
+ try:
646
+ instances = lambda_api.list_instances()
647
+ except Exception:
648
+ return []
649
+ prefix = run_label_prefix(run_id)
650
+ ids = [
651
+ str(i.get("id"))
652
+ for i in instances
653
+ if i.get("id")
654
+ and (str(i.get("name") or "") == prefix or str(i.get("name") or "").startswith(prefix + "-s"))
655
+ ]
656
+ return lambda_api.terminate_instances(ids) if ids else []
657
+
658
+
659
+ def sweep_orphans(
660
+ active_labels: set[str] | Callable[[], set[str]] | None = None,
661
+ ) -> list[str]:
662
+ """Terminate Flash-named instances that no live run owns; return terminated ids.
663
+
664
+ Run at server startup (crash recovery) and after runs. Only names carrying the ``flash-`` run
665
+ prefix are ever touched — nothing else on the account is ours to terminate. ``active_labels``
666
+ may be RAW run ids; each is passed through ``run_label_prefix`` so it matches the same forced
667
+ prefix the instance names carry. Best-effort: never raises.
668
+
669
+ ``active_labels`` may also be a CALLABLE returning that set — it is then resolved AFTER the
670
+ instance list is fetched. The periodic in-lifetime sweep passes one so the protection set is
671
+ read post-listing: any instance present in the list had its run's status row committed before
672
+ the instance was launched (hence before this list call), so resolving the live set now is
673
+ guaranteed to include it — closing the launch race where a run started after a pre-captured set
674
+ could have its fresh worker reaped as a phantom orphan.
675
+ """
676
+ try:
677
+ instances = lambda_api.list_instances()
678
+ except Exception as exc:
679
+ logger.warning("lambda orphan sweep skipped: %s", exc)
680
+ return []
681
+ try:
682
+ labels = active_labels() if callable(active_labels) else active_labels
683
+ except Exception as exc:
684
+ # Resolving the protection set failed (e.g. a db/status read error in the callable). SKIP the
685
+ # sweep — never fall through to an empty set, which would treat every live run's instance as
686
+ # an orphan and reap it. Honors the "never raises" contract.
687
+ logger.warning("lambda orphan sweep skipped: could not resolve active set: %s", exc)
688
+ return []
689
+ active = {run_label_prefix(a) for a in (labels or set())}
690
+ now = time.time()
691
+ orphans: list[str] = []
692
+ for inst in instances:
693
+ name = str(inst.get("name") or "")
694
+ if not name.startswith("flash-"):
695
+ continue
696
+ # Warm/preload boxes (``flash-preload-...``) are driver-owned: launched by
697
+ # preload.warm_instances (mode="preload"), NEVER persisted in the run DB (so never in
698
+ # ``active``), and self-terminated in _warm_one_instance's ``finally`` (and by startup
699
+ # recover_runs). A catalog warm can outlast this ~10-min sweep, so reaping them by the bare
700
+ # ``flash-`` prefix would kill an in-progress preload mid-download; normally exempt them.
701
+ # EXCEPTION: a box still alive past its embedded wall deadline + grace has lost its driver (the
702
+ # only thing that terminates instance providers — nothing on the box self-terminates the VM), so
703
+ # reap it to bound the leak rather than exempt it forever (see preload_box_reap_due).
704
+ if name.startswith("flash-preload-"):
705
+ if preload_box_reap_due(name, now):
706
+ iid = inst.get("id")
707
+ if iid:
708
+ orphans.append(str(iid))
709
+ logger.warning(
710
+ "reaping orphaned lambda preload box %s (outlived its wall deadline + grace; "
711
+ "driver lost)", name)
712
+ continue
713
+ # Match on the name boundary, not a raw string prefix: a live run's prefix must EQUAL the
714
+ # name or be followed by the ``-s`` seed boundary, so ``flash-100`` can't shield
715
+ # ``flash-1000-...`` (or vice versa).
716
+ if any(name == a or name.startswith(a + "-s") for a in active):
717
+ continue
718
+ iid = inst.get("id")
719
+ if iid:
720
+ orphans.append(str(iid))
721
+ deleted = lambda_api.terminate_instances(orphans) if orphans else []
722
+ for iid in deleted:
723
+ logger.warning("terminated orphaned lambda instance %s", iid)
724
+ return deleted