benchmax 0.1.2.dev28__py3-none-any.whl → 0.1.2.dev29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmax/cli.py +78 -0
- benchmax/config.py +42 -1
- benchmax/platform/__init__.py +10 -0
- benchmax/platform/client.py +303 -16
- benchmax/platform/credentials.py +224 -4
- benchmax/platform/device_auth.py +81 -0
- benchmax/platform/login.py +81 -0
- benchmax/platform/training_run.py +5 -3
- benchmax/platform/validation.py +151 -7
- benchmax/rag/corpus/postgres/client.py +9 -1
- benchmax/rag/corpus/postgres/source.py +21 -11
- benchmax/rag/qa_generation/filters/env_rollout.py +9 -1
- benchmax/rag/qa_generation/filters/grounding_llm.py +9 -1
- benchmax/rag/qa_generation/filters/hop_count_validity.py +7 -6
- benchmax/rag/qa_generation/filters/retrieval_llm.py +8 -1
- benchmax/rag/qa_generation/pipeline.py +10 -4
- benchmax/rag/qa_generation/pipeline_config.py +7 -3
- {benchmax-0.1.2.dev28.dist-info → benchmax-0.1.2.dev29.dist-info}/METADATA +1 -1
- {benchmax-0.1.2.dev28.dist-info → benchmax-0.1.2.dev29.dist-info}/RECORD +23 -19
- benchmax-0.1.2.dev29.dist-info/entry_points.txt +2 -0
- {benchmax-0.1.2.dev28.dist-info → benchmax-0.1.2.dev29.dist-info}/WHEEL +0 -0
- {benchmax-0.1.2.dev28.dist-info → benchmax-0.1.2.dev29.dist-info}/licenses/LICENSE +0 -0
- {benchmax-0.1.2.dev28.dist-info → benchmax-0.1.2.dev29.dist-info}/top_level.txt +0 -0
benchmax/cli.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"""``castform`` CLI — browser-based login for the SDK.
|
|
2
|
+
|
|
3
|
+
Commands: ``login`` (device authorization), ``logout``, ``whoami``. The login
|
|
4
|
+
flow + the reusable ``ensure_session`` live in :mod:`benchmax.platform.login`;
|
|
5
|
+
this module is the thin argparse wrapper. After ``castform login`` the SDK
|
|
6
|
+
resolves its bearer from ``~/.castform`` automatically — no API key or URL.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import argparse
|
|
12
|
+
import sys
|
|
13
|
+
|
|
14
|
+
from benchmax.platform import credentials
|
|
15
|
+
from benchmax.platform.device_auth import DeviceAuthError
|
|
16
|
+
from benchmax.platform.login import _login
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _cmd_login(args: argparse.Namespace) -> int:
|
|
20
|
+
env = "staging" if args.env == "staging" else None
|
|
21
|
+
try:
|
|
22
|
+
_login(env)
|
|
23
|
+
except DeviceAuthError as exc:
|
|
24
|
+
print(f"Login failed: {exc}", file=sys.stderr)
|
|
25
|
+
return 1
|
|
26
|
+
print(f"\n✓ Logged in to {args.env}.")
|
|
27
|
+
return 0
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _cmd_logout(_args: argparse.Namespace) -> int:
|
|
31
|
+
credentials.clear_castform_session()
|
|
32
|
+
print("✓ Logged out.")
|
|
33
|
+
return 0
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _cmd_whoami(_args: argparse.Namespace) -> int:
|
|
37
|
+
session = credentials.read_castform_session()
|
|
38
|
+
if not session:
|
|
39
|
+
print("Not logged in. Run `castform login`.", file=sys.stderr)
|
|
40
|
+
return 1
|
|
41
|
+
env = session.get("env", "prod")
|
|
42
|
+
jwt = credentials._session_jwt() # mints from the session; None if invalid/expired/offline
|
|
43
|
+
if not jwt:
|
|
44
|
+
print(
|
|
45
|
+
f"Session present (env: {env}), but couldn't reach auth-service to "
|
|
46
|
+
"verify it (offline, or the session expired). If this persists, run "
|
|
47
|
+
"`castform login` again.",
|
|
48
|
+
file=sys.stderr,
|
|
49
|
+
)
|
|
50
|
+
return 1
|
|
51
|
+
claims = credentials._jwt_claims(jwt)
|
|
52
|
+
who = claims.get("email") or claims.get("sub", "<unknown>")
|
|
53
|
+
print(f"Logged in as {who} (env: {env}).")
|
|
54
|
+
return 0
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def main(argv: list[str] | None = None) -> int:
|
|
58
|
+
parser = argparse.ArgumentParser(prog="castform", description="Castform CLI")
|
|
59
|
+
sub = parser.add_subparsers(dest="command", required=True)
|
|
60
|
+
|
|
61
|
+
p_login = sub.add_parser("login", help="Sign in via your browser")
|
|
62
|
+
p_login.add_argument(
|
|
63
|
+
"--env",
|
|
64
|
+
choices=["prod", "staging"],
|
|
65
|
+
default="prod",
|
|
66
|
+
help="Environment to sign in to (staging is internal-only)",
|
|
67
|
+
)
|
|
68
|
+
p_login.set_defaults(func=_cmd_login)
|
|
69
|
+
|
|
70
|
+
sub.add_parser("logout", help="Clear the cached session").set_defaults(func=_cmd_logout)
|
|
71
|
+
sub.add_parser("whoami", help="Show the current login").set_defaults(func=_cmd_whoami)
|
|
72
|
+
|
|
73
|
+
args = parser.parse_args(argv)
|
|
74
|
+
return args.func(args)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
if __name__ == "__main__":
|
|
78
|
+
sys.exit(main())
|
benchmax/config.py
CHANGED
|
@@ -16,7 +16,38 @@ DEFAULT_BASE_DOMAIN = "castform.com"
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
def base_domain() -> str:
|
|
19
|
-
|
|
19
|
+
"""Resolve the platform base domain.
|
|
20
|
+
|
|
21
|
+
Precedence: explicit ``CASTFORM_BASE_DOMAIN`` → the cached device-auth
|
|
22
|
+
session's ``env`` (``staging`` → ``castform.dev``) → ``prod`` default
|
|
23
|
+
(``castform.com``). The ``env`` claim travels with the credential, so a
|
|
24
|
+
logged-in SDK routes to the same environment it authenticated against —
|
|
25
|
+
URL and credential can't desync. A prod session carries no ``env`` marker
|
|
26
|
+
(``None`` → prod), so only internal staging logins deviate from the default.
|
|
27
|
+
"""
|
|
28
|
+
override = os.environ.get("CASTFORM_BASE_DOMAIN")
|
|
29
|
+
if override:
|
|
30
|
+
return override
|
|
31
|
+
if _session_env() == "staging":
|
|
32
|
+
return "castform.dev"
|
|
33
|
+
return DEFAULT_BASE_DOMAIN
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _session_env() -> str | None:
|
|
37
|
+
"""The ``env`` from the cached device-auth session, if any.
|
|
38
|
+
|
|
39
|
+
Lazy import: ``config`` is a leaf that ``benchmax.platform`` depends on, so a
|
|
40
|
+
top-level import would cycle (platform/__init__ → client → config)."""
|
|
41
|
+
try:
|
|
42
|
+
from benchmax.platform.credentials import read_castform_session
|
|
43
|
+
|
|
44
|
+
session = read_castform_session()
|
|
45
|
+
except Exception:
|
|
46
|
+
return None
|
|
47
|
+
if not session:
|
|
48
|
+
return None
|
|
49
|
+
env = session.get("env")
|
|
50
|
+
return env if isinstance(env, str) else None
|
|
20
51
|
|
|
21
52
|
|
|
22
53
|
def platform_url() -> str:
|
|
@@ -38,3 +69,13 @@ def web_app_url() -> str:
|
|
|
38
69
|
def llm_url() -> str:
|
|
39
70
|
"""OpenAI-compatible LLM endpoint hosted by the platform."""
|
|
40
71
|
return os.environ.get("CASTFORM_LLM_URL") or f"https://llm.{base_domain()}/v1"
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def auth_url() -> str:
|
|
75
|
+
"""Auth-service base URL (device-authorization + JWT mint endpoints).
|
|
76
|
+
|
|
77
|
+
Used by ``castform login`` and the per-process session→JWT mint. Derives from
|
|
78
|
+
the same base domain as everything else, so a session minted against ``staging``
|
|
79
|
+
talks to ``auth.castform.dev`` and a ``prod`` session to ``auth.castform.com``.
|
|
80
|
+
"""
|
|
81
|
+
return os.environ.get("CASTFORM_AUTH_URL") or f"https://auth.{base_domain()}"
|
benchmax/platform/__init__.py
CHANGED
|
@@ -1,15 +1,25 @@
|
|
|
1
1
|
"""Castform platform clients (storage, training runs, rollout)."""
|
|
2
2
|
|
|
3
3
|
from .client import RolloutClient, StorageClient, TrainerClient
|
|
4
|
+
from .credentials import platform_bearer
|
|
4
5
|
from .training_run import UploadedTrainingRun, upload_training_run
|
|
5
6
|
from .validation import ValidationReport, validate_env
|
|
6
7
|
|
|
8
|
+
# Imported last: login depends on credentials/device_auth (siblings), so this
|
|
9
|
+
# stays cycle-free as long as those are already loaded by the imports above.
|
|
10
|
+
from .login import ensure_session
|
|
11
|
+
|
|
7
12
|
__all__ = [
|
|
8
13
|
"RolloutClient",
|
|
9
14
|
"StorageClient",
|
|
10
15
|
"TrainerClient",
|
|
11
16
|
"UploadedTrainingRun",
|
|
12
17
|
"ValidationReport",
|
|
18
|
+
# The seam token-getter: generated scripts pass it to a raw OpenAI client
|
|
19
|
+
# (e.g. the traces pivot), so it's part of the public surface alongside
|
|
20
|
+
# ensure_session — not just an internal credentials helper.
|
|
21
|
+
"platform_bearer",
|
|
22
|
+
"ensure_session",
|
|
13
23
|
"upload_training_run",
|
|
14
24
|
"validate_env",
|
|
15
25
|
]
|
benchmax/platform/client.py
CHANGED
|
@@ -14,10 +14,9 @@ from typing import TYPE_CHECKING, Any
|
|
|
14
14
|
|
|
15
15
|
import httpx
|
|
16
16
|
|
|
17
|
-
logger = logging.getLogger(__name__)
|
|
18
|
-
|
|
19
17
|
from benchmax import config
|
|
20
18
|
|
|
19
|
+
from .credentials import TokenProvider, resolve_token_provider
|
|
21
20
|
from .exceptions import (
|
|
22
21
|
AuthenticationError,
|
|
23
22
|
JobLaunchError,
|
|
@@ -28,6 +27,8 @@ from .exceptions import (
|
|
|
28
27
|
TrainerError,
|
|
29
28
|
)
|
|
30
29
|
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
31
32
|
if TYPE_CHECKING:
|
|
32
33
|
from types import ModuleType
|
|
33
34
|
|
|
@@ -74,10 +75,17 @@ class ValidationResult:
|
|
|
74
75
|
"""
|
|
75
76
|
|
|
76
77
|
examples: list[ExampleValidation]
|
|
78
|
+
# Outcome of the compute_group_reward contract check, run on the real
|
|
79
|
+
# smoke-rollout transcripts. None when the env has no group reward, the
|
|
80
|
+
# env class wasn't supplied, or the check was skipped (deps not installed
|
|
81
|
+
# locally — it runs on the trainer instead). Its index is -1.
|
|
82
|
+
group_reward: ExampleValidation | None = None
|
|
77
83
|
|
|
78
84
|
@property
|
|
79
85
|
def ok(self) -> bool:
|
|
80
|
-
|
|
86
|
+
rollouts_ok = all(ex.ok for ex in self.examples)
|
|
87
|
+
group_ok = self.group_reward is None or self.group_reward.ok
|
|
88
|
+
return rollouts_ok and group_ok
|
|
81
89
|
|
|
82
90
|
def __bool__(self) -> bool:
|
|
83
91
|
return self.ok
|
|
@@ -110,6 +118,22 @@ def _file_hash(content: bytes, length: int = 8) -> str:
|
|
|
110
118
|
return hashlib.sha256(content).hexdigest()[:length]
|
|
111
119
|
|
|
112
120
|
|
|
121
|
+
class _BearerAuth(httpx.Auth):
|
|
122
|
+
"""Resolve the platform bearer per request via ``token_provider``.
|
|
123
|
+
|
|
124
|
+
Built once but called on every request, so the auth header is never frozen
|
|
125
|
+
at construction — a rotating/expiring device or act-as token is picked up
|
|
126
|
+
each call (the "token expires mid-run" bug ``credentials.py`` warns about).
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def __init__(self, token_provider: TokenProvider) -> None:
|
|
130
|
+
self._token_provider = token_provider
|
|
131
|
+
|
|
132
|
+
def auth_flow(self, request: httpx.Request):
|
|
133
|
+
request.headers["Authorization"] = f"Bearer {self._token_provider()}"
|
|
134
|
+
yield request
|
|
135
|
+
|
|
136
|
+
|
|
113
137
|
@dataclass
|
|
114
138
|
class StorageClient:
|
|
115
139
|
"""Client for uploading files to storage via pre-signed URLs.
|
|
@@ -117,6 +141,10 @@ class StorageClient:
|
|
|
117
141
|
Uses the ``GET /api/storage/upload-url`` endpoint to obtain a pre-signed
|
|
118
142
|
upload URL, then PUTs the file content directly to that URL.
|
|
119
143
|
|
|
144
|
+
``api_key`` is optional: when omitted the bearer resolves per request via
|
|
145
|
+
the credential seam (``ACT_AS_TOKEN_PATH`` / ``PLATFORM_API_KEY``). Pass
|
|
146
|
+
``api_key`` to override, or ``token_provider`` for a custom per-call source.
|
|
147
|
+
|
|
120
148
|
Example:
|
|
121
149
|
client = StorageClient(api_key="sk_...", base_url="http://localhost:3000")
|
|
122
150
|
result = client.upload_file(
|
|
@@ -127,19 +155,22 @@ class StorageClient:
|
|
|
127
155
|
print(f"Uploaded to {result['blobPath']}")
|
|
128
156
|
"""
|
|
129
157
|
|
|
130
|
-
api_key: str
|
|
158
|
+
api_key: str | None = None
|
|
131
159
|
base_url: str = field(default_factory=config.platform_url)
|
|
132
160
|
timeout: float = 60.0
|
|
133
161
|
# SAS-URL PUTs are bounded by file size, not API latency. Default to
|
|
134
162
|
# 30 minutes so multi-GB datasets don't time out at the platform-API timeout.
|
|
135
163
|
upload_timeout: float = 1800.0
|
|
164
|
+
token_provider: TokenProvider | None = None
|
|
165
|
+
_token_provider: TokenProvider = field(init=False, repr=False)
|
|
136
166
|
_http_client: httpx.Client = field(init=False, repr=False)
|
|
137
167
|
|
|
138
168
|
def __post_init__(self) -> None:
|
|
139
|
-
"""Initialize HTTP client
|
|
169
|
+
"""Initialize HTTP client; auth resolves per request, never baked here."""
|
|
170
|
+
self._token_provider = resolve_token_provider(self.api_key, self.token_provider)
|
|
140
171
|
self._http_client = httpx.Client(
|
|
141
172
|
base_url=self.base_url,
|
|
142
|
-
|
|
173
|
+
auth=_BearerAuth(self._token_provider),
|
|
143
174
|
timeout=self.timeout,
|
|
144
175
|
)
|
|
145
176
|
|
|
@@ -294,6 +325,10 @@ class StorageClient:
|
|
|
294
325
|
class TrainerClient:
|
|
295
326
|
"""Client for launching and managing training runs.
|
|
296
327
|
|
|
328
|
+
``api_key`` is optional: when omitted the bearer resolves per request via
|
|
329
|
+
the credential seam (``ACT_AS_TOKEN_PATH`` / ``PLATFORM_API_KEY``). Pass
|
|
330
|
+
``api_key`` to override, or ``token_provider`` for a custom per-call source.
|
|
331
|
+
|
|
297
332
|
Example:
|
|
298
333
|
client = TrainerClient(api_key="sk_...", base_url="http://localhost:3000")
|
|
299
334
|
run_id = client.launch_training_run(
|
|
@@ -306,16 +341,19 @@ class TrainerClient:
|
|
|
306
341
|
print(f"Launched: {run_id}")
|
|
307
342
|
"""
|
|
308
343
|
|
|
309
|
-
api_key: str
|
|
344
|
+
api_key: str | None = None
|
|
310
345
|
base_url: str = field(default_factory=config.platform_url)
|
|
311
346
|
timeout: float = 30.0
|
|
347
|
+
token_provider: TokenProvider | None = None
|
|
348
|
+
_token_provider: TokenProvider = field(init=False, repr=False)
|
|
312
349
|
_http_client: httpx.Client = field(init=False, repr=False)
|
|
313
350
|
|
|
314
351
|
def __post_init__(self) -> None:
|
|
315
|
-
"""Initialize HTTP client
|
|
352
|
+
"""Initialize HTTP client; auth resolves per request, never baked here."""
|
|
353
|
+
self._token_provider = resolve_token_provider(self.api_key, self.token_provider)
|
|
316
354
|
self._http_client = httpx.Client(
|
|
317
355
|
base_url=self.base_url,
|
|
318
|
-
|
|
356
|
+
auth=_BearerAuth(self._token_provider),
|
|
319
357
|
timeout=self.timeout,
|
|
320
358
|
)
|
|
321
359
|
|
|
@@ -716,23 +754,32 @@ class RolloutClient:
|
|
|
716
754
|
raw file contents; they will be base64-encoded and sent inline.
|
|
717
755
|
|
|
718
756
|
Args:
|
|
719
|
-
api_key: Platform API key
|
|
720
|
-
platform-service validates.
|
|
757
|
+
api_key: Platform API key forwarded as the Bearer token
|
|
758
|
+
platform-service validates. Optional — when omitted the
|
|
759
|
+
bearer resolves per request via the credential seam
|
|
760
|
+
(``ACT_AS_TOKEN_PATH`` / ``PLATFORM_API_KEY``).
|
|
721
761
|
server_url: Base URL of platform-service. Defaults to
|
|
722
762
|
``config.platform_url()``; the ``/v1/rollout/stream`` path is
|
|
723
763
|
appended per request.
|
|
724
764
|
timeout: Per-request timeout in seconds (default 300 — rollouts can be slow).
|
|
765
|
+
token_provider: Custom per-call bearer source; overrides the seam when
|
|
766
|
+
``api_key`` is unset.
|
|
725
767
|
"""
|
|
726
768
|
|
|
727
769
|
_TERMINAL = {"rollout_completed", "worker_error", "cancelled", "error"}
|
|
728
770
|
|
|
729
771
|
def __init__(
|
|
730
772
|
self,
|
|
731
|
-
api_key: str,
|
|
773
|
+
api_key: str | None = None,
|
|
732
774
|
server_url: str | None = None,
|
|
733
775
|
timeout: float = 300.0,
|
|
776
|
+
*,
|
|
777
|
+
token_provider: TokenProvider | None = None,
|
|
734
778
|
) -> None:
|
|
735
|
-
|
|
779
|
+
# Bearer resolves per request (see stream_rollout): explicit api_key →
|
|
780
|
+
# token_provider → platform_bearer seam. Optional so a logged-in/CI
|
|
781
|
+
# caller need not pass one.
|
|
782
|
+
self._token_provider = resolve_token_provider(api_key, token_provider)
|
|
736
783
|
# Resolve at construction time, not import time, so env-var changes
|
|
737
784
|
# take effect (mirrors StorageClient/TrainerClient default_factory pattern).
|
|
738
785
|
# Target platform-service (the API-key gate), not the rollout-service
|
|
@@ -834,6 +881,12 @@ class RolloutClient:
|
|
|
834
881
|
env_metadata_bytes,
|
|
835
882
|
)
|
|
836
883
|
|
|
884
|
+
# Resolve the platform bearer once, per request (never frozen at
|
|
885
|
+
# construction): a rotating/expiring device or act-as token is picked
|
|
886
|
+
# up each call. Used for the platform-service header below AND, when the
|
|
887
|
+
# LLM leg hits the platform's own endpoint, as that leg's key.
|
|
888
|
+
bearer = self._token_provider()
|
|
889
|
+
|
|
837
890
|
# Resolve LLM URL lazily. The platform key is only auto-forwarded when
|
|
838
891
|
# the LLM endpoint is the platform's own LLM service — pointing at a
|
|
839
892
|
# third-party host (Azure OpenAI, Anthropic) requires an explicit
|
|
@@ -842,7 +895,7 @@ class RolloutClient:
|
|
|
842
895
|
resolved_llm_url = llm_base_url or platform_llm_url
|
|
843
896
|
if not llm_api_key:
|
|
844
897
|
if resolved_llm_url == platform_llm_url:
|
|
845
|
-
llm_api_key =
|
|
898
|
+
llm_api_key = bearer
|
|
846
899
|
else:
|
|
847
900
|
raise ValueError(
|
|
848
901
|
"llm_api_key is required when llm_base_url points outside the "
|
|
@@ -870,7 +923,7 @@ class RolloutClient:
|
|
|
870
923
|
# platform-service mounts the proxy at /v1/rollout/stream; it validates
|
|
871
924
|
# the platform key and forwards to rollout-service with an act_as JWT.
|
|
872
925
|
url = f"{self._server_url}/v1/rollout/stream"
|
|
873
|
-
headers = {"Authorization": f"Bearer {
|
|
926
|
+
headers = {"Authorization": f"Bearer {bearer}"}
|
|
874
927
|
|
|
875
928
|
with httpx.stream(
|
|
876
929
|
"POST",
|
|
@@ -960,6 +1013,8 @@ class RolloutClient:
|
|
|
960
1013
|
llm_api_key: str = "",
|
|
961
1014
|
llm_model: str = _VALIDATION_MODEL,
|
|
962
1015
|
max_turns: int = 4,
|
|
1016
|
+
check_group_reward: bool = True,
|
|
1017
|
+
group_reward_samples: int = 2,
|
|
963
1018
|
verbose: bool = True,
|
|
964
1019
|
) -> ValidationResult:
|
|
965
1020
|
"""Run rollouts on the first *n* examples and report pass/fail.
|
|
@@ -991,6 +1046,19 @@ class RolloutClient:
|
|
|
991
1046
|
``llm_base_url`` points outside the platform LLM
|
|
992
1047
|
endpoint (stream_rollout refuses to forward the
|
|
993
1048
|
platform key to a third-party host).
|
|
1049
|
+
check_group_reward: After the rollouts, run a REAL same-example
|
|
1050
|
+
group through rollout-service (one example,
|
|
1051
|
+
samples_per_example=N) so the env's
|
|
1052
|
+
``compute_group_reward`` executes server-side in
|
|
1053
|
+
the trainer image, over co-located siblings —
|
|
1054
|
+
the trainer/external-eval path. A server-side
|
|
1055
|
+
failure (raise or contract violation) comes back
|
|
1056
|
+
as ``group_reward_error`` and fails validation.
|
|
1057
|
+
Only fires when ``env_class`` is given and the
|
|
1058
|
+
env overrides the method.
|
|
1059
|
+
group_reward_samples: Size of that group (the batch's
|
|
1060
|
+
``samples_per_example``). Costs this many extra
|
|
1061
|
+
rollouts; ignored unless the group check runs.
|
|
994
1062
|
verbose: Print colored progress to stdout (default True for
|
|
995
1063
|
interactive/notebook UX). Set False for programmatic
|
|
996
1064
|
callers that consume the returned ValidationResult.
|
|
@@ -1026,6 +1094,18 @@ class RolloutClient:
|
|
|
1026
1094
|
env_cls_path, env_metadata_path, env_cls_bytes, env_metadata_bytes
|
|
1027
1095
|
)
|
|
1028
1096
|
|
|
1097
|
+
# compute_group_reward runs on a whole rollout GROUP, which the
|
|
1098
|
+
# per-example smoke above never forms (each is a group of 1). Run a real
|
|
1099
|
+
# same-example group server-side (run_group, below) so the env method
|
|
1100
|
+
# executes on the trainer's path; the server reports any failure as
|
|
1101
|
+
# group_reward_error. Needs the env_class, and is pointless unless the
|
|
1102
|
+
# env overrides the no-op default.
|
|
1103
|
+
want_group = False
|
|
1104
|
+
if check_group_reward and env_class is not None:
|
|
1105
|
+
from .validation import overrides_compute_group_reward
|
|
1106
|
+
|
|
1107
|
+
want_group = overrides_compute_group_reward(env_class)
|
|
1108
|
+
|
|
1029
1109
|
sample = examples[:n]
|
|
1030
1110
|
if verbose:
|
|
1031
1111
|
print(
|
|
@@ -1066,7 +1146,55 @@ class RolloutClient:
|
|
|
1066
1146
|
print(_err(f" Example {i} failed: {exc}"))
|
|
1067
1147
|
per_example.append(ExampleValidation(index=i, ok=False, error=str(exc)))
|
|
1068
1148
|
|
|
1069
|
-
|
|
1149
|
+
group_reward: ExampleValidation | None = None
|
|
1150
|
+
if want_group and sample and group_reward_samples >= 1:
|
|
1151
|
+
# Faithful check: run a REAL same-example group through
|
|
1152
|
+
# rollout-service (one example, samples_per_example=N) so the env's
|
|
1153
|
+
# compute_group_reward runs server-side in the trainer image, over
|
|
1154
|
+
# co-located siblings — exactly the trainer/external-eval path. A
|
|
1155
|
+
# server-side failure comes back as group_reward_error per rollout.
|
|
1156
|
+
if verbose:
|
|
1157
|
+
print(
|
|
1158
|
+
_info(
|
|
1159
|
+
f"\n Group reward — {group_reward_samples} server-side "
|
|
1160
|
+
"sibling(s) of example 0"
|
|
1161
|
+
)
|
|
1162
|
+
)
|
|
1163
|
+
try:
|
|
1164
|
+
events = self.run_group(
|
|
1165
|
+
sample[0],
|
|
1166
|
+
samples=group_reward_samples,
|
|
1167
|
+
env_cls_path=env_cls_path,
|
|
1168
|
+
env_metadata_path=env_metadata_path,
|
|
1169
|
+
env_cls_bytes=env_cls_bytes,
|
|
1170
|
+
env_metadata_bytes=env_metadata_bytes,
|
|
1171
|
+
llm_base_url=llm_base_url,
|
|
1172
|
+
llm_api_key=llm_api_key,
|
|
1173
|
+
llm_model=llm_model,
|
|
1174
|
+
max_turns=max_turns,
|
|
1175
|
+
verbose=verbose,
|
|
1176
|
+
)
|
|
1177
|
+
group_reward = self._assess_group_events(
|
|
1178
|
+
events, group_reward_samples, verbose
|
|
1179
|
+
)
|
|
1180
|
+
except RolloutNotFound:
|
|
1181
|
+
# The batch proxy (/v1/rollout/batch/stream) isn't deployed on
|
|
1182
|
+
# this server yet — skip rather than fail, so the SDK can land
|
|
1183
|
+
# ahead of platform-service. group_reward stays None; the offline
|
|
1184
|
+
# local check still covered shape.
|
|
1185
|
+
if verbose:
|
|
1186
|
+
print(
|
|
1187
|
+
_info(
|
|
1188
|
+
" compute_group_reward: skipped — server has no "
|
|
1189
|
+
"/rollout/batch/stream yet"
|
|
1190
|
+
)
|
|
1191
|
+
)
|
|
1192
|
+
except (RolloutError, RuntimeError) as exc:
|
|
1193
|
+
if verbose:
|
|
1194
|
+
print(_err(f" group reward check failed: {exc}"))
|
|
1195
|
+
group_reward = ExampleValidation(index=-1, ok=False, error=str(exc))
|
|
1196
|
+
|
|
1197
|
+
result = ValidationResult(examples=per_example, group_reward=group_reward)
|
|
1070
1198
|
if verbose:
|
|
1071
1199
|
print()
|
|
1072
1200
|
if result.ok:
|
|
@@ -1079,3 +1207,162 @@ class RolloutClient:
|
|
|
1079
1207
|
)
|
|
1080
1208
|
|
|
1081
1209
|
return result
|
|
1210
|
+
|
|
1211
|
+
def run_group(
|
|
1212
|
+
self,
|
|
1213
|
+
example: dict[str, Any],
|
|
1214
|
+
*,
|
|
1215
|
+
samples: int,
|
|
1216
|
+
env_cls_path: str | None = None,
|
|
1217
|
+
env_metadata_path: str | None = None,
|
|
1218
|
+
env_cls_bytes: bytes | None = None,
|
|
1219
|
+
env_metadata_bytes: bytes | None = None,
|
|
1220
|
+
llm_base_url: str | None = None,
|
|
1221
|
+
llm_api_key: str = "",
|
|
1222
|
+
llm_model: str = _VALIDATION_MODEL,
|
|
1223
|
+
max_turns: int = 4,
|
|
1224
|
+
verbose: bool = True,
|
|
1225
|
+
) -> list[dict[str, Any]]:
|
|
1226
|
+
"""Run ONE example as a real ``samples``-member group; return its
|
|
1227
|
+
``rollout_completed`` events.
|
|
1228
|
+
|
|
1229
|
+
Submits a one-row batch with ``samples_per_example=samples`` to
|
|
1230
|
+
``/v1/rollout/batch/stream``. rollout-service co-locates the siblings on
|
|
1231
|
+
one worker and runs ``env.compute_group_reward`` over them — the same
|
|
1232
|
+
path the trainer/external-eval use, in the trainer image. Each event
|
|
1233
|
+
carries ``success``, ``rewards`` and (on a server new enough to report
|
|
1234
|
+
it) ``group_reward_error``. Raises the same typed errors as
|
|
1235
|
+
``stream_rollout`` on a non-200.
|
|
1236
|
+
"""
|
|
1237
|
+
env = self._build_env(
|
|
1238
|
+
env_cls_path, env_metadata_path, env_cls_bytes, env_metadata_bytes
|
|
1239
|
+
)
|
|
1240
|
+
# Resolve the platform bearer once, per request (never frozen at
|
|
1241
|
+
# construction) — used for the request header below AND, when the LLM leg
|
|
1242
|
+
# hits the platform's own endpoint, as that leg's key. Mirrors stream_rollout.
|
|
1243
|
+
bearer = self._token_provider()
|
|
1244
|
+
|
|
1245
|
+
# The platform key is only auto-forwarded to the platform's own LLM host
|
|
1246
|
+
# (see stream_rollout for the no-leak rationale).
|
|
1247
|
+
platform_llm_url = config.llm_url()
|
|
1248
|
+
resolved_llm_url = llm_base_url or platform_llm_url
|
|
1249
|
+
if not llm_api_key:
|
|
1250
|
+
if resolved_llm_url == platform_llm_url:
|
|
1251
|
+
llm_api_key = bearer
|
|
1252
|
+
else:
|
|
1253
|
+
raise ValueError(
|
|
1254
|
+
"llm_api_key is required when llm_base_url points outside the "
|
|
1255
|
+
f"platform LLM endpoint ({platform_llm_url}). Refusing to "
|
|
1256
|
+
"forward the platform API key to a third-party host."
|
|
1257
|
+
)
|
|
1258
|
+
|
|
1259
|
+
payload = {
|
|
1260
|
+
"dataset_bytes": base64.b64encode(json.dumps(example).encode()).decode(),
|
|
1261
|
+
"is_dataset_standardized": False,
|
|
1262
|
+
# One example → one group; compute_group_reward needs all siblings
|
|
1263
|
+
# co-located on a single worker anyway. Pin to 1 so rollout-service
|
|
1264
|
+
# doesn't spin up extra workers that get an empty partition and crash.
|
|
1265
|
+
"concurrent_workers": 1,
|
|
1266
|
+
"env": env,
|
|
1267
|
+
"llm": {
|
|
1268
|
+
"base_url": resolved_llm_url,
|
|
1269
|
+
"api_key": llm_api_key,
|
|
1270
|
+
"model": llm_model,
|
|
1271
|
+
},
|
|
1272
|
+
"options": {"max_turns": max_turns, "samples_per_example": samples},
|
|
1273
|
+
}
|
|
1274
|
+
# platform-service mounts the batch proxy at /v1/rollout/batch/stream; it
|
|
1275
|
+
# validates the platform key and forwards to rollout-service with an
|
|
1276
|
+
# act_as JWT, same as the single /v1/rollout/stream proxy.
|
|
1277
|
+
url = f"{self._server_url}/v1/rollout/batch/stream"
|
|
1278
|
+
headers = {"Authorization": f"Bearer {bearer}"}
|
|
1279
|
+
|
|
1280
|
+
completed: list[dict[str, Any]] = []
|
|
1281
|
+
with httpx.stream(
|
|
1282
|
+
"POST", url, json=payload, headers=headers, timeout=self._timeout
|
|
1283
|
+
) as response:
|
|
1284
|
+
if response.status_code != 200:
|
|
1285
|
+
body = response.read().decode()
|
|
1286
|
+
if response.status_code in (401, 403):
|
|
1287
|
+
raise AuthenticationError(body[:300], response.status_code)
|
|
1288
|
+
if response.status_code == 404:
|
|
1289
|
+
raise RolloutNotFound(body[:300], response.status_code)
|
|
1290
|
+
if 500 <= response.status_code < 600:
|
|
1291
|
+
raise RolloutServerError(body[:300], response.status_code)
|
|
1292
|
+
raise RolloutError(body[:300], response.status_code)
|
|
1293
|
+
|
|
1294
|
+
for event in _iter_sse(response):
|
|
1295
|
+
etype = event.get("event")
|
|
1296
|
+
if etype == "batch_started":
|
|
1297
|
+
if verbose:
|
|
1298
|
+
print(
|
|
1299
|
+
_info(
|
|
1300
|
+
f" group batch started ({event.get('total')} rollouts)"
|
|
1301
|
+
)
|
|
1302
|
+
)
|
|
1303
|
+
elif etype == "rollout_completed":
|
|
1304
|
+
completed.append(event)
|
|
1305
|
+
elif etype == "worker_error":
|
|
1306
|
+
# A sandbox process crashed. Non-fatal to the group on its
|
|
1307
|
+
# own — the verdict comes from the rollout_completed events
|
|
1308
|
+
# (_assess_group_events fails if none succeeded). Surfaced
|
|
1309
|
+
# for visibility, not raised.
|
|
1310
|
+
if verbose:
|
|
1311
|
+
print(_err(f" worker_error: {str(event.get('error'))[:200]}"))
|
|
1312
|
+
elif etype == "error":
|
|
1313
|
+
raise RolloutError(str(event.get("error"))[:300], 500)
|
|
1314
|
+
elif etype in ("batch_completed", "cancelled"):
|
|
1315
|
+
break
|
|
1316
|
+
return completed
|
|
1317
|
+
|
|
1318
|
+
def _assess_group_events(
|
|
1319
|
+
self,
|
|
1320
|
+
events: list[dict[str, Any]],
|
|
1321
|
+
samples: int,
|
|
1322
|
+
verbose: bool,
|
|
1323
|
+
) -> ExampleValidation:
|
|
1324
|
+
"""Turn a group's ``rollout_completed`` events into a pass/fail verdict.
|
|
1325
|
+
|
|
1326
|
+
Fails when the server reported a ``group_reward_error`` for any sibling
|
|
1327
|
+
— compute_group_reward raised or violated its contract (see
|
|
1328
|
+
rollout-service's ``_compute_group_rewards_safe``). Also fails if every
|
|
1329
|
+
group rollout failed (nothing to assess). Index -1.
|
|
1330
|
+
|
|
1331
|
+
Note: a server that predates ``group_reward_error`` can't report a
|
|
1332
|
+
failure, so a green verdict there means "no failure observed", not
|
|
1333
|
+
"verified" — the offline local check still covers shape regardless.
|
|
1334
|
+
"""
|
|
1335
|
+
errors = [
|
|
1336
|
+
e["group_reward_error"] for e in events if e.get("group_reward_error")
|
|
1337
|
+
]
|
|
1338
|
+
if errors:
|
|
1339
|
+
msg = str(errors[0])
|
|
1340
|
+
if verbose:
|
|
1341
|
+
print(_err(f" compute_group_reward FAILED server-side: {msg}"))
|
|
1342
|
+
return ExampleValidation(index=-1, ok=False, error=msg)
|
|
1343
|
+
|
|
1344
|
+
succeeded = [e for e in events if e.get("success")]
|
|
1345
|
+
if not succeeded:
|
|
1346
|
+
first = next(
|
|
1347
|
+
(e.get("error") for e in events if e.get("error")),
|
|
1348
|
+
"no successful group rollouts",
|
|
1349
|
+
)
|
|
1350
|
+
if verbose:
|
|
1351
|
+
print(
|
|
1352
|
+
_err(
|
|
1353
|
+
" group reward not validated — all group rollouts "
|
|
1354
|
+
f"failed: {first}"
|
|
1355
|
+
)
|
|
1356
|
+
)
|
|
1357
|
+
return ExampleValidation(
|
|
1358
|
+
index=-1, ok=False, error=f"all group rollouts failed: {first}"
|
|
1359
|
+
)
|
|
1360
|
+
|
|
1361
|
+
if verbose:
|
|
1362
|
+
print(
|
|
1363
|
+
_ok(
|
|
1364
|
+
" compute_group_reward OK server-side on a group of "
|
|
1365
|
+
f"{len(succeeded)}"
|
|
1366
|
+
)
|
|
1367
|
+
)
|
|
1368
|
+
return ExampleValidation(index=-1, ok=True)
|