freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
flash/client/http.py ADDED
@@ -0,0 +1,372 @@
1
+ """Stdlib HTTP client for the Flash control plane (no extra dependencies).
2
+
3
+ Every CLI/MCP operation maps to one method here. Server errors (FastAPI's
4
+ ``{"detail": ...}``) surface as ``ApiError`` with the server's message; connection
5
+ problems surface as ``ClientError`` with an actionable hint.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import codecs
11
+ import contextlib
12
+ import json
13
+ import os
14
+ import urllib.error
15
+ import urllib.request
16
+ from collections.abc import Callable, Iterator
17
+ from typing import Any
18
+
19
+ from .config import load_credentials_with_source
20
+
21
+ # Called as ``progress(bytes_sent, total_bytes)`` as a request body streams to the server, so
22
+ # the CLI can draw an upload bar. ``total_bytes`` is the full Content-Length, fixed up front.
23
+ ProgressCallback = Callable[[int, int], None]
24
+
25
+
26
+ class ClientError(RuntimeError):
27
+ """Expected client-side errors (no key, unreachable server) — printed cleanly."""
28
+
29
+
30
+ class ApiError(ClientError):
31
+ def __init__(self, status: int, message: str):
32
+ super().__init__(message)
33
+ self.status = status
34
+
35
+
36
+ # Login is handled by the freesolo backend (not the flash control plane): `flash login`
37
+ # verifies the user's freesolo API key here. The same key authenticates the flash
38
+ # control plane, which accepts freesolo-issued keys.
39
+ DEFAULT_FREESOLO_BASE_URL = "https://api.freesolo.co"
40
+ FREESOLO_AUTH_VERIFY_PATH = "/api/auth/verify"
41
+
42
+
43
+ def freesolo_base_url(override: str | None = None) -> str:
44
+ return (override or os.environ.get("FREESOLO_BASE_URL") or DEFAULT_FREESOLO_BASE_URL).rstrip(
45
+ "/"
46
+ )
47
+
48
+
49
+ def _detail_from_http_error(exc: urllib.error.HTTPError) -> str:
50
+ """Extract the server's error message from an HTTPError body (FastAPI ``detail``)."""
51
+ body = exc.read()
52
+ try:
53
+ detail = json.loads(body).get("detail") or body.decode()
54
+ except (ValueError, AttributeError):
55
+ detail = body.decode(errors="replace") if body else str(exc)
56
+ return str(detail)
57
+
58
+
59
+ def verify_freesolo_key(api_key: str, base_url: str | None = None) -> None:
60
+ """Verify a freesolo API key against the freesolo backend's ``/api/auth/verify``.
61
+
62
+ Raises :class:`ClientError`/:class:`ApiError` if the key is rejected or the backend is
63
+ unreachable; returns ``None`` on success. Keys are issued from the freesolo sign-in page.
64
+ """
65
+ base = freesolo_base_url(base_url)
66
+ url = f"{base}{FREESOLO_AUTH_VERIFY_PATH}"
67
+ req = urllib.request.Request(
68
+ url,
69
+ method="GET",
70
+ headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
71
+ )
72
+ try:
73
+ with urllib.request.urlopen(req, timeout=30) as resp:
74
+ resp.read()
75
+ except urllib.error.HTTPError as exc:
76
+ if exc.code in (401, 403):
77
+ raise ClientError(
78
+ "freesolo rejected this API key — create or copy a valid key at "
79
+ "https://freesolo.co/sign-in and pass it with `flash login --api-key` "
80
+ "(or FREESOLO_API_KEY)"
81
+ ) from exc
82
+ raise ApiError(exc.code, _detail_from_http_error(exc)) from exc
83
+ except urllib.error.URLError as exc:
84
+ raise ClientError(
85
+ f"cannot reach the freesolo backend at {base} ({exc.reason}); "
86
+ "check your network connection and FREESOLO_BASE_URL"
87
+ ) from exc
88
+
89
+
90
+ class _ProgressReader:
91
+ """A read()-only file-like over an in-memory payload that reports bytes consumed.
92
+
93
+ ``http.client`` sends a body exposing ``read()`` in blocksize chunks; we forward the running
94
+ total to ``progress(sent, total)`` for each chunk so the CLI can draw an upload bar. The
95
+ caller sets Content-Length from ``len(payload)``, so the request is NOT chunked-encoded and
96
+ the server reads it exactly as a plain bytes body."""
97
+
98
+ def __init__(self, data: bytes, progress: ProgressCallback):
99
+ self._data = data
100
+ self._total = len(data)
101
+ self._pos = 0
102
+ self._progress = progress
103
+
104
+ def __len__(self) -> int:
105
+ return self._total
106
+
107
+ def read(self, size: int = -1) -> bytes:
108
+ if size is None or size < 0:
109
+ chunk = self._data[self._pos :]
110
+ else:
111
+ chunk = self._data[self._pos : self._pos + size]
112
+ self._pos += len(chunk)
113
+ # a rendering hiccup must never abort an in-flight upload
114
+ with contextlib.suppress(Exception):
115
+ self._progress(self._pos, self._total)
116
+ return chunk
117
+
118
+
119
+ class ApiClient:
120
+ def __init__(
121
+ self,
122
+ api_url: str,
123
+ api_key: str | None = None,
124
+ timeout: float = 60.0,
125
+ key_source: str | None = None,
126
+ ):
127
+ self.api_url = api_url.rstrip("/")
128
+ self.api_key = api_key
129
+ self.timeout = timeout
130
+ self.key_source = key_source
131
+
132
+ def _auth_error_detail(self, status: int, detail: str) -> str:
133
+ if status not in {401, 403} or self.key_source != "FREESOLO_API_KEY":
134
+ return detail
135
+ return (
136
+ f"{detail}; FREESOLO_API_KEY is set and overrides the key saved by "
137
+ "`flash login`. Unset FREESOLO_API_KEY or update it to a valid freesolo API key."
138
+ )
139
+
140
+ def _request(
141
+ self,
142
+ method: str,
143
+ path: str,
144
+ body: dict | None = None,
145
+ timeout: float | None = None,
146
+ ) -> Any:
147
+ headers = {"Content-Type": "application/json"}
148
+ if self.api_key:
149
+ headers["Authorization"] = f"Bearer {self.api_key}"
150
+ req = urllib.request.Request(
151
+ f"{self.api_url}{path}",
152
+ method=method,
153
+ data=json.dumps(body).encode() if body is not None else None,
154
+ headers=headers,
155
+ )
156
+ try:
157
+ with urllib.request.urlopen(req, timeout=timeout or self.timeout) as resp:
158
+ raw = resp.read()
159
+ return json.loads(raw) if raw else {}
160
+ except urllib.error.HTTPError as exc:
161
+ detail = self._auth_error_detail(exc.code, _detail_from_http_error(exc))
162
+ raise ApiError(exc.code, detail) from exc
163
+ except urllib.error.URLError as exc:
164
+ raise ClientError(
165
+ f"cannot reach the Flash service at {self.api_url} ({exc.reason}); "
166
+ "check your network connection and FLASH_API_URL"
167
+ ) from exc
168
+
169
+ def _post_with_progress(
170
+ self,
171
+ path: str,
172
+ body: dict,
173
+ *,
174
+ progress: ProgressCallback,
175
+ timeout: float,
176
+ ) -> Any:
177
+ """POST a JSON body while reporting upload progress (see :class:`_ProgressReader`).
178
+
179
+ Same error mapping as :meth:`_request`; kept separate because the body is a streaming
180
+ reader with an explicit Content-Length rather than a one-shot bytes payload."""
181
+ payload = json.dumps(body).encode()
182
+ headers = {"Content-Type": "application/json", "Content-Length": str(len(payload))}
183
+ if self.api_key:
184
+ headers["Authorization"] = f"Bearer {self.api_key}"
185
+ req = urllib.request.Request(
186
+ f"{self.api_url}{path}",
187
+ method="POST",
188
+ data=_ProgressReader(payload, progress),
189
+ headers=headers,
190
+ )
191
+ try:
192
+ with urllib.request.urlopen(req, timeout=timeout) as resp:
193
+ raw = resp.read()
194
+ return json.loads(raw) if raw else {}
195
+ except urllib.error.HTTPError as exc:
196
+ detail = self._auth_error_detail(exc.code, _detail_from_http_error(exc))
197
+ raise ApiError(exc.code, detail) from exc
198
+ except urllib.error.URLError as exc:
199
+ raise ClientError(
200
+ f"cannot reach the Flash service at {self.api_url} ({exc.reason}); "
201
+ "check your network connection and FLASH_API_URL"
202
+ ) from exc
203
+
204
+ # -- identity ----------------------------------------------------------------------
205
+ def me(self) -> dict:
206
+ return self._request("GET", "/v1/me")
207
+
208
+ def health(self) -> dict:
209
+ return self._request("GET", "/v1/health", timeout=10.0)
210
+
211
+ # -- environments ------------------------------------------------------------------
212
+ def publish_env(
213
+ self,
214
+ *,
215
+ name: str,
216
+ package_b64: str,
217
+ progress: ProgressCallback | None = None,
218
+ ) -> dict:
219
+ """Upload a packaged Freesolo environment to the managed Environments Hub.
220
+
221
+ When ``progress`` is given the body streams to the server in chunks and
222
+ ``progress(bytes_sent, total_bytes)`` fires for each, so the CLI can render an upload
223
+ bar; otherwise the body is sent in one shot (the default, used off a TTY)."""
224
+ body = {"name": name, "package_b64": package_b64}
225
+ if progress is None:
226
+ return self._request("POST", "/v1/envs", body=body, timeout=1800.0)
227
+ return self._post_with_progress("/v1/envs", body, progress=progress, timeout=1800.0)
228
+
229
+ # -- runs --------------------------------------------------------------------------
230
+ def create_run(self, spec: dict, runtime_secrets: dict[str, str] | None = None) -> dict:
231
+ body = {"spec": spec}
232
+ if runtime_secrets:
233
+ body["runtime_secrets"] = runtime_secrets
234
+ return self._request("POST", "/v1/runs", body=body)
235
+
236
+ def list_runs(self) -> list[dict]:
237
+ return self._request("GET", "/v1/runs")["runs"]
238
+
239
+ def get_run(self, run_id: str) -> dict:
240
+ return self._request("GET", f"/v1/runs/{run_id}")
241
+
242
+ def get_logs(self, run_id: str, offset: int = 0) -> dict:
243
+ return self._request("GET", f"/v1/runs/{run_id}/logs?offset={int(offset)}")
244
+
245
+ def get_worker_output(self, run_id: str) -> dict[str, str]:
246
+ # The train-subprocess console/traceback ({console_<phase>.txt, error_<phase>.txt}) from the
247
+ # run's HF artifact repo, fetched server-side with the operator token — the real worker
248
+ # output the offset-paged log can't carry. Kept off the hot get_logs poll path. {} if none.
249
+ #
250
+ # Tolerate a managed server that predates the /worker route: a CLI upgraded ahead of the
251
+ # service rollout would otherwise hard-fail. FastAPI returns a bare 404 "Not Found" for an
252
+ # unmatched path -> treat ONLY that as "no worker output" ({}); real 404s still surface (an
253
+ # unknown run_id carries detail "unknown run_id: ...", not "Not Found").
254
+ try:
255
+ return self._request("GET", f"/v1/runs/{run_id}/worker").get("worker", {})
256
+ except ApiError as exc:
257
+ if exc.status == 404 and str(exc).strip().lower() == "not found":
258
+ return {}
259
+ raise
260
+
261
+ def cancel_run(self, run_id: str) -> dict:
262
+ return self._request("POST", f"/v1/runs/{run_id}/cancel")
263
+
264
+ def checkpoints(self, run_id: str) -> list[dict]:
265
+ """Deployable per-step RL checkpoints for a run (each `flash deploy --step N`-able)."""
266
+ return self._request("GET", f"/v1/runs/{run_id}/checkpoints")["checkpoints"]
267
+
268
+ # -- serving -----------------------------------------------------------------------
269
+ def deploy(
270
+ self,
271
+ run_id: str,
272
+ dry_run: bool = False,
273
+ step: int | None = None,
274
+ ) -> dict:
275
+ # Deploy blocks on registration and serving warmup, which can take many minutes.
276
+ deploy_timeout = 30 * 60 if not dry_run else None
277
+ body: dict = {"dry_run": dry_run}
278
+ if step is not None:
279
+ # Deploy a specific intermediate checkpoint instead of the run's final adapter.
280
+ # Reject a bool explicitly: `int(True)`/`int(False)` would silently coerce to step
281
+ # 1/0, but the server guard (_resolve_deploy_step) treats a bool as an invalid step
282
+ # and 400s — so fail fast here with a clear client-side error instead of sending a
283
+ # bogus 0/1 that the server rejects (or, worse, that hits a real checkpoint 0/1).
284
+ if isinstance(step, bool):
285
+ raise ClientError(f"invalid checkpoint step: {step!r} (must be an integer)")
286
+ body["step"] = int(step)
287
+ return self._request(
288
+ "POST",
289
+ f"/v1/runs/{run_id}/deploy",
290
+ body=body,
291
+ timeout=deploy_timeout,
292
+ )
293
+
294
+ def undeploy(self, run_id: str) -> dict:
295
+ return self._request("DELETE", f"/v1/runs/{run_id}/deploy")
296
+
297
+ def deployments(self) -> list[dict]:
298
+ return self._request("GET", "/v1/deployments")["deployments"]
299
+
300
+ def chat(
301
+ self,
302
+ run_id: str,
303
+ messages: list[dict],
304
+ temperature: float = 0.0,
305
+ max_tokens: int = 512,
306
+ ) -> dict:
307
+ # Serving warmup can take minutes; give inference a generous timeout.
308
+ return self._request(
309
+ "POST",
310
+ f"/v1/runs/{run_id}/chat",
311
+ body={"messages": messages, "temperature": temperature, "max_tokens": max_tokens},
312
+ timeout=30 * 60,
313
+ )
314
+
315
+ def chat_stream(
316
+ self,
317
+ run_id: str,
318
+ messages: list[dict],
319
+ temperature: float = 0.0,
320
+ max_tokens: int = 512,
321
+ ) -> Iterator[str]:
322
+ headers = {"Content-Type": "application/json"}
323
+ if self.api_key:
324
+ headers["Authorization"] = f"Bearer {self.api_key}"
325
+ req = urllib.request.Request(
326
+ f"{self.api_url}/v1/runs/{run_id}/chat",
327
+ method="POST",
328
+ data=json.dumps(
329
+ {
330
+ "messages": messages,
331
+ "temperature": temperature,
332
+ "max_tokens": max_tokens,
333
+ "stream": True,
334
+ }
335
+ ).encode(),
336
+ headers=headers,
337
+ )
338
+ decoder = codecs.getincrementaldecoder("utf-8")()
339
+ try:
340
+ with urllib.request.urlopen(req, timeout=30 * 60) as resp:
341
+ content_type = resp.headers.get("Content-Type", "")
342
+ if "application/json" in content_type:
343
+ payload = json.loads(resp.read() or b"{}")
344
+ content = (((payload.get("choices") or [{}])[0].get("message") or {}).get("content"))
345
+ if content:
346
+ yield str(content)
347
+ return
348
+ while raw := resp.read(1):
349
+ chunk = decoder.decode(raw)
350
+ if chunk:
351
+ yield chunk
352
+ tail = decoder.decode(b"", final=True)
353
+ if tail:
354
+ yield tail
355
+ except urllib.error.HTTPError as exc:
356
+ detail = self._auth_error_detail(exc.code, _detail_from_http_error(exc))
357
+ raise ApiError(exc.code, detail) from exc
358
+ except urllib.error.URLError as exc:
359
+ raise ClientError(
360
+ f"cannot reach the Flash service at {self.api_url} ({exc.reason}); "
361
+ "check your network connection and FLASH_API_URL"
362
+ ) from exc
363
+
364
+
365
+ def client_from_config(require_key: bool = True) -> ApiClient:
366
+ """Build a client from the stored credentials; fail with a clear hint when logged out."""
367
+ api_url, api_key, key_source = load_credentials_with_source()
368
+ if require_key and not api_key:
369
+ raise ClientError(
370
+ "not logged in — run `flash login` with your freesolo API key (or set FREESOLO_API_KEY)"
371
+ )
372
+ return ApiClient(api_url, api_key, key_source=key_source)
@@ -0,0 +1,69 @@
1
+ """Client-side runtime secrets for managed runs.
2
+
3
+ These values are read on the user's machine and sent only with the submit request. They are not
4
+ part of JobSpec/TOML, and the control plane must not persist them in run status or artifacts.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import os
10
+ from pathlib import Path
11
+
12
+ DEFAULT_RUNTIME_SECRET_KEYS = frozenset({"WANDB_API_KEY"})
13
+
14
+
15
+ def _runtime_secret_keys(keys: tuple[str, ...] | list[str] | set[str] | None = None) -> set[str]:
16
+ return set(DEFAULT_RUNTIME_SECRET_KEYS) | {str(key) for key in (keys or ())}
17
+
18
+
19
+ def _read_env_file(path: Path, keys: set[str]) -> dict[str, str]:
20
+ if not path.exists() or not path.is_file():
21
+ return {}
22
+ out: dict[str, str] = {}
23
+ try:
24
+ lines = path.read_text(errors="ignore").splitlines()
25
+ except OSError:
26
+ return out
27
+ for line in lines:
28
+ s = line.strip()
29
+ if not s or s.startswith("#") or "=" not in s:
30
+ continue
31
+ key, value = s.split("=", 1)
32
+ key = key.strip()
33
+ if key not in keys:
34
+ continue
35
+ value = value.strip().strip('"').strip("'")
36
+ if value:
37
+ out[key] = value
38
+ return out
39
+
40
+
41
+ def runtime_secrets_from_local_env(
42
+ config_path: str | os.PathLike[str] | None = None,
43
+ keys: tuple[str, ...] | list[str] | set[str] | None = None,
44
+ ) -> dict[str, str]:
45
+ """Collect supported run secrets from the user's local env.
46
+
47
+ Process environment wins. As a convenience for local project workflows, also read `.env` and
48
+ `.env.local` in the current directory and next to the config file. This deliberately does not
49
+ scan arbitrary parent directories or serialize secrets into the run spec.
50
+ """
51
+
52
+ wanted = _runtime_secret_keys(keys)
53
+ secrets = {key: value for key in wanted if (value := os.environ.get(key))}
54
+ candidates = [Path.cwd() / ".env", Path.cwd() / ".env.local"]
55
+ if config_path:
56
+ cfg_dir = Path(config_path).expanduser().resolve().parent
57
+ candidates.extend([cfg_dir / ".env", cfg_dir / ".env.local"])
58
+ for path in candidates:
59
+ for key, value in _read_env_file(path, wanted).items():
60
+ secrets.setdefault(key, value)
61
+ required = {str(key) for key in (keys or ())}
62
+ missing = sorted(required - set(secrets))
63
+ if missing:
64
+ raise ValueError(
65
+ "missing declared environment secret(s): "
66
+ f"{', '.join(missing)}. Set them in your shell or local .env file before submitting; "
67
+ "do not put secret values in TOML."
68
+ )
69
+ return secrets
flash/client/specs.py ADDED
@@ -0,0 +1,20 @@
1
+ """Turn a locally validated JobSpec into the payload sent to the control plane.
2
+
3
+ The client fills default pip requirements for Freesolo environments unless the
4
+ config provided an explicit ``[environment] pip`` escape hatch.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from flash.spec import JobSpec
10
+
11
+
12
+ def spec_payload(spec: JobSpec) -> dict:
13
+ out = spec.to_dict()
14
+ if not spec.environment.pip:
15
+ from flash.envs.registry import worker_pip_for_env
16
+
17
+ pip = worker_pip_for_env(spec.environment.id)
18
+ if pip:
19
+ out["environment"]["pip"] = pip
20
+ return out
flash/cost/__init__.py ADDED
@@ -0,0 +1,16 @@
1
+ """Flash training-cost estimator: a deterministic, equation-based pre-flight estimate
2
+ (``estimate_cost``) of cost = wall-clock hours x market $/hr. No output multiplier."""
3
+
4
+ from __future__ import annotations
5
+
6
+ from .analytical import estimate_cost
7
+ from .spec import estimate_for_spec, runconfig_from_spec
8
+ from .types import CostEstimate, RunConfig
9
+
10
+ __all__ = [
11
+ "CostEstimate",
12
+ "RunConfig",
13
+ "estimate_cost",
14
+ "estimate_for_spec",
15
+ "runconfig_from_spec",
16
+ ]
@@ -0,0 +1,175 @@
1
+ """The analytical cost model: total = wall-clock hours x GPU $/hr, where wall = cold-start
2
+ setup + steps x per-step time (a FLOPs/MFU estimate). GRPO splits each step into a vLLM
3
+ rollout + reward grading + policy/reference update."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import math
8
+
9
+ from flash.providers.allocator import required_vram_gb, vram_headroom
10
+
11
+ from .facts import (
12
+ download_weight_gb,
13
+ gpu_hourly_usd,
14
+ gpu_tflops,
15
+ gpu_vram_gb,
16
+ model_quant,
17
+ pick_gpu,
18
+ reward_seconds_per_completion,
19
+ total_params_b,
20
+ )
21
+ from .types import CostEstimate, RunConfig
22
+
23
+ # FLOPs per token per active-parameter.
24
+ SFT_FLOPS_PER_TOKEN_PER_PARAM = 6.0 # forward (2) + backward (4)
25
+ GRPO_GEN_FLOPS_PER_TOKEN_PER_PARAM = 2.0 # autoregressive rollout forward
26
+ GRPO_UPDATE_FLOPS_PER_TOKEN_PER_PARAM = 8.0 # policy fwd+bwd (6) + frozen-ref fwd (2)
27
+
28
+ # Model-FLOPs utilization (fraction of peak sustained), calibrated against real RunPod
29
+ # wall clock. LoRA + small batches sit well below dense-pretraining MFU.
30
+ MFU_TRAIN = 0.35 # GRPO policy/reference update
31
+ MFU_SFT_TRAIN = 0.25 # SFT fwd/bwd (smaller effective batch, long sequences)
32
+ MFU_DECODE = 0.12 # batched vLLM rollout (decode is memory-bandwidth-bound)
33
+
34
+ # Reward grading is CONCURRENT: a step's completions score in parallel slots, so the reward
35
+ # wall is ceil(completions / slots) waves x latency, not completions x latency.
36
+ REWARD_CONCURRENCY = 16.0
37
+
38
+ # Cold-start overhead (seconds): container boot + deps + model load (+ vLLM init for GRPO).
39
+ #
40
+ # Calibrated against a real fresh-worker run (0.8B SFT, RTX 3090 @ $0.239/hr) whose billed wall
41
+ # was ~708s for only ~26 priced steps -- i.e. cold start, not training, dominated. A fresh worker
42
+ # spent ~12.5 min in `sft_model_load` alone (download + checkpoint deserialize + GPU placement +
43
+ # framework/CUDA init), so the MODEL-LOAD term -- not boot/deps -- is the dominant cost of a short
44
+ # job. MODEL_LOAD_BASE_S is the fixed (size-independent) load/init overhead; the download term on
45
+ # top of it scales with checkpoint size, so bigger models pay a longer cold start.
46
+ WORKER_BOOT_S = 120.0 # container pull + start
47
+ DEPS_INSTALL_S = 90.0 # pip/uv resolve + install
48
+ MODEL_LOAD_BASE_S = 235.0 # fixed checkpoint deserialize + GPU placement + framework/CUDA init
49
+ VLLM_INIT_S = 120.0
50
+ DOWNLOAD_RATE_GBPS = 0.4 # effective HF snapshot download (hf_transfer), on top of the base load
51
+
52
+ DEFAULT_WALL_CAP_S = 24 * 3600 # spec gpu.max_wall_seconds default
53
+
54
+
55
+ def _fmt_duration(seconds: float) -> str:
56
+ """Human duration for notes: seconds < 1m, minutes < 1h, else whole/1-decimal hours."""
57
+ if seconds < 60:
58
+ return f"{seconds:.0f}s"
59
+ if seconds < 3600:
60
+ return f"{seconds / 60:.0f}m"
61
+ hours = seconds / 3600
62
+ return f"{hours:.0f}h" if abs(hours - round(hours)) < 1e-9 else f"{hours:.1f}h"
63
+
64
+
65
+ def setup_seconds(config: RunConfig) -> float:
66
+ """Cold-start wall time billed before the first optimizer step: container boot + deps + model
67
+ load (a fixed deserialize/placement/init base + a size-scaled download), plus vLLM init for
68
+ GRPO. The model-load term dominates a short job's bill (see the constants above)."""
69
+ model_load = MODEL_LOAD_BASE_S + download_weight_gb(config.model_id) / DOWNLOAD_RATE_GBPS
70
+ s = WORKER_BOOT_S + DEPS_INSTALL_S + model_load
71
+ if config.is_grpo:
72
+ s += VLLM_INIT_S
73
+ return s
74
+
75
+
76
+ def seconds_per_step(config: RunConfig, gpu: str) -> float:
77
+ """Steady-state wall time for one optimizer step on ``gpu``."""
78
+ n = config.normalized()
79
+ params = total_params_b(n.model_id) * 1e9
80
+ peak = gpu_tflops(gpu) * 1e12 # FLOP/s
81
+
82
+ if not n.is_grpo:
83
+ flops = SFT_FLOPS_PER_TOKEN_PER_PARAM * params * (n.batch_size * n.seq_len)
84
+ return flops / (peak * MFU_SFT_TRAIN)
85
+
86
+ # GRPO step = rollout (G completions/prompt) + concurrent reward grading + policy/ref update.
87
+ completions = n.batch_size * n.group_size
88
+ gen_tokens = completions * n.completion_len
89
+ gen_s = (GRPO_GEN_FLOPS_PER_TOKEN_PER_PARAM * params * gen_tokens) / (peak * MFU_DECODE)
90
+ update_s = (GRPO_UPDATE_FLOPS_PER_TOKEN_PER_PARAM * params * gen_tokens) / (peak * MFU_TRAIN)
91
+ latency = reward_seconds_per_completion(n.reward_seconds_per_completion)
92
+ reward_s = math.ceil(completions / REWARD_CONCURRENCY) * latency # ceil: a partial wave still costs one latency
93
+ return gen_s + reward_s + update_s
94
+
95
+
96
+ def select_gpu(config: RunConfig) -> tuple[str, int]:
97
+ """(chosen GPU class, required VRAM GB): the cheapest fitting class for the cost.
98
+
99
+ Uses ``pick_gpu``, which (unlike the submit-time allocator) intentionally stays gate-free —
100
+ it considers every fitting class, validated or not — so the estimate reflects the cheapest
101
+ card that *could* run the job. The live allocator restricts to the validated pool, so the
102
+ actually-provisioned class can be pricier than this. Catalog sizing is offline/deterministic."""
103
+ total_params_b(config.model_id) # catalog-only: reject a non-catalog model before any (HF) sizing
104
+ need = required_vram_gb(
105
+ config.model_id,
106
+ config.method,
107
+ train=config.train_knobs(),
108
+ thinking=config.thinking,
109
+ )
110
+ gpu = pick_gpu(need, provider=config.provider)
111
+ return gpu, need
112
+
113
+
114
+ def _notes(config: RunConfig, raw_train_s: float, wall_capped: bool, cap_s: float) -> tuple[str, ...]:
115
+ n = config.normalized()
116
+ notes: list[str] = []
117
+ if (quant := model_quant(n.model_id)) != "bf16":
118
+ notes.append(f"{quant}: smaller VRAM footprint -> cheaper GPU class fits")
119
+ if n.is_grpo:
120
+ comps = n.batch_size * n.group_size
121
+ rsec = reward_seconds_per_completion(n.reward_seconds_per_completion)
122
+ notes.append(
123
+ f"GRPO step = vLLM rollout of {n.batch_size}x{n.group_size}={comps} completions "
124
+ f"@ {n.completion_len} tok + reward ({rsec:.2f}s/completion"
125
+ + (f", env {n.environment}" if n.environment else "")
126
+ + ") + policy+reference update"
127
+ )
128
+ notes.append(f"GPU sized with {vram_headroom() - 1:.0%} VRAM headroom; static GPU $/hr")
129
+ if wall_capped:
130
+ per_seed = "" if config.setup_repeats == 1 else "per-seed "
131
+ notes.append(
132
+ f"training clamped to fit the {_fmt_duration(cap_s)} {per_seed}wall cap "
133
+ f"(after setup; uncapped: {_fmt_duration(raw_train_s)})"
134
+ )
135
+ return tuple(notes)
136
+
137
+
138
+ def estimate_cost(config: RunConfig, *, wall_cap_s: float = DEFAULT_WALL_CAP_S) -> CostEstimate:
139
+ """Deterministic pre-flight cost calculation."""
140
+ gpu, need = select_gpu(config)
141
+ hourly = gpu_hourly_usd(gpu, provider=config.provider)
142
+ # Mirror the runner's max(60, max_wall_seconds) floor so a sub-60s cap isn't underpriced.
143
+ cap_s = max(60.0, float(config.max_wall_seconds)) if config.max_wall_seconds is not None else wall_cap_s
144
+
145
+ # Each seed is its own job (own cold start + own wall cap): price one seed, clamp, x seeds.
146
+ seeds = config.setup_repeats
147
+ setup_per_seed = setup_seconds(config)
148
+ sps = seconds_per_step(config, gpu)
149
+ raw_train_per_seed = (config.steps / seeds) * sps
150
+
151
+ # The cap is on total per-seed wall; setup is billed too, so clamp training to fit it.
152
+ wall_capped = (setup_per_seed + raw_train_per_seed) > cap_s
153
+ setup_per_seed = min(setup_per_seed, cap_s)
154
+ train_per_seed = max(0.0, cap_s - setup_per_seed) if wall_capped else raw_train_per_seed
155
+
156
+ setup, train = setup_per_seed * seeds, train_per_seed * seeds
157
+ wall = setup + train
158
+
159
+ return CostEstimate(
160
+ model_id=config.model_id,
161
+ method=config.method,
162
+ steps=config.steps,
163
+ gpu=gpu,
164
+ provider=config.provider,
165
+ gpu_vram_gb=gpu_vram_gb(gpu),
166
+ required_vram_gb=need,
167
+ gpu_hourly_usd=hourly,
168
+ setup_seconds=setup,
169
+ seconds_per_step=sps,
170
+ train_seconds=train,
171
+ wall_clock_seconds=wall,
172
+ wall_capped=wall_capped,
173
+ total_usd=wall / 3600.0 * hourly,
174
+ notes=_notes(config, raw_train_per_seed, wall_capped, cap_s),
175
+ )