freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,118 @@
1
+ """Pure, monkeypatch-free building blocks for the Lambda Cloud run lifecycle.
2
+
3
+ The Lambda-specific leaf of ``flash.providers.lambdalabs.jobs``: the normalized dataclasses
4
+ (``LambdaInstance``, ``LambdaJobHandle``) and the image accessor. The cross-provider pieces — the
5
+ run-derived sweep label, the bootstrap payload, and the cloud-init ``user_data`` — are shared with
6
+ Hyperstack in ``flash.providers._instance`` and re-exported here so the import path is unchanged.
7
+
8
+ This module MUST NOT import the ``jobs`` package ``__init__`` (it is imported BY it).
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from dataclasses import dataclass
14
+
15
+ # Shared instance-provider helpers (single source of truth; Lambda binds arm="lambda" + its image).
16
+ from flash.providers._instance import (
17
+ build_payload as _shared_build_payload,
18
+ )
19
+ from flash.providers._instance import (
20
+ build_user_data as _shared_build_user_data,
21
+ )
22
+ from flash.providers._instance import (
23
+ instance_label,
24
+ run_label_prefix,
25
+ )
26
+
27
+ __all__ = [
28
+ "LambdaInstance",
29
+ "LambdaJobHandle",
30
+ "build_payload",
31
+ "build_user_data",
32
+ "instance_label",
33
+ "lambda_image",
34
+ "run_label_prefix",
35
+ ]
36
+
37
+
38
+ @dataclass(frozen=True)
39
+ class LambdaInstance:
40
+ """A launchable (region, instance_type, $/hr) for a managed GPU class — the Lambda analog of a
41
+ vetted Vast offer."""
42
+
43
+ gpu: str # canonical class name (GPU_INFO key)
44
+ instance_type: str # Lambda instance-type name (e.g. "gpu_1x_a10")
45
+ region: str
46
+ vram_gb: int
47
+ price_usd_hr: float
48
+
49
+
50
+ @dataclass
51
+ class LambdaJobHandle:
52
+ """Persisted in RunStatus.remote so any process can reattach/cancel (cf. base.JobHandle)."""
53
+
54
+ instance_id: str
55
+ instance_type: str
56
+ region: str
57
+ name: str # the sweep-matchable instance name (run-derived; see ``instance_label``)
58
+ gpu: str
59
+ hourly_usd: float
60
+ attempt: int
61
+ started_ts: float
62
+
63
+ def to_dict(self) -> dict:
64
+ return {
65
+ "provider": "lambda",
66
+ "instance_id": self.instance_id,
67
+ "instance_type": self.instance_type,
68
+ "region": self.region,
69
+ "name": self.name,
70
+ "gpu": self.gpu,
71
+ "hourly_usd": self.hourly_usd,
72
+ "attempt": self.attempt,
73
+ "started_ts": self.started_ts,
74
+ }
75
+
76
+ @classmethod
77
+ def from_dict(cls, d: dict) -> LambdaJobHandle:
78
+ return cls(
79
+ instance_id=str(d["instance_id"]),
80
+ instance_type=str(d.get("instance_type") or ""),
81
+ region=str(d.get("region") or ""),
82
+ name=str(d.get("name") or ""),
83
+ gpu=str(d.get("gpu") or ""),
84
+ hourly_usd=float(d.get("hourly_usd") or 0),
85
+ attempt=int(d.get("attempt") or 0),
86
+ started_ts=float(d.get("started_ts") or 0),
87
+ )
88
+
89
+
90
+ def lambda_image(gpu: str | None = None) -> str:
91
+ """Docker image the cloud-init runs on the Lambda host: the prebuilt, PUBLIC ``WORKER_IMAGE``
92
+ (the byte-identical training stack RunPod bakes). ``FLASH_WORKER_IMAGE`` overrides it; when the
93
+ operator opts into per-SM warmed images (``FLASH_WORKER_IMAGE_PER_SM`` /
94
+ ``FLASH_WORKER_IMAGE_TEMPLATE``), the GPU class selects the matching ``-smXX`` tag so the worker's
95
+ baked kernel cache matches the rented GPU's arch (the same selector RunPod uses)."""
96
+ from flash.providers.runpod.train import WORKER_IMAGE, worker_image_for_gpu
97
+
98
+ # allow_default=True -> always a concrete image to docker-pull (override / per-sm tag / base).
99
+ return worker_image_for_gpu(gpu, allow_default=True) or WORKER_IMAGE
100
+
101
+
102
+ def build_payload(
103
+ spec, seed: int, attempt: int, runtime_secrets: dict | None = None,
104
+ cache_host_mount: str | None = None,
105
+ mode: str | None = None, models: list | None = None,
106
+ ) -> dict:
107
+ """The Lambda bootstrap payload (shared builder, arm='lambda'). ``cache_host_mount`` (the host
108
+ NFS mount of the attached weight-cache filesystem, /lambda/nfs/<name>) points HF_HOME at it.
109
+ ``mode='preload'`` + ``models`` makes it a download-only warm payload (no worker)."""
110
+ return _shared_build_payload(
111
+ spec, seed, attempt, arm="lambda", runtime_secrets=runtime_secrets,
112
+ cache_host_mount=cache_host_mount, mode=mode, models=models,
113
+ )
114
+
115
+
116
+ def build_user_data(payload: dict, *, gpu: str | None = None) -> str:
117
+ """The Lambda cloud-init user_data (shared builder, runs the Lambda WORKER_IMAGE)."""
118
+ return _shared_build_user_data(payload, image=lambda_image(gpu))
@@ -0,0 +1,27 @@
1
+ """Fail-fast credential checks for the Lambda Cloud substrate (operator-side).
2
+
3
+ Mirrors ``providers/runpod/preflight.py``. Lambda is OPT-IN (the allocator only reaches for it
4
+ when ``LAMBDA_API_KEY`` is set), so the only Lambda-specific requirement is ``LAMBDA_API_KEY``;
5
+ HF_TOKEN is a shared run requirement checked once centrally by the cross-provider preflight
6
+ (``flash/providers/preflight.py``), which calls each provider-specific check with
7
+ ``require_hf=False`` so HF is never double-reported.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from flash.providers.lambdalabs.auth import load_api_key
13
+
14
+
15
+ def missing_credentials(require_hf: bool = True) -> list[str]:
16
+ """Lambda-related operator config that is missing (empty list == ready).
17
+
18
+ ``require_hf`` is accepted only for signature parity with the RunPod check and is
19
+ intentionally ignored: Lambda has no provider-owned HF requirement (the shared HF_TOKEN is
20
+ checked once centrally in ``providers.preflight``).
21
+ """
22
+ problems: list[str] = []
23
+ if not load_api_key():
24
+ problems.append(
25
+ " - LAMBDA_API_KEY: the operator's Lambda Cloud API key (for the lambda provider)"
26
+ )
27
+ return problems
@@ -0,0 +1,51 @@
1
+ """Lambda Cloud $/hr: live ``/instance-types`` rate per class, static fallback.
2
+
3
+ Lambda prices a fixed instance-type catalog (unlike Vast's live market), so a class's rate is just
4
+ its instance type's ``price_cents_per_hour``. This module gives the provider interface a uniform
5
+ ``hourly_rate(gpu)``. Offline-safe: without ``LAMBDA_API_KEY`` (or on any fetch failure) it falls
6
+ back to the static Lambda snapshot below.
7
+
8
+ NB: the static fallback is a Lambda-specific map, NOT ``GpuClass.hourly_usd`` — that field is the
9
+ RunPod secure-cloud snapshot, which differs from Lambda's list price for the shared classes (e.g.
10
+ RTX A6000 is $0.49 on RunPod but $1.09 on Lambda).
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from flash._logging import get_logger
16
+
17
+ logger = get_logger(__name__)
18
+
19
+ # Lambda list prices (snapshot 2026-06-25, from /instance-types). Live rates override these.
20
+ _STATIC_RATES: dict[str, float] = {
21
+ "A10": 1.29,
22
+ "RTX A6000": 1.09,
23
+ "A100 SXM 40GB": 1.99,
24
+ "H100": 3.29,
25
+ }
26
+
27
+
28
+ def _static_rate(name: str) -> float:
29
+ from flash.providers.base import GPU_INFO
30
+
31
+ # Prefer the Lambda snapshot; fall back to the class's nominal rate for a class we somehow
32
+ # don't have a Lambda price for (keeps ``hourly_rate`` total).
33
+ return _STATIC_RATES.get(name) or GPU_INFO[name].hourly_usd
34
+
35
+
36
+ def hourly_rate(gpu_name: str) -> float:
37
+ """$/hr for one friendly GPU name on Lambda (live ``/instance-types`` if available, else static)."""
38
+ from flash.providers.base import canonical_gpu, get_gpu_info
39
+
40
+ name = canonical_gpu(gpu_name)
41
+ info = get_gpu_info(name)
42
+ if info.lambda_name:
43
+ try:
44
+ from flash.providers.lambdalabs.api import instance_type_price_usd_hr
45
+
46
+ live = instance_type_price_usd_hr(info.lambda_name)
47
+ if live:
48
+ return live
49
+ except Exception as exc:
50
+ logger.debug("live lambda pricing unavailable for %s (%s); using static", name, exc)
51
+ return _static_rate(name)
@@ -0,0 +1,27 @@
1
+ """Lambda Cloud train submission: build the instance payload + submit a run.
2
+
3
+ The worker stack/env is substrate-neutral, so the per-run worker env and dependency resolution are
4
+ shared with RunPod (``providers/runpod/train.py``); this module owns the Lambda-specific submission
5
+ entrypoint and the instance payload shape. Provisioning, polling, and teardown live in
6
+ ``providers/lambdalabs/jobs``.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ # Shared, substrate-neutral worker stack (single source of truth on RunPod's module).
12
+ from flash.providers.lambdalabs.jobs import build_payload, submit_run_lambda
13
+ from flash.providers.runpod.train import (
14
+ WORKER_DEPS,
15
+ WORKER_SYSTEM_DEPS,
16
+ build_worker_env,
17
+ resolve_worker_deps,
18
+ )
19
+
20
+ __all__ = [
21
+ "WORKER_DEPS",
22
+ "WORKER_SYSTEM_DEPS",
23
+ "build_payload",
24
+ "build_worker_env",
25
+ "resolve_worker_deps",
26
+ "submit_run_lambda",
27
+ ]
@@ -0,0 +1,55 @@
1
+ """RunPod startup preflight.
2
+
3
+ ``check_run_preflight`` aggregates RunPod's missing-config problems plus the shared Hugging Face
4
+ dataset-repo requirements, so a single startup error lists everything missing.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import os
10
+
11
+ from flash.providers.runpod.preflight import (
12
+ PreflightError,
13
+ missing_credentials,
14
+ )
15
+
16
+ __all__ = [
17
+ "PreflightError",
18
+ "check_run_preflight",
19
+ ]
20
+
21
+
22
+ def _missing_hf_credentials() -> list[str]:
23
+ """Shared run infra every substrate needs."""
24
+ problems: list[str] = []
25
+ if not os.environ.get("GITHUB_TOKEN"):
26
+ problems.append(" - GITHUB_TOKEN: server token with access to managed Freesolo environments")
27
+ if not os.environ.get("HF_TOKEN"):
28
+ problems.append(
29
+ " - HF_TOKEN: a token with write access to each run's "
30
+ "`[train] hf_repo`, e.g. `export HF_TOKEN=hf_...`"
31
+ )
32
+ return problems
33
+
34
+
35
+ def _preflight_provider_names() -> set[str]:
36
+ """The providers whose operator config this control plane must satisfy."""
37
+ return {"runpod"}
38
+
39
+
40
+ def check_run_preflight(require_hf: bool = True) -> None:
41
+ """Validate RunPod operator config; raise on missing."""
42
+ selected = _preflight_provider_names()
43
+ problems: list[str] = []
44
+ # The HF write token is shared run infra and is checked once so it isn't double-reported.
45
+ # The HF dataset repo itself is per-run (``[train] hf_repo``).
46
+ if "runpod" in selected:
47
+ problems += missing_credentials(require_hf=False)
48
+ if require_hf:
49
+ problems += _missing_hf_credentials()
50
+ if problems:
51
+ raise PreflightError(
52
+ "the Flash control plane is missing required operator configuration:\n"
53
+ + "\n".join(problems)
54
+ + "\n\nSet these on the control-plane host."
55
+ )
@@ -0,0 +1,80 @@
1
+ """Realized provider cost (COGS) for a finished run -- the cost side of estimator accuracy.
2
+
3
+ RunPod's billing API gives the dollars it ACTUALLY charged, which the reconciliation job
4
+ compares against the run's charged pre-flight estimate. This module owns the ``RealizedCost``
5
+ shape and dispatches to the RunPod shaper by the run's persisted handle
6
+ (``RunStatus.remote['provider']``). The HTTP calls live in the provider's ``api.py``; the pure
7
+ shaping lives in its ``cost.py`` so it stays offline-testable.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from dataclasses import dataclass, field
13
+
14
+
15
+ @dataclass
16
+ class RealizedCost:
17
+ provider: str
18
+ realized_usd: float
19
+ by_resource: dict[str, float] = field(
20
+ default_factory=dict
21
+ ) # {"gpu": .., "disk": .., "bwd": ..}
22
+ wall_seconds: float | None = None
23
+ source: dict = field(default_factory=dict) # audit: resource ids / raw refs
24
+
25
+
26
+ def realized_cost_for_remote(
27
+ remote: dict | None, *, start: float, end: float, run_end: float | None = None
28
+ ) -> RealizedCost | None:
29
+ """Pull realized cost for a run from its persisted provider handle, or None if unattributable.
30
+
31
+ ``remote`` is ``RunStatus.remote`` (the last/successful attempt's handle dict). Returns None
32
+ when there is no handle, no resource id, or an unknown provider -- the run then stays
33
+ unreconciled (and is retried).
34
+
35
+ Two distinct time bounds, because the two cost sources are different:
36
+ * ``start``/``end`` bound the RunPod BILLING-API query window. The caller pads ``end`` past the
37
+ run's terminal time so the settled invoice row is in range (see reconcile ``_SETTLE_SECONDS``).
38
+ * ``run_end`` is the run's ACTUAL terminal time (~teardown). The instance providers
39
+ (Lambda/Hyperstack) have no billing endpoint: an instance bills at a flat $/hr from launch to
40
+ teardown, so their realized COGS is wall x rate over ``started_ts -> run_end`` — it must NOT
41
+ use the settle-padded ``end`` or it would over-bill by the padding (up to an hour). Defaults
42
+ to ``end`` for back-compat when the caller doesn't distinguish.
43
+ """
44
+ if not remote:
45
+ return None
46
+ provider = remote.get("provider") or "runpod"
47
+ if provider == "runpod":
48
+ from flash.providers.runpod.cost import realized_cost as runpod_realized
49
+
50
+ return runpod_realized(remote.get("endpoint_id"), start=start, end=end)
51
+ if provider in ("lambda", "hyperstack"):
52
+ return _instance_realized_cost(remote, start=start, end=run_end if run_end is not None else end)
53
+ return None
54
+
55
+
56
+ def _instance_realized_cost(
57
+ remote: dict, *, start: float, end: float
58
+ ) -> RealizedCost | None:
59
+ """Realized COGS for an instance-billed provider: wall-clock x the instance's flat $/hr.
60
+
61
+ The instance billed from its launch (``started_ts`` on the handle) until teardown (``end``, the
62
+ run's true terminal time — NOT a settle-padded billing-query bound). Unattributable -> None (no
63
+ rate persisted) so the run stays unreconciled rather than booking $0.
64
+ """
65
+ rate = remote.get("hourly_usd")
66
+ rid = remote.get("instance_id") or remote.get("vm_id")
67
+ # Honor the module contract: no rate OR no auditable resource id -> unattributable (None), so the
68
+ # run stays unreconciled rather than booking instance cost we can't tie to a resource.
69
+ if not rate or not rid:
70
+ return None
71
+ launch = remote.get("started_ts") or start
72
+ wall = max(0.0, float(end) - float(launch))
73
+ usd = round(wall / 3600.0 * float(rate), 6)
74
+ return RealizedCost(
75
+ provider=str(remote.get("provider")),
76
+ realized_usd=usd,
77
+ by_resource={"gpu": usd},
78
+ wall_seconds=wall,
79
+ source={"resource_id": str(rid), "hourly_usd": float(rate), "started_ts": float(launch)},
80
+ )
@@ -0,0 +1,130 @@
1
+ """RunPod Flash provider: managed, serverless GPUs (no Docker) for Flash.
2
+
3
+ Fine-tuning runs on a dedicated RunPod GPU provisioned by Flash. A decorated Python
4
+ handler (``train._train_body``) executes ``flash.engine.worker`` on the GPU; Flash
5
+ handles provisioning, dependency install, execution, and scale-to-zero teardown.
6
+ Serving exposes an OpenAI-compatible endpoint for a trained LoRA adapter.
7
+
8
+ ``PROVIDER`` is the ``base.Provider`` implementation the registry hands out; the
9
+ orchestrator/allocator only talk to its interface, never these modules directly.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import Any
15
+
16
+ from flash.providers.base import GpuClass, JobHandle, PollResult, Provider
17
+
18
+
19
+ class RunpodProvider:
20
+ """``base.Provider`` for the RunPod Flash substrate."""
21
+
22
+ name = "runpod"
23
+
24
+ def is_configured(self) -> bool:
25
+ # RunPod is the ALWAYS-ON default substrate, so it is always "available" for
26
+ # allocation. Pricing is static, and a missing RUNPOD_API_KEY surfaces at provision time
27
+ # via ensure_auth / the preflight, never as a silent empty candidate list. This matches the
28
+ # historical ``available_providers()`` which listed runpod unconditionally.
29
+ return True
30
+
31
+ def preflight(self, require_hf: bool = True) -> list[str]:
32
+ from flash.providers.runpod.preflight import missing_credentials
33
+
34
+ return missing_credentials(require_hf=require_hf)
35
+
36
+ def gpu_classes(self) -> list[GpuClass]:
37
+ from flash.providers.runpod.gpus import gpu_classes
38
+
39
+ return gpu_classes()
40
+
41
+ def hourly_rate(self, gpu: str) -> float:
42
+ from flash.providers.runpod.pricing import hourly_rate
43
+
44
+ return hourly_rate(gpu)
45
+
46
+ def submit_run(
47
+ self,
48
+ spec,
49
+ seed: int,
50
+ *,
51
+ log: Any = None,
52
+ on_handle: Any = None,
53
+ attempt: int = 0,
54
+ runtime_secrets: dict[str, str] | None = None,
55
+ on_last_gpu: bool = False,
56
+ ) -> PollResult:
57
+ # ``on_last_gpu`` stretches the no-capacity grace when no further GPU attempt will be made
58
+ # after this one — either the candidate list is exhausted or the retry budget is exhausted (see
59
+ # ``jobs.stall_kwargs``); waiting longer can't cost a fallback there is none.
60
+ from flash.providers.runpod.jobs import submit_run
61
+
62
+ kwargs = {
63
+ "log": log,
64
+ "on_handle": on_handle,
65
+ "attempt": attempt,
66
+ "on_last_gpu": on_last_gpu,
67
+ }
68
+ if runtime_secrets:
69
+ kwargs["runtime_secrets"] = runtime_secrets
70
+ return submit_run(spec, seed, **kwargs)
71
+
72
+ def poll(self, handle: JobHandle, spec, seed: int, *, log: Any = None) -> PollResult:
73
+ from flash.providers.runpod.jobs import JobHandle as RunpodJobHandle
74
+ from flash.providers.runpod.jobs import (
75
+ make_hf_failure_detail_reader,
76
+ make_hf_heartbeat_reader,
77
+ poll_job,
78
+ stall_kwargs,
79
+ )
80
+
81
+ hf_repo = spec.train.hf_repo
82
+ prefix = f"{spec.phase}/{spec.run_id}/seed{seed}"
83
+ reader = make_hf_heartbeat_reader(hf_repo, prefix) if hf_repo else None
84
+ failure_reader = (
85
+ make_hf_failure_detail_reader(hf_repo, prefix, spec.phase) if hf_repo else None
86
+ )
87
+ rh = RunpodJobHandle.from_dict(handle.to_dict())
88
+ if log is not None:
89
+ print(f"attaching: job={rh.job_id} endpoint={rh.endpoint_name}", file=log, flush=True)
90
+ # Same stall tuning as the submit path so a reattached run isn't judged differently:
91
+ # the original submit's ``on_last_gpu`` is persisted in the handle (by the runner's
92
+ # on_handle), so reproduce its no-capacity grace here instead of defaulting to the
93
+ # shorter non-last window. Absent (a pre-persist / non-runpod handle) => False, the
94
+ # historical default.
95
+ on_last_gpu = bool(handle.to_dict().get("on_last_gpu", False))
96
+ return poll_job(
97
+ rh,
98
+ log=log,
99
+ heartbeat_reader=reader,
100
+ failure_detail_reader=failure_reader,
101
+ **stall_kwargs(on_last_gpu=on_last_gpu),
102
+ )
103
+
104
+ def cancel(self, handle: JobHandle) -> None:
105
+ from flash.providers.runpod import api as runpod_api
106
+
107
+ d = handle.to_dict()
108
+ if d.get("endpoint_id") and d.get("job_id"):
109
+ runpod_api.cancel_job(d["endpoint_id"], d["job_id"])
110
+
111
+ def destroy(self, handle: JobHandle) -> None:
112
+ from flash.providers.runpod import api as runpod_api
113
+
114
+ d = handle.to_dict()
115
+ if d.get("endpoint_id"):
116
+ runpod_api.delete_endpoint(d["endpoint_id"])
117
+
118
+ def gc(self, spec) -> None:
119
+ from flash.providers.runpod.train import terminate_endpoint
120
+
121
+ terminate_endpoint(spec.gpu.type, spec.run_id)
122
+
123
+ def sweep_orphans(self, active_labels: set[str] | None = None) -> list[int]:
124
+ # No-op: RunPod serverless endpoints have no standing per-run billing to reap on
125
+ # crash recovery (a failed-before-submit endpoint is GC'd by reconstructed name in
126
+ # recover_runs). Present for the ``base.Provider`` protocol.
127
+ return []
128
+
129
+
130
+ PROVIDER: Provider = RunpodProvider()
@@ -0,0 +1,186 @@
1
+ """Thin RunPod REST client (no SDK state): endpoints, queue jobs, health.
2
+
3
+ Used by the run supervisor and endpoint GC so that a *fresh process* can
4
+ reattach to / clean up after any run using only the persisted ids + RUNPOD_API_KEY —
5
+ independent of the Flash SDK's local resource registry (which is per-directory,
6
+ whole-dict, last-writer-wins and therefore unreliable across processes).
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import urllib.error
12
+ from typing import Any
13
+
14
+ from flash.providers._http import RestClient
15
+ from flash.providers.runpod import keys as _keys
16
+
17
+ REST_BASE = "https://rest.runpod.io/v1"
18
+ QUEUE_BASE = "https://api.runpod.ai/v2"
19
+
20
+
21
+ class RunpodApiError(RuntimeError):
22
+ pass
23
+
24
+
25
+ # Shared urllib client (full-URL form: callers pass absolute REST/QUEUE urls).
26
+ # Env-only by design: ~/.flash/config.json holds the *Flash* key (client-side),
27
+ # never the RunPod key — the operator sets RUNPOD_API_KEY on the control-plane host.
28
+ #
29
+ # ``RUNPOD_API_KEY`` may be a comma-separated pool of per-account keys: the client tries
30
+ # them active-account-first per call (``keys.ordered_keys``) and fails over to the next
31
+ # account on an auth/quota/not-found error (``keys.is_failover_error``). RunPod endpoints
32
+ # are account-scoped, so a single-account op (status/cancel/delete) resolves no matter
33
+ # which account a failed-over run was provisioned on. A single key => a pool of one.
34
+ _CLIENT = RestClient(
35
+ env_var="RUNPOD_API_KEY",
36
+ error_cls=RunpodApiError,
37
+ keys_provider=_keys.ordered_keys,
38
+ failover_predicate=_keys.is_failover_error,
39
+ )
40
+
41
+
42
+ def request_with_retries(
43
+ url: str,
44
+ method: str = "GET",
45
+ body: dict | None = None,
46
+ retries: int = 4,
47
+ base_delay: float = 2.0,
48
+ ) -> Any:
49
+ """REST call hardened against transient network/5xx blips (jittered backoff)."""
50
+ return _CLIENT.request_with_retries(
51
+ url, method=method, body=body, retries=retries, base_delay=base_delay
52
+ )
53
+
54
+
55
+ # ---------------------------------------------------------------------------
56
+ # Endpoints
57
+ # ---------------------------------------------------------------------------
58
+ def list_endpoints() -> list[dict]:
59
+ # ``RUNPOD_API_KEY`` may be a comma-separated pool of per-account keys. RunPod
60
+ # endpoints are account-scoped: a plain request_with_retries() call stops at the
61
+ # first key that succeeds and returns only *that* account's endpoints. Idle-sweep
62
+ # and slot-reconcile need the full fleet across every account in the pool, so we
63
+ # query each key independently (with per-key retries) and aggregate.
64
+ #
65
+ # Raises on any per-key failure so callers that treat an empty result as "confirmed
66
+ # absent" (teardown, slot-reconcile) don't act on an incomplete view. Both
67
+ # sweep_idle_endpoints() and the slot reconcile already catch and skip on exception.
68
+ pool = _keys.keys()
69
+ if not pool:
70
+ # No RUNPOD_API_KEY at all: an empty `pool` would make this return [] WITHOUT a single
71
+ # authenticated call, and callers read [] as "the fleet is empty / confirmed absent" and may
72
+ # act on that (teardown, slot-reconcile). Fail loud instead — matching the old single-call
73
+ # request_with_retries() behavior, which raised on a missing key.
74
+ raise RunpodApiError(
75
+ "RUNPOD_API_KEY is not set; refusing to report an empty endpoint fleet"
76
+ )
77
+ all_endpoints: list[dict] = []
78
+ for key in pool:
79
+ out = _CLIENT.request_with_retries_for_key(key, f"{REST_BASE}/endpoints", retries=2)
80
+ if not isinstance(out, list):
81
+ # A 200 whose body isn't the expected list is NOT an empty account — silently skipping it
82
+ # (the old behavior) yields a partial fleet view that callers trust as complete. Raise so
83
+ # the per-key failure surfaces, consistent with this function's "fail, don't under-report"
84
+ # contract above.
85
+ raise RunpodApiError(
86
+ f"unexpected /endpoints response for a pool key (got {type(out).__name__}, want list)"
87
+ )
88
+ all_endpoints.extend(out)
89
+ return all_endpoints
90
+
91
+
92
+ def find_endpoints_by_name(substr: str) -> list[dict]:
93
+ return [e for e in list_endpoints() if substr in (e.get("name") or "")]
94
+
95
+
96
+ def delete_endpoint(endpoint_id: str) -> bool:
97
+ try:
98
+ request_with_retries(f"{REST_BASE}/endpoints/{endpoint_id}", method="DELETE", retries=2)
99
+ return True
100
+ except RunpodApiError as e:
101
+ # An already-gone endpoint is a clean teardown, not a failure: a 404 (or a body
102
+ # saying the endpoint "does not exist") means the desired end state — no such
103
+ # endpoint — already holds. Reporting False here makes undeploy_adapter surface a
104
+ # misleading "may still be running" 502 for something that's provably gone.
105
+ return _is_not_found(e)
106
+
107
+
108
+ def _is_not_found(err: RunpodApiError) -> bool:
109
+ """True only when a RunpodApiError represents a genuine 404 (endpoint already gone).
110
+
111
+ request_with_retries chains the original urllib HTTPError as ``__cause__`` for every
112
+ fast-failed 4xx (``raise ... from e``), so the status code is authoritative when a
113
+ cause is present: a 404 is "already gone", anything else (403/401/5xx) is a real
114
+ failure and must NOT be swallowed — a body that merely *mentions* "does not exist" on a
115
+ 403 is still a 403. We only fall back to a text match when there is no HTTPError cause
116
+ (e.g. the "failed after N attempts" path), and even then only on an unambiguous 404.
117
+ """
118
+ cause = err.__cause__
119
+ if isinstance(cause, urllib.error.HTTPError):
120
+ return cause.code == 404
121
+ return "http 404" in str(err).lower()
122
+
123
+
124
+ def endpoint_health(endpoint_id: str) -> dict:
125
+ return request_with_retries(f"{QUEUE_BASE}/{endpoint_id}/health")
126
+
127
+
128
+ # ---------------------------------------------------------------------------
129
+ # Queue jobs
130
+ # ---------------------------------------------------------------------------
131
+ def submit_job(endpoint_id: str, input_payload: dict) -> str:
132
+ """POST /run -> job id (async queue submission)."""
133
+ out = request_with_retries(
134
+ f"{QUEUE_BASE}/{endpoint_id}/run", method="POST", body={"input": input_payload}
135
+ )
136
+ job_id = out.get("id")
137
+ if not job_id:
138
+ raise RunpodApiError(f"submit_job: no job id in response: {out}")
139
+ return job_id
140
+
141
+
142
+ def job_status(endpoint_id: str, job_id: str) -> dict:
143
+ """GET /status/<job_id> -> {status, output?, error?, ...}."""
144
+ return request_with_retries(f"{QUEUE_BASE}/{endpoint_id}/status/{job_id}")
145
+
146
+
147
+ def cancel_job(endpoint_id: str, job_id: str) -> dict:
148
+ return request_with_retries(
149
+ f"{QUEUE_BASE}/{endpoint_id}/cancel/{job_id}", method="POST", retries=2
150
+ )
151
+
152
+
153
+ # ---------------------------------------------------------------------------
154
+ # Realized billing (COGS) -- what RunPod actually charged, for estimator accuracy.
155
+ # ---------------------------------------------------------------------------
156
+ def billing_endpoints(
157
+ *,
158
+ start_time: str,
159
+ end_time: str,
160
+ endpoint_id: str | None = None,
161
+ bucket_size: str = "day",
162
+ ) -> list[dict]:
163
+ """Realized serverless spend per endpoint over [start_time, end_time] (ISO-8601).
164
+
165
+ GET /v1/billing/endpoints -> records of {endpointId, time, amount (USD), timeBilledMs, ...}.
166
+ RunPod has no per-job cost; the finest realized granularity is per-endpoint per time bucket.
167
+ Flash provisions one endpoint per run, so filtering by ``endpoint_id`` yields that run's
168
+ realized cost even after the endpoint is torn down (billing history survives deletion).
169
+ """
170
+ from urllib.parse import urlencode
171
+
172
+ params: dict[str, str] = {
173
+ "startTime": start_time,
174
+ "endTime": end_time,
175
+ "bucketSize": bucket_size,
176
+ }
177
+ if endpoint_id:
178
+ params["endpointId"] = endpoint_id
179
+ out = request_with_retries(f"{REST_BASE}/billing/endpoints?{urlencode(params)}")
180
+ if isinstance(out, list):
181
+ return out
182
+ # Defensive: some RunPod list responses wrap rows under a key.
183
+ if isinstance(out, dict):
184
+ rows = out.get("data") or out.get("endpoints") or out.get("billing")
185
+ return rows if isinstance(rows, list) else []
186
+ return []