benchmax 0.1.2.dev33__py3-none-any.whl → 0.1.2.dev35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmax/cli/__init__.py +71 -0
- benchmax/{cli.py → cli/_auth.py} +16 -22
- benchmax/cli/_client.py +49 -0
- benchmax/cli/_output.py +134 -0
- benchmax/cli/_project.py +138 -0
- benchmax/cli/_providers.py +60 -0
- benchmax/cli/control.py +28 -0
- benchmax/cli/corpus.py +230 -0
- benchmax/cli/data.py +441 -0
- benchmax/cli/help.py +233 -0
- benchmax/cli/launch.py +241 -0
- benchmax/cli/runs.py +187 -0
- benchmax/cli/scaffold/CLAUDE.md +132 -0
- benchmax/cli/scaffold/STARTER.md +93 -0
- benchmax/cli/scaffold/__init__.py +0 -0
- benchmax/cli/scaffold/rag_run.py +72 -0
- benchmax/cli/scaffold/skills/design-environment/SKILL.md +327 -0
- benchmax/cli/scaffold/skills/generate-data/SKILL.md +192 -0
- benchmax/cli/scaffold/skills/launch-run/SKILL.md +68 -0
- benchmax/cli/scaffold/skills/verify-environment/SKILL.md +199 -0
- benchmax/cli/scaffold/skills/view-progress/SKILL.md +63 -0
- benchmax/cli/setup.py +286 -0
- benchmax/cli/validate.py +448 -0
- benchmax/envs/postgres_search/search_env.py +14 -3
- benchmax/envs/telestich/example.py +2 -3
- benchmax/platform/client.py +117 -9
- benchmax/platform/training_run.py +0 -1
- benchmax/platform/validation.py +12 -1
- benchmax/rag/corpus/embed.py +54 -0
- benchmax/rag/corpus/postgres/client.py +237 -12
- benchmax/rag/corpus/postgres/exceptions.py +2 -2
- benchmax/rag/corpus/postgres/source.py +93 -26
- benchmax/rag/qa_generation/batch_processor.py +138 -12
- benchmax/rag/qa_generation/filters/grounding_llm.py +117 -34
- benchmax/rag/qa_generation/filters/hop_count_validity.py +116 -31
- benchmax/rag/qa_generation/filters/retrieval_llm.py +131 -44
- benchmax/rag/qa_generation/generators/direct_llm.py +123 -43
- benchmax/rag/qa_generation/metadata_linker.py +179 -10
- benchmax/rag/qa_generation/pipeline.py +297 -205
- benchmax/rag/qa_generation/pipeline_config.py +89 -0
- benchmax/rag/qa_generation/search_agent_linker.py +59 -6
- benchmax/rag/qa_generation/wiki_chunk_linker.py +34 -6
- {benchmax-0.1.2.dev33.dist-info → benchmax-0.1.2.dev35.dist-info}/METADATA +4 -2
- {benchmax-0.1.2.dev33.dist-info → benchmax-0.1.2.dev35.dist-info}/RECORD +48 -25
- {benchmax-0.1.2.dev33.dist-info → benchmax-0.1.2.dev35.dist-info}/WHEEL +0 -0
- {benchmax-0.1.2.dev33.dist-info → benchmax-0.1.2.dev35.dist-info}/entry_points.txt +0 -0
- {benchmax-0.1.2.dev33.dist-info → benchmax-0.1.2.dev35.dist-info}/licenses/LICENSE +0 -0
- {benchmax-0.1.2.dev33.dist-info → benchmax-0.1.2.dev35.dist-info}/top_level.txt +0 -0
benchmax/cli/__init__.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""``castform`` CLI — a single argparse tree assembled from command groups.
|
|
2
|
+
|
|
3
|
+
Each command group lives in its own module exposing ``register(sub)``;
|
|
4
|
+
``build_parser`` wires them onto the top-level subparsers and ``main`` dispatches
|
|
5
|
+
to the selected handler's ``func``. Bundled with the benchmax SDK — entry point
|
|
6
|
+
``benchmax.cli:main`` (``pyproject.toml``). Argparse (not typer) is deliberate:
|
|
7
|
+
bundled packaging means a CLI dep would land in the training-engine closure; see
|
|
8
|
+
``docs/plans/castform-cli-rl-workflow.md`` slice 1.1.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import argparse
|
|
14
|
+
import sys
|
|
15
|
+
|
|
16
|
+
from benchmax.cli import (
|
|
17
|
+
_auth,
|
|
18
|
+
control,
|
|
19
|
+
corpus,
|
|
20
|
+
data,
|
|
21
|
+
help,
|
|
22
|
+
launch,
|
|
23
|
+
runs,
|
|
24
|
+
setup,
|
|
25
|
+
validate,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
# Re-export auth handlers — tests/unit/test_cli.py imports them as cli._cmd_*.
|
|
29
|
+
from benchmax.cli._auth import _cmd_login, _cmd_logout, _cmd_whoami
|
|
30
|
+
|
|
31
|
+
__all__ = ["build_parser", "main", "_cmd_login", "_cmd_logout", "_cmd_whoami"]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def build_parser() -> argparse.ArgumentParser:
|
|
35
|
+
"""Build the full castform parser. Tests snapshot its ``format_help()``."""
|
|
36
|
+
parser = argparse.ArgumentParser(prog="castform", description="Castform CLI")
|
|
37
|
+
sub = parser.add_subparsers(dest="command", required=True, metavar="<command>")
|
|
38
|
+
_auth.register(sub)
|
|
39
|
+
runs.register(sub)
|
|
40
|
+
control.register(sub)
|
|
41
|
+
validate.register(sub)
|
|
42
|
+
launch.register(sub)
|
|
43
|
+
data.register(sub)
|
|
44
|
+
corpus.register(sub)
|
|
45
|
+
setup.register(sub)
|
|
46
|
+
|
|
47
|
+
# `guide` renders the getting-started walkthrough (the renderer lives in
|
|
48
|
+
# ``benchmax.cli.help``). Named `guide`, not `quickstart`, because `setup`
|
|
49
|
+
# is itself the quickstart flow — keep the two from blurring together.
|
|
50
|
+
gp = sub.add_parser("guide", help="Walk through your first run")
|
|
51
|
+
gp.set_defaults(func=help._cmd_help)
|
|
52
|
+
|
|
53
|
+
# `help` mirrors `-h`: just list the commands. Users reach for it by habit;
|
|
54
|
+
# the walkthrough is `castform guide`, not here.
|
|
55
|
+
def _list_commands(_args: argparse.Namespace) -> int:
|
|
56
|
+
parser.print_help()
|
|
57
|
+
return 0
|
|
58
|
+
|
|
59
|
+
hp = sub.add_parser("help", help="List the available commands")
|
|
60
|
+
hp.set_defaults(func=_list_commands)
|
|
61
|
+
return parser
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def main(argv: list[str] | None = None) -> int:
|
|
65
|
+
parser = build_parser()
|
|
66
|
+
args = parser.parse_args(argv)
|
|
67
|
+
return args.func(args)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
if __name__ == "__main__":
|
|
71
|
+
sys.exit(main())
|
benchmax/{cli.py → cli/_auth.py}
RENAMED
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""castform auth commands: ``login`` / ``logout`` / ``whoami``.
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
resolves its bearer from ``~/.castform`` automatically — no API key or URL.
|
|
3
|
+
The device-auth flow + the reusable ``ensure_session`` live in
|
|
4
|
+
:mod:`benchmax.platform.login`; these handlers are thin argparse wrappers. After
|
|
5
|
+
``castform login`` the SDK resolves its bearer from ``~/.castform`` automatically.
|
|
7
6
|
"""
|
|
8
7
|
|
|
9
8
|
from __future__ import annotations
|
|
@@ -38,7 +37,7 @@ def _cmd_whoami(_args: argparse.Namespace) -> int:
|
|
|
38
37
|
if not session:
|
|
39
38
|
print("Not logged in. Run `castform login`.", file=sys.stderr)
|
|
40
39
|
return 1
|
|
41
|
-
jwt = credentials._session_jwt() #
|
|
40
|
+
jwt = credentials._session_jwt() # None if invalid/expired/offline
|
|
42
41
|
if not jwt:
|
|
43
42
|
print(
|
|
44
43
|
"Session present, but couldn't reach auth-service to verify it "
|
|
@@ -53,19 +52,14 @@ def _cmd_whoami(_args: argparse.Namespace) -> int:
|
|
|
53
52
|
return 0
|
|
54
53
|
|
|
55
54
|
|
|
56
|
-
def
|
|
57
|
-
|
|
58
|
-
sub
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
sub.add_parser("whoami", help="Show the current login").set_defaults(
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
return args.func(args)
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
if __name__ == "__main__":
|
|
71
|
-
sys.exit(main())
|
|
55
|
+
def register(sub: argparse._SubParsersAction) -> None:
|
|
56
|
+
"""Attach login/logout/whoami to the top-level subparsers."""
|
|
57
|
+
sub.add_parser("login", help="Sign in via your browser").set_defaults(
|
|
58
|
+
func=_cmd_login
|
|
59
|
+
)
|
|
60
|
+
sub.add_parser("logout", help="Clear the cached session").set_defaults(
|
|
61
|
+
func=_cmd_logout
|
|
62
|
+
)
|
|
63
|
+
sub.add_parser("whoami", help="Show the current login").set_defaults(
|
|
64
|
+
func=_cmd_whoami
|
|
65
|
+
)
|
benchmax/cli/_client.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""Shared platform-client wiring + error handling for CLI command groups.
|
|
2
|
+
|
|
3
|
+
Read/control commands resolve their bearer through the credential seam
|
|
4
|
+
(:func:`benchmax.platform.credentials.platform_bearer`) — ``ACT_AS_TOKEN_PATH``
|
|
5
|
+
→ ``PLATFORM_API_KEY`` → cached ``~/.castform`` session — against the host from
|
|
6
|
+
:mod:`benchmax.config`. ``handle_errors`` keeps tracebacks out of normal failures.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import functools
|
|
12
|
+
import sys
|
|
13
|
+
from collections.abc import Callable
|
|
14
|
+
|
|
15
|
+
import httpx
|
|
16
|
+
|
|
17
|
+
from benchmax.platform.client import TrainerClient
|
|
18
|
+
from benchmax.platform.exceptions import AuthenticationError, TrainerError
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def trainer_client() -> TrainerClient:
|
|
22
|
+
"""A TrainerClient bound to the configured platform host + bearer seam."""
|
|
23
|
+
return TrainerClient()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def handle_errors(func: Callable) -> Callable:
|
|
27
|
+
"""Turn client/credential/network failures into a clean stderr line + exit 1."""
|
|
28
|
+
|
|
29
|
+
@functools.wraps(func)
|
|
30
|
+
def wrapper(args):
|
|
31
|
+
try:
|
|
32
|
+
return func(args)
|
|
33
|
+
except AuthenticationError:
|
|
34
|
+
print(
|
|
35
|
+
"Not logged in (or session expired). Run `castform login`.",
|
|
36
|
+
file=sys.stderr,
|
|
37
|
+
)
|
|
38
|
+
return 1
|
|
39
|
+
except TrainerError as exc:
|
|
40
|
+
print(f"Error: {exc.message}", file=sys.stderr)
|
|
41
|
+
return 1
|
|
42
|
+
except RuntimeError as exc: # platform_bearer with no resolvable credential
|
|
43
|
+
print(f"Error: {exc}", file=sys.stderr)
|
|
44
|
+
return 1
|
|
45
|
+
except httpx.HTTPError as exc:
|
|
46
|
+
print(f"Network error: {exc}", file=sys.stderr)
|
|
47
|
+
return 1
|
|
48
|
+
|
|
49
|
+
return wrapper
|
benchmax/cli/_output.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
"""Small, dependency-free output helpers shared by CLI command groups."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json as _json
|
|
6
|
+
import os
|
|
7
|
+
import re
|
|
8
|
+
import shutil
|
|
9
|
+
import sys
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# Castform brand palette (web-app diagram tokens: blue #3b76f6, orange #f97316),
|
|
14
|
+
# softened for the terminal and tuned per background. We can't query the terminal
|
|
15
|
+
# background synchronously, but many emulators export COLORFGBG ("fg;bg"); when bg
|
|
16
|
+
# reads light we pick deeper, higher-contrast tones, otherwise softer pastels that
|
|
17
|
+
# don't glare on black. Falls back to the dark (pastel) set. Rendered as truecolor.
|
|
18
|
+
def _is_light_terminal() -> bool:
|
|
19
|
+
parts = os.environ.get("COLORFGBG", "").split(";")
|
|
20
|
+
if parts and parts[-1].strip().isdigit():
|
|
21
|
+
return int(parts[-1].strip()) >= 7 # 7/15 = light background
|
|
22
|
+
return False
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
if _is_light_terminal():
|
|
26
|
+
BLUE = (37, 99, 235) # #2563eb — deeper, reads on white
|
|
27
|
+
ORANGE = (194, 87, 28) # #c2571c — terracotta
|
|
28
|
+
else:
|
|
29
|
+
BLUE = (125, 166, 232) # #7da6e8 — soft sky
|
|
30
|
+
ORANGE = (230, 166, 99) # #e6a663 — soft amber
|
|
31
|
+
|
|
32
|
+
_GREY = (140, 140, 140)
|
|
33
|
+
|
|
34
|
+
_ANSI_RE = re.compile(r"\x1b\[[0-9;]*m")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def color_enabled() -> bool:
|
|
38
|
+
"""True when we should emit ANSI — a real TTY and ``NO_COLOR`` unset."""
|
|
39
|
+
return sys.stdout.isatty() and not os.environ.get("NO_COLOR")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def paint(
|
|
43
|
+
text: str,
|
|
44
|
+
rgb: tuple[int, int, int] | None = None,
|
|
45
|
+
*,
|
|
46
|
+
bold: bool = False,
|
|
47
|
+
italic: bool = False,
|
|
48
|
+
dim: bool = False,
|
|
49
|
+
) -> str:
|
|
50
|
+
"""Wrap ``text`` in ANSI styling, or return it unchanged when color is off."""
|
|
51
|
+
if not color_enabled():
|
|
52
|
+
return text
|
|
53
|
+
codes: list[str] = []
|
|
54
|
+
if bold:
|
|
55
|
+
codes.append("1")
|
|
56
|
+
if dim:
|
|
57
|
+
codes.append("2")
|
|
58
|
+
if italic:
|
|
59
|
+
codes.append("3")
|
|
60
|
+
if rgb is not None:
|
|
61
|
+
codes.append(f"38;2;{rgb[0]};{rgb[1]};{rgb[2]}")
|
|
62
|
+
if not codes:
|
|
63
|
+
return text
|
|
64
|
+
return f"\x1b[{';'.join(codes)}m{text}\x1b[0m"
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _visible_len(s: str) -> int:
|
|
68
|
+
"""Length of ``s`` ignoring ANSI escapes — for padding pre-colored text."""
|
|
69
|
+
return len(_ANSI_RE.sub("", s))
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def term_width(default: int = 80) -> int:
|
|
73
|
+
return shutil.get_terminal_size((default, 24)).columns
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def rule_label(text: str, color: tuple[int, int, int], width: int) -> str:
|
|
77
|
+
"""The standard section divider: a centered title flanked by rules, fully
|
|
78
|
+
colored — ``──────── title ────────`` — spanning ``width`` columns."""
|
|
79
|
+
total = max(width - len(text) - 2, 0)
|
|
80
|
+
left = total // 2
|
|
81
|
+
line = "─" * left + f" {text} " + "─" * (total - left)
|
|
82
|
+
return paint(line, color, bold=True)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def boxed(
|
|
86
|
+
lines: list[str],
|
|
87
|
+
*,
|
|
88
|
+
color: tuple[int, int, int],
|
|
89
|
+
width: int,
|
|
90
|
+
title: str = "",
|
|
91
|
+
) -> list[str]:
|
|
92
|
+
"""Render a rounded box ``width`` columns wide (content area) around ``lines``.
|
|
93
|
+
|
|
94
|
+
``lines`` may already contain ANSI codes or be nested boxes — widths are
|
|
95
|
+
measured ignoring escapes, so padding stays aligned. ``title`` is centered in
|
|
96
|
+
the top border. Returns the box as a list of rendered rows.
|
|
97
|
+
"""
|
|
98
|
+
inner = width + 2 # one space of padding on each side
|
|
99
|
+
seg = f" {title} " if title else ""
|
|
100
|
+
fill = max(inner - _visible_len(seg), 0)
|
|
101
|
+
left = fill // 2
|
|
102
|
+
top = "╭" + "─" * left + seg + "─" * (fill - left) + "╮"
|
|
103
|
+
bottom = "╰" + "─" * inner + "╯"
|
|
104
|
+
bar = paint("│", color)
|
|
105
|
+
out = [paint(top, color, bold=True)]
|
|
106
|
+
for ln in lines:
|
|
107
|
+
pad = max(width - _visible_len(ln), 0)
|
|
108
|
+
out.append(f"{bar} {ln}{' ' * pad} {bar}")
|
|
109
|
+
out.append(paint(bottom, color))
|
|
110
|
+
return out
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def print_json(obj: Any) -> None:
|
|
114
|
+
"""Emit ``obj`` as pretty JSON (``default=str`` so stray types don't crash)."""
|
|
115
|
+
print(_json.dumps(obj, indent=2, default=str))
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def render_table(headers: list[str], rows: list[list[Any]]) -> None:
|
|
119
|
+
"""Print a left-aligned fixed-width table. No-op styling — pipe-friendly."""
|
|
120
|
+
widths = [len(str(h)) for h in headers]
|
|
121
|
+
for row in rows:
|
|
122
|
+
for i, cell in enumerate(row):
|
|
123
|
+
widths[i] = max(widths[i], len(str(cell)))
|
|
124
|
+
fmt = " ".join("{:<" + str(w) + "}" for w in widths)
|
|
125
|
+
print(fmt.format(*headers))
|
|
126
|
+
for row in rows:
|
|
127
|
+
print(fmt.format(*[str(c) for c in row]))
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def fmt_value(value: Any) -> str:
|
|
131
|
+
"""Compact numeric formatting for scalar values; pass through non-numbers."""
|
|
132
|
+
if isinstance(value, float):
|
|
133
|
+
return f"{value:.4g}"
|
|
134
|
+
return str(value)
|
benchmax/cli/_project.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""Load a benchmax project (env class + datasets) from a directory.
|
|
2
|
+
|
|
3
|
+
Convention mirrors the web-app scaffold (``buildAgentContextBody``): ``run.py``
|
|
4
|
+
defines a single :class:`BaseEnv` subclass; ``train_dataset.jsonl`` /
|
|
5
|
+
``eval_dataset.jsonl`` hold one JSON object per line. ``validate`` and ``launch``
|
|
6
|
+
share this loader. An importable module path (``--module``) is an alternative to
|
|
7
|
+
``run.py`` for shipped envs / fixtures.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import importlib
|
|
13
|
+
import importlib.util
|
|
14
|
+
import inspect
|
|
15
|
+
import json
|
|
16
|
+
import sys
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from types import ModuleType
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ProjectError(Exception):
|
|
24
|
+
"""A project couldn't be loaded (missing run.py/dataset, or no/ambiguous env)."""
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class LoadedProject:
|
|
29
|
+
env_class: type
|
|
30
|
+
train_dataset: list[dict[str, Any]]
|
|
31
|
+
eval_dataset: list[dict[str, Any]]
|
|
32
|
+
module: ModuleType
|
|
33
|
+
from_file: (
|
|
34
|
+
bool # loaded from a run.py path (pickle env by value) vs an importable module
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _load_module_from_file(path: Path) -> ModuleType:
|
|
39
|
+
spec = importlib.util.spec_from_file_location(path.stem, path)
|
|
40
|
+
if spec is None or spec.loader is None:
|
|
41
|
+
raise ProjectError(f"Could not load a module from {path}")
|
|
42
|
+
module = importlib.util.module_from_spec(spec)
|
|
43
|
+
sys.modules[path.stem] = module # so dataclass/pickle name resolution works
|
|
44
|
+
try:
|
|
45
|
+
spec.loader.exec_module(module)
|
|
46
|
+
except Exception as exc: # surface the user's import/syntax error cleanly
|
|
47
|
+
raise ProjectError(f"Failed to import {path.name}: {exc}") from exc
|
|
48
|
+
return module
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def discover_env_class(module: ModuleType, explicit: str | None = None) -> type:
|
|
52
|
+
"""Find the env class in ``module``. With no ``explicit`` name, require exactly
|
|
53
|
+
one BaseEnv subclass *defined in* the module (imported ones are ignored)."""
|
|
54
|
+
from benchmax.envs.base_env import BaseEnv
|
|
55
|
+
|
|
56
|
+
def _is_env(obj: Any) -> bool:
|
|
57
|
+
return inspect.isclass(obj) and issubclass(obj, BaseEnv) and obj is not BaseEnv
|
|
58
|
+
|
|
59
|
+
if explicit:
|
|
60
|
+
for obj in vars(module).values():
|
|
61
|
+
if _is_env(obj) and obj.__name__ == explicit:
|
|
62
|
+
return obj
|
|
63
|
+
raise ProjectError(f"No BaseEnv subclass named {explicit!r} in the module.")
|
|
64
|
+
|
|
65
|
+
defined_here = [
|
|
66
|
+
obj
|
|
67
|
+
for obj in vars(module).values()
|
|
68
|
+
if _is_env(obj) and obj.__module__ == module.__name__
|
|
69
|
+
]
|
|
70
|
+
if not defined_here:
|
|
71
|
+
raise ProjectError("No BaseEnv subclass defined in the module.")
|
|
72
|
+
if len(defined_here) > 1:
|
|
73
|
+
names = sorted(c.__name__ for c in defined_here)
|
|
74
|
+
raise ProjectError(
|
|
75
|
+
f"Multiple env classes {names}; pass --env-class to pick one."
|
|
76
|
+
)
|
|
77
|
+
return defined_here[0]
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _load_jsonl(path: Path) -> list[dict[str, Any]]:
|
|
81
|
+
if not path.exists():
|
|
82
|
+
raise ProjectError(f"Dataset not found: {path}")
|
|
83
|
+
rows: list[dict[str, Any]] = []
|
|
84
|
+
for n, raw in enumerate(path.read_text(encoding="utf-8").splitlines(), 1):
|
|
85
|
+
line = raw.strip()
|
|
86
|
+
if not line:
|
|
87
|
+
continue
|
|
88
|
+
try:
|
|
89
|
+
rows.append(json.loads(line))
|
|
90
|
+
except json.JSONDecodeError as exc:
|
|
91
|
+
raise ProjectError(f"{path}:{n}: invalid JSON ({exc})") from exc
|
|
92
|
+
if not rows:
|
|
93
|
+
raise ProjectError(f"Dataset is empty: {path}")
|
|
94
|
+
return rows
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def load_project(
|
|
98
|
+
*,
|
|
99
|
+
directory: str = ".",
|
|
100
|
+
run_file: str = "run.py",
|
|
101
|
+
module_path: str | None = None,
|
|
102
|
+
env_class_name: str | None = None,
|
|
103
|
+
train_file: str = "train_dataset.jsonl",
|
|
104
|
+
eval_file: str = "eval_dataset.jsonl",
|
|
105
|
+
require_eval: bool = False,
|
|
106
|
+
) -> LoadedProject:
|
|
107
|
+
"""Load the env class + datasets for a project dir (or an importable module)."""
|
|
108
|
+
from_file = module_path is None
|
|
109
|
+
if module_path:
|
|
110
|
+
try:
|
|
111
|
+
module = importlib.import_module(module_path)
|
|
112
|
+
except Exception as exc: # missing dep, bad path, import-time error
|
|
113
|
+
raise ProjectError(
|
|
114
|
+
f"Could not import module {module_path!r}: {exc}"
|
|
115
|
+
) from exc
|
|
116
|
+
else:
|
|
117
|
+
path = Path(directory) / run_file
|
|
118
|
+
if not path.exists():
|
|
119
|
+
raise ProjectError(
|
|
120
|
+
f"{run_file} not found in {directory!r} — run inside a project dir, "
|
|
121
|
+
"or pass --module for an importable env."
|
|
122
|
+
)
|
|
123
|
+
module = _load_module_from_file(path)
|
|
124
|
+
|
|
125
|
+
env_class = discover_env_class(module, env_class_name)
|
|
126
|
+
base = Path(directory)
|
|
127
|
+
train_dataset = _load_jsonl(base / train_file)
|
|
128
|
+
eval_path = base / eval_file
|
|
129
|
+
eval_dataset = _load_jsonl(eval_path) if eval_path.exists() else []
|
|
130
|
+
if require_eval and not eval_dataset:
|
|
131
|
+
raise ProjectError(f"Eval dataset required but not found: {eval_path}")
|
|
132
|
+
return LoadedProject(
|
|
133
|
+
env_class=env_class,
|
|
134
|
+
train_dataset=train_dataset,
|
|
135
|
+
eval_dataset=eval_dataset,
|
|
136
|
+
module=module,
|
|
137
|
+
from_file=from_file,
|
|
138
|
+
)
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""Single source of truth for the RAG corpus providers and their sandbox deps.
|
|
2
|
+
|
|
3
|
+
``PROVIDER_PIP`` mirrors the per-provider extras in ``pyproject.toml`` — a unit test
|
|
4
|
+
asserts byte-equality against ``[project.optional-dependencies]``, so the two can't
|
|
5
|
+
silently drift (chroma must carry ``snowballstemmer``, which chromadb's BM25 needs but
|
|
6
|
+
doesn't declare). ``data``/``validate``/``launch`` all read their ``--provider``
|
|
7
|
+
choices and install hints from here instead of repeating the SDK names procedurally.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
# Maps a provider key → the pip requirements its search SDK needs in the rollout
|
|
13
|
+
# sandbox. Mirrors pyproject.toml's [project.optional-dependencies] provider extras.
|
|
14
|
+
PROVIDER_PIP: dict[str, list[str]] = {
|
|
15
|
+
"turbopuffer": ["turbopuffer>=1.16.2"],
|
|
16
|
+
"pinecone": ["pinecone>=5.0.0"],
|
|
17
|
+
# chromadb's BM25 embedding function needs snowballstemmer but doesn't declare it.
|
|
18
|
+
"chroma": ["chromadb>=1.0.0", "snowballstemmer>=2.2.0"],
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def provider_choices() -> list[str]:
|
|
23
|
+
"""The provider keys, for argparse ``choices=`` on ``--provider``."""
|
|
24
|
+
return list(PROVIDER_PIP)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def install_hint(provider: str) -> str:
|
|
28
|
+
"""The user-facing ``Install with: pip install castform[<extra>]`` line."""
|
|
29
|
+
return f"Install with: pip install castform[{provider}]"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def resolve_pip_dependencies(
|
|
33
|
+
explicit: list[str] | None,
|
|
34
|
+
env_class: type | None = None,
|
|
35
|
+
provider: str | None = None,
|
|
36
|
+
) -> list[str] | None:
|
|
37
|
+
"""Compose the rollout-sandbox pip deps for ``validate``/``launch``.
|
|
38
|
+
|
|
39
|
+
Merges, in order: ``--pip`` (``explicit``), the env's self-declared
|
|
40
|
+
``PIP_DEPENDENCIES`` class attribute, then ``PROVIDER_PIP[provider]`` when
|
|
41
|
+
``--provider`` was passed — de-duped preserving first-seen order. Returns
|
|
42
|
+
``None`` when nothing resolves, so the no-slot/no-provider case is the old
|
|
43
|
+
``args.pip or None`` verbatim (preserving the single ``dump_bundle`` channel).
|
|
44
|
+
|
|
45
|
+
Read CLI-side, not sandbox-side: ``dump_bundle`` pickles the env class by value
|
|
46
|
+
but the sandbox installs from ``BundleMetadata.pip_dependencies``, so the slot
|
|
47
|
+
must be resolved here and fed the existing ``pip_dependencies=`` channel.
|
|
48
|
+
"""
|
|
49
|
+
deps: list[str] = list(explicit or [])
|
|
50
|
+
|
|
51
|
+
declared = getattr(env_class, "PIP_DEPENDENCIES", None)
|
|
52
|
+
if isinstance(declared, (list, tuple)): # guard: a list-of-str slot only
|
|
53
|
+
deps.extend(d for d in declared if isinstance(d, str))
|
|
54
|
+
|
|
55
|
+
if provider:
|
|
56
|
+
deps.extend(PROVIDER_PIP.get(provider, []))
|
|
57
|
+
|
|
58
|
+
seen: set[str] = set()
|
|
59
|
+
ordered = [d for d in deps if not (d in seen or seen.add(d))]
|
|
60
|
+
return ordered or None
|
benchmax/cli/control.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""castform run-control verbs (slice 1.3).
|
|
2
|
+
|
|
3
|
+
Top-level ``castform stop <id>`` → ``POST /v1/train/runs/{id}/cancel``. Owner-only
|
|
4
|
+
(403 otherwise). A launched run emits ``training.cancelling`` + cancels its
|
|
5
|
+
launcher job; a run with no job is marked complete directly (no cancelling event).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import argparse
|
|
11
|
+
|
|
12
|
+
from benchmax.cli._client import handle_errors, trainer_client
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@handle_errors
|
|
16
|
+
def _cmd_stop(args: argparse.Namespace) -> int:
|
|
17
|
+
with trainer_client() as client:
|
|
18
|
+
result = client.cancel_run(args.run_id)
|
|
19
|
+
message = result.get("message") or "Cancellation requested"
|
|
20
|
+
print(f"✓ {message}")
|
|
21
|
+
return 0
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def register(sub: argparse._SubParsersAction) -> None:
|
|
25
|
+
"""Attach the top-level `stop` verb."""
|
|
26
|
+
p_stop = sub.add_parser("stop", help="Stop (cancel) a run you own")
|
|
27
|
+
p_stop.add_argument("run_id")
|
|
28
|
+
p_stop.set_defaults(func=_cmd_stop)
|