freesolo-flash-dev 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flash/__init__.py +29 -0
- flash/_channel.py +23 -0
- flash/_fileio.py +35 -0
- flash/_logging.py +49 -0
- flash/_update_check.py +266 -0
- flash/catalog.py +253 -0
- flash/cli/__init__.py +1 -0
- flash/cli/main/__init__.py +227 -0
- flash/cli/main/__main__.py +6 -0
- flash/cli/main/commands.py +636 -0
- flash/cli/main/envpush.py +317 -0
- flash/cli/main/render.py +599 -0
- flash/cli/main/training_doc.py +455 -0
- flash/client/__init__.py +14 -0
- flash/client/config.py +70 -0
- flash/client/http.py +372 -0
- flash/client/runtime_secrets.py +69 -0
- flash/client/specs.py +20 -0
- flash/cost/__init__.py +16 -0
- flash/cost/analytical.py +175 -0
- flash/cost/facts.py +114 -0
- flash/cost/spec.py +113 -0
- flash/cost/types.py +158 -0
- flash/engine/__init__.py +6 -0
- flash/engine/accounting.py +36 -0
- flash/engine/chalk_kernels.py +116 -0
- flash/engine/multiturn_rollout.py +780 -0
- flash/engine/recipe.py +86 -0
- flash/engine/vram.py +603 -0
- flash/engine/worker/__init__.py +2916 -0
- flash/engine/worker/__main__.py +4 -0
- flash/engine/worker/kernel_warmup.py +400 -0
- flash/engine/worker/lora.py +796 -0
- flash/engine/worker/packing.py +366 -0
- flash/engine/worker/perf.py +1048 -0
- flash/envs/__init__.py +10 -0
- flash/envs/adapter/__init__.py +883 -0
- flash/envs/adapter/rubric.py +222 -0
- flash/envs/base.py +52 -0
- flash/envs/registry.py +62 -0
- flash/mcp/__init__.py +1 -0
- flash/mcp/server.py +85 -0
- flash/providers/__init__.py +59 -0
- flash/providers/_auth.py +24 -0
- flash/providers/_http.py +230 -0
- flash/providers/_instance.py +416 -0
- flash/providers/_instance_bootstrap.py +517 -0
- flash/providers/_poll.py +311 -0
- flash/providers/allocator.py +193 -0
- flash/providers/base.py +431 -0
- flash/providers/hyperstack/__init__.py +127 -0
- flash/providers/hyperstack/api.py +522 -0
- flash/providers/hyperstack/auth.py +17 -0
- flash/providers/hyperstack/gpus.py +29 -0
- flash/providers/hyperstack/jobs/__init__.py +632 -0
- flash/providers/hyperstack/jobs/builders.py +122 -0
- flash/providers/hyperstack/preflight.py +23 -0
- flash/providers/hyperstack/pricing.py +26 -0
- flash/providers/hyperstack/train.py +25 -0
- flash/providers/lambdalabs/__init__.py +139 -0
- flash/providers/lambdalabs/api.py +261 -0
- flash/providers/lambdalabs/auth.py +18 -0
- flash/providers/lambdalabs/gpus.py +29 -0
- flash/providers/lambdalabs/jobs/__init__.py +724 -0
- flash/providers/lambdalabs/jobs/builders.py +118 -0
- flash/providers/lambdalabs/preflight.py +27 -0
- flash/providers/lambdalabs/pricing.py +51 -0
- flash/providers/lambdalabs/train.py +27 -0
- flash/providers/preflight.py +55 -0
- flash/providers/realized.py +80 -0
- flash/providers/runpod/__init__.py +130 -0
- flash/providers/runpod/api.py +186 -0
- flash/providers/runpod/auth.py +37 -0
- flash/providers/runpod/cost.py +57 -0
- flash/providers/runpod/gpus.py +46 -0
- flash/providers/runpod/jobs.py +956 -0
- flash/providers/runpod/keys.py +139 -0
- flash/providers/runpod/preflight.py +30 -0
- flash/providers/runpod/preload.py +915 -0
- flash/providers/runpod/pricing.py +18 -0
- flash/providers/runpod/slots.py +79 -0
- flash/providers/runpod/train/__init__.py +150 -0
- flash/providers/runpod/train/deps.py +395 -0
- flash/providers/runpod/train/endpoints.py +820 -0
- flash/py.typed +0 -0
- flash/runner/__init__.py +686 -0
- flash/runner/checkpoints.py +82 -0
- flash/runner/deploy.py +422 -0
- flash/runner/lifecycle.py +672 -0
- flash/schema/__init__.py +375 -0
- flash/schema/fields.py +331 -0
- flash/serve/__init__.py +1 -0
- flash/serve/deploy.py +326 -0
- flash/serve/pricing.py +60 -0
- flash/server/__init__.py +1 -0
- flash/server/__main__.py +20 -0
- flash/server/app.py +961 -0
- flash/server/auth.py +263 -0
- flash/server/billing.py +124 -0
- flash/server/checkpoints.py +110 -0
- flash/server/db.py +160 -0
- flash/server/environment_registry.py +102 -0
- flash/server/envs.py +360 -0
- flash/server/reconcile.py +163 -0
- flash/server/run_registry.py +150 -0
- flash/spec.py +333 -0
- freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
- freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
- freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
- freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
- freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
flash/client/http.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
1
|
+
"""Stdlib HTTP client for the Flash control plane (no extra dependencies).
|
|
2
|
+
|
|
3
|
+
Every CLI/MCP operation maps to one method here. Server errors (FastAPI's
|
|
4
|
+
``{"detail": ...}``) surface as ``ApiError`` with the server's message; connection
|
|
5
|
+
problems surface as ``ClientError`` with an actionable hint.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import codecs
|
|
11
|
+
import contextlib
|
|
12
|
+
import json
|
|
13
|
+
import os
|
|
14
|
+
import urllib.error
|
|
15
|
+
import urllib.request
|
|
16
|
+
from collections.abc import Callable, Iterator
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
from .config import load_credentials_with_source
|
|
20
|
+
|
|
21
|
+
# Called as ``progress(bytes_sent, total_bytes)`` as a request body streams to the server, so
|
|
22
|
+
# the CLI can draw an upload bar. ``total_bytes`` is the full Content-Length, fixed up front.
|
|
23
|
+
ProgressCallback = Callable[[int, int], None]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ClientError(RuntimeError):
|
|
27
|
+
"""Expected client-side errors (no key, unreachable server) — printed cleanly."""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ApiError(ClientError):
|
|
31
|
+
def __init__(self, status: int, message: str):
|
|
32
|
+
super().__init__(message)
|
|
33
|
+
self.status = status
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# Login is handled by the freesolo backend (not the flash control plane): `flash login`
|
|
37
|
+
# verifies the user's freesolo API key here. The same key authenticates the flash
|
|
38
|
+
# control plane, which accepts freesolo-issued keys.
|
|
39
|
+
DEFAULT_FREESOLO_BASE_URL = "https://api.freesolo.co"
|
|
40
|
+
FREESOLO_AUTH_VERIFY_PATH = "/api/auth/verify"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def freesolo_base_url(override: str | None = None) -> str:
|
|
44
|
+
return (override or os.environ.get("FREESOLO_BASE_URL") or DEFAULT_FREESOLO_BASE_URL).rstrip(
|
|
45
|
+
"/"
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _detail_from_http_error(exc: urllib.error.HTTPError) -> str:
|
|
50
|
+
"""Extract the server's error message from an HTTPError body (FastAPI ``detail``)."""
|
|
51
|
+
body = exc.read()
|
|
52
|
+
try:
|
|
53
|
+
detail = json.loads(body).get("detail") or body.decode()
|
|
54
|
+
except (ValueError, AttributeError):
|
|
55
|
+
detail = body.decode(errors="replace") if body else str(exc)
|
|
56
|
+
return str(detail)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def verify_freesolo_key(api_key: str, base_url: str | None = None) -> None:
|
|
60
|
+
"""Verify a freesolo API key against the freesolo backend's ``/api/auth/verify``.
|
|
61
|
+
|
|
62
|
+
Raises :class:`ClientError`/:class:`ApiError` if the key is rejected or the backend is
|
|
63
|
+
unreachable; returns ``None`` on success. Keys are issued from the freesolo sign-in page.
|
|
64
|
+
"""
|
|
65
|
+
base = freesolo_base_url(base_url)
|
|
66
|
+
url = f"{base}{FREESOLO_AUTH_VERIFY_PATH}"
|
|
67
|
+
req = urllib.request.Request(
|
|
68
|
+
url,
|
|
69
|
+
method="GET",
|
|
70
|
+
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
|
|
71
|
+
)
|
|
72
|
+
try:
|
|
73
|
+
with urllib.request.urlopen(req, timeout=30) as resp:
|
|
74
|
+
resp.read()
|
|
75
|
+
except urllib.error.HTTPError as exc:
|
|
76
|
+
if exc.code in (401, 403):
|
|
77
|
+
raise ClientError(
|
|
78
|
+
"freesolo rejected this API key — create or copy a valid key at "
|
|
79
|
+
"https://freesolo.co/sign-in and pass it with `flash login --api-key` "
|
|
80
|
+
"(or FREESOLO_API_KEY)"
|
|
81
|
+
) from exc
|
|
82
|
+
raise ApiError(exc.code, _detail_from_http_error(exc)) from exc
|
|
83
|
+
except urllib.error.URLError as exc:
|
|
84
|
+
raise ClientError(
|
|
85
|
+
f"cannot reach the freesolo backend at {base} ({exc.reason}); "
|
|
86
|
+
"check your network connection and FREESOLO_BASE_URL"
|
|
87
|
+
) from exc
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class _ProgressReader:
|
|
91
|
+
"""A read()-only file-like over an in-memory payload that reports bytes consumed.
|
|
92
|
+
|
|
93
|
+
``http.client`` sends a body exposing ``read()`` in blocksize chunks; we forward the running
|
|
94
|
+
total to ``progress(sent, total)`` for each chunk so the CLI can draw an upload bar. The
|
|
95
|
+
caller sets Content-Length from ``len(payload)``, so the request is NOT chunked-encoded and
|
|
96
|
+
the server reads it exactly as a plain bytes body."""
|
|
97
|
+
|
|
98
|
+
def __init__(self, data: bytes, progress: ProgressCallback):
|
|
99
|
+
self._data = data
|
|
100
|
+
self._total = len(data)
|
|
101
|
+
self._pos = 0
|
|
102
|
+
self._progress = progress
|
|
103
|
+
|
|
104
|
+
def __len__(self) -> int:
|
|
105
|
+
return self._total
|
|
106
|
+
|
|
107
|
+
def read(self, size: int = -1) -> bytes:
|
|
108
|
+
if size is None or size < 0:
|
|
109
|
+
chunk = self._data[self._pos :]
|
|
110
|
+
else:
|
|
111
|
+
chunk = self._data[self._pos : self._pos + size]
|
|
112
|
+
self._pos += len(chunk)
|
|
113
|
+
# a rendering hiccup must never abort an in-flight upload
|
|
114
|
+
with contextlib.suppress(Exception):
|
|
115
|
+
self._progress(self._pos, self._total)
|
|
116
|
+
return chunk
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class ApiClient:
|
|
120
|
+
def __init__(
|
|
121
|
+
self,
|
|
122
|
+
api_url: str,
|
|
123
|
+
api_key: str | None = None,
|
|
124
|
+
timeout: float = 60.0,
|
|
125
|
+
key_source: str | None = None,
|
|
126
|
+
):
|
|
127
|
+
self.api_url = api_url.rstrip("/")
|
|
128
|
+
self.api_key = api_key
|
|
129
|
+
self.timeout = timeout
|
|
130
|
+
self.key_source = key_source
|
|
131
|
+
|
|
132
|
+
def _auth_error_detail(self, status: int, detail: str) -> str:
|
|
133
|
+
if status not in {401, 403} or self.key_source != "FREESOLO_API_KEY":
|
|
134
|
+
return detail
|
|
135
|
+
return (
|
|
136
|
+
f"{detail}; FREESOLO_API_KEY is set and overrides the key saved by "
|
|
137
|
+
"`flash login`. Unset FREESOLO_API_KEY or update it to a valid freesolo API key."
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
def _request(
|
|
141
|
+
self,
|
|
142
|
+
method: str,
|
|
143
|
+
path: str,
|
|
144
|
+
body: dict | None = None,
|
|
145
|
+
timeout: float | None = None,
|
|
146
|
+
) -> Any:
|
|
147
|
+
headers = {"Content-Type": "application/json"}
|
|
148
|
+
if self.api_key:
|
|
149
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
150
|
+
req = urllib.request.Request(
|
|
151
|
+
f"{self.api_url}{path}",
|
|
152
|
+
method=method,
|
|
153
|
+
data=json.dumps(body).encode() if body is not None else None,
|
|
154
|
+
headers=headers,
|
|
155
|
+
)
|
|
156
|
+
try:
|
|
157
|
+
with urllib.request.urlopen(req, timeout=timeout or self.timeout) as resp:
|
|
158
|
+
raw = resp.read()
|
|
159
|
+
return json.loads(raw) if raw else {}
|
|
160
|
+
except urllib.error.HTTPError as exc:
|
|
161
|
+
detail = self._auth_error_detail(exc.code, _detail_from_http_error(exc))
|
|
162
|
+
raise ApiError(exc.code, detail) from exc
|
|
163
|
+
except urllib.error.URLError as exc:
|
|
164
|
+
raise ClientError(
|
|
165
|
+
f"cannot reach the Flash service at {self.api_url} ({exc.reason}); "
|
|
166
|
+
"check your network connection and FLASH_API_URL"
|
|
167
|
+
) from exc
|
|
168
|
+
|
|
169
|
+
def _post_with_progress(
|
|
170
|
+
self,
|
|
171
|
+
path: str,
|
|
172
|
+
body: dict,
|
|
173
|
+
*,
|
|
174
|
+
progress: ProgressCallback,
|
|
175
|
+
timeout: float,
|
|
176
|
+
) -> Any:
|
|
177
|
+
"""POST a JSON body while reporting upload progress (see :class:`_ProgressReader`).
|
|
178
|
+
|
|
179
|
+
Same error mapping as :meth:`_request`; kept separate because the body is a streaming
|
|
180
|
+
reader with an explicit Content-Length rather than a one-shot bytes payload."""
|
|
181
|
+
payload = json.dumps(body).encode()
|
|
182
|
+
headers = {"Content-Type": "application/json", "Content-Length": str(len(payload))}
|
|
183
|
+
if self.api_key:
|
|
184
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
185
|
+
req = urllib.request.Request(
|
|
186
|
+
f"{self.api_url}{path}",
|
|
187
|
+
method="POST",
|
|
188
|
+
data=_ProgressReader(payload, progress),
|
|
189
|
+
headers=headers,
|
|
190
|
+
)
|
|
191
|
+
try:
|
|
192
|
+
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
|
193
|
+
raw = resp.read()
|
|
194
|
+
return json.loads(raw) if raw else {}
|
|
195
|
+
except urllib.error.HTTPError as exc:
|
|
196
|
+
detail = self._auth_error_detail(exc.code, _detail_from_http_error(exc))
|
|
197
|
+
raise ApiError(exc.code, detail) from exc
|
|
198
|
+
except urllib.error.URLError as exc:
|
|
199
|
+
raise ClientError(
|
|
200
|
+
f"cannot reach the Flash service at {self.api_url} ({exc.reason}); "
|
|
201
|
+
"check your network connection and FLASH_API_URL"
|
|
202
|
+
) from exc
|
|
203
|
+
|
|
204
|
+
# -- identity ----------------------------------------------------------------------
|
|
205
|
+
def me(self) -> dict:
|
|
206
|
+
return self._request("GET", "/v1/me")
|
|
207
|
+
|
|
208
|
+
def health(self) -> dict:
|
|
209
|
+
return self._request("GET", "/v1/health", timeout=10.0)
|
|
210
|
+
|
|
211
|
+
# -- environments ------------------------------------------------------------------
|
|
212
|
+
def publish_env(
|
|
213
|
+
self,
|
|
214
|
+
*,
|
|
215
|
+
name: str,
|
|
216
|
+
package_b64: str,
|
|
217
|
+
progress: ProgressCallback | None = None,
|
|
218
|
+
) -> dict:
|
|
219
|
+
"""Upload a packaged Freesolo environment to the managed Environments Hub.
|
|
220
|
+
|
|
221
|
+
When ``progress`` is given the body streams to the server in chunks and
|
|
222
|
+
``progress(bytes_sent, total_bytes)`` fires for each, so the CLI can render an upload
|
|
223
|
+
bar; otherwise the body is sent in one shot (the default, used off a TTY)."""
|
|
224
|
+
body = {"name": name, "package_b64": package_b64}
|
|
225
|
+
if progress is None:
|
|
226
|
+
return self._request("POST", "/v1/envs", body=body, timeout=1800.0)
|
|
227
|
+
return self._post_with_progress("/v1/envs", body, progress=progress, timeout=1800.0)
|
|
228
|
+
|
|
229
|
+
# -- runs --------------------------------------------------------------------------
|
|
230
|
+
def create_run(self, spec: dict, runtime_secrets: dict[str, str] | None = None) -> dict:
|
|
231
|
+
body = {"spec": spec}
|
|
232
|
+
if runtime_secrets:
|
|
233
|
+
body["runtime_secrets"] = runtime_secrets
|
|
234
|
+
return self._request("POST", "/v1/runs", body=body)
|
|
235
|
+
|
|
236
|
+
def list_runs(self) -> list[dict]:
|
|
237
|
+
return self._request("GET", "/v1/runs")["runs"]
|
|
238
|
+
|
|
239
|
+
def get_run(self, run_id: str) -> dict:
|
|
240
|
+
return self._request("GET", f"/v1/runs/{run_id}")
|
|
241
|
+
|
|
242
|
+
def get_logs(self, run_id: str, offset: int = 0) -> dict:
|
|
243
|
+
return self._request("GET", f"/v1/runs/{run_id}/logs?offset={int(offset)}")
|
|
244
|
+
|
|
245
|
+
def get_worker_output(self, run_id: str) -> dict[str, str]:
|
|
246
|
+
# The train-subprocess console/traceback ({console_<phase>.txt, error_<phase>.txt}) from the
|
|
247
|
+
# run's HF artifact repo, fetched server-side with the operator token — the real worker
|
|
248
|
+
# output the offset-paged log can't carry. Kept off the hot get_logs poll path. {} if none.
|
|
249
|
+
#
|
|
250
|
+
# Tolerate a managed server that predates the /worker route: a CLI upgraded ahead of the
|
|
251
|
+
# service rollout would otherwise hard-fail. FastAPI returns a bare 404 "Not Found" for an
|
|
252
|
+
# unmatched path -> treat ONLY that as "no worker output" ({}); real 404s still surface (an
|
|
253
|
+
# unknown run_id carries detail "unknown run_id: ...", not "Not Found").
|
|
254
|
+
try:
|
|
255
|
+
return self._request("GET", f"/v1/runs/{run_id}/worker").get("worker", {})
|
|
256
|
+
except ApiError as exc:
|
|
257
|
+
if exc.status == 404 and str(exc).strip().lower() == "not found":
|
|
258
|
+
return {}
|
|
259
|
+
raise
|
|
260
|
+
|
|
261
|
+
def cancel_run(self, run_id: str) -> dict:
|
|
262
|
+
return self._request("POST", f"/v1/runs/{run_id}/cancel")
|
|
263
|
+
|
|
264
|
+
def checkpoints(self, run_id: str) -> list[dict]:
|
|
265
|
+
"""Deployable per-step RL checkpoints for a run (each `flash deploy --step N`-able)."""
|
|
266
|
+
return self._request("GET", f"/v1/runs/{run_id}/checkpoints")["checkpoints"]
|
|
267
|
+
|
|
268
|
+
# -- serving -----------------------------------------------------------------------
|
|
269
|
+
def deploy(
|
|
270
|
+
self,
|
|
271
|
+
run_id: str,
|
|
272
|
+
dry_run: bool = False,
|
|
273
|
+
step: int | None = None,
|
|
274
|
+
) -> dict:
|
|
275
|
+
# Deploy blocks on registration and serving warmup, which can take many minutes.
|
|
276
|
+
deploy_timeout = 30 * 60 if not dry_run else None
|
|
277
|
+
body: dict = {"dry_run": dry_run}
|
|
278
|
+
if step is not None:
|
|
279
|
+
# Deploy a specific intermediate checkpoint instead of the run's final adapter.
|
|
280
|
+
# Reject a bool explicitly: `int(True)`/`int(False)` would silently coerce to step
|
|
281
|
+
# 1/0, but the server guard (_resolve_deploy_step) treats a bool as an invalid step
|
|
282
|
+
# and 400s — so fail fast here with a clear client-side error instead of sending a
|
|
283
|
+
# bogus 0/1 that the server rejects (or, worse, that hits a real checkpoint 0/1).
|
|
284
|
+
if isinstance(step, bool):
|
|
285
|
+
raise ClientError(f"invalid checkpoint step: {step!r} (must be an integer)")
|
|
286
|
+
body["step"] = int(step)
|
|
287
|
+
return self._request(
|
|
288
|
+
"POST",
|
|
289
|
+
f"/v1/runs/{run_id}/deploy",
|
|
290
|
+
body=body,
|
|
291
|
+
timeout=deploy_timeout,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
def undeploy(self, run_id: str) -> dict:
|
|
295
|
+
return self._request("DELETE", f"/v1/runs/{run_id}/deploy")
|
|
296
|
+
|
|
297
|
+
def deployments(self) -> list[dict]:
|
|
298
|
+
return self._request("GET", "/v1/deployments")["deployments"]
|
|
299
|
+
|
|
300
|
+
def chat(
|
|
301
|
+
self,
|
|
302
|
+
run_id: str,
|
|
303
|
+
messages: list[dict],
|
|
304
|
+
temperature: float = 0.0,
|
|
305
|
+
max_tokens: int = 512,
|
|
306
|
+
) -> dict:
|
|
307
|
+
# Serving warmup can take minutes; give inference a generous timeout.
|
|
308
|
+
return self._request(
|
|
309
|
+
"POST",
|
|
310
|
+
f"/v1/runs/{run_id}/chat",
|
|
311
|
+
body={"messages": messages, "temperature": temperature, "max_tokens": max_tokens},
|
|
312
|
+
timeout=30 * 60,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
def chat_stream(
|
|
316
|
+
self,
|
|
317
|
+
run_id: str,
|
|
318
|
+
messages: list[dict],
|
|
319
|
+
temperature: float = 0.0,
|
|
320
|
+
max_tokens: int = 512,
|
|
321
|
+
) -> Iterator[str]:
|
|
322
|
+
headers = {"Content-Type": "application/json"}
|
|
323
|
+
if self.api_key:
|
|
324
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
325
|
+
req = urllib.request.Request(
|
|
326
|
+
f"{self.api_url}/v1/runs/{run_id}/chat",
|
|
327
|
+
method="POST",
|
|
328
|
+
data=json.dumps(
|
|
329
|
+
{
|
|
330
|
+
"messages": messages,
|
|
331
|
+
"temperature": temperature,
|
|
332
|
+
"max_tokens": max_tokens,
|
|
333
|
+
"stream": True,
|
|
334
|
+
}
|
|
335
|
+
).encode(),
|
|
336
|
+
headers=headers,
|
|
337
|
+
)
|
|
338
|
+
decoder = codecs.getincrementaldecoder("utf-8")()
|
|
339
|
+
try:
|
|
340
|
+
with urllib.request.urlopen(req, timeout=30 * 60) as resp:
|
|
341
|
+
content_type = resp.headers.get("Content-Type", "")
|
|
342
|
+
if "application/json" in content_type:
|
|
343
|
+
payload = json.loads(resp.read() or b"{}")
|
|
344
|
+
content = (((payload.get("choices") or [{}])[0].get("message") or {}).get("content"))
|
|
345
|
+
if content:
|
|
346
|
+
yield str(content)
|
|
347
|
+
return
|
|
348
|
+
while raw := resp.read(1):
|
|
349
|
+
chunk = decoder.decode(raw)
|
|
350
|
+
if chunk:
|
|
351
|
+
yield chunk
|
|
352
|
+
tail = decoder.decode(b"", final=True)
|
|
353
|
+
if tail:
|
|
354
|
+
yield tail
|
|
355
|
+
except urllib.error.HTTPError as exc:
|
|
356
|
+
detail = self._auth_error_detail(exc.code, _detail_from_http_error(exc))
|
|
357
|
+
raise ApiError(exc.code, detail) from exc
|
|
358
|
+
except urllib.error.URLError as exc:
|
|
359
|
+
raise ClientError(
|
|
360
|
+
f"cannot reach the Flash service at {self.api_url} ({exc.reason}); "
|
|
361
|
+
"check your network connection and FLASH_API_URL"
|
|
362
|
+
) from exc
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def client_from_config(require_key: bool = True) -> ApiClient:
|
|
366
|
+
"""Build a client from the stored credentials; fail with a clear hint when logged out."""
|
|
367
|
+
api_url, api_key, key_source = load_credentials_with_source()
|
|
368
|
+
if require_key and not api_key:
|
|
369
|
+
raise ClientError(
|
|
370
|
+
"not logged in — run `flash login` with your freesolo API key (or set FREESOLO_API_KEY)"
|
|
371
|
+
)
|
|
372
|
+
return ApiClient(api_url, api_key, key_source=key_source)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Client-side runtime secrets for managed runs.
|
|
2
|
+
|
|
3
|
+
These values are read on the user's machine and sent only with the submit request. They are not
|
|
4
|
+
part of JobSpec/TOML, and the control plane must not persist them in run status or artifacts.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
DEFAULT_RUNTIME_SECRET_KEYS = frozenset({"WANDB_API_KEY"})
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _runtime_secret_keys(keys: tuple[str, ...] | list[str] | set[str] | None = None) -> set[str]:
|
|
16
|
+
return set(DEFAULT_RUNTIME_SECRET_KEYS) | {str(key) for key in (keys or ())}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _read_env_file(path: Path, keys: set[str]) -> dict[str, str]:
|
|
20
|
+
if not path.exists() or not path.is_file():
|
|
21
|
+
return {}
|
|
22
|
+
out: dict[str, str] = {}
|
|
23
|
+
try:
|
|
24
|
+
lines = path.read_text(errors="ignore").splitlines()
|
|
25
|
+
except OSError:
|
|
26
|
+
return out
|
|
27
|
+
for line in lines:
|
|
28
|
+
s = line.strip()
|
|
29
|
+
if not s or s.startswith("#") or "=" not in s:
|
|
30
|
+
continue
|
|
31
|
+
key, value = s.split("=", 1)
|
|
32
|
+
key = key.strip()
|
|
33
|
+
if key not in keys:
|
|
34
|
+
continue
|
|
35
|
+
value = value.strip().strip('"').strip("'")
|
|
36
|
+
if value:
|
|
37
|
+
out[key] = value
|
|
38
|
+
return out
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def runtime_secrets_from_local_env(
|
|
42
|
+
config_path: str | os.PathLike[str] | None = None,
|
|
43
|
+
keys: tuple[str, ...] | list[str] | set[str] | None = None,
|
|
44
|
+
) -> dict[str, str]:
|
|
45
|
+
"""Collect supported run secrets from the user's local env.
|
|
46
|
+
|
|
47
|
+
Process environment wins. As a convenience for local project workflows, also read `.env` and
|
|
48
|
+
`.env.local` in the current directory and next to the config file. This deliberately does not
|
|
49
|
+
scan arbitrary parent directories or serialize secrets into the run spec.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
wanted = _runtime_secret_keys(keys)
|
|
53
|
+
secrets = {key: value for key in wanted if (value := os.environ.get(key))}
|
|
54
|
+
candidates = [Path.cwd() / ".env", Path.cwd() / ".env.local"]
|
|
55
|
+
if config_path:
|
|
56
|
+
cfg_dir = Path(config_path).expanduser().resolve().parent
|
|
57
|
+
candidates.extend([cfg_dir / ".env", cfg_dir / ".env.local"])
|
|
58
|
+
for path in candidates:
|
|
59
|
+
for key, value in _read_env_file(path, wanted).items():
|
|
60
|
+
secrets.setdefault(key, value)
|
|
61
|
+
required = {str(key) for key in (keys or ())}
|
|
62
|
+
missing = sorted(required - set(secrets))
|
|
63
|
+
if missing:
|
|
64
|
+
raise ValueError(
|
|
65
|
+
"missing declared environment secret(s): "
|
|
66
|
+
f"{', '.join(missing)}. Set them in your shell or local .env file before submitting; "
|
|
67
|
+
"do not put secret values in TOML."
|
|
68
|
+
)
|
|
69
|
+
return secrets
|
flash/client/specs.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Turn a locally validated JobSpec into the payload sent to the control plane.
|
|
2
|
+
|
|
3
|
+
The client fills default pip requirements for Freesolo environments unless the
|
|
4
|
+
config provided an explicit ``[environment] pip`` escape hatch.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from flash.spec import JobSpec
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def spec_payload(spec: JobSpec) -> dict:
|
|
13
|
+
out = spec.to_dict()
|
|
14
|
+
if not spec.environment.pip:
|
|
15
|
+
from flash.envs.registry import worker_pip_for_env
|
|
16
|
+
|
|
17
|
+
pip = worker_pip_for_env(spec.environment.id)
|
|
18
|
+
if pip:
|
|
19
|
+
out["environment"]["pip"] = pip
|
|
20
|
+
return out
|
flash/cost/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Flash training-cost estimator: a deterministic, equation-based pre-flight estimate
|
|
2
|
+
(``estimate_cost``) of cost = wall-clock hours x market $/hr. No output multiplier."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from .analytical import estimate_cost
|
|
7
|
+
from .spec import estimate_for_spec, runconfig_from_spec
|
|
8
|
+
from .types import CostEstimate, RunConfig
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"CostEstimate",
|
|
12
|
+
"RunConfig",
|
|
13
|
+
"estimate_cost",
|
|
14
|
+
"estimate_for_spec",
|
|
15
|
+
"runconfig_from_spec",
|
|
16
|
+
]
|
flash/cost/analytical.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""The analytical cost model: total = wall-clock hours x GPU $/hr, where wall = cold-start
|
|
2
|
+
setup + steps x per-step time (a FLOPs/MFU estimate). GRPO splits each step into a vLLM
|
|
3
|
+
rollout + reward grading + policy/reference update."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
|
|
9
|
+
from flash.providers.allocator import required_vram_gb, vram_headroom
|
|
10
|
+
|
|
11
|
+
from .facts import (
|
|
12
|
+
download_weight_gb,
|
|
13
|
+
gpu_hourly_usd,
|
|
14
|
+
gpu_tflops,
|
|
15
|
+
gpu_vram_gb,
|
|
16
|
+
model_quant,
|
|
17
|
+
pick_gpu,
|
|
18
|
+
reward_seconds_per_completion,
|
|
19
|
+
total_params_b,
|
|
20
|
+
)
|
|
21
|
+
from .types import CostEstimate, RunConfig
|
|
22
|
+
|
|
23
|
+
# FLOPs per token per active-parameter.
|
|
24
|
+
SFT_FLOPS_PER_TOKEN_PER_PARAM = 6.0 # forward (2) + backward (4)
|
|
25
|
+
GRPO_GEN_FLOPS_PER_TOKEN_PER_PARAM = 2.0 # autoregressive rollout forward
|
|
26
|
+
GRPO_UPDATE_FLOPS_PER_TOKEN_PER_PARAM = 8.0 # policy fwd+bwd (6) + frozen-ref fwd (2)
|
|
27
|
+
|
|
28
|
+
# Model-FLOPs utilization (fraction of peak sustained), calibrated against real RunPod
|
|
29
|
+
# wall clock. LoRA + small batches sit well below dense-pretraining MFU.
|
|
30
|
+
MFU_TRAIN = 0.35 # GRPO policy/reference update
|
|
31
|
+
MFU_SFT_TRAIN = 0.25 # SFT fwd/bwd (smaller effective batch, long sequences)
|
|
32
|
+
MFU_DECODE = 0.12 # batched vLLM rollout (decode is memory-bandwidth-bound)
|
|
33
|
+
|
|
34
|
+
# Reward grading is CONCURRENT: a step's completions score in parallel slots, so the reward
|
|
35
|
+
# wall is ceil(completions / slots) waves x latency, not completions x latency.
|
|
36
|
+
REWARD_CONCURRENCY = 16.0
|
|
37
|
+
|
|
38
|
+
# Cold-start overhead (seconds): container boot + deps + model load (+ vLLM init for GRPO).
|
|
39
|
+
#
|
|
40
|
+
# Calibrated against a real fresh-worker run (0.8B SFT, RTX 3090 @ $0.239/hr) whose billed wall
|
|
41
|
+
# was ~708s for only ~26 priced steps -- i.e. cold start, not training, dominated. A fresh worker
|
|
42
|
+
# spent ~12.5 min in `sft_model_load` alone (download + checkpoint deserialize + GPU placement +
|
|
43
|
+
# framework/CUDA init), so the MODEL-LOAD term -- not boot/deps -- is the dominant cost of a short
|
|
44
|
+
# job. MODEL_LOAD_BASE_S is the fixed (size-independent) load/init overhead; the download term on
|
|
45
|
+
# top of it scales with checkpoint size, so bigger models pay a longer cold start.
|
|
46
|
+
WORKER_BOOT_S = 120.0 # container pull + start
|
|
47
|
+
DEPS_INSTALL_S = 90.0 # pip/uv resolve + install
|
|
48
|
+
MODEL_LOAD_BASE_S = 235.0 # fixed checkpoint deserialize + GPU placement + framework/CUDA init
|
|
49
|
+
VLLM_INIT_S = 120.0
|
|
50
|
+
DOWNLOAD_RATE_GBPS = 0.4 # effective HF snapshot download (hf_transfer), on top of the base load
|
|
51
|
+
|
|
52
|
+
DEFAULT_WALL_CAP_S = 24 * 3600 # spec gpu.max_wall_seconds default
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _fmt_duration(seconds: float) -> str:
|
|
56
|
+
"""Human duration for notes: seconds < 1m, minutes < 1h, else whole/1-decimal hours."""
|
|
57
|
+
if seconds < 60:
|
|
58
|
+
return f"{seconds:.0f}s"
|
|
59
|
+
if seconds < 3600:
|
|
60
|
+
return f"{seconds / 60:.0f}m"
|
|
61
|
+
hours = seconds / 3600
|
|
62
|
+
return f"{hours:.0f}h" if abs(hours - round(hours)) < 1e-9 else f"{hours:.1f}h"
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def setup_seconds(config: RunConfig) -> float:
|
|
66
|
+
"""Cold-start wall time billed before the first optimizer step: container boot + deps + model
|
|
67
|
+
load (a fixed deserialize/placement/init base + a size-scaled download), plus vLLM init for
|
|
68
|
+
GRPO. The model-load term dominates a short job's bill (see the constants above)."""
|
|
69
|
+
model_load = MODEL_LOAD_BASE_S + download_weight_gb(config.model_id) / DOWNLOAD_RATE_GBPS
|
|
70
|
+
s = WORKER_BOOT_S + DEPS_INSTALL_S + model_load
|
|
71
|
+
if config.is_grpo:
|
|
72
|
+
s += VLLM_INIT_S
|
|
73
|
+
return s
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def seconds_per_step(config: RunConfig, gpu: str) -> float:
|
|
77
|
+
"""Steady-state wall time for one optimizer step on ``gpu``."""
|
|
78
|
+
n = config.normalized()
|
|
79
|
+
params = total_params_b(n.model_id) * 1e9
|
|
80
|
+
peak = gpu_tflops(gpu) * 1e12 # FLOP/s
|
|
81
|
+
|
|
82
|
+
if not n.is_grpo:
|
|
83
|
+
flops = SFT_FLOPS_PER_TOKEN_PER_PARAM * params * (n.batch_size * n.seq_len)
|
|
84
|
+
return flops / (peak * MFU_SFT_TRAIN)
|
|
85
|
+
|
|
86
|
+
# GRPO step = rollout (G completions/prompt) + concurrent reward grading + policy/ref update.
|
|
87
|
+
completions = n.batch_size * n.group_size
|
|
88
|
+
gen_tokens = completions * n.completion_len
|
|
89
|
+
gen_s = (GRPO_GEN_FLOPS_PER_TOKEN_PER_PARAM * params * gen_tokens) / (peak * MFU_DECODE)
|
|
90
|
+
update_s = (GRPO_UPDATE_FLOPS_PER_TOKEN_PER_PARAM * params * gen_tokens) / (peak * MFU_TRAIN)
|
|
91
|
+
latency = reward_seconds_per_completion(n.reward_seconds_per_completion)
|
|
92
|
+
reward_s = math.ceil(completions / REWARD_CONCURRENCY) * latency # ceil: a partial wave still costs one latency
|
|
93
|
+
return gen_s + reward_s + update_s
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def select_gpu(config: RunConfig) -> tuple[str, int]:
|
|
97
|
+
"""(chosen GPU class, required VRAM GB): the cheapest fitting class for the cost.
|
|
98
|
+
|
|
99
|
+
Uses ``pick_gpu``, which (unlike the submit-time allocator) intentionally stays gate-free —
|
|
100
|
+
it considers every fitting class, validated or not — so the estimate reflects the cheapest
|
|
101
|
+
card that *could* run the job. The live allocator restricts to the validated pool, so the
|
|
102
|
+
actually-provisioned class can be pricier than this. Catalog sizing is offline/deterministic."""
|
|
103
|
+
total_params_b(config.model_id) # catalog-only: reject a non-catalog model before any (HF) sizing
|
|
104
|
+
need = required_vram_gb(
|
|
105
|
+
config.model_id,
|
|
106
|
+
config.method,
|
|
107
|
+
train=config.train_knobs(),
|
|
108
|
+
thinking=config.thinking,
|
|
109
|
+
)
|
|
110
|
+
gpu = pick_gpu(need, provider=config.provider)
|
|
111
|
+
return gpu, need
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _notes(config: RunConfig, raw_train_s: float, wall_capped: bool, cap_s: float) -> tuple[str, ...]:
|
|
115
|
+
n = config.normalized()
|
|
116
|
+
notes: list[str] = []
|
|
117
|
+
if (quant := model_quant(n.model_id)) != "bf16":
|
|
118
|
+
notes.append(f"{quant}: smaller VRAM footprint -> cheaper GPU class fits")
|
|
119
|
+
if n.is_grpo:
|
|
120
|
+
comps = n.batch_size * n.group_size
|
|
121
|
+
rsec = reward_seconds_per_completion(n.reward_seconds_per_completion)
|
|
122
|
+
notes.append(
|
|
123
|
+
f"GRPO step = vLLM rollout of {n.batch_size}x{n.group_size}={comps} completions "
|
|
124
|
+
f"@ {n.completion_len} tok + reward ({rsec:.2f}s/completion"
|
|
125
|
+
+ (f", env {n.environment}" if n.environment else "")
|
|
126
|
+
+ ") + policy+reference update"
|
|
127
|
+
)
|
|
128
|
+
notes.append(f"GPU sized with {vram_headroom() - 1:.0%} VRAM headroom; static GPU $/hr")
|
|
129
|
+
if wall_capped:
|
|
130
|
+
per_seed = "" if config.setup_repeats == 1 else "per-seed "
|
|
131
|
+
notes.append(
|
|
132
|
+
f"training clamped to fit the {_fmt_duration(cap_s)} {per_seed}wall cap "
|
|
133
|
+
f"(after setup; uncapped: {_fmt_duration(raw_train_s)})"
|
|
134
|
+
)
|
|
135
|
+
return tuple(notes)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def estimate_cost(config: RunConfig, *, wall_cap_s: float = DEFAULT_WALL_CAP_S) -> CostEstimate:
|
|
139
|
+
"""Deterministic pre-flight cost calculation."""
|
|
140
|
+
gpu, need = select_gpu(config)
|
|
141
|
+
hourly = gpu_hourly_usd(gpu, provider=config.provider)
|
|
142
|
+
# Mirror the runner's max(60, max_wall_seconds) floor so a sub-60s cap isn't underpriced.
|
|
143
|
+
cap_s = max(60.0, float(config.max_wall_seconds)) if config.max_wall_seconds is not None else wall_cap_s
|
|
144
|
+
|
|
145
|
+
# Each seed is its own job (own cold start + own wall cap): price one seed, clamp, x seeds.
|
|
146
|
+
seeds = config.setup_repeats
|
|
147
|
+
setup_per_seed = setup_seconds(config)
|
|
148
|
+
sps = seconds_per_step(config, gpu)
|
|
149
|
+
raw_train_per_seed = (config.steps / seeds) * sps
|
|
150
|
+
|
|
151
|
+
# The cap is on total per-seed wall; setup is billed too, so clamp training to fit it.
|
|
152
|
+
wall_capped = (setup_per_seed + raw_train_per_seed) > cap_s
|
|
153
|
+
setup_per_seed = min(setup_per_seed, cap_s)
|
|
154
|
+
train_per_seed = max(0.0, cap_s - setup_per_seed) if wall_capped else raw_train_per_seed
|
|
155
|
+
|
|
156
|
+
setup, train = setup_per_seed * seeds, train_per_seed * seeds
|
|
157
|
+
wall = setup + train
|
|
158
|
+
|
|
159
|
+
return CostEstimate(
|
|
160
|
+
model_id=config.model_id,
|
|
161
|
+
method=config.method,
|
|
162
|
+
steps=config.steps,
|
|
163
|
+
gpu=gpu,
|
|
164
|
+
provider=config.provider,
|
|
165
|
+
gpu_vram_gb=gpu_vram_gb(gpu),
|
|
166
|
+
required_vram_gb=need,
|
|
167
|
+
gpu_hourly_usd=hourly,
|
|
168
|
+
setup_seconds=setup,
|
|
169
|
+
seconds_per_step=sps,
|
|
170
|
+
train_seconds=train,
|
|
171
|
+
wall_clock_seconds=wall,
|
|
172
|
+
wall_capped=wall_capped,
|
|
173
|
+
total_usd=wall / 3600.0 * hourly,
|
|
174
|
+
notes=_notes(config, raw_train_per_seed, wall_capped, cap_s),
|
|
175
|
+
)
|