benchmax 0.1.2.dev28__py3-none-any.whl → 0.1.2.dev30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. benchmax/cli.py +71 -0
  2. benchmax/config.py +19 -4
  3. benchmax/envs/postgres_search/search_env.py +7 -1
  4. benchmax/envs/reward_helpers.py +9 -2
  5. benchmax/envs/telestich/example.py +28 -39
  6. benchmax/envs/telestich/telestich_env.py +627 -414
  7. benchmax/platform/__init__.py +10 -0
  8. benchmax/platform/client.py +303 -16
  9. benchmax/platform/credentials.py +223 -4
  10. benchmax/platform/device_auth.py +81 -0
  11. benchmax/platform/login.py +92 -0
  12. benchmax/platform/training_run.py +5 -3
  13. benchmax/platform/validation.py +151 -7
  14. benchmax/rag/corpus/chroma/client.py +97 -0
  15. benchmax/rag/corpus/chroma/source.py +35 -5
  16. benchmax/rag/corpus/postgres/client.py +9 -1
  17. benchmax/rag/corpus/postgres/source.py +21 -11
  18. benchmax/rag/qa_generation/filters/env_rollout.py +9 -1
  19. benchmax/rag/qa_generation/filters/grounding_llm.py +9 -1
  20. benchmax/rag/qa_generation/filters/hop_count_validity.py +7 -6
  21. benchmax/rag/qa_generation/filters/retrieval_llm.py +8 -1
  22. benchmax/rag/qa_generation/pipeline.py +10 -4
  23. benchmax/rag/qa_generation/pipeline_config.py +7 -3
  24. benchmax/rubrics/rubric.py +101 -26
  25. {benchmax-0.1.2.dev28.dist-info → benchmax-0.1.2.dev30.dist-info}/METADATA +1 -1
  26. {benchmax-0.1.2.dev28.dist-info → benchmax-0.1.2.dev30.dist-info}/RECORD +30 -26
  27. benchmax-0.1.2.dev30.dist-info/entry_points.txt +2 -0
  28. {benchmax-0.1.2.dev28.dist-info → benchmax-0.1.2.dev30.dist-info}/WHEEL +0 -0
  29. {benchmax-0.1.2.dev28.dist-info → benchmax-0.1.2.dev30.dist-info}/licenses/LICENSE +0 -0
  30. {benchmax-0.1.2.dev28.dist-info → benchmax-0.1.2.dev30.dist-info}/top_level.txt +0 -0
benchmax/cli.py ADDED
@@ -0,0 +1,71 @@
1
+ """``castform`` CLI — browser-based login for the SDK.
2
+
3
+ Commands: ``login`` (device authorization), ``logout``, ``whoami``. The login
4
+ flow + the reusable ``ensure_session`` live in :mod:`benchmax.platform.login`;
5
+ this module is the thin argparse wrapper. After ``castform login`` the SDK
6
+ resolves its bearer from ``~/.castform`` automatically — no API key or URL.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import argparse
12
+ import sys
13
+
14
+ from benchmax import config
15
+ from benchmax.platform import credentials
16
+ from benchmax.platform.device_auth import DeviceAuthError
17
+ from benchmax.platform.login import _login
18
+
19
+
20
+ def _cmd_login(_args: argparse.Namespace) -> int:
21
+ try:
22
+ _login()
23
+ except DeviceAuthError as exc:
24
+ print(f"Login failed: {exc}", file=sys.stderr)
25
+ return 1
26
+ print(f"\n✓ Logged in to {config.base_domain()}.")
27
+ return 0
28
+
29
+
30
+ def _cmd_logout(_args: argparse.Namespace) -> int:
31
+ credentials.clear_castform_session()
32
+ print("✓ Logged out.")
33
+ return 0
34
+
35
+
36
+ def _cmd_whoami(_args: argparse.Namespace) -> int:
37
+ session = credentials.read_castform_session()
38
+ if not session:
39
+ print("Not logged in. Run `castform login`.", file=sys.stderr)
40
+ return 1
41
+ jwt = credentials._session_jwt() # mints from the session; None if invalid/expired/offline
42
+ if not jwt:
43
+ print(
44
+ "Session present, but couldn't reach auth-service to verify it "
45
+ "(offline, or the session expired). If this persists, run "
46
+ "`castform login` again.",
47
+ file=sys.stderr,
48
+ )
49
+ return 1
50
+ claims = credentials._jwt_claims(jwt)
51
+ who = claims.get("email") or claims.get("sub", "<unknown>")
52
+ print(f"Logged in as {who} ({config.base_domain()}).")
53
+ return 0
54
+
55
+
56
+ def main(argv: list[str] | None = None) -> int:
57
+ parser = argparse.ArgumentParser(prog="castform", description="Castform CLI")
58
+ sub = parser.add_subparsers(dest="command", required=True)
59
+
60
+ p_login = sub.add_parser("login", help="Sign in via your browser")
61
+ p_login.set_defaults(func=_cmd_login)
62
+
63
+ sub.add_parser("logout", help="Clear the cached session").set_defaults(func=_cmd_logout)
64
+ sub.add_parser("whoami", help="Show the current login").set_defaults(func=_cmd_whoami)
65
+
66
+ args = parser.parse_args(argv)
67
+ return args.func(args)
68
+
69
+
70
+ if __name__ == "__main__":
71
+ sys.exit(main())
benchmax/config.py CHANGED
@@ -1,8 +1,11 @@
1
1
  """Centralized URL configuration for the Castform platform.
2
2
 
3
- All URLs derive from a single base domain. Set ``CASTFORM_BASE_DOMAIN`` to
4
- point at a different environment (e.g. ``staging.castform.com``); individual
5
- URL components may be overridden via their own env vars when needed.
3
+ All URLs derive from a single base domain, resolved from exactly two places: the
4
+ ``CASTFORM_BASE_DOMAIN`` env var, or the built-in ``castform.com`` default.
5
+ Individual URLs may be overridden via their own env vars
6
+ (``CASTFORM_PLATFORM_URL`` / ``CASTFORM_LLM_URL`` / ``CASTFORM_AUTH_URL`` /
7
+ ``CASTFORM_WEB_APP_URL``) — e.g. point platform at ``http://localhost:3000`` for
8
+ local dev while auth keeps talking to the real host.
6
9
 
7
10
  Usage::
8
11
 
@@ -16,7 +19,10 @@ DEFAULT_BASE_DOMAIN = "castform.com"
16
19
 
17
20
 
18
21
  def base_domain() -> str:
19
- return os.environ.get("CASTFORM_BASE_DOMAIN", DEFAULT_BASE_DOMAIN)
22
+ """Resolve the platform base domain: ``CASTFORM_BASE_DOMAIN`` or the
23
+ ``castform.com`` default. To target another environment (e.g. internal
24
+ staging), export ``CASTFORM_BASE_DOMAIN=castform.dev``."""
25
+ return os.environ.get("CASTFORM_BASE_DOMAIN") or DEFAULT_BASE_DOMAIN
20
26
 
21
27
 
22
28
  def platform_url() -> str:
@@ -38,3 +44,12 @@ def web_app_url() -> str:
38
44
  def llm_url() -> str:
39
45
  """OpenAI-compatible LLM endpoint hosted by the platform."""
40
46
  return os.environ.get("CASTFORM_LLM_URL") or f"https://llm.{base_domain()}/v1"
47
+
48
+
49
+ def auth_url() -> str:
50
+ """Auth-service base URL (device-authorization + JWT mint endpoints).
51
+
52
+ Used by ``castform login`` and the per-process session→JWT mint. Derives from
53
+ the same base domain as everything else, or ``CASTFORM_AUTH_URL`` to override.
54
+ """
55
+ return os.environ.get("CASTFORM_AUTH_URL") or f"https://auth.{base_domain()}"
@@ -285,8 +285,14 @@ tags. Cite your sources inline using [Source: <source_id>] next to each claim.
285
285
  if not text.strip():
286
286
  return zeros
287
287
 
288
- t = task or {}
288
+ # No final <answer> block → no answer to score. Return all-zero
289
+ # rewards so conciseness / citations / efficiency can't accrue
290
+ # from reasoning or tool-call text alone.
289
291
  answer = extract_answer_block(text)
292
+ if not answer:
293
+ return zeros
294
+
295
+ t = task or {}
290
296
  prompt = str(t.get("question") or t.get("prompt") or "")
291
297
  gt_str = str(t.get("ground_truth") or "")
292
298
  reference_chunks = t.get("reference_chunks", [])
@@ -82,9 +82,16 @@ def extract_completion_text(completion: str | list[dict[str, Any]]) -> str:
82
82
 
83
83
 
84
84
  def extract_answer_block(text: str) -> str:
85
- """Extract content from <answer> tags, or return full text."""
85
+ """Extract content from ``<answer>`` tags.
86
+
87
+ Returns the (stripped) tag contents when an ``<answer>…</answer>`` block
88
+ is present, otherwise ``""``. A missing answer block is treated as "no
89
+ final answer" rather than silently falling back to the full completion —
90
+ consumers can gate rewards on a non-empty result. ``<answer></answer>``
91
+ likewise yields ``""``.
92
+ """
86
93
  match = _ANSWER_TAG_RE.search(text or "")
87
- return (match.group(1) if match else text).strip()
94
+ return match.group(1).strip() if match else ""
88
95
 
89
96
 
90
97
  def clip01(value: Any) -> float:
@@ -17,8 +17,8 @@ env's word-list / rhyme dependencies):
17
17
 
18
18
  (``CASTFORM_LLM_API_KEY`` is optional — it defaults to ``CASTFORM_API_KEY``.)
19
19
 
20
- By default this is a 2-example smoke run. Set ``TELESTICH_FULL_RUN=1`` to launch
21
- a real run on the full seed dataset (~90/10 train/eval split).
20
+ This launches a real training run on the full committed seed dataset
21
+ (~90/10 train/eval split).
22
22
  """
23
23
 
24
24
  import asyncio
@@ -42,7 +42,7 @@ from benchmax.rubrics import rubric as rubric_mod
42
42
  #
43
43
  # Defaults route through ``benchmax.config``: the prod LLM endpoint is
44
44
  # ``https://llm.castform.com/v1`` and the platform control plane is
45
- # ``https://api.castform.com``. Point at staging or a different env by setting
45
+ # ``https://api.castform.com``. Point at a different environment by setting
46
46
  # ``CASTFORM_BASE_DOMAIN`` (or override URLs individually via
47
47
  # ``CASTFORM_PLATFORM_URL`` / ``CASTFORM_LLM_URL``).
48
48
  from benchmax import config
@@ -59,6 +59,10 @@ EXPERIMENT_PREFIX = "telestich"
59
59
  DATASET_PATH = str(Path(__file__).parent / "telestich_dataset.jsonl")
60
60
  NUM_EXAMPLES = 400
61
61
  CONCURRENCY = 15
62
+ # Trainer model — the launch `model` arg selects the trainer YAML (and thus the GPU
63
+ # pool) server-side. Supported: "Qwen/Qwen3.5-4B" (gpu4) or "Qwen/Qwen3.5-35B-A3B"
64
+ # (gpu8). Override via TELESTICH_MODEL.
65
+ MODEL = os.environ.get("TELESTICH_MODEL", "Qwen/Qwen3.5-4B")
62
66
 
63
67
  # (model, weight). Weights reflect observed reliability on our checks:
64
68
  # - Both grok models leak banned example words and rubber-stamp the CoT self-check.
@@ -552,7 +556,6 @@ def get_dataset():
552
556
  # alongside the pickle so a UI can show "what code is in this env" without
553
557
  # unpickling.
554
558
  if __name__ == "__main__":
555
- import tempfile
556
559
  import uuid
557
560
 
558
561
  from benchmax.platform.client import TrainerClient
@@ -565,42 +568,25 @@ if __name__ == "__main__":
565
568
  print(f"Platform URL: {BASE_URL}")
566
569
  print(f"LLM URL: {LLM_BASE_URL}\n")
567
570
 
568
- # 1. Build the dataset.
569
- # Full run (TELESTICH_FULL_RUN=1): the committed seed dataset, topped up
570
- # to NUM_EXAMPLES via the platform LLM if short, split ~90/10 train/eval.
571
- # Default: a 2-example smoke that just exercises gen -> bundle -> upload
572
- # -> launch (and the key-less judge path), not a real training job.
573
- full_run = bool(os.environ.get("TELESTICH_FULL_RUN"))
574
- if full_run:
575
- examples = get_dataset()
576
- if len(examples) < 2:
577
- raise SystemExit(f"Need >=2 examples for a full run, got {len(examples)}.")
578
- # Hold out a representative eval set at random; keep TRAIN in curriculum
579
- # order (simpler first) so the difficulty ramp is preserved.
580
- n_eval = max(1, len(examples) // 10)
581
- eval_idx = set(random.sample(range(len(examples)), n_eval))
582
- eval_data = [e for i, e in enumerate(examples) if i in eval_idx]
583
- train_data = [e for i, e in enumerate(examples) if i not in eval_idx]
584
- print(f"Full run: {len(train_data)} train (curriculum order) / {len(eval_data)} eval.\n")
585
- else:
586
- with tempfile.TemporaryDirectory() as tmp:
587
- gen_path = Path(tmp) / "gen.jsonl"
588
- print(f"Generating 2 examples via {LLM_BASE_URL} ...")
589
- asyncio.run(generate_dataset(n=2, path=str(gen_path), concurrency=2))
590
- examples = load_dataset(str(gen_path))
591
- if len(examples) < 2:
592
- raise SystemExit(f"Needed 2 examples, only got {len(examples)}.")
593
- train_data, eval_data = examples[:1], examples[1:2]
594
- print(f"Smoke run: generated {len(examples)} examples — 1 train, 1 eval.\n")
571
+ # 1. Build the dataset from the committed seed file (curriculum order). Hold out a
572
+ # representative eval set at random; keep TRAIN in curriculum order (simpler first)
573
+ # so the difficulty ramp is preserved.
574
+ examples = get_dataset()
575
+ if len(examples) < 2:
576
+ raise SystemExit(f"Need >=2 examples, got {len(examples)}.")
577
+ n_eval = max(1, len(examples) // 10)
578
+ eval_idx = set(random.sample(range(len(examples)), n_eval))
579
+ eval_data = [e for i, e in enumerate(examples) if i in eval_idx]
580
+ train_data = [e for i, e in enumerate(examples) if i not in eval_idx]
581
+ print(f"{len(train_data)} train (curriculum order) / {len(eval_data)} eval.\n")
595
582
 
596
583
  # 2. Bundle the env class and upload everything to platform storage.
597
584
  # Bundle config, defined once so the pre-flight validation below exercises
598
585
  # the EXACT same env_args / by-value modules / deps as the launch.
599
586
  # - local_modules: ship env + rubric by value (the platform's installed
600
587
  # benchmax may not contain this version of these modules).
601
- # - judge_api_key="": satisfies the constructor without leaking a key; the
602
- # judge resolves its bearer at runtime via the platform act-as seam.
603
- constructor_args = {"judge_base_url": LLM_BASE_URL, "judge_api_key": ""}
588
+ # - judge bearer resolves at runtime via the device-auth / platform seam.
589
+ constructor_args = {"judge_base_url": LLM_BASE_URL}
604
590
  local_modules = [telestich_env_mod, rubric_mod]
605
591
  # All three are still required (is_valid_word → correctness; pronouncing →
606
592
  # rhyme). Removing word_bank did NOT free any of them.
@@ -623,10 +609,12 @@ if __name__ == "__main__":
623
609
  llm_api_key="",
624
610
  remote_examples=2,
625
611
  ):
626
- raise SystemExit("Env validation failed — aborting before launch (see output above).")
612
+ raise SystemExit(
613
+ "Env validation failed — aborting before launch (see output above)."
614
+ )
627
615
 
628
616
  # 3. Bundle the env class and upload everything to platform storage.
629
- run_name = f"telestich-{'full' if full_run else 'example'}-{uuid.uuid4().hex[:8]}"
617
+ run_name = f"telestich-full-{uuid.uuid4().hex[:8]}"
630
618
  print(f"\nUploading bundle + datasets as {run_name!r} ...")
631
619
  uploaded = upload_training_run(
632
620
  env_class=TelestichEnv,
@@ -647,8 +635,9 @@ if __name__ == "__main__":
647
635
  ):
648
636
  print(f" {label:<14}: {path}")
649
637
 
650
- # 4. Launch the training run. ``simple`` is the deployed 4B/gpu4 template.
651
- print("\nLaunching training run ...")
638
+ # 4. Launch the training run. training_run_type="simple" + the `model` arg select
639
+ # the trainer YAML/pool server-side (Qwen3.5-4B→gpu4, Qwen3.5-35B-A3B→gpu8).
640
+ print(f"\nLaunching training run (model={MODEL}) ...")
652
641
  with TrainerClient(api_key=API_KEY, base_url=BASE_URL) as trainer:
653
642
  run_id = trainer.launch_training_run(
654
643
  training_run_type="simple",
@@ -661,7 +650,7 @@ if __name__ == "__main__":
661
650
  # max_response_len 3000: a brief reason + 1-2 tool rounds + poem fits well
662
651
  # under this; lowered from 4000 to cut off in-head enumeration rambles
663
652
  # sooner (they truncate to a 0-reward anyway).
664
- launcher_args={"max_response_len": 3000, "num_epochs": 10},
653
+ launcher_args={"model": MODEL, "max_response_len": 3000, "num_epochs": 10},
665
654
  )
666
655
 
667
656
  print(f"\n✓ Launched run_id={run_id}")