benchmax 0.1.2.dev29__py3-none-any.whl → 0.1.2.dev31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmax/bundle.py +74 -0
- benchmax/cli.py +7 -14
- benchmax/config.py +11 -37
- benchmax/envs/postgres_search/search_env.py +7 -1
- benchmax/envs/reward_helpers.py +9 -2
- benchmax/envs/telestich/example.py +44 -48
- benchmax/envs/telestich/telestich_env.py +627 -414
- benchmax/platform/client.py +6 -2
- benchmax/platform/credentials.py +1 -2
- benchmax/platform/login.py +28 -17
- benchmax/platform/validation.py +43 -1
- benchmax/rag/corpus/chroma/client.py +97 -0
- benchmax/rag/corpus/chroma/source.py +35 -5
- benchmax/rag/corpus/pinecone/index_client.py +78 -5
- benchmax/rag/corpus/pinecone/search.py +5 -0
- benchmax/rag/corpus/pinecone/source.py +52 -26
- benchmax/rag/corpus/turbopuffer/namespace.py +21 -0
- benchmax/rag/corpus/turbopuffer/search.py +15 -3
- benchmax/rag/corpus/turbopuffer/source.py +14 -8
- benchmax/rubrics/rubric.py +101 -26
- {benchmax-0.1.2.dev29.dist-info → benchmax-0.1.2.dev31.dist-info}/METADATA +1 -1
- {benchmax-0.1.2.dev29.dist-info → benchmax-0.1.2.dev31.dist-info}/RECORD +26 -26
- {benchmax-0.1.2.dev29.dist-info → benchmax-0.1.2.dev31.dist-info}/WHEEL +0 -0
- {benchmax-0.1.2.dev29.dist-info → benchmax-0.1.2.dev31.dist-info}/entry_points.txt +0 -0
- {benchmax-0.1.2.dev29.dist-info → benchmax-0.1.2.dev31.dist-info}/licenses/LICENSE +0 -0
- {benchmax-0.1.2.dev29.dist-info → benchmax-0.1.2.dev31.dist-info}/top_level.txt +0 -0
benchmax/bundle.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import importlib
|
|
3
4
|
import inspect
|
|
4
5
|
import io
|
|
5
6
|
import json
|
|
@@ -76,6 +77,7 @@ def dump_bundle(
|
|
|
76
77
|
pip_dependencies: list[str] | None = None,
|
|
77
78
|
local_modules: list[ModuleType] | None = None,
|
|
78
79
|
env_class_source: str | None = None,
|
|
80
|
+
auto_local_modules: bool = True,
|
|
79
81
|
) -> Bundle:
|
|
80
82
|
"""Pickle ``(env_class, constructor_args)`` and stamp metadata.
|
|
81
83
|
|
|
@@ -90,6 +92,10 @@ def dump_bundle(
|
|
|
90
92
|
recover it — e.g. a class produced by ``exec()`` into an in-memory
|
|
91
93
|
namespace, which has no source file on disk. When ``None``
|
|
92
94
|
(default), source is introspected from ``env_class``.
|
|
95
|
+
auto_local_modules: When True (default), any local module the pickle
|
|
96
|
+
references but that wasn't passed in ``local_modules`` is imported
|
|
97
|
+
and pickled by value automatically (a warning names them). When
|
|
98
|
+
False, such a reference raises ``BundlingError`` instead.
|
|
93
99
|
|
|
94
100
|
Raises:
|
|
95
101
|
BundlingError: bad env_class, cloudpickle failure, or pickle references
|
|
@@ -124,6 +130,46 @@ def dump_bundle(
|
|
|
124
130
|
except Exception:
|
|
125
131
|
pass
|
|
126
132
|
|
|
133
|
+
if auto_local_modules and _unregistered_local_refs(pickled):
|
|
134
|
+
# Import each referenced local module and re-dump with it pickled by
|
|
135
|
+
# value. Loop because a by-value module can surface further local refs;
|
|
136
|
+
# registrations accumulate (and are torn down once at the end) so an
|
|
137
|
+
# earlier module stays by-value while we resolve the ones it pulled in.
|
|
138
|
+
seen: set[str] = {m.__name__ for m in local_modules}
|
|
139
|
+
registered: list[ModuleType] = []
|
|
140
|
+
with _BUNDLE_LOCK:
|
|
141
|
+
try:
|
|
142
|
+
for _ in range(10):
|
|
143
|
+
pending = [
|
|
144
|
+
m for m in _unregistered_local_refs(pickled) if m not in seen
|
|
145
|
+
]
|
|
146
|
+
if not pending:
|
|
147
|
+
break
|
|
148
|
+
new_mods: list[ModuleType] = []
|
|
149
|
+
for name in pending:
|
|
150
|
+
seen.add(name) # unimportable names fall through to the guard
|
|
151
|
+
try:
|
|
152
|
+
new_mods.append(importlib.import_module(name))
|
|
153
|
+
except Exception:
|
|
154
|
+
pass
|
|
155
|
+
if not new_mods:
|
|
156
|
+
break
|
|
157
|
+
logger.warning(
|
|
158
|
+
"[bundle] %s: auto-bundling local module(s): %s ",
|
|
159
|
+
env_class.__name__,
|
|
160
|
+
", ".join(sorted(m.__name__ for m in new_mods)),
|
|
161
|
+
)
|
|
162
|
+
for mod in new_mods:
|
|
163
|
+
cloudpickle.register_pickle_by_value(mod)
|
|
164
|
+
registered.append(mod)
|
|
165
|
+
pickled = cloudpickle.dumps((env_class, constructor_args))
|
|
166
|
+
finally:
|
|
167
|
+
for mod in registered:
|
|
168
|
+
try:
|
|
169
|
+
cloudpickle.unregister_pickle_by_value(mod)
|
|
170
|
+
except Exception:
|
|
171
|
+
pass
|
|
172
|
+
|
|
127
173
|
risky = _unregistered_local_refs(pickled)
|
|
128
174
|
if risky:
|
|
129
175
|
msg = (
|
|
@@ -259,6 +305,15 @@ def _referenced_modules(pickled: bytes) -> set[str]:
|
|
|
259
305
|
# Hooks find_class so we see every (module, name) the unpickler would import —
|
|
260
306
|
# i.e. exactly what'd raise ModuleNotFoundError on a fresh interpreter. The stub
|
|
261
307
|
# lets unpickling proceed past missing classes so we collect every ref.
|
|
308
|
+
#
|
|
309
|
+
# find_class alone has a blind spot: a bare ``import foo`` that leaves a
|
|
310
|
+
# module *object* in the env's globals is pickled as
|
|
311
|
+
# ``cloudpickle.subimport("foo")`` — the module name is a REDUCE argument,
|
|
312
|
+
# not a find_class path, so we'd only see ``cloudpickle.cloudpickle`` (which
|
|
313
|
+
# looks installed) and miss ``foo``. We shim subimport to record its arg and
|
|
314
|
+
# return a stub instead of importing, so a missing module is captured rather
|
|
315
|
+
# than aborting the whole load early. (``dynamic_subimport`` is by-value /
|
|
316
|
+
# self-contained — leave it to the real find_class so we don't flag it.)
|
|
262
317
|
refs: set[str] = set()
|
|
263
318
|
|
|
264
319
|
class _Stub:
|
|
@@ -271,9 +326,28 @@ def _referenced_modules(pickled: bytes) -> set[str]:
|
|
|
271
326
|
def __reduce__(self) -> tuple:
|
|
272
327
|
return (type(self), ())
|
|
273
328
|
|
|
329
|
+
def _recording_subimport(name: str, *a: Any, **kw: Any) -> ModuleType:
|
|
330
|
+
refs.add(name)
|
|
331
|
+
return ModuleType(str(name))
|
|
332
|
+
|
|
333
|
+
def _noop_setstate(obj: Any, *a: Any, **kw: Any) -> Any:
|
|
334
|
+
# cloudpickle's _make_skeleton_class resolves the class_tracker_id back
|
|
335
|
+
# to the *live* class (it was tracked when env_class was dumped), so the
|
|
336
|
+
# real ``_class_setstate``/``_function_setstate`` would setattr the
|
|
337
|
+
# reconstructed (stub-globals) members onto the live class/function —
|
|
338
|
+
# mutating the caller's class mid-bundle and poisoning any later dump.
|
|
339
|
+
# We only need the refs from ``state``, which are already recorded while
|
|
340
|
+
# it's unpickled; the setter itself is a no-op here.
|
|
341
|
+
return obj
|
|
342
|
+
|
|
274
343
|
class _Recorder(pickle.Unpickler):
|
|
275
344
|
def find_class(self, module: str, name: str) -> Any:
|
|
276
345
|
refs.add(module)
|
|
346
|
+
if module.startswith("cloudpickle"):
|
|
347
|
+
if name == "subimport":
|
|
348
|
+
return _recording_subimport
|
|
349
|
+
if name in ("_class_setstate", "_function_setstate"):
|
|
350
|
+
return _noop_setstate
|
|
277
351
|
try:
|
|
278
352
|
return super().find_class(module, name)
|
|
279
353
|
except Exception:
|
benchmax/cli.py
CHANGED
|
@@ -11,19 +11,19 @@ from __future__ import annotations
|
|
|
11
11
|
import argparse
|
|
12
12
|
import sys
|
|
13
13
|
|
|
14
|
+
from benchmax import config
|
|
14
15
|
from benchmax.platform import credentials
|
|
15
16
|
from benchmax.platform.device_auth import DeviceAuthError
|
|
16
17
|
from benchmax.platform.login import _login
|
|
17
18
|
|
|
18
19
|
|
|
19
|
-
def _cmd_login(
|
|
20
|
-
env = "staging" if args.env == "staging" else None
|
|
20
|
+
def _cmd_login(_args: argparse.Namespace) -> int:
|
|
21
21
|
try:
|
|
22
|
-
_login(
|
|
22
|
+
_login()
|
|
23
23
|
except DeviceAuthError as exc:
|
|
24
24
|
print(f"Login failed: {exc}", file=sys.stderr)
|
|
25
25
|
return 1
|
|
26
|
-
print(f"\n✓ Logged in to {
|
|
26
|
+
print(f"\n✓ Logged in to {config.base_domain()}.")
|
|
27
27
|
return 0
|
|
28
28
|
|
|
29
29
|
|
|
@@ -38,19 +38,18 @@ def _cmd_whoami(_args: argparse.Namespace) -> int:
|
|
|
38
38
|
if not session:
|
|
39
39
|
print("Not logged in. Run `castform login`.", file=sys.stderr)
|
|
40
40
|
return 1
|
|
41
|
-
env = session.get("env", "prod")
|
|
42
41
|
jwt = credentials._session_jwt() # mints from the session; None if invalid/expired/offline
|
|
43
42
|
if not jwt:
|
|
44
43
|
print(
|
|
45
|
-
|
|
46
|
-
"
|
|
44
|
+
"Session present, but couldn't reach auth-service to verify it "
|
|
45
|
+
"(offline, or the session expired). If this persists, run "
|
|
47
46
|
"`castform login` again.",
|
|
48
47
|
file=sys.stderr,
|
|
49
48
|
)
|
|
50
49
|
return 1
|
|
51
50
|
claims = credentials._jwt_claims(jwt)
|
|
52
51
|
who = claims.get("email") or claims.get("sub", "<unknown>")
|
|
53
|
-
print(f"Logged in as {who} (
|
|
52
|
+
print(f"Logged in as {who} ({config.base_domain()}).")
|
|
54
53
|
return 0
|
|
55
54
|
|
|
56
55
|
|
|
@@ -59,12 +58,6 @@ def main(argv: list[str] | None = None) -> int:
|
|
|
59
58
|
sub = parser.add_subparsers(dest="command", required=True)
|
|
60
59
|
|
|
61
60
|
p_login = sub.add_parser("login", help="Sign in via your browser")
|
|
62
|
-
p_login.add_argument(
|
|
63
|
-
"--env",
|
|
64
|
-
choices=["prod", "staging"],
|
|
65
|
-
default="prod",
|
|
66
|
-
help="Environment to sign in to (staging is internal-only)",
|
|
67
|
-
)
|
|
68
61
|
p_login.set_defaults(func=_cmd_login)
|
|
69
62
|
|
|
70
63
|
sub.add_parser("logout", help="Clear the cached session").set_defaults(func=_cmd_logout)
|
benchmax/config.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
|
1
1
|
"""Centralized URL configuration for the Castform platform.
|
|
2
2
|
|
|
3
|
-
All URLs derive from a single base domain
|
|
4
|
-
|
|
5
|
-
|
|
3
|
+
All URLs derive from a single base domain, resolved from exactly two places: the
|
|
4
|
+
``CASTFORM_BASE_DOMAIN`` env var, or the built-in ``castform.com`` default.
|
|
5
|
+
Individual URLs may be overridden via their own env vars
|
|
6
|
+
(``CASTFORM_PLATFORM_URL`` / ``CASTFORM_LLM_URL`` / ``CASTFORM_AUTH_URL`` /
|
|
7
|
+
``CASTFORM_WEB_APP_URL``) — e.g. point platform at ``http://localhost:3000`` for
|
|
8
|
+
local dev while auth keeps talking to the real host.
|
|
6
9
|
|
|
7
10
|
Usage::
|
|
8
11
|
|
|
@@ -16,38 +19,10 @@ DEFAULT_BASE_DOMAIN = "castform.com"
|
|
|
16
19
|
|
|
17
20
|
|
|
18
21
|
def base_domain() -> str:
|
|
19
|
-
"""Resolve the platform base domain
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
(``castform.com``). The ``env`` claim travels with the credential, so a
|
|
24
|
-
logged-in SDK routes to the same environment it authenticated against —
|
|
25
|
-
URL and credential can't desync. A prod session carries no ``env`` marker
|
|
26
|
-
(``None`` → prod), so only internal staging logins deviate from the default.
|
|
27
|
-
"""
|
|
28
|
-
override = os.environ.get("CASTFORM_BASE_DOMAIN")
|
|
29
|
-
if override:
|
|
30
|
-
return override
|
|
31
|
-
if _session_env() == "staging":
|
|
32
|
-
return "castform.dev"
|
|
33
|
-
return DEFAULT_BASE_DOMAIN
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def _session_env() -> str | None:
|
|
37
|
-
"""The ``env`` from the cached device-auth session, if any.
|
|
38
|
-
|
|
39
|
-
Lazy import: ``config`` is a leaf that ``benchmax.platform`` depends on, so a
|
|
40
|
-
top-level import would cycle (platform/__init__ → client → config)."""
|
|
41
|
-
try:
|
|
42
|
-
from benchmax.platform.credentials import read_castform_session
|
|
43
|
-
|
|
44
|
-
session = read_castform_session()
|
|
45
|
-
except Exception:
|
|
46
|
-
return None
|
|
47
|
-
if not session:
|
|
48
|
-
return None
|
|
49
|
-
env = session.get("env")
|
|
50
|
-
return env if isinstance(env, str) else None
|
|
22
|
+
"""Resolve the platform base domain: ``CASTFORM_BASE_DOMAIN`` or the
|
|
23
|
+
``castform.com`` default. To target another environment (e.g. internal
|
|
24
|
+
staging), export ``CASTFORM_BASE_DOMAIN=castform.dev``."""
|
|
25
|
+
return os.environ.get("CASTFORM_BASE_DOMAIN") or DEFAULT_BASE_DOMAIN
|
|
51
26
|
|
|
52
27
|
|
|
53
28
|
def platform_url() -> str:
|
|
@@ -75,7 +50,6 @@ def auth_url() -> str:
|
|
|
75
50
|
"""Auth-service base URL (device-authorization + JWT mint endpoints).
|
|
76
51
|
|
|
77
52
|
Used by ``castform login`` and the per-process session→JWT mint. Derives from
|
|
78
|
-
the same base domain as everything else,
|
|
79
|
-
talks to ``auth.castform.dev`` and a ``prod`` session to ``auth.castform.com``.
|
|
53
|
+
the same base domain as everything else, or ``CASTFORM_AUTH_URL`` to override.
|
|
80
54
|
"""
|
|
81
55
|
return os.environ.get("CASTFORM_AUTH_URL") or f"https://auth.{base_domain()}"
|
|
@@ -285,8 +285,14 @@ tags. Cite your sources inline using [Source: <source_id>] next to each claim.
|
|
|
285
285
|
if not text.strip():
|
|
286
286
|
return zeros
|
|
287
287
|
|
|
288
|
-
|
|
288
|
+
# No final <answer> block → no answer to score. Return all-zero
|
|
289
|
+
# rewards so conciseness / citations / efficiency can't accrue
|
|
290
|
+
# from reasoning or tool-call text alone.
|
|
289
291
|
answer = extract_answer_block(text)
|
|
292
|
+
if not answer:
|
|
293
|
+
return zeros
|
|
294
|
+
|
|
295
|
+
t = task or {}
|
|
290
296
|
prompt = str(t.get("question") or t.get("prompt") or "")
|
|
291
297
|
gt_str = str(t.get("ground_truth") or "")
|
|
292
298
|
reference_chunks = t.get("reference_chunks", [])
|
benchmax/envs/reward_helpers.py
CHANGED
|
@@ -82,9 +82,16 @@ def extract_completion_text(completion: str | list[dict[str, Any]]) -> str:
|
|
|
82
82
|
|
|
83
83
|
|
|
84
84
|
def extract_answer_block(text: str) -> str:
|
|
85
|
-
"""Extract content from
|
|
85
|
+
"""Extract content from ``<answer>`` tags.
|
|
86
|
+
|
|
87
|
+
Returns the (stripped) tag contents when an ``<answer>…</answer>`` block
|
|
88
|
+
is present, otherwise ``""``. A missing answer block is treated as "no
|
|
89
|
+
final answer" rather than silently falling back to the full completion —
|
|
90
|
+
consumers can gate rewards on a non-empty result. ``<answer></answer>``
|
|
91
|
+
likewise yields ``""``.
|
|
92
|
+
"""
|
|
86
93
|
match = _ANSWER_TAG_RE.search(text or "")
|
|
87
|
-
return
|
|
94
|
+
return match.group(1).strip() if match else ""
|
|
88
95
|
|
|
89
96
|
|
|
90
97
|
def clip01(value: Any) -> float:
|
|
@@ -12,13 +12,15 @@ Run it from the benchmax project root (the ``telestich`` extra pulls in the
|
|
|
12
12
|
env's word-list / rhyme dependencies):
|
|
13
13
|
|
|
14
14
|
cd core/benchmax
|
|
15
|
-
|
|
16
|
-
uv run --extra telestich python -m benchmax.envs.telestich.example
|
|
15
|
+
uv run --extra telestich python -m benchmax.envs.telestich.example
|
|
17
16
|
|
|
18
|
-
(``
|
|
17
|
+
Auth is the device-auth session (``ensure_session()`` opens a browser login if
|
|
18
|
+
``~/.castform`` has no valid session) — no API key needed. ``CASTFORM_API_KEY``
|
|
19
|
+
/ ``CASTFORM_LLM_API_KEY`` are only consulted by the offline dataset-generation
|
|
20
|
+
helpers, not the launch path.
|
|
19
21
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
+
This launches a real training run on the full committed seed dataset
|
|
23
|
+
(~90/10 train/eval split).
|
|
22
24
|
"""
|
|
23
25
|
|
|
24
26
|
import asyncio
|
|
@@ -42,7 +44,7 @@ from benchmax.rubrics import rubric as rubric_mod
|
|
|
42
44
|
#
|
|
43
45
|
# Defaults route through ``benchmax.config``: the prod LLM endpoint is
|
|
44
46
|
# ``https://llm.castform.com/v1`` and the platform control plane is
|
|
45
|
-
# ``https://api.castform.com``. Point at
|
|
47
|
+
# ``https://api.castform.com``. Point at a different environment by setting
|
|
46
48
|
# ``CASTFORM_BASE_DOMAIN`` (or override URLs individually via
|
|
47
49
|
# ``CASTFORM_PLATFORM_URL`` / ``CASTFORM_LLM_URL``).
|
|
48
50
|
from benchmax import config
|
|
@@ -59,6 +61,12 @@ EXPERIMENT_PREFIX = "telestich"
|
|
|
59
61
|
DATASET_PATH = str(Path(__file__).parent / "telestich_dataset.jsonl")
|
|
60
62
|
NUM_EXAMPLES = 400
|
|
61
63
|
CONCURRENCY = 15
|
|
64
|
+
# Trainer model — the launch `model` arg selects the trainer YAML (and thus the GPU
|
|
65
|
+
# pool) server-side. Supported: "Qwen/Qwen3.5-4B" (gpu4) or "Qwen/Qwen3.5-35B-A3B"
|
|
66
|
+
# (gpu8). Override via TELESTICH_MODEL.
|
|
67
|
+
MODEL = os.environ.get("TELESTICH_MODEL", "Qwen/Qwen3.5-4B")
|
|
68
|
+
# Run name — defaults to a unique telestich-full-<uuid>. Override via TELESTICH_RUN_NAME.
|
|
69
|
+
RUN_NAME = os.environ.get("TELESTICH_RUN_NAME", "")
|
|
62
70
|
|
|
63
71
|
# (model, weight). Weights reflect observed reliability on our checks:
|
|
64
72
|
# - Both grok models leak banned example words and rubber-stamp the CoT self-check.
|
|
@@ -552,55 +560,40 @@ def get_dataset():
|
|
|
552
560
|
# alongside the pickle so a UI can show "what code is in this env" without
|
|
553
561
|
# unpickling.
|
|
554
562
|
if __name__ == "__main__":
|
|
555
|
-
import tempfile
|
|
556
563
|
import uuid
|
|
557
564
|
|
|
565
|
+
from benchmax.platform import ensure_session
|
|
558
566
|
from benchmax.platform.client import TrainerClient
|
|
559
567
|
from benchmax.platform.training_run import upload_training_run
|
|
560
568
|
from benchmax.platform.validation import validate_env
|
|
561
569
|
|
|
562
|
-
if
|
|
563
|
-
|
|
570
|
+
# Device-auth session bootstrap: browser login if no credential resolves.
|
|
571
|
+
# After this the platform bearer comes from ~/.castform — no API key needed,
|
|
572
|
+
# so we pass api_key="" to the platform calls below (resolves via the seam).
|
|
573
|
+
ensure_session()
|
|
564
574
|
|
|
565
575
|
print(f"Platform URL: {BASE_URL}")
|
|
566
576
|
print(f"LLM URL: {LLM_BASE_URL}\n")
|
|
567
577
|
|
|
568
|
-
# 1. Build the dataset.
|
|
569
|
-
#
|
|
570
|
-
#
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
# order (simpler first) so the difficulty ramp is preserved.
|
|
580
|
-
n_eval = max(1, len(examples) // 10)
|
|
581
|
-
eval_idx = set(random.sample(range(len(examples)), n_eval))
|
|
582
|
-
eval_data = [e for i, e in enumerate(examples) if i in eval_idx]
|
|
583
|
-
train_data = [e for i, e in enumerate(examples) if i not in eval_idx]
|
|
584
|
-
print(f"Full run: {len(train_data)} train (curriculum order) / {len(eval_data)} eval.\n")
|
|
585
|
-
else:
|
|
586
|
-
with tempfile.TemporaryDirectory() as tmp:
|
|
587
|
-
gen_path = Path(tmp) / "gen.jsonl"
|
|
588
|
-
print(f"Generating 2 examples via {LLM_BASE_URL} ...")
|
|
589
|
-
asyncio.run(generate_dataset(n=2, path=str(gen_path), concurrency=2))
|
|
590
|
-
examples = load_dataset(str(gen_path))
|
|
591
|
-
if len(examples) < 2:
|
|
592
|
-
raise SystemExit(f"Needed 2 examples, only got {len(examples)}.")
|
|
593
|
-
train_data, eval_data = examples[:1], examples[1:2]
|
|
594
|
-
print(f"Smoke run: generated {len(examples)} examples — 1 train, 1 eval.\n")
|
|
578
|
+
# 1. Build the dataset from the committed seed file (curriculum order). Hold out a
|
|
579
|
+
# representative eval set at random; keep TRAIN in curriculum order (simpler first)
|
|
580
|
+
# so the difficulty ramp is preserved.
|
|
581
|
+
examples = get_dataset()
|
|
582
|
+
if len(examples) < 2:
|
|
583
|
+
raise SystemExit(f"Need >=2 examples, got {len(examples)}.")
|
|
584
|
+
n_eval = max(1, len(examples) // 10)
|
|
585
|
+
eval_idx = set(random.sample(range(len(examples)), n_eval))
|
|
586
|
+
eval_data = [e for i, e in enumerate(examples) if i in eval_idx]
|
|
587
|
+
train_data = [e for i, e in enumerate(examples) if i not in eval_idx]
|
|
588
|
+
print(f"{len(train_data)} train (curriculum order) / {len(eval_data)} eval.\n")
|
|
595
589
|
|
|
596
590
|
# 2. Bundle the env class and upload everything to platform storage.
|
|
597
591
|
# Bundle config, defined once so the pre-flight validation below exercises
|
|
598
592
|
# the EXACT same env_args / by-value modules / deps as the launch.
|
|
599
593
|
# - local_modules: ship env + rubric by value (the platform's installed
|
|
600
594
|
# benchmax may not contain this version of these modules).
|
|
601
|
-
# -
|
|
602
|
-
|
|
603
|
-
constructor_args = {"judge_base_url": LLM_BASE_URL, "judge_api_key": ""}
|
|
595
|
+
# - judge bearer resolves at runtime via the device-auth / platform seam.
|
|
596
|
+
constructor_args = {"judge_base_url": LLM_BASE_URL}
|
|
604
597
|
local_modules = [telestich_env_mod, rubric_mod]
|
|
605
598
|
# All three are still required (is_valid_word → correctness; pronouncing →
|
|
606
599
|
# rhyme). Removing word_bank did NOT free any of them.
|
|
@@ -617,23 +610,25 @@ if __name__ == "__main__":
|
|
|
617
610
|
eval_dataset=eval_data[:2],
|
|
618
611
|
local_modules=local_modules,
|
|
619
612
|
pip_dependencies=pip_dependencies,
|
|
620
|
-
api_key=
|
|
613
|
+
api_key="", # session bearer via ensure_session()
|
|
621
614
|
base_url=BASE_URL,
|
|
622
615
|
llm_base_url=LLM_BASE_URL,
|
|
623
616
|
llm_api_key="",
|
|
624
617
|
remote_examples=2,
|
|
625
618
|
):
|
|
626
|
-
raise SystemExit(
|
|
619
|
+
raise SystemExit(
|
|
620
|
+
"Env validation failed — aborting before launch (see output above)."
|
|
621
|
+
)
|
|
627
622
|
|
|
628
623
|
# 3. Bundle the env class and upload everything to platform storage.
|
|
629
|
-
run_name = f"telestich-
|
|
624
|
+
run_name = RUN_NAME or f"telestich-full-{uuid.uuid4().hex[:8]}"
|
|
630
625
|
print(f"\nUploading bundle + datasets as {run_name!r} ...")
|
|
631
626
|
uploaded = upload_training_run(
|
|
632
627
|
env_class=TelestichEnv,
|
|
633
628
|
train_dataset=train_data,
|
|
634
629
|
eval_dataset=eval_data,
|
|
635
630
|
run_name=run_name,
|
|
636
|
-
api_key=
|
|
631
|
+
api_key="", # session bearer via ensure_session()
|
|
637
632
|
base_url=BASE_URL,
|
|
638
633
|
local_modules=local_modules,
|
|
639
634
|
constructor_args=constructor_args,
|
|
@@ -647,9 +642,10 @@ if __name__ == "__main__":
|
|
|
647
642
|
):
|
|
648
643
|
print(f" {label:<14}: {path}")
|
|
649
644
|
|
|
650
|
-
# 4. Launch the training run.
|
|
651
|
-
|
|
652
|
-
|
|
645
|
+
# 4. Launch the training run. training_run_type="simple" + the `model` arg select
|
|
646
|
+
# the trainer YAML/pool server-side (Qwen3.5-4B→gpu4, Qwen3.5-35B-A3B→gpu8).
|
|
647
|
+
print(f"\nLaunching training run (model={MODEL}) ...")
|
|
648
|
+
with TrainerClient(api_key="", base_url=BASE_URL) as trainer:
|
|
653
649
|
run_id = trainer.launch_training_run(
|
|
654
650
|
training_run_type="simple",
|
|
655
651
|
env_cls_path=uploaded.env_cls_path,
|
|
@@ -658,10 +654,10 @@ if __name__ == "__main__":
|
|
|
658
654
|
eval_dataset_path=uploaded.eval_dataset_path,
|
|
659
655
|
name=run_name,
|
|
660
656
|
# num_epochs: passes over the train set (platform default is 5).
|
|
661
|
-
#
|
|
657
|
+
# max_rollout_len 3000: a brief reason + 1-2 tool rounds + poem fits well
|
|
662
658
|
# under this; lowered from 4000 to cut off in-head enumeration rambles
|
|
663
659
|
# sooner (they truncate to a 0-reward anyway).
|
|
664
|
-
launcher_args={"
|
|
660
|
+
launcher_args={"model": MODEL, "max_rollout_len": 3000, "num_epochs": 10},
|
|
665
661
|
)
|
|
666
662
|
|
|
667
663
|
print(f"\n✓ Launched run_id={run_id}")
|