freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
flash/server/auth.py ADDED
@@ -0,0 +1,263 @@
1
+ """Bearer auth for the managed control plane.
2
+
3
+ User authentication is freesolo API keys only — there is no native key system. A bearer
4
+ token equal to the operator's shared ``FREESOLO_INTERNAL_KEY`` resolves to the service
5
+ identity; any other token is verified against the freesolo backend and (on success)
6
+ resolved to a per-token user identity. A failed/unreachable verify returns False (the key
7
+ is treated as unverified), so a backend outage never admits an unverified key.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ import os
14
+ import threading
15
+ import time
16
+ import urllib.error
17
+ import urllib.request
18
+ from typing import Any
19
+
20
+ from . import db
21
+
22
+ # Operators set this to the shared freesolo internal key; a bearer token equal to it
23
+ # authenticates as the service identity (see db.ensure_internal_key).
24
+ INTERNAL_KEY_ENV = "FREESOLO_INTERNAL_KEY"
25
+
26
+ # Freesolo USER-key acceptance: a user who `flash login`s with a freesolo API key sends it as
27
+ # the bearer to this control plane. Any non-internal token is verified against the freesolo
28
+ # backend and (on success) resolved to a per-token identity.
29
+ FREESOLO_BASE_URL_ENV = "FREESOLO_BASE_URL"
30
+ DEFAULT_FREESOLO_BASE_URL = "https://api.freesolo.co"
31
+ _VERIFY_TIMEOUT_S = 5.0
32
+ _VERIFY_CACHE_TTL_S = 300.0 # short TTL so it isn't a backend round-trip per request
33
+ # Negative verdicts get a much SHORTER TTL than positives. The freesolo verify endpoint
34
+ # returns 401 not only for a genuinely-bad key but also when the backend converts an
35
+ # auth-LOOKUP infra exception (authenticate_api_key failure) into a 401 — a transient outage.
36
+ # Caching such a negative for the full 300s would lock out an otherwise-valid key for 5
37
+ # minutes after the backend recovers. A short negative TTL keeps persistent bad tokens
38
+ # rate-limited (~30s, so they don't hammer the backend) while letting a transient 401 clear
39
+ # quickly. Positives keep the long TTL.
40
+ _VERIFY_CACHE_NEG_TTL_S = 30.0
41
+ # Upper bound on a bearer token we'll cache/verify. Real freesolo API keys are short; an
42
+ # arbitrarily long bearer is rejected up front so it can't bloat _verify_cache (keyed by the
43
+ # raw token) or produce an oversized outbound Authorization header.
44
+ _MAX_TOKEN_LEN = 256
45
+
46
+ # In-process verify cache: token -> (verified_bool, expires_at). Caches positives AND
47
+ # negatives so a burst of requests for the same token hits the backend at most once per TTL.
48
+ # Bounded: pruned of expired entries on every write and capped at _VERIFY_CACHE_MAX so a
49
+ # stream of unique bearer tokens can't grow it without bound (each token is a distinct key).
50
+ _verify_cache: dict[str, tuple[bool, float]] = {}
51
+ _identity_cache: dict[str, tuple[dict[str, Any], float]] = {}
52
+ _verify_cache_lock = threading.Lock()
53
+ _VERIFY_CACHE_MAX = 1024
54
+
55
+
56
+ def _prune_verify_cache_locked(now: float) -> None:
57
+ """Drop expired entries, then cap the cache size (oldest-expiry first).
58
+
59
+ Caller must hold ``_verify_cache_lock``. Keeps the cache from growing unbounded as
60
+ many distinct bearer tokens are verified over time.
61
+ """
62
+ for tok in [t for t, (_v, exp) in _verify_cache.items() if exp <= now]:
63
+ del _verify_cache[tok]
64
+ _identity_cache.pop(tok, None)
65
+ for tok in [t for t, (_v, exp) in _identity_cache.items() if exp <= now]:
66
+ del _identity_cache[tok]
67
+ if len(_verify_cache) >= _VERIFY_CACHE_MAX:
68
+ # Still over the cap after dropping expired entries: evict the soonest-to-expire
69
+ # (oldest) entries until we're back under the cap.
70
+ for tok, _exp in sorted(_verify_cache.items(), key=lambda kv: kv[1][1])[
71
+ : len(_verify_cache) - _VERIFY_CACHE_MAX + 1
72
+ ]:
73
+ del _verify_cache[tok]
74
+ _identity_cache.pop(tok, None)
75
+
76
+
77
+ def _freesolo_key_prefix(token: str) -> str:
78
+ """Non-secret preview of a Freesolo API key, matching the dashboard-style public prefix."""
79
+ parts = token.split("_", 2)
80
+ if len(parts) >= 2 and parts[0] == "fslo" and parts[1]:
81
+ return f"fslo_{parts[1]}"
82
+ return f"fslo_{db.hash_key(token)[:12]}"
83
+
84
+
85
+ def _external_key_prefix(token: str, identity: dict[str, Any]) -> str:
86
+ prefix = _str_field(identity.get("key_prefix"))
87
+ if prefix and prefix.startswith("fslo_"):
88
+ return prefix
89
+ if not identity and token.startswith("fslo-user-"):
90
+ return "freesolo"
91
+ return _freesolo_key_prefix(token)
92
+
93
+
94
+ def _str_field(value: Any) -> str | None:
95
+ if isinstance(value, str) and value.strip():
96
+ return value.strip()
97
+ return None
98
+
99
+
100
+ def _identity_from_verify_body(raw: bytes) -> dict[str, Any]:
101
+ """Extract optional identity fields from the freesolo verify response.
102
+
103
+ The backend historically returned only ``{"ok": true}``. This parser is deliberately
104
+ tolerant so Flash surfaces real fields when the backend includes them.
105
+ """
106
+ if not raw:
107
+ return {}
108
+ try:
109
+ body = json.loads(raw)
110
+ except (TypeError, ValueError):
111
+ return {}
112
+ if not isinstance(body, dict):
113
+ return {}
114
+
115
+ user = body.get("user") if isinstance(body.get("user"), dict) else {}
116
+ org = body.get("org") if isinstance(body.get("org"), dict) else {}
117
+ api_key = body.get("api_key") if isinstance(body.get("api_key"), dict) else {}
118
+
119
+ fields = {
120
+ "email": _str_field(body.get("email")) or _str_field(user.get("email")),
121
+ "user_id": (
122
+ _str_field(body.get("user_id"))
123
+ or _str_field(body.get("created_by"))
124
+ or _str_field(user.get("id"))
125
+ ),
126
+ "org_id": _str_field(body.get("org_id")) or _str_field(org.get("id")),
127
+ "api_key_id": (
128
+ _str_field(body.get("api_key_id"))
129
+ or _str_field(body.get("key_id"))
130
+ or _str_field(api_key.get("id"))
131
+ ),
132
+ "key_prefix": _str_field(body.get("key_prefix")) or _str_field(api_key.get("key_prefix")),
133
+ "training_agent_job_id": _str_field(body.get("training_agent_job_id")),
134
+ "project_id": _str_field(body.get("project_id")),
135
+ }
136
+ return {k: v for k, v in fields.items() if v}
137
+
138
+
139
+ def _response_body(resp: Any) -> bytes:
140
+ read = getattr(resp, "read", None)
141
+ if not callable(read):
142
+ return b""
143
+ data = read()
144
+ if isinstance(data, bytes):
145
+ return data
146
+ if isinstance(data, str):
147
+ return data.encode()
148
+ return b""
149
+
150
+
151
+ def _cached_identity(token: str) -> dict[str, Any]:
152
+ now = time.time()
153
+ with _verify_cache_lock:
154
+ cached = _identity_cache.get(token)
155
+ if cached is not None and cached[1] > now:
156
+ return dict(cached[0])
157
+ return {}
158
+
159
+
160
+ def _external_row(row: dict, token: str, identity: dict[str, Any]) -> dict:
161
+ out = dict(row)
162
+ out["auth_kind"] = "freesolo_api_key"
163
+ out["key_prefix"] = _external_key_prefix(token, identity)
164
+ if identity.get("email"):
165
+ out["email"] = identity["email"]
166
+ for key in ("user_id", "org_id", "api_key_id", "training_agent_job_id", "project_id"):
167
+ if identity.get(key):
168
+ out[key] = identity[key]
169
+ return out
170
+
171
+
172
+ def _identity_email(identity: dict[str, Any]) -> str:
173
+ email = str(identity.get("email") or "").strip()
174
+ return email if "@" in email else ""
175
+
176
+
177
+ def freesolo_base_url() -> str:
178
+ """The freesolo backend base URL (``FREESOLO_BASE_URL`` env, else the default), trailing
179
+ slash trimmed. Shared by auth verify and the billing client."""
180
+ return (os.environ.get(FREESOLO_BASE_URL_ENV) or DEFAULT_FREESOLO_BASE_URL).rstrip("/")
181
+
182
+
183
+ def _freesolo_verify(token: str) -> bool:
184
+ """Verify a token against the freesolo backend (cached, short TTL, network errors = False).
185
+
186
+ Never raises — a swallowed network/HTTP error is treated as "not authenticated" (returns
187
+ False), never a 500."""
188
+ # Reject obviously-invalid oversized tokens before they touch the cache or the network.
189
+ if not token or len(token) > _MAX_TOKEN_LEN:
190
+ return False
191
+ now = time.time()
192
+ with _verify_cache_lock:
193
+ cached = _verify_cache.get(token)
194
+ if cached is not None and cached[1] > now:
195
+ return cached[0]
196
+ url = f"{freesolo_base_url()}/api/auth/verify"
197
+ req = urllib.request.Request(url, headers={"Authorization": f"Bearer {token}"})
198
+ identity: dict[str, Any] = {}
199
+ try:
200
+ with urllib.request.urlopen(req, timeout=_VERIFY_TIMEOUT_S) as resp:
201
+ verified = resp.status == 200
202
+ if verified:
203
+ identity = _identity_from_verify_body(_response_body(resp))
204
+ except urllib.error.HTTPError as exc:
205
+ # Only a DEFINITIVE rejection (4xx other than 429) is a verdict worth caching as a bad
206
+ # key. A 5xx or 429 is a transient backend hiccup — treat it like a network error
207
+ # (return False WITHOUT caching) so a valid key isn't locked out for the whole TTL
208
+ # while the backend is briefly unhealthy.
209
+ if exc.code >= 500 or exc.code == 429:
210
+ return False
211
+ verified = False
212
+ except (urllib.error.URLError, OSError, ValueError):
213
+ # A TRANSIENT network/connection error is NOT a verdict: don't cache it, so a valid
214
+ # key isn't locked out for the whole TTL after the backend recovers.
215
+ return False
216
+ with _verify_cache_lock:
217
+ # Prune expired entries and cap the size before inserting so unbounded distinct
218
+ # tokens can't grow the cache.
219
+ _prune_verify_cache_locked(now)
220
+ # Pick the TTL by verdict: positives last the full TTL; a negative (which may be a
221
+ # transient backend 401 rather than a real rejection) expires quickly so a valid key
222
+ # isn't locked out for 5 minutes after the backend recovers.
223
+ ttl = _VERIFY_CACHE_TTL_S if verified else _VERIFY_CACHE_NEG_TTL_S
224
+ _verify_cache[token] = (verified, now + ttl)
225
+ if verified:
226
+ _identity_cache[token] = (identity, now + ttl)
227
+ else:
228
+ _identity_cache.pop(token, None)
229
+ return verified
230
+
231
+
232
+ def authenticate(authorization: str | None) -> dict | None:
233
+ """Resolve an ``Authorization: Bearer ...`` header to a key row.
234
+
235
+ Freesolo keys are the only user auth. When the operator has configured
236
+ ``FREESOLO_INTERNAL_KEY``, that shared internal key resolves to a single service
237
+ identity. Any other token is verified against the freesolo backend and (on success)
238
+ resolved to a per-token user identity so a user who ``flash login``s with their freesolo
239
+ key can drive the control plane. A token that can't be verified (bad key, or the backend
240
+ is unreachable) is treated as unverified -> authenticate returns None."""
241
+ if not authorization or not authorization.startswith("Bearer "):
242
+ return None
243
+ token = authorization.removeprefix("Bearer ").strip()
244
+ internal = os.environ.get(INTERNAL_KEY_ENV)
245
+ if internal and token == internal:
246
+ row = db.lookup_key(token) or db.ensure_internal_key(token)
247
+ out = dict(row)
248
+ out["auth_kind"] = "internal"
249
+ return out
250
+ # Any non-internal token is a freesolo USER key: verify it against the freesolo backend.
251
+ if _freesolo_verify(token):
252
+ # A verified freesolo key gets its own per-token run-ownership identity.
253
+ identity = _cached_identity(token)
254
+ email = _identity_email(identity)
255
+ if not email:
256
+ return None
257
+ row = db.lookup_key(token) or db.ensure_external_key(
258
+ token,
259
+ key_prefix=_external_key_prefix(token, identity),
260
+ email=email,
261
+ )
262
+ return _external_row(row, token, identity) if row is not None else None
263
+ return None
@@ -0,0 +1,124 @@
1
+ """Charge completed Flash training runs to the freesolo backend.
2
+
3
+ The control plane records the submitting org id when a run is accepted, then POSTs the final
4
+ ``RunStatus.cost_usd`` after the run reaches ``done``. The call is authenticated with the
5
+ operator internal key, so Flash never persists a user's freesolo API key while training runs.
6
+ Tests stub the network boundary directly."""
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import urllib.error
12
+ import urllib.request
13
+ from decimal import ROUND_HALF_UP, Decimal
14
+
15
+ from .auth import freesolo_base_url
16
+
17
+ # A charge does more than a verify; allow a bit more than auth's 5s but stay bounded.
18
+ _CHARGE_TIMEOUT_S = 10.0
19
+ _COMPLETION_CHARGE_PATH = "/api/billing/training-usage/internal"
20
+
21
+
22
+ class BillingError(Exception):
23
+ """A run charge that didn't succeed. ``status_code`` (402 insufficient balance, 503 backend
24
+ unreachable) is surfaced to the client with ``detail``."""
25
+
26
+ def __init__(self, status_code: int, detail: str) -> None:
27
+ super().__init__(detail)
28
+ self.status_code = status_code
29
+ self.detail = detail
30
+
31
+
32
+ def _cents(usd: float) -> int:
33
+ """Whole cents for a USD amount, round-HALF-UP (not Python's banker's rounding, which would
34
+ undercharge a half-cent tie), never negative."""
35
+ cents = Decimal(str(usd)).scaleb(2).quantize(Decimal("1"), rounding=ROUND_HALF_UP)
36
+ return max(0, int(cents))
37
+
38
+
39
+ def _http_reason(exc: urllib.error.HTTPError) -> str:
40
+ """The backend's status reason phrase (``.reason``/``.msg``), else the bare code."""
41
+ reason = getattr(exc, "reason", None) or getattr(exc, "msg", None)
42
+ return str(reason).strip() if reason else str(exc.code)
43
+
44
+
45
+ def _http_error_detail(exc: urllib.error.HTTPError) -> str:
46
+ """Clean message from the backend's ``{"detail": {"error", "code"}}`` JSON error body, else
47
+ the status reason (never a bare code)."""
48
+
49
+ def _fallback() -> str:
50
+ return f"billing failed ({exc.code} {_http_reason(exc)})"
51
+
52
+ try:
53
+ body = json.loads(exc.read() or b"{}")
54
+ except (ValueError, OSError):
55
+ return _fallback()
56
+ detail = body.get("detail") if isinstance(body, dict) else None
57
+ if isinstance(detail, dict):
58
+ return str(detail.get("error") or detail.get("code") or _fallback())
59
+ if isinstance(detail, str) and detail:
60
+ return detail
61
+ return _fallback()
62
+
63
+
64
+ def _post_billing(*, token: str, path: str, body: dict) -> dict:
65
+ """POST a JSON body to the backend billing ``path`` and return the parsed response.
66
+
67
+ Raises ``BillingError`` (the route's status + a clean detail) on a non-2xx, and ``503``
68
+ when the service is unreachable -- the same translation the charge and its reversal share.
69
+ """
70
+ url = f"{freesolo_base_url()}{path}"
71
+ req = urllib.request.Request(
72
+ url,
73
+ data=json.dumps(body).encode("utf-8"),
74
+ method="POST",
75
+ headers={
76
+ "Authorization": f"Bearer {token}",
77
+ "Content-Type": "application/json",
78
+ },
79
+ )
80
+ try:
81
+ with urllib.request.urlopen(req, timeout=_CHARGE_TIMEOUT_S) as resp:
82
+ raw = resp.read()
83
+ except urllib.error.HTTPError as exc:
84
+ raise BillingError(exc.code, _http_error_detail(exc)) from exc
85
+ except (urllib.error.URLError, OSError) as exc:
86
+ raise BillingError(503, f"billing service unavailable: {exc}") from exc
87
+ try:
88
+ return json.loads(raw or b"{}")
89
+ except ValueError as exc:
90
+ # The backend responded but the body isn't JSON -- a bad gateway, not an outage.
91
+ raise BillingError(502, f"billing service returned an invalid response: {exc}") from exc
92
+
93
+
94
+ def charge_completed_run(*, internal_key: str, status) -> dict:
95
+ """Charge one completed external run using its persisted non-secret billing context.
96
+
97
+ The backend route is idempotent by ``runId``. Raises ``BillingError`` on a non-2xx or
98
+ unreachable backend; callers should record that billing failed without changing the run's
99
+ terminal training state.
100
+ """
101
+ context = status.billing_context if isinstance(status.billing_context, dict) else {}
102
+ org_id = str(context.get("org_id") or "").strip()
103
+ if not org_id:
104
+ raise BillingError(400, "missing billing org id for completed training run")
105
+ spec = status.spec or {}
106
+ remote = status.remote or {}
107
+ gpu = remote.get("allocated_gpu") or (spec.get("gpu") or {}).get("type")
108
+ provider = remote.get("provider")
109
+ total_usd = float(status.cost_usd or 0.0)
110
+ body = {
111
+ "orgId": org_id,
112
+ "runId": status.run_id,
113
+ "costCents": _cents(total_usd),
114
+ "gpu": gpu,
115
+ "provider": provider,
116
+ "method": spec.get("algorithm"),
117
+ "model": spec.get("model"),
118
+ "estimate": {
119
+ "totalUsd": total_usd,
120
+ "costBasis": "final",
121
+ "costSource": "run_status.cost_usd",
122
+ },
123
+ }
124
+ return _post_billing(token=internal_key, path=_COMPLETION_CHARGE_PATH, body=body)
@@ -0,0 +1,110 @@
1
+ """Mirror a run's deployable RL checkpoints to the freesolo backend.
2
+
3
+ The worker streams each step's LoRA adapter to the run's HF repo; HF is the source of truth
4
+ for what's deployable. This module persists that list to the backend's ``run_checkpoints``
5
+ store so the dashboard/SDK can enumerate a run's checkpoints without crawling HF, and so a
6
+ cancelled run's checkpoints survive in one queryable place.
7
+
8
+ Like ``flash.server.billing``, the POST is authenticated with the operator INTERNAL key (the
9
+ control plane never persists a user's freesolo key) and carries the org id from the run's
10
+ non-secret billing context. Unlike billing, checkpoint persistence is STRICTLY best-effort:
11
+ a failure here must never disturb a run or a deploy, so the public entry point swallows
12
+ everything."""
13
+
14
+ from __future__ import annotations
15
+
16
+ import json
17
+ import os
18
+ import urllib.error
19
+ import urllib.request
20
+
21
+ from flash.runner.checkpoints import list_checkpoints
22
+ from flash.spec import JobSpec
23
+
24
+ from .auth import INTERNAL_KEY_ENV, freesolo_base_url
25
+
26
+ _TIMEOUT_S = 10.0
27
+ _RECORD_PATH = "/api/runs/internal/checkpoints"
28
+
29
+
30
+ def _post_checkpoints(*, token: str, body: dict) -> dict:
31
+ """POST the checkpoint batch to the backend; raise on any non-2xx/unreachable.
32
+
33
+ Callers in this module always wrap this in a best-effort guard — the raise exists so the
34
+ one network boundary is easy for tests to stub/assert."""
35
+ url = f"{freesolo_base_url()}{_RECORD_PATH}"
36
+ req = urllib.request.Request(
37
+ url,
38
+ data=json.dumps(body).encode("utf-8"),
39
+ method="POST",
40
+ headers={
41
+ "Authorization": f"Bearer {token}",
42
+ "Content-Type": "application/json",
43
+ },
44
+ )
45
+ with urllib.request.urlopen(req, timeout=_TIMEOUT_S) as resp:
46
+ raw = resp.read()
47
+ try:
48
+ return json.loads(raw or b"{}")
49
+ except ValueError:
50
+ return {}
51
+
52
+
53
+ def register_run_checkpoints(*, internal_key: str, status, checkpoints: list[dict]) -> dict:
54
+ """Upsert ``checkpoints`` for one run into the backend store (idempotent by run_id+step).
55
+
56
+ Pulls the org id from the run's persisted billing context (same source as billing). Raises
57
+ ``ValueError`` when there's nothing to record or no org id; raises ``urllib`` errors through
58
+ on a backend failure — ``register_checkpoints_best_effort`` is the guarded wrapper most
59
+ callers use."""
60
+ if not checkpoints:
61
+ raise ValueError("no checkpoints to record")
62
+ context = status.billing_context if isinstance(status.billing_context, dict) else {}
63
+ org_id = str(context.get("org_id") or "").strip()
64
+ if not org_id:
65
+ raise ValueError("missing org id for run checkpoints")
66
+ spec = status.spec or {}
67
+ first = checkpoints[0]
68
+ body = {
69
+ "orgId": org_id,
70
+ "runId": status.run_id,
71
+ "baseModel": spec.get("model"),
72
+ "repoId": first.get("repo_id"),
73
+ "repoType": first.get("repo_type", "dataset"),
74
+ "checkpoints": [
75
+ {"step": c["step"], "subfolder": c["subfolder"]} for c in checkpoints
76
+ ],
77
+ }
78
+ return _post_checkpoints(token=internal_key, body=body)
79
+
80
+
81
+ def register_checkpoints_best_effort(status, *, log=None) -> int:
82
+ """List ``status``'s deployable checkpoints from HF and mirror them to the backend.
83
+
84
+ Returns the number of checkpoints submitted (0 if none, or if persistence was skipped /
85
+ failed). Never raises: the HF copy remains the source of truth, so a persistence miss only
86
+ costs the convenience of a DB-backed listing — not correctness."""
87
+
88
+ def _log(msg: str) -> None:
89
+ print(msg, file=log, flush=True) if log is not None else print(msg)
90
+
91
+ internal_key = os.environ.get(INTERNAL_KEY_ENV, "").strip()
92
+ if not internal_key:
93
+ return 0 # local/dev control plane: HF still has the checkpoints
94
+ try:
95
+ spec = JobSpec.from_dict(status.spec)
96
+ except Exception as exc:
97
+ _log(f"[ckpt] register skipped ({status.run_id}): bad spec: {exc}")
98
+ return 0
99
+ checkpoints = list_checkpoints(spec)
100
+ if not checkpoints:
101
+ return 0
102
+ try:
103
+ register_run_checkpoints(
104
+ internal_key=internal_key, status=status, checkpoints=checkpoints
105
+ )
106
+ except (ValueError, urllib.error.URLError, urllib.error.HTTPError, OSError) as exc:
107
+ _log(f"[ckpt] backend register warn ({status.run_id}): {exc}")
108
+ return 0
109
+ _log(f"[ckpt] registered {len(checkpoints)} checkpoint(s) for {status.run_id}")
110
+ return len(checkpoints)
flash/server/db.py ADDED
@@ -0,0 +1,160 @@
1
+ """SQLite store for the managed control plane: API keys + run ownership.
2
+
3
+ Run *state* stays in the runner's JSON files (``runner.RUNS_DIR``) — the
4
+ battle-tested submit/attach/cancel paths all read those. This database is only the
5
+ key registry and the run -> key ownership index that makes the server multi-tenant.
6
+
7
+ Connections are opened per operation (cheap for SQLite, avoids cross-thread state;
8
+ the runner runs jobs in daemon threads inside the same process).
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import hashlib
14
+ import sqlite3
15
+ import time
16
+ from pathlib import Path
17
+
18
+ _SCHEMA = """
19
+ CREATE TABLE IF NOT EXISTS api_keys (
20
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
21
+ key_hash TEXT NOT NULL UNIQUE,
22
+ key_prefix TEXT NOT NULL,
23
+ email TEXT,
24
+ created_at REAL NOT NULL,
25
+ last_used_at REAL,
26
+ disabled INTEGER NOT NULL DEFAULT 0
27
+ );
28
+ CREATE TABLE IF NOT EXISTS runs (
29
+ run_id TEXT PRIMARY KEY,
30
+ key_id INTEGER NOT NULL REFERENCES api_keys(id),
31
+ kind TEXT NOT NULL DEFAULT 'train',
32
+ created_at REAL NOT NULL
33
+ );
34
+ CREATE INDEX IF NOT EXISTS runs_key_idx ON runs(key_id);
35
+ """
36
+
37
+
38
+ # Fixed location for the keys/run-ownership SQLite DB (not operator-configurable). Tests
39
+ # point it elsewhere with monkeypatch.setattr(db, "DB_PATH", tmp).
40
+ DB_PATH = str(Path.home() / ".flash" / "server.db")
41
+
42
+
43
+ def db_path() -> str:
44
+ return DB_PATH
45
+
46
+
47
+ def _connect() -> sqlite3.Connection:
48
+ path = db_path()
49
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
50
+ conn = sqlite3.connect(path, timeout=30.0)
51
+ conn.row_factory = sqlite3.Row
52
+ conn.execute("PRAGMA foreign_keys=ON")
53
+ conn.execute("PRAGMA journal_mode=WAL")
54
+ conn.executescript(_SCHEMA)
55
+ return conn
56
+
57
+
58
+ def hash_key(api_key: str) -> str:
59
+ # API keys are 192-bit random tokens (secrets.token_hex(24)), not passwords:
60
+ # brute-forcing the keyspace is infeasible, so an unsalted fast hash is the
61
+ # standard at-rest form and keeps O(1) lookup by hash. (CodeQL's
62
+ # password-hashing rule does not apply to high-entropy machine tokens.)
63
+ return hashlib.sha256(api_key.encode()).hexdigest()
64
+
65
+
66
+ def ensure_internal_key(api_key: str) -> dict:
67
+ """Provision a row for the shared freesolo internal/service key (idempotent).
68
+
69
+ The freesolo platform/SDK authenticate to the control plane with the same
70
+ ``FREESOLO_INTERNAL_KEY`` they already hold. Backing it with a real row
71
+ (inserted once, by hash) means run ownership and
72
+ the runs.key_id foreign key work exactly as for a normal key — all
73
+ internal-key runs share this single service identity (no per-user isolation;
74
+ the platform scopes users upstream)."""
75
+ now = time.time()
76
+ internal_email = "internal@freesolo.co"
77
+ with _connect() as conn:
78
+ conn.execute(
79
+ "INSERT OR IGNORE INTO api_keys (key_hash, key_prefix, email, created_at) "
80
+ "VALUES (?, ?, ?, ?)",
81
+ (hash_key(api_key), "internal", internal_email, now),
82
+ )
83
+ conn.execute(
84
+ "UPDATE api_keys SET email = ? WHERE key_hash = ? AND key_prefix = ?",
85
+ (internal_email, hash_key(api_key), "internal"),
86
+ )
87
+ row = lookup_key(api_key)
88
+ if row is None: # pragma: no cover - the row was just inserted
89
+ raise RuntimeError("failed to provision the internal service key")
90
+ return row
91
+
92
+
93
+ def ensure_external_key(
94
+ api_key: str, *, key_prefix: str | None = None, email: str | None = None
95
+ ) -> dict | None:
96
+ """Provision a per-token row for a verified external (freesolo USER) key (idempotent).
97
+
98
+ Unlike :func:`ensure_internal_key` (one shared service identity), this keys a distinct
99
+ row by the presented token's hash, so each freesolo user key gets its OWN run-ownership
100
+ identity (the runs.key_id foreign key then scopes runs per user). The full token is never
101
+ stored — only its sha256.
102
+
103
+ Returns ``None`` (not a row) when the token's row already exists but is DISABLED:
104
+ ``INSERT OR IGNORE`` won't revive it and ``lookup_key`` filters disabled rows, so a
105
+ revoked key is rejected (401) by the caller instead of surfacing as a 500."""
106
+ now = time.time()
107
+ with _connect() as conn:
108
+ conn.execute(
109
+ "INSERT OR IGNORE INTO api_keys (key_hash, key_prefix, email, created_at) "
110
+ "VALUES (?, ?, ?, ?)",
111
+ (hash_key(api_key), key_prefix or "freesolo", email, now),
112
+ )
113
+ return lookup_key(api_key)
114
+
115
+
116
+ def lookup_key(api_key: str) -> dict | None:
117
+ """Resolve a presented key to its row (and touch last_used_at); None if unknown/disabled."""
118
+ with _connect() as conn:
119
+ row = conn.execute(
120
+ "SELECT * FROM api_keys WHERE key_hash = ? AND disabled = 0",
121
+ (hash_key(api_key),),
122
+ ).fetchone()
123
+ if row is None:
124
+ return None
125
+ conn.execute("UPDATE api_keys SET last_used_at = ? WHERE id = ?", (time.time(), row["id"]))
126
+ return dict(row)
127
+
128
+
129
+ def record_run(run_id: str, key_id: int) -> None:
130
+ with _connect() as conn:
131
+ conn.execute(
132
+ "INSERT INTO runs (run_id, key_id, kind, created_at) VALUES (?, ?, ?, ?)",
133
+ (run_id, key_id, "train", time.time()),
134
+ )
135
+
136
+
137
+ def delete_run(run_id: str) -> None:
138
+ with _connect() as conn:
139
+ conn.execute("DELETE FROM runs WHERE run_id = ?", (run_id,))
140
+
141
+
142
+ def run_owner(run_id: str) -> int | None:
143
+ with _connect() as conn:
144
+ row = conn.execute("SELECT key_id FROM runs WHERE run_id = ?", (run_id,)).fetchone()
145
+ return row["key_id"] if row else None
146
+
147
+
148
+ def runs_for_key(key_id: int) -> list[dict]:
149
+ with _connect() as conn:
150
+ rows = conn.execute(
151
+ "SELECT run_id, kind, created_at FROM runs WHERE key_id = ? ORDER BY created_at",
152
+ (key_id,),
153
+ ).fetchall()
154
+ return [dict(r) for r in rows]
155
+
156
+
157
+ def all_runs() -> list[dict]:
158
+ with _connect() as conn:
159
+ rows = conn.execute("SELECT run_id, key_id, kind, created_at FROM runs").fetchall()
160
+ return [dict(r) for r in rows]