freesolo-flash-dev 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flash/__init__.py +29 -0
- flash/_channel.py +23 -0
- flash/_fileio.py +35 -0
- flash/_logging.py +49 -0
- flash/_update_check.py +266 -0
- flash/catalog.py +253 -0
- flash/cli/__init__.py +1 -0
- flash/cli/main/__init__.py +227 -0
- flash/cli/main/__main__.py +6 -0
- flash/cli/main/commands.py +636 -0
- flash/cli/main/envpush.py +317 -0
- flash/cli/main/render.py +599 -0
- flash/cli/main/training_doc.py +455 -0
- flash/client/__init__.py +14 -0
- flash/client/config.py +70 -0
- flash/client/http.py +372 -0
- flash/client/runtime_secrets.py +69 -0
- flash/client/specs.py +20 -0
- flash/cost/__init__.py +16 -0
- flash/cost/analytical.py +175 -0
- flash/cost/facts.py +114 -0
- flash/cost/spec.py +113 -0
- flash/cost/types.py +158 -0
- flash/engine/__init__.py +6 -0
- flash/engine/accounting.py +36 -0
- flash/engine/chalk_kernels.py +116 -0
- flash/engine/multiturn_rollout.py +780 -0
- flash/engine/recipe.py +86 -0
- flash/engine/vram.py +603 -0
- flash/engine/worker/__init__.py +2916 -0
- flash/engine/worker/__main__.py +4 -0
- flash/engine/worker/kernel_warmup.py +400 -0
- flash/engine/worker/lora.py +796 -0
- flash/engine/worker/packing.py +366 -0
- flash/engine/worker/perf.py +1048 -0
- flash/envs/__init__.py +10 -0
- flash/envs/adapter/__init__.py +883 -0
- flash/envs/adapter/rubric.py +222 -0
- flash/envs/base.py +52 -0
- flash/envs/registry.py +62 -0
- flash/mcp/__init__.py +1 -0
- flash/mcp/server.py +85 -0
- flash/providers/__init__.py +59 -0
- flash/providers/_auth.py +24 -0
- flash/providers/_http.py +230 -0
- flash/providers/_instance.py +416 -0
- flash/providers/_instance_bootstrap.py +517 -0
- flash/providers/_poll.py +311 -0
- flash/providers/allocator.py +193 -0
- flash/providers/base.py +431 -0
- flash/providers/hyperstack/__init__.py +127 -0
- flash/providers/hyperstack/api.py +522 -0
- flash/providers/hyperstack/auth.py +17 -0
- flash/providers/hyperstack/gpus.py +29 -0
- flash/providers/hyperstack/jobs/__init__.py +632 -0
- flash/providers/hyperstack/jobs/builders.py +122 -0
- flash/providers/hyperstack/preflight.py +23 -0
- flash/providers/hyperstack/pricing.py +26 -0
- flash/providers/hyperstack/train.py +25 -0
- flash/providers/lambdalabs/__init__.py +139 -0
- flash/providers/lambdalabs/api.py +261 -0
- flash/providers/lambdalabs/auth.py +18 -0
- flash/providers/lambdalabs/gpus.py +29 -0
- flash/providers/lambdalabs/jobs/__init__.py +724 -0
- flash/providers/lambdalabs/jobs/builders.py +118 -0
- flash/providers/lambdalabs/preflight.py +27 -0
- flash/providers/lambdalabs/pricing.py +51 -0
- flash/providers/lambdalabs/train.py +27 -0
- flash/providers/preflight.py +55 -0
- flash/providers/realized.py +80 -0
- flash/providers/runpod/__init__.py +130 -0
- flash/providers/runpod/api.py +186 -0
- flash/providers/runpod/auth.py +37 -0
- flash/providers/runpod/cost.py +57 -0
- flash/providers/runpod/gpus.py +46 -0
- flash/providers/runpod/jobs.py +956 -0
- flash/providers/runpod/keys.py +139 -0
- flash/providers/runpod/preflight.py +30 -0
- flash/providers/runpod/preload.py +915 -0
- flash/providers/runpod/pricing.py +18 -0
- flash/providers/runpod/slots.py +79 -0
- flash/providers/runpod/train/__init__.py +150 -0
- flash/providers/runpod/train/deps.py +395 -0
- flash/providers/runpod/train/endpoints.py +820 -0
- flash/py.typed +0 -0
- flash/runner/__init__.py +686 -0
- flash/runner/checkpoints.py +82 -0
- flash/runner/deploy.py +422 -0
- flash/runner/lifecycle.py +672 -0
- flash/schema/__init__.py +375 -0
- flash/schema/fields.py +331 -0
- flash/serve/__init__.py +1 -0
- flash/serve/deploy.py +326 -0
- flash/serve/pricing.py +60 -0
- flash/server/__init__.py +1 -0
- flash/server/__main__.py +20 -0
- flash/server/app.py +961 -0
- flash/server/auth.py +263 -0
- flash/server/billing.py +124 -0
- flash/server/checkpoints.py +110 -0
- flash/server/db.py +160 -0
- flash/server/environment_registry.py +102 -0
- flash/server/envs.py +360 -0
- flash/server/reconcile.py +163 -0
- flash/server/run_registry.py +150 -0
- flash/spec.py +333 -0
- freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
- freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
- freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
- freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
- freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
flash/server/auth.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
"""Bearer auth for the managed control plane.
|
|
2
|
+
|
|
3
|
+
User authentication is freesolo API keys only — there is no native key system. A bearer
|
|
4
|
+
token equal to the operator's shared ``FREESOLO_INTERNAL_KEY`` resolves to the service
|
|
5
|
+
identity; any other token is verified against the freesolo backend and (on success)
|
|
6
|
+
resolved to a per-token user identity. A failed/unreachable verify returns False (the key
|
|
7
|
+
is treated as unverified), so a backend outage never admits an unverified key.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import json
|
|
13
|
+
import os
|
|
14
|
+
import threading
|
|
15
|
+
import time
|
|
16
|
+
import urllib.error
|
|
17
|
+
import urllib.request
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
from . import db
|
|
21
|
+
|
|
22
|
+
# Operators set this to the shared freesolo internal key; a bearer token equal to it
|
|
23
|
+
# authenticates as the service identity (see db.ensure_internal_key).
|
|
24
|
+
INTERNAL_KEY_ENV = "FREESOLO_INTERNAL_KEY"
|
|
25
|
+
|
|
26
|
+
# Freesolo USER-key acceptance: a user who `flash login`s with a freesolo API key sends it as
|
|
27
|
+
# the bearer to this control plane. Any non-internal token is verified against the freesolo
|
|
28
|
+
# backend and (on success) resolved to a per-token identity.
|
|
29
|
+
FREESOLO_BASE_URL_ENV = "FREESOLO_BASE_URL"
|
|
30
|
+
DEFAULT_FREESOLO_BASE_URL = "https://api.freesolo.co"
|
|
31
|
+
_VERIFY_TIMEOUT_S = 5.0
|
|
32
|
+
_VERIFY_CACHE_TTL_S = 300.0 # short TTL so it isn't a backend round-trip per request
|
|
33
|
+
# Negative verdicts get a much SHORTER TTL than positives. The freesolo verify endpoint
|
|
34
|
+
# returns 401 not only for a genuinely-bad key but also when the backend converts an
|
|
35
|
+
# auth-LOOKUP infra exception (authenticate_api_key failure) into a 401 — a transient outage.
|
|
36
|
+
# Caching such a negative for the full 300s would lock out an otherwise-valid key for 5
|
|
37
|
+
# minutes after the backend recovers. A short negative TTL keeps persistent bad tokens
|
|
38
|
+
# rate-limited (~30s, so they don't hammer the backend) while letting a transient 401 clear
|
|
39
|
+
# quickly. Positives keep the long TTL.
|
|
40
|
+
_VERIFY_CACHE_NEG_TTL_S = 30.0
|
|
41
|
+
# Upper bound on a bearer token we'll cache/verify. Real freesolo API keys are short; an
|
|
42
|
+
# arbitrarily long bearer is rejected up front so it can't bloat _verify_cache (keyed by the
|
|
43
|
+
# raw token) or produce an oversized outbound Authorization header.
|
|
44
|
+
_MAX_TOKEN_LEN = 256
|
|
45
|
+
|
|
46
|
+
# In-process verify cache: token -> (verified_bool, expires_at). Caches positives AND
|
|
47
|
+
# negatives so a burst of requests for the same token hits the backend at most once per TTL.
|
|
48
|
+
# Bounded: pruned of expired entries on every write and capped at _VERIFY_CACHE_MAX so a
|
|
49
|
+
# stream of unique bearer tokens can't grow it without bound (each token is a distinct key).
|
|
50
|
+
_verify_cache: dict[str, tuple[bool, float]] = {}
|
|
51
|
+
_identity_cache: dict[str, tuple[dict[str, Any], float]] = {}
|
|
52
|
+
_verify_cache_lock = threading.Lock()
|
|
53
|
+
_VERIFY_CACHE_MAX = 1024
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _prune_verify_cache_locked(now: float) -> None:
|
|
57
|
+
"""Drop expired entries, then cap the cache size (oldest-expiry first).
|
|
58
|
+
|
|
59
|
+
Caller must hold ``_verify_cache_lock``. Keeps the cache from growing unbounded as
|
|
60
|
+
many distinct bearer tokens are verified over time.
|
|
61
|
+
"""
|
|
62
|
+
for tok in [t for t, (_v, exp) in _verify_cache.items() if exp <= now]:
|
|
63
|
+
del _verify_cache[tok]
|
|
64
|
+
_identity_cache.pop(tok, None)
|
|
65
|
+
for tok in [t for t, (_v, exp) in _identity_cache.items() if exp <= now]:
|
|
66
|
+
del _identity_cache[tok]
|
|
67
|
+
if len(_verify_cache) >= _VERIFY_CACHE_MAX:
|
|
68
|
+
# Still over the cap after dropping expired entries: evict the soonest-to-expire
|
|
69
|
+
# (oldest) entries until we're back under the cap.
|
|
70
|
+
for tok, _exp in sorted(_verify_cache.items(), key=lambda kv: kv[1][1])[
|
|
71
|
+
: len(_verify_cache) - _VERIFY_CACHE_MAX + 1
|
|
72
|
+
]:
|
|
73
|
+
del _verify_cache[tok]
|
|
74
|
+
_identity_cache.pop(tok, None)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _freesolo_key_prefix(token: str) -> str:
|
|
78
|
+
"""Non-secret preview of a Freesolo API key, matching the dashboard-style public prefix."""
|
|
79
|
+
parts = token.split("_", 2)
|
|
80
|
+
if len(parts) >= 2 and parts[0] == "fslo" and parts[1]:
|
|
81
|
+
return f"fslo_{parts[1]}"
|
|
82
|
+
return f"fslo_{db.hash_key(token)[:12]}"
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _external_key_prefix(token: str, identity: dict[str, Any]) -> str:
|
|
86
|
+
prefix = _str_field(identity.get("key_prefix"))
|
|
87
|
+
if prefix and prefix.startswith("fslo_"):
|
|
88
|
+
return prefix
|
|
89
|
+
if not identity and token.startswith("fslo-user-"):
|
|
90
|
+
return "freesolo"
|
|
91
|
+
return _freesolo_key_prefix(token)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _str_field(value: Any) -> str | None:
|
|
95
|
+
if isinstance(value, str) and value.strip():
|
|
96
|
+
return value.strip()
|
|
97
|
+
return None
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _identity_from_verify_body(raw: bytes) -> dict[str, Any]:
|
|
101
|
+
"""Extract optional identity fields from the freesolo verify response.
|
|
102
|
+
|
|
103
|
+
The backend historically returned only ``{"ok": true}``. This parser is deliberately
|
|
104
|
+
tolerant so Flash surfaces real fields when the backend includes them.
|
|
105
|
+
"""
|
|
106
|
+
if not raw:
|
|
107
|
+
return {}
|
|
108
|
+
try:
|
|
109
|
+
body = json.loads(raw)
|
|
110
|
+
except (TypeError, ValueError):
|
|
111
|
+
return {}
|
|
112
|
+
if not isinstance(body, dict):
|
|
113
|
+
return {}
|
|
114
|
+
|
|
115
|
+
user = body.get("user") if isinstance(body.get("user"), dict) else {}
|
|
116
|
+
org = body.get("org") if isinstance(body.get("org"), dict) else {}
|
|
117
|
+
api_key = body.get("api_key") if isinstance(body.get("api_key"), dict) else {}
|
|
118
|
+
|
|
119
|
+
fields = {
|
|
120
|
+
"email": _str_field(body.get("email")) or _str_field(user.get("email")),
|
|
121
|
+
"user_id": (
|
|
122
|
+
_str_field(body.get("user_id"))
|
|
123
|
+
or _str_field(body.get("created_by"))
|
|
124
|
+
or _str_field(user.get("id"))
|
|
125
|
+
),
|
|
126
|
+
"org_id": _str_field(body.get("org_id")) or _str_field(org.get("id")),
|
|
127
|
+
"api_key_id": (
|
|
128
|
+
_str_field(body.get("api_key_id"))
|
|
129
|
+
or _str_field(body.get("key_id"))
|
|
130
|
+
or _str_field(api_key.get("id"))
|
|
131
|
+
),
|
|
132
|
+
"key_prefix": _str_field(body.get("key_prefix")) or _str_field(api_key.get("key_prefix")),
|
|
133
|
+
"training_agent_job_id": _str_field(body.get("training_agent_job_id")),
|
|
134
|
+
"project_id": _str_field(body.get("project_id")),
|
|
135
|
+
}
|
|
136
|
+
return {k: v for k, v in fields.items() if v}
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _response_body(resp: Any) -> bytes:
|
|
140
|
+
read = getattr(resp, "read", None)
|
|
141
|
+
if not callable(read):
|
|
142
|
+
return b""
|
|
143
|
+
data = read()
|
|
144
|
+
if isinstance(data, bytes):
|
|
145
|
+
return data
|
|
146
|
+
if isinstance(data, str):
|
|
147
|
+
return data.encode()
|
|
148
|
+
return b""
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _cached_identity(token: str) -> dict[str, Any]:
|
|
152
|
+
now = time.time()
|
|
153
|
+
with _verify_cache_lock:
|
|
154
|
+
cached = _identity_cache.get(token)
|
|
155
|
+
if cached is not None and cached[1] > now:
|
|
156
|
+
return dict(cached[0])
|
|
157
|
+
return {}
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def _external_row(row: dict, token: str, identity: dict[str, Any]) -> dict:
|
|
161
|
+
out = dict(row)
|
|
162
|
+
out["auth_kind"] = "freesolo_api_key"
|
|
163
|
+
out["key_prefix"] = _external_key_prefix(token, identity)
|
|
164
|
+
if identity.get("email"):
|
|
165
|
+
out["email"] = identity["email"]
|
|
166
|
+
for key in ("user_id", "org_id", "api_key_id", "training_agent_job_id", "project_id"):
|
|
167
|
+
if identity.get(key):
|
|
168
|
+
out[key] = identity[key]
|
|
169
|
+
return out
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _identity_email(identity: dict[str, Any]) -> str:
|
|
173
|
+
email = str(identity.get("email") or "").strip()
|
|
174
|
+
return email if "@" in email else ""
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def freesolo_base_url() -> str:
|
|
178
|
+
"""The freesolo backend base URL (``FREESOLO_BASE_URL`` env, else the default), trailing
|
|
179
|
+
slash trimmed. Shared by auth verify and the billing client."""
|
|
180
|
+
return (os.environ.get(FREESOLO_BASE_URL_ENV) or DEFAULT_FREESOLO_BASE_URL).rstrip("/")
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _freesolo_verify(token: str) -> bool:
|
|
184
|
+
"""Verify a token against the freesolo backend (cached, short TTL, network errors = False).
|
|
185
|
+
|
|
186
|
+
Never raises — a swallowed network/HTTP error is treated as "not authenticated" (returns
|
|
187
|
+
False), never a 500."""
|
|
188
|
+
# Reject obviously-invalid oversized tokens before they touch the cache or the network.
|
|
189
|
+
if not token or len(token) > _MAX_TOKEN_LEN:
|
|
190
|
+
return False
|
|
191
|
+
now = time.time()
|
|
192
|
+
with _verify_cache_lock:
|
|
193
|
+
cached = _verify_cache.get(token)
|
|
194
|
+
if cached is not None and cached[1] > now:
|
|
195
|
+
return cached[0]
|
|
196
|
+
url = f"{freesolo_base_url()}/api/auth/verify"
|
|
197
|
+
req = urllib.request.Request(url, headers={"Authorization": f"Bearer {token}"})
|
|
198
|
+
identity: dict[str, Any] = {}
|
|
199
|
+
try:
|
|
200
|
+
with urllib.request.urlopen(req, timeout=_VERIFY_TIMEOUT_S) as resp:
|
|
201
|
+
verified = resp.status == 200
|
|
202
|
+
if verified:
|
|
203
|
+
identity = _identity_from_verify_body(_response_body(resp))
|
|
204
|
+
except urllib.error.HTTPError as exc:
|
|
205
|
+
# Only a DEFINITIVE rejection (4xx other than 429) is a verdict worth caching as a bad
|
|
206
|
+
# key. A 5xx or 429 is a transient backend hiccup — treat it like a network error
|
|
207
|
+
# (return False WITHOUT caching) so a valid key isn't locked out for the whole TTL
|
|
208
|
+
# while the backend is briefly unhealthy.
|
|
209
|
+
if exc.code >= 500 or exc.code == 429:
|
|
210
|
+
return False
|
|
211
|
+
verified = False
|
|
212
|
+
except (urllib.error.URLError, OSError, ValueError):
|
|
213
|
+
# A TRANSIENT network/connection error is NOT a verdict: don't cache it, so a valid
|
|
214
|
+
# key isn't locked out for the whole TTL after the backend recovers.
|
|
215
|
+
return False
|
|
216
|
+
with _verify_cache_lock:
|
|
217
|
+
# Prune expired entries and cap the size before inserting so unbounded distinct
|
|
218
|
+
# tokens can't grow the cache.
|
|
219
|
+
_prune_verify_cache_locked(now)
|
|
220
|
+
# Pick the TTL by verdict: positives last the full TTL; a negative (which may be a
|
|
221
|
+
# transient backend 401 rather than a real rejection) expires quickly so a valid key
|
|
222
|
+
# isn't locked out for 5 minutes after the backend recovers.
|
|
223
|
+
ttl = _VERIFY_CACHE_TTL_S if verified else _VERIFY_CACHE_NEG_TTL_S
|
|
224
|
+
_verify_cache[token] = (verified, now + ttl)
|
|
225
|
+
if verified:
|
|
226
|
+
_identity_cache[token] = (identity, now + ttl)
|
|
227
|
+
else:
|
|
228
|
+
_identity_cache.pop(token, None)
|
|
229
|
+
return verified
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def authenticate(authorization: str | None) -> dict | None:
|
|
233
|
+
"""Resolve an ``Authorization: Bearer ...`` header to a key row.
|
|
234
|
+
|
|
235
|
+
Freesolo keys are the only user auth. When the operator has configured
|
|
236
|
+
``FREESOLO_INTERNAL_KEY``, that shared internal key resolves to a single service
|
|
237
|
+
identity. Any other token is verified against the freesolo backend and (on success)
|
|
238
|
+
resolved to a per-token user identity so a user who ``flash login``s with their freesolo
|
|
239
|
+
key can drive the control plane. A token that can't be verified (bad key, or the backend
|
|
240
|
+
is unreachable) is treated as unverified -> authenticate returns None."""
|
|
241
|
+
if not authorization or not authorization.startswith("Bearer "):
|
|
242
|
+
return None
|
|
243
|
+
token = authorization.removeprefix("Bearer ").strip()
|
|
244
|
+
internal = os.environ.get(INTERNAL_KEY_ENV)
|
|
245
|
+
if internal and token == internal:
|
|
246
|
+
row = db.lookup_key(token) or db.ensure_internal_key(token)
|
|
247
|
+
out = dict(row)
|
|
248
|
+
out["auth_kind"] = "internal"
|
|
249
|
+
return out
|
|
250
|
+
# Any non-internal token is a freesolo USER key: verify it against the freesolo backend.
|
|
251
|
+
if _freesolo_verify(token):
|
|
252
|
+
# A verified freesolo key gets its own per-token run-ownership identity.
|
|
253
|
+
identity = _cached_identity(token)
|
|
254
|
+
email = _identity_email(identity)
|
|
255
|
+
if not email:
|
|
256
|
+
return None
|
|
257
|
+
row = db.lookup_key(token) or db.ensure_external_key(
|
|
258
|
+
token,
|
|
259
|
+
key_prefix=_external_key_prefix(token, identity),
|
|
260
|
+
email=email,
|
|
261
|
+
)
|
|
262
|
+
return _external_row(row, token, identity) if row is not None else None
|
|
263
|
+
return None
|
flash/server/billing.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""Charge completed Flash training runs to the freesolo backend.
|
|
2
|
+
|
|
3
|
+
The control plane records the submitting org id when a run is accepted, then POSTs the final
|
|
4
|
+
``RunStatus.cost_usd`` after the run reaches ``done``. The call is authenticated with the
|
|
5
|
+
operator internal key, so Flash never persists a user's freesolo API key while training runs.
|
|
6
|
+
Tests stub the network boundary directly."""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import urllib.error
|
|
12
|
+
import urllib.request
|
|
13
|
+
from decimal import ROUND_HALF_UP, Decimal
|
|
14
|
+
|
|
15
|
+
from .auth import freesolo_base_url
|
|
16
|
+
|
|
17
|
+
# A charge does more than a verify; allow a bit more than auth's 5s but stay bounded.
|
|
18
|
+
_CHARGE_TIMEOUT_S = 10.0
|
|
19
|
+
_COMPLETION_CHARGE_PATH = "/api/billing/training-usage/internal"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BillingError(Exception):
|
|
23
|
+
"""A run charge that didn't succeed. ``status_code`` (402 insufficient balance, 503 backend
|
|
24
|
+
unreachable) is surfaced to the client with ``detail``."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, status_code: int, detail: str) -> None:
|
|
27
|
+
super().__init__(detail)
|
|
28
|
+
self.status_code = status_code
|
|
29
|
+
self.detail = detail
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _cents(usd: float) -> int:
|
|
33
|
+
"""Whole cents for a USD amount, round-HALF-UP (not Python's banker's rounding, which would
|
|
34
|
+
undercharge a half-cent tie), never negative."""
|
|
35
|
+
cents = Decimal(str(usd)).scaleb(2).quantize(Decimal("1"), rounding=ROUND_HALF_UP)
|
|
36
|
+
return max(0, int(cents))
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _http_reason(exc: urllib.error.HTTPError) -> str:
|
|
40
|
+
"""The backend's status reason phrase (``.reason``/``.msg``), else the bare code."""
|
|
41
|
+
reason = getattr(exc, "reason", None) or getattr(exc, "msg", None)
|
|
42
|
+
return str(reason).strip() if reason else str(exc.code)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _http_error_detail(exc: urllib.error.HTTPError) -> str:
|
|
46
|
+
"""Clean message from the backend's ``{"detail": {"error", "code"}}`` JSON error body, else
|
|
47
|
+
the status reason (never a bare code)."""
|
|
48
|
+
|
|
49
|
+
def _fallback() -> str:
|
|
50
|
+
return f"billing failed ({exc.code} {_http_reason(exc)})"
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
body = json.loads(exc.read() or b"{}")
|
|
54
|
+
except (ValueError, OSError):
|
|
55
|
+
return _fallback()
|
|
56
|
+
detail = body.get("detail") if isinstance(body, dict) else None
|
|
57
|
+
if isinstance(detail, dict):
|
|
58
|
+
return str(detail.get("error") or detail.get("code") or _fallback())
|
|
59
|
+
if isinstance(detail, str) and detail:
|
|
60
|
+
return detail
|
|
61
|
+
return _fallback()
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _post_billing(*, token: str, path: str, body: dict) -> dict:
|
|
65
|
+
"""POST a JSON body to the backend billing ``path`` and return the parsed response.
|
|
66
|
+
|
|
67
|
+
Raises ``BillingError`` (the route's status + a clean detail) on a non-2xx, and ``503``
|
|
68
|
+
when the service is unreachable -- the same translation the charge and its reversal share.
|
|
69
|
+
"""
|
|
70
|
+
url = f"{freesolo_base_url()}{path}"
|
|
71
|
+
req = urllib.request.Request(
|
|
72
|
+
url,
|
|
73
|
+
data=json.dumps(body).encode("utf-8"),
|
|
74
|
+
method="POST",
|
|
75
|
+
headers={
|
|
76
|
+
"Authorization": f"Bearer {token}",
|
|
77
|
+
"Content-Type": "application/json",
|
|
78
|
+
},
|
|
79
|
+
)
|
|
80
|
+
try:
|
|
81
|
+
with urllib.request.urlopen(req, timeout=_CHARGE_TIMEOUT_S) as resp:
|
|
82
|
+
raw = resp.read()
|
|
83
|
+
except urllib.error.HTTPError as exc:
|
|
84
|
+
raise BillingError(exc.code, _http_error_detail(exc)) from exc
|
|
85
|
+
except (urllib.error.URLError, OSError) as exc:
|
|
86
|
+
raise BillingError(503, f"billing service unavailable: {exc}") from exc
|
|
87
|
+
try:
|
|
88
|
+
return json.loads(raw or b"{}")
|
|
89
|
+
except ValueError as exc:
|
|
90
|
+
# The backend responded but the body isn't JSON -- a bad gateway, not an outage.
|
|
91
|
+
raise BillingError(502, f"billing service returned an invalid response: {exc}") from exc
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def charge_completed_run(*, internal_key: str, status) -> dict:
|
|
95
|
+
"""Charge one completed external run using its persisted non-secret billing context.
|
|
96
|
+
|
|
97
|
+
The backend route is idempotent by ``runId``. Raises ``BillingError`` on a non-2xx or
|
|
98
|
+
unreachable backend; callers should record that billing failed without changing the run's
|
|
99
|
+
terminal training state.
|
|
100
|
+
"""
|
|
101
|
+
context = status.billing_context if isinstance(status.billing_context, dict) else {}
|
|
102
|
+
org_id = str(context.get("org_id") or "").strip()
|
|
103
|
+
if not org_id:
|
|
104
|
+
raise BillingError(400, "missing billing org id for completed training run")
|
|
105
|
+
spec = status.spec or {}
|
|
106
|
+
remote = status.remote or {}
|
|
107
|
+
gpu = remote.get("allocated_gpu") or (spec.get("gpu") or {}).get("type")
|
|
108
|
+
provider = remote.get("provider")
|
|
109
|
+
total_usd = float(status.cost_usd or 0.0)
|
|
110
|
+
body = {
|
|
111
|
+
"orgId": org_id,
|
|
112
|
+
"runId": status.run_id,
|
|
113
|
+
"costCents": _cents(total_usd),
|
|
114
|
+
"gpu": gpu,
|
|
115
|
+
"provider": provider,
|
|
116
|
+
"method": spec.get("algorithm"),
|
|
117
|
+
"model": spec.get("model"),
|
|
118
|
+
"estimate": {
|
|
119
|
+
"totalUsd": total_usd,
|
|
120
|
+
"costBasis": "final",
|
|
121
|
+
"costSource": "run_status.cost_usd",
|
|
122
|
+
},
|
|
123
|
+
}
|
|
124
|
+
return _post_billing(token=internal_key, path=_COMPLETION_CHARGE_PATH, body=body)
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""Mirror a run's deployable RL checkpoints to the freesolo backend.
|
|
2
|
+
|
|
3
|
+
The worker streams each step's LoRA adapter to the run's HF repo; HF is the source of truth
|
|
4
|
+
for what's deployable. This module persists that list to the backend's ``run_checkpoints``
|
|
5
|
+
store so the dashboard/SDK can enumerate a run's checkpoints without crawling HF, and so a
|
|
6
|
+
cancelled run's checkpoints survive in one queryable place.
|
|
7
|
+
|
|
8
|
+
Like ``flash.server.billing``, the POST is authenticated with the operator INTERNAL key (the
|
|
9
|
+
control plane never persists a user's freesolo key) and carries the org id from the run's
|
|
10
|
+
non-secret billing context. Unlike billing, checkpoint persistence is STRICTLY best-effort:
|
|
11
|
+
a failure here must never disturb a run or a deploy, so the public entry point swallows
|
|
12
|
+
everything."""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import os
|
|
18
|
+
import urllib.error
|
|
19
|
+
import urllib.request
|
|
20
|
+
|
|
21
|
+
from flash.runner.checkpoints import list_checkpoints
|
|
22
|
+
from flash.spec import JobSpec
|
|
23
|
+
|
|
24
|
+
from .auth import INTERNAL_KEY_ENV, freesolo_base_url
|
|
25
|
+
|
|
26
|
+
_TIMEOUT_S = 10.0
|
|
27
|
+
_RECORD_PATH = "/api/runs/internal/checkpoints"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _post_checkpoints(*, token: str, body: dict) -> dict:
|
|
31
|
+
"""POST the checkpoint batch to the backend; raise on any non-2xx/unreachable.
|
|
32
|
+
|
|
33
|
+
Callers in this module always wrap this in a best-effort guard — the raise exists so the
|
|
34
|
+
one network boundary is easy for tests to stub/assert."""
|
|
35
|
+
url = f"{freesolo_base_url()}{_RECORD_PATH}"
|
|
36
|
+
req = urllib.request.Request(
|
|
37
|
+
url,
|
|
38
|
+
data=json.dumps(body).encode("utf-8"),
|
|
39
|
+
method="POST",
|
|
40
|
+
headers={
|
|
41
|
+
"Authorization": f"Bearer {token}",
|
|
42
|
+
"Content-Type": "application/json",
|
|
43
|
+
},
|
|
44
|
+
)
|
|
45
|
+
with urllib.request.urlopen(req, timeout=_TIMEOUT_S) as resp:
|
|
46
|
+
raw = resp.read()
|
|
47
|
+
try:
|
|
48
|
+
return json.loads(raw or b"{}")
|
|
49
|
+
except ValueError:
|
|
50
|
+
return {}
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def register_run_checkpoints(*, internal_key: str, status, checkpoints: list[dict]) -> dict:
|
|
54
|
+
"""Upsert ``checkpoints`` for one run into the backend store (idempotent by run_id+step).
|
|
55
|
+
|
|
56
|
+
Pulls the org id from the run's persisted billing context (same source as billing). Raises
|
|
57
|
+
``ValueError`` when there's nothing to record or no org id; raises ``urllib`` errors through
|
|
58
|
+
on a backend failure — ``register_checkpoints_best_effort`` is the guarded wrapper most
|
|
59
|
+
callers use."""
|
|
60
|
+
if not checkpoints:
|
|
61
|
+
raise ValueError("no checkpoints to record")
|
|
62
|
+
context = status.billing_context if isinstance(status.billing_context, dict) else {}
|
|
63
|
+
org_id = str(context.get("org_id") or "").strip()
|
|
64
|
+
if not org_id:
|
|
65
|
+
raise ValueError("missing org id for run checkpoints")
|
|
66
|
+
spec = status.spec or {}
|
|
67
|
+
first = checkpoints[0]
|
|
68
|
+
body = {
|
|
69
|
+
"orgId": org_id,
|
|
70
|
+
"runId": status.run_id,
|
|
71
|
+
"baseModel": spec.get("model"),
|
|
72
|
+
"repoId": first.get("repo_id"),
|
|
73
|
+
"repoType": first.get("repo_type", "dataset"),
|
|
74
|
+
"checkpoints": [
|
|
75
|
+
{"step": c["step"], "subfolder": c["subfolder"]} for c in checkpoints
|
|
76
|
+
],
|
|
77
|
+
}
|
|
78
|
+
return _post_checkpoints(token=internal_key, body=body)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def register_checkpoints_best_effort(status, *, log=None) -> int:
|
|
82
|
+
"""List ``status``'s deployable checkpoints from HF and mirror them to the backend.
|
|
83
|
+
|
|
84
|
+
Returns the number of checkpoints submitted (0 if none, or if persistence was skipped /
|
|
85
|
+
failed). Never raises: the HF copy remains the source of truth, so a persistence miss only
|
|
86
|
+
costs the convenience of a DB-backed listing — not correctness."""
|
|
87
|
+
|
|
88
|
+
def _log(msg: str) -> None:
|
|
89
|
+
print(msg, file=log, flush=True) if log is not None else print(msg)
|
|
90
|
+
|
|
91
|
+
internal_key = os.environ.get(INTERNAL_KEY_ENV, "").strip()
|
|
92
|
+
if not internal_key:
|
|
93
|
+
return 0 # local/dev control plane: HF still has the checkpoints
|
|
94
|
+
try:
|
|
95
|
+
spec = JobSpec.from_dict(status.spec)
|
|
96
|
+
except Exception as exc:
|
|
97
|
+
_log(f"[ckpt] register skipped ({status.run_id}): bad spec: {exc}")
|
|
98
|
+
return 0
|
|
99
|
+
checkpoints = list_checkpoints(spec)
|
|
100
|
+
if not checkpoints:
|
|
101
|
+
return 0
|
|
102
|
+
try:
|
|
103
|
+
register_run_checkpoints(
|
|
104
|
+
internal_key=internal_key, status=status, checkpoints=checkpoints
|
|
105
|
+
)
|
|
106
|
+
except (ValueError, urllib.error.URLError, urllib.error.HTTPError, OSError) as exc:
|
|
107
|
+
_log(f"[ckpt] backend register warn ({status.run_id}): {exc}")
|
|
108
|
+
return 0
|
|
109
|
+
_log(f"[ckpt] registered {len(checkpoints)} checkpoint(s) for {status.run_id}")
|
|
110
|
+
return len(checkpoints)
|
flash/server/db.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
"""SQLite store for the managed control plane: API keys + run ownership.
|
|
2
|
+
|
|
3
|
+
Run *state* stays in the runner's JSON files (``runner.RUNS_DIR``) — the
|
|
4
|
+
battle-tested submit/attach/cancel paths all read those. This database is only the
|
|
5
|
+
key registry and the run -> key ownership index that makes the server multi-tenant.
|
|
6
|
+
|
|
7
|
+
Connections are opened per operation (cheap for SQLite, avoids cross-thread state;
|
|
8
|
+
the runner runs jobs in daemon threads inside the same process).
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import hashlib
|
|
14
|
+
import sqlite3
|
|
15
|
+
import time
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
|
|
18
|
+
_SCHEMA = """
|
|
19
|
+
CREATE TABLE IF NOT EXISTS api_keys (
|
|
20
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
21
|
+
key_hash TEXT NOT NULL UNIQUE,
|
|
22
|
+
key_prefix TEXT NOT NULL,
|
|
23
|
+
email TEXT,
|
|
24
|
+
created_at REAL NOT NULL,
|
|
25
|
+
last_used_at REAL,
|
|
26
|
+
disabled INTEGER NOT NULL DEFAULT 0
|
|
27
|
+
);
|
|
28
|
+
CREATE TABLE IF NOT EXISTS runs (
|
|
29
|
+
run_id TEXT PRIMARY KEY,
|
|
30
|
+
key_id INTEGER NOT NULL REFERENCES api_keys(id),
|
|
31
|
+
kind TEXT NOT NULL DEFAULT 'train',
|
|
32
|
+
created_at REAL NOT NULL
|
|
33
|
+
);
|
|
34
|
+
CREATE INDEX IF NOT EXISTS runs_key_idx ON runs(key_id);
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# Fixed location for the keys/run-ownership SQLite DB (not operator-configurable). Tests
|
|
39
|
+
# point it elsewhere with monkeypatch.setattr(db, "DB_PATH", tmp).
|
|
40
|
+
DB_PATH = str(Path.home() / ".flash" / "server.db")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def db_path() -> str:
|
|
44
|
+
return DB_PATH
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _connect() -> sqlite3.Connection:
|
|
48
|
+
path = db_path()
|
|
49
|
+
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
|
50
|
+
conn = sqlite3.connect(path, timeout=30.0)
|
|
51
|
+
conn.row_factory = sqlite3.Row
|
|
52
|
+
conn.execute("PRAGMA foreign_keys=ON")
|
|
53
|
+
conn.execute("PRAGMA journal_mode=WAL")
|
|
54
|
+
conn.executescript(_SCHEMA)
|
|
55
|
+
return conn
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def hash_key(api_key: str) -> str:
|
|
59
|
+
# API keys are 192-bit random tokens (secrets.token_hex(24)), not passwords:
|
|
60
|
+
# brute-forcing the keyspace is infeasible, so an unsalted fast hash is the
|
|
61
|
+
# standard at-rest form and keeps O(1) lookup by hash. (CodeQL's
|
|
62
|
+
# password-hashing rule does not apply to high-entropy machine tokens.)
|
|
63
|
+
return hashlib.sha256(api_key.encode()).hexdigest()
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def ensure_internal_key(api_key: str) -> dict:
|
|
67
|
+
"""Provision a row for the shared freesolo internal/service key (idempotent).
|
|
68
|
+
|
|
69
|
+
The freesolo platform/SDK authenticate to the control plane with the same
|
|
70
|
+
``FREESOLO_INTERNAL_KEY`` they already hold. Backing it with a real row
|
|
71
|
+
(inserted once, by hash) means run ownership and
|
|
72
|
+
the runs.key_id foreign key work exactly as for a normal key — all
|
|
73
|
+
internal-key runs share this single service identity (no per-user isolation;
|
|
74
|
+
the platform scopes users upstream)."""
|
|
75
|
+
now = time.time()
|
|
76
|
+
internal_email = "internal@freesolo.co"
|
|
77
|
+
with _connect() as conn:
|
|
78
|
+
conn.execute(
|
|
79
|
+
"INSERT OR IGNORE INTO api_keys (key_hash, key_prefix, email, created_at) "
|
|
80
|
+
"VALUES (?, ?, ?, ?)",
|
|
81
|
+
(hash_key(api_key), "internal", internal_email, now),
|
|
82
|
+
)
|
|
83
|
+
conn.execute(
|
|
84
|
+
"UPDATE api_keys SET email = ? WHERE key_hash = ? AND key_prefix = ?",
|
|
85
|
+
(internal_email, hash_key(api_key), "internal"),
|
|
86
|
+
)
|
|
87
|
+
row = lookup_key(api_key)
|
|
88
|
+
if row is None: # pragma: no cover - the row was just inserted
|
|
89
|
+
raise RuntimeError("failed to provision the internal service key")
|
|
90
|
+
return row
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def ensure_external_key(
|
|
94
|
+
api_key: str, *, key_prefix: str | None = None, email: str | None = None
|
|
95
|
+
) -> dict | None:
|
|
96
|
+
"""Provision a per-token row for a verified external (freesolo USER) key (idempotent).
|
|
97
|
+
|
|
98
|
+
Unlike :func:`ensure_internal_key` (one shared service identity), this keys a distinct
|
|
99
|
+
row by the presented token's hash, so each freesolo user key gets its OWN run-ownership
|
|
100
|
+
identity (the runs.key_id foreign key then scopes runs per user). The full token is never
|
|
101
|
+
stored — only its sha256.
|
|
102
|
+
|
|
103
|
+
Returns ``None`` (not a row) when the token's row already exists but is DISABLED:
|
|
104
|
+
``INSERT OR IGNORE`` won't revive it and ``lookup_key`` filters disabled rows, so a
|
|
105
|
+
revoked key is rejected (401) by the caller instead of surfacing as a 500."""
|
|
106
|
+
now = time.time()
|
|
107
|
+
with _connect() as conn:
|
|
108
|
+
conn.execute(
|
|
109
|
+
"INSERT OR IGNORE INTO api_keys (key_hash, key_prefix, email, created_at) "
|
|
110
|
+
"VALUES (?, ?, ?, ?)",
|
|
111
|
+
(hash_key(api_key), key_prefix or "freesolo", email, now),
|
|
112
|
+
)
|
|
113
|
+
return lookup_key(api_key)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def lookup_key(api_key: str) -> dict | None:
|
|
117
|
+
"""Resolve a presented key to its row (and touch last_used_at); None if unknown/disabled."""
|
|
118
|
+
with _connect() as conn:
|
|
119
|
+
row = conn.execute(
|
|
120
|
+
"SELECT * FROM api_keys WHERE key_hash = ? AND disabled = 0",
|
|
121
|
+
(hash_key(api_key),),
|
|
122
|
+
).fetchone()
|
|
123
|
+
if row is None:
|
|
124
|
+
return None
|
|
125
|
+
conn.execute("UPDATE api_keys SET last_used_at = ? WHERE id = ?", (time.time(), row["id"]))
|
|
126
|
+
return dict(row)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def record_run(run_id: str, key_id: int) -> None:
|
|
130
|
+
with _connect() as conn:
|
|
131
|
+
conn.execute(
|
|
132
|
+
"INSERT INTO runs (run_id, key_id, kind, created_at) VALUES (?, ?, ?, ?)",
|
|
133
|
+
(run_id, key_id, "train", time.time()),
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def delete_run(run_id: str) -> None:
|
|
138
|
+
with _connect() as conn:
|
|
139
|
+
conn.execute("DELETE FROM runs WHERE run_id = ?", (run_id,))
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def run_owner(run_id: str) -> int | None:
|
|
143
|
+
with _connect() as conn:
|
|
144
|
+
row = conn.execute("SELECT key_id FROM runs WHERE run_id = ?", (run_id,)).fetchone()
|
|
145
|
+
return row["key_id"] if row else None
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def runs_for_key(key_id: int) -> list[dict]:
|
|
149
|
+
with _connect() as conn:
|
|
150
|
+
rows = conn.execute(
|
|
151
|
+
"SELECT run_id, kind, created_at FROM runs WHERE key_id = ? ORDER BY created_at",
|
|
152
|
+
(key_id,),
|
|
153
|
+
).fetchall()
|
|
154
|
+
return [dict(r) for r in rows]
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def all_runs() -> list[dict]:
|
|
158
|
+
with _connect() as conn:
|
|
159
|
+
rows = conn.execute("SELECT run_id, key_id, kind, created_at FROM runs").fetchall()
|
|
160
|
+
return [dict(r) for r in rows]
|