freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,122 @@
1
+ """Pure, monkeypatch-free building blocks for the Hyperstack run lifecycle.
2
+
3
+ The Hyperstack-specific leaf of ``flash.providers.hyperstack.jobs``: the normalized dataclasses
4
+ (``HyperstackInstance``, ``HyperstackJobHandle``) and the image accessor. The cross-provider
5
+ pieces (sweep label, bootstrap payload, cloud-init ``user_data``) are shared with Lambda in
6
+ ``flash.providers._instance`` and re-exported here.
7
+
8
+ This module MUST NOT import the ``jobs`` package ``__init__`` (it is imported BY it).
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from dataclasses import dataclass
14
+
15
+ from flash.providers._instance import (
16
+ build_payload as _shared_build_payload,
17
+ )
18
+ from flash.providers._instance import (
19
+ build_user_data as _shared_build_user_data,
20
+ )
21
+ from flash.providers._instance import (
22
+ instance_label,
23
+ run_label_prefix,
24
+ )
25
+
26
+ __all__ = [
27
+ "HyperstackInstance",
28
+ "HyperstackJobHandle",
29
+ "build_payload",
30
+ "build_user_data",
31
+ "hyperstack_image",
32
+ "instance_label",
33
+ "run_label_prefix",
34
+ ]
35
+
36
+
37
+ @dataclass(frozen=True)
38
+ class HyperstackInstance:
39
+ """A launchable (region, flavor, $/hr) for a managed GPU class (the Hyperstack analog of a
40
+ Lambda instance candidate)."""
41
+
42
+ gpu: str # canonical class name (GPU_INFO key)
43
+ flavor: str # Hyperstack flavor name (e.g. "n3-L40x1")
44
+ region: str
45
+ environment: str # default-<region>
46
+ vram_gb: int
47
+ price_usd_hr: float
48
+
49
+
50
+ @dataclass
51
+ class HyperstackJobHandle:
52
+ """Persisted in RunStatus.remote so any process can reattach/cancel (cf. base.JobHandle)."""
53
+
54
+ vm_id: str
55
+ flavor: str
56
+ region: str
57
+ name: str # the sweep-matchable VM name (run-derived; see ``instance_label``)
58
+ gpu: str
59
+ hourly_usd: float
60
+ attempt: int
61
+ started_ts: float
62
+
63
+ def to_dict(self) -> dict:
64
+ return {
65
+ "provider": "hyperstack",
66
+ "vm_id": self.vm_id,
67
+ "flavor": self.flavor,
68
+ "region": self.region,
69
+ "name": self.name,
70
+ "gpu": self.gpu,
71
+ "hourly_usd": self.hourly_usd,
72
+ "attempt": self.attempt,
73
+ "started_ts": self.started_ts,
74
+ }
75
+
76
+ @classmethod
77
+ def from_dict(cls, d: dict) -> HyperstackJobHandle:
78
+ return cls(
79
+ vm_id=str(d["vm_id"]),
80
+ flavor=str(d.get("flavor") or ""),
81
+ region=str(d.get("region") or ""),
82
+ name=str(d.get("name") or ""),
83
+ gpu=str(d.get("gpu") or ""),
84
+ hourly_usd=float(d.get("hourly_usd") or 0),
85
+ attempt=int(d.get("attempt") or 0),
86
+ started_ts=float(d.get("started_ts") or 0),
87
+ )
88
+
89
+
90
+ def hyperstack_image(gpu: str | None = None) -> str:
91
+ """Docker image the cloud-init runs on the Hyperstack host: the prebuilt, PUBLIC ``WORKER_IMAGE``
92
+ (the byte-identical training stack RunPod bakes). ``FLASH_WORKER_IMAGE`` overrides it; when the
93
+ operator opts into per-SM warmed images (``FLASH_WORKER_IMAGE_PER_SM`` /
94
+ ``FLASH_WORKER_IMAGE_TEMPLATE``), the GPU class selects the matching ``-smXX`` tag so the worker's
95
+ baked kernel cache matches the rented GPU's arch (the same selector RunPod uses). NB: this is the
96
+ *container* image; the Hyperstack VM *boot* image (Docker-preinstalled Ubuntu/CUDA) is chosen
97
+ separately in ``api.docker_image_for_region``."""
98
+ from flash.providers.runpod.train import WORKER_IMAGE, worker_image_for_gpu
99
+
100
+ # allow_default=True -> always a concrete image to docker-pull (override / per-sm tag / base).
101
+ return worker_image_for_gpu(gpu, allow_default=True) or WORKER_IMAGE
102
+
103
+
104
+ def build_payload(
105
+ spec, seed: int, attempt: int, runtime_secrets: dict | None = None,
106
+ cache_host_mount: str | None = None, cache_block_device: bool = False,
107
+ mode: str | None = None, models: list | None = None,
108
+ ) -> dict:
109
+ """The Hyperstack bootstrap payload (shared builder, arm='hyperstack'). ``cache_host_mount`` (the
110
+ host path the attached block volume is formatted+mounted at) points HF_HOME at it;
111
+ ``cache_block_device`` enables the cloud-init wait-for-device/format/mount preamble.
112
+ ``mode='preload'`` + ``models`` makes it a download-only warm payload (no worker)."""
113
+ return _shared_build_payload(
114
+ spec, seed, attempt, arm="hyperstack", runtime_secrets=runtime_secrets,
115
+ cache_host_mount=cache_host_mount, cache_block_device=cache_block_device,
116
+ mode=mode, models=models,
117
+ )
118
+
119
+
120
+ def build_user_data(payload: dict, *, gpu: str | None = None) -> str:
121
+ """The Hyperstack cloud-init user_data (shared builder, runs the worker WORKER_IMAGE)."""
122
+ return _shared_build_user_data(payload, image=hyperstack_image(gpu))
@@ -0,0 +1,23 @@
1
+ """Fail-fast credential checks for the Hyperstack substrate (operator-side).
2
+
3
+ Mirrors ``providers/lambdalabs/preflight.py``. Hyperstack is OPT-IN (the allocator only reaches for
4
+ it when ``HYPERSTACK_API_KEY`` is set), so the only Hyperstack-specific requirement is the key;
5
+ HF_TOKEN is the shared run requirement checked once centrally.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from flash.providers.hyperstack.auth import load_api_key
11
+
12
+
13
+ def missing_credentials(require_hf: bool = True) -> list[str]:
14
+ """Hyperstack-related operator config that is missing (empty list == ready).
15
+
16
+ ``require_hf`` is accepted only for signature parity with the RunPod check and is ignored.
17
+ """
18
+ problems: list[str] = []
19
+ if not load_api_key():
20
+ problems.append(
21
+ " - HYPERSTACK_API_KEY: the operator's Hyperstack API key (for the hyperstack provider)"
22
+ )
23
+ return problems
@@ -0,0 +1,26 @@
1
+ """Hyperstack $/hr: static list price per class (the flavors API carries no price field).
2
+
3
+ Unlike Lambda's ``/instance-types`` (which carries live prices), Hyperstack's ``/core/flavors``
4
+ exposes capacity (``stock_available``) but not price, so rates come from the published list-price
5
+ snapshot below. Offline-safe and stable (Hyperstack prices are fixed list rates, not an auction).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ # Hyperstack list prices (snapshot 2026-06-25, hyperstack.cloud/gpu-pricing), per single GPU.
11
+ _STATIC_RATES: dict[str, float] = {
12
+ "RTX A6000": 0.50,
13
+ "L40": 1.00,
14
+ "A100 PCIe": 1.35,
15
+ "H100": 1.90,
16
+ "RTX Pro 6000": 1.80,
17
+ }
18
+
19
+
20
+ def hourly_rate(gpu_name: str) -> float:
21
+ """$/hr for one friendly GPU name on Hyperstack (static list price)."""
22
+ from flash.providers.base import GPU_INFO, canonical_gpu
23
+
24
+ name = canonical_gpu(gpu_name)
25
+ # Prefer the Hyperstack snapshot; fall back to the class nominal rate (keeps the call total).
26
+ return _STATIC_RATES.get(name) or GPU_INFO[name].hourly_usd
@@ -0,0 +1,25 @@
1
+ """Hyperstack train submission: build the VM payload + submit a run.
2
+
3
+ The worker stack/env is substrate-neutral, so the per-run worker env and dependency resolution are
4
+ shared with RunPod; this module owns the Hyperstack-specific submission entrypoint. Provisioning,
5
+ polling, and teardown live in ``providers/hyperstack/jobs``.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from flash.providers.hyperstack.jobs import build_payload, submit_run_hyperstack
11
+ from flash.providers.runpod.train import (
12
+ WORKER_DEPS,
13
+ WORKER_SYSTEM_DEPS,
14
+ build_worker_env,
15
+ resolve_worker_deps,
16
+ )
17
+
18
+ __all__ = [
19
+ "WORKER_DEPS",
20
+ "WORKER_SYSTEM_DEPS",
21
+ "build_payload",
22
+ "build_worker_env",
23
+ "resolve_worker_deps",
24
+ "submit_run_hyperstack",
25
+ ]
@@ -0,0 +1,139 @@
1
+ """Lambda Cloud provider: single-GPU instances bootstrapped via cloud-init (the instance-based
2
+ complement to RunPod's serverless Flash endpoints).
3
+
4
+ Fine-tuning runs on a Lambda Cloud GPU instance launched by Flash. The instance's cloud-init
5
+ ``user_data`` runs the prebuilt, PUBLIC ``WORKER_IMAGE`` via Docker (the byte-identical training
6
+ stack RunPod bakes), which executes ``flash.engine.worker`` on the GPU; completion is detected
7
+ purely from the worker's HF artifacts (no inbound network, no serverless queue). It implements the
8
+ SAME ``base.Provider`` interface as RunPod, so the orchestrator/allocator treat the two
9
+ interchangeably.
10
+
11
+ ``PROVIDER`` is the ``base.Provider`` implementation the registry hands out.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from collections.abc import Callable
17
+ from typing import Any
18
+
19
+ from flash.providers.base import GpuClass, JobHandle, PollResult, Provider
20
+
21
+
22
+ class LambdaProvider:
23
+ """``base.Provider`` for the Lambda Cloud substrate."""
24
+
25
+ name = "lambda"
26
+
27
+ def is_configured(self) -> bool:
28
+ from flash.providers.lambdalabs.auth import load_api_key
29
+
30
+ # Lambda is an opt-in instance substrate: it is available only when its operator key is
31
+ # present. Without LAMBDA_API_KEY (tests / CI / RunPod-only operators) allocation degrades
32
+ # deterministically to RunPod's catalog — exactly the prior RunPod-only behavior.
33
+ return load_api_key() is not None
34
+
35
+ def preflight(self, require_hf: bool = True) -> list[str]:
36
+ from flash.providers.lambdalabs.preflight import missing_credentials
37
+
38
+ return missing_credentials(require_hf=require_hf)
39
+
40
+ def gpu_classes(self) -> list[GpuClass]:
41
+ from flash.providers.lambdalabs.gpus import gpu_classes
42
+
43
+ return gpu_classes()
44
+
45
+ def hourly_rate(self, gpu: str) -> float:
46
+ from flash.providers.lambdalabs.pricing import hourly_rate
47
+
48
+ return hourly_rate(gpu)
49
+
50
+ def submit_run(
51
+ self,
52
+ spec,
53
+ seed: int,
54
+ *,
55
+ log: Any = None,
56
+ on_handle: Any = None,
57
+ attempt: int = 0,
58
+ runtime_secrets: dict[str, str] | None = None,
59
+ on_last_gpu: bool = False,
60
+ ) -> PollResult:
61
+ from flash.providers.lambdalabs.jobs import submit_run_lambda
62
+
63
+ return submit_run_lambda(
64
+ spec,
65
+ seed,
66
+ log=log,
67
+ on_handle=on_handle,
68
+ attempt=attempt,
69
+ runtime_secrets=runtime_secrets,
70
+ on_last_gpu=on_last_gpu,
71
+ )
72
+
73
+ def poll(self, handle: JobHandle, spec, seed: int, *, log: Any = None) -> PollResult:
74
+ import contextlib
75
+
76
+ from flash.providers.lambdalabs import api as lambda_api
77
+ from flash.providers.lambdalabs.jobs import (
78
+ PROVISION_GRACE_S,
79
+ LambdaJobHandle,
80
+ poll_lambda_job,
81
+ )
82
+ from flash.providers.runpod.jobs import make_hf_heartbeat_reader
83
+
84
+ hf_repo = spec.train.hf_repo
85
+ prefix = f"{spec.phase}/{spec.run_id}/seed{seed}"
86
+ reader = make_hf_heartbeat_reader(hf_repo, prefix) if hf_repo else None
87
+ lh = LambdaJobHandle.from_dict(handle.to_dict())
88
+ if log is not None:
89
+ print(f"attaching: lambda instance={lh.instance_id}", file=log, flush=True)
90
+ # The wall-cap deadline counts from the instance's LAUNCH, not from this reattach — Lambda
91
+ # has no server-side execution timeout, so resetting it on every recovery would let a
92
+ # control-plane restart extend the billable window unbounded. The poll loop already anchors
93
+ # its deadline check to ``handle.started_ts`` (start = launch), so we pass the FULL
94
+ # launch-relative budget here; pre-subtracting elapsed too would double-count and tear down
95
+ # a still-valid instance the moment a recovered run is past half its window.
96
+ deadline = max(60.0, int(spec.gpu.max_wall_seconds) + PROVISION_GRACE_S)
97
+ try:
98
+ return poll_lambda_job(lh, spec, seed, log=log, heartbeat_reader=reader, deadline_s=deadline)
99
+ finally:
100
+ # Recovery (attach_run) has no submit_run_lambda teardown ``finally``; terminate the
101
+ # reattached instance here so a finished/abandoned recovered seed stops billing
102
+ # immediately instead of idling until the whole run ends.
103
+ with contextlib.suppress(Exception):
104
+ lambda_api.terminate_instances([lh.instance_id])
105
+
106
+ def cancel(self, handle: JobHandle) -> None:
107
+ # Terminating the instance both stops the job and tears down the (only) billable resource —
108
+ # Lambda has no separate "cancel job" vs "destroy resource".
109
+ from flash.providers.lambdalabs import api as lambda_api
110
+
111
+ d = handle.to_dict()
112
+ if d.get("instance_id"):
113
+ lambda_api.terminate_instances([str(d["instance_id"])])
114
+
115
+ def destroy(self, handle: JobHandle) -> None:
116
+ from flash.providers.lambdalabs import api as lambda_api
117
+
118
+ d = handle.to_dict()
119
+ if d.get("instance_id"):
120
+ lambda_api.terminate_instances([str(d["instance_id"])])
121
+
122
+ def gc(self, spec) -> None:
123
+ from flash.providers.lambdalabs.jobs import terminate_run_instances
124
+
125
+ terminate_run_instances(spec.run_id)
126
+
127
+ def sweep_orphans(
128
+ self, active_labels: set[str] | Callable[[], set[str]] | None = None
129
+ ) -> list[str]:
130
+ """Lambda crash-recovery sweep (called via the provider object at startup).
131
+
132
+ Lambda instance ids are opaque hex STRINGS (the ``base.Provider`` protocol widens the return
133
+ to ``list[int | str]`` to cover both substrates); the orchestrator only logs/counts them."""
134
+ from flash.providers.lambdalabs.jobs import sweep_orphans
135
+
136
+ return sweep_orphans(active_labels=active_labels)
137
+
138
+
139
+ PROVIDER: Provider = LambdaProvider()
@@ -0,0 +1,261 @@
1
+ """Thin Lambda Cloud REST client (no SDK state): instance-types + instance lifecycle.
2
+
3
+ Mirrors ``providers/runpod/api.py`` / the historical Vast client: stdlib urllib only (via the
4
+ shared ``RestClient``), hardened retries, and nothing persisted locally — a fresh process can
5
+ list/terminate any instance using only the persisted ids + ``LAMBDA_API_KEY``.
6
+
7
+ Two Lambda-specific quirks the rest of the provider relies on:
8
+
9
+ * **Cloudflare WAF.** Lambda's API sits behind Cloudflare, which 403s the stdlib default
10
+ ``Python-urllib/<v>`` User-Agent. The client therefore sends a real UA (``extra_headers``);
11
+ without it EVERY call fails 403 (verified live).
12
+ * **Non-idempotent launch.** ``POST /instance-operations/launch`` provisions a NEW (billed)
13
+ instance every time it succeeds, so it is NEVER retried — a blind retry on a timeout where
14
+ Lambda actually accepted the first request would double-provision. Idempotent calls
15
+ (instance-types, list, detail, terminate) keep their retries.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import time
21
+ from typing import Any
22
+
23
+ from flash._logging import get_logger
24
+ from flash.providers._http import RestClient, is_not_found
25
+
26
+ logger = get_logger(__name__)
27
+
28
+ LAMBDA_BASE = "https://cloud.lambdalabs.com/api/v1"
29
+ # A real User-Agent: Lambda's Cloudflare edge rejects the stdlib default with 403 (verified live).
30
+ _USER_AGENT = "flash-lambda/1.0 (+https://freesolo.co)"
31
+
32
+
33
+ class LambdaApiError(RuntimeError):
34
+ pass
35
+
36
+
37
+ _CLIENT = RestClient(
38
+ env_var="LAMBDA_API_KEY",
39
+ error_cls=LambdaApiError,
40
+ base_url=LAMBDA_BASE,
41
+ missing_key_message="LAMBDA_API_KEY not configured on the control-plane host",
42
+ extra_headers={"User-Agent": _USER_AGENT},
43
+ )
44
+
45
+
46
+ def request_with_retries(
47
+ path: str,
48
+ method: str = "GET",
49
+ body: dict | None = None,
50
+ retries: int = 4,
51
+ base_delay: float = 2.0,
52
+ ) -> Any:
53
+ """REST call hardened against transient network/5xx blips (jittered backoff)."""
54
+ return _CLIENT.request_with_retries(
55
+ path, method=method, body=body, retries=retries, base_delay=base_delay
56
+ )
57
+
58
+
59
+ def _data(out: Any) -> Any:
60
+ """Unwrap Lambda's ``{"data": ...}`` envelope (every 2xx response uses it)."""
61
+ if isinstance(out, dict) and "data" in out:
62
+ return out["data"]
63
+ return out
64
+
65
+
66
+ # ---------------------------------------------------------------------------
67
+ # Instance types + capacity (cached: pricing, the allocator, and the launcher all read this)
68
+ # ---------------------------------------------------------------------------
69
+ _TYPES_TTL_S = 45.0
70
+ _types_cache: dict[str, Any] = {"ts": 0.0, "data": None}
71
+
72
+
73
+ def list_instance_types(force: bool = False) -> dict[str, dict]:
74
+ """Map of ``instance_type_name -> {instance_type, regions_with_capacity_available}``.
75
+
76
+ Cached for ``_TYPES_TTL_S`` so pricing + allocation + the launch path share one fetch within an
77
+ allocation pass. ``force`` bypasses the cache. Raises ``LambdaApiError`` on a hard failure;
78
+ callers that must degrade gracefully (pricing) catch it.
79
+ """
80
+ now = time.time()
81
+ if not force and _types_cache["data"] is not None and now - _types_cache["ts"] < _TYPES_TTL_S:
82
+ return _types_cache["data"]
83
+ out = _data(request_with_retries("/instance-types"))
84
+ if not isinstance(out, dict):
85
+ raise LambdaApiError(f"unexpected /instance-types response: {out!r}")
86
+ _types_cache.update(ts=now, data=out)
87
+ return out
88
+
89
+
90
+ def regions_with_capacity(instance_type: str, force: bool = False) -> list[str]:
91
+ """Region names that currently have capacity for ``instance_type`` (cheapest source of truth
92
+ for whether a launch can succeed at all)."""
93
+ info = list_instance_types(force=force).get(instance_type) or {}
94
+ return [
95
+ r.get("name")
96
+ for r in info.get("regions_with_capacity_available", [])
97
+ if r.get("name")
98
+ ]
99
+
100
+
101
+ def all_regions(force: bool = False) -> list[str]:
102
+ """Every Lambda region with at least one capacity-available instance type — the UNION of the
103
+ ``regions_with_capacity_available`` lists across all instance types (the API has no standalone
104
+ region list, so this is the only way to enumerate reachable regions). Used by the eager
105
+ weight-cache provision step to create the ``flash-weights`` filesystem in those regions.
106
+
107
+ This is therefore capacity-DEPENDENT: a region that currently advertises ZERO capacity for every
108
+ instance type won't appear (Lambda only surfaces regions through the per-type capacity list); the
109
+ launch-time ``ensure_filesystem`` backstop covers any such region the moment a run lands there.
110
+ Sorted for a stable provision order.
111
+ """
112
+ regions: set[str] = set()
113
+ for info in list_instance_types(force=force).values():
114
+ for r in (info or {}).get("regions_with_capacity_available", []):
115
+ if r.get("name"):
116
+ regions.add(r["name"])
117
+ return sorted(regions)
118
+
119
+
120
+ def instance_type_price_usd_hr(instance_type: str) -> float | None:
121
+ """Live $/hr for a Lambda instance type (``price_cents_per_hour`` / 100), or None."""
122
+ info = (list_instance_types().get(instance_type) or {}).get("instance_type") or {}
123
+ cents = info.get("price_cents_per_hour")
124
+ return float(cents) / 100.0 if cents else None
125
+
126
+
127
+ # ---------------------------------------------------------------------------
128
+ # SSH keys (launch requires exactly one; the box is bootstrapped via user_data, not SSH)
129
+ # ---------------------------------------------------------------------------
130
+ def list_ssh_keys() -> list[dict]:
131
+ out = _data(request_with_retries("/ssh-keys"))
132
+ return out if isinstance(out, list) else []
133
+
134
+
135
+ # ---------------------------------------------------------------------------
136
+ # Instances
137
+ # ---------------------------------------------------------------------------
138
+ def launch_instance(
139
+ *,
140
+ region_name: str,
141
+ instance_type_name: str,
142
+ ssh_key_names: list[str],
143
+ name: str,
144
+ user_data: str,
145
+ file_system_names: list[str] | None = None,
146
+ ) -> str:
147
+ """Launch one instance -> its id. Raises ``LambdaApiError`` on rejection (no capacity, etc.).
148
+
149
+ NON-IDEMPOTENT (see module docstring): never retried. A transient failure surfaces to the
150
+ launcher, which walks to the next region/class.
151
+
152
+ ``file_system_names`` attaches persistent filesystems (the weight cache) AT LAUNCH — Lambda can
153
+ only attach at launch, and each must already exist in ``region_name`` (auto-mounted on the host
154
+ at ``/lambda/nfs/<name>``).
155
+ """
156
+ body = {
157
+ "region_name": region_name,
158
+ "instance_type_name": instance_type_name,
159
+ "ssh_key_names": list(ssh_key_names),
160
+ # ``name`` is bounded <=60 by ``_instance.run_label_prefix`` (NOT truncated here) so the
161
+ # stored name always equals the prefix ``sweep_orphans`` matches on.
162
+ "name": name,
163
+ "quantity": 1,
164
+ "user_data": user_data,
165
+ }
166
+ if file_system_names:
167
+ body["file_system_names"] = list(file_system_names)
168
+ out = _data(request_with_retries("/instance-operations/launch", method="POST", body=body, retries=0))
169
+ ids = out.get("instance_ids") if isinstance(out, dict) else None
170
+ if not ids:
171
+ raise LambdaApiError(f"launch({instance_type_name}@{region_name}) returned no instance id: {out}")
172
+ return str(ids[0])
173
+
174
+
175
+ # ---------------------------------------------------------------------------
176
+ # Persistent filesystems (the weight cache). Region-scoped, NFS, multi-attach; auto-mounted on the
177
+ # host at /lambda/nfs/<name>. Created via the Cloud API, attached at launch via file_system_names.
178
+ # ---------------------------------------------------------------------------
179
+ # NB: Lambda's filesystem API paths are ASYMMETRIC and this is intentional/correct, not a typo —
180
+ # verified LIVE against cloud.lambdalabs.com/api/v1 with a real create->ensure->delete probe (the FS
181
+ # was created, reused idempotently, then confirmed deleted, no stranded resources). LIST is the
182
+ # hyphenated GET /file-systems; CREATE/DELETE are the un-hyphenated POST /filesystems and
183
+ # DELETE /filesystems/{id}. Lambda's own surface differs from its other (hyphenated) resources here,
184
+ # so DO NOT "unify" these to /file-systems — that 404s the working create/delete endpoints and
185
+ # silently disables the cache. (Reviewers keep flagging the inconsistency; it's the real API.)
186
+ def list_filesystems() -> list[dict]:
187
+ """All filesystems on the account: ``[{id, name, mount_point, region:{name}, is_in_use}, ...]``."""
188
+ out = _data(request_with_retries("/file-systems")) # LIST: hyphenated (verified live)
189
+ return out if isinstance(out, list) else []
190
+
191
+
192
+ def create_filesystem(name: str, region_name: str) -> dict:
193
+ """Create filesystem ``name`` in ``region_name`` -> its object (incl. ``mount_point``)."""
194
+ out = _data(
195
+ request_with_retries(
196
+ # CREATE: un-hyphenated /filesystems (NOT /file-systems) — verified live; see note above.
197
+ "/filesystems", method="POST", body={"name": name, "region": region_name}, retries=2
198
+ )
199
+ )
200
+ return out if isinstance(out, dict) else {}
201
+
202
+
203
+ def delete_filesystem(filesystem_id: str) -> bool:
204
+ """Delete a filesystem by id (best-effort). Returns True if the request didn't raise."""
205
+ try:
206
+ # DELETE: un-hyphenated /filesystems/{id} (NOT /file-systems/{id}) — verified live; see note.
207
+ request_with_retries(f"/filesystems/{filesystem_id}", method="DELETE", retries=2)
208
+ return True
209
+ except Exception as exc:
210
+ logger.warning("lambda delete_filesystem(%s) failed: %s", filesystem_id, exc)
211
+ return False
212
+
213
+
214
+ def ensure_filesystem(name: str, region_name: str) -> str:
215
+ """Create-if-absent the cache filesystem ``name`` in ``region_name``; return its mount_point
216
+ (``/lambda/nfs/<name>``). Idempotent: reuses an existing same-name filesystem in that region."""
217
+ for fs in list_filesystems():
218
+ if fs.get("name") == name and (fs.get("region") or {}).get("name") == region_name:
219
+ return fs.get("mount_point") or f"/lambda/nfs/{name}"
220
+ created = create_filesystem(name, region_name)
221
+ return created.get("mount_point") or f"/lambda/nfs/{name}"
222
+
223
+
224
+ def get_instance(instance_id: str) -> dict | None:
225
+ """Instance detail dict, or None once it no longer exists (terminated)."""
226
+ try:
227
+ out = request_with_retries(f"/instances/{instance_id}")
228
+ except LambdaApiError as e:
229
+ if is_not_found(e):
230
+ return None
231
+ raise
232
+ data = _data(out)
233
+ return data if isinstance(data, dict) else None
234
+
235
+
236
+ def list_instances() -> list[dict]:
237
+ out = _data(request_with_retries("/instances"))
238
+ return out if isinstance(out, list) else []
239
+
240
+
241
+ def terminate_instances(instance_ids: list[str]) -> list[str]:
242
+ """Terminate (and stop billing for) instances; return the ids that ACTUALLY terminated.
243
+
244
+ PER-ID ISOLATED (one POST per id), so a single stale/invalid/race-deleted id can't abort
245
+ teardown of the rest — this is the crash-backstop path (``sweep_orphans`` /
246
+ ``terminate_run_instances`` pass many ids at once, where stale ids are common). Lambda's
247
+ terminate endpoint validates the whole request, so a batch of N ids with one bad id 4xx's and
248
+ terminates NONE; isolating per id removes that money-leak. Best-effort: never raises."""
249
+ deleted: list[str] = []
250
+ for iid in [str(i) for i in instance_ids if i]:
251
+ try:
252
+ request_with_retries(
253
+ "/instance-operations/terminate",
254
+ method="POST",
255
+ body={"instance_ids": [iid]},
256
+ retries=2,
257
+ )
258
+ deleted.append(iid)
259
+ except Exception as exc:
260
+ logger.warning("lambda terminate(%s) failed: %s", iid, exc)
261
+ return deleted
@@ -0,0 +1,18 @@
1
+ """Lambda Cloud credential handling (operator-side), mirroring the RunPod auth module.
2
+
3
+ The Lambda REST client authenticates via the ``LAMBDA_API_KEY`` environment variable, set by
4
+ the **operator** on the control-plane host. Env-only by design, exactly like ``RUNPOD_API_KEY``:
5
+ it is never written to config files. (Unlike Vast/RunPod, Lambda has no instance-scoped key, so
6
+ the operator key is also NOT shipped to the box — teardown is control-plane-side only.)
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from .._auth import load_provider_key
12
+
13
+ _ENV_VAR = "LAMBDA_API_KEY"
14
+
15
+
16
+ def load_api_key() -> str | None:
17
+ """API key from the environment (operator configuration)."""
18
+ return load_provider_key(_ENV_VAR)
@@ -0,0 +1,29 @@
1
+ """Lambda Cloud's GPU classes (its rows of the shared GPU table).
2
+
3
+ The class table is provider-agnostic and lives in ``providers/base.py``. This module carves out
4
+ Lambda's rows (``gpu_classes()`` == every class with a ``lambda_name``) and owns the
5
+ friendly-name -> Lambda instance-type translation.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from flash.providers.base import GpuClass, UnsupportedGpuError, get_gpu_info, providers_for
11
+
12
+ __all__ = ["gpu_classes", "instance_type_for"]
13
+
14
+
15
+ def gpu_classes() -> list[GpuClass]:
16
+ """The GPU classes Lambda can provision (those with a ``lambda_name``)."""
17
+ from flash.providers.base import GPU_INFO
18
+
19
+ return [g for g in GPU_INFO.values() if g.lambda_name]
20
+
21
+
22
+ def instance_type_for(name: str) -> str:
23
+ """Lambda instance-type name (e.g. 'gpu_1x_a10') for a friendly GPU class name."""
24
+ info = get_gpu_info(name)
25
+ if not info.lambda_name:
26
+ raise UnsupportedGpuError(
27
+ f"{info.name} is not available on Lambda (providers: {', '.join(providers_for(name))})"
28
+ )
29
+ return info.lambda_name