benchmax 0.1.2.dev29__py3-none-any.whl → 0.1.2.dev31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
benchmax/bundle.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import importlib
3
4
  import inspect
4
5
  import io
5
6
  import json
@@ -76,6 +77,7 @@ def dump_bundle(
76
77
  pip_dependencies: list[str] | None = None,
77
78
  local_modules: list[ModuleType] | None = None,
78
79
  env_class_source: str | None = None,
80
+ auto_local_modules: bool = True,
79
81
  ) -> Bundle:
80
82
  """Pickle ``(env_class, constructor_args)`` and stamp metadata.
81
83
 
@@ -90,6 +92,10 @@ def dump_bundle(
90
92
  recover it — e.g. a class produced by ``exec()`` into an in-memory
91
93
  namespace, which has no source file on disk. When ``None``
92
94
  (default), source is introspected from ``env_class``.
95
+ auto_local_modules: When True (default), any local module the pickle
96
+ references but that wasn't passed in ``local_modules`` is imported
97
+ and pickled by value automatically (a warning names them). When
98
+ False, such a reference raises ``BundlingError`` instead.
93
99
 
94
100
  Raises:
95
101
  BundlingError: bad env_class, cloudpickle failure, or pickle references
@@ -124,6 +130,46 @@ def dump_bundle(
124
130
  except Exception:
125
131
  pass
126
132
 
133
+ if auto_local_modules and _unregistered_local_refs(pickled):
134
+ # Import each referenced local module and re-dump with it pickled by
135
+ # value. Loop because a by-value module can surface further local refs;
136
+ # registrations accumulate (and are torn down once at the end) so an
137
+ # earlier module stays by-value while we resolve the ones it pulled in.
138
+ seen: set[str] = {m.__name__ for m in local_modules}
139
+ registered: list[ModuleType] = []
140
+ with _BUNDLE_LOCK:
141
+ try:
142
+ for _ in range(10):
143
+ pending = [
144
+ m for m in _unregistered_local_refs(pickled) if m not in seen
145
+ ]
146
+ if not pending:
147
+ break
148
+ new_mods: list[ModuleType] = []
149
+ for name in pending:
150
+ seen.add(name) # unimportable names fall through to the guard
151
+ try:
152
+ new_mods.append(importlib.import_module(name))
153
+ except Exception:
154
+ pass
155
+ if not new_mods:
156
+ break
157
+ logger.warning(
158
+ "[bundle] %s: auto-bundling local module(s): %s ",
159
+ env_class.__name__,
160
+ ", ".join(sorted(m.__name__ for m in new_mods)),
161
+ )
162
+ for mod in new_mods:
163
+ cloudpickle.register_pickle_by_value(mod)
164
+ registered.append(mod)
165
+ pickled = cloudpickle.dumps((env_class, constructor_args))
166
+ finally:
167
+ for mod in registered:
168
+ try:
169
+ cloudpickle.unregister_pickle_by_value(mod)
170
+ except Exception:
171
+ pass
172
+
127
173
  risky = _unregistered_local_refs(pickled)
128
174
  if risky:
129
175
  msg = (
@@ -259,6 +305,15 @@ def _referenced_modules(pickled: bytes) -> set[str]:
259
305
  # Hooks find_class so we see every (module, name) the unpickler would import —
260
306
  # i.e. exactly what'd raise ModuleNotFoundError on a fresh interpreter. The stub
261
307
  # lets unpickling proceed past missing classes so we collect every ref.
308
+ #
309
+ # find_class alone has a blind spot: a bare ``import foo`` that leaves a
310
+ # module *object* in the env's globals is pickled as
311
+ # ``cloudpickle.subimport("foo")`` — the module name is a REDUCE argument,
312
+ # not a find_class path, so we'd only see ``cloudpickle.cloudpickle`` (which
313
+ # looks installed) and miss ``foo``. We shim subimport to record its arg and
314
+ # return a stub instead of importing, so a missing module is captured rather
315
+ # than aborting the whole load early. (``dynamic_subimport`` is by-value /
316
+ # self-contained — leave it to the real find_class so we don't flag it.)
262
317
  refs: set[str] = set()
263
318
 
264
319
  class _Stub:
@@ -271,9 +326,28 @@ def _referenced_modules(pickled: bytes) -> set[str]:
271
326
  def __reduce__(self) -> tuple:
272
327
  return (type(self), ())
273
328
 
329
+ def _recording_subimport(name: str, *a: Any, **kw: Any) -> ModuleType:
330
+ refs.add(name)
331
+ return ModuleType(str(name))
332
+
333
+ def _noop_setstate(obj: Any, *a: Any, **kw: Any) -> Any:
334
+ # cloudpickle's _make_skeleton_class resolves the class_tracker_id back
335
+ # to the *live* class (it was tracked when env_class was dumped), so the
336
+ # real ``_class_setstate``/``_function_setstate`` would setattr the
337
+ # reconstructed (stub-globals) members onto the live class/function —
338
+ # mutating the caller's class mid-bundle and poisoning any later dump.
339
+ # We only need the refs from ``state``, which are already recorded while
340
+ # it's unpickled; the setter itself is a no-op here.
341
+ return obj
342
+
274
343
  class _Recorder(pickle.Unpickler):
275
344
  def find_class(self, module: str, name: str) -> Any:
276
345
  refs.add(module)
346
+ if module.startswith("cloudpickle"):
347
+ if name == "subimport":
348
+ return _recording_subimport
349
+ if name in ("_class_setstate", "_function_setstate"):
350
+ return _noop_setstate
277
351
  try:
278
352
  return super().find_class(module, name)
279
353
  except Exception:
benchmax/cli.py CHANGED
@@ -11,19 +11,19 @@ from __future__ import annotations
11
11
  import argparse
12
12
  import sys
13
13
 
14
+ from benchmax import config
14
15
  from benchmax.platform import credentials
15
16
  from benchmax.platform.device_auth import DeviceAuthError
16
17
  from benchmax.platform.login import _login
17
18
 
18
19
 
19
- def _cmd_login(args: argparse.Namespace) -> int:
20
- env = "staging" if args.env == "staging" else None
20
+ def _cmd_login(_args: argparse.Namespace) -> int:
21
21
  try:
22
- _login(env)
22
+ _login()
23
23
  except DeviceAuthError as exc:
24
24
  print(f"Login failed: {exc}", file=sys.stderr)
25
25
  return 1
26
- print(f"\n✓ Logged in to {args.env}.")
26
+ print(f"\n✓ Logged in to {config.base_domain()}.")
27
27
  return 0
28
28
 
29
29
 
@@ -38,19 +38,18 @@ def _cmd_whoami(_args: argparse.Namespace) -> int:
38
38
  if not session:
39
39
  print("Not logged in. Run `castform login`.", file=sys.stderr)
40
40
  return 1
41
- env = session.get("env", "prod")
42
41
  jwt = credentials._session_jwt() # mints from the session; None if invalid/expired/offline
43
42
  if not jwt:
44
43
  print(
45
- f"Session present (env: {env}), but couldn't reach auth-service to "
46
- "verify it (offline, or the session expired). If this persists, run "
44
+ "Session present, but couldn't reach auth-service to verify it "
45
+ "(offline, or the session expired). If this persists, run "
47
46
  "`castform login` again.",
48
47
  file=sys.stderr,
49
48
  )
50
49
  return 1
51
50
  claims = credentials._jwt_claims(jwt)
52
51
  who = claims.get("email") or claims.get("sub", "<unknown>")
53
- print(f"Logged in as {who} (env: {env}).")
52
+ print(f"Logged in as {who} ({config.base_domain()}).")
54
53
  return 0
55
54
 
56
55
 
@@ -59,12 +58,6 @@ def main(argv: list[str] | None = None) -> int:
59
58
  sub = parser.add_subparsers(dest="command", required=True)
60
59
 
61
60
  p_login = sub.add_parser("login", help="Sign in via your browser")
62
- p_login.add_argument(
63
- "--env",
64
- choices=["prod", "staging"],
65
- default="prod",
66
- help="Environment to sign in to (staging is internal-only)",
67
- )
68
61
  p_login.set_defaults(func=_cmd_login)
69
62
 
70
63
  sub.add_parser("logout", help="Clear the cached session").set_defaults(func=_cmd_logout)
benchmax/config.py CHANGED
@@ -1,8 +1,11 @@
1
1
  """Centralized URL configuration for the Castform platform.
2
2
 
3
- All URLs derive from a single base domain. Set ``CASTFORM_BASE_DOMAIN`` to
4
- point at a different environment (e.g. ``staging.castform.com``); individual
5
- URL components may be overridden via their own env vars when needed.
3
+ All URLs derive from a single base domain, resolved from exactly two places: the
4
+ ``CASTFORM_BASE_DOMAIN`` env var, or the built-in ``castform.com`` default.
5
+ Individual URLs may be overridden via their own env vars
6
+ (``CASTFORM_PLATFORM_URL`` / ``CASTFORM_LLM_URL`` / ``CASTFORM_AUTH_URL`` /
7
+ ``CASTFORM_WEB_APP_URL``) — e.g. point platform at ``http://localhost:3000`` for
8
+ local dev while auth keeps talking to the real host.
6
9
 
7
10
  Usage::
8
11
 
@@ -16,38 +19,10 @@ DEFAULT_BASE_DOMAIN = "castform.com"
16
19
 
17
20
 
18
21
  def base_domain() -> str:
19
- """Resolve the platform base domain.
20
-
21
- Precedence: explicit ``CASTFORM_BASE_DOMAIN`` → the cached device-auth
22
- session's ``env`` (``staging`` → ``castform.dev``) ``prod`` default
23
- (``castform.com``). The ``env`` claim travels with the credential, so a
24
- logged-in SDK routes to the same environment it authenticated against —
25
- URL and credential can't desync. A prod session carries no ``env`` marker
26
- (``None`` → prod), so only internal staging logins deviate from the default.
27
- """
28
- override = os.environ.get("CASTFORM_BASE_DOMAIN")
29
- if override:
30
- return override
31
- if _session_env() == "staging":
32
- return "castform.dev"
33
- return DEFAULT_BASE_DOMAIN
34
-
35
-
36
- def _session_env() -> str | None:
37
- """The ``env`` from the cached device-auth session, if any.
38
-
39
- Lazy import: ``config`` is a leaf that ``benchmax.platform`` depends on, so a
40
- top-level import would cycle (platform/__init__ → client → config)."""
41
- try:
42
- from benchmax.platform.credentials import read_castform_session
43
-
44
- session = read_castform_session()
45
- except Exception:
46
- return None
47
- if not session:
48
- return None
49
- env = session.get("env")
50
- return env if isinstance(env, str) else None
22
+ """Resolve the platform base domain: ``CASTFORM_BASE_DOMAIN`` or the
23
+ ``castform.com`` default. To target another environment (e.g. internal
24
+ staging), export ``CASTFORM_BASE_DOMAIN=castform.dev``."""
25
+ return os.environ.get("CASTFORM_BASE_DOMAIN") or DEFAULT_BASE_DOMAIN
51
26
 
52
27
 
53
28
  def platform_url() -> str:
@@ -75,7 +50,6 @@ def auth_url() -> str:
75
50
  """Auth-service base URL (device-authorization + JWT mint endpoints).
76
51
 
77
52
  Used by ``castform login`` and the per-process session→JWT mint. Derives from
78
- the same base domain as everything else, so a session minted against ``staging``
79
- talks to ``auth.castform.dev`` and a ``prod`` session to ``auth.castform.com``.
53
+ the same base domain as everything else, or ``CASTFORM_AUTH_URL`` to override.
80
54
  """
81
55
  return os.environ.get("CASTFORM_AUTH_URL") or f"https://auth.{base_domain()}"
@@ -285,8 +285,14 @@ tags. Cite your sources inline using [Source: <source_id>] next to each claim.
285
285
  if not text.strip():
286
286
  return zeros
287
287
 
288
- t = task or {}
288
+ # No final <answer> block → no answer to score. Return all-zero
289
+ # rewards so conciseness / citations / efficiency can't accrue
290
+ # from reasoning or tool-call text alone.
289
291
  answer = extract_answer_block(text)
292
+ if not answer:
293
+ return zeros
294
+
295
+ t = task or {}
290
296
  prompt = str(t.get("question") or t.get("prompt") or "")
291
297
  gt_str = str(t.get("ground_truth") or "")
292
298
  reference_chunks = t.get("reference_chunks", [])
@@ -82,9 +82,16 @@ def extract_completion_text(completion: str | list[dict[str, Any]]) -> str:
82
82
 
83
83
 
84
84
  def extract_answer_block(text: str) -> str:
85
- """Extract content from <answer> tags, or return full text."""
85
+ """Extract content from ``<answer>`` tags.
86
+
87
+ Returns the (stripped) tag contents when an ``<answer>…</answer>`` block
88
+ is present, otherwise ``""``. A missing answer block is treated as "no
89
+ final answer" rather than silently falling back to the full completion —
90
+ consumers can gate rewards on a non-empty result. ``<answer></answer>``
91
+ likewise yields ``""``.
92
+ """
86
93
  match = _ANSWER_TAG_RE.search(text or "")
87
- return (match.group(1) if match else text).strip()
94
+ return match.group(1).strip() if match else ""
88
95
 
89
96
 
90
97
  def clip01(value: Any) -> float:
@@ -12,13 +12,15 @@ Run it from the benchmax project root (the ``telestich`` extra pulls in the
12
12
  env's word-list / rhyme dependencies):
13
13
 
14
14
  cd core/benchmax
15
- CASTFORM_API_KEY=sk_... \
16
- uv run --extra telestich python -m benchmax.envs.telestich.example
15
+ uv run --extra telestich python -m benchmax.envs.telestich.example
17
16
 
18
- (``CASTFORM_LLM_API_KEY`` is optional it defaults to ``CASTFORM_API_KEY``.)
17
+ Auth is the device-auth session (``ensure_session()`` opens a browser login if
18
+ ``~/.castform`` has no valid session) — no API key needed. ``CASTFORM_API_KEY``
19
+ / ``CASTFORM_LLM_API_KEY`` are only consulted by the offline dataset-generation
20
+ helpers, not the launch path.
19
21
 
20
- By default this is a 2-example smoke run. Set ``TELESTICH_FULL_RUN=1`` to launch
21
- a real run on the full seed dataset (~90/10 train/eval split).
22
+ This launches a real training run on the full committed seed dataset
23
+ (~90/10 train/eval split).
22
24
  """
23
25
 
24
26
  import asyncio
@@ -42,7 +44,7 @@ from benchmax.rubrics import rubric as rubric_mod
42
44
  #
43
45
  # Defaults route through ``benchmax.config``: the prod LLM endpoint is
44
46
  # ``https://llm.castform.com/v1`` and the platform control plane is
45
- # ``https://api.castform.com``. Point at staging or a different env by setting
47
+ # ``https://api.castform.com``. Point at a different environment by setting
46
48
  # ``CASTFORM_BASE_DOMAIN`` (or override URLs individually via
47
49
  # ``CASTFORM_PLATFORM_URL`` / ``CASTFORM_LLM_URL``).
48
50
  from benchmax import config
@@ -59,6 +61,12 @@ EXPERIMENT_PREFIX = "telestich"
59
61
  DATASET_PATH = str(Path(__file__).parent / "telestich_dataset.jsonl")
60
62
  NUM_EXAMPLES = 400
61
63
  CONCURRENCY = 15
64
+ # Trainer model — the launch `model` arg selects the trainer YAML (and thus the GPU
65
+ # pool) server-side. Supported: "Qwen/Qwen3.5-4B" (gpu4) or "Qwen/Qwen3.5-35B-A3B"
66
+ # (gpu8). Override via TELESTICH_MODEL.
67
+ MODEL = os.environ.get("TELESTICH_MODEL", "Qwen/Qwen3.5-4B")
68
+ # Run name — defaults to a unique telestich-full-<uuid>. Override via TELESTICH_RUN_NAME.
69
+ RUN_NAME = os.environ.get("TELESTICH_RUN_NAME", "")
62
70
 
63
71
  # (model, weight). Weights reflect observed reliability on our checks:
64
72
  # - Both grok models leak banned example words and rubber-stamp the CoT self-check.
@@ -552,55 +560,40 @@ def get_dataset():
552
560
  # alongside the pickle so a UI can show "what code is in this env" without
553
561
  # unpickling.
554
562
  if __name__ == "__main__":
555
- import tempfile
556
563
  import uuid
557
564
 
565
+ from benchmax.platform import ensure_session
558
566
  from benchmax.platform.client import TrainerClient
559
567
  from benchmax.platform.training_run import upload_training_run
560
568
  from benchmax.platform.validation import validate_env
561
569
 
562
- if not API_KEY:
563
- raise SystemExit("Set CASTFORM_API_KEY before running this example.")
570
+ # Device-auth session bootstrap: browser login if no credential resolves.
571
+ # After this the platform bearer comes from ~/.castform — no API key needed,
572
+ # so we pass api_key="" to the platform calls below (resolves via the seam).
573
+ ensure_session()
564
574
 
565
575
  print(f"Platform URL: {BASE_URL}")
566
576
  print(f"LLM URL: {LLM_BASE_URL}\n")
567
577
 
568
- # 1. Build the dataset.
569
- # Full run (TELESTICH_FULL_RUN=1): the committed seed dataset, topped up
570
- # to NUM_EXAMPLES via the platform LLM if short, split ~90/10 train/eval.
571
- # Default: a 2-example smoke that just exercises gen -> bundle -> upload
572
- # -> launch (and the key-less judge path), not a real training job.
573
- full_run = bool(os.environ.get("TELESTICH_FULL_RUN"))
574
- if full_run:
575
- examples = get_dataset()
576
- if len(examples) < 2:
577
- raise SystemExit(f"Need >=2 examples for a full run, got {len(examples)}.")
578
- # Hold out a representative eval set at random; keep TRAIN in curriculum
579
- # order (simpler first) so the difficulty ramp is preserved.
580
- n_eval = max(1, len(examples) // 10)
581
- eval_idx = set(random.sample(range(len(examples)), n_eval))
582
- eval_data = [e for i, e in enumerate(examples) if i in eval_idx]
583
- train_data = [e for i, e in enumerate(examples) if i not in eval_idx]
584
- print(f"Full run: {len(train_data)} train (curriculum order) / {len(eval_data)} eval.\n")
585
- else:
586
- with tempfile.TemporaryDirectory() as tmp:
587
- gen_path = Path(tmp) / "gen.jsonl"
588
- print(f"Generating 2 examples via {LLM_BASE_URL} ...")
589
- asyncio.run(generate_dataset(n=2, path=str(gen_path), concurrency=2))
590
- examples = load_dataset(str(gen_path))
591
- if len(examples) < 2:
592
- raise SystemExit(f"Needed 2 examples, only got {len(examples)}.")
593
- train_data, eval_data = examples[:1], examples[1:2]
594
- print(f"Smoke run: generated {len(examples)} examples — 1 train, 1 eval.\n")
578
+ # 1. Build the dataset from the committed seed file (curriculum order). Hold out a
579
+ # representative eval set at random; keep TRAIN in curriculum order (simpler first)
580
+ # so the difficulty ramp is preserved.
581
+ examples = get_dataset()
582
+ if len(examples) < 2:
583
+ raise SystemExit(f"Need >=2 examples, got {len(examples)}.")
584
+ n_eval = max(1, len(examples) // 10)
585
+ eval_idx = set(random.sample(range(len(examples)), n_eval))
586
+ eval_data = [e for i, e in enumerate(examples) if i in eval_idx]
587
+ train_data = [e for i, e in enumerate(examples) if i not in eval_idx]
588
+ print(f"{len(train_data)} train (curriculum order) / {len(eval_data)} eval.\n")
595
589
 
596
590
  # 2. Bundle the env class and upload everything to platform storage.
597
591
  # Bundle config, defined once so the pre-flight validation below exercises
598
592
  # the EXACT same env_args / by-value modules / deps as the launch.
599
593
  # - local_modules: ship env + rubric by value (the platform's installed
600
594
  # benchmax may not contain this version of these modules).
601
- # - judge_api_key="": satisfies the constructor without leaking a key; the
602
- # judge resolves its bearer at runtime via the platform act-as seam.
603
- constructor_args = {"judge_base_url": LLM_BASE_URL, "judge_api_key": ""}
595
+ # - judge bearer resolves at runtime via the device-auth / platform seam.
596
+ constructor_args = {"judge_base_url": LLM_BASE_URL}
604
597
  local_modules = [telestich_env_mod, rubric_mod]
605
598
  # All three are still required (is_valid_word → correctness; pronouncing →
606
599
  # rhyme). Removing word_bank did NOT free any of them.
@@ -617,23 +610,25 @@ if __name__ == "__main__":
617
610
  eval_dataset=eval_data[:2],
618
611
  local_modules=local_modules,
619
612
  pip_dependencies=pip_dependencies,
620
- api_key=API_KEY,
613
+ api_key="", # session bearer via ensure_session()
621
614
  base_url=BASE_URL,
622
615
  llm_base_url=LLM_BASE_URL,
623
616
  llm_api_key="",
624
617
  remote_examples=2,
625
618
  ):
626
- raise SystemExit("Env validation failed — aborting before launch (see output above).")
619
+ raise SystemExit(
620
+ "Env validation failed — aborting before launch (see output above)."
621
+ )
627
622
 
628
623
  # 3. Bundle the env class and upload everything to platform storage.
629
- run_name = f"telestich-{'full' if full_run else 'example'}-{uuid.uuid4().hex[:8]}"
624
+ run_name = RUN_NAME or f"telestich-full-{uuid.uuid4().hex[:8]}"
630
625
  print(f"\nUploading bundle + datasets as {run_name!r} ...")
631
626
  uploaded = upload_training_run(
632
627
  env_class=TelestichEnv,
633
628
  train_dataset=train_data,
634
629
  eval_dataset=eval_data,
635
630
  run_name=run_name,
636
- api_key=API_KEY,
631
+ api_key="", # session bearer via ensure_session()
637
632
  base_url=BASE_URL,
638
633
  local_modules=local_modules,
639
634
  constructor_args=constructor_args,
@@ -647,9 +642,10 @@ if __name__ == "__main__":
647
642
  ):
648
643
  print(f" {label:<14}: {path}")
649
644
 
650
- # 4. Launch the training run. ``simple`` is the deployed 4B/gpu4 template.
651
- print("\nLaunching training run ...")
652
- with TrainerClient(api_key=API_KEY, base_url=BASE_URL) as trainer:
645
+ # 4. Launch the training run. training_run_type="simple" + the `model` arg select
646
+ # the trainer YAML/pool server-side (Qwen3.5-4B→gpu4, Qwen3.5-35B-A3B→gpu8).
647
+ print(f"\nLaunching training run (model={MODEL}) ...")
648
+ with TrainerClient(api_key="", base_url=BASE_URL) as trainer:
653
649
  run_id = trainer.launch_training_run(
654
650
  training_run_type="simple",
655
651
  env_cls_path=uploaded.env_cls_path,
@@ -658,10 +654,10 @@ if __name__ == "__main__":
658
654
  eval_dataset_path=uploaded.eval_dataset_path,
659
655
  name=run_name,
660
656
  # num_epochs: passes over the train set (platform default is 5).
661
- # max_response_len 3000: a brief reason + 1-2 tool rounds + poem fits well
657
+ # max_rollout_len 3000: a brief reason + 1-2 tool rounds + poem fits well
662
658
  # under this; lowered from 4000 to cut off in-head enumeration rambles
663
659
  # sooner (they truncate to a 0-reward anyway).
664
- launcher_args={"max_response_len": 3000, "num_epochs": 10},
660
+ launcher_args={"model": MODEL, "max_rollout_len": 3000, "num_epochs": 10},
665
661
  )
666
662
 
667
663
  print(f"\n✓ Launched run_id={run_id}")