benchmax 0.1.2.dev29__py3-none-any.whl → 0.1.2.dev30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmax/cli.py +7 -14
- benchmax/config.py +11 -37
- benchmax/envs/postgres_search/search_env.py +7 -1
- benchmax/envs/reward_helpers.py +9 -2
- benchmax/envs/telestich/example.py +28 -39
- benchmax/envs/telestich/telestich_env.py +627 -414
- benchmax/platform/credentials.py +1 -2
- benchmax/platform/login.py +28 -17
- benchmax/rag/corpus/chroma/client.py +97 -0
- benchmax/rag/corpus/chroma/source.py +35 -5
- benchmax/rubrics/rubric.py +101 -26
- {benchmax-0.1.2.dev29.dist-info → benchmax-0.1.2.dev30.dist-info}/METADATA +1 -1
- {benchmax-0.1.2.dev29.dist-info → benchmax-0.1.2.dev30.dist-info}/RECORD +17 -17
- {benchmax-0.1.2.dev29.dist-info → benchmax-0.1.2.dev30.dist-info}/WHEEL +0 -0
- {benchmax-0.1.2.dev29.dist-info → benchmax-0.1.2.dev30.dist-info}/entry_points.txt +0 -0
- {benchmax-0.1.2.dev29.dist-info → benchmax-0.1.2.dev30.dist-info}/licenses/LICENSE +0 -0
- {benchmax-0.1.2.dev29.dist-info → benchmax-0.1.2.dev30.dist-info}/top_level.txt +0 -0
benchmax/cli.py
CHANGED
|
@@ -11,19 +11,19 @@ from __future__ import annotations
|
|
|
11
11
|
import argparse
|
|
12
12
|
import sys
|
|
13
13
|
|
|
14
|
+
from benchmax import config
|
|
14
15
|
from benchmax.platform import credentials
|
|
15
16
|
from benchmax.platform.device_auth import DeviceAuthError
|
|
16
17
|
from benchmax.platform.login import _login
|
|
17
18
|
|
|
18
19
|
|
|
19
|
-
def _cmd_login(
|
|
20
|
-
env = "staging" if args.env == "staging" else None
|
|
20
|
+
def _cmd_login(_args: argparse.Namespace) -> int:
|
|
21
21
|
try:
|
|
22
|
-
_login(
|
|
22
|
+
_login()
|
|
23
23
|
except DeviceAuthError as exc:
|
|
24
24
|
print(f"Login failed: {exc}", file=sys.stderr)
|
|
25
25
|
return 1
|
|
26
|
-
print(f"\n✓ Logged in to {
|
|
26
|
+
print(f"\n✓ Logged in to {config.base_domain()}.")
|
|
27
27
|
return 0
|
|
28
28
|
|
|
29
29
|
|
|
@@ -38,19 +38,18 @@ def _cmd_whoami(_args: argparse.Namespace) -> int:
|
|
|
38
38
|
if not session:
|
|
39
39
|
print("Not logged in. Run `castform login`.", file=sys.stderr)
|
|
40
40
|
return 1
|
|
41
|
-
env = session.get("env", "prod")
|
|
42
41
|
jwt = credentials._session_jwt() # mints from the session; None if invalid/expired/offline
|
|
43
42
|
if not jwt:
|
|
44
43
|
print(
|
|
45
|
-
|
|
46
|
-
"
|
|
44
|
+
"Session present, but couldn't reach auth-service to verify it "
|
|
45
|
+
"(offline, or the session expired). If this persists, run "
|
|
47
46
|
"`castform login` again.",
|
|
48
47
|
file=sys.stderr,
|
|
49
48
|
)
|
|
50
49
|
return 1
|
|
51
50
|
claims = credentials._jwt_claims(jwt)
|
|
52
51
|
who = claims.get("email") or claims.get("sub", "<unknown>")
|
|
53
|
-
print(f"Logged in as {who} (
|
|
52
|
+
print(f"Logged in as {who} ({config.base_domain()}).")
|
|
54
53
|
return 0
|
|
55
54
|
|
|
56
55
|
|
|
@@ -59,12 +58,6 @@ def main(argv: list[str] | None = None) -> int:
|
|
|
59
58
|
sub = parser.add_subparsers(dest="command", required=True)
|
|
60
59
|
|
|
61
60
|
p_login = sub.add_parser("login", help="Sign in via your browser")
|
|
62
|
-
p_login.add_argument(
|
|
63
|
-
"--env",
|
|
64
|
-
choices=["prod", "staging"],
|
|
65
|
-
default="prod",
|
|
66
|
-
help="Environment to sign in to (staging is internal-only)",
|
|
67
|
-
)
|
|
68
61
|
p_login.set_defaults(func=_cmd_login)
|
|
69
62
|
|
|
70
63
|
sub.add_parser("logout", help="Clear the cached session").set_defaults(func=_cmd_logout)
|
benchmax/config.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
|
1
1
|
"""Centralized URL configuration for the Castform platform.
|
|
2
2
|
|
|
3
|
-
All URLs derive from a single base domain
|
|
4
|
-
|
|
5
|
-
|
|
3
|
+
All URLs derive from a single base domain, resolved from exactly two places: the
|
|
4
|
+
``CASTFORM_BASE_DOMAIN`` env var, or the built-in ``castform.com`` default.
|
|
5
|
+
Individual URLs may be overridden via their own env vars
|
|
6
|
+
(``CASTFORM_PLATFORM_URL`` / ``CASTFORM_LLM_URL`` / ``CASTFORM_AUTH_URL`` /
|
|
7
|
+
``CASTFORM_WEB_APP_URL``) — e.g. point platform at ``http://localhost:3000`` for
|
|
8
|
+
local dev while auth keeps talking to the real host.
|
|
6
9
|
|
|
7
10
|
Usage::
|
|
8
11
|
|
|
@@ -16,38 +19,10 @@ DEFAULT_BASE_DOMAIN = "castform.com"
|
|
|
16
19
|
|
|
17
20
|
|
|
18
21
|
def base_domain() -> str:
|
|
19
|
-
"""Resolve the platform base domain
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
(``castform.com``). The ``env`` claim travels with the credential, so a
|
|
24
|
-
logged-in SDK routes to the same environment it authenticated against —
|
|
25
|
-
URL and credential can't desync. A prod session carries no ``env`` marker
|
|
26
|
-
(``None`` → prod), so only internal staging logins deviate from the default.
|
|
27
|
-
"""
|
|
28
|
-
override = os.environ.get("CASTFORM_BASE_DOMAIN")
|
|
29
|
-
if override:
|
|
30
|
-
return override
|
|
31
|
-
if _session_env() == "staging":
|
|
32
|
-
return "castform.dev"
|
|
33
|
-
return DEFAULT_BASE_DOMAIN
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def _session_env() -> str | None:
|
|
37
|
-
"""The ``env`` from the cached device-auth session, if any.
|
|
38
|
-
|
|
39
|
-
Lazy import: ``config`` is a leaf that ``benchmax.platform`` depends on, so a
|
|
40
|
-
top-level import would cycle (platform/__init__ → client → config)."""
|
|
41
|
-
try:
|
|
42
|
-
from benchmax.platform.credentials import read_castform_session
|
|
43
|
-
|
|
44
|
-
session = read_castform_session()
|
|
45
|
-
except Exception:
|
|
46
|
-
return None
|
|
47
|
-
if not session:
|
|
48
|
-
return None
|
|
49
|
-
env = session.get("env")
|
|
50
|
-
return env if isinstance(env, str) else None
|
|
22
|
+
"""Resolve the platform base domain: ``CASTFORM_BASE_DOMAIN`` or the
|
|
23
|
+
``castform.com`` default. To target another environment (e.g. internal
|
|
24
|
+
staging), export ``CASTFORM_BASE_DOMAIN=castform.dev``."""
|
|
25
|
+
return os.environ.get("CASTFORM_BASE_DOMAIN") or DEFAULT_BASE_DOMAIN
|
|
51
26
|
|
|
52
27
|
|
|
53
28
|
def platform_url() -> str:
|
|
@@ -75,7 +50,6 @@ def auth_url() -> str:
|
|
|
75
50
|
"""Auth-service base URL (device-authorization + JWT mint endpoints).
|
|
76
51
|
|
|
77
52
|
Used by ``castform login`` and the per-process session→JWT mint. Derives from
|
|
78
|
-
the same base domain as everything else,
|
|
79
|
-
talks to ``auth.castform.dev`` and a ``prod`` session to ``auth.castform.com``.
|
|
53
|
+
the same base domain as everything else, or ``CASTFORM_AUTH_URL`` to override.
|
|
80
54
|
"""
|
|
81
55
|
return os.environ.get("CASTFORM_AUTH_URL") or f"https://auth.{base_domain()}"
|
|
@@ -285,8 +285,14 @@ tags. Cite your sources inline using [Source: <source_id>] next to each claim.
|
|
|
285
285
|
if not text.strip():
|
|
286
286
|
return zeros
|
|
287
287
|
|
|
288
|
-
|
|
288
|
+
# No final <answer> block → no answer to score. Return all-zero
|
|
289
|
+
# rewards so conciseness / citations / efficiency can't accrue
|
|
290
|
+
# from reasoning or tool-call text alone.
|
|
289
291
|
answer = extract_answer_block(text)
|
|
292
|
+
if not answer:
|
|
293
|
+
return zeros
|
|
294
|
+
|
|
295
|
+
t = task or {}
|
|
290
296
|
prompt = str(t.get("question") or t.get("prompt") or "")
|
|
291
297
|
gt_str = str(t.get("ground_truth") or "")
|
|
292
298
|
reference_chunks = t.get("reference_chunks", [])
|
benchmax/envs/reward_helpers.py
CHANGED
|
@@ -82,9 +82,16 @@ def extract_completion_text(completion: str | list[dict[str, Any]]) -> str:
|
|
|
82
82
|
|
|
83
83
|
|
|
84
84
|
def extract_answer_block(text: str) -> str:
|
|
85
|
-
"""Extract content from
|
|
85
|
+
"""Extract content from ``<answer>`` tags.
|
|
86
|
+
|
|
87
|
+
Returns the (stripped) tag contents when an ``<answer>…</answer>`` block
|
|
88
|
+
is present, otherwise ``""``. A missing answer block is treated as "no
|
|
89
|
+
final answer" rather than silently falling back to the full completion —
|
|
90
|
+
consumers can gate rewards on a non-empty result. ``<answer></answer>``
|
|
91
|
+
likewise yields ``""``.
|
|
92
|
+
"""
|
|
86
93
|
match = _ANSWER_TAG_RE.search(text or "")
|
|
87
|
-
return
|
|
94
|
+
return match.group(1).strip() if match else ""
|
|
88
95
|
|
|
89
96
|
|
|
90
97
|
def clip01(value: Any) -> float:
|
|
@@ -17,8 +17,8 @@ env's word-list / rhyme dependencies):
|
|
|
17
17
|
|
|
18
18
|
(``CASTFORM_LLM_API_KEY`` is optional — it defaults to ``CASTFORM_API_KEY``.)
|
|
19
19
|
|
|
20
|
-
|
|
21
|
-
|
|
20
|
+
This launches a real training run on the full committed seed dataset
|
|
21
|
+
(~90/10 train/eval split).
|
|
22
22
|
"""
|
|
23
23
|
|
|
24
24
|
import asyncio
|
|
@@ -42,7 +42,7 @@ from benchmax.rubrics import rubric as rubric_mod
|
|
|
42
42
|
#
|
|
43
43
|
# Defaults route through ``benchmax.config``: the prod LLM endpoint is
|
|
44
44
|
# ``https://llm.castform.com/v1`` and the platform control plane is
|
|
45
|
-
# ``https://api.castform.com``. Point at
|
|
45
|
+
# ``https://api.castform.com``. Point at a different environment by setting
|
|
46
46
|
# ``CASTFORM_BASE_DOMAIN`` (or override URLs individually via
|
|
47
47
|
# ``CASTFORM_PLATFORM_URL`` / ``CASTFORM_LLM_URL``).
|
|
48
48
|
from benchmax import config
|
|
@@ -59,6 +59,10 @@ EXPERIMENT_PREFIX = "telestich"
|
|
|
59
59
|
DATASET_PATH = str(Path(__file__).parent / "telestich_dataset.jsonl")
|
|
60
60
|
NUM_EXAMPLES = 400
|
|
61
61
|
CONCURRENCY = 15
|
|
62
|
+
# Trainer model — the launch `model` arg selects the trainer YAML (and thus the GPU
|
|
63
|
+
# pool) server-side. Supported: "Qwen/Qwen3.5-4B" (gpu4) or "Qwen/Qwen3.5-35B-A3B"
|
|
64
|
+
# (gpu8). Override via TELESTICH_MODEL.
|
|
65
|
+
MODEL = os.environ.get("TELESTICH_MODEL", "Qwen/Qwen3.5-4B")
|
|
62
66
|
|
|
63
67
|
# (model, weight). Weights reflect observed reliability on our checks:
|
|
64
68
|
# - Both grok models leak banned example words and rubber-stamp the CoT self-check.
|
|
@@ -552,7 +556,6 @@ def get_dataset():
|
|
|
552
556
|
# alongside the pickle so a UI can show "what code is in this env" without
|
|
553
557
|
# unpickling.
|
|
554
558
|
if __name__ == "__main__":
|
|
555
|
-
import tempfile
|
|
556
559
|
import uuid
|
|
557
560
|
|
|
558
561
|
from benchmax.platform.client import TrainerClient
|
|
@@ -565,42 +568,25 @@ if __name__ == "__main__":
|
|
|
565
568
|
print(f"Platform URL: {BASE_URL}")
|
|
566
569
|
print(f"LLM URL: {LLM_BASE_URL}\n")
|
|
567
570
|
|
|
568
|
-
# 1. Build the dataset.
|
|
569
|
-
#
|
|
570
|
-
#
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
# order (simpler first) so the difficulty ramp is preserved.
|
|
580
|
-
n_eval = max(1, len(examples) // 10)
|
|
581
|
-
eval_idx = set(random.sample(range(len(examples)), n_eval))
|
|
582
|
-
eval_data = [e for i, e in enumerate(examples) if i in eval_idx]
|
|
583
|
-
train_data = [e for i, e in enumerate(examples) if i not in eval_idx]
|
|
584
|
-
print(f"Full run: {len(train_data)} train (curriculum order) / {len(eval_data)} eval.\n")
|
|
585
|
-
else:
|
|
586
|
-
with tempfile.TemporaryDirectory() as tmp:
|
|
587
|
-
gen_path = Path(tmp) / "gen.jsonl"
|
|
588
|
-
print(f"Generating 2 examples via {LLM_BASE_URL} ...")
|
|
589
|
-
asyncio.run(generate_dataset(n=2, path=str(gen_path), concurrency=2))
|
|
590
|
-
examples = load_dataset(str(gen_path))
|
|
591
|
-
if len(examples) < 2:
|
|
592
|
-
raise SystemExit(f"Needed 2 examples, only got {len(examples)}.")
|
|
593
|
-
train_data, eval_data = examples[:1], examples[1:2]
|
|
594
|
-
print(f"Smoke run: generated {len(examples)} examples — 1 train, 1 eval.\n")
|
|
571
|
+
# 1. Build the dataset from the committed seed file (curriculum order). Hold out a
|
|
572
|
+
# representative eval set at random; keep TRAIN in curriculum order (simpler first)
|
|
573
|
+
# so the difficulty ramp is preserved.
|
|
574
|
+
examples = get_dataset()
|
|
575
|
+
if len(examples) < 2:
|
|
576
|
+
raise SystemExit(f"Need >=2 examples, got {len(examples)}.")
|
|
577
|
+
n_eval = max(1, len(examples) // 10)
|
|
578
|
+
eval_idx = set(random.sample(range(len(examples)), n_eval))
|
|
579
|
+
eval_data = [e for i, e in enumerate(examples) if i in eval_idx]
|
|
580
|
+
train_data = [e for i, e in enumerate(examples) if i not in eval_idx]
|
|
581
|
+
print(f"{len(train_data)} train (curriculum order) / {len(eval_data)} eval.\n")
|
|
595
582
|
|
|
596
583
|
# 2. Bundle the env class and upload everything to platform storage.
|
|
597
584
|
# Bundle config, defined once so the pre-flight validation below exercises
|
|
598
585
|
# the EXACT same env_args / by-value modules / deps as the launch.
|
|
599
586
|
# - local_modules: ship env + rubric by value (the platform's installed
|
|
600
587
|
# benchmax may not contain this version of these modules).
|
|
601
|
-
# -
|
|
602
|
-
|
|
603
|
-
constructor_args = {"judge_base_url": LLM_BASE_URL, "judge_api_key": ""}
|
|
588
|
+
# - judge bearer resolves at runtime via the device-auth / platform seam.
|
|
589
|
+
constructor_args = {"judge_base_url": LLM_BASE_URL}
|
|
604
590
|
local_modules = [telestich_env_mod, rubric_mod]
|
|
605
591
|
# All three are still required (is_valid_word → correctness; pronouncing →
|
|
606
592
|
# rhyme). Removing word_bank did NOT free any of them.
|
|
@@ -623,10 +609,12 @@ if __name__ == "__main__":
|
|
|
623
609
|
llm_api_key="",
|
|
624
610
|
remote_examples=2,
|
|
625
611
|
):
|
|
626
|
-
raise SystemExit(
|
|
612
|
+
raise SystemExit(
|
|
613
|
+
"Env validation failed — aborting before launch (see output above)."
|
|
614
|
+
)
|
|
627
615
|
|
|
628
616
|
# 3. Bundle the env class and upload everything to platform storage.
|
|
629
|
-
run_name = f"telestich-
|
|
617
|
+
run_name = f"telestich-full-{uuid.uuid4().hex[:8]}"
|
|
630
618
|
print(f"\nUploading bundle + datasets as {run_name!r} ...")
|
|
631
619
|
uploaded = upload_training_run(
|
|
632
620
|
env_class=TelestichEnv,
|
|
@@ -647,8 +635,9 @@ if __name__ == "__main__":
|
|
|
647
635
|
):
|
|
648
636
|
print(f" {label:<14}: {path}")
|
|
649
637
|
|
|
650
|
-
# 4. Launch the training run.
|
|
651
|
-
|
|
638
|
+
# 4. Launch the training run. training_run_type="simple" + the `model` arg select
|
|
639
|
+
# the trainer YAML/pool server-side (Qwen3.5-4B→gpu4, Qwen3.5-35B-A3B→gpu8).
|
|
640
|
+
print(f"\nLaunching training run (model={MODEL}) ...")
|
|
652
641
|
with TrainerClient(api_key=API_KEY, base_url=BASE_URL) as trainer:
|
|
653
642
|
run_id = trainer.launch_training_run(
|
|
654
643
|
training_run_type="simple",
|
|
@@ -661,7 +650,7 @@ if __name__ == "__main__":
|
|
|
661
650
|
# max_response_len 3000: a brief reason + 1-2 tool rounds + poem fits well
|
|
662
651
|
# under this; lowered from 4000 to cut off in-head enumeration rambles
|
|
663
652
|
# sooner (they truncate to a 0-reward anyway).
|
|
664
|
-
launcher_args={"max_response_len": 3000, "num_epochs": 10},
|
|
653
|
+
launcher_args={"model": MODEL, "max_response_len": 3000, "num_epochs": 10},
|
|
665
654
|
)
|
|
666
655
|
|
|
667
656
|
print(f"\n✓ Launched run_id={run_id}")
|