freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,915 @@
1
+ """Preload (warm) the shared weight-cache volumes with the catalog's base-model weights.
2
+
3
+ The weight cache is the eager ``flash-weights-<dc>`` network-volume fleet (one volume per storage
4
+ datacenter, created-or-attached on every endpoint deploy). This module WARMS that fleet — downloads
5
+ the catalog's base-model weights onto each region's volume up front — so the very first run in any
6
+ region is already warm. An operator/setup action, not a user one (the cache is fully managed, so
7
+ there is no user-facing knob).
8
+
9
+ Mechanism: for each datacenter, deploy a short-lived worker with ONLY that region's volume attached
10
+ (pinned to that single DC, so the worker provably lands there), run the baked handler in ``preload``
11
+ mode (download-only, ``HF_HOME`` -> the mounted volume), then tear the endpoint down. Reuses the
12
+ existing baked worker image + deploy/submit/quota machinery; the only new worker code is the
13
+ ``preload`` branch in ``train.endpoints._train_body``.
14
+
15
+ COST / GC NOTE: the fleet is permanent, billed standing storage. Eager provisioning means a
16
+ ``flash-weights-<dc>`` volume exists in EVERY storage datacenter (one per DataCenter.all() entry —
17
+ currently ~11 x 100 GB ~= 1.1 TB, ~$77/mo; grows by one volume if the SDK adds a storage region),
18
+ created by the first endpoint deploy (or ``--provision`` / a full preload), and RunPod network
19
+ volumes are NOT auto-deleted — there is no GC. Reclaim them with ``--teardown`` (deletes every
20
+ per-DC weight-cache volume across ALL pool accounts via the RunPod REST API).
21
+
22
+ Run it::
23
+
24
+ python -m flash.providers.runpod.preload # all catalog models, all DCs
25
+ python -m flash.providers.runpod.preload --datacenters US-CA-2,EU-RO-1 --models Qwen/Qwen3.5-4B
26
+ python -m flash.providers.runpod.preload --dry-run # print the plan, provision nothing
27
+ python -m flash.providers.runpod.preload --teardown # DELETE the cache volumes (reclaim $)
28
+ """
29
+
30
+ from __future__ import annotations
31
+
32
+ import argparse
33
+ import contextlib
34
+ import json
35
+ import os
36
+ import time
37
+ import uuid
38
+ from concurrent.futures import ThreadPoolExecutor, as_completed
39
+
40
+ from flash._logging import get_logger
41
+ from flash.providers._poll import preload_instance_run_id
42
+ from flash.providers.runpod import api as runpod_api
43
+ from flash.providers.runpod.jobs import (
44
+ build_function_input,
45
+ decode_output,
46
+ deploy_train_endpoint,
47
+ make_hf_text_reader,
48
+ weight_cache_datacenters,
49
+ weight_cache_volume_name,
50
+ )
51
+
52
+ logger = get_logger(__name__)
53
+
54
+
55
+ def _run_async(coro):
56
+ """Run a coroutine to completion from sync code, even if an event loop is already running.
57
+
58
+ teardown is normally a sync CLI/operator entrypoint (asyncio.run is fine), but it may also be
59
+ called from an async context (a notebook, a FastAPI handler) where ``asyncio.run`` raises
60
+ "cannot be called from a running event loop". In that case run it on a worker thread instead.
61
+ """
62
+ import asyncio as _asyncio
63
+
64
+ try:
65
+ _asyncio.get_running_loop()
66
+ except RuntimeError:
67
+ return _asyncio.run(coro) # no running loop — the normal CLI/sync path
68
+ import concurrent.futures
69
+
70
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as ex:
71
+ return ex.submit(_asyncio.run, coro).result()
72
+
73
+
74
+ _HF_HOME = "/runpod-volume/hf-cache"
75
+ # Cheapest broadly-available class; preload only downloads (no compute), so the GPU is incidental —
76
+ # the job is short, so the cost is a few cents per region.
77
+ _PRELOAD_GPU = "RTX 4090"
78
+ _TERMINAL_OK = {"COMPLETED"}
79
+ _TERMINAL_FAIL = {"FAILED", "CANCELLED", "TIMED_OUT"}
80
+
81
+
82
+ def catalog_model_ids() -> list[str]:
83
+ """The public base models to warm: every curated catalog entry (the cache holds public weights).
84
+
85
+ Open-model-policy (``allow``) runs may use arbitrary/private models that aren't worth — or safe —
86
+ to pre-warm globally; those simply download cold on first use and then cache like any other.
87
+ """
88
+ from flash.catalog import MODELS
89
+
90
+ return list(MODELS)
91
+
92
+
93
+ def _preload_one_dc(
94
+ dc_id: str,
95
+ models: list[str],
96
+ token: str | None,
97
+ gpu: str,
98
+ timeout_s: int,
99
+ poll_interval_s: float,
100
+ ) -> dict:
101
+ """Warm one datacenter's volume: deploy (pinned to that DC) -> preload job -> teardown."""
102
+ from runpod_flash import NetworkVolume
103
+ from runpod_flash.core.resources.datacenter import DataCenter
104
+
105
+ from flash.runner import WEIGHT_CACHE_VOLUME_GB, WEIGHT_CACHE_VOLUME_NAME
106
+
107
+ dc = DataCenter.from_string(dc_id)
108
+ # SAME per-DC physical name the training path uses (weight_cache_volume_name), so preload warms
109
+ # exactly the volume a later run in this DC will mount.
110
+ vol_name = weight_cache_volume_name(WEIGHT_CACHE_VOLUME_NAME, dc)
111
+ # Pass a FACTORY (not a prebuilt dict): deploy_train_endpoint may fail over across accounts under
112
+ # a multi-key pool, and the SDK stamps an account-scoped id onto a NetworkVolume — so each account
113
+ # attempt must build a fresh volume, else the next account reuses the first's stale id and the
114
+ # single-DC preload fails.
115
+ def _endpoint_kwargs():
116
+ return {
117
+ "volume": [NetworkVolume(name=vol_name, size=WEIGHT_CACHE_VOLUME_GB, datacenter=dc)],
118
+ "datacenter": [dc],
119
+ }
120
+
121
+ endpoint_id = None
122
+ try:
123
+ endpoint_id, _name = deploy_train_endpoint(
124
+ gpu,
125
+ execution_timeout_ms=timeout_s * 1000,
126
+ # Unique per invocation: RunPod reuses an endpoint by name, so a stable suffix could
127
+ # resolve a stale (deleted) endpoint id from a prior preload's persisted SDK state on a
128
+ # long-lived control plane. A fresh suffix each run sidesteps that.
129
+ name_suffix=f"preload-{dc_id.lower()}-{uuid.uuid4().hex[:6]}",
130
+ spec=None,
131
+ endpoint_kwargs=_endpoint_kwargs,
132
+ )
133
+ # HF_HUB_ENABLE_HF_TRANSFER is exported by the worker image (Dockerfile.worker ENV), so it is
134
+ # not passed here — only HF_HOME (the per-region mount) and the token need overriding.
135
+ payload = {
136
+ "mode": "preload",
137
+ "models": models,
138
+ "env": {"HF_HOME": _HF_HOME, **({"HF_TOKEN": token} if token else {})},
139
+ }
140
+ job_id = runpod_api.submit_job(endpoint_id, build_function_input(payload))
141
+ logger.info("preload %s: job %s submitted (%d models)", dc_id, job_id, len(models))
142
+ result = _poll_until_done(endpoint_id, job_id, timeout_s, poll_interval_s)
143
+ # The job COMPLETED, but the handler reports per-model failures (and a hard error if the
144
+ # volume wasn't mounted) inside its result — a completed job is NOT necessarily a warmed
145
+ # region. Surface those so the driver/CLI don't count a no-op (or partial) warm as success.
146
+ if result.get("error"):
147
+ return {"datacenter": dc_id, "status": "error", "error": result["error"], "result": result}
148
+ # Warmed this DC's volume. Nothing else to record: the training path attaches the eager fleet
149
+ # (a volume in every storage DC) regardless, so the warm weights are picked up automatically.
150
+ if result.get("failed"):
151
+ return {"datacenter": dc_id, "status": "partial", "result": result}
152
+ return {"datacenter": dc_id, "status": "ok", "result": result}
153
+ except Exception as exc: # one region failing must not abort the others
154
+ logger.warning("preload %s FAILED: %s", dc_id, exc)
155
+ return {"datacenter": dc_id, "status": "error", "error": str(exc)}
156
+ finally:
157
+ if endpoint_id:
158
+ with contextlib.suppress(Exception):
159
+ runpod_api.delete_endpoint(endpoint_id)
160
+
161
+
162
+ def _poll_until_done(
163
+ endpoint_id: str, job_id: str, timeout_s: int, poll_interval_s: float
164
+ ) -> dict:
165
+ deadline = time.time() + timeout_s
166
+ while time.time() < deadline:
167
+ st = runpod_api.job_status(endpoint_id, job_id)
168
+ status = (st or {}).get("status")
169
+ if status in _TERMINAL_OK:
170
+ # RunPod may surface a completed job's ``output`` as a JSON STRING (not a dict) or as the
171
+ # flash live-function envelope; decode_output normalizes both to the handler's metrics dict
172
+ # so the caller's ``result.get(...)`` never crashes on a str and mis-reports a warmed region.
173
+ output = (st or {}).get("output")
174
+ if not output:
175
+ # A COMPLETED RunPod job with no ``output`` is NOT evidence of a warmed region — an
176
+ # API/handler-shape mismatch or a broken worker image can finish the job yet return
177
+ # nothing, so there's no ``preloaded``/``already_cached`` record for any model. Surface a
178
+ # structured error so _preload_one_dc reports the DC FAILED instead of counting an empty
179
+ # terminal output as ``status: ok`` and mis-reporting the cache as warm.
180
+ return {"error": f"preload job {job_id} completed with no output"}
181
+ try:
182
+ return decode_output(output) or {}
183
+ except Exception as exc:
184
+ # decode_output RAISES on an error / unexpected-shape envelope. Don't let that surface as
185
+ # a bare exception — it would skip _preload_one_dc's structured ``result.get("error")``
186
+ # classification. Return the message AS the result's error so the region is still
187
+ # reported as a structured error rather than an opaque crash.
188
+ return {"error": str(exc)}
189
+ if status in _TERMINAL_FAIL:
190
+ raise RuntimeError(f"preload job {job_id} ended {status}: {(st or {}).get('error')}")
191
+ time.sleep(poll_interval_s)
192
+ raise TimeoutError(f"preload job {job_id} did not finish within {timeout_s}s")
193
+
194
+
195
+ def warm_weight_cache(
196
+ models: list[str] | None = None,
197
+ datacenters: list[str] | None = None,
198
+ gpu: str = _PRELOAD_GPU,
199
+ timeout_s: int = 1800,
200
+ max_workers: int = 4,
201
+ poll_interval_s: float = 10.0,
202
+ token: str | None = None,
203
+ ) -> list[dict]:
204
+ """Warm every (datacenter) volume with the given models. Returns one result dict per DC.
205
+
206
+ Datacenters are warmed concurrently (bounded by ``max_workers``). Each concurrent warm deploys a
207
+ preload endpoint, so ``max_workers`` MUST stay under the RunPod endpoint/worker quota (documented
208
+ default 5) — the default of 4 leaves a buffer so extra deploys don't fail on quota. A region that
209
+ errors is reported in its result dict and does not abort the others.
210
+ """
211
+ from runpod_flash.core.resources.datacenter import DataCenter
212
+
213
+ models = models or catalog_model_ids()
214
+ dc_ids = datacenters or [dc.value for dc in weight_cache_datacenters()]
215
+ # Validate the WHOLE --datacenters scope to concrete DataCenter values BEFORE submitting any
216
+ # futures: the per-DC parse otherwise runs inside _preload_one_dc on a worker thread, so a single
217
+ # bad id would raise through fut.result() only AFTER the valid DCs already deployed paid preload
218
+ # endpoints — aborting the command with money already spent. Parse up front so an invalid id fails
219
+ # the whole command (naming the bad id + listing valid ones via DataCenter.from_string) before any
220
+ # endpoint launches.
221
+ for d in dc_ids:
222
+ DataCenter.from_string(d)
223
+ token = token or os.environ.get("HF_TOKEN")
224
+ logger.info("warming %d datacenter(s) with %d model(s)", len(dc_ids), len(models))
225
+ with ThreadPoolExecutor(max_workers=max_workers) as pool:
226
+ futs = {
227
+ pool.submit(_preload_one_dc, dc, models, token, gpu, timeout_s, poll_interval_s): dc
228
+ for dc in dc_ids
229
+ }
230
+ results: list[dict] = [fut.result() for fut in as_completed(futs)]
231
+ ok = sum(1 for r in results if r.get("status") == "ok")
232
+ logger.info("preload complete: %d/%d datacenters warmed", ok, len(results))
233
+ return results
234
+
235
+
236
+ def teardown_weight_cache(datacenters: list[str] | None = None) -> list[str]:
237
+ """Delete the per-DC ``flash-weights-<dc>`` cache volumes to reclaim the standing storage.
238
+
239
+ RunPod network volumes are never auto-GC'd, so this is the only way to stop the monthly bill
240
+ short of the console. Returns the names deleted (``account:name`` when a multi-account pool is
241
+ configured). Targets ONLY this fleet's per-DC names (built from ``WEIGHT_CACHE_VOLUME_NAME``),
242
+ never other volumes.
243
+
244
+ Sweeps EVERY account in the ``RUNPOD_API_KEY`` pool: ``deploy_train_endpoint`` fails over to
245
+ another account on a quota error, so a cache volume may have been created under any pool key —
246
+ a single-account teardown would leak the volumes the failover created elsewhere.
247
+
248
+ ``datacenters`` semantics: ``None`` (the default) = the WHOLE storage-DC fleet; a non-empty list =
249
+ just those DCs; an EXPLICIT empty list ``[]`` = nothing (returns ``[]``). The empty-list case must
250
+ NOT widen to the full fleet — a caller that resolved a scope down to zero DCs intends a no-op, and
251
+ silently nuking every cache there would be a destructive footgun.
252
+ """
253
+
254
+ from flash.providers.runpod import keys as rp_keys
255
+ from flash.runner import WEIGHT_CACHE_VOLUME_NAME
256
+
257
+ # An EXPLICIT empty scope ([]) is a no-op, NOT "all" — never widen zero DCs to the whole fleet.
258
+ if datacenters is not None and not datacenters:
259
+ logger.info("teardown: empty datacenter scope — nothing to reclaim (refusing to widen to all)")
260
+ return []
261
+ pool = rp_keys.keys()
262
+ if not pool:
263
+ # No RunPod key configured (e.g. an instance-only control plane): this is a best-effort
264
+ # no-op, NOT an error — RunpodRestClient() would raise on a missing key and (under a chained
265
+ # `--teardown`) could abort the Lambda/Hyperstack reclaim. Mirror the instance providers'
266
+ # missing-key behavior: log and return nothing reclaimed.
267
+ logger.info("teardown: RUNPOD_API_KEY not configured — skipping RunPod cache teardown")
268
+ return []
269
+ # Import the runpod_flash SDK only AFTER the empty-scope / no-key early returns: on an
270
+ # instance-only control plane the SDK may be unavailable, and importing it at the top would defeat
271
+ # the intended best-effort no-op (a missing-key teardown must not raise on an absent SDK).
272
+ from runpod_flash.core.api.runpod import RunpodRestClient
273
+ from runpod_flash.core.resources.datacenter import DataCenter
274
+ from runpod_flash.core.urls import RUNPOD_REST_API_URL
275
+
276
+ dc_ids = datacenters if datacenters else [dc.value for dc in weight_cache_datacenters()]
277
+ targets = {
278
+ weight_cache_volume_name(WEIGHT_CACHE_VOLUME_NAME, DataCenter.from_string(d)) for d in dc_ids
279
+ }
280
+
281
+ async def _names(client) -> set:
282
+ res = await client.list_network_volumes()
283
+ vols = res if isinstance(res, list) else res.get("networkVolumes", [])
284
+ return {v.get("name") for v in vols}
285
+
286
+ async def _go_one(api_key) -> list[str]:
287
+ client = RunpodRestClient(api_key=api_key) if api_key else RunpodRestClient()
288
+ res = await client.list_network_volumes()
289
+ vols = res if isinstance(res, list) else res.get("networkVolumes", [])
290
+ to_delete = {v["name"]: v["id"] for v in vols if v.get("name") in targets and v.get("id")}
291
+ for vid in to_delete.values():
292
+ # RunPod's DELETE /networkvolumes/{id} returns 204 No Content, which the SDK's
293
+ # _execute_rest chokes on (it always await response.json()). Swallow that — we confirm
294
+ # the actual outcome by RE-LISTING below, not by trusting the delete's parsed response.
295
+ with contextlib.suppress(Exception):
296
+ await client._execute_rest("DELETE", f"{RUNPOD_REST_API_URL}/networkvolumes/{vid}")
297
+ remaining = await _names(client)
298
+ gone = [name for name in to_delete if name not in remaining] # provably gone (confirmed)
299
+ # A target still present after its delete means a REAL failure (auth/permission/5xx/network)
300
+ # that the 204-tolerant suppress() above hid — surface it so a failed reclaim isn't silent.
301
+ still = [name for name in to_delete if name in remaining]
302
+ if still:
303
+ logger.warning("teardown: %d cache volume(s) FAILED to delete (still present): %s",
304
+ len(still), ", ".join(sorted(still)))
305
+ return gone
306
+
307
+ multi = len(pool) > 1
308
+ deleted: list[str] = []
309
+ failed_accounts: list[str] = []
310
+ for i, key in enumerate(pool):
311
+ # One bad key (expired / revoked / network) must NOT abort the sweep: the cache volume a
312
+ # failover created under a LATER account would otherwise stay billed forever. Catch, record,
313
+ # and keep going so every other account is still reclaimed.
314
+ try:
315
+ names = _run_async(_go_one(key))
316
+ except Exception as exc:
317
+ failed_accounts.append(f"acct{i}")
318
+ logger.warning("teardown: RunPod account %d sweep FAILED (continuing): %s", i, exc)
319
+ continue
320
+ deleted.extend((f"acct{i}:{n}" if multi else n) for n in names)
321
+ if failed_accounts:
322
+ # Surface the aggregate so a fully-failed (or partially-failed) sweep is observable, not silent
323
+ # — the caller logs/returns `deleted`, which would otherwise hide that some accounts never ran.
324
+ logger.warning(
325
+ "teardown: %d of %d RunPod account(s) failed to sweep (%s) — their cache volumes may "
326
+ "still be billed; re-run teardown once the key(s) are valid",
327
+ len(failed_accounts), len(pool), ", ".join(failed_accounts),
328
+ )
329
+ return deleted
330
+
331
+
332
+ def teardown_lambda_filesystems(name: str | None = None) -> list[str]:
333
+ """Delete the Lambda persistent filesystems named ``name`` (default ``flash-weights``) across ALL
334
+ regions, reclaiming the standing NFS cache storage.
335
+
336
+ Best-effort and idempotent: Lambda refuses to delete a filesystem that is still in use (an
337
+ instance is mounting it), so a live run keeps its cache — re-run teardown once the run finishes.
338
+ Returns ``lambda:<region>/<name>`` per filesystem deleted. A missing/empty Lambda key is not an
339
+ error (nothing to reclaim) — it logs and returns ``[]``.
340
+ """
341
+ from flash.providers.lambdalabs import api as lambda_api
342
+ from flash.runner import WEIGHT_CACHE_VOLUME_NAME
343
+
344
+ target = name or WEIGHT_CACHE_VOLUME_NAME
345
+ deleted: list[str] = []
346
+ try:
347
+ fses = lambda_api.list_filesystems()
348
+ except Exception as exc:
349
+ logger.warning("teardown: lambda list_filesystems failed (skipping): %s", exc)
350
+ return deleted
351
+ for fs in fses:
352
+ if fs.get("name") == target and fs.get("id") and lambda_api.delete_filesystem(fs["id"]):
353
+ region = (fs.get("region") or {}).get("name") or "?"
354
+ deleted.append(f"lambda:{region}/{target}")
355
+ return deleted
356
+
357
+
358
+ def teardown_hyperstack_volumes(name: str | None = None) -> list[str]:
359
+ """Delete the Hyperstack cache volumes named ``name`` (default ``flash-weights``) across ALL
360
+ environments, reclaiming the standing block storage.
361
+
362
+ Best-effort and idempotent: a volume attached to a live VM won't delete — re-run once the run
363
+ finishes. Returns ``hyperstack:<env>/<name>`` per volume deleted. A missing Hyperstack key is not
364
+ an error — it logs and returns ``[]``.
365
+ """
366
+ from flash.providers.hyperstack import api as hs_api
367
+ from flash.runner import WEIGHT_CACHE_VOLUME_NAME
368
+
369
+ base = name or WEIGHT_CACHE_VOLUME_NAME
370
+ deleted: list[str] = []
371
+ try:
372
+ vols = hs_api.list_volumes()
373
+ except Exception as exc:
374
+ logger.warning("teardown: hyperstack list_volumes failed (skipping): %s", exc)
375
+ return deleted
376
+ # Allowlist of EXACT deterministic cache-fleet names — the per-region ``flash-weights-<region>``
377
+ # this code provisions, PLUS the legacy bare ``flash-weights`` from before per-region naming. A
378
+ # broad ``startswith(base + "-")`` prefix would also nuke unrelated user volumes like
379
+ # ``flash-weights-backup`` / ``flash-weights-test``, so match exact names only.
380
+ fleet = {base}
381
+ try:
382
+ fleet |= {hs_api.cache_volume_name(base, r) for r in hs_api.cache_regions()}
383
+ except Exception as exc:
384
+ # cache_regions() failed (API down / auth) — we genuinely cannot enumerate the canonical
385
+ # region set, so we CANNOT distinguish a fleet volume ``flash-weights-us-1`` from a user volume
386
+ # ``flash-weights-backup-1`` / ``flash-weights-test-1`` (both are region-shaped). FAVOR DATA
387
+ # SAFETY: do NOT guess-delete per-region volumes by pattern — a missed cache volume is just
388
+ # recoverable leftover billing, but deleting a user's volume is unrecoverable data loss. Delete
389
+ # ONLY the unambiguous legacy bare ``base`` name, and warn LOUDLY that the per-region cache
390
+ # volumes could not be enumerated and were LEFT INTACT (re-run once regions are reachable, or
391
+ # clean them manually). This still satisfies "failure is loud/observable, never silently
392
+ # narrowed" without the over-broad deletion.
393
+ logger.warning(
394
+ "teardown: hyperstack cache_regions failed (%s) — could NOT enumerate per-region cache "
395
+ "volumes; deleting only the legacy bare %r and LEAVING any per-region "
396
+ "flash-weights-<region> volumes INTACT (re-run teardown once regions are reachable, or "
397
+ "delete them manually). Refusing to pattern-match region-shaped names to avoid deleting "
398
+ "unrelated user volumes.",
399
+ exc, base,
400
+ )
401
+ for v in vols:
402
+ vname = v.get("name") or ""
403
+ if vname in fleet and v.get("id") and hs_api.delete_volume(v["id"]):
404
+ env = (v.get("environment") or {}).get("name") or "?"
405
+ deleted.append(f"hyperstack:{env}/{vname}")
406
+ return deleted
407
+
408
+
409
+ # Instance-provider WARM (Lambda + Hyperstack). RunPod warms via the serverless preload above; the
410
+ # instance providers have no serverless API, so a warm is a real (cheap, short) GPU launch in download
411
+ # -only mode: the bootstrap pulls the catalog into the mounted cache and exits (no worker). The box
412
+ # self-reports completion by uploading ``preload_result.json`` to a shared status repo, which the
413
+ # driver polls; the instance is ALWAYS terminated in a finally. Cheap class by default (the work is a
414
+ # download, not compute) — override with FLASH_PRELOAD_INSTANCE_GPU.
415
+ # Per-provider default warm GPU: a cheap class that the provider actually offers. A10 is LAMBDA-ONLY
416
+ # (no hyperstack_name), so using it for Hyperstack makes usable_instances("A10") empty/raise and
417
+ # Hyperstack is silently never warmed — pick L40 (a cheap Hyperstack datacenter card) there. An
418
+ # explicit --gpu / FLASH_PRELOAD_INSTANCE_GPU overrides BOTH.
419
+ _PRELOAD_INSTANCE_GPU = os.environ.get("FLASH_PRELOAD_INSTANCE_GPU") or "A10"
420
+ _PRELOAD_GPU_BY_PROVIDER = {"lambda": "A10", "hyperstack": "L40"}
421
+ # Shared dataset repo the preload boxes upload their status marker to (the driver polls it). The
422
+ # warmed WEIGHTS go to the per-region cache volume, NOT here — this holds only tiny status JSON.
423
+ _PRELOAD_STATUS_REPO = os.environ.get("FLASH_PRELOAD_STATUS_REPO") or "Freesolo-Co/flash-weight-preload"
424
+
425
+
426
+ def _ensure_status_repo(token: str | None) -> None:
427
+ """Create the preload status dataset repo if absent (the boxes upload their marker there).
428
+
429
+ RAISES on failure (missing/invalid HF_TOKEN, no access): the repo is the ONLY completion signal
430
+ — without it every launched box runs until timeout_s with no preload_result.json, so the warm
431
+ burns paid GPUs and reports nothing. Fail fast BEFORE launching instead of swallowing the error.
432
+ """
433
+ from huggingface_hub import HfApi
434
+
435
+ HfApi(token=token).create_repo(_PRELOAD_STATUS_REPO, repo_type="dataset", exist_ok=True, private=True)
436
+
437
+
438
+ def _preload_instance_spec(gpu: str, run_id: str, wall_s: int = 1800):
439
+ """A minimal download-only preload spec: cache attached, status marker repo, placeholder model
440
+ (the bootstrap warms ``payload['models']``, not ``spec.model``). ``wall_s`` is the worker wall cap
441
+ — thread the warm timeout in so a long catalog warm isn't killed at the hard-coded 30 min while
442
+ the driver is still polling."""
443
+ from flash.runner import WEIGHT_CACHE_VOLUME_GB, WEIGHT_CACHE_VOLUME_NAME
444
+ from flash.spec import JobSpec
445
+
446
+ return JobSpec.from_dict({
447
+ "model": "Qwen/Qwen3.5-0.8B", "algorithm": "sft", "run_id": run_id,
448
+ "train": {"hf_repo": _PRELOAD_STATUS_REPO, "seeds": [0]},
449
+ "gpu": {"type": gpu, "max_wall_seconds": max(60, int(wall_s)),
450
+ "network_volume": WEIGHT_CACHE_VOLUME_NAME, "network_volume_gb": WEIGHT_CACHE_VOLUME_GB},
451
+ })
452
+
453
+
454
+ def _warm_one_instance(provider: str, jobs_mod, candidate, models: list, gpu: str,
455
+ token: str | None, timeout_s: int, poll_interval_s: float) -> dict:
456
+ """Launch a download-only preload instance pinned to ``candidate``'s region, poll its status
457
+ marker, then ALWAYS terminate. One region failing never aborts the others."""
458
+ region = getattr(candidate, "region", "?")
459
+ # ONE effective budget shared by the worker wall cap AND the driver poll, so the two can't disagree.
460
+ # The worker spec floors the wall cap at 60s (a sub-minute cap can't even boot+download), so the
461
+ # driver must poll for that SAME floored budget — otherwise a `--timeout-s` under 60 would have the
462
+ # driver report timeout + terminate the box at e.g. 30s while the worker still had ~60s to finish,
463
+ # aborting an in-progress preload.
464
+ effective_s = max(60, int(timeout_s))
465
+ # Embed the wall-clock reap deadline in the name so an orphan sweep can free this box if THIS driver
466
+ # process dies before its ``finally`` (terminate_run_instances) — instance providers self-terminate
467
+ # nothing, so a lost driver would otherwise leak a billing box forever (see preload_box_reap_due).
468
+ reap_deadline = int(time.time()) + effective_s
469
+ run_id = preload_instance_run_id(provider, region, reap_deadline, uuid.uuid4().hex[:6])
470
+ spec = _preload_instance_spec(gpu, run_id, wall_s=effective_s)
471
+ prefix = f"{spec.phase}/{run_id}/seed0"
472
+ reader = make_hf_text_reader(_PRELOAD_STATUS_REPO, f"{prefix}/preload_result.json",
473
+ min_interval_s=max(5.0, poll_interval_s))
474
+ # ALSO watch the attempt-failure marker (<arm>_attempt0.json): if the box dies BEFORE run_preload
475
+ # uploads preload_result.json (docker/GPU never ready, image pull fails, the bootstrap crashes
476
+ # early), the worker/host failmark uploader still writes this terminal ok=false marker. Without
477
+ # watching it the driver would poll to the full effective_s on an already-dead box, burning paid
478
+ # GPU. The completion file is authoritative when present (success or partial), so check it FIRST.
479
+ fail_reader = make_hf_text_reader(_PRELOAD_STATUS_REPO, f"{prefix}/{provider}_attempt0.json",
480
+ min_interval_s=max(5.0, poll_interval_s))
481
+ try:
482
+ try:
483
+ jobs_mod.launch_and_submit(spec, seed=0, instances=[candidate], attempt=0,
484
+ mode="preload", models=models)
485
+ except Exception as exc: # no capacity / launch reject — skip this region (warm-on-first-run covers it)
486
+ return {"provider": provider, "region": region, "status": "error", "error": f"launch: {exc}"}
487
+ logger.info("warm %s/%s: launched preload (%d models)", provider, region, len(models))
488
+ deadline = time.time() + effective_s
489
+ text = None
490
+ while time.time() < deadline:
491
+ text = reader(force=True)
492
+ if text:
493
+ break
494
+ # No completion file yet — the terminal attempt marker is the backstop: ok=false means the
495
+ # box already died (stop polling, free it now), ok=true means the download SUCCEEDED but
496
+ # only the preload_result.json upload had a transient Hub blip (the worker still wrote a
497
+ # terminal ok=true marker), so the box is ALREADY warmed — short-circuit the wait instead
498
+ # of polling to the full budget then terminating a warmed box and reporting it timed out.
499
+ fail_text = fail_reader(force=True)
500
+ if fail_text:
501
+ try:
502
+ fail = json.loads(fail_text)
503
+ except Exception:
504
+ fail = {}
505
+ if fail.get("ok") is True:
506
+ # Terminal SUCCESS marker, completion file lost to a transient upload failure. Treat
507
+ # the marker itself as the result (the completion file is still authoritative when
508
+ # present, but it never landed here). "partial" if the marker carries an
509
+ # error/failed field, else "ok".
510
+ bad = fail.get("error") or fail.get("failed")
511
+ return {"provider": provider, "region": region,
512
+ "status": "partial" if bad else "ok", "result": fail}
513
+ if not fail.get("ok", True):
514
+ # The completion file (preload_result.json) is authoritative when present: a
515
+ # partial/failed-download run uploads it AND THEN writes the ok=false fail marker,
516
+ # so the marker can be visible an iteration before the completion file. Re-check
517
+ # the completion file ONE more time; if it's now there, fall through to the normal
518
+ # completion handling (-> "partial"/"ok") instead of mislabeling a completed
519
+ # (partial) preload as an early box death. Only the genuinely-still-absent case
520
+ # returns the early-death error.
521
+ text = reader(force=True)
522
+ if text:
523
+ break
524
+ return {"provider": provider, "region": region, "status": "error",
525
+ "error": f"box failed early: {fail.get('error') or 'see boot log'}"}
526
+ time.sleep(max(5.0, poll_interval_s))
527
+ if not text:
528
+ return {"provider": provider, "region": region, "status": "timeout"}
529
+ result = json.loads(text)
530
+ bad = result.get("error") or result.get("failed")
531
+ return {"provider": provider, "region": region,
532
+ "status": "partial" if bad else "ok", "result": result}
533
+ except Exception as exc:
534
+ return {"provider": provider, "region": region, "status": "error", "error": str(exc)}
535
+ finally:
536
+ with contextlib.suppress(Exception):
537
+ jobs_mod.terminate_run_instances(run_id)
538
+
539
+
540
+ def warm_instances(models: list | None = None, gpu: str | None = None,
541
+ providers: list | None = None, timeout_s: int = 1800,
542
+ poll_interval_s: float = 20.0, max_workers: int = 4) -> list[dict]:
543
+ """WARM the Lambda + Hyperstack caches: one download-only launch per region that currently has
544
+ capacity (regions with no capacity now are skipped — warm-on-first-run covers them). Each launch
545
+ is pinned to its region, polled to completion, and terminated. Best-effort: a provider with no key
546
+ / no capacity contributes nothing. Returns a status dict per region attempted.
547
+
548
+ NB: the preload logic itself ships in the cloud-init ``user_data`` — ``_instance.build_user_data``
549
+ reads the current ``_instance_bootstrap.py`` from the repo and embeds it, so every launch runs the
550
+ latest bootstrap (no image rebuild needed for the preload code). The only image requirement is the
551
+ HF download deps (huggingface_hub + hf_transfer), which the worker image already carries.
552
+ """
553
+ models = models or catalog_model_ids()
554
+ providers = providers or ["lambda", "hyperstack"]
555
+ token = os.environ.get("HF_TOKEN")
556
+
557
+ from flash.providers.hyperstack import api as hs_api
558
+ from flash.providers.hyperstack import jobs as hs_jobs
559
+ from flash.providers.lambdalabs import jobs as lambda_jobs
560
+
561
+ mods = {"lambda": lambda_jobs, "hyperstack": hs_jobs}
562
+ # Per-provider "can this region host the cache?" predicate. Skipping a cache-incapable region (e.g.
563
+ # Hyperstack CANADA-2, excluded from cache_regions()) BEFORE launching avoids burning a paid GPU
564
+ # whose preload just reports "weight cache not supported in region" — which main() then counts as a
565
+ # failed warm, so the default --warm-instances would fail even when every cache-capable region
566
+ # succeeded. Lambda exposes no such filter (every region hosts filesystems), so it stays unfiltered.
567
+ region_ok = {"hyperstack": hs_api.region_supports_cache}
568
+ # One launch per region (dedupe so two candidates in a region don't double-launch — block volumes
569
+ # are single-attach anyway). Each entry carries its provider's resolved GPU (an explicit override
570
+ # applies to all; otherwise the per-provider default — so A10 doesn't silently skip Hyperstack).
571
+ targets: list = []
572
+ for provider in providers:
573
+ jobs_mod = mods.get(provider)
574
+ if jobs_mod is None:
575
+ continue
576
+ provider_gpu = gpu or _PRELOAD_GPU_BY_PROVIDER.get(provider, _PRELOAD_INSTANCE_GPU)
577
+ cache_capable = region_ok.get(provider)
578
+ seen_regions: set = set()
579
+ try:
580
+ candidates = jobs_mod.usable_instances(provider_gpu)
581
+ except Exception as exc:
582
+ logger.warning("warm %s: usable_instances(%s) failed (skipping): %s", provider, provider_gpu, exc)
583
+ continue
584
+ for c in candidates:
585
+ if c.region in seen_regions:
586
+ continue
587
+ # Skip regions that can't host the cache for this provider — the preload would just report
588
+ # "weight cache not supported in region" and be counted as a failed warm.
589
+ if cache_capable is not None and not cache_capable(c.region):
590
+ logger.info("warm %s: skipping cache-incapable region %s", provider, c.region)
591
+ seen_regions.add(c.region)
592
+ continue
593
+ seen_regions.add(c.region)
594
+ targets.append((provider, jobs_mod, c, provider_gpu))
595
+ if not targets:
596
+ logger.warning("warm: no Lambda/Hyperstack capacity right now (nothing to warm)")
597
+ return []
598
+ # Fail fast BEFORE launching any paid GPU: the status repo is the only completion signal, so if it
599
+ # can't be created/accessed (missing/invalid HF_TOKEN) every box would just run to timeout warming
600
+ # nothing observable. Surface a clear error instead of silently burning instances. Done only AFTER
601
+ # the target list is built and the no-targets early-return above, so an empty warm (no capacity /
602
+ # provider not configured) stays a harmless no-op and doesn't hard-fail on a missing HF_TOKEN.
603
+ try:
604
+ _ensure_status_repo(token)
605
+ except Exception as exc:
606
+ raise RuntimeError(
607
+ f"preload status repo {_PRELOAD_STATUS_REPO!r} unavailable ({exc}); set a valid HF_TOKEN "
608
+ "with write access before warming (refusing to launch paid GPUs that can't report)."
609
+ ) from exc
610
+ with ThreadPoolExecutor(max_workers=max_workers) as ex:
611
+ futs = [
612
+ ex.submit(_warm_one_instance, provider, jobs_mod, c, models, provider_gpu, token, timeout_s, poll_interval_s)
613
+ for (provider, jobs_mod, c, provider_gpu) in targets
614
+ ]
615
+ return [f.result() for f in as_completed(futs)]
616
+
617
+
618
+ def provision_lambda_filesystems(name: str | None = None) -> list[str]:
619
+ """Eagerly create the ``flash-weights`` filesystem in every Lambda region ``all_regions()`` can
620
+ enumerate (create-if-absent), so the cache storage exists before runs land — pure control-plane
621
+ API, no GPU.
622
+
623
+ NB: Lambda has no standalone region list, so ``all_regions()`` is the UNION of regions currently
624
+ advertising capacity across instance types — a region advertising ZERO capacity right now won't be
625
+ covered here. That's fine: the launch-time ``ensure_filesystem`` backstop creates the FS the moment
626
+ a run actually lands in such a region. So this is a best-effort eager warm, not a hard guarantee of
627
+ coverage in literally every region Lambda might ever expose.
628
+
629
+ Idempotent (``ensure_filesystem`` reuses an existing same-name FS). Returns ``lambda:<region>``
630
+ per region provisioned. A missing/empty Lambda key is not an error (logs + returns ``[]``); a
631
+ per-region failure is logged and skipped so one bad region never aborts the rest.
632
+ """
633
+ from flash.providers.lambdalabs import api as lambda_api
634
+ from flash.runner import WEIGHT_CACHE_VOLUME_NAME
635
+
636
+ target = name or WEIGHT_CACHE_VOLUME_NAME
637
+ done: list[str] = []
638
+ try:
639
+ regions = lambda_api.all_regions()
640
+ except Exception as exc:
641
+ logger.warning("provision: lambda all_regions failed (skipping): %s", exc)
642
+ return done
643
+ for region in regions:
644
+ try:
645
+ lambda_api.ensure_filesystem(target, region)
646
+ done.append(f"lambda:{region}")
647
+ except Exception as exc:
648
+ logger.warning("provision: lambda ensure_filesystem(%s, %s) failed: %s", target, region, exc)
649
+ return done
650
+
651
+
652
+ def provision_hyperstack_volumes(name: str | None = None, size_gb: int | None = None) -> list[str]:
653
+ """Eagerly create the ``flash-weights`` block volume in EVERY Hyperstack environment
654
+ (create-if-absent), so the cache storage exists before any run lands — pure control-plane API, no
655
+ GPU.
656
+
657
+ Idempotent (``ensure_volume`` reuses an existing same-name volume in the env). Returns
658
+ ``hyperstack:<env>`` per environment provisioned. A missing Hyperstack key is not an error; a
659
+ per-environment failure is logged and skipped.
660
+ """
661
+ from flash.providers.hyperstack import api as hs_api
662
+ from flash.runner import WEIGHT_CACHE_VOLUME_GB, WEIGHT_CACHE_VOLUME_NAME
663
+
664
+ base = name or WEIGHT_CACHE_VOLUME_NAME
665
+ gb = int(size_gb or WEIGHT_CACHE_VOLUME_GB)
666
+ done: list[str] = []
667
+ try:
668
+ # cache_regions() drops volume-incapable regions (e.g. CANADA-2) so we don't burn a
669
+ # guaranteed-400 create on a region that can't host the cache anyway.
670
+ regions = hs_api.cache_regions()
671
+ except Exception as exc:
672
+ logger.warning("provision: hyperstack cache_regions failed (skipping): %s", exc)
673
+ return done
674
+ # One PER-REGION volume (Hyperstack names are globally unique — see cache_volume_name), created in
675
+ # that region's default environment.
676
+ for region in regions:
677
+ try:
678
+ env = hs_api.environment_for_region(region)
679
+ vol_name = hs_api.cache_volume_name(base, region)
680
+ vol_id = hs_api.ensure_volume(vol_name, env, gb)
681
+ # ensure_volume returns the volume id; a falsy id means create-or-confirm did NOT yield a
682
+ # real volume (e.g. the API responded without an id). Don't record that region as
683
+ # provisioned — otherwise --provision reports success and the launch path treats a
684
+ # never-created region as warm.
685
+ if not vol_id:
686
+ logger.warning("provision: hyperstack ensure_volume(%s, %s) returned no id — region not "
687
+ "provisioned", vol_name, region)
688
+ continue
689
+ done.append(f"hyperstack:{region}")
690
+ except Exception as exc:
691
+ logger.warning("provision: hyperstack ensure_volume(%s, %s) failed: %s", base, region, exc)
692
+ return done
693
+
694
+
695
+ def provision_all() -> list[str]:
696
+ """Eagerly create the cache storage on every instance provider, in every region/environment
697
+ (pure control-plane API, no GPU). RunPod's per-DC network volumes are NOT provisioned here: they
698
+ are create-or-attached automatically by the eager endpoint deploy (jobs.weight_cache_volumes
699
+ covers every storage DC) and warmed by ``warm_weight_cache`` — there is no GPU-free RunPod
700
+ volume-create in the SDK. Returns ``provider:<region/env>`` per storage created/confirmed."""
701
+ provisioned = provision_lambda_filesystems()
702
+ provisioned += provision_hyperstack_volumes()
703
+ return provisioned
704
+
705
+
706
+ def main(argv: list[str] | None = None) -> int:
707
+ ap = argparse.ArgumentParser(description="Preload the flash weight-cache volumes.")
708
+ ap.add_argument("--models", help="comma-separated HF model ids (default: whole catalog)")
709
+ ap.add_argument("--datacenters", help="comma-separated DC ids (default: all storage DCs)")
710
+ ap.add_argument(
711
+ "--gpu", default=None,
712
+ help="GPU class for the preload worker. Defaults are per-mode (RunPod warm -> "
713
+ f"{_PRELOAD_GPU!r}; --warm-instances -> {_PRELOAD_INSTANCE_GPU!r}); pass this to override "
714
+ "either. Defaulting to None (not a sentinel string) lets you explicitly pick even the "
715
+ "per-mode default GPU without it being mistaken for 'no override'.",
716
+ )
717
+ ap.add_argument("--timeout-s", type=int, default=1800, help="per-DC job timeout")
718
+ ap.add_argument(
719
+ "--max-workers", type=int, default=4,
720
+ help="datacenters warmed concurrently. Each one deploys a preload endpoint, so this MUST stay "
721
+ "under your RunPod endpoint/worker quota (the documented default is 5); the default of 4 "
722
+ "leaves a 1-slot buffer. Raise it only if your account quota is higher.",
723
+ )
724
+ ap.add_argument("--dry-run", action="store_true", help="print the plan, provision nothing")
725
+ ap.add_argument(
726
+ "--provision", action="store_true",
727
+ help="CREATE the Lambda/Hyperstack cache storage in every region/env (pure API, no GPU) and "
728
+ "exit; RunPod volumes are auto-created by the eager deploy/warm. Run before --teardown's "
729
+ "inverse to set up all storage up front.",
730
+ )
731
+ ap.add_argument(
732
+ "--warm-instances", action="store_true",
733
+ help="WARM the Lambda + Hyperstack caches: one download-only GPU launch per region with "
734
+ "capacity now (needs the merged worker image carrying the bootstrap preload branch).",
735
+ )
736
+ ap.add_argument(
737
+ "--teardown", action="store_true",
738
+ help="DELETE the weight-cache storage on every provider (reclaim standing storage) and exit. "
739
+ "With --datacenters it is SCOPED to that RunPod-DC subset only (Lambda/Hyperstack caches "
740
+ "are left intact, since DC ids don't map to their region/env namespace).",
741
+ )
742
+ args = ap.parse_args(argv)
743
+
744
+ # The mode flags are MUTUALLY EXCLUSIVE: each selects a different exit-early branch below, and the
745
+ # branch order (provision -> warm-instances -> teardown -> default RunPod warm) silently picks ONE
746
+ # when several are set — e.g. `--teardown --warm-instances` would launch paid warm jobs (the warm
747
+ # branch runs first) instead of deleting caches, AND bypass the off-catalog --models check (the
748
+ # teardown exemption short-circuits it). Reject the conflict up front so the off-catalog gate always
749
+ # applies to whichever warm branch actually executes. The default RunPod warm has no flag, so it's
750
+ # only reachable when NONE of these are set — it can't conflict.
751
+ selected_modes = [
752
+ name for name, on in (
753
+ ("--provision", args.provision),
754
+ ("--warm-instances", args.warm_instances),
755
+ ("--teardown", args.teardown),
756
+ ) if on
757
+ ]
758
+ if len(selected_modes) > 1:
759
+ ap.error(f"{', '.join(selected_modes)} are mutually exclusive — pass exactly one mode")
760
+
761
+ catalog = catalog_model_ids()
762
+ models = [m.strip() for m in args.models.split(",") if m.strip()] if args.models else catalog
763
+ # Confidentiality gate: an explicit --models override may ONLY name public catalog ids on the paths
764
+ # that actually DOWNLOAD weights into the shared cache (the default RunPod warm + --warm-instances).
765
+ # Warming an arbitrary (private/gated) repo with the operator HF_TOKEN would leave those weights on
766
+ # the platform-wide WRITABLE shared cache for every other tenant — bypassing the same catalog gate
767
+ # the normal run path enforces. Reject any off-catalog id BEFORE launching any preload worker.
768
+ # --teardown (only deletes) and --provision (only CREATES empty storage — downloads NOTHING) are
769
+ # both exempt: neither reaches a download path, so an off-catalog id there is harmless.
770
+ if args.models and not args.teardown and not args.provision:
771
+ off_catalog = [m for m in models if m not in set(catalog)]
772
+ if off_catalog:
773
+ print("--models: refusing to preload off-catalog model id(s) into the shared cache: "
774
+ f"{', '.join(off_catalog)} — only public catalog models may be warmed (private/gated "
775
+ "repos would leak onto the platform-wide shared volume). They download cold on first "
776
+ "use instead.")
777
+ return 2
778
+ # Parse --datacenters ONCE. `scoped` means the operator actually narrowed to >=1 real DC id — NOT
779
+ # merely that the flag was present. A flag that parses to NOTHING (e.g. `--datacenters ""`, all
780
+ # whitespace/commas, or an all-invalid list) must be an ERROR, never a silent full teardown: it
781
+ # would otherwise (a) hit teardown_weight_cache's `datacenters or <all>` fallback and delete EVERY
782
+ # RunPod cache, while (b) the present-but-empty flag skipped the instance-provider cleanup.
783
+ # argparse default is None when --datacenters is OMITTED; an empty/whitespace/all-comma STRING is
784
+ # still "provided" (`is not None`) but parses to zero ids. Use `is not None` — NOT truthiness — so
785
+ # `--datacenters ""` is caught too (bool("") is False).
786
+ dcs_given = args.datacenters is not None
787
+ parsed_dcs = (
788
+ [d.strip() for d in args.datacenters.split(",") if d.strip()] if dcs_given else []
789
+ )
790
+ if dcs_given and not parsed_dcs:
791
+ print("--datacenters was given but parsed to no datacenter ids — refusing to run "
792
+ "(an empty scope would delete the WHOLE RunPod fleet); drop --datacenters for a full "
793
+ "teardown, or pass real DC ids.")
794
+ return 2
795
+ scoped = bool(parsed_dcs) # a real RunPod-DC subset -> RunPod-only scope
796
+
797
+ # LAZY default RunPod-DC list. weight_cache_datacenters() imports runpod_flash, so resolving it
798
+ # eagerly here would crash --provision / --warm-instances / --teardown --dry-run on an instance-only
799
+ # control plane (no/broken RunPod SDK) — modes that never touch a RunPod DC, or are non-destructive.
800
+ # Resolve it ONLY inside the branches that actually warm or tear down RunPod without an explicit
801
+ # --datacenters scope. When `scoped`, `parsed_dcs` is used directly and this is never called.
802
+ def _default_dcs() -> list[str]:
803
+ return [dc.value for dc in weight_cache_datacenters()]
804
+
805
+ if args.provision:
806
+ # Eagerly create the instance-provider cache storage in every region/env (GPU-free). RunPod's
807
+ # per-DC fleet materializes on the next eager endpoint deploy / warm, so it's not created here.
808
+ if args.dry_run:
809
+ print("would provision Lambda filesystems + Hyperstack volumes in every region/env")
810
+ return 0
811
+ provisioned = provision_all()
812
+ print(f"provisioned {len(provisioned)} instance-provider cache store(s): "
813
+ f"{', '.join(provisioned) or '(none — no Lambda/Hyperstack key, or no regions)'}")
814
+ return 0
815
+ if args.warm_instances:
816
+ if args.dry_run:
817
+ print("would warm Lambda + Hyperstack caches (one download-only launch per region with capacity)")
818
+ return 0
819
+ # gpu=None lets warm_instances apply its own per-mode default (_PRELOAD_INSTANCE_GPU). Passing
820
+ # args.gpu directly (no sentinel comparison) means an explicit --gpu, even RTX 4090, overrides.
821
+ results = warm_instances(models=models, gpu=args.gpu,
822
+ timeout_s=args.timeout_s, max_workers=args.max_workers)
823
+ if not results:
824
+ # NOT the same as "warmed everything": zero launch targets means no Lambda/Hyperstack
825
+ # region had capacity to warm right now (or every candidate region is cache-incapable —
826
+ # each such skip is logged above). This is a best-effort no-op, not a failure: those
827
+ # regions' weights simply download cold on first run. Make it explicit so "0/0" isn't read
828
+ # as success. (See the per-region "skipping cache-incapable region" / "no capacity" logs.)
829
+ print("0 regions warmed — no Lambda/Hyperstack region had capacity to warm right now "
830
+ "(weights download cold on first run). Nothing launched.")
831
+ return 0
832
+ failed = [r for r in results if r.get("status") not in ("ok",)]
833
+ for r in results:
834
+ print(f" {r['provider']}/{r['region']}: {r['status']}"
835
+ + (f" ({r.get('error')})" if r.get("error") else ""))
836
+ print(f"{len(results) - len(failed)}/{len(results)} regions warmed")
837
+ return 1 if failed else 0
838
+ if args.teardown:
839
+ # Validate any scoped DC ids BEFORE deleting anything: an invalid id (typo in teardown
840
+ # automation) must fail loudly with a non-zero exit, NOT get swallowed by the best-effort
841
+ # catch below and report success while deleting nothing / leaving the billed fleet in place.
842
+ if scoped:
843
+ from runpod_flash.core.resources.datacenter import DataCenter
844
+ bad = []
845
+ for d in parsed_dcs:
846
+ try:
847
+ DataCenter.from_string(d)
848
+ except Exception:
849
+ bad.append(d)
850
+ if bad:
851
+ print(f"--teardown --datacenters: invalid datacenter id(s): {', '.join(bad)} "
852
+ "— refusing to run (nothing deleted)")
853
+ return 2
854
+ if args.dry_run:
855
+ # `--teardown --dry-run` must only PRINT the plan — never call the destructive helpers AND
856
+ # never resolve the full RunPod DC list (weight_cache_datacenters imports runpod_flash):
857
+ # describe the scope abstractly when unscoped so this stays usable on an instance-only host.
858
+ scope_desc = (f"{len(parsed_dcs)} datacenter(s): {', '.join(parsed_dcs)}"
859
+ if scoped else "every RunPod storage datacenter")
860
+ print(f"would delete the RunPod weight-cache volumes in {scope_desc}"
861
+ + ("" if scoped else " + every Lambda filesystem + Hyperstack volume named flash-weights"))
862
+ return 0
863
+ # Reclaim the cache storage on EVERY provider: RunPod network volumes, Lambda filesystems,
864
+ # and Hyperstack block volumes. Each provider is guarded INDEPENDENTLY so one provider's
865
+ # failure (e.g. RunPod auth absent/broken on an instance-only control plane, or a RunPod
866
+ # outage) never aborts the others' best-effort cleanup — otherwise their billed caches would
867
+ # leak behind a single RunPod error. A provider with no configured key is already a no-op.
868
+ deleted: list[str] = []
869
+ try:
870
+ # Pass the scoped list, or None for a full teardown — teardown_weight_cache resolves the
871
+ # default DC fleet itself (lazily, and only AFTER it confirms a RunPod key is configured),
872
+ # so an instance-only control plane never imports the RunPod SDK here.
873
+ deleted += teardown_weight_cache(parsed_dcs or None)
874
+ except Exception as exc:
875
+ logger.warning("teardown: RunPod cache teardown failed (continuing): %s", exc)
876
+ # `--datacenters` is a RunPod-DC subset and has no meaning for the instance providers (Lambda
877
+ # regions / Hyperstack envs are a different namespace), so a SCOPED teardown stays RunPod-only
878
+ # rather than unexpectedly deleting every Lambda/Hyperstack `flash-weights` cache too. Only a
879
+ # FULL teardown (no --datacenters) reclaims the instance-provider caches.
880
+ if not scoped:
881
+ try:
882
+ deleted += teardown_lambda_filesystems()
883
+ except Exception as exc:
884
+ logger.warning("teardown: Lambda cache teardown failed (continuing): %s", exc)
885
+ try:
886
+ deleted += teardown_hyperstack_volumes()
887
+ except Exception as exc:
888
+ logger.warning("teardown: Hyperstack cache teardown failed (continuing): %s", exc)
889
+ else:
890
+ print("scoped teardown (--datacenters): RunPod-only; Lambda/Hyperstack caches left intact")
891
+ print(f"deleted {len(deleted)} weight-cache volume(s): {', '.join(deleted) or '(none)'}")
892
+ return 0
893
+ # Default mode = warm the RunPod serverless fleet. This is the ONE path that genuinely needs the
894
+ # RunPod DC list (and the SDK), so resolve the lazy default here rather than eagerly above.
895
+ dcs = parsed_dcs or _default_dcs()
896
+ if args.dry_run:
897
+ print(f"would warm {len(dcs)} datacenter(s): {', '.join(dcs)}")
898
+ print(f"with {len(models)} model(s): {', '.join(models)}")
899
+ return 0
900
+
901
+ results = warm_weight_cache(
902
+ # args.gpu defaults to None -> fall back to the RunPod warm default here so None never reaches
903
+ # _preload_one_dc / deploy_train_endpoint; an explicit --gpu (incl. RTX 4090) still overrides.
904
+ models=models, datacenters=dcs, gpu=args.gpu or _PRELOAD_GPU,
905
+ timeout_s=args.timeout_s, max_workers=args.max_workers,
906
+ )
907
+ failed = [r for r in results if r.get("status") != "ok"]
908
+ for r in results:
909
+ print(f" {r['datacenter']}: {r['status']}" + (f" ({r.get('error')})" if r.get("error") else ""))
910
+ print(f"{len(results) - len(failed)}/{len(results)} datacenters warmed")
911
+ return 1 if failed else 0
912
+
913
+
914
+ if __name__ == "__main__":
915
+ raise SystemExit(main())