freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
flash/schema/fields.py ADDED
@@ -0,0 +1,331 @@
1
+ """Field-level validators/coercers for Flash TOML config parsing.
2
+
3
+ Leaf helpers split out of ``flash.schema``: the ``ConfigError`` type, the [train] scalar
4
+ validators, the slug/worker-env/wandb validators, and the ``--set`` scalar coercer. None
5
+ reference the rest of the schema package; the package ``__init__`` re-exports them.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import math
11
+ import re
12
+ import urllib.parse
13
+ from typing import Any
14
+
15
+ from flash.spec import WandbSpec
16
+
17
+ _GITHUB_SAFE_PART_RE = re.compile(r"^[A-Za-z0-9._-]+$")
18
+
19
+
20
+ def _train_int(train_raw: dict, key: str, *, minimum: int) -> int | None:
21
+ """Validate an optional integer [train] knob (>= minimum) -> ConfigError (HTTP 400).
22
+
23
+ None stays None (recipe default). Rejects bools, non-numbers, non-integers, and
24
+ out-of-range values at parse time instead of letting them reach a provisioned worker.
25
+ """
26
+ v = train_raw.get(key)
27
+ if v is None:
28
+ return None
29
+ if isinstance(v, bool) or not isinstance(v, (int, float)):
30
+ raise ConfigError(f"train.{key} must be an integer")
31
+ # Check finiteness BEFORE int(v): int(inf) raises OverflowError and int(nan) ValueError
32
+ # (the former would be a 500); reject both as a clean 400.
33
+ if not math.isfinite(v) or float(v) != int(v):
34
+ raise ConfigError(f"train.{key} must be a finite integer")
35
+ v = int(v)
36
+ if v < minimum:
37
+ raise ConfigError(f"train.{key} must be >= {minimum}")
38
+ return v
39
+
40
+
41
+ def _train_float(
42
+ train_raw: dict,
43
+ key: str,
44
+ *,
45
+ minimum: float,
46
+ exclusive: bool = False,
47
+ maximum: float | None = None,
48
+ ) -> float | None:
49
+ """Validate an optional float [train] knob -> ConfigError (HTTP 400). None stays None."""
50
+ v = train_raw.get(key)
51
+ if v is None:
52
+ return None
53
+ if isinstance(v, bool) or not isinstance(v, (int, float)):
54
+ raise ConfigError(f"train.{key} must be a number")
55
+ v = float(v)
56
+ # nan/inf slip past the range checks below (nan compares false, inf passes any minimum)
57
+ # and would reach TRL optimizer/sampling settings; reject them as a 400 here.
58
+ if not math.isfinite(v):
59
+ raise ConfigError(f"train.{key} must be a finite number")
60
+ if exclusive and v <= minimum:
61
+ raise ConfigError(f"train.{key} must be > {minimum}")
62
+ if not exclusive and v < minimum:
63
+ raise ConfigError(f"train.{key} must be >= {minimum}")
64
+ if maximum is not None and v > maximum:
65
+ raise ConfigError(f"train.{key} must be between {minimum} and {maximum}")
66
+ return v
67
+
68
+
69
+ def _train_stops(train_raw: dict) -> tuple[str, ...]:
70
+ """Validate stop_sequences -> ConfigError. A string is ONE stop (never char-split);
71
+ a list must hold strings; empties are dropped; anything else is rejected."""
72
+ v = train_raw.get("stop_sequences")
73
+ if v is None:
74
+ return ()
75
+ if isinstance(v, str):
76
+ return (v,) if v else ()
77
+ if not isinstance(v, (list, tuple)):
78
+ raise ConfigError("train.stop_sequences must be a string or a list of strings")
79
+ for s in v:
80
+ if not isinstance(s, str):
81
+ raise ConfigError("train.stop_sequences entries must be strings")
82
+ return tuple(s for s in v if s)
83
+
84
+
85
+ class ConfigError(ValueError):
86
+ pass
87
+
88
+
89
+ def _require_slug(value: str, message: str) -> None:
90
+ """Require an ``owner/name`` slug."""
91
+ text = (value or "").strip()
92
+ if not text or ":" in text:
93
+ raise ConfigError(message)
94
+ parsed = urllib.parse.urlparse(text)
95
+ if parsed.scheme or parsed.netloc:
96
+ raise ConfigError(message)
97
+ parts = text.split("/")
98
+ if len(parts) != 2 or not _is_safe_github_path_parts(parts):
99
+ raise ConfigError(message)
100
+
101
+
102
+ def _require_environment_ref(value: str, message: str) -> None:
103
+ """Require a Freesolo environment id."""
104
+ try:
105
+ _require_slug(value, message)
106
+ return
107
+ except ConfigError:
108
+ pass
109
+ if value.startswith("github:"):
110
+ body = value[len("github:") :]
111
+ repo_ref, sep, path = body.partition(":")
112
+ repo, at, ref = repo_ref.partition("@")
113
+ if at and not ref:
114
+ raise ConfigError(message)
115
+ owner_repo = repo.split("/")
116
+ if (
117
+ len(owner_repo) == 2
118
+ and _is_safe_github_path_parts(owner_repo)
119
+ and (not at or _is_safe_github_path_parts([ref]))
120
+ and (not sep or _is_safe_environment_path(path))
121
+ ):
122
+ return
123
+ raise ConfigError(message)
124
+ if value.startswith("https://github.com/") or value.startswith("http://github.com/"):
125
+ parsed = urllib.parse.urlparse(value)
126
+ if parsed.scheme in {"http", "https"} and parsed.netloc.lower() == "github.com":
127
+ parts = [
128
+ part for part in urllib.parse.unquote(parsed.path).strip("/").split("/") if part
129
+ ]
130
+ if len(parts) < 2:
131
+ raise ConfigError(message)
132
+ owner, repo = parts[0], parts[1]
133
+ repo = repo[:-4] if repo.endswith(".git") else repo
134
+ if len(parts) == 2:
135
+ if not _is_safe_github_path_parts([owner, repo]):
136
+ raise ConfigError(message)
137
+ elif len(parts) >= 5 and parts[2] in {"blob", "tree"}:
138
+ ref = parts[3]
139
+ if not _is_safe_github_path_parts([ref]):
140
+ raise ConfigError(message)
141
+ raw_path = "/".join(parts[4:])
142
+ if not _is_safe_environment_path(raw_path):
143
+ raise ConfigError(message)
144
+ if not _is_safe_github_path_parts([owner, repo, ref]):
145
+ raise ConfigError(message)
146
+ else:
147
+ raise ConfigError(message)
148
+ return
149
+ raise ConfigError(message)
150
+
151
+
152
+ def _is_safe_environment_path(path: str) -> bool:
153
+ if not path:
154
+ return True
155
+ raw = path.strip().replace("\\", "/")
156
+ if raw.startswith("/"):
157
+ return False
158
+ parts = [part for part in raw.split("/") if part]
159
+ if not parts:
160
+ return True
161
+ return not any(part in {".", ".."} for part in parts)
162
+
163
+
164
+ def _is_safe_github_path_parts(parts: list[str]) -> bool:
165
+ if any(part in {".", "..", ""} for part in parts):
166
+ return False
167
+ return all(_GITHUB_SAFE_PART_RE.fullmatch(part) for part in parts)
168
+
169
+
170
+ def _coerce_scalar(value: str):
171
+ low = value.strip().lower()
172
+ if low in ("true", "false"):
173
+ return low == "true"
174
+ try:
175
+ return int(value)
176
+ except ValueError:
177
+ pass
178
+ try:
179
+ return float(value)
180
+ except ValueError:
181
+ return value
182
+
183
+
184
+ def _validate_env_var_names(names, context: str) -> None:
185
+ bad_names = sorted(repr(k) for k in names if (not k) or any(c in k for c in "=\0 \t\n\r"))
186
+ if bad_names:
187
+ raise ConfigError(
188
+ f"{context} has invalid environment variable name(s): {', '.join(bad_names)}; an "
189
+ "env var name must be non-empty and contain no '=', whitespace, or NUL byte"
190
+ )
191
+
192
+
193
+ _RESERVED_ENVIRONMENT_SECRET_KEYS = frozenset(
194
+ {
195
+ "RUNPOD_API_KEY",
196
+ "HF_TOKEN",
197
+ "HUGGING_FACE_HUB_TOKEN",
198
+ "GITHUB_TOKEN",
199
+ "FREESOLO_API_KEY",
200
+ "FREESOLO_INTERNAL_KEY",
201
+ "RUN_ID",
202
+ "HF_REPO",
203
+ "FLASH_ARM",
204
+ }
205
+ )
206
+
207
+
208
+ def _environment_secrets(raw: Any) -> tuple[str, ...]:
209
+ """Parse [environment].secrets as declared worker env-var secret names."""
210
+ if raw is None:
211
+ return ()
212
+ if isinstance(raw, str) or not isinstance(raw, (list, tuple)):
213
+ raise ConfigError("[environment] secrets must be a list of environment variable names")
214
+ if not all(isinstance(name, str) for name in raw):
215
+ raise ConfigError("[environment] secrets entries must be strings")
216
+ secrets = tuple(dict.fromkeys(raw))
217
+ _validate_env_var_names(secrets, "[environment] secrets")
218
+ reserved = sorted(set(secrets) & _RESERVED_ENVIRONMENT_SECRET_KEYS)
219
+ if reserved:
220
+ raise ConfigError(
221
+ "[environment] secrets must not include platform-managed key(s): "
222
+ f"{', '.join(reserved)}"
223
+ )
224
+ return secrets
225
+
226
+
227
+ def _worker_env(raw: Any) -> dict[str, str]:
228
+ """Parse the optional [worker_env] table: per-run worker env overrides (string-valued)."""
229
+ if raw is None:
230
+ return {}
231
+ if not isinstance(raw, dict):
232
+ raise ConfigError("[worker_env] must be a table of string key/values")
233
+ env = {str(k): str(v) for k, v in raw.items()}
234
+ # Env var NAMES must be usable by subprocess.Popen(env=...) on the worker, which raises
235
+ # ValueError for an empty name or one containing '=' or a NUL byte (and whitespace breaks most
236
+ # shells). Reject these at parse time so a malformed [worker_env] (e.g. a TOML quoted key like
237
+ # "BAD=KEY", or an empty key) fails on config load — not after a worker has been provisioned.
238
+ _validate_env_var_names(env, "[worker_env]")
239
+ # [worker_env] is serialized into job_spec_json (persisted + logged), so it must NOT carry
240
+ # secrets — they would leak into run artifacts. Reject secret-looking keys; operators set
241
+ # those as real process environment variables (forwarded to the worker out-of-band) instead.
242
+ # Detect by `_`-delimited WORD components (not substring): flag a secret WORD, or `KEY`
243
+ # qualified by a credential context. This catches HF_TOKEN, *_API_KEY, SECRET_KEY, INTERNAL_KEY,
244
+ # CREDENTIAL, AWS_SECRET_ACCESS_KEY, GITHUB_PAT (PAT word), and credential keys like SSH_KEY /
245
+ # DEPLOY_KEY / GPG_KEY (KEY qualified by a credential context) — while allowing legit knobs whose
246
+ # names merely contain a marker (RL_VLLM_MAX_BATCHED_TOKENS -> word TOKENS, not TOKEN; a bare
247
+ # SORT_KEY -> KEY without a secret qualifier).
248
+ _secret_words = {
249
+ "TOKEN",
250
+ "SECRET",
251
+ "PASSWORD",
252
+ "PASSWD",
253
+ "PASSPHRASE",
254
+ "CREDENTIAL",
255
+ "CREDENTIALS",
256
+ "APIKEY",
257
+ "PRIVATEKEY",
258
+ "PAT", # personal access token (e.g. GITHUB_PAT, GH_PAT)
259
+ }
260
+ _key_qualifiers = {
261
+ "API",
262
+ "SECRET",
263
+ "PRIVATE",
264
+ "ACCESS",
265
+ "INTERNAL",
266
+ "AUTH",
267
+ "SIGNING",
268
+ "ENCRYPTION",
269
+ # credential-key contexts: SSH_KEY, DEPLOY_KEY, GPG_KEY, RSA_KEY, TLS/SSL/PEM keys, etc.
270
+ "SSH",
271
+ "DEPLOY",
272
+ "GPG",
273
+ "PGP",
274
+ "RSA",
275
+ "PEM",
276
+ "SSL",
277
+ "TLS",
278
+ }
279
+
280
+ def _is_secret_key(name: str) -> bool:
281
+ words = set(name.upper().split("_"))
282
+ return bool(words & _secret_words) or ("KEY" in words and bool(words & _key_qualifiers))
283
+
284
+ secrets = sorted(k for k in env if _is_secret_key(k))
285
+ if secrets:
286
+ raise ConfigError(
287
+ f"[worker_env] must not contain secret-bearing keys ({', '.join(secrets)}); these are "
288
+ "serialized into run artifacts; use provider process env or supported runtime secrets "
289
+ "instead"
290
+ )
291
+ return env
292
+
293
+
294
+ # Allowed [wandb] config keys -> typed JobSpec.wandb fields (first-class spec config, NOT env vars).
295
+ _WANDB_KEYS = ("project", "run_name")
296
+
297
+
298
+ def _wandb_spec(raw: Any) -> WandbSpec:
299
+ """Parse the optional ``[wandb]`` table into a typed ``WandbSpec`` (project / run_name).
300
+
301
+ These are non-secret W&B naming labels carried as first-class spec config (round-tripped in
302
+ the job-spec JSON the worker reads), NOT environment variables. The worker honors them in
303
+ ``engine.worker.wandb_report_to`` / ``wandb_run_name``, so a run can land in its own W&B
304
+ project under its own run name instead of the hardcoded ``flash`` / ``flash-…`` defaults.
305
+ Settable in TOML (``[wandb] project = …``) or via ``flash train cfg.toml --set
306
+ wandb.project=… --set wandb.run_name=…``. The actual W&B credential (WANDB_API_KEY) stays an
307
+ env-var secret — only the naming config lives here."""
308
+ if raw is None:
309
+ return WandbSpec()
310
+ if not isinstance(raw, dict):
311
+ raise ConfigError('[wandb] must be a table (e.g. project = "my-project")')
312
+ unknown = sorted(set(raw) - set(_WANDB_KEYS))
313
+ if unknown:
314
+ raise ConfigError(
315
+ f"[wandb] unknown key(s): {', '.join(unknown)} (allowed: {', '.join(_WANDB_KEYS)})"
316
+ )
317
+ values: dict[str, str] = {}
318
+ for key in _WANDB_KEYS:
319
+ val = raw.get(key)
320
+ # Absent OR null means "unset". A serialized JobSpec round-trips unset wandb fields as
321
+ # null (``asdict`` emits ``{"project": null, "run_name": null}``), so re-parsing a spec —
322
+ # which is exactly what the control plane does on submit, ``spec_from_dict(spec.to_dict())``
323
+ # — must accept null without demanding a value, or every run that omits ``[wandb]`` is
324
+ # rejected. Only an explicitly-set value is validated: a bare ""/whitespace is a real
325
+ # config mistake worth flagging.
326
+ if val is None:
327
+ continue
328
+ if not isinstance(val, str) or not val.strip():
329
+ raise ConfigError(f"[wandb] {key} must be a non-empty string")
330
+ values[key] = val.strip()
331
+ return WandbSpec(**values)
@@ -0,0 +1 @@
1
+ """Adapter serving helpers."""
flash/serve/deploy.py ADDED
@@ -0,0 +1,326 @@
1
+ """Serve a trained LoRA adapter via the freesolo platform's multi-LoRA serving app.
2
+
3
+ Flash no longer runs its own per-run vLLM endpoint. Instead the control plane is a
4
+ thin client of the freesolo serving service (a Modal multi-LoRA app that serves every
5
+ adapter on shared base-model capacity — so there is no flash-side idle billing to
6
+ track). The same CLI commands and control-plane endpoints
7
+ (`deploy`/`undeploy`/`chat`/`deployments`) stay; only what they do under the hood
8
+ changed.
9
+
10
+ The serving service exposes:
11
+
12
+ - ``POST {FREESOLO_SERVING_URL}/adapters`` — register/deploy an adapter (auth header).
13
+ - ``DELETE {FREESOLO_SERVING_URL}/adapters/{adapterId}`` — undeploy (auth header).
14
+ - ``POST {FREESOLO_SERVING_URL}/v1/chat/completions`` — OpenAI-style chat.
15
+ - ``GET {FREESOLO_SERVING_URL}/healthz`` / ``GET .../adapters`` — health / list.
16
+
17
+ The registration/teardown calls carry the shared ``X-Freesolo-Internal-Key`` header
18
+ (the same internal credential flash already holds, ``FREESOLO_INTERNAL_KEY``). The chat
19
+ calls also send it: the control plane is a trusted server-to-server caller (it has already
20
+ authorized the user's key on its own ``/v1/runs/{run_id}/chat`` route), so it uses the
21
+ serving app's internal-key bypass when serving enforces external chat auth.
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import json
27
+ import os
28
+ from collections.abc import Iterator
29
+ from dataclasses import asdict, dataclass
30
+
31
+ import httpx
32
+
33
+ from flash._logging import get_logger
34
+ from flash.providers.base import canonical_gpu, gpu_short
35
+
36
+ logger = get_logger(__name__)
37
+
38
+ # Default freesolo serving base URL (the Modal multi-LoRA app). Overridable per-env.
39
+ DEFAULT_FREESOLO_SERVING_URL = "https://clado-ai--freesolo-lora-serving.modal.run"
40
+
41
+
42
+ class ServingError(RuntimeError):
43
+ """The freesolo serving backend (Modal LoRA app) rejected a request or was unreachable.
44
+
45
+ Carries the upstream status (when there was an HTTP response) so the API layer can
46
+ surface a clean ``502 Bad Gateway`` with the real reason instead of letting an
47
+ ``httpx`` exception escape as an unhandled ``500`` + traceback.
48
+ """
49
+
50
+ def __init__(self, message: str, *, status_code: int | None = None):
51
+ super().__init__(message)
52
+ self.status_code = status_code
53
+
54
+
55
+ def _post_adapter_or_raise(url: str, body: dict) -> httpx.Response:
56
+ """POST an adapter registration to the serving backend, translating any transport- or
57
+ status-level failure into a ``ServingError`` that carries the upstream detail."""
58
+ try:
59
+ # follow_redirects: Modal answers a slow request with a 303 to an async-result poll URL
60
+ # (?__modal_function_call_id=...); without following it httpx raises on the 303 (see chat).
61
+ resp = httpx.post(
62
+ url,
63
+ json=body,
64
+ headers=_internal_key_header(),
65
+ timeout=60.0,
66
+ follow_redirects=True,
67
+ )
68
+ resp.raise_for_status()
69
+ return resp
70
+ except httpx.HTTPStatusError as exc:
71
+ raise _serving_status_error(url, exc) from exc
72
+ except httpx.RequestError as exc:
73
+ raise ServingError(f"could not reach the serving backend at {url}: {exc}") from exc
74
+
75
+
76
+ def _serving_status_error(url: str, exc: httpx.HTTPStatusError) -> ServingError:
77
+ """Build a ``ServingError`` from an upstream HTTP failure, carrying the status and a
78
+ 4xx-vs-5xx-tailored hint (shared by the deploy POST and the undeploy DELETE)."""
79
+ # raise_for_status() always carries a response, but a hand-built HTTPStatusError may
80
+ # not — guard so error translation can never itself raise.
81
+ resp = exc.response
82
+ status = resp.status_code if resp is not None else None
83
+ detail = ((resp.text if resp is not None else "") or "").strip()[:500]
84
+ msg = f"serving backend error for {url}"
85
+ if status is not None:
86
+ msg += f" (HTTP {status})"
87
+ if detail:
88
+ msg += f": {detail}"
89
+ # Tailor the hint to the upstream status: a 4xx is a client/auth problem with THIS request
90
+ # (e.g. a missing/invalid FREESOLO_INTERNAL_KEY), not a serving outage; a 5xx (or unknown)
91
+ # means the backend itself failed / has no engine for the base model.
92
+ if status is not None and status < 500:
93
+ msg += (
94
+ " — the serving backend rejected the request (4xx); check FREESOLO_INTERNAL_KEY "
95
+ "and the request payload (this is a client/auth error, not a serving outage)"
96
+ )
97
+ else:
98
+ msg += (
99
+ " — the serving backend is unavailable or has no engine for this base model; "
100
+ "an operator must check the freesolo serving deployment"
101
+ )
102
+ return ServingError(msg, status_code=status)
103
+
104
+
105
+ def serving_base_url() -> str:
106
+ """The freesolo serving base URL (env-overridable, trailing slash stripped)."""
107
+ return (os.environ.get("FREESOLO_SERVING_URL") or DEFAULT_FREESOLO_SERVING_URL).rstrip("/")
108
+
109
+
110
+ def _internal_key_header() -> dict[str, str]:
111
+ key = os.environ.get("FREESOLO_INTERNAL_KEY") or ""
112
+ return {"X-Freesolo-Internal-Key": key} if key else {}
113
+
114
+
115
+ @dataclass
116
+ class Deployment:
117
+ run_id: str
118
+ model: str
119
+ adapter_hf_prefix: str
120
+ gpu: str
121
+ openai_model: str
122
+ endpoint_name: str
123
+ state: str = "ready"
124
+
125
+ def to_dict(self) -> dict:
126
+ return asdict(self)
127
+
128
+
129
+ def serve_endpoint_name(friendly_gpu: str, run_id: str) -> str:
130
+ """Cosmetic endpoint label (the freesolo app serves all adapters on one endpoint)."""
131
+ tail = (run_id or "").split("-")[-1][:24]
132
+ base = f"flash-serve-{gpu_short(canonical_gpu(friendly_gpu))}"
133
+ return f"{base}-{tail}" if tail else base
134
+
135
+
136
+ def servable_gpu(gpu_name: str) -> str:
137
+ """Resolve a friendly GPU class for the deployment record.
138
+
139
+ Serving is delegated to freesolo (one GPU per base model, chosen there), so this is
140
+ now informational. We still canonicalize the name and fall back to the cheapest RunPod
141
+ class big enough when the trained class isn't a RunPod class, so the recorded ``gpu`` is
142
+ a sensible, valid class (and junk GPU names still raise)."""
143
+ from flash.providers.base import GPU_INFO, cheapest_gpu
144
+
145
+ friendly = canonical_gpu(gpu_name)
146
+ info = GPU_INFO[friendly]
147
+ if info.enum_member: # a RunPod class — serve it directly
148
+ return friendly
149
+ return cheapest_gpu(info.vram_gb) # else the cheapest RunPod class that fits
150
+
151
+
152
+ def deploy_adapter(
153
+ run_id: str,
154
+ model: str,
155
+ hf_repo: str,
156
+ adapter_prefix: str,
157
+ gpu_name: str = "RTX 5090",
158
+ dry_run: bool = False,
159
+ thinking: bool = False,
160
+ org_id: str | None = None,
161
+ ) -> Deployment:
162
+ """Register the trained adapter with the freesolo serving app.
163
+
164
+ The adapter artifacts already live in the run's HF dataset repo (the trainer
165
+ streamed them there); freesolo serving pulls them from
166
+ ``{hf_repo}:{adapter_prefix}/adapter``. ``dry_run`` validates/shapes the deployment
167
+ without making the network call.
168
+ """
169
+ friendly = servable_gpu(gpu_name)
170
+ subfolder = f"{adapter_prefix}/adapter"
171
+ dep = Deployment(
172
+ run_id=run_id,
173
+ model=model,
174
+ adapter_hf_prefix=subfolder,
175
+ gpu=friendly,
176
+ openai_model=run_id,
177
+ endpoint_name=serving_base_url(),
178
+ state="dry_run" if dry_run else "ready",
179
+ )
180
+ if dry_run:
181
+ return dep
182
+ base = serving_base_url()
183
+ body = {
184
+ "adapterId": run_id,
185
+ "repoId": hf_repo,
186
+ "baseModel": model,
187
+ "subfolder": subfolder,
188
+ # The trainer always streams the adapter into a *dataset* repo (the worker's
189
+ # hf_upload_folder uses repo_type="dataset"), so serving must pull from the dataset
190
+ # namespace. Without this the serving app defaults repoType to "model" and
191
+ # snapshot_download 404s on the model namespace — deploy returns 200 but the engine
192
+ # warmup fails, the adapter is silently disabled, and the first chat 404s.
193
+ "repoType": "dataset",
194
+ "status": "ready",
195
+ }
196
+ # Attribute the adapter to the deploying org so serving can authorize external chat by org:
197
+ # the backend maps adapterId -> org via hosted_lora_adapters.org_id, which serving persists
198
+ # from this field. Normalize (strip) and omit when blank (older callers / whitespace) so the
199
+ # registration shape is unchanged and a stray " org " can't mis-attribute the adapter.
200
+ normalized_org_id = (org_id or "").strip()
201
+ if normalized_org_id:
202
+ body["orgId"] = normalized_org_id
203
+ _post_adapter_or_raise(f"{base}/adapters", body)
204
+ logger.info("registered adapter %s with freesolo serving (%s)", run_id, base)
205
+ return dep
206
+
207
+
208
+ def undeploy_adapter(run_id: str) -> list[str]:
209
+ """Deregister the run's adapter from the freesolo serving app.
210
+
211
+ Returns ``[run_id]`` when the adapter was removed (200), ``[]`` when it was already
212
+ gone (404). Any other failure — a non-404 HTTP status or a transport error — is
213
+ translated into a ``ServingError`` (carrying the upstream status), exactly like
214
+ ``deploy_adapter``, so callers see a stable error surface (the API maps it to a clean
215
+ 502) instead of a raw ``httpx`` exception escaping as an unhandled 500.
216
+ """
217
+ base = serving_base_url()
218
+ url = f"{base}/adapters/{run_id}"
219
+ try:
220
+ resp = httpx.delete(
221
+ url,
222
+ headers=_internal_key_header(),
223
+ timeout=60.0,
224
+ # Modal answers a slow request with a 303 to an async-result poll URL; follow it (see chat).
225
+ follow_redirects=True,
226
+ )
227
+ # Undeploy is idempotent: an already-absent adapter (404) is a no-op success, not an
228
+ # error — handle it before raise_for_status() so it never becomes a ServingError.
229
+ if resp.status_code == 404:
230
+ return []
231
+ resp.raise_for_status()
232
+ except httpx.HTTPStatusError as exc:
233
+ raise _serving_status_error(url, exc) from exc
234
+ except httpx.RequestError as exc:
235
+ raise ServingError(f"could not reach the serving backend at {url}: {exc}") from exc
236
+ logger.info("deregistered adapter %s from freesolo serving (%s)", run_id, base)
237
+ return [run_id]
238
+
239
+
240
+ def chat(
241
+ run_id: str,
242
+ messages: list[dict],
243
+ temperature: float = 0.0,
244
+ max_tokens: int = 512,
245
+ thinking: bool = False,
246
+ ) -> dict:
247
+ """Send an OpenAI-style chat request for the run's adapter to freesolo serving.
248
+
249
+ The adapter is addressed by ``model=run_id`` (its registered ``adapterId``); the
250
+ response is the parsed OpenAI chat-completion dict, so
251
+ ``resp["choices"][0]["message"]["content"]`` keeps working downstream.
252
+ """
253
+ base = serving_base_url()
254
+ body = {
255
+ "model": run_id,
256
+ "messages": messages,
257
+ "max_tokens": int(max_tokens),
258
+ "temperature": float(temperature),
259
+ # Per-run thinking parity: a run trained with thinking must serve with thinking, so
260
+ # forward the flag to the chat template (enable_thinking is the kwarg the renderer and
261
+ # rollout path use, e.g. multiturn_rollout.build_rollout_func). Without this the served
262
+ # completions diverge from training behavior even though the caller passes thinking=.
263
+ "chat_template_kwargs": {"enable_thinking": bool(thinking)},
264
+ }
265
+ # Cold starts (scale-from-zero per base model) can take minutes. Modal serves a slow ASGI
266
+ # request by 303-redirecting to an async-result poll URL (?__modal_function_call_id=...), so
267
+ # the client must follow redirects to retrieve the eventual completion — without this httpx
268
+ # raises on the 303 and the chat fails mid cold-start. max_redirects is raised because a long
269
+ # cold start polls across several redirect cycles before the result is ready.
270
+ with httpx.Client(follow_redirects=True, max_redirects=100, timeout=30 * 60.0) as client:
271
+ # The control plane is a trusted server-to-server caller (it already authorized the user's
272
+ # key on the /v1/runs/{run_id}/chat route), so present the internal key to pass serving's
273
+ # external chat-auth gate. No-op when the gate is off or the key is unset.
274
+ resp = client.post(f"{base}/v1/chat/completions", json=body, headers=_internal_key_header())
275
+ resp.raise_for_status()
276
+ return resp.json()
277
+
278
+
279
+ def _openai_stream_content(lines: Iterator[str]) -> Iterator[str]:
280
+ for line in lines:
281
+ line = line.strip()
282
+ if not line.startswith("data:"):
283
+ continue
284
+ data = line.removeprefix("data:").strip()
285
+ if data == "[DONE]":
286
+ break
287
+ if not data:
288
+ continue
289
+ chunk = json.loads(data)
290
+ for choice in chunk.get("choices") or []:
291
+ content = ((choice.get("delta") or {}).get("content")) or ""
292
+ if content:
293
+ yield str(content)
294
+
295
+
296
+ def chat_stream(
297
+ run_id: str,
298
+ messages: list[dict],
299
+ temperature: float = 0.0,
300
+ max_tokens: int = 512,
301
+ thinking: bool = False,
302
+ ) -> Iterator[str]:
303
+ """Yield text deltas from the freesolo OpenAI-compatible streaming endpoint."""
304
+ base = serving_base_url()
305
+ body = {
306
+ "model": run_id,
307
+ "messages": messages,
308
+ "max_tokens": int(max_tokens),
309
+ "temperature": float(temperature),
310
+ "chat_template_kwargs": {"enable_thinking": bool(thinking)},
311
+ "stream": True,
312
+ }
313
+ with (
314
+ httpx.Client(follow_redirects=True, max_redirects=100, timeout=30 * 60.0) as client,
315
+ client.stream(
316
+ "POST", f"{base}/v1/chat/completions", json=body, headers=_internal_key_header()
317
+ ) as resp,
318
+ ):
319
+ resp.raise_for_status()
320
+ if "application/json" in resp.headers.get("content-type", ""):
321
+ payload = resp.json()
322
+ content = (((payload.get("choices") or [{}])[0].get("message") or {}).get("content"))
323
+ if content:
324
+ yield str(content)
325
+ return
326
+ yield from _openai_stream_content(resp.iter_lines())