freesolo-flash-dev 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flash/__init__.py +29 -0
- flash/_channel.py +23 -0
- flash/_fileio.py +35 -0
- flash/_logging.py +49 -0
- flash/_update_check.py +266 -0
- flash/catalog.py +253 -0
- flash/cli/__init__.py +1 -0
- flash/cli/main/__init__.py +227 -0
- flash/cli/main/__main__.py +6 -0
- flash/cli/main/commands.py +636 -0
- flash/cli/main/envpush.py +317 -0
- flash/cli/main/render.py +599 -0
- flash/cli/main/training_doc.py +455 -0
- flash/client/__init__.py +14 -0
- flash/client/config.py +70 -0
- flash/client/http.py +372 -0
- flash/client/runtime_secrets.py +69 -0
- flash/client/specs.py +20 -0
- flash/cost/__init__.py +16 -0
- flash/cost/analytical.py +175 -0
- flash/cost/facts.py +114 -0
- flash/cost/spec.py +113 -0
- flash/cost/types.py +158 -0
- flash/engine/__init__.py +6 -0
- flash/engine/accounting.py +36 -0
- flash/engine/chalk_kernels.py +116 -0
- flash/engine/multiturn_rollout.py +780 -0
- flash/engine/recipe.py +86 -0
- flash/engine/vram.py +603 -0
- flash/engine/worker/__init__.py +2916 -0
- flash/engine/worker/__main__.py +4 -0
- flash/engine/worker/kernel_warmup.py +400 -0
- flash/engine/worker/lora.py +796 -0
- flash/engine/worker/packing.py +366 -0
- flash/engine/worker/perf.py +1048 -0
- flash/envs/__init__.py +10 -0
- flash/envs/adapter/__init__.py +883 -0
- flash/envs/adapter/rubric.py +222 -0
- flash/envs/base.py +52 -0
- flash/envs/registry.py +62 -0
- flash/mcp/__init__.py +1 -0
- flash/mcp/server.py +85 -0
- flash/providers/__init__.py +59 -0
- flash/providers/_auth.py +24 -0
- flash/providers/_http.py +230 -0
- flash/providers/_instance.py +416 -0
- flash/providers/_instance_bootstrap.py +517 -0
- flash/providers/_poll.py +311 -0
- flash/providers/allocator.py +193 -0
- flash/providers/base.py +431 -0
- flash/providers/hyperstack/__init__.py +127 -0
- flash/providers/hyperstack/api.py +522 -0
- flash/providers/hyperstack/auth.py +17 -0
- flash/providers/hyperstack/gpus.py +29 -0
- flash/providers/hyperstack/jobs/__init__.py +632 -0
- flash/providers/hyperstack/jobs/builders.py +122 -0
- flash/providers/hyperstack/preflight.py +23 -0
- flash/providers/hyperstack/pricing.py +26 -0
- flash/providers/hyperstack/train.py +25 -0
- flash/providers/lambdalabs/__init__.py +139 -0
- flash/providers/lambdalabs/api.py +261 -0
- flash/providers/lambdalabs/auth.py +18 -0
- flash/providers/lambdalabs/gpus.py +29 -0
- flash/providers/lambdalabs/jobs/__init__.py +724 -0
- flash/providers/lambdalabs/jobs/builders.py +118 -0
- flash/providers/lambdalabs/preflight.py +27 -0
- flash/providers/lambdalabs/pricing.py +51 -0
- flash/providers/lambdalabs/train.py +27 -0
- flash/providers/preflight.py +55 -0
- flash/providers/realized.py +80 -0
- flash/providers/runpod/__init__.py +130 -0
- flash/providers/runpod/api.py +186 -0
- flash/providers/runpod/auth.py +37 -0
- flash/providers/runpod/cost.py +57 -0
- flash/providers/runpod/gpus.py +46 -0
- flash/providers/runpod/jobs.py +956 -0
- flash/providers/runpod/keys.py +139 -0
- flash/providers/runpod/preflight.py +30 -0
- flash/providers/runpod/preload.py +915 -0
- flash/providers/runpod/pricing.py +18 -0
- flash/providers/runpod/slots.py +79 -0
- flash/providers/runpod/train/__init__.py +150 -0
- flash/providers/runpod/train/deps.py +395 -0
- flash/providers/runpod/train/endpoints.py +820 -0
- flash/py.typed +0 -0
- flash/runner/__init__.py +686 -0
- flash/runner/checkpoints.py +82 -0
- flash/runner/deploy.py +422 -0
- flash/runner/lifecycle.py +672 -0
- flash/schema/__init__.py +375 -0
- flash/schema/fields.py +331 -0
- flash/serve/__init__.py +1 -0
- flash/serve/deploy.py +326 -0
- flash/serve/pricing.py +60 -0
- flash/server/__init__.py +1 -0
- flash/server/__main__.py +20 -0
- flash/server/app.py +961 -0
- flash/server/auth.py +263 -0
- flash/server/billing.py +124 -0
- flash/server/checkpoints.py +110 -0
- flash/server/db.py +160 -0
- flash/server/environment_registry.py +102 -0
- flash/server/envs.py +360 -0
- flash/server/reconcile.py +163 -0
- flash/server/run_registry.py +150 -0
- flash/spec.py +333 -0
- freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
- freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
- freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
- freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
- freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,883 @@
|
|
|
1
|
+
"""Adapter that runs Freesolo SDK environments on Flash.
|
|
2
|
+
|
|
3
|
+
Flash environment ids are Freesolo Hub slugs (``namespace/name``). Explicit
|
|
4
|
+
low-level refs remain parseable for compatibility. The canonical generated environment file is
|
|
5
|
+
``environment.py`` and its
|
|
6
|
+
``load_environment`` function must return a Freesolo SDK environment:
|
|
7
|
+
``EnvironmentSingleTurn`` or ``EnvironmentMultiTurn``.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import hashlib
|
|
13
|
+
import io
|
|
14
|
+
import json
|
|
15
|
+
import os
|
|
16
|
+
import re
|
|
17
|
+
import shutil
|
|
18
|
+
import tarfile
|
|
19
|
+
import tempfile
|
|
20
|
+
import urllib.error
|
|
21
|
+
import urllib.parse
|
|
22
|
+
import urllib.request
|
|
23
|
+
from dataclasses import dataclass
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
from typing import Any
|
|
26
|
+
|
|
27
|
+
from flash.envs.base import BaseEnvironment
|
|
28
|
+
|
|
29
|
+
_DEFAULT_GITHUB_REF = "main"
|
|
30
|
+
_DEFAULT_ENVIRONMENT_PATH = "environment.py"
|
|
31
|
+
_DEFAULT_MANAGED_ENV_REPO = "freesolo-co/environment-hub"
|
|
32
|
+
_CACHE_ROOT = Path(os.environ.get("FLASH_ENV_CACHE_DIR", "/tmp/flash-env-cache"))
|
|
33
|
+
_MAX_ARCHIVE_BYTES = 256 * 1024 * 1024
|
|
34
|
+
_MAX_ARCHIVE_MEMBERS = 5000
|
|
35
|
+
_COMMIT_SHA_RE = re.compile(r"^[0-9a-f]{40}$", re.IGNORECASE)
|
|
36
|
+
_GITHUB_SAFE_PART_RE = re.compile(r"^[A-Za-z0-9._-]+$")
|
|
37
|
+
_TAR_METADATA_TYPES = {
|
|
38
|
+
tarfile.XHDTYPE,
|
|
39
|
+
tarfile.XGLTYPE,
|
|
40
|
+
tarfile.GNUTYPE_LONGNAME,
|
|
41
|
+
tarfile.GNUTYPE_LONGLINK,
|
|
42
|
+
}
|
|
43
|
+
_CANONICAL_INPUT_KEY = "input"
|
|
44
|
+
_CANONICAL_OUTPUT_KEY = "output"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class GitHubRateLimitError(RuntimeError):
|
|
48
|
+
"""Raised when the GitHub API rate-limits us (HTTP 429, or a 403 whose body says "rate limit").
|
|
49
|
+
|
|
50
|
+
Raised by ``_urlopen`` only after the in-process jittered retry is exhausted, so it signals a
|
|
51
|
+
*persistent* limit. The worker's top-level handler catches it and stamps ``retriable=True`` so
|
|
52
|
+
the control plane reschedules the job on a fresh worker once the limit window resets, instead of
|
|
53
|
+
permanently failing the run. The real spawn-wave mitigation is the control-plane resolve-once
|
|
54
|
+
pin (EnvironmentSpec.resolved_sha) that lets workers skip the GitHub resolve entirely.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass(frozen=True)
|
|
59
|
+
class GitHubEnvironmentRef:
|
|
60
|
+
owner: str
|
|
61
|
+
repo: str
|
|
62
|
+
ref: str
|
|
63
|
+
path: str
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def repo_full_name(self) -> str:
|
|
67
|
+
return f"{self.owner}/{self.repo}"
|
|
68
|
+
|
|
69
|
+
def canonical(self) -> str:
|
|
70
|
+
return f"github:{self.repo_full_name}@{self.ref}:{self.path}"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def is_github_environment_ref(value: str) -> bool:
|
|
74
|
+
return _parse_github_environment_ref(value) is not None
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def is_managed_environment_slug(value: str) -> bool:
|
|
78
|
+
return _parse_managed_environment_slug(value) is not None
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def is_freesolo_environment_id(value: str) -> bool:
|
|
82
|
+
return is_managed_environment_slug(value) or is_github_environment_ref(value)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def managed_slug_to_github_ref(value: str) -> str:
|
|
86
|
+
parsed = _parse_managed_environment_slug(value)
|
|
87
|
+
if parsed is None:
|
|
88
|
+
raise ValueError(f"not a Freesolo environment slug: {value!r}")
|
|
89
|
+
namespace, name = parsed
|
|
90
|
+
return (
|
|
91
|
+
f"github:{_DEFAULT_MANAGED_ENV_REPO}@{_DEFAULT_GITHUB_REF}:"
|
|
92
|
+
f"{namespace}/{name}/{_DEFAULT_ENVIRONMENT_PATH}"
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _parse_managed_environment_slug(value: str) -> tuple[str, str] | None:
|
|
97
|
+
text = (value or "").strip()
|
|
98
|
+
if not text or ":" in text:
|
|
99
|
+
return None
|
|
100
|
+
parsed = urllib.parse.urlparse(text)
|
|
101
|
+
if parsed.scheme or parsed.netloc:
|
|
102
|
+
return None
|
|
103
|
+
parts = text.split("/")
|
|
104
|
+
if len(parts) != 2 or not _is_safe_github_path_parts(tuple(parts)):
|
|
105
|
+
return None
|
|
106
|
+
return parts[0], parts[1]
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _parse_github_environment_ref(value: str) -> GitHubEnvironmentRef | None:
|
|
110
|
+
text = (value or "").strip()
|
|
111
|
+
if not text:
|
|
112
|
+
return None
|
|
113
|
+
if text.startswith("github:"):
|
|
114
|
+
body = text[len("github:") :]
|
|
115
|
+
repo_ref, sep, path = body.partition(":")
|
|
116
|
+
try:
|
|
117
|
+
path = _normalize_env_path(path)
|
|
118
|
+
except ValueError:
|
|
119
|
+
return None
|
|
120
|
+
if not sep:
|
|
121
|
+
path = _DEFAULT_ENVIRONMENT_PATH
|
|
122
|
+
repo_part, at, ref = repo_ref.partition("@")
|
|
123
|
+
if not at:
|
|
124
|
+
ref = _DEFAULT_GITHUB_REF
|
|
125
|
+
if not ref:
|
|
126
|
+
return None
|
|
127
|
+
if not _is_safe_github_path_parts((ref,)):
|
|
128
|
+
return None
|
|
129
|
+
owner_repo = repo_part.split("/")
|
|
130
|
+
if len(owner_repo) == 2 and _is_safe_github_path_parts(owner_repo):
|
|
131
|
+
return GitHubEnvironmentRef(owner_repo[0], owner_repo[1], ref, path)
|
|
132
|
+
return None
|
|
133
|
+
|
|
134
|
+
parsed = urllib.parse.urlparse(text)
|
|
135
|
+
if parsed.scheme not in {"http", "https"} or parsed.netloc.lower() != "github.com":
|
|
136
|
+
return None
|
|
137
|
+
parts = [urllib.parse.unquote(p) for p in parsed.path.strip("/").split("/") if p]
|
|
138
|
+
if len(parts) < 2:
|
|
139
|
+
return None
|
|
140
|
+
owner, repo = parts[0], parts[1]
|
|
141
|
+
repo = repo[:-4] if repo.endswith(".git") else repo
|
|
142
|
+
if not _is_safe_github_path_parts((owner, repo)):
|
|
143
|
+
return None
|
|
144
|
+
if len(parts) >= 5 and parts[2] in {"blob", "tree"}:
|
|
145
|
+
ref = parts[3]
|
|
146
|
+
if not _is_safe_github_path_parts((ref,)):
|
|
147
|
+
return None
|
|
148
|
+
raw_path = "/".join(parts[4:])
|
|
149
|
+
if not _is_safe_environment_path(raw_path):
|
|
150
|
+
return None
|
|
151
|
+
try:
|
|
152
|
+
path = _normalize_env_path(raw_path)
|
|
153
|
+
except ValueError:
|
|
154
|
+
return None
|
|
155
|
+
if not raw_path:
|
|
156
|
+
path = _DEFAULT_ENVIRONMENT_PATH
|
|
157
|
+
if parts[2] == "tree" and raw_path and not path.endswith(".py"):
|
|
158
|
+
path = f"{path.rstrip('/')}/{_DEFAULT_ENVIRONMENT_PATH}"
|
|
159
|
+
elif len(parts) == 2:
|
|
160
|
+
ref = _DEFAULT_GITHUB_REF
|
|
161
|
+
path = _DEFAULT_ENVIRONMENT_PATH
|
|
162
|
+
else:
|
|
163
|
+
return None
|
|
164
|
+
return GitHubEnvironmentRef(owner, repo, ref, path)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _normalize_env_path(path: str | None) -> str:
|
|
168
|
+
if not path:
|
|
169
|
+
return _DEFAULT_ENVIRONMENT_PATH
|
|
170
|
+
raw = path.strip()
|
|
171
|
+
if not raw:
|
|
172
|
+
return _DEFAULT_ENVIRONMENT_PATH
|
|
173
|
+
raw = raw.replace("\\", "/")
|
|
174
|
+
if raw.startswith("/"):
|
|
175
|
+
raise ValueError(f"unsafe environment path: {path!r}")
|
|
176
|
+
if not raw:
|
|
177
|
+
return _DEFAULT_ENVIRONMENT_PATH
|
|
178
|
+
parts = [part for part in raw.split("/") if part]
|
|
179
|
+
if not parts:
|
|
180
|
+
return _DEFAULT_ENVIRONMENT_PATH
|
|
181
|
+
if any(part == ".." or part == "." for part in parts):
|
|
182
|
+
raise ValueError(f"unsafe environment path: {path!r}")
|
|
183
|
+
return "/".join(parts)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _is_safe_environment_path(path: str) -> bool:
|
|
187
|
+
if not path:
|
|
188
|
+
return True
|
|
189
|
+
raw = path.strip().replace("\\", "/")
|
|
190
|
+
if raw.startswith("/"):
|
|
191
|
+
return False
|
|
192
|
+
parts = [part for part in raw.split("/") if part]
|
|
193
|
+
if not parts:
|
|
194
|
+
return True
|
|
195
|
+
return not any(part in {".", ".."} for part in parts)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _is_safe_github_path_parts(parts: list[str] | tuple[str, ...]) -> bool:
|
|
199
|
+
if not parts:
|
|
200
|
+
return False
|
|
201
|
+
if any(part in {".", "..", ""} for part in parts):
|
|
202
|
+
return False
|
|
203
|
+
return all(_GITHUB_SAFE_PART_RE.fullmatch(part) for part in parts)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _github_token() -> str | None:
|
|
207
|
+
return os.environ.get("GITHUB_TOKEN")
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _is_commit_sha(value: str) -> bool:
|
|
211
|
+
return _COMMIT_SHA_RE.fullmatch(value) is not None
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def _resolve_ref_sha(
|
|
215
|
+
parsed: GitHubEnvironmentRef,
|
|
216
|
+
pinned_sha: str | None = None,
|
|
217
|
+
*,
|
|
218
|
+
timeout: float = 60.0,
|
|
219
|
+
max_rate_limit_retries: int = 5,
|
|
220
|
+
) -> str:
|
|
221
|
+
# Resolve-once hook: the control plane resolves ref->sha ONCE (runner._assign_resolved_env_sha)
|
|
222
|
+
# and threads the pinned commit sha through, so every worker in a fan-out short-circuits here
|
|
223
|
+
# without hitting GitHub at all — this is what actually defuses a cold spawn wave (the prior
|
|
224
|
+
# in-process cache could not, since each worker is a separate process). Same effect as the
|
|
225
|
+
# immutable-ref fast path below, but applied to a symbolic ref (e.g. "main"). Only a real
|
|
226
|
+
# 40-char sha is trusted; anything else falls through to a live resolve. The control plane
|
|
227
|
+
# passes a short timeout + max_rate_limit_retries=0 so its best-effort pin can't block run
|
|
228
|
+
# creation; the worker keeps the full retry budget.
|
|
229
|
+
if pinned_sha and _is_commit_sha(pinned_sha):
|
|
230
|
+
return pinned_sha
|
|
231
|
+
if _is_commit_sha(parsed.ref):
|
|
232
|
+
return parsed.ref
|
|
233
|
+
# No pin (legacy spec / non-managed ref / control-plane resolve failed): resolve the symbolic
|
|
234
|
+
# ref live every time. We deliberately do NOT cache symbolic refs in-process — a long-lived
|
|
235
|
+
# process must see a moved branch (managed slugs point at environment-hub@main, which moves on
|
|
236
|
+
# `flash env push`), and the immutable-sha cases above already skip the network.
|
|
237
|
+
headers = {"Accept": "application/vnd.github+json", "User-Agent": "freesolo-flash"}
|
|
238
|
+
token = _github_token()
|
|
239
|
+
if token:
|
|
240
|
+
headers["Authorization"] = f"Bearer {token}"
|
|
241
|
+
commit_url = f"https://api.github.com/repos/{parsed.repo_full_name}/commits/{urllib.parse.quote(parsed.ref, safe='')}"
|
|
242
|
+
req = urllib.request.Request(commit_url, headers=headers)
|
|
243
|
+
data = _urlopen(req, timeout=timeout, max_rate_limit_retries=max_rate_limit_retries)
|
|
244
|
+
try:
|
|
245
|
+
payload = json.loads(data)
|
|
246
|
+
except json.JSONDecodeError as exc:
|
|
247
|
+
raise RuntimeError(
|
|
248
|
+
f"Failed to resolve GitHub environment ref {parsed.canonical()}: invalid response"
|
|
249
|
+
) from exc
|
|
250
|
+
sha = payload.get("sha")
|
|
251
|
+
if not isinstance(sha, str) or not _is_commit_sha(sha):
|
|
252
|
+
raise RuntimeError(f"Failed to resolve GitHub environment ref {parsed.canonical()}")
|
|
253
|
+
return sha
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def _urlopen(
|
|
257
|
+
req: urllib.request.Request, *, timeout: float = 60.0, max_rate_limit_retries: int = 5
|
|
258
|
+
) -> bytes:
|
|
259
|
+
"""Fetch bytes for a GitHub request, surviving GitHub's secondary rate limit.
|
|
260
|
+
|
|
261
|
+
On a 429 or a 403 whose body says "rate limit" (a cold spawn wave of workers all hitting the
|
|
262
|
+
commits/tarball endpoint trips GitHub's abuse detection), retry up to ``max_rate_limit_retries``
|
|
263
|
+
times with jitter so concurrent workers don't all retry in lockstep. If the limit persists past
|
|
264
|
+
the retries, raise ``GitHubRateLimitError`` so the worker's top-level handler stamps
|
|
265
|
+
``retriable=True`` and the run reschedules on a fresh worker, instead of hard-failing on a
|
|
266
|
+
transient limit. Any other HTTP / URL error raises a plain ``RuntimeError`` (non-retriable).
|
|
267
|
+
|
|
268
|
+
``max_rate_limit_retries=0`` makes it fail fast (one request, no sleeps): the control plane uses
|
|
269
|
+
that for its best-effort resolve-once pin so a persistent limit can never block run creation —
|
|
270
|
+
the long retry belongs on the worker, which can afford to wait.
|
|
271
|
+
"""
|
|
272
|
+
import random
|
|
273
|
+
import time
|
|
274
|
+
|
|
275
|
+
_RATE_LIMIT_BASE_DELAY = 10.0
|
|
276
|
+
attempt = 0
|
|
277
|
+
while True:
|
|
278
|
+
try:
|
|
279
|
+
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
|
280
|
+
return resp.read()
|
|
281
|
+
except urllib.error.HTTPError as exc:
|
|
282
|
+
body = exc.read().decode("utf-8", "replace")
|
|
283
|
+
is_rate_limit = exc.code == 429 or (exc.code == 403 and "rate limit" in body.lower())
|
|
284
|
+
if is_rate_limit and attempt < max_rate_limit_retries:
|
|
285
|
+
delay = max(_RATE_LIMIT_BASE_DELAY, min(45.0, _RATE_LIMIT_BASE_DELAY * (attempt + 1) * random.uniform(0.5, 1.5)))
|
|
286
|
+
time.sleep(delay)
|
|
287
|
+
attempt += 1
|
|
288
|
+
continue
|
|
289
|
+
if is_rate_limit:
|
|
290
|
+
# Persistent limit: signal retriable so the control plane reschedules (#209).
|
|
291
|
+
raise GitHubRateLimitError(
|
|
292
|
+
f"GitHub API rate limit exceeded ({exc.code}): {body[:300]}"
|
|
293
|
+
) from exc
|
|
294
|
+
raise RuntimeError(f"GitHub environment request failed ({exc.code}): {body[:500]}") from exc
|
|
295
|
+
except urllib.error.URLError as exc:
|
|
296
|
+
raise RuntimeError(f"GitHub environment request failed: {exc.reason}") from exc
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def _download_github_tarball(ref: GitHubEnvironmentRef) -> bytes:
|
|
300
|
+
# Callers download the tarball for the ALREADY-resolved commit sha (see
|
|
301
|
+
# _resolve_github_environment_file), which extracts it into the content-addressed disk cache
|
|
302
|
+
# under _CACHE_ROOT/<hash(repo@sha:path)> and never re-downloads on a hit. So reuse is handled
|
|
303
|
+
# on disk; we deliberately do NOT also retain the (up to _MAX_ARCHIVE_BYTES) archive bytes in a
|
|
304
|
+
# module-level cache for the worker's lifetime — that wasted hundreds of MiB of RAM per process.
|
|
305
|
+
url = f"https://api.github.com/repos/{ref.repo_full_name}/tarball/{urllib.parse.quote(ref.ref, safe='')}"
|
|
306
|
+
headers = {
|
|
307
|
+
"Accept": "application/vnd.github+json",
|
|
308
|
+
"User-Agent": "freesolo-flash",
|
|
309
|
+
}
|
|
310
|
+
token = _github_token()
|
|
311
|
+
if token:
|
|
312
|
+
headers["Authorization"] = f"Bearer {token}"
|
|
313
|
+
data = _urlopen(urllib.request.Request(url, headers=headers), timeout=120.0)
|
|
314
|
+
if len(data) > _MAX_ARCHIVE_BYTES:
|
|
315
|
+
raise RuntimeError(
|
|
316
|
+
f"environment archive is too large ({len(data)} bytes; "
|
|
317
|
+
f"limit {_MAX_ARCHIVE_BYTES} bytes)"
|
|
318
|
+
)
|
|
319
|
+
return data
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def _safe_extract_archive(tar_bytes: bytes, dest: Path) -> Path:
|
|
323
|
+
root = dest.resolve()
|
|
324
|
+
top_dirs: set[str] = set()
|
|
325
|
+
total = 0
|
|
326
|
+
with tarfile.open(fileobj=io.BytesIO(tar_bytes), mode="r:gz") as tar:
|
|
327
|
+
for count, member in enumerate(tar, start=1):
|
|
328
|
+
if count > _MAX_ARCHIVE_MEMBERS:
|
|
329
|
+
raise RuntimeError(
|
|
330
|
+
f"env package has too many members (limit {_MAX_ARCHIVE_MEMBERS})"
|
|
331
|
+
)
|
|
332
|
+
if member.type in _TAR_METADATA_TYPES:
|
|
333
|
+
continue
|
|
334
|
+
parts: list[str] = []
|
|
335
|
+
for part in member.name.replace("\\", "/").split("/"):
|
|
336
|
+
if not part or part == ".":
|
|
337
|
+
continue
|
|
338
|
+
if part == "..":
|
|
339
|
+
raise RuntimeError(f"unsafe path in environment archive: {member.name!r}")
|
|
340
|
+
parts.append(part)
|
|
341
|
+
if not parts:
|
|
342
|
+
continue
|
|
343
|
+
normalized_name = "/".join(parts)
|
|
344
|
+
target = (dest / normalized_name).resolve()
|
|
345
|
+
if target != root and root not in target.parents:
|
|
346
|
+
raise RuntimeError(f"unsafe path in environment archive: {member.name!r}")
|
|
347
|
+
if member.islnk() or member.issym() or not (member.isreg() or member.isdir()):
|
|
348
|
+
continue
|
|
349
|
+
top_dirs.add(parts[0])
|
|
350
|
+
total += max(0, member.size)
|
|
351
|
+
if total > _MAX_ARCHIVE_BYTES:
|
|
352
|
+
raise RuntimeError(
|
|
353
|
+
f"environment archive is too large uncompressed ({total} bytes; limit {_MAX_ARCHIVE_BYTES} bytes)"
|
|
354
|
+
)
|
|
355
|
+
member.name = normalized_name
|
|
356
|
+
tar.extract(member, dest)
|
|
357
|
+
if len(top_dirs) != 1:
|
|
358
|
+
raise RuntimeError("environment archive had an unexpected layout")
|
|
359
|
+
extracted = dest / next(iter(top_dirs))
|
|
360
|
+
if not extracted.is_dir():
|
|
361
|
+
raise RuntimeError("environment archive did not extract to a directory")
|
|
362
|
+
return extracted
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def _resolve_github_environment_file(env_ref: str, pinned_sha: str | None = None) -> Path:
|
|
366
|
+
parsed = _parse_github_environment_ref(env_ref)
|
|
367
|
+
if parsed is None:
|
|
368
|
+
raise ValueError(f"not a GitHub environment ref: {env_ref!r}")
|
|
369
|
+
# Only forward pinned_sha when set, so the unpatched _resolve_ref_sha(parsed) call shape stays
|
|
370
|
+
# single-arg (a test monkeypatch like ``lambda parsed: ...`` keeps working).
|
|
371
|
+
resolved_ref = (
|
|
372
|
+
_resolve_ref_sha(parsed, pinned_sha=pinned_sha) if pinned_sha else _resolve_ref_sha(parsed)
|
|
373
|
+
)
|
|
374
|
+
cache_key = hashlib.sha256(
|
|
375
|
+
f"github:{parsed.repo_full_name}@{resolved_ref}:{parsed.path}".encode()
|
|
376
|
+
).hexdigest()[:24]
|
|
377
|
+
cache_dir = _CACHE_ROOT / cache_key
|
|
378
|
+
env_file = cache_dir / parsed.path
|
|
379
|
+
if env_file.is_dir():
|
|
380
|
+
env_file = env_file / _DEFAULT_ENVIRONMENT_PATH
|
|
381
|
+
if env_file.is_file():
|
|
382
|
+
return env_file
|
|
383
|
+
tmp_parent = Path(tempfile.mkdtemp(prefix="flash-env-github-"))
|
|
384
|
+
resolved = GitHubEnvironmentRef(
|
|
385
|
+
parsed.owner,
|
|
386
|
+
parsed.repo,
|
|
387
|
+
resolved_ref,
|
|
388
|
+
parsed.path,
|
|
389
|
+
)
|
|
390
|
+
try:
|
|
391
|
+
extracted = _safe_extract_archive(_download_github_tarball(resolved), tmp_parent)
|
|
392
|
+
candidate = extracted / parsed.path
|
|
393
|
+
if candidate.is_dir():
|
|
394
|
+
candidate = candidate / _DEFAULT_ENVIRONMENT_PATH
|
|
395
|
+
required_entrypoint = candidate.relative_to(extracted).as_posix()
|
|
396
|
+
if not candidate.is_file():
|
|
397
|
+
raise FileNotFoundError(
|
|
398
|
+
f"environment archive did not contain required entrypoint {required_entrypoint!r}"
|
|
399
|
+
)
|
|
400
|
+
cache_dir.parent.mkdir(parents=True, exist_ok=True)
|
|
401
|
+
shutil.rmtree(cache_dir, ignore_errors=True)
|
|
402
|
+
shutil.copytree(extracted, cache_dir)
|
|
403
|
+
return cache_dir / candidate.relative_to(extracted)
|
|
404
|
+
finally:
|
|
405
|
+
shutil.rmtree(tmp_parent, ignore_errors=True)
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def _resolve_environment_reference(env_ref: str, pinned_sha: str | None = None) -> str:
|
|
409
|
+
# pinned_sha (resolve-once hook, item 3): when the control plane already resolved this env's
|
|
410
|
+
# ref->sha, it threads the commit sha here so the GitHub commits API is skipped entirely. None
|
|
411
|
+
# (the default and today's behavior) means the worker resolves the ref itself.
|
|
412
|
+
if is_managed_environment_slug(env_ref):
|
|
413
|
+
return str(
|
|
414
|
+
_resolve_github_environment_file(managed_slug_to_github_ref(env_ref), pinned_sha)
|
|
415
|
+
)
|
|
416
|
+
parsed = _parse_github_environment_ref(env_ref)
|
|
417
|
+
if parsed is None:
|
|
418
|
+
path = Path(env_ref)
|
|
419
|
+
if path.exists():
|
|
420
|
+
return str(path)
|
|
421
|
+
return env_ref
|
|
422
|
+
return str(_resolve_github_environment_file(env_ref, pinned_sha))
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def _resolve_path_arg(value: object, base_dir: Path) -> object:
|
|
426
|
+
if not isinstance(value, str) or not value:
|
|
427
|
+
return value
|
|
428
|
+
parsed = urllib.parse.urlparse(value)
|
|
429
|
+
if parsed.scheme or Path(value).is_absolute():
|
|
430
|
+
return value
|
|
431
|
+
candidate = base_dir / value
|
|
432
|
+
return str(candidate) if candidate.exists() else value
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
def _load_contract_text(path: str | None) -> str:
|
|
436
|
+
if not path:
|
|
437
|
+
return ""
|
|
438
|
+
candidate = Path(path)
|
|
439
|
+
if not candidate.is_file():
|
|
440
|
+
return ""
|
|
441
|
+
try:
|
|
442
|
+
return candidate.read_text(encoding="utf-8")
|
|
443
|
+
except UnicodeError:
|
|
444
|
+
return candidate.read_text(errors="replace")
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def _import_freesolo_environment_tools():
|
|
448
|
+
try:
|
|
449
|
+
from freesolo.datasets.records import load_task_examples, task_example_from_record
|
|
450
|
+
from freesolo.environments import (
|
|
451
|
+
EnvironmentEpisode,
|
|
452
|
+
EnvironmentMultiTurn,
|
|
453
|
+
EnvironmentSingleTurn,
|
|
454
|
+
EnvironmentTurn,
|
|
455
|
+
load_environment,
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
return {
|
|
459
|
+
"EnvironmentEpisode": EnvironmentEpisode,
|
|
460
|
+
"EnvironmentMultiTurn": EnvironmentMultiTurn,
|
|
461
|
+
"EnvironmentSingleTurn": EnvironmentSingleTurn,
|
|
462
|
+
"EnvironmentTurn": EnvironmentTurn,
|
|
463
|
+
"load_environment": load_environment,
|
|
464
|
+
"load_task_examples": load_task_examples,
|
|
465
|
+
"task_example_from_record": task_example_from_record,
|
|
466
|
+
}
|
|
467
|
+
except ImportError as exc:
|
|
468
|
+
raise ImportError(
|
|
469
|
+
"the 'freesolo' package is required to run Freesolo environments; "
|
|
470
|
+
"install it (for example `uv pip install freesolo`) or use a worker image "
|
|
471
|
+
"that includes the Freesolo SDK"
|
|
472
|
+
) from exc
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
def _json_safe(value: Any) -> Any:
|
|
476
|
+
try:
|
|
477
|
+
json.dumps(value)
|
|
478
|
+
return value
|
|
479
|
+
except TypeError:
|
|
480
|
+
return str(value)
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
class FreesoloEnvironment(BaseEnvironment):
|
|
484
|
+
"""Flash environment backed by ``freesolo.environments``."""
|
|
485
|
+
|
|
486
|
+
def __init__(
|
|
487
|
+
self,
|
|
488
|
+
sdk_env: object,
|
|
489
|
+
env_id: str,
|
|
490
|
+
*,
|
|
491
|
+
source: object | None,
|
|
492
|
+
contract_text: str = "",
|
|
493
|
+
):
|
|
494
|
+
super().__init__(id=env_id)
|
|
495
|
+
self._env = sdk_env
|
|
496
|
+
self._source = source
|
|
497
|
+
self._contract_text = contract_text
|
|
498
|
+
tools = _import_freesolo_environment_tools()
|
|
499
|
+
self._task_example_from_record = tools["task_example_from_record"]
|
|
500
|
+
self._load_task_examples = tools["load_task_examples"]
|
|
501
|
+
self._EnvironmentEpisode = tools["EnvironmentEpisode"]
|
|
502
|
+
self._EnvironmentMultiTurn = tools["EnvironmentMultiTurn"]
|
|
503
|
+
self._EnvironmentTurn = tools["EnvironmentTurn"]
|
|
504
|
+
self.multi_turn = isinstance(sdk_env, tools["EnvironmentMultiTurn"])
|
|
505
|
+
self.is_tool_env = False
|
|
506
|
+
self._max_turns_cache: int | None = None
|
|
507
|
+
self._dataset_cache: list[dict] | None = None
|
|
508
|
+
|
|
509
|
+
@property
|
|
510
|
+
def max_turns(self) -> int:
|
|
511
|
+
"""Rollout turn ceiling the worker reads for the batch-level loop cap.
|
|
512
|
+
|
|
513
|
+
A pure multi-turn freesolo env sets a *per-example* budget via
|
|
514
|
+
``max_episode_turns(example)`` (e.g. #user turns + a tool-iteration budget per
|
|
515
|
+
turn). The batch cap the rollout loop uses must be at least the largest such
|
|
516
|
+
budget, or it would truncate the deepest scenarios before they finish (e.g.
|
|
517
|
+
support-chat's 4-customer-turn rollouts need ~20 turns, not 8). Take the
|
|
518
|
+
dataset-wide max once, bounded so a pathological env can't make rollouts
|
|
519
|
+
unbounded; single-turn / non-multi-turn envs keep the small default. The exact
|
|
520
|
+
per-example budget is still enforced in :meth:`rollout_done`.
|
|
521
|
+
"""
|
|
522
|
+
if self._max_turns_cache is not None:
|
|
523
|
+
return self._max_turns_cache
|
|
524
|
+
cap = 8
|
|
525
|
+
if self.multi_turn:
|
|
526
|
+
cap = 24 # safe default if no per-example budget can be read at all
|
|
527
|
+
best: int | None = None # running max — no intermediate list for large datasets
|
|
528
|
+
for ex in self.dataset(): # cached; see dataset()
|
|
529
|
+
# Per-example so ONE malformed row (or an env whose max_episode_turns raises on it)
|
|
530
|
+
# is skipped rather than discarding every budget and silently falling back to 24,
|
|
531
|
+
# which would reintroduce the truncation this is meant to prevent.
|
|
532
|
+
try:
|
|
533
|
+
turns = int(self._env.max_episode_turns(self._task_example(ex)))
|
|
534
|
+
except Exception:
|
|
535
|
+
continue
|
|
536
|
+
if best is None or turns > best:
|
|
537
|
+
best = turns
|
|
538
|
+
if best is not None:
|
|
539
|
+
cap = max(8, min(64, best))
|
|
540
|
+
self._max_turns_cache = cap
|
|
541
|
+
return cap
|
|
542
|
+
|
|
543
|
+
def _task_example(self, example: dict):
|
|
544
|
+
return self._task_example_from_record(self._canonical_record(example))
|
|
545
|
+
|
|
546
|
+
@staticmethod
|
|
547
|
+
def _canonical_record(record: dict) -> dict:
|
|
548
|
+
raw = dict(record)
|
|
549
|
+
canonical = {}
|
|
550
|
+
if _CANONICAL_INPUT_KEY not in raw:
|
|
551
|
+
raise ValueError("Freesolo dataset records must contain an input field")
|
|
552
|
+
canonical[_CANONICAL_INPUT_KEY] = raw[_CANONICAL_INPUT_KEY]
|
|
553
|
+
if _CANONICAL_OUTPUT_KEY in raw:
|
|
554
|
+
canonical[_CANONICAL_OUTPUT_KEY] = raw[_CANONICAL_OUTPUT_KEY]
|
|
555
|
+
if raw.get("id") is not None:
|
|
556
|
+
canonical["id"] = raw["id"]
|
|
557
|
+
metadata = raw.get("metadata")
|
|
558
|
+
if isinstance(metadata, dict) and metadata:
|
|
559
|
+
canonical["metadata"] = metadata
|
|
560
|
+
return canonical
|
|
561
|
+
|
|
562
|
+
def _reward_to_breakdown(self, reward) -> dict[str, float]:
|
|
563
|
+
out: dict[str, float] = {}
|
|
564
|
+
for metric in getattr(reward, "metrics", ()) or ():
|
|
565
|
+
score = getattr(metric, "score", None)
|
|
566
|
+
if score is not None:
|
|
567
|
+
name = str(getattr(metric, "name", "") or "metric")
|
|
568
|
+
key = name
|
|
569
|
+
idx = 1
|
|
570
|
+
while key in out:
|
|
571
|
+
idx += 1
|
|
572
|
+
key = f"{name}_{idx}"
|
|
573
|
+
out[key] = float(score)
|
|
574
|
+
out["total"] = float(getattr(reward, "score", 0.0))
|
|
575
|
+
return out
|
|
576
|
+
|
|
577
|
+
def dataset(self) -> list[dict]:
|
|
578
|
+
# Parse once and cache: the worker reads ``env.dataset()`` AND ``env.max_turns`` (which
|
|
579
|
+
# also scans the dataset), so without this a multi-turn run would parse/load the whole
|
|
580
|
+
# dataset twice at startup.
|
|
581
|
+
if self._dataset_cache is not None:
|
|
582
|
+
return self._dataset_cache
|
|
583
|
+
if self._source is None:
|
|
584
|
+
rows = getattr(self._env, "dataset", None) or getattr(self._env, "examples", None)
|
|
585
|
+
if rows is None:
|
|
586
|
+
raise ValueError(
|
|
587
|
+
"Freesolo environment has no dataset source. Set "
|
|
588
|
+
"[environment.params] dataset_path or records so Flash can train."
|
|
589
|
+
)
|
|
590
|
+
examples = self._load_task_examples(rows)
|
|
591
|
+
else:
|
|
592
|
+
examples = self._load_task_examples(self._source)
|
|
593
|
+
records = []
|
|
594
|
+
for example in examples:
|
|
595
|
+
raw = dict(getattr(example, "record", {}) or {})
|
|
596
|
+
task = getattr(example, "task", None)
|
|
597
|
+
if _CANONICAL_INPUT_KEY not in raw and task is not None:
|
|
598
|
+
raw[_CANONICAL_INPUT_KEY] = task
|
|
599
|
+
task_id = getattr(example, "task_id", None)
|
|
600
|
+
if task_id is not None:
|
|
601
|
+
raw.setdefault("id", task_id)
|
|
602
|
+
expected = getattr(example, "expected_output", None)
|
|
603
|
+
if expected is not None:
|
|
604
|
+
raw.setdefault(_CANONICAL_OUTPUT_KEY, _json_safe(expected))
|
|
605
|
+
metadata = getattr(example, "metadata", None)
|
|
606
|
+
if isinstance(metadata, dict) and metadata:
|
|
607
|
+
raw.setdefault("metadata", metadata)
|
|
608
|
+
record = self._canonical_record(raw)
|
|
609
|
+
records.append(record)
|
|
610
|
+
self._dataset_cache = records
|
|
611
|
+
return records
|
|
612
|
+
|
|
613
|
+
def prompt_messages(self, example: dict) -> list[dict]:
|
|
614
|
+
messages = self._env.start_episode(self._task_example(example), self._contract_text)
|
|
615
|
+
return [dict(message) for message in messages]
|
|
616
|
+
|
|
617
|
+
def sft_completion(self, example: dict) -> list[dict]:
|
|
618
|
+
"""Target completion messages to append after the prompt for one SFT example.
|
|
619
|
+
|
|
620
|
+
Delegates to the freesolo-sdk env's first-class ``Environment.sft_completion``, which turns
|
|
621
|
+
the record's ``output`` into the target messages: a MULTI-TURN target trajectory —
|
|
622
|
+
assistant turns, tool calls, tool results, replies (authored as ``output = {"messages":
|
|
623
|
+
[...]}`` or a bare message list) — when the row ships one, else a single assistant turn from
|
|
624
|
+
a scalar output. So the SFT example shape is owned by the freesolo-sdk dataset layer
|
|
625
|
+
(``freesolo.datasets.target_messages``), not a flash-only convention; ``len(...) > 1`` is
|
|
626
|
+
multi-turn. Falls back to reading the raw record only for an older installed SDK that
|
|
627
|
+
predates the method."""
|
|
628
|
+
fn = getattr(self._env, "sft_completion", None)
|
|
629
|
+
if callable(fn):
|
|
630
|
+
msgs = fn(self._task_example(example))
|
|
631
|
+
if msgs:
|
|
632
|
+
return [dict(m) for m in msgs]
|
|
633
|
+
value = example.get(_CANONICAL_OUTPUT_KEY)
|
|
634
|
+
if isinstance(value, list) and value and all(isinstance(m, dict) for m in value):
|
|
635
|
+
return [dict(m) for m in value]
|
|
636
|
+
if isinstance(value, dict) and list(value) == ["messages"] and isinstance(value["messages"], list):
|
|
637
|
+
return [dict(m) for m in value["messages"]]
|
|
638
|
+
return [{"role": "assistant", "content": "" if value is None else str(value)}]
|
|
639
|
+
|
|
640
|
+
def scores_breakdown(
|
|
641
|
+
self, completion: str, example: dict, state: dict | None = None
|
|
642
|
+
) -> dict[str, float]:
|
|
643
|
+
if state and self.multi_turn:
|
|
644
|
+
reward = self._score_episode(example, state)
|
|
645
|
+
else:
|
|
646
|
+
rewards = self._env.score_responses(self._task_example(example), [completion])
|
|
647
|
+
if len(rewards) != 1:
|
|
648
|
+
raise RuntimeError("Freesolo environment score_responses returned the wrong length")
|
|
649
|
+
reward = rewards[0]
|
|
650
|
+
return self._reward_to_breakdown(reward)
|
|
651
|
+
|
|
652
|
+
def reward(self, completion: str, example: dict, state: dict | None = None) -> float:
|
|
653
|
+
return float(self.scores_breakdown(completion, example, state)["total"])
|
|
654
|
+
|
|
655
|
+
def reward_many(self, items: list[tuple[dict, dict]]) -> list[float]:
|
|
656
|
+
"""Reward for many ``(example, state)`` rollouts at once, in input order.
|
|
657
|
+
|
|
658
|
+
For multi-turn, episodes that share a task go through ONE ``score_episodes`` call, which the
|
|
659
|
+
env scores concurrently (``Environment.max_score_concurrency``) — replacing one blocking
|
|
660
|
+
scoring call per rollout. For a judge / network-reward env (where scoring dominates) this is
|
|
661
|
+
the multi-turn analogue of batched generation. Equals one :meth:`reward` per item:
|
|
662
|
+
``score_episodes`` scores each episode independently, so batching changes only concurrency,
|
|
663
|
+
not values. Single-turn falls back to per-item :meth:`reward`."""
|
|
664
|
+
if not self.multi_turn:
|
|
665
|
+
# Single-turn scoring ignores state and grades the completion, so pass the rollout's
|
|
666
|
+
# actual response (stored on the state) — not "" (which would score every item empty).
|
|
667
|
+
return [self.reward(str(st.get("response_text") or ""), ex, st) for ex, st in items]
|
|
668
|
+
groups: dict[str, dict] = {}
|
|
669
|
+
order: list[str] = []
|
|
670
|
+
for i, (ex, st) in enumerate(items):
|
|
671
|
+
# Group rollouts of the same example (a GRPO group shares one prompt) so their episodes
|
|
672
|
+
# are scored together; the example dict is the stable grouping key.
|
|
673
|
+
key = json.dumps(ex, sort_keys=True, default=str)
|
|
674
|
+
grp = groups.get(key)
|
|
675
|
+
if grp is None:
|
|
676
|
+
grp = groups[key] = {
|
|
677
|
+
"task": st.get("task") or self._task_example(ex),
|
|
678
|
+
"idxs": [],
|
|
679
|
+
"episodes": [],
|
|
680
|
+
}
|
|
681
|
+
order.append(key)
|
|
682
|
+
grp["idxs"].append(i)
|
|
683
|
+
grp["episodes"].append(self._episode_from_state(st))
|
|
684
|
+
out: list[float] = [0.0] * len(items)
|
|
685
|
+
for key in order:
|
|
686
|
+
grp = groups[key]
|
|
687
|
+
rewards = self._env.score_episodes(grp["task"], grp["episodes"])
|
|
688
|
+
if len(rewards) != len(grp["episodes"]):
|
|
689
|
+
raise RuntimeError("Freesolo environment score_episodes returned the wrong length")
|
|
690
|
+
for idx, rw in zip(grp["idxs"], rewards, strict=True):
|
|
691
|
+
out[idx] = float(rw.score)
|
|
692
|
+
return out
|
|
693
|
+
|
|
694
|
+
@property
|
|
695
|
+
def reward_thread_safe(self) -> bool:
|
|
696
|
+
"""Whether ``reward`` may be called concurrently across rollouts (multiturn_rollout's
|
|
697
|
+
``_score_rollouts`` thread-pool fallback, used when the env has no ``reward_many``). The
|
|
698
|
+
verifiers reward contract is a pure scorer — ``score_responses`` reads the per-call inputs +
|
|
699
|
+
immutable env config — so the default is True. An underlying env whose scorer keeps mutable
|
|
700
|
+
state or a thread-bound client opts out with ``reward_thread_safe = False`` (scored serially)."""
|
|
701
|
+
return bool(getattr(self._env, "reward_thread_safe", True))
|
|
702
|
+
|
|
703
|
+
def grade(self, completion: str, example: dict, state: dict | None = None) -> bool:
|
|
704
|
+
if state and self.multi_turn:
|
|
705
|
+
reward = self._score_episode(example, state)
|
|
706
|
+
else:
|
|
707
|
+
rewards = self._env.score_responses(self._task_example(example), [completion])
|
|
708
|
+
if len(rewards) != 1:
|
|
709
|
+
raise RuntimeError("Freesolo environment score_responses returned the wrong length")
|
|
710
|
+
reward = rewards[0]
|
|
711
|
+
return bool(reward.resolved_success())
|
|
712
|
+
|
|
713
|
+
def tools(self) -> list:
|
|
714
|
+
return []
|
|
715
|
+
|
|
716
|
+
def new_rollout_state(self, example: dict) -> dict:
|
|
717
|
+
task = self._task_example(example)
|
|
718
|
+
prompt = [dict(message) for message in self._env.start_episode(task, self._contract_text)]
|
|
719
|
+
# Per-example turn budget (env's max_episode_turns) so rollout_done caps THIS
|
|
720
|
+
# rollout at its own budget rather than the batch-wide ceiling -- a single-turn
|
|
721
|
+
# scenario stops after a few turns while a deep one gets its full budget.
|
|
722
|
+
try:
|
|
723
|
+
episode_turns: int | None = int(self._env.max_episode_turns(task))
|
|
724
|
+
except Exception:
|
|
725
|
+
episode_turns = None
|
|
726
|
+
return {
|
|
727
|
+
"task": task,
|
|
728
|
+
"prompt": [dict(message) for message in prompt],
|
|
729
|
+
"messages": [dict(message) for message in prompt],
|
|
730
|
+
"turns": [],
|
|
731
|
+
"done": False,
|
|
732
|
+
"response_text": "",
|
|
733
|
+
"turn": 0,
|
|
734
|
+
"max_episode_turns": episode_turns,
|
|
735
|
+
}
|
|
736
|
+
|
|
737
|
+
def record_model_turn(self, state: dict, content: str) -> dict:
|
|
738
|
+
msg = {"role": "assistant", "content": content}
|
|
739
|
+
state.setdefault("messages", []).append(msg)
|
|
740
|
+
state.setdefault("turns", []).append(
|
|
741
|
+
self._EnvironmentTurn(role="assistant", content=content)
|
|
742
|
+
)
|
|
743
|
+
state["response_text"] = content
|
|
744
|
+
return msg
|
|
745
|
+
|
|
746
|
+
def env_reply(self, messages: list[dict], state: dict) -> list[dict]:
|
|
747
|
+
if not self.multi_turn:
|
|
748
|
+
return []
|
|
749
|
+
task = state.get("task")
|
|
750
|
+
if task is None:
|
|
751
|
+
raise RuntimeError("missing Freesolo rollout task state")
|
|
752
|
+
assistant_response = str(state.get("response_text") or "")
|
|
753
|
+
step = self._env.step_episode(task, list(messages), assistant_response)
|
|
754
|
+
state["done"] = bool(step.done)
|
|
755
|
+
if step.final_response_text is not None:
|
|
756
|
+
state["response_text"] = step.final_response_text
|
|
757
|
+
state["turn"] = int(state.get("turn", 0)) + 1
|
|
758
|
+
if step.metadata:
|
|
759
|
+
state.setdefault("step_metadata", []).append(step.metadata)
|
|
760
|
+
replies = [dict(message) for message in step.messages]
|
|
761
|
+
state.setdefault("messages", []).extend(replies)
|
|
762
|
+
for message in replies:
|
|
763
|
+
state.setdefault("turns", []).append(
|
|
764
|
+
self._EnvironmentTurn(
|
|
765
|
+
role=str(message.get("role", "")),
|
|
766
|
+
content=str(message.get("content", "")),
|
|
767
|
+
)
|
|
768
|
+
)
|
|
769
|
+
return replies
|
|
770
|
+
|
|
771
|
+
def rollout_done(self, state: dict, max_turns: int | None = None) -> bool:
|
|
772
|
+
if not self.multi_turn:
|
|
773
|
+
return True
|
|
774
|
+
if bool(state.get("done")):
|
|
775
|
+
return True
|
|
776
|
+
# Prefer THIS rollout's own per-example budget (set in new_rollout_state); fall
|
|
777
|
+
# back to the batch-wide cap the worker passes. The env normally terminates via
|
|
778
|
+
# step.done well before either, so this is a non-termination guard.
|
|
779
|
+
cap = state.get("max_episode_turns")
|
|
780
|
+
if cap is None:
|
|
781
|
+
cap = max_turns
|
|
782
|
+
return cap is not None and int(state.get("turn", 0)) >= int(cap)
|
|
783
|
+
|
|
784
|
+
def _episode_from_state(self, state: dict):
|
|
785
|
+
return self._EnvironmentEpisode(
|
|
786
|
+
messages=tuple(state.get("messages") or ()),
|
|
787
|
+
response_text=str(state.get("response_text") or ""),
|
|
788
|
+
turns=tuple(state.get("turns") or ()),
|
|
789
|
+
metadata={"steps": state.get("step_metadata", [])}
|
|
790
|
+
if state.get("step_metadata")
|
|
791
|
+
else {},
|
|
792
|
+
)
|
|
793
|
+
|
|
794
|
+
def _score_episode(self, example: dict, state: dict):
|
|
795
|
+
task = state.get("task") or self._task_example(example)
|
|
796
|
+
rewards = self._env.score_episodes(task, [self._episode_from_state(state)])
|
|
797
|
+
if len(rewards) != 1:
|
|
798
|
+
raise RuntimeError("Freesolo environment score_episodes returned the wrong length")
|
|
799
|
+
return rewards[0]
|
|
800
|
+
|
|
801
|
+
def reward_from_messages(
|
|
802
|
+
self, completion_msgs: list[dict], example: dict, prompt_msgs: list[dict] | None = None
|
|
803
|
+
) -> float:
|
|
804
|
+
messages = [*(prompt_msgs or []), *completion_msgs]
|
|
805
|
+
response_text = ""
|
|
806
|
+
turns = []
|
|
807
|
+
for message in completion_msgs:
|
|
808
|
+
content = str(message.get("content", ""))
|
|
809
|
+
role = str(message.get("role", ""))
|
|
810
|
+
turns.append(self._EnvironmentTurn(role=role, content=content))
|
|
811
|
+
if role == "assistant":
|
|
812
|
+
response_text = content
|
|
813
|
+
episode = self._EnvironmentEpisode(
|
|
814
|
+
messages=tuple(dict(m) for m in messages),
|
|
815
|
+
response_text=response_text,
|
|
816
|
+
turns=tuple(turns),
|
|
817
|
+
)
|
|
818
|
+
rewards = self._env.score_episodes(self._task_example(example), [episode])
|
|
819
|
+
if len(rewards) != 1:
|
|
820
|
+
raise RuntimeError("Freesolo environment score_episodes returned the wrong length")
|
|
821
|
+
return float(rewards[0].score)
|
|
822
|
+
|
|
823
|
+
|
|
824
|
+
def load_freesolo_environment(
|
|
825
|
+
env_id: str, pinned_sha: str | None = None, /, **kwargs
|
|
826
|
+
) -> FreesoloEnvironment:
|
|
827
|
+
# pinned_sha is a POSITIONAL-ONLY resolve-once hook (the control-plane-pinned commit sha). It is
|
|
828
|
+
# positional-only (the `/`) precisely so a user [environment.params] entry of ANY name — even
|
|
829
|
+
# one literally named "pinned_sha" — lands in **kwargs and is forwarded verbatim to the Freesolo
|
|
830
|
+
# SDK loader, never binding to or shadowing this internal pin. None (default) preserves today's
|
|
831
|
+
# behavior — the worker resolves the env ref->sha itself.
|
|
832
|
+
tools = _import_freesolo_environment_tools()
|
|
833
|
+
reference = _resolve_environment_reference(env_id, pinned_sha)
|
|
834
|
+
reference_path = Path(reference)
|
|
835
|
+
base_dir = reference_path.parent if reference_path.exists() else Path.cwd()
|
|
836
|
+
|
|
837
|
+
params = dict(kwargs)
|
|
838
|
+
source = params.pop("records", None)
|
|
839
|
+
dataset_path = params.get("dataset_path")
|
|
840
|
+
if source is None and dataset_path:
|
|
841
|
+
resolved_dataset_path = _resolve_path_arg(dataset_path, base_dir)
|
|
842
|
+
params["dataset_path"] = resolved_dataset_path
|
|
843
|
+
source = resolved_dataset_path
|
|
844
|
+
if source is None:
|
|
845
|
+
for rel in (
|
|
846
|
+
"datasets/train.jsonl",
|
|
847
|
+
"datasets/train.json",
|
|
848
|
+
"train.jsonl",
|
|
849
|
+
"train.json",
|
|
850
|
+
):
|
|
851
|
+
candidate = base_dir / rel
|
|
852
|
+
if candidate.is_file():
|
|
853
|
+
params.setdefault("dataset_path", str(candidate))
|
|
854
|
+
source = str(candidate)
|
|
855
|
+
break
|
|
856
|
+
|
|
857
|
+
contract_path = _resolve_path_arg(params.get("contract_path"), base_dir)
|
|
858
|
+
if isinstance(contract_path, str):
|
|
859
|
+
params["contract_path"] = contract_path
|
|
860
|
+
else:
|
|
861
|
+
params.setdefault("contract_path", str(base_dir / "TRAINING_CONTRACT.md"))
|
|
862
|
+
contract_text = str(
|
|
863
|
+
params.pop("contract_text", "") or _load_contract_text(params["contract_path"])
|
|
864
|
+
)
|
|
865
|
+
|
|
866
|
+
sdk_env = tools["load_environment"](reference, **params)
|
|
867
|
+
return FreesoloEnvironment(
|
|
868
|
+
sdk_env,
|
|
869
|
+
env_id,
|
|
870
|
+
source=source,
|
|
871
|
+
contract_text=contract_text,
|
|
872
|
+
)
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
__all__ = [
|
|
876
|
+
"FreesoloEnvironment",
|
|
877
|
+
"GitHubEnvironmentRef",
|
|
878
|
+
"is_freesolo_environment_id",
|
|
879
|
+
"is_github_environment_ref",
|
|
880
|
+
"is_managed_environment_slug",
|
|
881
|
+
"load_freesolo_environment",
|
|
882
|
+
"managed_slug_to_github_ref",
|
|
883
|
+
]
|