freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,883 @@
1
+ """Adapter that runs Freesolo SDK environments on Flash.
2
+
3
+ Flash environment ids are Freesolo Hub slugs (``namespace/name``). Explicit
4
+ low-level refs remain parseable for compatibility. The canonical generated environment file is
5
+ ``environment.py`` and its
6
+ ``load_environment`` function must return a Freesolo SDK environment:
7
+ ``EnvironmentSingleTurn`` or ``EnvironmentMultiTurn``.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import hashlib
13
+ import io
14
+ import json
15
+ import os
16
+ import re
17
+ import shutil
18
+ import tarfile
19
+ import tempfile
20
+ import urllib.error
21
+ import urllib.parse
22
+ import urllib.request
23
+ from dataclasses import dataclass
24
+ from pathlib import Path
25
+ from typing import Any
26
+
27
+ from flash.envs.base import BaseEnvironment
28
+
29
+ _DEFAULT_GITHUB_REF = "main"
30
+ _DEFAULT_ENVIRONMENT_PATH = "environment.py"
31
+ _DEFAULT_MANAGED_ENV_REPO = "freesolo-co/environment-hub"
32
+ _CACHE_ROOT = Path(os.environ.get("FLASH_ENV_CACHE_DIR", "/tmp/flash-env-cache"))
33
+ _MAX_ARCHIVE_BYTES = 256 * 1024 * 1024
34
+ _MAX_ARCHIVE_MEMBERS = 5000
35
+ _COMMIT_SHA_RE = re.compile(r"^[0-9a-f]{40}$", re.IGNORECASE)
36
+ _GITHUB_SAFE_PART_RE = re.compile(r"^[A-Za-z0-9._-]+$")
37
+ _TAR_METADATA_TYPES = {
38
+ tarfile.XHDTYPE,
39
+ tarfile.XGLTYPE,
40
+ tarfile.GNUTYPE_LONGNAME,
41
+ tarfile.GNUTYPE_LONGLINK,
42
+ }
43
+ _CANONICAL_INPUT_KEY = "input"
44
+ _CANONICAL_OUTPUT_KEY = "output"
45
+
46
+
47
+ class GitHubRateLimitError(RuntimeError):
48
+ """Raised when the GitHub API rate-limits us (HTTP 429, or a 403 whose body says "rate limit").
49
+
50
+ Raised by ``_urlopen`` only after the in-process jittered retry is exhausted, so it signals a
51
+ *persistent* limit. The worker's top-level handler catches it and stamps ``retriable=True`` so
52
+ the control plane reschedules the job on a fresh worker once the limit window resets, instead of
53
+ permanently failing the run. The real spawn-wave mitigation is the control-plane resolve-once
54
+ pin (EnvironmentSpec.resolved_sha) that lets workers skip the GitHub resolve entirely.
55
+ """
56
+
57
+
58
+ @dataclass(frozen=True)
59
+ class GitHubEnvironmentRef:
60
+ owner: str
61
+ repo: str
62
+ ref: str
63
+ path: str
64
+
65
+ @property
66
+ def repo_full_name(self) -> str:
67
+ return f"{self.owner}/{self.repo}"
68
+
69
+ def canonical(self) -> str:
70
+ return f"github:{self.repo_full_name}@{self.ref}:{self.path}"
71
+
72
+
73
+ def is_github_environment_ref(value: str) -> bool:
74
+ return _parse_github_environment_ref(value) is not None
75
+
76
+
77
+ def is_managed_environment_slug(value: str) -> bool:
78
+ return _parse_managed_environment_slug(value) is not None
79
+
80
+
81
+ def is_freesolo_environment_id(value: str) -> bool:
82
+ return is_managed_environment_slug(value) or is_github_environment_ref(value)
83
+
84
+
85
+ def managed_slug_to_github_ref(value: str) -> str:
86
+ parsed = _parse_managed_environment_slug(value)
87
+ if parsed is None:
88
+ raise ValueError(f"not a Freesolo environment slug: {value!r}")
89
+ namespace, name = parsed
90
+ return (
91
+ f"github:{_DEFAULT_MANAGED_ENV_REPO}@{_DEFAULT_GITHUB_REF}:"
92
+ f"{namespace}/{name}/{_DEFAULT_ENVIRONMENT_PATH}"
93
+ )
94
+
95
+
96
+ def _parse_managed_environment_slug(value: str) -> tuple[str, str] | None:
97
+ text = (value or "").strip()
98
+ if not text or ":" in text:
99
+ return None
100
+ parsed = urllib.parse.urlparse(text)
101
+ if parsed.scheme or parsed.netloc:
102
+ return None
103
+ parts = text.split("/")
104
+ if len(parts) != 2 or not _is_safe_github_path_parts(tuple(parts)):
105
+ return None
106
+ return parts[0], parts[1]
107
+
108
+
109
+ def _parse_github_environment_ref(value: str) -> GitHubEnvironmentRef | None:
110
+ text = (value or "").strip()
111
+ if not text:
112
+ return None
113
+ if text.startswith("github:"):
114
+ body = text[len("github:") :]
115
+ repo_ref, sep, path = body.partition(":")
116
+ try:
117
+ path = _normalize_env_path(path)
118
+ except ValueError:
119
+ return None
120
+ if not sep:
121
+ path = _DEFAULT_ENVIRONMENT_PATH
122
+ repo_part, at, ref = repo_ref.partition("@")
123
+ if not at:
124
+ ref = _DEFAULT_GITHUB_REF
125
+ if not ref:
126
+ return None
127
+ if not _is_safe_github_path_parts((ref,)):
128
+ return None
129
+ owner_repo = repo_part.split("/")
130
+ if len(owner_repo) == 2 and _is_safe_github_path_parts(owner_repo):
131
+ return GitHubEnvironmentRef(owner_repo[0], owner_repo[1], ref, path)
132
+ return None
133
+
134
+ parsed = urllib.parse.urlparse(text)
135
+ if parsed.scheme not in {"http", "https"} or parsed.netloc.lower() != "github.com":
136
+ return None
137
+ parts = [urllib.parse.unquote(p) for p in parsed.path.strip("/").split("/") if p]
138
+ if len(parts) < 2:
139
+ return None
140
+ owner, repo = parts[0], parts[1]
141
+ repo = repo[:-4] if repo.endswith(".git") else repo
142
+ if not _is_safe_github_path_parts((owner, repo)):
143
+ return None
144
+ if len(parts) >= 5 and parts[2] in {"blob", "tree"}:
145
+ ref = parts[3]
146
+ if not _is_safe_github_path_parts((ref,)):
147
+ return None
148
+ raw_path = "/".join(parts[4:])
149
+ if not _is_safe_environment_path(raw_path):
150
+ return None
151
+ try:
152
+ path = _normalize_env_path(raw_path)
153
+ except ValueError:
154
+ return None
155
+ if not raw_path:
156
+ path = _DEFAULT_ENVIRONMENT_PATH
157
+ if parts[2] == "tree" and raw_path and not path.endswith(".py"):
158
+ path = f"{path.rstrip('/')}/{_DEFAULT_ENVIRONMENT_PATH}"
159
+ elif len(parts) == 2:
160
+ ref = _DEFAULT_GITHUB_REF
161
+ path = _DEFAULT_ENVIRONMENT_PATH
162
+ else:
163
+ return None
164
+ return GitHubEnvironmentRef(owner, repo, ref, path)
165
+
166
+
167
+ def _normalize_env_path(path: str | None) -> str:
168
+ if not path:
169
+ return _DEFAULT_ENVIRONMENT_PATH
170
+ raw = path.strip()
171
+ if not raw:
172
+ return _DEFAULT_ENVIRONMENT_PATH
173
+ raw = raw.replace("\\", "/")
174
+ if raw.startswith("/"):
175
+ raise ValueError(f"unsafe environment path: {path!r}")
176
+ if not raw:
177
+ return _DEFAULT_ENVIRONMENT_PATH
178
+ parts = [part for part in raw.split("/") if part]
179
+ if not parts:
180
+ return _DEFAULT_ENVIRONMENT_PATH
181
+ if any(part == ".." or part == "." for part in parts):
182
+ raise ValueError(f"unsafe environment path: {path!r}")
183
+ return "/".join(parts)
184
+
185
+
186
+ def _is_safe_environment_path(path: str) -> bool:
187
+ if not path:
188
+ return True
189
+ raw = path.strip().replace("\\", "/")
190
+ if raw.startswith("/"):
191
+ return False
192
+ parts = [part for part in raw.split("/") if part]
193
+ if not parts:
194
+ return True
195
+ return not any(part in {".", ".."} for part in parts)
196
+
197
+
198
+ def _is_safe_github_path_parts(parts: list[str] | tuple[str, ...]) -> bool:
199
+ if not parts:
200
+ return False
201
+ if any(part in {".", "..", ""} for part in parts):
202
+ return False
203
+ return all(_GITHUB_SAFE_PART_RE.fullmatch(part) for part in parts)
204
+
205
+
206
+ def _github_token() -> str | None:
207
+ return os.environ.get("GITHUB_TOKEN")
208
+
209
+
210
+ def _is_commit_sha(value: str) -> bool:
211
+ return _COMMIT_SHA_RE.fullmatch(value) is not None
212
+
213
+
214
+ def _resolve_ref_sha(
215
+ parsed: GitHubEnvironmentRef,
216
+ pinned_sha: str | None = None,
217
+ *,
218
+ timeout: float = 60.0,
219
+ max_rate_limit_retries: int = 5,
220
+ ) -> str:
221
+ # Resolve-once hook: the control plane resolves ref->sha ONCE (runner._assign_resolved_env_sha)
222
+ # and threads the pinned commit sha through, so every worker in a fan-out short-circuits here
223
+ # without hitting GitHub at all — this is what actually defuses a cold spawn wave (the prior
224
+ # in-process cache could not, since each worker is a separate process). Same effect as the
225
+ # immutable-ref fast path below, but applied to a symbolic ref (e.g. "main"). Only a real
226
+ # 40-char sha is trusted; anything else falls through to a live resolve. The control plane
227
+ # passes a short timeout + max_rate_limit_retries=0 so its best-effort pin can't block run
228
+ # creation; the worker keeps the full retry budget.
229
+ if pinned_sha and _is_commit_sha(pinned_sha):
230
+ return pinned_sha
231
+ if _is_commit_sha(parsed.ref):
232
+ return parsed.ref
233
+ # No pin (legacy spec / non-managed ref / control-plane resolve failed): resolve the symbolic
234
+ # ref live every time. We deliberately do NOT cache symbolic refs in-process — a long-lived
235
+ # process must see a moved branch (managed slugs point at environment-hub@main, which moves on
236
+ # `flash env push`), and the immutable-sha cases above already skip the network.
237
+ headers = {"Accept": "application/vnd.github+json", "User-Agent": "freesolo-flash"}
238
+ token = _github_token()
239
+ if token:
240
+ headers["Authorization"] = f"Bearer {token}"
241
+ commit_url = f"https://api.github.com/repos/{parsed.repo_full_name}/commits/{urllib.parse.quote(parsed.ref, safe='')}"
242
+ req = urllib.request.Request(commit_url, headers=headers)
243
+ data = _urlopen(req, timeout=timeout, max_rate_limit_retries=max_rate_limit_retries)
244
+ try:
245
+ payload = json.loads(data)
246
+ except json.JSONDecodeError as exc:
247
+ raise RuntimeError(
248
+ f"Failed to resolve GitHub environment ref {parsed.canonical()}: invalid response"
249
+ ) from exc
250
+ sha = payload.get("sha")
251
+ if not isinstance(sha, str) or not _is_commit_sha(sha):
252
+ raise RuntimeError(f"Failed to resolve GitHub environment ref {parsed.canonical()}")
253
+ return sha
254
+
255
+
256
+ def _urlopen(
257
+ req: urllib.request.Request, *, timeout: float = 60.0, max_rate_limit_retries: int = 5
258
+ ) -> bytes:
259
+ """Fetch bytes for a GitHub request, surviving GitHub's secondary rate limit.
260
+
261
+ On a 429 or a 403 whose body says "rate limit" (a cold spawn wave of workers all hitting the
262
+ commits/tarball endpoint trips GitHub's abuse detection), retry up to ``max_rate_limit_retries``
263
+ times with jitter so concurrent workers don't all retry in lockstep. If the limit persists past
264
+ the retries, raise ``GitHubRateLimitError`` so the worker's top-level handler stamps
265
+ ``retriable=True`` and the run reschedules on a fresh worker, instead of hard-failing on a
266
+ transient limit. Any other HTTP / URL error raises a plain ``RuntimeError`` (non-retriable).
267
+
268
+ ``max_rate_limit_retries=0`` makes it fail fast (one request, no sleeps): the control plane uses
269
+ that for its best-effort resolve-once pin so a persistent limit can never block run creation —
270
+ the long retry belongs on the worker, which can afford to wait.
271
+ """
272
+ import random
273
+ import time
274
+
275
+ _RATE_LIMIT_BASE_DELAY = 10.0
276
+ attempt = 0
277
+ while True:
278
+ try:
279
+ with urllib.request.urlopen(req, timeout=timeout) as resp:
280
+ return resp.read()
281
+ except urllib.error.HTTPError as exc:
282
+ body = exc.read().decode("utf-8", "replace")
283
+ is_rate_limit = exc.code == 429 or (exc.code == 403 and "rate limit" in body.lower())
284
+ if is_rate_limit and attempt < max_rate_limit_retries:
285
+ delay = max(_RATE_LIMIT_BASE_DELAY, min(45.0, _RATE_LIMIT_BASE_DELAY * (attempt + 1) * random.uniform(0.5, 1.5)))
286
+ time.sleep(delay)
287
+ attempt += 1
288
+ continue
289
+ if is_rate_limit:
290
+ # Persistent limit: signal retriable so the control plane reschedules (#209).
291
+ raise GitHubRateLimitError(
292
+ f"GitHub API rate limit exceeded ({exc.code}): {body[:300]}"
293
+ ) from exc
294
+ raise RuntimeError(f"GitHub environment request failed ({exc.code}): {body[:500]}") from exc
295
+ except urllib.error.URLError as exc:
296
+ raise RuntimeError(f"GitHub environment request failed: {exc.reason}") from exc
297
+
298
+
299
+ def _download_github_tarball(ref: GitHubEnvironmentRef) -> bytes:
300
+ # Callers download the tarball for the ALREADY-resolved commit sha (see
301
+ # _resolve_github_environment_file), which extracts it into the content-addressed disk cache
302
+ # under _CACHE_ROOT/<hash(repo@sha:path)> and never re-downloads on a hit. So reuse is handled
303
+ # on disk; we deliberately do NOT also retain the (up to _MAX_ARCHIVE_BYTES) archive bytes in a
304
+ # module-level cache for the worker's lifetime — that wasted hundreds of MiB of RAM per process.
305
+ url = f"https://api.github.com/repos/{ref.repo_full_name}/tarball/{urllib.parse.quote(ref.ref, safe='')}"
306
+ headers = {
307
+ "Accept": "application/vnd.github+json",
308
+ "User-Agent": "freesolo-flash",
309
+ }
310
+ token = _github_token()
311
+ if token:
312
+ headers["Authorization"] = f"Bearer {token}"
313
+ data = _urlopen(urllib.request.Request(url, headers=headers), timeout=120.0)
314
+ if len(data) > _MAX_ARCHIVE_BYTES:
315
+ raise RuntimeError(
316
+ f"environment archive is too large ({len(data)} bytes; "
317
+ f"limit {_MAX_ARCHIVE_BYTES} bytes)"
318
+ )
319
+ return data
320
+
321
+
322
+ def _safe_extract_archive(tar_bytes: bytes, dest: Path) -> Path:
323
+ root = dest.resolve()
324
+ top_dirs: set[str] = set()
325
+ total = 0
326
+ with tarfile.open(fileobj=io.BytesIO(tar_bytes), mode="r:gz") as tar:
327
+ for count, member in enumerate(tar, start=1):
328
+ if count > _MAX_ARCHIVE_MEMBERS:
329
+ raise RuntimeError(
330
+ f"env package has too many members (limit {_MAX_ARCHIVE_MEMBERS})"
331
+ )
332
+ if member.type in _TAR_METADATA_TYPES:
333
+ continue
334
+ parts: list[str] = []
335
+ for part in member.name.replace("\\", "/").split("/"):
336
+ if not part or part == ".":
337
+ continue
338
+ if part == "..":
339
+ raise RuntimeError(f"unsafe path in environment archive: {member.name!r}")
340
+ parts.append(part)
341
+ if not parts:
342
+ continue
343
+ normalized_name = "/".join(parts)
344
+ target = (dest / normalized_name).resolve()
345
+ if target != root and root not in target.parents:
346
+ raise RuntimeError(f"unsafe path in environment archive: {member.name!r}")
347
+ if member.islnk() or member.issym() or not (member.isreg() or member.isdir()):
348
+ continue
349
+ top_dirs.add(parts[0])
350
+ total += max(0, member.size)
351
+ if total > _MAX_ARCHIVE_BYTES:
352
+ raise RuntimeError(
353
+ f"environment archive is too large uncompressed ({total} bytes; limit {_MAX_ARCHIVE_BYTES} bytes)"
354
+ )
355
+ member.name = normalized_name
356
+ tar.extract(member, dest)
357
+ if len(top_dirs) != 1:
358
+ raise RuntimeError("environment archive had an unexpected layout")
359
+ extracted = dest / next(iter(top_dirs))
360
+ if not extracted.is_dir():
361
+ raise RuntimeError("environment archive did not extract to a directory")
362
+ return extracted
363
+
364
+
365
+ def _resolve_github_environment_file(env_ref: str, pinned_sha: str | None = None) -> Path:
366
+ parsed = _parse_github_environment_ref(env_ref)
367
+ if parsed is None:
368
+ raise ValueError(f"not a GitHub environment ref: {env_ref!r}")
369
+ # Only forward pinned_sha when set, so the unpatched _resolve_ref_sha(parsed) call shape stays
370
+ # single-arg (a test monkeypatch like ``lambda parsed: ...`` keeps working).
371
+ resolved_ref = (
372
+ _resolve_ref_sha(parsed, pinned_sha=pinned_sha) if pinned_sha else _resolve_ref_sha(parsed)
373
+ )
374
+ cache_key = hashlib.sha256(
375
+ f"github:{parsed.repo_full_name}@{resolved_ref}:{parsed.path}".encode()
376
+ ).hexdigest()[:24]
377
+ cache_dir = _CACHE_ROOT / cache_key
378
+ env_file = cache_dir / parsed.path
379
+ if env_file.is_dir():
380
+ env_file = env_file / _DEFAULT_ENVIRONMENT_PATH
381
+ if env_file.is_file():
382
+ return env_file
383
+ tmp_parent = Path(tempfile.mkdtemp(prefix="flash-env-github-"))
384
+ resolved = GitHubEnvironmentRef(
385
+ parsed.owner,
386
+ parsed.repo,
387
+ resolved_ref,
388
+ parsed.path,
389
+ )
390
+ try:
391
+ extracted = _safe_extract_archive(_download_github_tarball(resolved), tmp_parent)
392
+ candidate = extracted / parsed.path
393
+ if candidate.is_dir():
394
+ candidate = candidate / _DEFAULT_ENVIRONMENT_PATH
395
+ required_entrypoint = candidate.relative_to(extracted).as_posix()
396
+ if not candidate.is_file():
397
+ raise FileNotFoundError(
398
+ f"environment archive did not contain required entrypoint {required_entrypoint!r}"
399
+ )
400
+ cache_dir.parent.mkdir(parents=True, exist_ok=True)
401
+ shutil.rmtree(cache_dir, ignore_errors=True)
402
+ shutil.copytree(extracted, cache_dir)
403
+ return cache_dir / candidate.relative_to(extracted)
404
+ finally:
405
+ shutil.rmtree(tmp_parent, ignore_errors=True)
406
+
407
+
408
+ def _resolve_environment_reference(env_ref: str, pinned_sha: str | None = None) -> str:
409
+ # pinned_sha (resolve-once hook, item 3): when the control plane already resolved this env's
410
+ # ref->sha, it threads the commit sha here so the GitHub commits API is skipped entirely. None
411
+ # (the default and today's behavior) means the worker resolves the ref itself.
412
+ if is_managed_environment_slug(env_ref):
413
+ return str(
414
+ _resolve_github_environment_file(managed_slug_to_github_ref(env_ref), pinned_sha)
415
+ )
416
+ parsed = _parse_github_environment_ref(env_ref)
417
+ if parsed is None:
418
+ path = Path(env_ref)
419
+ if path.exists():
420
+ return str(path)
421
+ return env_ref
422
+ return str(_resolve_github_environment_file(env_ref, pinned_sha))
423
+
424
+
425
+ def _resolve_path_arg(value: object, base_dir: Path) -> object:
426
+ if not isinstance(value, str) or not value:
427
+ return value
428
+ parsed = urllib.parse.urlparse(value)
429
+ if parsed.scheme or Path(value).is_absolute():
430
+ return value
431
+ candidate = base_dir / value
432
+ return str(candidate) if candidate.exists() else value
433
+
434
+
435
+ def _load_contract_text(path: str | None) -> str:
436
+ if not path:
437
+ return ""
438
+ candidate = Path(path)
439
+ if not candidate.is_file():
440
+ return ""
441
+ try:
442
+ return candidate.read_text(encoding="utf-8")
443
+ except UnicodeError:
444
+ return candidate.read_text(errors="replace")
445
+
446
+
447
+ def _import_freesolo_environment_tools():
448
+ try:
449
+ from freesolo.datasets.records import load_task_examples, task_example_from_record
450
+ from freesolo.environments import (
451
+ EnvironmentEpisode,
452
+ EnvironmentMultiTurn,
453
+ EnvironmentSingleTurn,
454
+ EnvironmentTurn,
455
+ load_environment,
456
+ )
457
+
458
+ return {
459
+ "EnvironmentEpisode": EnvironmentEpisode,
460
+ "EnvironmentMultiTurn": EnvironmentMultiTurn,
461
+ "EnvironmentSingleTurn": EnvironmentSingleTurn,
462
+ "EnvironmentTurn": EnvironmentTurn,
463
+ "load_environment": load_environment,
464
+ "load_task_examples": load_task_examples,
465
+ "task_example_from_record": task_example_from_record,
466
+ }
467
+ except ImportError as exc:
468
+ raise ImportError(
469
+ "the 'freesolo' package is required to run Freesolo environments; "
470
+ "install it (for example `uv pip install freesolo`) or use a worker image "
471
+ "that includes the Freesolo SDK"
472
+ ) from exc
473
+
474
+
475
+ def _json_safe(value: Any) -> Any:
476
+ try:
477
+ json.dumps(value)
478
+ return value
479
+ except TypeError:
480
+ return str(value)
481
+
482
+
483
+ class FreesoloEnvironment(BaseEnvironment):
484
+ """Flash environment backed by ``freesolo.environments``."""
485
+
486
+ def __init__(
487
+ self,
488
+ sdk_env: object,
489
+ env_id: str,
490
+ *,
491
+ source: object | None,
492
+ contract_text: str = "",
493
+ ):
494
+ super().__init__(id=env_id)
495
+ self._env = sdk_env
496
+ self._source = source
497
+ self._contract_text = contract_text
498
+ tools = _import_freesolo_environment_tools()
499
+ self._task_example_from_record = tools["task_example_from_record"]
500
+ self._load_task_examples = tools["load_task_examples"]
501
+ self._EnvironmentEpisode = tools["EnvironmentEpisode"]
502
+ self._EnvironmentMultiTurn = tools["EnvironmentMultiTurn"]
503
+ self._EnvironmentTurn = tools["EnvironmentTurn"]
504
+ self.multi_turn = isinstance(sdk_env, tools["EnvironmentMultiTurn"])
505
+ self.is_tool_env = False
506
+ self._max_turns_cache: int | None = None
507
+ self._dataset_cache: list[dict] | None = None
508
+
509
+ @property
510
+ def max_turns(self) -> int:
511
+ """Rollout turn ceiling the worker reads for the batch-level loop cap.
512
+
513
+ A pure multi-turn freesolo env sets a *per-example* budget via
514
+ ``max_episode_turns(example)`` (e.g. #user turns + a tool-iteration budget per
515
+ turn). The batch cap the rollout loop uses must be at least the largest such
516
+ budget, or it would truncate the deepest scenarios before they finish (e.g.
517
+ support-chat's 4-customer-turn rollouts need ~20 turns, not 8). Take the
518
+ dataset-wide max once, bounded so a pathological env can't make rollouts
519
+ unbounded; single-turn / non-multi-turn envs keep the small default. The exact
520
+ per-example budget is still enforced in :meth:`rollout_done`.
521
+ """
522
+ if self._max_turns_cache is not None:
523
+ return self._max_turns_cache
524
+ cap = 8
525
+ if self.multi_turn:
526
+ cap = 24 # safe default if no per-example budget can be read at all
527
+ best: int | None = None # running max — no intermediate list for large datasets
528
+ for ex in self.dataset(): # cached; see dataset()
529
+ # Per-example so ONE malformed row (or an env whose max_episode_turns raises on it)
530
+ # is skipped rather than discarding every budget and silently falling back to 24,
531
+ # which would reintroduce the truncation this is meant to prevent.
532
+ try:
533
+ turns = int(self._env.max_episode_turns(self._task_example(ex)))
534
+ except Exception:
535
+ continue
536
+ if best is None or turns > best:
537
+ best = turns
538
+ if best is not None:
539
+ cap = max(8, min(64, best))
540
+ self._max_turns_cache = cap
541
+ return cap
542
+
543
+ def _task_example(self, example: dict):
544
+ return self._task_example_from_record(self._canonical_record(example))
545
+
546
+ @staticmethod
547
+ def _canonical_record(record: dict) -> dict:
548
+ raw = dict(record)
549
+ canonical = {}
550
+ if _CANONICAL_INPUT_KEY not in raw:
551
+ raise ValueError("Freesolo dataset records must contain an input field")
552
+ canonical[_CANONICAL_INPUT_KEY] = raw[_CANONICAL_INPUT_KEY]
553
+ if _CANONICAL_OUTPUT_KEY in raw:
554
+ canonical[_CANONICAL_OUTPUT_KEY] = raw[_CANONICAL_OUTPUT_KEY]
555
+ if raw.get("id") is not None:
556
+ canonical["id"] = raw["id"]
557
+ metadata = raw.get("metadata")
558
+ if isinstance(metadata, dict) and metadata:
559
+ canonical["metadata"] = metadata
560
+ return canonical
561
+
562
+ def _reward_to_breakdown(self, reward) -> dict[str, float]:
563
+ out: dict[str, float] = {}
564
+ for metric in getattr(reward, "metrics", ()) or ():
565
+ score = getattr(metric, "score", None)
566
+ if score is not None:
567
+ name = str(getattr(metric, "name", "") or "metric")
568
+ key = name
569
+ idx = 1
570
+ while key in out:
571
+ idx += 1
572
+ key = f"{name}_{idx}"
573
+ out[key] = float(score)
574
+ out["total"] = float(getattr(reward, "score", 0.0))
575
+ return out
576
+
577
+ def dataset(self) -> list[dict]:
578
+ # Parse once and cache: the worker reads ``env.dataset()`` AND ``env.max_turns`` (which
579
+ # also scans the dataset), so without this a multi-turn run would parse/load the whole
580
+ # dataset twice at startup.
581
+ if self._dataset_cache is not None:
582
+ return self._dataset_cache
583
+ if self._source is None:
584
+ rows = getattr(self._env, "dataset", None) or getattr(self._env, "examples", None)
585
+ if rows is None:
586
+ raise ValueError(
587
+ "Freesolo environment has no dataset source. Set "
588
+ "[environment.params] dataset_path or records so Flash can train."
589
+ )
590
+ examples = self._load_task_examples(rows)
591
+ else:
592
+ examples = self._load_task_examples(self._source)
593
+ records = []
594
+ for example in examples:
595
+ raw = dict(getattr(example, "record", {}) or {})
596
+ task = getattr(example, "task", None)
597
+ if _CANONICAL_INPUT_KEY not in raw and task is not None:
598
+ raw[_CANONICAL_INPUT_KEY] = task
599
+ task_id = getattr(example, "task_id", None)
600
+ if task_id is not None:
601
+ raw.setdefault("id", task_id)
602
+ expected = getattr(example, "expected_output", None)
603
+ if expected is not None:
604
+ raw.setdefault(_CANONICAL_OUTPUT_KEY, _json_safe(expected))
605
+ metadata = getattr(example, "metadata", None)
606
+ if isinstance(metadata, dict) and metadata:
607
+ raw.setdefault("metadata", metadata)
608
+ record = self._canonical_record(raw)
609
+ records.append(record)
610
+ self._dataset_cache = records
611
+ return records
612
+
613
+ def prompt_messages(self, example: dict) -> list[dict]:
614
+ messages = self._env.start_episode(self._task_example(example), self._contract_text)
615
+ return [dict(message) for message in messages]
616
+
617
+ def sft_completion(self, example: dict) -> list[dict]:
618
+ """Target completion messages to append after the prompt for one SFT example.
619
+
620
+ Delegates to the freesolo-sdk env's first-class ``Environment.sft_completion``, which turns
621
+ the record's ``output`` into the target messages: a MULTI-TURN target trajectory —
622
+ assistant turns, tool calls, tool results, replies (authored as ``output = {"messages":
623
+ [...]}`` or a bare message list) — when the row ships one, else a single assistant turn from
624
+ a scalar output. So the SFT example shape is owned by the freesolo-sdk dataset layer
625
+ (``freesolo.datasets.target_messages``), not a flash-only convention; ``len(...) > 1`` is
626
+ multi-turn. Falls back to reading the raw record only for an older installed SDK that
627
+ predates the method."""
628
+ fn = getattr(self._env, "sft_completion", None)
629
+ if callable(fn):
630
+ msgs = fn(self._task_example(example))
631
+ if msgs:
632
+ return [dict(m) for m in msgs]
633
+ value = example.get(_CANONICAL_OUTPUT_KEY)
634
+ if isinstance(value, list) and value and all(isinstance(m, dict) for m in value):
635
+ return [dict(m) for m in value]
636
+ if isinstance(value, dict) and list(value) == ["messages"] and isinstance(value["messages"], list):
637
+ return [dict(m) for m in value["messages"]]
638
+ return [{"role": "assistant", "content": "" if value is None else str(value)}]
639
+
640
+ def scores_breakdown(
641
+ self, completion: str, example: dict, state: dict | None = None
642
+ ) -> dict[str, float]:
643
+ if state and self.multi_turn:
644
+ reward = self._score_episode(example, state)
645
+ else:
646
+ rewards = self._env.score_responses(self._task_example(example), [completion])
647
+ if len(rewards) != 1:
648
+ raise RuntimeError("Freesolo environment score_responses returned the wrong length")
649
+ reward = rewards[0]
650
+ return self._reward_to_breakdown(reward)
651
+
652
+ def reward(self, completion: str, example: dict, state: dict | None = None) -> float:
653
+ return float(self.scores_breakdown(completion, example, state)["total"])
654
+
655
+ def reward_many(self, items: list[tuple[dict, dict]]) -> list[float]:
656
+ """Reward for many ``(example, state)`` rollouts at once, in input order.
657
+
658
+ For multi-turn, episodes that share a task go through ONE ``score_episodes`` call, which the
659
+ env scores concurrently (``Environment.max_score_concurrency``) — replacing one blocking
660
+ scoring call per rollout. For a judge / network-reward env (where scoring dominates) this is
661
+ the multi-turn analogue of batched generation. Equals one :meth:`reward` per item:
662
+ ``score_episodes`` scores each episode independently, so batching changes only concurrency,
663
+ not values. Single-turn falls back to per-item :meth:`reward`."""
664
+ if not self.multi_turn:
665
+ # Single-turn scoring ignores state and grades the completion, so pass the rollout's
666
+ # actual response (stored on the state) — not "" (which would score every item empty).
667
+ return [self.reward(str(st.get("response_text") or ""), ex, st) for ex, st in items]
668
+ groups: dict[str, dict] = {}
669
+ order: list[str] = []
670
+ for i, (ex, st) in enumerate(items):
671
+ # Group rollouts of the same example (a GRPO group shares one prompt) so their episodes
672
+ # are scored together; the example dict is the stable grouping key.
673
+ key = json.dumps(ex, sort_keys=True, default=str)
674
+ grp = groups.get(key)
675
+ if grp is None:
676
+ grp = groups[key] = {
677
+ "task": st.get("task") or self._task_example(ex),
678
+ "idxs": [],
679
+ "episodes": [],
680
+ }
681
+ order.append(key)
682
+ grp["idxs"].append(i)
683
+ grp["episodes"].append(self._episode_from_state(st))
684
+ out: list[float] = [0.0] * len(items)
685
+ for key in order:
686
+ grp = groups[key]
687
+ rewards = self._env.score_episodes(grp["task"], grp["episodes"])
688
+ if len(rewards) != len(grp["episodes"]):
689
+ raise RuntimeError("Freesolo environment score_episodes returned the wrong length")
690
+ for idx, rw in zip(grp["idxs"], rewards, strict=True):
691
+ out[idx] = float(rw.score)
692
+ return out
693
+
694
+ @property
695
+ def reward_thread_safe(self) -> bool:
696
+ """Whether ``reward`` may be called concurrently across rollouts (multiturn_rollout's
697
+ ``_score_rollouts`` thread-pool fallback, used when the env has no ``reward_many``). The
698
+ verifiers reward contract is a pure scorer — ``score_responses`` reads the per-call inputs +
699
+ immutable env config — so the default is True. An underlying env whose scorer keeps mutable
700
+ state or a thread-bound client opts out with ``reward_thread_safe = False`` (scored serially)."""
701
+ return bool(getattr(self._env, "reward_thread_safe", True))
702
+
703
+ def grade(self, completion: str, example: dict, state: dict | None = None) -> bool:
704
+ if state and self.multi_turn:
705
+ reward = self._score_episode(example, state)
706
+ else:
707
+ rewards = self._env.score_responses(self._task_example(example), [completion])
708
+ if len(rewards) != 1:
709
+ raise RuntimeError("Freesolo environment score_responses returned the wrong length")
710
+ reward = rewards[0]
711
+ return bool(reward.resolved_success())
712
+
713
+ def tools(self) -> list:
714
+ return []
715
+
716
+ def new_rollout_state(self, example: dict) -> dict:
717
+ task = self._task_example(example)
718
+ prompt = [dict(message) for message in self._env.start_episode(task, self._contract_text)]
719
+ # Per-example turn budget (env's max_episode_turns) so rollout_done caps THIS
720
+ # rollout at its own budget rather than the batch-wide ceiling -- a single-turn
721
+ # scenario stops after a few turns while a deep one gets its full budget.
722
+ try:
723
+ episode_turns: int | None = int(self._env.max_episode_turns(task))
724
+ except Exception:
725
+ episode_turns = None
726
+ return {
727
+ "task": task,
728
+ "prompt": [dict(message) for message in prompt],
729
+ "messages": [dict(message) for message in prompt],
730
+ "turns": [],
731
+ "done": False,
732
+ "response_text": "",
733
+ "turn": 0,
734
+ "max_episode_turns": episode_turns,
735
+ }
736
+
737
+ def record_model_turn(self, state: dict, content: str) -> dict:
738
+ msg = {"role": "assistant", "content": content}
739
+ state.setdefault("messages", []).append(msg)
740
+ state.setdefault("turns", []).append(
741
+ self._EnvironmentTurn(role="assistant", content=content)
742
+ )
743
+ state["response_text"] = content
744
+ return msg
745
+
746
+ def env_reply(self, messages: list[dict], state: dict) -> list[dict]:
747
+ if not self.multi_turn:
748
+ return []
749
+ task = state.get("task")
750
+ if task is None:
751
+ raise RuntimeError("missing Freesolo rollout task state")
752
+ assistant_response = str(state.get("response_text") or "")
753
+ step = self._env.step_episode(task, list(messages), assistant_response)
754
+ state["done"] = bool(step.done)
755
+ if step.final_response_text is not None:
756
+ state["response_text"] = step.final_response_text
757
+ state["turn"] = int(state.get("turn", 0)) + 1
758
+ if step.metadata:
759
+ state.setdefault("step_metadata", []).append(step.metadata)
760
+ replies = [dict(message) for message in step.messages]
761
+ state.setdefault("messages", []).extend(replies)
762
+ for message in replies:
763
+ state.setdefault("turns", []).append(
764
+ self._EnvironmentTurn(
765
+ role=str(message.get("role", "")),
766
+ content=str(message.get("content", "")),
767
+ )
768
+ )
769
+ return replies
770
+
771
+ def rollout_done(self, state: dict, max_turns: int | None = None) -> bool:
772
+ if not self.multi_turn:
773
+ return True
774
+ if bool(state.get("done")):
775
+ return True
776
+ # Prefer THIS rollout's own per-example budget (set in new_rollout_state); fall
777
+ # back to the batch-wide cap the worker passes. The env normally terminates via
778
+ # step.done well before either, so this is a non-termination guard.
779
+ cap = state.get("max_episode_turns")
780
+ if cap is None:
781
+ cap = max_turns
782
+ return cap is not None and int(state.get("turn", 0)) >= int(cap)
783
+
784
+ def _episode_from_state(self, state: dict):
785
+ return self._EnvironmentEpisode(
786
+ messages=tuple(state.get("messages") or ()),
787
+ response_text=str(state.get("response_text") or ""),
788
+ turns=tuple(state.get("turns") or ()),
789
+ metadata={"steps": state.get("step_metadata", [])}
790
+ if state.get("step_metadata")
791
+ else {},
792
+ )
793
+
794
+ def _score_episode(self, example: dict, state: dict):
795
+ task = state.get("task") or self._task_example(example)
796
+ rewards = self._env.score_episodes(task, [self._episode_from_state(state)])
797
+ if len(rewards) != 1:
798
+ raise RuntimeError("Freesolo environment score_episodes returned the wrong length")
799
+ return rewards[0]
800
+
801
+ def reward_from_messages(
802
+ self, completion_msgs: list[dict], example: dict, prompt_msgs: list[dict] | None = None
803
+ ) -> float:
804
+ messages = [*(prompt_msgs or []), *completion_msgs]
805
+ response_text = ""
806
+ turns = []
807
+ for message in completion_msgs:
808
+ content = str(message.get("content", ""))
809
+ role = str(message.get("role", ""))
810
+ turns.append(self._EnvironmentTurn(role=role, content=content))
811
+ if role == "assistant":
812
+ response_text = content
813
+ episode = self._EnvironmentEpisode(
814
+ messages=tuple(dict(m) for m in messages),
815
+ response_text=response_text,
816
+ turns=tuple(turns),
817
+ )
818
+ rewards = self._env.score_episodes(self._task_example(example), [episode])
819
+ if len(rewards) != 1:
820
+ raise RuntimeError("Freesolo environment score_episodes returned the wrong length")
821
+ return float(rewards[0].score)
822
+
823
+
824
+ def load_freesolo_environment(
825
+ env_id: str, pinned_sha: str | None = None, /, **kwargs
826
+ ) -> FreesoloEnvironment:
827
+ # pinned_sha is a POSITIONAL-ONLY resolve-once hook (the control-plane-pinned commit sha). It is
828
+ # positional-only (the `/`) precisely so a user [environment.params] entry of ANY name — even
829
+ # one literally named "pinned_sha" — lands in **kwargs and is forwarded verbatim to the Freesolo
830
+ # SDK loader, never binding to or shadowing this internal pin. None (default) preserves today's
831
+ # behavior — the worker resolves the env ref->sha itself.
832
+ tools = _import_freesolo_environment_tools()
833
+ reference = _resolve_environment_reference(env_id, pinned_sha)
834
+ reference_path = Path(reference)
835
+ base_dir = reference_path.parent if reference_path.exists() else Path.cwd()
836
+
837
+ params = dict(kwargs)
838
+ source = params.pop("records", None)
839
+ dataset_path = params.get("dataset_path")
840
+ if source is None and dataset_path:
841
+ resolved_dataset_path = _resolve_path_arg(dataset_path, base_dir)
842
+ params["dataset_path"] = resolved_dataset_path
843
+ source = resolved_dataset_path
844
+ if source is None:
845
+ for rel in (
846
+ "datasets/train.jsonl",
847
+ "datasets/train.json",
848
+ "train.jsonl",
849
+ "train.json",
850
+ ):
851
+ candidate = base_dir / rel
852
+ if candidate.is_file():
853
+ params.setdefault("dataset_path", str(candidate))
854
+ source = str(candidate)
855
+ break
856
+
857
+ contract_path = _resolve_path_arg(params.get("contract_path"), base_dir)
858
+ if isinstance(contract_path, str):
859
+ params["contract_path"] = contract_path
860
+ else:
861
+ params.setdefault("contract_path", str(base_dir / "TRAINING_CONTRACT.md"))
862
+ contract_text = str(
863
+ params.pop("contract_text", "") or _load_contract_text(params["contract_path"])
864
+ )
865
+
866
+ sdk_env = tools["load_environment"](reference, **params)
867
+ return FreesoloEnvironment(
868
+ sdk_env,
869
+ env_id,
870
+ source=source,
871
+ contract_text=contract_text,
872
+ )
873
+
874
+
875
+ __all__ = [
876
+ "FreesoloEnvironment",
877
+ "GitHubEnvironmentRef",
878
+ "is_freesolo_environment_id",
879
+ "is_github_environment_ref",
880
+ "is_managed_environment_slug",
881
+ "load_freesolo_environment",
882
+ "managed_slug_to_github_ref",
883
+ ]