freesolo-flash-dev 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flash/__init__.py +29 -0
- flash/_channel.py +23 -0
- flash/_fileio.py +35 -0
- flash/_logging.py +49 -0
- flash/_update_check.py +266 -0
- flash/catalog.py +253 -0
- flash/cli/__init__.py +1 -0
- flash/cli/main/__init__.py +227 -0
- flash/cli/main/__main__.py +6 -0
- flash/cli/main/commands.py +636 -0
- flash/cli/main/envpush.py +317 -0
- flash/cli/main/render.py +599 -0
- flash/cli/main/training_doc.py +455 -0
- flash/client/__init__.py +14 -0
- flash/client/config.py +70 -0
- flash/client/http.py +372 -0
- flash/client/runtime_secrets.py +69 -0
- flash/client/specs.py +20 -0
- flash/cost/__init__.py +16 -0
- flash/cost/analytical.py +175 -0
- flash/cost/facts.py +114 -0
- flash/cost/spec.py +113 -0
- flash/cost/types.py +158 -0
- flash/engine/__init__.py +6 -0
- flash/engine/accounting.py +36 -0
- flash/engine/chalk_kernels.py +116 -0
- flash/engine/multiturn_rollout.py +780 -0
- flash/engine/recipe.py +86 -0
- flash/engine/vram.py +603 -0
- flash/engine/worker/__init__.py +2916 -0
- flash/engine/worker/__main__.py +4 -0
- flash/engine/worker/kernel_warmup.py +400 -0
- flash/engine/worker/lora.py +796 -0
- flash/engine/worker/packing.py +366 -0
- flash/engine/worker/perf.py +1048 -0
- flash/envs/__init__.py +10 -0
- flash/envs/adapter/__init__.py +883 -0
- flash/envs/adapter/rubric.py +222 -0
- flash/envs/base.py +52 -0
- flash/envs/registry.py +62 -0
- flash/mcp/__init__.py +1 -0
- flash/mcp/server.py +85 -0
- flash/providers/__init__.py +59 -0
- flash/providers/_auth.py +24 -0
- flash/providers/_http.py +230 -0
- flash/providers/_instance.py +416 -0
- flash/providers/_instance_bootstrap.py +517 -0
- flash/providers/_poll.py +311 -0
- flash/providers/allocator.py +193 -0
- flash/providers/base.py +431 -0
- flash/providers/hyperstack/__init__.py +127 -0
- flash/providers/hyperstack/api.py +522 -0
- flash/providers/hyperstack/auth.py +17 -0
- flash/providers/hyperstack/gpus.py +29 -0
- flash/providers/hyperstack/jobs/__init__.py +632 -0
- flash/providers/hyperstack/jobs/builders.py +122 -0
- flash/providers/hyperstack/preflight.py +23 -0
- flash/providers/hyperstack/pricing.py +26 -0
- flash/providers/hyperstack/train.py +25 -0
- flash/providers/lambdalabs/__init__.py +139 -0
- flash/providers/lambdalabs/api.py +261 -0
- flash/providers/lambdalabs/auth.py +18 -0
- flash/providers/lambdalabs/gpus.py +29 -0
- flash/providers/lambdalabs/jobs/__init__.py +724 -0
- flash/providers/lambdalabs/jobs/builders.py +118 -0
- flash/providers/lambdalabs/preflight.py +27 -0
- flash/providers/lambdalabs/pricing.py +51 -0
- flash/providers/lambdalabs/train.py +27 -0
- flash/providers/preflight.py +55 -0
- flash/providers/realized.py +80 -0
- flash/providers/runpod/__init__.py +130 -0
- flash/providers/runpod/api.py +186 -0
- flash/providers/runpod/auth.py +37 -0
- flash/providers/runpod/cost.py +57 -0
- flash/providers/runpod/gpus.py +46 -0
- flash/providers/runpod/jobs.py +956 -0
- flash/providers/runpod/keys.py +139 -0
- flash/providers/runpod/preflight.py +30 -0
- flash/providers/runpod/preload.py +915 -0
- flash/providers/runpod/pricing.py +18 -0
- flash/providers/runpod/slots.py +79 -0
- flash/providers/runpod/train/__init__.py +150 -0
- flash/providers/runpod/train/deps.py +395 -0
- flash/providers/runpod/train/endpoints.py +820 -0
- flash/py.typed +0 -0
- flash/runner/__init__.py +686 -0
- flash/runner/checkpoints.py +82 -0
- flash/runner/deploy.py +422 -0
- flash/runner/lifecycle.py +672 -0
- flash/schema/__init__.py +375 -0
- flash/schema/fields.py +331 -0
- flash/serve/__init__.py +1 -0
- flash/serve/deploy.py +326 -0
- flash/serve/pricing.py +60 -0
- flash/server/__init__.py +1 -0
- flash/server/__main__.py +20 -0
- flash/server/app.py +961 -0
- flash/server/auth.py +263 -0
- flash/server/billing.py +124 -0
- flash/server/checkpoints.py +110 -0
- flash/server/db.py +160 -0
- flash/server/environment_registry.py +102 -0
- flash/server/envs.py +360 -0
- flash/server/reconcile.py +163 -0
- flash/server/run_registry.py +150 -0
- flash/spec.py +333 -0
- freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
- freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
- freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
- freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
- freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
flash/schema/fields.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
"""Field-level validators/coercers for Flash TOML config parsing.
|
|
2
|
+
|
|
3
|
+
Leaf helpers split out of ``flash.schema``: the ``ConfigError`` type, the [train] scalar
|
|
4
|
+
validators, the slug/worker-env/wandb validators, and the ``--set`` scalar coercer. None
|
|
5
|
+
reference the rest of the schema package; the package ``__init__`` re-exports them.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import math
|
|
11
|
+
import re
|
|
12
|
+
import urllib.parse
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
from flash.spec import WandbSpec
|
|
16
|
+
|
|
17
|
+
_GITHUB_SAFE_PART_RE = re.compile(r"^[A-Za-z0-9._-]+$")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _train_int(train_raw: dict, key: str, *, minimum: int) -> int | None:
|
|
21
|
+
"""Validate an optional integer [train] knob (>= minimum) -> ConfigError (HTTP 400).
|
|
22
|
+
|
|
23
|
+
None stays None (recipe default). Rejects bools, non-numbers, non-integers, and
|
|
24
|
+
out-of-range values at parse time instead of letting them reach a provisioned worker.
|
|
25
|
+
"""
|
|
26
|
+
v = train_raw.get(key)
|
|
27
|
+
if v is None:
|
|
28
|
+
return None
|
|
29
|
+
if isinstance(v, bool) or not isinstance(v, (int, float)):
|
|
30
|
+
raise ConfigError(f"train.{key} must be an integer")
|
|
31
|
+
# Check finiteness BEFORE int(v): int(inf) raises OverflowError and int(nan) ValueError
|
|
32
|
+
# (the former would be a 500); reject both as a clean 400.
|
|
33
|
+
if not math.isfinite(v) or float(v) != int(v):
|
|
34
|
+
raise ConfigError(f"train.{key} must be a finite integer")
|
|
35
|
+
v = int(v)
|
|
36
|
+
if v < minimum:
|
|
37
|
+
raise ConfigError(f"train.{key} must be >= {minimum}")
|
|
38
|
+
return v
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _train_float(
|
|
42
|
+
train_raw: dict,
|
|
43
|
+
key: str,
|
|
44
|
+
*,
|
|
45
|
+
minimum: float,
|
|
46
|
+
exclusive: bool = False,
|
|
47
|
+
maximum: float | None = None,
|
|
48
|
+
) -> float | None:
|
|
49
|
+
"""Validate an optional float [train] knob -> ConfigError (HTTP 400). None stays None."""
|
|
50
|
+
v = train_raw.get(key)
|
|
51
|
+
if v is None:
|
|
52
|
+
return None
|
|
53
|
+
if isinstance(v, bool) or not isinstance(v, (int, float)):
|
|
54
|
+
raise ConfigError(f"train.{key} must be a number")
|
|
55
|
+
v = float(v)
|
|
56
|
+
# nan/inf slip past the range checks below (nan compares false, inf passes any minimum)
|
|
57
|
+
# and would reach TRL optimizer/sampling settings; reject them as a 400 here.
|
|
58
|
+
if not math.isfinite(v):
|
|
59
|
+
raise ConfigError(f"train.{key} must be a finite number")
|
|
60
|
+
if exclusive and v <= minimum:
|
|
61
|
+
raise ConfigError(f"train.{key} must be > {minimum}")
|
|
62
|
+
if not exclusive and v < minimum:
|
|
63
|
+
raise ConfigError(f"train.{key} must be >= {minimum}")
|
|
64
|
+
if maximum is not None and v > maximum:
|
|
65
|
+
raise ConfigError(f"train.{key} must be between {minimum} and {maximum}")
|
|
66
|
+
return v
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _train_stops(train_raw: dict) -> tuple[str, ...]:
|
|
70
|
+
"""Validate stop_sequences -> ConfigError. A string is ONE stop (never char-split);
|
|
71
|
+
a list must hold strings; empties are dropped; anything else is rejected."""
|
|
72
|
+
v = train_raw.get("stop_sequences")
|
|
73
|
+
if v is None:
|
|
74
|
+
return ()
|
|
75
|
+
if isinstance(v, str):
|
|
76
|
+
return (v,) if v else ()
|
|
77
|
+
if not isinstance(v, (list, tuple)):
|
|
78
|
+
raise ConfigError("train.stop_sequences must be a string or a list of strings")
|
|
79
|
+
for s in v:
|
|
80
|
+
if not isinstance(s, str):
|
|
81
|
+
raise ConfigError("train.stop_sequences entries must be strings")
|
|
82
|
+
return tuple(s for s in v if s)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class ConfigError(ValueError):
|
|
86
|
+
pass
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _require_slug(value: str, message: str) -> None:
|
|
90
|
+
"""Require an ``owner/name`` slug."""
|
|
91
|
+
text = (value or "").strip()
|
|
92
|
+
if not text or ":" in text:
|
|
93
|
+
raise ConfigError(message)
|
|
94
|
+
parsed = urllib.parse.urlparse(text)
|
|
95
|
+
if parsed.scheme or parsed.netloc:
|
|
96
|
+
raise ConfigError(message)
|
|
97
|
+
parts = text.split("/")
|
|
98
|
+
if len(parts) != 2 or not _is_safe_github_path_parts(parts):
|
|
99
|
+
raise ConfigError(message)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _require_environment_ref(value: str, message: str) -> None:
|
|
103
|
+
"""Require a Freesolo environment id."""
|
|
104
|
+
try:
|
|
105
|
+
_require_slug(value, message)
|
|
106
|
+
return
|
|
107
|
+
except ConfigError:
|
|
108
|
+
pass
|
|
109
|
+
if value.startswith("github:"):
|
|
110
|
+
body = value[len("github:") :]
|
|
111
|
+
repo_ref, sep, path = body.partition(":")
|
|
112
|
+
repo, at, ref = repo_ref.partition("@")
|
|
113
|
+
if at and not ref:
|
|
114
|
+
raise ConfigError(message)
|
|
115
|
+
owner_repo = repo.split("/")
|
|
116
|
+
if (
|
|
117
|
+
len(owner_repo) == 2
|
|
118
|
+
and _is_safe_github_path_parts(owner_repo)
|
|
119
|
+
and (not at or _is_safe_github_path_parts([ref]))
|
|
120
|
+
and (not sep or _is_safe_environment_path(path))
|
|
121
|
+
):
|
|
122
|
+
return
|
|
123
|
+
raise ConfigError(message)
|
|
124
|
+
if value.startswith("https://github.com/") or value.startswith("http://github.com/"):
|
|
125
|
+
parsed = urllib.parse.urlparse(value)
|
|
126
|
+
if parsed.scheme in {"http", "https"} and parsed.netloc.lower() == "github.com":
|
|
127
|
+
parts = [
|
|
128
|
+
part for part in urllib.parse.unquote(parsed.path).strip("/").split("/") if part
|
|
129
|
+
]
|
|
130
|
+
if len(parts) < 2:
|
|
131
|
+
raise ConfigError(message)
|
|
132
|
+
owner, repo = parts[0], parts[1]
|
|
133
|
+
repo = repo[:-4] if repo.endswith(".git") else repo
|
|
134
|
+
if len(parts) == 2:
|
|
135
|
+
if not _is_safe_github_path_parts([owner, repo]):
|
|
136
|
+
raise ConfigError(message)
|
|
137
|
+
elif len(parts) >= 5 and parts[2] in {"blob", "tree"}:
|
|
138
|
+
ref = parts[3]
|
|
139
|
+
if not _is_safe_github_path_parts([ref]):
|
|
140
|
+
raise ConfigError(message)
|
|
141
|
+
raw_path = "/".join(parts[4:])
|
|
142
|
+
if not _is_safe_environment_path(raw_path):
|
|
143
|
+
raise ConfigError(message)
|
|
144
|
+
if not _is_safe_github_path_parts([owner, repo, ref]):
|
|
145
|
+
raise ConfigError(message)
|
|
146
|
+
else:
|
|
147
|
+
raise ConfigError(message)
|
|
148
|
+
return
|
|
149
|
+
raise ConfigError(message)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _is_safe_environment_path(path: str) -> bool:
|
|
153
|
+
if not path:
|
|
154
|
+
return True
|
|
155
|
+
raw = path.strip().replace("\\", "/")
|
|
156
|
+
if raw.startswith("/"):
|
|
157
|
+
return False
|
|
158
|
+
parts = [part for part in raw.split("/") if part]
|
|
159
|
+
if not parts:
|
|
160
|
+
return True
|
|
161
|
+
return not any(part in {".", ".."} for part in parts)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _is_safe_github_path_parts(parts: list[str]) -> bool:
|
|
165
|
+
if any(part in {".", "..", ""} for part in parts):
|
|
166
|
+
return False
|
|
167
|
+
return all(_GITHUB_SAFE_PART_RE.fullmatch(part) for part in parts)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _coerce_scalar(value: str):
|
|
171
|
+
low = value.strip().lower()
|
|
172
|
+
if low in ("true", "false"):
|
|
173
|
+
return low == "true"
|
|
174
|
+
try:
|
|
175
|
+
return int(value)
|
|
176
|
+
except ValueError:
|
|
177
|
+
pass
|
|
178
|
+
try:
|
|
179
|
+
return float(value)
|
|
180
|
+
except ValueError:
|
|
181
|
+
return value
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _validate_env_var_names(names, context: str) -> None:
|
|
185
|
+
bad_names = sorted(repr(k) for k in names if (not k) or any(c in k for c in "=\0 \t\n\r"))
|
|
186
|
+
if bad_names:
|
|
187
|
+
raise ConfigError(
|
|
188
|
+
f"{context} has invalid environment variable name(s): {', '.join(bad_names)}; an "
|
|
189
|
+
"env var name must be non-empty and contain no '=', whitespace, or NUL byte"
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
_RESERVED_ENVIRONMENT_SECRET_KEYS = frozenset(
|
|
194
|
+
{
|
|
195
|
+
"RUNPOD_API_KEY",
|
|
196
|
+
"HF_TOKEN",
|
|
197
|
+
"HUGGING_FACE_HUB_TOKEN",
|
|
198
|
+
"GITHUB_TOKEN",
|
|
199
|
+
"FREESOLO_API_KEY",
|
|
200
|
+
"FREESOLO_INTERNAL_KEY",
|
|
201
|
+
"RUN_ID",
|
|
202
|
+
"HF_REPO",
|
|
203
|
+
"FLASH_ARM",
|
|
204
|
+
}
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _environment_secrets(raw: Any) -> tuple[str, ...]:
|
|
209
|
+
"""Parse [environment].secrets as declared worker env-var secret names."""
|
|
210
|
+
if raw is None:
|
|
211
|
+
return ()
|
|
212
|
+
if isinstance(raw, str) or not isinstance(raw, (list, tuple)):
|
|
213
|
+
raise ConfigError("[environment] secrets must be a list of environment variable names")
|
|
214
|
+
if not all(isinstance(name, str) for name in raw):
|
|
215
|
+
raise ConfigError("[environment] secrets entries must be strings")
|
|
216
|
+
secrets = tuple(dict.fromkeys(raw))
|
|
217
|
+
_validate_env_var_names(secrets, "[environment] secrets")
|
|
218
|
+
reserved = sorted(set(secrets) & _RESERVED_ENVIRONMENT_SECRET_KEYS)
|
|
219
|
+
if reserved:
|
|
220
|
+
raise ConfigError(
|
|
221
|
+
"[environment] secrets must not include platform-managed key(s): "
|
|
222
|
+
f"{', '.join(reserved)}"
|
|
223
|
+
)
|
|
224
|
+
return secrets
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _worker_env(raw: Any) -> dict[str, str]:
|
|
228
|
+
"""Parse the optional [worker_env] table: per-run worker env overrides (string-valued)."""
|
|
229
|
+
if raw is None:
|
|
230
|
+
return {}
|
|
231
|
+
if not isinstance(raw, dict):
|
|
232
|
+
raise ConfigError("[worker_env] must be a table of string key/values")
|
|
233
|
+
env = {str(k): str(v) for k, v in raw.items()}
|
|
234
|
+
# Env var NAMES must be usable by subprocess.Popen(env=...) on the worker, which raises
|
|
235
|
+
# ValueError for an empty name or one containing '=' or a NUL byte (and whitespace breaks most
|
|
236
|
+
# shells). Reject these at parse time so a malformed [worker_env] (e.g. a TOML quoted key like
|
|
237
|
+
# "BAD=KEY", or an empty key) fails on config load — not after a worker has been provisioned.
|
|
238
|
+
_validate_env_var_names(env, "[worker_env]")
|
|
239
|
+
# [worker_env] is serialized into job_spec_json (persisted + logged), so it must NOT carry
|
|
240
|
+
# secrets — they would leak into run artifacts. Reject secret-looking keys; operators set
|
|
241
|
+
# those as real process environment variables (forwarded to the worker out-of-band) instead.
|
|
242
|
+
# Detect by `_`-delimited WORD components (not substring): flag a secret WORD, or `KEY`
|
|
243
|
+
# qualified by a credential context. This catches HF_TOKEN, *_API_KEY, SECRET_KEY, INTERNAL_KEY,
|
|
244
|
+
# CREDENTIAL, AWS_SECRET_ACCESS_KEY, GITHUB_PAT (PAT word), and credential keys like SSH_KEY /
|
|
245
|
+
# DEPLOY_KEY / GPG_KEY (KEY qualified by a credential context) — while allowing legit knobs whose
|
|
246
|
+
# names merely contain a marker (RL_VLLM_MAX_BATCHED_TOKENS -> word TOKENS, not TOKEN; a bare
|
|
247
|
+
# SORT_KEY -> KEY without a secret qualifier).
|
|
248
|
+
_secret_words = {
|
|
249
|
+
"TOKEN",
|
|
250
|
+
"SECRET",
|
|
251
|
+
"PASSWORD",
|
|
252
|
+
"PASSWD",
|
|
253
|
+
"PASSPHRASE",
|
|
254
|
+
"CREDENTIAL",
|
|
255
|
+
"CREDENTIALS",
|
|
256
|
+
"APIKEY",
|
|
257
|
+
"PRIVATEKEY",
|
|
258
|
+
"PAT", # personal access token (e.g. GITHUB_PAT, GH_PAT)
|
|
259
|
+
}
|
|
260
|
+
_key_qualifiers = {
|
|
261
|
+
"API",
|
|
262
|
+
"SECRET",
|
|
263
|
+
"PRIVATE",
|
|
264
|
+
"ACCESS",
|
|
265
|
+
"INTERNAL",
|
|
266
|
+
"AUTH",
|
|
267
|
+
"SIGNING",
|
|
268
|
+
"ENCRYPTION",
|
|
269
|
+
# credential-key contexts: SSH_KEY, DEPLOY_KEY, GPG_KEY, RSA_KEY, TLS/SSL/PEM keys, etc.
|
|
270
|
+
"SSH",
|
|
271
|
+
"DEPLOY",
|
|
272
|
+
"GPG",
|
|
273
|
+
"PGP",
|
|
274
|
+
"RSA",
|
|
275
|
+
"PEM",
|
|
276
|
+
"SSL",
|
|
277
|
+
"TLS",
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
def _is_secret_key(name: str) -> bool:
|
|
281
|
+
words = set(name.upper().split("_"))
|
|
282
|
+
return bool(words & _secret_words) or ("KEY" in words and bool(words & _key_qualifiers))
|
|
283
|
+
|
|
284
|
+
secrets = sorted(k for k in env if _is_secret_key(k))
|
|
285
|
+
if secrets:
|
|
286
|
+
raise ConfigError(
|
|
287
|
+
f"[worker_env] must not contain secret-bearing keys ({', '.join(secrets)}); these are "
|
|
288
|
+
"serialized into run artifacts; use provider process env or supported runtime secrets "
|
|
289
|
+
"instead"
|
|
290
|
+
)
|
|
291
|
+
return env
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
# Allowed [wandb] config keys -> typed JobSpec.wandb fields (first-class spec config, NOT env vars).
|
|
295
|
+
_WANDB_KEYS = ("project", "run_name")
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def _wandb_spec(raw: Any) -> WandbSpec:
|
|
299
|
+
"""Parse the optional ``[wandb]`` table into a typed ``WandbSpec`` (project / run_name).
|
|
300
|
+
|
|
301
|
+
These are non-secret W&B naming labels carried as first-class spec config (round-tripped in
|
|
302
|
+
the job-spec JSON the worker reads), NOT environment variables. The worker honors them in
|
|
303
|
+
``engine.worker.wandb_report_to`` / ``wandb_run_name``, so a run can land in its own W&B
|
|
304
|
+
project under its own run name instead of the hardcoded ``flash`` / ``flash-…`` defaults.
|
|
305
|
+
Settable in TOML (``[wandb] project = …``) or via ``flash train cfg.toml --set
|
|
306
|
+
wandb.project=… --set wandb.run_name=…``. The actual W&B credential (WANDB_API_KEY) stays an
|
|
307
|
+
env-var secret — only the naming config lives here."""
|
|
308
|
+
if raw is None:
|
|
309
|
+
return WandbSpec()
|
|
310
|
+
if not isinstance(raw, dict):
|
|
311
|
+
raise ConfigError('[wandb] must be a table (e.g. project = "my-project")')
|
|
312
|
+
unknown = sorted(set(raw) - set(_WANDB_KEYS))
|
|
313
|
+
if unknown:
|
|
314
|
+
raise ConfigError(
|
|
315
|
+
f"[wandb] unknown key(s): {', '.join(unknown)} (allowed: {', '.join(_WANDB_KEYS)})"
|
|
316
|
+
)
|
|
317
|
+
values: dict[str, str] = {}
|
|
318
|
+
for key in _WANDB_KEYS:
|
|
319
|
+
val = raw.get(key)
|
|
320
|
+
# Absent OR null means "unset". A serialized JobSpec round-trips unset wandb fields as
|
|
321
|
+
# null (``asdict`` emits ``{"project": null, "run_name": null}``), so re-parsing a spec —
|
|
322
|
+
# which is exactly what the control plane does on submit, ``spec_from_dict(spec.to_dict())``
|
|
323
|
+
# — must accept null without demanding a value, or every run that omits ``[wandb]`` is
|
|
324
|
+
# rejected. Only an explicitly-set value is validated: a bare ""/whitespace is a real
|
|
325
|
+
# config mistake worth flagging.
|
|
326
|
+
if val is None:
|
|
327
|
+
continue
|
|
328
|
+
if not isinstance(val, str) or not val.strip():
|
|
329
|
+
raise ConfigError(f"[wandb] {key} must be a non-empty string")
|
|
330
|
+
values[key] = val.strip()
|
|
331
|
+
return WandbSpec(**values)
|
flash/serve/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Adapter serving helpers."""
|
flash/serve/deploy.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
1
|
+
"""Serve a trained LoRA adapter via the freesolo platform's multi-LoRA serving app.
|
|
2
|
+
|
|
3
|
+
Flash no longer runs its own per-run vLLM endpoint. Instead the control plane is a
|
|
4
|
+
thin client of the freesolo serving service (a Modal multi-LoRA app that serves every
|
|
5
|
+
adapter on shared base-model capacity — so there is no flash-side idle billing to
|
|
6
|
+
track). The same CLI commands and control-plane endpoints
|
|
7
|
+
(`deploy`/`undeploy`/`chat`/`deployments`) stay; only what they do under the hood
|
|
8
|
+
changed.
|
|
9
|
+
|
|
10
|
+
The serving service exposes:
|
|
11
|
+
|
|
12
|
+
- ``POST {FREESOLO_SERVING_URL}/adapters`` — register/deploy an adapter (auth header).
|
|
13
|
+
- ``DELETE {FREESOLO_SERVING_URL}/adapters/{adapterId}`` — undeploy (auth header).
|
|
14
|
+
- ``POST {FREESOLO_SERVING_URL}/v1/chat/completions`` — OpenAI-style chat.
|
|
15
|
+
- ``GET {FREESOLO_SERVING_URL}/healthz`` / ``GET .../adapters`` — health / list.
|
|
16
|
+
|
|
17
|
+
The registration/teardown calls carry the shared ``X-Freesolo-Internal-Key`` header
|
|
18
|
+
(the same internal credential flash already holds, ``FREESOLO_INTERNAL_KEY``). The chat
|
|
19
|
+
calls also send it: the control plane is a trusted server-to-server caller (it has already
|
|
20
|
+
authorized the user's key on its own ``/v1/runs/{run_id}/chat`` route), so it uses the
|
|
21
|
+
serving app's internal-key bypass when serving enforces external chat auth.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
import json
|
|
27
|
+
import os
|
|
28
|
+
from collections.abc import Iterator
|
|
29
|
+
from dataclasses import asdict, dataclass
|
|
30
|
+
|
|
31
|
+
import httpx
|
|
32
|
+
|
|
33
|
+
from flash._logging import get_logger
|
|
34
|
+
from flash.providers.base import canonical_gpu, gpu_short
|
|
35
|
+
|
|
36
|
+
logger = get_logger(__name__)
|
|
37
|
+
|
|
38
|
+
# Default freesolo serving base URL (the Modal multi-LoRA app). Overridable per-env.
|
|
39
|
+
DEFAULT_FREESOLO_SERVING_URL = "https://clado-ai--freesolo-lora-serving.modal.run"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ServingError(RuntimeError):
|
|
43
|
+
"""The freesolo serving backend (Modal LoRA app) rejected a request or was unreachable.
|
|
44
|
+
|
|
45
|
+
Carries the upstream status (when there was an HTTP response) so the API layer can
|
|
46
|
+
surface a clean ``502 Bad Gateway`` with the real reason instead of letting an
|
|
47
|
+
``httpx`` exception escape as an unhandled ``500`` + traceback.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(self, message: str, *, status_code: int | None = None):
|
|
51
|
+
super().__init__(message)
|
|
52
|
+
self.status_code = status_code
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _post_adapter_or_raise(url: str, body: dict) -> httpx.Response:
|
|
56
|
+
"""POST an adapter registration to the serving backend, translating any transport- or
|
|
57
|
+
status-level failure into a ``ServingError`` that carries the upstream detail."""
|
|
58
|
+
try:
|
|
59
|
+
# follow_redirects: Modal answers a slow request with a 303 to an async-result poll URL
|
|
60
|
+
# (?__modal_function_call_id=...); without following it httpx raises on the 303 (see chat).
|
|
61
|
+
resp = httpx.post(
|
|
62
|
+
url,
|
|
63
|
+
json=body,
|
|
64
|
+
headers=_internal_key_header(),
|
|
65
|
+
timeout=60.0,
|
|
66
|
+
follow_redirects=True,
|
|
67
|
+
)
|
|
68
|
+
resp.raise_for_status()
|
|
69
|
+
return resp
|
|
70
|
+
except httpx.HTTPStatusError as exc:
|
|
71
|
+
raise _serving_status_error(url, exc) from exc
|
|
72
|
+
except httpx.RequestError as exc:
|
|
73
|
+
raise ServingError(f"could not reach the serving backend at {url}: {exc}") from exc
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _serving_status_error(url: str, exc: httpx.HTTPStatusError) -> ServingError:
|
|
77
|
+
"""Build a ``ServingError`` from an upstream HTTP failure, carrying the status and a
|
|
78
|
+
4xx-vs-5xx-tailored hint (shared by the deploy POST and the undeploy DELETE)."""
|
|
79
|
+
# raise_for_status() always carries a response, but a hand-built HTTPStatusError may
|
|
80
|
+
# not — guard so error translation can never itself raise.
|
|
81
|
+
resp = exc.response
|
|
82
|
+
status = resp.status_code if resp is not None else None
|
|
83
|
+
detail = ((resp.text if resp is not None else "") or "").strip()[:500]
|
|
84
|
+
msg = f"serving backend error for {url}"
|
|
85
|
+
if status is not None:
|
|
86
|
+
msg += f" (HTTP {status})"
|
|
87
|
+
if detail:
|
|
88
|
+
msg += f": {detail}"
|
|
89
|
+
# Tailor the hint to the upstream status: a 4xx is a client/auth problem with THIS request
|
|
90
|
+
# (e.g. a missing/invalid FREESOLO_INTERNAL_KEY), not a serving outage; a 5xx (or unknown)
|
|
91
|
+
# means the backend itself failed / has no engine for the base model.
|
|
92
|
+
if status is not None and status < 500:
|
|
93
|
+
msg += (
|
|
94
|
+
" — the serving backend rejected the request (4xx); check FREESOLO_INTERNAL_KEY "
|
|
95
|
+
"and the request payload (this is a client/auth error, not a serving outage)"
|
|
96
|
+
)
|
|
97
|
+
else:
|
|
98
|
+
msg += (
|
|
99
|
+
" — the serving backend is unavailable or has no engine for this base model; "
|
|
100
|
+
"an operator must check the freesolo serving deployment"
|
|
101
|
+
)
|
|
102
|
+
return ServingError(msg, status_code=status)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def serving_base_url() -> str:
|
|
106
|
+
"""The freesolo serving base URL (env-overridable, trailing slash stripped)."""
|
|
107
|
+
return (os.environ.get("FREESOLO_SERVING_URL") or DEFAULT_FREESOLO_SERVING_URL).rstrip("/")
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _internal_key_header() -> dict[str, str]:
|
|
111
|
+
key = os.environ.get("FREESOLO_INTERNAL_KEY") or ""
|
|
112
|
+
return {"X-Freesolo-Internal-Key": key} if key else {}
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@dataclass
|
|
116
|
+
class Deployment:
|
|
117
|
+
run_id: str
|
|
118
|
+
model: str
|
|
119
|
+
adapter_hf_prefix: str
|
|
120
|
+
gpu: str
|
|
121
|
+
openai_model: str
|
|
122
|
+
endpoint_name: str
|
|
123
|
+
state: str = "ready"
|
|
124
|
+
|
|
125
|
+
def to_dict(self) -> dict:
|
|
126
|
+
return asdict(self)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def serve_endpoint_name(friendly_gpu: str, run_id: str) -> str:
|
|
130
|
+
"""Cosmetic endpoint label (the freesolo app serves all adapters on one endpoint)."""
|
|
131
|
+
tail = (run_id or "").split("-")[-1][:24]
|
|
132
|
+
base = f"flash-serve-{gpu_short(canonical_gpu(friendly_gpu))}"
|
|
133
|
+
return f"{base}-{tail}" if tail else base
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def servable_gpu(gpu_name: str) -> str:
|
|
137
|
+
"""Resolve a friendly GPU class for the deployment record.
|
|
138
|
+
|
|
139
|
+
Serving is delegated to freesolo (one GPU per base model, chosen there), so this is
|
|
140
|
+
now informational. We still canonicalize the name and fall back to the cheapest RunPod
|
|
141
|
+
class big enough when the trained class isn't a RunPod class, so the recorded ``gpu`` is
|
|
142
|
+
a sensible, valid class (and junk GPU names still raise)."""
|
|
143
|
+
from flash.providers.base import GPU_INFO, cheapest_gpu
|
|
144
|
+
|
|
145
|
+
friendly = canonical_gpu(gpu_name)
|
|
146
|
+
info = GPU_INFO[friendly]
|
|
147
|
+
if info.enum_member: # a RunPod class — serve it directly
|
|
148
|
+
return friendly
|
|
149
|
+
return cheapest_gpu(info.vram_gb) # else the cheapest RunPod class that fits
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def deploy_adapter(
|
|
153
|
+
run_id: str,
|
|
154
|
+
model: str,
|
|
155
|
+
hf_repo: str,
|
|
156
|
+
adapter_prefix: str,
|
|
157
|
+
gpu_name: str = "RTX 5090",
|
|
158
|
+
dry_run: bool = False,
|
|
159
|
+
thinking: bool = False,
|
|
160
|
+
org_id: str | None = None,
|
|
161
|
+
) -> Deployment:
|
|
162
|
+
"""Register the trained adapter with the freesolo serving app.
|
|
163
|
+
|
|
164
|
+
The adapter artifacts already live in the run's HF dataset repo (the trainer
|
|
165
|
+
streamed them there); freesolo serving pulls them from
|
|
166
|
+
``{hf_repo}:{adapter_prefix}/adapter``. ``dry_run`` validates/shapes the deployment
|
|
167
|
+
without making the network call.
|
|
168
|
+
"""
|
|
169
|
+
friendly = servable_gpu(gpu_name)
|
|
170
|
+
subfolder = f"{adapter_prefix}/adapter"
|
|
171
|
+
dep = Deployment(
|
|
172
|
+
run_id=run_id,
|
|
173
|
+
model=model,
|
|
174
|
+
adapter_hf_prefix=subfolder,
|
|
175
|
+
gpu=friendly,
|
|
176
|
+
openai_model=run_id,
|
|
177
|
+
endpoint_name=serving_base_url(),
|
|
178
|
+
state="dry_run" if dry_run else "ready",
|
|
179
|
+
)
|
|
180
|
+
if dry_run:
|
|
181
|
+
return dep
|
|
182
|
+
base = serving_base_url()
|
|
183
|
+
body = {
|
|
184
|
+
"adapterId": run_id,
|
|
185
|
+
"repoId": hf_repo,
|
|
186
|
+
"baseModel": model,
|
|
187
|
+
"subfolder": subfolder,
|
|
188
|
+
# The trainer always streams the adapter into a *dataset* repo (the worker's
|
|
189
|
+
# hf_upload_folder uses repo_type="dataset"), so serving must pull from the dataset
|
|
190
|
+
# namespace. Without this the serving app defaults repoType to "model" and
|
|
191
|
+
# snapshot_download 404s on the model namespace — deploy returns 200 but the engine
|
|
192
|
+
# warmup fails, the adapter is silently disabled, and the first chat 404s.
|
|
193
|
+
"repoType": "dataset",
|
|
194
|
+
"status": "ready",
|
|
195
|
+
}
|
|
196
|
+
# Attribute the adapter to the deploying org so serving can authorize external chat by org:
|
|
197
|
+
# the backend maps adapterId -> org via hosted_lora_adapters.org_id, which serving persists
|
|
198
|
+
# from this field. Normalize (strip) and omit when blank (older callers / whitespace) so the
|
|
199
|
+
# registration shape is unchanged and a stray " org " can't mis-attribute the adapter.
|
|
200
|
+
normalized_org_id = (org_id or "").strip()
|
|
201
|
+
if normalized_org_id:
|
|
202
|
+
body["orgId"] = normalized_org_id
|
|
203
|
+
_post_adapter_or_raise(f"{base}/adapters", body)
|
|
204
|
+
logger.info("registered adapter %s with freesolo serving (%s)", run_id, base)
|
|
205
|
+
return dep
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def undeploy_adapter(run_id: str) -> list[str]:
|
|
209
|
+
"""Deregister the run's adapter from the freesolo serving app.
|
|
210
|
+
|
|
211
|
+
Returns ``[run_id]`` when the adapter was removed (200), ``[]`` when it was already
|
|
212
|
+
gone (404). Any other failure — a non-404 HTTP status or a transport error — is
|
|
213
|
+
translated into a ``ServingError`` (carrying the upstream status), exactly like
|
|
214
|
+
``deploy_adapter``, so callers see a stable error surface (the API maps it to a clean
|
|
215
|
+
502) instead of a raw ``httpx`` exception escaping as an unhandled 500.
|
|
216
|
+
"""
|
|
217
|
+
base = serving_base_url()
|
|
218
|
+
url = f"{base}/adapters/{run_id}"
|
|
219
|
+
try:
|
|
220
|
+
resp = httpx.delete(
|
|
221
|
+
url,
|
|
222
|
+
headers=_internal_key_header(),
|
|
223
|
+
timeout=60.0,
|
|
224
|
+
# Modal answers a slow request with a 303 to an async-result poll URL; follow it (see chat).
|
|
225
|
+
follow_redirects=True,
|
|
226
|
+
)
|
|
227
|
+
# Undeploy is idempotent: an already-absent adapter (404) is a no-op success, not an
|
|
228
|
+
# error — handle it before raise_for_status() so it never becomes a ServingError.
|
|
229
|
+
if resp.status_code == 404:
|
|
230
|
+
return []
|
|
231
|
+
resp.raise_for_status()
|
|
232
|
+
except httpx.HTTPStatusError as exc:
|
|
233
|
+
raise _serving_status_error(url, exc) from exc
|
|
234
|
+
except httpx.RequestError as exc:
|
|
235
|
+
raise ServingError(f"could not reach the serving backend at {url}: {exc}") from exc
|
|
236
|
+
logger.info("deregistered adapter %s from freesolo serving (%s)", run_id, base)
|
|
237
|
+
return [run_id]
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def chat(
|
|
241
|
+
run_id: str,
|
|
242
|
+
messages: list[dict],
|
|
243
|
+
temperature: float = 0.0,
|
|
244
|
+
max_tokens: int = 512,
|
|
245
|
+
thinking: bool = False,
|
|
246
|
+
) -> dict:
|
|
247
|
+
"""Send an OpenAI-style chat request for the run's adapter to freesolo serving.
|
|
248
|
+
|
|
249
|
+
The adapter is addressed by ``model=run_id`` (its registered ``adapterId``); the
|
|
250
|
+
response is the parsed OpenAI chat-completion dict, so
|
|
251
|
+
``resp["choices"][0]["message"]["content"]`` keeps working downstream.
|
|
252
|
+
"""
|
|
253
|
+
base = serving_base_url()
|
|
254
|
+
body = {
|
|
255
|
+
"model": run_id,
|
|
256
|
+
"messages": messages,
|
|
257
|
+
"max_tokens": int(max_tokens),
|
|
258
|
+
"temperature": float(temperature),
|
|
259
|
+
# Per-run thinking parity: a run trained with thinking must serve with thinking, so
|
|
260
|
+
# forward the flag to the chat template (enable_thinking is the kwarg the renderer and
|
|
261
|
+
# rollout path use, e.g. multiturn_rollout.build_rollout_func). Without this the served
|
|
262
|
+
# completions diverge from training behavior even though the caller passes thinking=.
|
|
263
|
+
"chat_template_kwargs": {"enable_thinking": bool(thinking)},
|
|
264
|
+
}
|
|
265
|
+
# Cold starts (scale-from-zero per base model) can take minutes. Modal serves a slow ASGI
|
|
266
|
+
# request by 303-redirecting to an async-result poll URL (?__modal_function_call_id=...), so
|
|
267
|
+
# the client must follow redirects to retrieve the eventual completion — without this httpx
|
|
268
|
+
# raises on the 303 and the chat fails mid cold-start. max_redirects is raised because a long
|
|
269
|
+
# cold start polls across several redirect cycles before the result is ready.
|
|
270
|
+
with httpx.Client(follow_redirects=True, max_redirects=100, timeout=30 * 60.0) as client:
|
|
271
|
+
# The control plane is a trusted server-to-server caller (it already authorized the user's
|
|
272
|
+
# key on the /v1/runs/{run_id}/chat route), so present the internal key to pass serving's
|
|
273
|
+
# external chat-auth gate. No-op when the gate is off or the key is unset.
|
|
274
|
+
resp = client.post(f"{base}/v1/chat/completions", json=body, headers=_internal_key_header())
|
|
275
|
+
resp.raise_for_status()
|
|
276
|
+
return resp.json()
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def _openai_stream_content(lines: Iterator[str]) -> Iterator[str]:
|
|
280
|
+
for line in lines:
|
|
281
|
+
line = line.strip()
|
|
282
|
+
if not line.startswith("data:"):
|
|
283
|
+
continue
|
|
284
|
+
data = line.removeprefix("data:").strip()
|
|
285
|
+
if data == "[DONE]":
|
|
286
|
+
break
|
|
287
|
+
if not data:
|
|
288
|
+
continue
|
|
289
|
+
chunk = json.loads(data)
|
|
290
|
+
for choice in chunk.get("choices") or []:
|
|
291
|
+
content = ((choice.get("delta") or {}).get("content")) or ""
|
|
292
|
+
if content:
|
|
293
|
+
yield str(content)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def chat_stream(
|
|
297
|
+
run_id: str,
|
|
298
|
+
messages: list[dict],
|
|
299
|
+
temperature: float = 0.0,
|
|
300
|
+
max_tokens: int = 512,
|
|
301
|
+
thinking: bool = False,
|
|
302
|
+
) -> Iterator[str]:
|
|
303
|
+
"""Yield text deltas from the freesolo OpenAI-compatible streaming endpoint."""
|
|
304
|
+
base = serving_base_url()
|
|
305
|
+
body = {
|
|
306
|
+
"model": run_id,
|
|
307
|
+
"messages": messages,
|
|
308
|
+
"max_tokens": int(max_tokens),
|
|
309
|
+
"temperature": float(temperature),
|
|
310
|
+
"chat_template_kwargs": {"enable_thinking": bool(thinking)},
|
|
311
|
+
"stream": True,
|
|
312
|
+
}
|
|
313
|
+
with (
|
|
314
|
+
httpx.Client(follow_redirects=True, max_redirects=100, timeout=30 * 60.0) as client,
|
|
315
|
+
client.stream(
|
|
316
|
+
"POST", f"{base}/v1/chat/completions", json=body, headers=_internal_key_header()
|
|
317
|
+
) as resp,
|
|
318
|
+
):
|
|
319
|
+
resp.raise_for_status()
|
|
320
|
+
if "application/json" in resp.headers.get("content-type", ""):
|
|
321
|
+
payload = resp.json()
|
|
322
|
+
content = (((payload.get("choices") or [{}])[0].get("message") or {}).get("content"))
|
|
323
|
+
if content:
|
|
324
|
+
yield str(content)
|
|
325
|
+
return
|
|
326
|
+
yield from _openai_stream_content(resp.iter_lines())
|