benchmax 0.1.2.dev28__py3-none-any.whl → 0.1.2.dev29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
benchmax/cli.py ADDED
@@ -0,0 +1,78 @@
1
+ """``castform`` CLI — browser-based login for the SDK.
2
+
3
+ Commands: ``login`` (device authorization), ``logout``, ``whoami``. The login
4
+ flow + the reusable ``ensure_session`` live in :mod:`benchmax.platform.login`;
5
+ this module is the thin argparse wrapper. After ``castform login`` the SDK
6
+ resolves its bearer from ``~/.castform`` automatically — no API key or URL.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import argparse
12
+ import sys
13
+
14
+ from benchmax.platform import credentials
15
+ from benchmax.platform.device_auth import DeviceAuthError
16
+ from benchmax.platform.login import _login
17
+
18
+
19
+ def _cmd_login(args: argparse.Namespace) -> int:
20
+ env = "staging" if args.env == "staging" else None
21
+ try:
22
+ _login(env)
23
+ except DeviceAuthError as exc:
24
+ print(f"Login failed: {exc}", file=sys.stderr)
25
+ return 1
26
+ print(f"\n✓ Logged in to {args.env}.")
27
+ return 0
28
+
29
+
30
+ def _cmd_logout(_args: argparse.Namespace) -> int:
31
+ credentials.clear_castform_session()
32
+ print("✓ Logged out.")
33
+ return 0
34
+
35
+
36
+ def _cmd_whoami(_args: argparse.Namespace) -> int:
37
+ session = credentials.read_castform_session()
38
+ if not session:
39
+ print("Not logged in. Run `castform login`.", file=sys.stderr)
40
+ return 1
41
+ env = session.get("env", "prod")
42
+ jwt = credentials._session_jwt() # mints from the session; None if invalid/expired/offline
43
+ if not jwt:
44
+ print(
45
+ f"Session present (env: {env}), but couldn't reach auth-service to "
46
+ "verify it (offline, or the session expired). If this persists, run "
47
+ "`castform login` again.",
48
+ file=sys.stderr,
49
+ )
50
+ return 1
51
+ claims = credentials._jwt_claims(jwt)
52
+ who = claims.get("email") or claims.get("sub", "<unknown>")
53
+ print(f"Logged in as {who} (env: {env}).")
54
+ return 0
55
+
56
+
57
+ def main(argv: list[str] | None = None) -> int:
58
+ parser = argparse.ArgumentParser(prog="castform", description="Castform CLI")
59
+ sub = parser.add_subparsers(dest="command", required=True)
60
+
61
+ p_login = sub.add_parser("login", help="Sign in via your browser")
62
+ p_login.add_argument(
63
+ "--env",
64
+ choices=["prod", "staging"],
65
+ default="prod",
66
+ help="Environment to sign in to (staging is internal-only)",
67
+ )
68
+ p_login.set_defaults(func=_cmd_login)
69
+
70
+ sub.add_parser("logout", help="Clear the cached session").set_defaults(func=_cmd_logout)
71
+ sub.add_parser("whoami", help="Show the current login").set_defaults(func=_cmd_whoami)
72
+
73
+ args = parser.parse_args(argv)
74
+ return args.func(args)
75
+
76
+
77
+ if __name__ == "__main__":
78
+ sys.exit(main())
benchmax/config.py CHANGED
@@ -16,7 +16,38 @@ DEFAULT_BASE_DOMAIN = "castform.com"
16
16
 
17
17
 
18
18
  def base_domain() -> str:
19
- return os.environ.get("CASTFORM_BASE_DOMAIN", DEFAULT_BASE_DOMAIN)
19
+ """Resolve the platform base domain.
20
+
21
+ Precedence: explicit ``CASTFORM_BASE_DOMAIN`` → the cached device-auth
22
+ session's ``env`` (``staging`` → ``castform.dev``) → ``prod`` default
23
+ (``castform.com``). The ``env`` claim travels with the credential, so a
24
+ logged-in SDK routes to the same environment it authenticated against —
25
+ URL and credential can't desync. A prod session carries no ``env`` marker
26
+ (``None`` → prod), so only internal staging logins deviate from the default.
27
+ """
28
+ override = os.environ.get("CASTFORM_BASE_DOMAIN")
29
+ if override:
30
+ return override
31
+ if _session_env() == "staging":
32
+ return "castform.dev"
33
+ return DEFAULT_BASE_DOMAIN
34
+
35
+
36
+ def _session_env() -> str | None:
37
+ """The ``env`` from the cached device-auth session, if any.
38
+
39
+ Lazy import: ``config`` is a leaf that ``benchmax.platform`` depends on, so a
40
+ top-level import would cycle (platform/__init__ → client → config)."""
41
+ try:
42
+ from benchmax.platform.credentials import read_castform_session
43
+
44
+ session = read_castform_session()
45
+ except Exception:
46
+ return None
47
+ if not session:
48
+ return None
49
+ env = session.get("env")
50
+ return env if isinstance(env, str) else None
20
51
 
21
52
 
22
53
  def platform_url() -> str:
@@ -38,3 +69,13 @@ def web_app_url() -> str:
38
69
  def llm_url() -> str:
39
70
  """OpenAI-compatible LLM endpoint hosted by the platform."""
40
71
  return os.environ.get("CASTFORM_LLM_URL") or f"https://llm.{base_domain()}/v1"
72
+
73
+
74
+ def auth_url() -> str:
75
+ """Auth-service base URL (device-authorization + JWT mint endpoints).
76
+
77
+ Used by ``castform login`` and the per-process session→JWT mint. Derives from
78
+ the same base domain as everything else, so a session minted against ``staging``
79
+ talks to ``auth.castform.dev`` and a ``prod`` session to ``auth.castform.com``.
80
+ """
81
+ return os.environ.get("CASTFORM_AUTH_URL") or f"https://auth.{base_domain()}"
@@ -1,15 +1,25 @@
1
1
  """Castform platform clients (storage, training runs, rollout)."""
2
2
 
3
3
  from .client import RolloutClient, StorageClient, TrainerClient
4
+ from .credentials import platform_bearer
4
5
  from .training_run import UploadedTrainingRun, upload_training_run
5
6
  from .validation import ValidationReport, validate_env
6
7
 
8
+ # Imported last: login depends on credentials/device_auth (siblings), so this
9
+ # stays cycle-free as long as those are already loaded by the imports above.
10
+ from .login import ensure_session
11
+
7
12
  __all__ = [
8
13
  "RolloutClient",
9
14
  "StorageClient",
10
15
  "TrainerClient",
11
16
  "UploadedTrainingRun",
12
17
  "ValidationReport",
18
+ # The seam token-getter: generated scripts pass it to a raw OpenAI client
19
+ # (e.g. the traces pivot), so it's part of the public surface alongside
20
+ # ensure_session — not just an internal credentials helper.
21
+ "platform_bearer",
22
+ "ensure_session",
13
23
  "upload_training_run",
14
24
  "validate_env",
15
25
  ]
@@ -14,10 +14,9 @@ from typing import TYPE_CHECKING, Any
14
14
 
15
15
  import httpx
16
16
 
17
- logger = logging.getLogger(__name__)
18
-
19
17
  from benchmax import config
20
18
 
19
+ from .credentials import TokenProvider, resolve_token_provider
21
20
  from .exceptions import (
22
21
  AuthenticationError,
23
22
  JobLaunchError,
@@ -28,6 +27,8 @@ from .exceptions import (
28
27
  TrainerError,
29
28
  )
30
29
 
30
+ logger = logging.getLogger(__name__)
31
+
31
32
  if TYPE_CHECKING:
32
33
  from types import ModuleType
33
34
 
@@ -74,10 +75,17 @@ class ValidationResult:
74
75
  """
75
76
 
76
77
  examples: list[ExampleValidation]
78
+ # Outcome of the compute_group_reward contract check, run on the real
79
+ # smoke-rollout transcripts. None when the env has no group reward, the
80
+ # env class wasn't supplied, or the check was skipped (deps not installed
81
+ # locally — it runs on the trainer instead). Its index is -1.
82
+ group_reward: ExampleValidation | None = None
77
83
 
78
84
  @property
79
85
  def ok(self) -> bool:
80
- return all(ex.ok for ex in self.examples)
86
+ rollouts_ok = all(ex.ok for ex in self.examples)
87
+ group_ok = self.group_reward is None or self.group_reward.ok
88
+ return rollouts_ok and group_ok
81
89
 
82
90
  def __bool__(self) -> bool:
83
91
  return self.ok
@@ -110,6 +118,22 @@ def _file_hash(content: bytes, length: int = 8) -> str:
110
118
  return hashlib.sha256(content).hexdigest()[:length]
111
119
 
112
120
 
121
+ class _BearerAuth(httpx.Auth):
122
+ """Resolve the platform bearer per request via ``token_provider``.
123
+
124
+ Built once but called on every request, so the auth header is never frozen
125
+ at construction — a rotating/expiring device or act-as token is picked up
126
+ each call (the "token expires mid-run" bug ``credentials.py`` warns about).
127
+ """
128
+
129
+ def __init__(self, token_provider: TokenProvider) -> None:
130
+ self._token_provider = token_provider
131
+
132
+ def auth_flow(self, request: httpx.Request):
133
+ request.headers["Authorization"] = f"Bearer {self._token_provider()}"
134
+ yield request
135
+
136
+
113
137
  @dataclass
114
138
  class StorageClient:
115
139
  """Client for uploading files to storage via pre-signed URLs.
@@ -117,6 +141,10 @@ class StorageClient:
117
141
  Uses the ``GET /api/storage/upload-url`` endpoint to obtain a pre-signed
118
142
  upload URL, then PUTs the file content directly to that URL.
119
143
 
144
+ ``api_key`` is optional: when omitted the bearer resolves per request via
145
+ the credential seam (``ACT_AS_TOKEN_PATH`` / ``PLATFORM_API_KEY``). Pass
146
+ ``api_key`` to override, or ``token_provider`` for a custom per-call source.
147
+
120
148
  Example:
121
149
  client = StorageClient(api_key="sk_...", base_url="http://localhost:3000")
122
150
  result = client.upload_file(
@@ -127,19 +155,22 @@ class StorageClient:
127
155
  print(f"Uploaded to {result['blobPath']}")
128
156
  """
129
157
 
130
- api_key: str
158
+ api_key: str | None = None
131
159
  base_url: str = field(default_factory=config.platform_url)
132
160
  timeout: float = 60.0
133
161
  # SAS-URL PUTs are bounded by file size, not API latency. Default to
134
162
  # 30 minutes so multi-GB datasets don't time out at the platform-API timeout.
135
163
  upload_timeout: float = 1800.0
164
+ token_provider: TokenProvider | None = None
165
+ _token_provider: TokenProvider = field(init=False, repr=False)
136
166
  _http_client: httpx.Client = field(init=False, repr=False)
137
167
 
138
168
  def __post_init__(self) -> None:
139
- """Initialize HTTP client with auth headers."""
169
+ """Initialize HTTP client; auth resolves per request, never baked here."""
170
+ self._token_provider = resolve_token_provider(self.api_key, self.token_provider)
140
171
  self._http_client = httpx.Client(
141
172
  base_url=self.base_url,
142
- headers={"Authorization": f"Bearer {self.api_key}"},
173
+ auth=_BearerAuth(self._token_provider),
143
174
  timeout=self.timeout,
144
175
  )
145
176
 
@@ -294,6 +325,10 @@ class StorageClient:
294
325
  class TrainerClient:
295
326
  """Client for launching and managing training runs.
296
327
 
328
+ ``api_key`` is optional: when omitted the bearer resolves per request via
329
+ the credential seam (``ACT_AS_TOKEN_PATH`` / ``PLATFORM_API_KEY``). Pass
330
+ ``api_key`` to override, or ``token_provider`` for a custom per-call source.
331
+
297
332
  Example:
298
333
  client = TrainerClient(api_key="sk_...", base_url="http://localhost:3000")
299
334
  run_id = client.launch_training_run(
@@ -306,16 +341,19 @@ class TrainerClient:
306
341
  print(f"Launched: {run_id}")
307
342
  """
308
343
 
309
- api_key: str
344
+ api_key: str | None = None
310
345
  base_url: str = field(default_factory=config.platform_url)
311
346
  timeout: float = 30.0
347
+ token_provider: TokenProvider | None = None
348
+ _token_provider: TokenProvider = field(init=False, repr=False)
312
349
  _http_client: httpx.Client = field(init=False, repr=False)
313
350
 
314
351
  def __post_init__(self) -> None:
315
- """Initialize HTTP client with auth headers."""
352
+ """Initialize HTTP client; auth resolves per request, never baked here."""
353
+ self._token_provider = resolve_token_provider(self.api_key, self.token_provider)
316
354
  self._http_client = httpx.Client(
317
355
  base_url=self.base_url,
318
- headers={"Authorization": f"Bearer {self.api_key}"},
356
+ auth=_BearerAuth(self._token_provider),
319
357
  timeout=self.timeout,
320
358
  )
321
359
 
@@ -716,23 +754,32 @@ class RolloutClient:
716
754
  raw file contents; they will be base64-encoded and sent inline.
717
755
 
718
756
  Args:
719
- api_key: Platform API key (``sk_``); forwarded as the Bearer token
720
- platform-service validates.
757
+ api_key: Platform API key forwarded as the Bearer token
758
+ platform-service validates. Optional — when omitted the
759
+ bearer resolves per request via the credential seam
760
+ (``ACT_AS_TOKEN_PATH`` / ``PLATFORM_API_KEY``).
721
761
  server_url: Base URL of platform-service. Defaults to
722
762
  ``config.platform_url()``; the ``/v1/rollout/stream`` path is
723
763
  appended per request.
724
764
  timeout: Per-request timeout in seconds (default 300 — rollouts can be slow).
765
+ token_provider: Custom per-call bearer source; overrides the seam when
766
+ ``api_key`` is unset.
725
767
  """
726
768
 
727
769
  _TERMINAL = {"rollout_completed", "worker_error", "cancelled", "error"}
728
770
 
729
771
  def __init__(
730
772
  self,
731
- api_key: str,
773
+ api_key: str | None = None,
732
774
  server_url: str | None = None,
733
775
  timeout: float = 300.0,
776
+ *,
777
+ token_provider: TokenProvider | None = None,
734
778
  ) -> None:
735
- self._api_key = api_key
779
+ # Bearer resolves per request (see stream_rollout): explicit api_key
780
+ # token_provider → platform_bearer seam. Optional so a logged-in/CI
781
+ # caller need not pass one.
782
+ self._token_provider = resolve_token_provider(api_key, token_provider)
736
783
  # Resolve at construction time, not import time, so env-var changes
737
784
  # take effect (mirrors StorageClient/TrainerClient default_factory pattern).
738
785
  # Target platform-service (the API-key gate), not the rollout-service
@@ -834,6 +881,12 @@ class RolloutClient:
834
881
  env_metadata_bytes,
835
882
  )
836
883
 
884
+ # Resolve the platform bearer once, per request (never frozen at
885
+ # construction): a rotating/expiring device or act-as token is picked
886
+ # up each call. Used for the platform-service header below AND, when the
887
+ # LLM leg hits the platform's own endpoint, as that leg's key.
888
+ bearer = self._token_provider()
889
+
837
890
  # Resolve LLM URL lazily. The platform key is only auto-forwarded when
838
891
  # the LLM endpoint is the platform's own LLM service — pointing at a
839
892
  # third-party host (Azure OpenAI, Anthropic) requires an explicit
@@ -842,7 +895,7 @@ class RolloutClient:
842
895
  resolved_llm_url = llm_base_url or platform_llm_url
843
896
  if not llm_api_key:
844
897
  if resolved_llm_url == platform_llm_url:
845
- llm_api_key = self._api_key
898
+ llm_api_key = bearer
846
899
  else:
847
900
  raise ValueError(
848
901
  "llm_api_key is required when llm_base_url points outside the "
@@ -870,7 +923,7 @@ class RolloutClient:
870
923
  # platform-service mounts the proxy at /v1/rollout/stream; it validates
871
924
  # the platform key and forwards to rollout-service with an act_as JWT.
872
925
  url = f"{self._server_url}/v1/rollout/stream"
873
- headers = {"Authorization": f"Bearer {self._api_key}"}
926
+ headers = {"Authorization": f"Bearer {bearer}"}
874
927
 
875
928
  with httpx.stream(
876
929
  "POST",
@@ -960,6 +1013,8 @@ class RolloutClient:
960
1013
  llm_api_key: str = "",
961
1014
  llm_model: str = _VALIDATION_MODEL,
962
1015
  max_turns: int = 4,
1016
+ check_group_reward: bool = True,
1017
+ group_reward_samples: int = 2,
963
1018
  verbose: bool = True,
964
1019
  ) -> ValidationResult:
965
1020
  """Run rollouts on the first *n* examples and report pass/fail.
@@ -991,6 +1046,19 @@ class RolloutClient:
991
1046
  ``llm_base_url`` points outside the platform LLM
992
1047
  endpoint (stream_rollout refuses to forward the
993
1048
  platform key to a third-party host).
1049
+ check_group_reward: After the rollouts, run a REAL same-example
1050
+ group through rollout-service (one example,
1051
+ samples_per_example=N) so the env's
1052
+ ``compute_group_reward`` executes server-side in
1053
+ the trainer image, over co-located siblings —
1054
+ the trainer/external-eval path. A server-side
1055
+ failure (raise or contract violation) comes back
1056
+ as ``group_reward_error`` and fails validation.
1057
+ Only fires when ``env_class`` is given and the
1058
+ env overrides the method.
1059
+ group_reward_samples: Size of that group (the batch's
1060
+ ``samples_per_example``). Costs this many extra
1061
+ rollouts; ignored unless the group check runs.
994
1062
  verbose: Print colored progress to stdout (default True for
995
1063
  interactive/notebook UX). Set False for programmatic
996
1064
  callers that consume the returned ValidationResult.
@@ -1026,6 +1094,18 @@ class RolloutClient:
1026
1094
  env_cls_path, env_metadata_path, env_cls_bytes, env_metadata_bytes
1027
1095
  )
1028
1096
 
1097
+ # compute_group_reward runs on a whole rollout GROUP, which the
1098
+ # per-example smoke above never forms (each is a group of 1). Run a real
1099
+ # same-example group server-side (run_group, below) so the env method
1100
+ # executes on the trainer's path; the server reports any failure as
1101
+ # group_reward_error. Needs the env_class, and is pointless unless the
1102
+ # env overrides the no-op default.
1103
+ want_group = False
1104
+ if check_group_reward and env_class is not None:
1105
+ from .validation import overrides_compute_group_reward
1106
+
1107
+ want_group = overrides_compute_group_reward(env_class)
1108
+
1029
1109
  sample = examples[:n]
1030
1110
  if verbose:
1031
1111
  print(
@@ -1066,7 +1146,55 @@ class RolloutClient:
1066
1146
  print(_err(f" Example {i} failed: {exc}"))
1067
1147
  per_example.append(ExampleValidation(index=i, ok=False, error=str(exc)))
1068
1148
 
1069
- result = ValidationResult(examples=per_example)
1149
+ group_reward: ExampleValidation | None = None
1150
+ if want_group and sample and group_reward_samples >= 1:
1151
+ # Faithful check: run a REAL same-example group through
1152
+ # rollout-service (one example, samples_per_example=N) so the env's
1153
+ # compute_group_reward runs server-side in the trainer image, over
1154
+ # co-located siblings — exactly the trainer/external-eval path. A
1155
+ # server-side failure comes back as group_reward_error per rollout.
1156
+ if verbose:
1157
+ print(
1158
+ _info(
1159
+ f"\n Group reward — {group_reward_samples} server-side "
1160
+ "sibling(s) of example 0"
1161
+ )
1162
+ )
1163
+ try:
1164
+ events = self.run_group(
1165
+ sample[0],
1166
+ samples=group_reward_samples,
1167
+ env_cls_path=env_cls_path,
1168
+ env_metadata_path=env_metadata_path,
1169
+ env_cls_bytes=env_cls_bytes,
1170
+ env_metadata_bytes=env_metadata_bytes,
1171
+ llm_base_url=llm_base_url,
1172
+ llm_api_key=llm_api_key,
1173
+ llm_model=llm_model,
1174
+ max_turns=max_turns,
1175
+ verbose=verbose,
1176
+ )
1177
+ group_reward = self._assess_group_events(
1178
+ events, group_reward_samples, verbose
1179
+ )
1180
+ except RolloutNotFound:
1181
+ # The batch proxy (/v1/rollout/batch/stream) isn't deployed on
1182
+ # this server yet — skip rather than fail, so the SDK can land
1183
+ # ahead of platform-service. group_reward stays None; the offline
1184
+ # local check still covered shape.
1185
+ if verbose:
1186
+ print(
1187
+ _info(
1188
+ " compute_group_reward: skipped — server has no "
1189
+ "/rollout/batch/stream yet"
1190
+ )
1191
+ )
1192
+ except (RolloutError, RuntimeError) as exc:
1193
+ if verbose:
1194
+ print(_err(f" group reward check failed: {exc}"))
1195
+ group_reward = ExampleValidation(index=-1, ok=False, error=str(exc))
1196
+
1197
+ result = ValidationResult(examples=per_example, group_reward=group_reward)
1070
1198
  if verbose:
1071
1199
  print()
1072
1200
  if result.ok:
@@ -1079,3 +1207,162 @@ class RolloutClient:
1079
1207
  )
1080
1208
 
1081
1209
  return result
1210
+
1211
+ def run_group(
1212
+ self,
1213
+ example: dict[str, Any],
1214
+ *,
1215
+ samples: int,
1216
+ env_cls_path: str | None = None,
1217
+ env_metadata_path: str | None = None,
1218
+ env_cls_bytes: bytes | None = None,
1219
+ env_metadata_bytes: bytes | None = None,
1220
+ llm_base_url: str | None = None,
1221
+ llm_api_key: str = "",
1222
+ llm_model: str = _VALIDATION_MODEL,
1223
+ max_turns: int = 4,
1224
+ verbose: bool = True,
1225
+ ) -> list[dict[str, Any]]:
1226
+ """Run ONE example as a real ``samples``-member group; return its
1227
+ ``rollout_completed`` events.
1228
+
1229
+ Submits a one-row batch with ``samples_per_example=samples`` to
1230
+ ``/v1/rollout/batch/stream``. rollout-service co-locates the siblings on
1231
+ one worker and runs ``env.compute_group_reward`` over them — the same
1232
+ path the trainer/external-eval use, in the trainer image. Each event
1233
+ carries ``success``, ``rewards`` and (on a server new enough to report
1234
+ it) ``group_reward_error``. Raises the same typed errors as
1235
+ ``stream_rollout`` on a non-200.
1236
+ """
1237
+ env = self._build_env(
1238
+ env_cls_path, env_metadata_path, env_cls_bytes, env_metadata_bytes
1239
+ )
1240
+ # Resolve the platform bearer once, per request (never frozen at
1241
+ # construction) — used for the request header below AND, when the LLM leg
1242
+ # hits the platform's own endpoint, as that leg's key. Mirrors stream_rollout.
1243
+ bearer = self._token_provider()
1244
+
1245
+ # The platform key is only auto-forwarded to the platform's own LLM host
1246
+ # (see stream_rollout for the no-leak rationale).
1247
+ platform_llm_url = config.llm_url()
1248
+ resolved_llm_url = llm_base_url or platform_llm_url
1249
+ if not llm_api_key:
1250
+ if resolved_llm_url == platform_llm_url:
1251
+ llm_api_key = bearer
1252
+ else:
1253
+ raise ValueError(
1254
+ "llm_api_key is required when llm_base_url points outside the "
1255
+ f"platform LLM endpoint ({platform_llm_url}). Refusing to "
1256
+ "forward the platform API key to a third-party host."
1257
+ )
1258
+
1259
+ payload = {
1260
+ "dataset_bytes": base64.b64encode(json.dumps(example).encode()).decode(),
1261
+ "is_dataset_standardized": False,
1262
+ # One example → one group; compute_group_reward needs all siblings
1263
+ # co-located on a single worker anyway. Pin to 1 so rollout-service
1264
+ # doesn't spin up extra workers that get an empty partition and crash.
1265
+ "concurrent_workers": 1,
1266
+ "env": env,
1267
+ "llm": {
1268
+ "base_url": resolved_llm_url,
1269
+ "api_key": llm_api_key,
1270
+ "model": llm_model,
1271
+ },
1272
+ "options": {"max_turns": max_turns, "samples_per_example": samples},
1273
+ }
1274
+ # platform-service mounts the batch proxy at /v1/rollout/batch/stream; it
1275
+ # validates the platform key and forwards to rollout-service with an
1276
+ # act_as JWT, same as the single /v1/rollout/stream proxy.
1277
+ url = f"{self._server_url}/v1/rollout/batch/stream"
1278
+ headers = {"Authorization": f"Bearer {bearer}"}
1279
+
1280
+ completed: list[dict[str, Any]] = []
1281
+ with httpx.stream(
1282
+ "POST", url, json=payload, headers=headers, timeout=self._timeout
1283
+ ) as response:
1284
+ if response.status_code != 200:
1285
+ body = response.read().decode()
1286
+ if response.status_code in (401, 403):
1287
+ raise AuthenticationError(body[:300], response.status_code)
1288
+ if response.status_code == 404:
1289
+ raise RolloutNotFound(body[:300], response.status_code)
1290
+ if 500 <= response.status_code < 600:
1291
+ raise RolloutServerError(body[:300], response.status_code)
1292
+ raise RolloutError(body[:300], response.status_code)
1293
+
1294
+ for event in _iter_sse(response):
1295
+ etype = event.get("event")
1296
+ if etype == "batch_started":
1297
+ if verbose:
1298
+ print(
1299
+ _info(
1300
+ f" group batch started ({event.get('total')} rollouts)"
1301
+ )
1302
+ )
1303
+ elif etype == "rollout_completed":
1304
+ completed.append(event)
1305
+ elif etype == "worker_error":
1306
+ # A sandbox process crashed. Non-fatal to the group on its
1307
+ # own — the verdict comes from the rollout_completed events
1308
+ # (_assess_group_events fails if none succeeded). Surfaced
1309
+ # for visibility, not raised.
1310
+ if verbose:
1311
+ print(_err(f" worker_error: {str(event.get('error'))[:200]}"))
1312
+ elif etype == "error":
1313
+ raise RolloutError(str(event.get("error"))[:300], 500)
1314
+ elif etype in ("batch_completed", "cancelled"):
1315
+ break
1316
+ return completed
1317
+
1318
+ def _assess_group_events(
1319
+ self,
1320
+ events: list[dict[str, Any]],
1321
+ samples: int,
1322
+ verbose: bool,
1323
+ ) -> ExampleValidation:
1324
+ """Turn a group's ``rollout_completed`` events into a pass/fail verdict.
1325
+
1326
+ Fails when the server reported a ``group_reward_error`` for any sibling
1327
+ — compute_group_reward raised or violated its contract (see
1328
+ rollout-service's ``_compute_group_rewards_safe``). Also fails if every
1329
+ group rollout failed (nothing to assess). Index -1.
1330
+
1331
+ Note: a server that predates ``group_reward_error`` can't report a
1332
+ failure, so a green verdict there means "no failure observed", not
1333
+ "verified" — the offline local check still covers shape regardless.
1334
+ """
1335
+ errors = [
1336
+ e["group_reward_error"] for e in events if e.get("group_reward_error")
1337
+ ]
1338
+ if errors:
1339
+ msg = str(errors[0])
1340
+ if verbose:
1341
+ print(_err(f" compute_group_reward FAILED server-side: {msg}"))
1342
+ return ExampleValidation(index=-1, ok=False, error=msg)
1343
+
1344
+ succeeded = [e for e in events if e.get("success")]
1345
+ if not succeeded:
1346
+ first = next(
1347
+ (e.get("error") for e in events if e.get("error")),
1348
+ "no successful group rollouts",
1349
+ )
1350
+ if verbose:
1351
+ print(
1352
+ _err(
1353
+ " group reward not validated — all group rollouts "
1354
+ f"failed: {first}"
1355
+ )
1356
+ )
1357
+ return ExampleValidation(
1358
+ index=-1, ok=False, error=f"all group rollouts failed: {first}"
1359
+ )
1360
+
1361
+ if verbose:
1362
+ print(
1363
+ _ok(
1364
+ " compute_group_reward OK server-side on a group of "
1365
+ f"{len(succeeded)}"
1366
+ )
1367
+ )
1368
+ return ExampleValidation(index=-1, ok=True)