freesolo-flash-dev 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flash/__init__.py +29 -0
- flash/_channel.py +23 -0
- flash/_fileio.py +35 -0
- flash/_logging.py +49 -0
- flash/_update_check.py +266 -0
- flash/catalog.py +253 -0
- flash/cli/__init__.py +1 -0
- flash/cli/main/__init__.py +227 -0
- flash/cli/main/__main__.py +6 -0
- flash/cli/main/commands.py +636 -0
- flash/cli/main/envpush.py +317 -0
- flash/cli/main/render.py +599 -0
- flash/cli/main/training_doc.py +455 -0
- flash/client/__init__.py +14 -0
- flash/client/config.py +70 -0
- flash/client/http.py +372 -0
- flash/client/runtime_secrets.py +69 -0
- flash/client/specs.py +20 -0
- flash/cost/__init__.py +16 -0
- flash/cost/analytical.py +175 -0
- flash/cost/facts.py +114 -0
- flash/cost/spec.py +113 -0
- flash/cost/types.py +158 -0
- flash/engine/__init__.py +6 -0
- flash/engine/accounting.py +36 -0
- flash/engine/chalk_kernels.py +116 -0
- flash/engine/multiturn_rollout.py +780 -0
- flash/engine/recipe.py +86 -0
- flash/engine/vram.py +603 -0
- flash/engine/worker/__init__.py +2916 -0
- flash/engine/worker/__main__.py +4 -0
- flash/engine/worker/kernel_warmup.py +400 -0
- flash/engine/worker/lora.py +796 -0
- flash/engine/worker/packing.py +366 -0
- flash/engine/worker/perf.py +1048 -0
- flash/envs/__init__.py +10 -0
- flash/envs/adapter/__init__.py +883 -0
- flash/envs/adapter/rubric.py +222 -0
- flash/envs/base.py +52 -0
- flash/envs/registry.py +62 -0
- flash/mcp/__init__.py +1 -0
- flash/mcp/server.py +85 -0
- flash/providers/__init__.py +59 -0
- flash/providers/_auth.py +24 -0
- flash/providers/_http.py +230 -0
- flash/providers/_instance.py +416 -0
- flash/providers/_instance_bootstrap.py +517 -0
- flash/providers/_poll.py +311 -0
- flash/providers/allocator.py +193 -0
- flash/providers/base.py +431 -0
- flash/providers/hyperstack/__init__.py +127 -0
- flash/providers/hyperstack/api.py +522 -0
- flash/providers/hyperstack/auth.py +17 -0
- flash/providers/hyperstack/gpus.py +29 -0
- flash/providers/hyperstack/jobs/__init__.py +632 -0
- flash/providers/hyperstack/jobs/builders.py +122 -0
- flash/providers/hyperstack/preflight.py +23 -0
- flash/providers/hyperstack/pricing.py +26 -0
- flash/providers/hyperstack/train.py +25 -0
- flash/providers/lambdalabs/__init__.py +139 -0
- flash/providers/lambdalabs/api.py +261 -0
- flash/providers/lambdalabs/auth.py +18 -0
- flash/providers/lambdalabs/gpus.py +29 -0
- flash/providers/lambdalabs/jobs/__init__.py +724 -0
- flash/providers/lambdalabs/jobs/builders.py +118 -0
- flash/providers/lambdalabs/preflight.py +27 -0
- flash/providers/lambdalabs/pricing.py +51 -0
- flash/providers/lambdalabs/train.py +27 -0
- flash/providers/preflight.py +55 -0
- flash/providers/realized.py +80 -0
- flash/providers/runpod/__init__.py +130 -0
- flash/providers/runpod/api.py +186 -0
- flash/providers/runpod/auth.py +37 -0
- flash/providers/runpod/cost.py +57 -0
- flash/providers/runpod/gpus.py +46 -0
- flash/providers/runpod/jobs.py +956 -0
- flash/providers/runpod/keys.py +139 -0
- flash/providers/runpod/preflight.py +30 -0
- flash/providers/runpod/preload.py +915 -0
- flash/providers/runpod/pricing.py +18 -0
- flash/providers/runpod/slots.py +79 -0
- flash/providers/runpod/train/__init__.py +150 -0
- flash/providers/runpod/train/deps.py +395 -0
- flash/providers/runpod/train/endpoints.py +820 -0
- flash/py.typed +0 -0
- flash/runner/__init__.py +686 -0
- flash/runner/checkpoints.py +82 -0
- flash/runner/deploy.py +422 -0
- flash/runner/lifecycle.py +672 -0
- flash/schema/__init__.py +375 -0
- flash/schema/fields.py +331 -0
- flash/serve/__init__.py +1 -0
- flash/serve/deploy.py +326 -0
- flash/serve/pricing.py +60 -0
- flash/server/__init__.py +1 -0
- flash/server/__main__.py +20 -0
- flash/server/app.py +961 -0
- flash/server/auth.py +263 -0
- flash/server/billing.py +124 -0
- flash/server/checkpoints.py +110 -0
- flash/server/db.py +160 -0
- flash/server/environment_registry.py +102 -0
- flash/server/envs.py +360 -0
- flash/server/reconcile.py +163 -0
- flash/server/run_registry.py +150 -0
- flash/spec.py +333 -0
- freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
- freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
- freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
- freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
- freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""Pure, monkeypatch-free building blocks for the Lambda Cloud run lifecycle.
|
|
2
|
+
|
|
3
|
+
The Lambda-specific leaf of ``flash.providers.lambdalabs.jobs``: the normalized dataclasses
|
|
4
|
+
(``LambdaInstance``, ``LambdaJobHandle``) and the image accessor. The cross-provider pieces — the
|
|
5
|
+
run-derived sweep label, the bootstrap payload, and the cloud-init ``user_data`` — are shared with
|
|
6
|
+
Hyperstack in ``flash.providers._instance`` and re-exported here so the import path is unchanged.
|
|
7
|
+
|
|
8
|
+
This module MUST NOT import the ``jobs`` package ``__init__`` (it is imported BY it).
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
|
|
15
|
+
# Shared instance-provider helpers (single source of truth; Lambda binds arm="lambda" + its image).
|
|
16
|
+
from flash.providers._instance import (
|
|
17
|
+
build_payload as _shared_build_payload,
|
|
18
|
+
)
|
|
19
|
+
from flash.providers._instance import (
|
|
20
|
+
build_user_data as _shared_build_user_data,
|
|
21
|
+
)
|
|
22
|
+
from flash.providers._instance import (
|
|
23
|
+
instance_label,
|
|
24
|
+
run_label_prefix,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"LambdaInstance",
|
|
29
|
+
"LambdaJobHandle",
|
|
30
|
+
"build_payload",
|
|
31
|
+
"build_user_data",
|
|
32
|
+
"instance_label",
|
|
33
|
+
"lambda_image",
|
|
34
|
+
"run_label_prefix",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass(frozen=True)
|
|
39
|
+
class LambdaInstance:
|
|
40
|
+
"""A launchable (region, instance_type, $/hr) for a managed GPU class — the Lambda analog of a
|
|
41
|
+
vetted Vast offer."""
|
|
42
|
+
|
|
43
|
+
gpu: str # canonical class name (GPU_INFO key)
|
|
44
|
+
instance_type: str # Lambda instance-type name (e.g. "gpu_1x_a10")
|
|
45
|
+
region: str
|
|
46
|
+
vram_gb: int
|
|
47
|
+
price_usd_hr: float
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class LambdaJobHandle:
|
|
52
|
+
"""Persisted in RunStatus.remote so any process can reattach/cancel (cf. base.JobHandle)."""
|
|
53
|
+
|
|
54
|
+
instance_id: str
|
|
55
|
+
instance_type: str
|
|
56
|
+
region: str
|
|
57
|
+
name: str # the sweep-matchable instance name (run-derived; see ``instance_label``)
|
|
58
|
+
gpu: str
|
|
59
|
+
hourly_usd: float
|
|
60
|
+
attempt: int
|
|
61
|
+
started_ts: float
|
|
62
|
+
|
|
63
|
+
def to_dict(self) -> dict:
|
|
64
|
+
return {
|
|
65
|
+
"provider": "lambda",
|
|
66
|
+
"instance_id": self.instance_id,
|
|
67
|
+
"instance_type": self.instance_type,
|
|
68
|
+
"region": self.region,
|
|
69
|
+
"name": self.name,
|
|
70
|
+
"gpu": self.gpu,
|
|
71
|
+
"hourly_usd": self.hourly_usd,
|
|
72
|
+
"attempt": self.attempt,
|
|
73
|
+
"started_ts": self.started_ts,
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
def from_dict(cls, d: dict) -> LambdaJobHandle:
|
|
78
|
+
return cls(
|
|
79
|
+
instance_id=str(d["instance_id"]),
|
|
80
|
+
instance_type=str(d.get("instance_type") or ""),
|
|
81
|
+
region=str(d.get("region") or ""),
|
|
82
|
+
name=str(d.get("name") or ""),
|
|
83
|
+
gpu=str(d.get("gpu") or ""),
|
|
84
|
+
hourly_usd=float(d.get("hourly_usd") or 0),
|
|
85
|
+
attempt=int(d.get("attempt") or 0),
|
|
86
|
+
started_ts=float(d.get("started_ts") or 0),
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def lambda_image(gpu: str | None = None) -> str:
|
|
91
|
+
"""Docker image the cloud-init runs on the Lambda host: the prebuilt, PUBLIC ``WORKER_IMAGE``
|
|
92
|
+
(the byte-identical training stack RunPod bakes). ``FLASH_WORKER_IMAGE`` overrides it; when the
|
|
93
|
+
operator opts into per-SM warmed images (``FLASH_WORKER_IMAGE_PER_SM`` /
|
|
94
|
+
``FLASH_WORKER_IMAGE_TEMPLATE``), the GPU class selects the matching ``-smXX`` tag so the worker's
|
|
95
|
+
baked kernel cache matches the rented GPU's arch (the same selector RunPod uses)."""
|
|
96
|
+
from flash.providers.runpod.train import WORKER_IMAGE, worker_image_for_gpu
|
|
97
|
+
|
|
98
|
+
# allow_default=True -> always a concrete image to docker-pull (override / per-sm tag / base).
|
|
99
|
+
return worker_image_for_gpu(gpu, allow_default=True) or WORKER_IMAGE
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def build_payload(
|
|
103
|
+
spec, seed: int, attempt: int, runtime_secrets: dict | None = None,
|
|
104
|
+
cache_host_mount: str | None = None,
|
|
105
|
+
mode: str | None = None, models: list | None = None,
|
|
106
|
+
) -> dict:
|
|
107
|
+
"""The Lambda bootstrap payload (shared builder, arm='lambda'). ``cache_host_mount`` (the host
|
|
108
|
+
NFS mount of the attached weight-cache filesystem, /lambda/nfs/<name>) points HF_HOME at it.
|
|
109
|
+
``mode='preload'`` + ``models`` makes it a download-only warm payload (no worker)."""
|
|
110
|
+
return _shared_build_payload(
|
|
111
|
+
spec, seed, attempt, arm="lambda", runtime_secrets=runtime_secrets,
|
|
112
|
+
cache_host_mount=cache_host_mount, mode=mode, models=models,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def build_user_data(payload: dict, *, gpu: str | None = None) -> str:
|
|
117
|
+
"""The Lambda cloud-init user_data (shared builder, runs the Lambda WORKER_IMAGE)."""
|
|
118
|
+
return _shared_build_user_data(payload, image=lambda_image(gpu))
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Fail-fast credential checks for the Lambda Cloud substrate (operator-side).
|
|
2
|
+
|
|
3
|
+
Mirrors ``providers/runpod/preflight.py``. Lambda is OPT-IN (the allocator only reaches for it
|
|
4
|
+
when ``LAMBDA_API_KEY`` is set), so the only Lambda-specific requirement is ``LAMBDA_API_KEY``;
|
|
5
|
+
HF_TOKEN is a shared run requirement checked once centrally by the cross-provider preflight
|
|
6
|
+
(``flash/providers/preflight.py``), which calls each provider-specific check with
|
|
7
|
+
``require_hf=False`` so HF is never double-reported.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from flash.providers.lambdalabs.auth import load_api_key
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def missing_credentials(require_hf: bool = True) -> list[str]:
|
|
16
|
+
"""Lambda-related operator config that is missing (empty list == ready).
|
|
17
|
+
|
|
18
|
+
``require_hf`` is accepted only for signature parity with the RunPod check and is
|
|
19
|
+
intentionally ignored: Lambda has no provider-owned HF requirement (the shared HF_TOKEN is
|
|
20
|
+
checked once centrally in ``providers.preflight``).
|
|
21
|
+
"""
|
|
22
|
+
problems: list[str] = []
|
|
23
|
+
if not load_api_key():
|
|
24
|
+
problems.append(
|
|
25
|
+
" - LAMBDA_API_KEY: the operator's Lambda Cloud API key (for the lambda provider)"
|
|
26
|
+
)
|
|
27
|
+
return problems
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""Lambda Cloud $/hr: live ``/instance-types`` rate per class, static fallback.
|
|
2
|
+
|
|
3
|
+
Lambda prices a fixed instance-type catalog (unlike Vast's live market), so a class's rate is just
|
|
4
|
+
its instance type's ``price_cents_per_hour``. This module gives the provider interface a uniform
|
|
5
|
+
``hourly_rate(gpu)``. Offline-safe: without ``LAMBDA_API_KEY`` (or on any fetch failure) it falls
|
|
6
|
+
back to the static Lambda snapshot below.
|
|
7
|
+
|
|
8
|
+
NB: the static fallback is a Lambda-specific map, NOT ``GpuClass.hourly_usd`` — that field is the
|
|
9
|
+
RunPod secure-cloud snapshot, which differs from Lambda's list price for the shared classes (e.g.
|
|
10
|
+
RTX A6000 is $0.49 on RunPod but $1.09 on Lambda).
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from flash._logging import get_logger
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
# Lambda list prices (snapshot 2026-06-25, from /instance-types). Live rates override these.
|
|
20
|
+
_STATIC_RATES: dict[str, float] = {
|
|
21
|
+
"A10": 1.29,
|
|
22
|
+
"RTX A6000": 1.09,
|
|
23
|
+
"A100 SXM 40GB": 1.99,
|
|
24
|
+
"H100": 3.29,
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _static_rate(name: str) -> float:
|
|
29
|
+
from flash.providers.base import GPU_INFO
|
|
30
|
+
|
|
31
|
+
# Prefer the Lambda snapshot; fall back to the class's nominal rate for a class we somehow
|
|
32
|
+
# don't have a Lambda price for (keeps ``hourly_rate`` total).
|
|
33
|
+
return _STATIC_RATES.get(name) or GPU_INFO[name].hourly_usd
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def hourly_rate(gpu_name: str) -> float:
|
|
37
|
+
"""$/hr for one friendly GPU name on Lambda (live ``/instance-types`` if available, else static)."""
|
|
38
|
+
from flash.providers.base import canonical_gpu, get_gpu_info
|
|
39
|
+
|
|
40
|
+
name = canonical_gpu(gpu_name)
|
|
41
|
+
info = get_gpu_info(name)
|
|
42
|
+
if info.lambda_name:
|
|
43
|
+
try:
|
|
44
|
+
from flash.providers.lambdalabs.api import instance_type_price_usd_hr
|
|
45
|
+
|
|
46
|
+
live = instance_type_price_usd_hr(info.lambda_name)
|
|
47
|
+
if live:
|
|
48
|
+
return live
|
|
49
|
+
except Exception as exc:
|
|
50
|
+
logger.debug("live lambda pricing unavailable for %s (%s); using static", name, exc)
|
|
51
|
+
return _static_rate(name)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Lambda Cloud train submission: build the instance payload + submit a run.
|
|
2
|
+
|
|
3
|
+
The worker stack/env is substrate-neutral, so the per-run worker env and dependency resolution are
|
|
4
|
+
shared with RunPod (``providers/runpod/train.py``); this module owns the Lambda-specific submission
|
|
5
|
+
entrypoint and the instance payload shape. Provisioning, polling, and teardown live in
|
|
6
|
+
``providers/lambdalabs/jobs``.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
# Shared, substrate-neutral worker stack (single source of truth on RunPod's module).
|
|
12
|
+
from flash.providers.lambdalabs.jobs import build_payload, submit_run_lambda
|
|
13
|
+
from flash.providers.runpod.train import (
|
|
14
|
+
WORKER_DEPS,
|
|
15
|
+
WORKER_SYSTEM_DEPS,
|
|
16
|
+
build_worker_env,
|
|
17
|
+
resolve_worker_deps,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"WORKER_DEPS",
|
|
22
|
+
"WORKER_SYSTEM_DEPS",
|
|
23
|
+
"build_payload",
|
|
24
|
+
"build_worker_env",
|
|
25
|
+
"resolve_worker_deps",
|
|
26
|
+
"submit_run_lambda",
|
|
27
|
+
]
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""RunPod startup preflight.
|
|
2
|
+
|
|
3
|
+
``check_run_preflight`` aggregates RunPod's missing-config problems plus the shared Hugging Face
|
|
4
|
+
dataset-repo requirements, so a single startup error lists everything missing.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
|
|
11
|
+
from flash.providers.runpod.preflight import (
|
|
12
|
+
PreflightError,
|
|
13
|
+
missing_credentials,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"PreflightError",
|
|
18
|
+
"check_run_preflight",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _missing_hf_credentials() -> list[str]:
|
|
23
|
+
"""Shared run infra every substrate needs."""
|
|
24
|
+
problems: list[str] = []
|
|
25
|
+
if not os.environ.get("GITHUB_TOKEN"):
|
|
26
|
+
problems.append(" - GITHUB_TOKEN: server token with access to managed Freesolo environments")
|
|
27
|
+
if not os.environ.get("HF_TOKEN"):
|
|
28
|
+
problems.append(
|
|
29
|
+
" - HF_TOKEN: a token with write access to each run's "
|
|
30
|
+
"`[train] hf_repo`, e.g. `export HF_TOKEN=hf_...`"
|
|
31
|
+
)
|
|
32
|
+
return problems
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _preflight_provider_names() -> set[str]:
|
|
36
|
+
"""The providers whose operator config this control plane must satisfy."""
|
|
37
|
+
return {"runpod"}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def check_run_preflight(require_hf: bool = True) -> None:
|
|
41
|
+
"""Validate RunPod operator config; raise on missing."""
|
|
42
|
+
selected = _preflight_provider_names()
|
|
43
|
+
problems: list[str] = []
|
|
44
|
+
# The HF write token is shared run infra and is checked once so it isn't double-reported.
|
|
45
|
+
# The HF dataset repo itself is per-run (``[train] hf_repo``).
|
|
46
|
+
if "runpod" in selected:
|
|
47
|
+
problems += missing_credentials(require_hf=False)
|
|
48
|
+
if require_hf:
|
|
49
|
+
problems += _missing_hf_credentials()
|
|
50
|
+
if problems:
|
|
51
|
+
raise PreflightError(
|
|
52
|
+
"the Flash control plane is missing required operator configuration:\n"
|
|
53
|
+
+ "\n".join(problems)
|
|
54
|
+
+ "\n\nSet these on the control-plane host."
|
|
55
|
+
)
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""Realized provider cost (COGS) for a finished run -- the cost side of estimator accuracy.
|
|
2
|
+
|
|
3
|
+
RunPod's billing API gives the dollars it ACTUALLY charged, which the reconciliation job
|
|
4
|
+
compares against the run's charged pre-flight estimate. This module owns the ``RealizedCost``
|
|
5
|
+
shape and dispatches to the RunPod shaper by the run's persisted handle
|
|
6
|
+
(``RunStatus.remote['provider']``). The HTTP calls live in the provider's ``api.py``; the pure
|
|
7
|
+
shaping lives in its ``cost.py`` so it stays offline-testable.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class RealizedCost:
|
|
17
|
+
provider: str
|
|
18
|
+
realized_usd: float
|
|
19
|
+
by_resource: dict[str, float] = field(
|
|
20
|
+
default_factory=dict
|
|
21
|
+
) # {"gpu": .., "disk": .., "bwd": ..}
|
|
22
|
+
wall_seconds: float | None = None
|
|
23
|
+
source: dict = field(default_factory=dict) # audit: resource ids / raw refs
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def realized_cost_for_remote(
|
|
27
|
+
remote: dict | None, *, start: float, end: float, run_end: float | None = None
|
|
28
|
+
) -> RealizedCost | None:
|
|
29
|
+
"""Pull realized cost for a run from its persisted provider handle, or None if unattributable.
|
|
30
|
+
|
|
31
|
+
``remote`` is ``RunStatus.remote`` (the last/successful attempt's handle dict). Returns None
|
|
32
|
+
when there is no handle, no resource id, or an unknown provider -- the run then stays
|
|
33
|
+
unreconciled (and is retried).
|
|
34
|
+
|
|
35
|
+
Two distinct time bounds, because the two cost sources are different:
|
|
36
|
+
* ``start``/``end`` bound the RunPod BILLING-API query window. The caller pads ``end`` past the
|
|
37
|
+
run's terminal time so the settled invoice row is in range (see reconcile ``_SETTLE_SECONDS``).
|
|
38
|
+
* ``run_end`` is the run's ACTUAL terminal time (~teardown). The instance providers
|
|
39
|
+
(Lambda/Hyperstack) have no billing endpoint: an instance bills at a flat $/hr from launch to
|
|
40
|
+
teardown, so their realized COGS is wall x rate over ``started_ts -> run_end`` — it must NOT
|
|
41
|
+
use the settle-padded ``end`` or it would over-bill by the padding (up to an hour). Defaults
|
|
42
|
+
to ``end`` for back-compat when the caller doesn't distinguish.
|
|
43
|
+
"""
|
|
44
|
+
if not remote:
|
|
45
|
+
return None
|
|
46
|
+
provider = remote.get("provider") or "runpod"
|
|
47
|
+
if provider == "runpod":
|
|
48
|
+
from flash.providers.runpod.cost import realized_cost as runpod_realized
|
|
49
|
+
|
|
50
|
+
return runpod_realized(remote.get("endpoint_id"), start=start, end=end)
|
|
51
|
+
if provider in ("lambda", "hyperstack"):
|
|
52
|
+
return _instance_realized_cost(remote, start=start, end=run_end if run_end is not None else end)
|
|
53
|
+
return None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _instance_realized_cost(
|
|
57
|
+
remote: dict, *, start: float, end: float
|
|
58
|
+
) -> RealizedCost | None:
|
|
59
|
+
"""Realized COGS for an instance-billed provider: wall-clock x the instance's flat $/hr.
|
|
60
|
+
|
|
61
|
+
The instance billed from its launch (``started_ts`` on the handle) until teardown (``end``, the
|
|
62
|
+
run's true terminal time — NOT a settle-padded billing-query bound). Unattributable -> None (no
|
|
63
|
+
rate persisted) so the run stays unreconciled rather than booking $0.
|
|
64
|
+
"""
|
|
65
|
+
rate = remote.get("hourly_usd")
|
|
66
|
+
rid = remote.get("instance_id") or remote.get("vm_id")
|
|
67
|
+
# Honor the module contract: no rate OR no auditable resource id -> unattributable (None), so the
|
|
68
|
+
# run stays unreconciled rather than booking instance cost we can't tie to a resource.
|
|
69
|
+
if not rate or not rid:
|
|
70
|
+
return None
|
|
71
|
+
launch = remote.get("started_ts") or start
|
|
72
|
+
wall = max(0.0, float(end) - float(launch))
|
|
73
|
+
usd = round(wall / 3600.0 * float(rate), 6)
|
|
74
|
+
return RealizedCost(
|
|
75
|
+
provider=str(remote.get("provider")),
|
|
76
|
+
realized_usd=usd,
|
|
77
|
+
by_resource={"gpu": usd},
|
|
78
|
+
wall_seconds=wall,
|
|
79
|
+
source={"resource_id": str(rid), "hourly_usd": float(rate), "started_ts": float(launch)},
|
|
80
|
+
)
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
"""RunPod Flash provider: managed, serverless GPUs (no Docker) for Flash.
|
|
2
|
+
|
|
3
|
+
Fine-tuning runs on a dedicated RunPod GPU provisioned by Flash. A decorated Python
|
|
4
|
+
handler (``train._train_body``) executes ``flash.engine.worker`` on the GPU; Flash
|
|
5
|
+
handles provisioning, dependency install, execution, and scale-to-zero teardown.
|
|
6
|
+
Serving exposes an OpenAI-compatible endpoint for a trained LoRA adapter.
|
|
7
|
+
|
|
8
|
+
``PROVIDER`` is the ``base.Provider`` implementation the registry hands out; the
|
|
9
|
+
orchestrator/allocator only talk to its interface, never these modules directly.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
from flash.providers.base import GpuClass, JobHandle, PollResult, Provider
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RunpodProvider:
|
|
20
|
+
"""``base.Provider`` for the RunPod Flash substrate."""
|
|
21
|
+
|
|
22
|
+
name = "runpod"
|
|
23
|
+
|
|
24
|
+
def is_configured(self) -> bool:
|
|
25
|
+
# RunPod is the ALWAYS-ON default substrate, so it is always "available" for
|
|
26
|
+
# allocation. Pricing is static, and a missing RUNPOD_API_KEY surfaces at provision time
|
|
27
|
+
# via ensure_auth / the preflight, never as a silent empty candidate list. This matches the
|
|
28
|
+
# historical ``available_providers()`` which listed runpod unconditionally.
|
|
29
|
+
return True
|
|
30
|
+
|
|
31
|
+
def preflight(self, require_hf: bool = True) -> list[str]:
|
|
32
|
+
from flash.providers.runpod.preflight import missing_credentials
|
|
33
|
+
|
|
34
|
+
return missing_credentials(require_hf=require_hf)
|
|
35
|
+
|
|
36
|
+
def gpu_classes(self) -> list[GpuClass]:
|
|
37
|
+
from flash.providers.runpod.gpus import gpu_classes
|
|
38
|
+
|
|
39
|
+
return gpu_classes()
|
|
40
|
+
|
|
41
|
+
def hourly_rate(self, gpu: str) -> float:
|
|
42
|
+
from flash.providers.runpod.pricing import hourly_rate
|
|
43
|
+
|
|
44
|
+
return hourly_rate(gpu)
|
|
45
|
+
|
|
46
|
+
def submit_run(
|
|
47
|
+
self,
|
|
48
|
+
spec,
|
|
49
|
+
seed: int,
|
|
50
|
+
*,
|
|
51
|
+
log: Any = None,
|
|
52
|
+
on_handle: Any = None,
|
|
53
|
+
attempt: int = 0,
|
|
54
|
+
runtime_secrets: dict[str, str] | None = None,
|
|
55
|
+
on_last_gpu: bool = False,
|
|
56
|
+
) -> PollResult:
|
|
57
|
+
# ``on_last_gpu`` stretches the no-capacity grace when no further GPU attempt will be made
|
|
58
|
+
# after this one — either the candidate list is exhausted or the retry budget is exhausted (see
|
|
59
|
+
# ``jobs.stall_kwargs``); waiting longer can't cost a fallback there is none.
|
|
60
|
+
from flash.providers.runpod.jobs import submit_run
|
|
61
|
+
|
|
62
|
+
kwargs = {
|
|
63
|
+
"log": log,
|
|
64
|
+
"on_handle": on_handle,
|
|
65
|
+
"attempt": attempt,
|
|
66
|
+
"on_last_gpu": on_last_gpu,
|
|
67
|
+
}
|
|
68
|
+
if runtime_secrets:
|
|
69
|
+
kwargs["runtime_secrets"] = runtime_secrets
|
|
70
|
+
return submit_run(spec, seed, **kwargs)
|
|
71
|
+
|
|
72
|
+
def poll(self, handle: JobHandle, spec, seed: int, *, log: Any = None) -> PollResult:
|
|
73
|
+
from flash.providers.runpod.jobs import JobHandle as RunpodJobHandle
|
|
74
|
+
from flash.providers.runpod.jobs import (
|
|
75
|
+
make_hf_failure_detail_reader,
|
|
76
|
+
make_hf_heartbeat_reader,
|
|
77
|
+
poll_job,
|
|
78
|
+
stall_kwargs,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
hf_repo = spec.train.hf_repo
|
|
82
|
+
prefix = f"{spec.phase}/{spec.run_id}/seed{seed}"
|
|
83
|
+
reader = make_hf_heartbeat_reader(hf_repo, prefix) if hf_repo else None
|
|
84
|
+
failure_reader = (
|
|
85
|
+
make_hf_failure_detail_reader(hf_repo, prefix, spec.phase) if hf_repo else None
|
|
86
|
+
)
|
|
87
|
+
rh = RunpodJobHandle.from_dict(handle.to_dict())
|
|
88
|
+
if log is not None:
|
|
89
|
+
print(f"attaching: job={rh.job_id} endpoint={rh.endpoint_name}", file=log, flush=True)
|
|
90
|
+
# Same stall tuning as the submit path so a reattached run isn't judged differently:
|
|
91
|
+
# the original submit's ``on_last_gpu`` is persisted in the handle (by the runner's
|
|
92
|
+
# on_handle), so reproduce its no-capacity grace here instead of defaulting to the
|
|
93
|
+
# shorter non-last window. Absent (a pre-persist / non-runpod handle) => False, the
|
|
94
|
+
# historical default.
|
|
95
|
+
on_last_gpu = bool(handle.to_dict().get("on_last_gpu", False))
|
|
96
|
+
return poll_job(
|
|
97
|
+
rh,
|
|
98
|
+
log=log,
|
|
99
|
+
heartbeat_reader=reader,
|
|
100
|
+
failure_detail_reader=failure_reader,
|
|
101
|
+
**stall_kwargs(on_last_gpu=on_last_gpu),
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
def cancel(self, handle: JobHandle) -> None:
|
|
105
|
+
from flash.providers.runpod import api as runpod_api
|
|
106
|
+
|
|
107
|
+
d = handle.to_dict()
|
|
108
|
+
if d.get("endpoint_id") and d.get("job_id"):
|
|
109
|
+
runpod_api.cancel_job(d["endpoint_id"], d["job_id"])
|
|
110
|
+
|
|
111
|
+
def destroy(self, handle: JobHandle) -> None:
|
|
112
|
+
from flash.providers.runpod import api as runpod_api
|
|
113
|
+
|
|
114
|
+
d = handle.to_dict()
|
|
115
|
+
if d.get("endpoint_id"):
|
|
116
|
+
runpod_api.delete_endpoint(d["endpoint_id"])
|
|
117
|
+
|
|
118
|
+
def gc(self, spec) -> None:
|
|
119
|
+
from flash.providers.runpod.train import terminate_endpoint
|
|
120
|
+
|
|
121
|
+
terminate_endpoint(spec.gpu.type, spec.run_id)
|
|
122
|
+
|
|
123
|
+
def sweep_orphans(self, active_labels: set[str] | None = None) -> list[int]:
|
|
124
|
+
# No-op: RunPod serverless endpoints have no standing per-run billing to reap on
|
|
125
|
+
# crash recovery (a failed-before-submit endpoint is GC'd by reconstructed name in
|
|
126
|
+
# recover_runs). Present for the ``base.Provider`` protocol.
|
|
127
|
+
return []
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
PROVIDER: Provider = RunpodProvider()
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
"""Thin RunPod REST client (no SDK state): endpoints, queue jobs, health.
|
|
2
|
+
|
|
3
|
+
Used by the run supervisor and endpoint GC so that a *fresh process* can
|
|
4
|
+
reattach to / clean up after any run using only the persisted ids + RUNPOD_API_KEY —
|
|
5
|
+
independent of the Flash SDK's local resource registry (which is per-directory,
|
|
6
|
+
whole-dict, last-writer-wins and therefore unreliable across processes).
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import urllib.error
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from flash.providers._http import RestClient
|
|
15
|
+
from flash.providers.runpod import keys as _keys
|
|
16
|
+
|
|
17
|
+
REST_BASE = "https://rest.runpod.io/v1"
|
|
18
|
+
QUEUE_BASE = "https://api.runpod.ai/v2"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class RunpodApiError(RuntimeError):
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# Shared urllib client (full-URL form: callers pass absolute REST/QUEUE urls).
|
|
26
|
+
# Env-only by design: ~/.flash/config.json holds the *Flash* key (client-side),
|
|
27
|
+
# never the RunPod key — the operator sets RUNPOD_API_KEY on the control-plane host.
|
|
28
|
+
#
|
|
29
|
+
# ``RUNPOD_API_KEY`` may be a comma-separated pool of per-account keys: the client tries
|
|
30
|
+
# them active-account-first per call (``keys.ordered_keys``) and fails over to the next
|
|
31
|
+
# account on an auth/quota/not-found error (``keys.is_failover_error``). RunPod endpoints
|
|
32
|
+
# are account-scoped, so a single-account op (status/cancel/delete) resolves no matter
|
|
33
|
+
# which account a failed-over run was provisioned on. A single key => a pool of one.
|
|
34
|
+
_CLIENT = RestClient(
|
|
35
|
+
env_var="RUNPOD_API_KEY",
|
|
36
|
+
error_cls=RunpodApiError,
|
|
37
|
+
keys_provider=_keys.ordered_keys,
|
|
38
|
+
failover_predicate=_keys.is_failover_error,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def request_with_retries(
|
|
43
|
+
url: str,
|
|
44
|
+
method: str = "GET",
|
|
45
|
+
body: dict | None = None,
|
|
46
|
+
retries: int = 4,
|
|
47
|
+
base_delay: float = 2.0,
|
|
48
|
+
) -> Any:
|
|
49
|
+
"""REST call hardened against transient network/5xx blips (jittered backoff)."""
|
|
50
|
+
return _CLIENT.request_with_retries(
|
|
51
|
+
url, method=method, body=body, retries=retries, base_delay=base_delay
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# ---------------------------------------------------------------------------
|
|
56
|
+
# Endpoints
|
|
57
|
+
# ---------------------------------------------------------------------------
|
|
58
|
+
def list_endpoints() -> list[dict]:
|
|
59
|
+
# ``RUNPOD_API_KEY`` may be a comma-separated pool of per-account keys. RunPod
|
|
60
|
+
# endpoints are account-scoped: a plain request_with_retries() call stops at the
|
|
61
|
+
# first key that succeeds and returns only *that* account's endpoints. Idle-sweep
|
|
62
|
+
# and slot-reconcile need the full fleet across every account in the pool, so we
|
|
63
|
+
# query each key independently (with per-key retries) and aggregate.
|
|
64
|
+
#
|
|
65
|
+
# Raises on any per-key failure so callers that treat an empty result as "confirmed
|
|
66
|
+
# absent" (teardown, slot-reconcile) don't act on an incomplete view. Both
|
|
67
|
+
# sweep_idle_endpoints() and the slot reconcile already catch and skip on exception.
|
|
68
|
+
pool = _keys.keys()
|
|
69
|
+
if not pool:
|
|
70
|
+
# No RUNPOD_API_KEY at all: an empty `pool` would make this return [] WITHOUT a single
|
|
71
|
+
# authenticated call, and callers read [] as "the fleet is empty / confirmed absent" and may
|
|
72
|
+
# act on that (teardown, slot-reconcile). Fail loud instead — matching the old single-call
|
|
73
|
+
# request_with_retries() behavior, which raised on a missing key.
|
|
74
|
+
raise RunpodApiError(
|
|
75
|
+
"RUNPOD_API_KEY is not set; refusing to report an empty endpoint fleet"
|
|
76
|
+
)
|
|
77
|
+
all_endpoints: list[dict] = []
|
|
78
|
+
for key in pool:
|
|
79
|
+
out = _CLIENT.request_with_retries_for_key(key, f"{REST_BASE}/endpoints", retries=2)
|
|
80
|
+
if not isinstance(out, list):
|
|
81
|
+
# A 200 whose body isn't the expected list is NOT an empty account — silently skipping it
|
|
82
|
+
# (the old behavior) yields a partial fleet view that callers trust as complete. Raise so
|
|
83
|
+
# the per-key failure surfaces, consistent with this function's "fail, don't under-report"
|
|
84
|
+
# contract above.
|
|
85
|
+
raise RunpodApiError(
|
|
86
|
+
f"unexpected /endpoints response for a pool key (got {type(out).__name__}, want list)"
|
|
87
|
+
)
|
|
88
|
+
all_endpoints.extend(out)
|
|
89
|
+
return all_endpoints
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def find_endpoints_by_name(substr: str) -> list[dict]:
|
|
93
|
+
return [e for e in list_endpoints() if substr in (e.get("name") or "")]
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def delete_endpoint(endpoint_id: str) -> bool:
|
|
97
|
+
try:
|
|
98
|
+
request_with_retries(f"{REST_BASE}/endpoints/{endpoint_id}", method="DELETE", retries=2)
|
|
99
|
+
return True
|
|
100
|
+
except RunpodApiError as e:
|
|
101
|
+
# An already-gone endpoint is a clean teardown, not a failure: a 404 (or a body
|
|
102
|
+
# saying the endpoint "does not exist") means the desired end state — no such
|
|
103
|
+
# endpoint — already holds. Reporting False here makes undeploy_adapter surface a
|
|
104
|
+
# misleading "may still be running" 502 for something that's provably gone.
|
|
105
|
+
return _is_not_found(e)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _is_not_found(err: RunpodApiError) -> bool:
|
|
109
|
+
"""True only when a RunpodApiError represents a genuine 404 (endpoint already gone).
|
|
110
|
+
|
|
111
|
+
request_with_retries chains the original urllib HTTPError as ``__cause__`` for every
|
|
112
|
+
fast-failed 4xx (``raise ... from e``), so the status code is authoritative when a
|
|
113
|
+
cause is present: a 404 is "already gone", anything else (403/401/5xx) is a real
|
|
114
|
+
failure and must NOT be swallowed — a body that merely *mentions* "does not exist" on a
|
|
115
|
+
403 is still a 403. We only fall back to a text match when there is no HTTPError cause
|
|
116
|
+
(e.g. the "failed after N attempts" path), and even then only on an unambiguous 404.
|
|
117
|
+
"""
|
|
118
|
+
cause = err.__cause__
|
|
119
|
+
if isinstance(cause, urllib.error.HTTPError):
|
|
120
|
+
return cause.code == 404
|
|
121
|
+
return "http 404" in str(err).lower()
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def endpoint_health(endpoint_id: str) -> dict:
|
|
125
|
+
return request_with_retries(f"{QUEUE_BASE}/{endpoint_id}/health")
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
# ---------------------------------------------------------------------------
|
|
129
|
+
# Queue jobs
|
|
130
|
+
# ---------------------------------------------------------------------------
|
|
131
|
+
def submit_job(endpoint_id: str, input_payload: dict) -> str:
|
|
132
|
+
"""POST /run -> job id (async queue submission)."""
|
|
133
|
+
out = request_with_retries(
|
|
134
|
+
f"{QUEUE_BASE}/{endpoint_id}/run", method="POST", body={"input": input_payload}
|
|
135
|
+
)
|
|
136
|
+
job_id = out.get("id")
|
|
137
|
+
if not job_id:
|
|
138
|
+
raise RunpodApiError(f"submit_job: no job id in response: {out}")
|
|
139
|
+
return job_id
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def job_status(endpoint_id: str, job_id: str) -> dict:
|
|
143
|
+
"""GET /status/<job_id> -> {status, output?, error?, ...}."""
|
|
144
|
+
return request_with_retries(f"{QUEUE_BASE}/{endpoint_id}/status/{job_id}")
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def cancel_job(endpoint_id: str, job_id: str) -> dict:
|
|
148
|
+
return request_with_retries(
|
|
149
|
+
f"{QUEUE_BASE}/{endpoint_id}/cancel/{job_id}", method="POST", retries=2
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
# ---------------------------------------------------------------------------
|
|
154
|
+
# Realized billing (COGS) -- what RunPod actually charged, for estimator accuracy.
|
|
155
|
+
# ---------------------------------------------------------------------------
|
|
156
|
+
def billing_endpoints(
|
|
157
|
+
*,
|
|
158
|
+
start_time: str,
|
|
159
|
+
end_time: str,
|
|
160
|
+
endpoint_id: str | None = None,
|
|
161
|
+
bucket_size: str = "day",
|
|
162
|
+
) -> list[dict]:
|
|
163
|
+
"""Realized serverless spend per endpoint over [start_time, end_time] (ISO-8601).
|
|
164
|
+
|
|
165
|
+
GET /v1/billing/endpoints -> records of {endpointId, time, amount (USD), timeBilledMs, ...}.
|
|
166
|
+
RunPod has no per-job cost; the finest realized granularity is per-endpoint per time bucket.
|
|
167
|
+
Flash provisions one endpoint per run, so filtering by ``endpoint_id`` yields that run's
|
|
168
|
+
realized cost even after the endpoint is torn down (billing history survives deletion).
|
|
169
|
+
"""
|
|
170
|
+
from urllib.parse import urlencode
|
|
171
|
+
|
|
172
|
+
params: dict[str, str] = {
|
|
173
|
+
"startTime": start_time,
|
|
174
|
+
"endTime": end_time,
|
|
175
|
+
"bucketSize": bucket_size,
|
|
176
|
+
}
|
|
177
|
+
if endpoint_id:
|
|
178
|
+
params["endpointId"] = endpoint_id
|
|
179
|
+
out = request_with_retries(f"{REST_BASE}/billing/endpoints?{urlencode(params)}")
|
|
180
|
+
if isinstance(out, list):
|
|
181
|
+
return out
|
|
182
|
+
# Defensive: some RunPod list responses wrap rows under a key.
|
|
183
|
+
if isinstance(out, dict):
|
|
184
|
+
rows = out.get("data") or out.get("endpoints") or out.get("billing")
|
|
185
|
+
return rows if isinstance(rows, list) else []
|
|
186
|
+
return []
|