benchmax 0.1.2.dev14__tar.gz → 0.1.2.dev16__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/PKG-INFO +2 -1
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/pyproject.toml +2 -1
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/base_env.py +35 -2
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/excel/workdir/reward_fn.py +6 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/mcp/parallel_mcp_env.py +4 -1
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/mcp/provisioners/utils.py +11 -2
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/mcp/proxy_server.py +28 -6
- benchmax-0.1.2.dev16/src/benchmax/envs/tracking.py +134 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/wikipedia/wiki_env.py +7 -3
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax.egg-info/PKG-INFO +2 -1
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax.egg-info/SOURCES.txt +1 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax.egg-info/requires.txt +1 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/LICENSE +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/README.md +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/setup.cfg +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/adapters/__init__.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/adapters/benchmax_wrapper.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/adapters/skyrl/benchmax_data_process.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/adapters/skyrl/skyrl_adapter.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/bundle/__init__.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/bundle/bundler.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/bundle/errors.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/bundle/loader.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/bundle/payload.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/bundle/validator.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/__init__.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/crm/crm_env.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/crm/workdir/reward_fn.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/crm/workdir/salesforce_mcp.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/excel/data_utils.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/excel/excel_env.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/excel/workdir/__init__.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/excel/workdir/excel_code_runner_mcp.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/excel/workdir/excel_utils.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/math/math_env.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/math/workdir/reward_fn.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/mcp/__init__.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/mcp/example_workdir/demo_mcp_server.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/mcp/example_workdir/reward_fn.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/mcp/provisioners/__init__.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/mcp/provisioners/base_provisioner.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/mcp/provisioners/local_provisioner.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/mcp/provisioners/manual_provisioner.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/mcp/provisioners/skypilot_provisioner.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/mcp/server_pool.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/mcp/utils.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/types.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/wikipedia/utils.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/prompts/__init__.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/prompts/tools.py +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax.egg-info/dependency_links.txt +0 -0
- {benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: benchmax
|
|
3
|
-
Version: 0.1.2.
|
|
3
|
+
Version: 0.1.2.dev16
|
|
4
4
|
Summary: Framework-Agnostic RL Environments for LLM Fine-Tuning
|
|
5
5
|
Author: cgft.io
|
|
6
6
|
Classifier: Programming Language :: Python :: 3
|
|
@@ -12,6 +12,7 @@ Requires-Dist: aiohttp>=3.13.1
|
|
|
12
12
|
Requires-Dist: asyncio>=4.0.0
|
|
13
13
|
Requires-Dist: cloudpickle>=3.0.0
|
|
14
14
|
Requires-Dist: datasets>=4.0.0
|
|
15
|
+
Requires-Dist: expt-logger>=0.1.0.dev20
|
|
15
16
|
Provides-Extra: mcp
|
|
16
17
|
Requires-Dist: fastmcp~=2.12.0; extra == "mcp"
|
|
17
18
|
Requires-Dist: pyjwt>=2.10.1; extra == "mcp"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "benchmax"
|
|
3
|
-
version = "0.1.2.
|
|
3
|
+
version = "0.1.2.dev16"
|
|
4
4
|
description = "Framework-Agnostic RL Environments for LLM Fine-Tuning"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
authors = [{ name = "cgft.io" }]
|
|
@@ -10,6 +10,7 @@ dependencies = [
|
|
|
10
10
|
"asyncio>=4.0.0",
|
|
11
11
|
"cloudpickle>=3.0.0",
|
|
12
12
|
"datasets>=4.0.0",
|
|
13
|
+
"expt-logger>=0.1.0.dev20",
|
|
13
14
|
]
|
|
14
15
|
classifiers = [
|
|
15
16
|
"Programming Language :: Python :: 3",
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import TYPE_CHECKING, Dict, List, Any, Optional, Tuple
|
|
3
2
|
from pathlib import Path
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
|
4
4
|
|
|
5
|
-
from benchmax.envs.
|
|
5
|
+
from benchmax.envs.tracking import TrackingConfig, log_env, with_tracking
|
|
6
|
+
from benchmax.envs.types import StandardizedExample, ToolDefinition
|
|
6
7
|
from benchmax.prompts.tools import render_tools_prompt
|
|
7
8
|
|
|
8
9
|
if TYPE_CHECKING:
|
|
@@ -13,6 +14,38 @@ class BaseEnv(ABC):
|
|
|
13
14
|
"""Base benchmax environment for tool execution and reward computation"""
|
|
14
15
|
|
|
15
16
|
system_prompt: str = ""
|
|
17
|
+
_tracking_config: TrackingConfig | None = None
|
|
18
|
+
|
|
19
|
+
def __init_subclass__(cls, **kwargs):
|
|
20
|
+
super().__init_subclass__(**kwargs)
|
|
21
|
+
|
|
22
|
+
compute_reward = cls.__dict__.get("compute_reward")
|
|
23
|
+
if compute_reward is None:
|
|
24
|
+
return
|
|
25
|
+
if getattr(compute_reward, "__benchmax_tracking_wrapped__", False):
|
|
26
|
+
return
|
|
27
|
+
|
|
28
|
+
wrapped = with_tracking(lambda self, *a, **kw: self.get_tracking_config())(
|
|
29
|
+
compute_reward
|
|
30
|
+
)
|
|
31
|
+
setattr(wrapped, "__benchmax_tracking_wrapped__", True)
|
|
32
|
+
setattr(cls, "compute_reward", wrapped)
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
experiment_id: Optional[str] = None,
|
|
37
|
+
api_key: Optional[str] = None,
|
|
38
|
+
**kwargs,
|
|
39
|
+
):
|
|
40
|
+
self._tracking_config = TrackingConfig(
|
|
41
|
+
experiment_id=experiment_id, api_key=api_key
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def get_tracking_config(self) -> TrackingConfig | None:
|
|
45
|
+
return self._tracking_config
|
|
46
|
+
|
|
47
|
+
def log_env(self, rollout_id: str, message: str) -> None:
|
|
48
|
+
log_env(rollout_id, message)
|
|
16
49
|
|
|
17
50
|
# Override this method if your example does not match the default structure
|
|
18
51
|
@classmethod
|
|
@@ -34,17 +34,23 @@ def spreadsheet_comparison_reward(
|
|
|
34
34
|
|
|
35
35
|
output_path = workspace / output_filename
|
|
36
36
|
ground_truth_path = workspace / ground_truth_filename
|
|
37
|
+
rollout_id = kwargs.get("rollout_id", "unknown_rollout")
|
|
38
|
+
log_env(
|
|
39
|
+
rollout_id, f"excel_reward:compare_files={ground_truth_filename}:{output_filename}:{answer_position}"
|
|
40
|
+
)
|
|
37
41
|
|
|
38
42
|
# Return 1.0 score if the output completely matches the ground truth
|
|
39
43
|
try:
|
|
40
44
|
match, _ = compare_excel_cells(
|
|
41
45
|
str(ground_truth_path), str(output_path), answer_position
|
|
42
46
|
)
|
|
47
|
+
log_env(rollout_id, f"excel_reward:spreadsheet_match={float(match)}")
|
|
43
48
|
return 1.0 if match else 0.0
|
|
44
49
|
except Exception as e:
|
|
45
50
|
print(
|
|
46
51
|
f"Error comparing spreadsheets {ground_truth_path} and {output_path}: {e}"
|
|
47
52
|
)
|
|
53
|
+
log_env(rollout_id, f"excel_reward:error={str(e)}")
|
|
48
54
|
return 0.0
|
|
49
55
|
|
|
50
56
|
|
|
@@ -19,6 +19,7 @@ except ModuleNotFoundError as e:
|
|
|
19
19
|
) from e
|
|
20
20
|
|
|
21
21
|
from benchmax.envs.base_env import BaseEnv
|
|
22
|
+
from benchmax.envs.tracking import to_tracking_payload
|
|
22
23
|
from benchmax.envs.types import ToolDefinition
|
|
23
24
|
from .server_pool import ServerPool
|
|
24
25
|
from .provisioners.base_provisioner import BaseProvisioner
|
|
@@ -96,7 +97,7 @@ class ParallelMcpEnv(BaseEnv):
|
|
|
96
97
|
provision_at_init: Whether to launch a server at the point of initialization
|
|
97
98
|
**kwargs: Additional keyword arguments (currently unused).
|
|
98
99
|
"""
|
|
99
|
-
super().__init__()
|
|
100
|
+
super().__init__(**kwargs)
|
|
100
101
|
|
|
101
102
|
self._workdir_path = Path(workdir_path).absolute()
|
|
102
103
|
self._provisioner = provisioner
|
|
@@ -373,6 +374,8 @@ class ParallelMcpEnv(BaseEnv):
|
|
|
373
374
|
payload = {
|
|
374
375
|
"completion": completion or "",
|
|
375
376
|
"ground_truth": ground_truth or "",
|
|
377
|
+
**to_tracking_payload(self.get_tracking_config()),
|
|
378
|
+
"rollout_id": rollout_id,
|
|
376
379
|
**kwargs,
|
|
377
380
|
}
|
|
378
381
|
|
|
@@ -15,7 +15,8 @@ def setup_sync_dir(workdir_path: Path) -> Path:
|
|
|
15
15
|
|
|
16
16
|
This creates a temp directory and copies:
|
|
17
17
|
1. proxy_server.py from the mcp/ directory
|
|
18
|
-
2.
|
|
18
|
+
2. env tracking helper for reward logging
|
|
19
|
+
3. All contents of the provided workdir_path
|
|
19
20
|
|
|
20
21
|
Args:
|
|
21
22
|
workdir_path: Path to workdir containing mcp_config.yaml, setup.sh, etc.
|
|
@@ -38,6 +39,14 @@ def setup_sync_dir(workdir_path: Path) -> Path:
|
|
|
38
39
|
)
|
|
39
40
|
shutil.copy(src_server_path, sync_dir / "proxy_server.py")
|
|
40
41
|
|
|
42
|
+
# Copy shared env tracking helper for reward_fn logging.
|
|
43
|
+
src_tracking_path = Path(__file__).parents[2] / "tracking.py"
|
|
44
|
+
if not src_tracking_path.exists():
|
|
45
|
+
raise FileNotFoundError(
|
|
46
|
+
f"Expected tracking helper at {src_tracking_path}, but not found."
|
|
47
|
+
)
|
|
48
|
+
shutil.copy(src_tracking_path, sync_dir / "env_tracking.py")
|
|
49
|
+
|
|
41
50
|
# Validate workdir exists and is a directory
|
|
42
51
|
if not workdir_path.exists():
|
|
43
52
|
raise FileNotFoundError(
|
|
@@ -91,7 +100,7 @@ def get_setup_command() -> str:
|
|
|
91
100
|
# Install uv
|
|
92
101
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
|
93
102
|
UV_VENV_CLEAR=1 uv venv ~/venv && source ~/venv/bin/activate
|
|
94
|
-
uv pip install fastmcp~=2.12.0 pyyaml psutil
|
|
103
|
+
uv pip install fastmcp~=2.12.0 pyyaml psutil expt-logger
|
|
95
104
|
bash setup.sh
|
|
96
105
|
"""
|
|
97
106
|
|
|
@@ -24,12 +24,30 @@ from starlette.requests import Request
|
|
|
24
24
|
from starlette.responses import PlainTextResponse, FileResponse, JSONResponse, Response
|
|
25
25
|
from starlette.datastructures import UploadFile
|
|
26
26
|
|
|
27
|
-
|
|
27
|
+
try:
|
|
28
|
+
from benchmax.envs.tracking import log_env, pop_tracking_config, tracking_context
|
|
29
|
+
except Exception:
|
|
30
|
+
# In provisioned MCP servers, this helper is copied as env_tracking.py.
|
|
31
|
+
from env_tracking import log_env, pop_tracking_config, tracking_context # type: ignore
|
|
32
|
+
|
|
33
|
+
from reward_fn import reward_functions as imported_reward_functions # type: ignore
|
|
28
34
|
|
|
29
35
|
RewardFunction = Callable[..., Union[float, Awaitable[float]]]
|
|
30
36
|
DEFAULT_API_SECRET = "dev_default_api_secret_please_change_me_32chars!"
|
|
31
37
|
|
|
32
38
|
|
|
39
|
+
def _with_log_env(func: RewardFunction) -> RewardFunction:
|
|
40
|
+
"""Decorator that binds the shared log_env callable into reward_fn globals."""
|
|
41
|
+
func.__globals__["log_env"] = log_env
|
|
42
|
+
return func
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
reward_functions: Dict[str, RewardFunction] = {
|
|
46
|
+
name: _with_log_env(func)
|
|
47
|
+
for name, func in (imported_reward_functions or {}).items()
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
|
|
33
51
|
# ---------------- Utility Functions ---------------- #
|
|
34
52
|
def setup_workspace(base_dir: Path) -> Path:
|
|
35
53
|
"""Create a unique workspace directory."""
|
|
@@ -329,14 +347,17 @@ class ProxyServer:
|
|
|
329
347
|
status_code=400,
|
|
330
348
|
)
|
|
331
349
|
|
|
350
|
+
payload_kwargs: Dict[str, Any] = {
|
|
351
|
+
k: v for k, v in data.items() if k not in ("completion", "ground_truth")
|
|
352
|
+
}
|
|
353
|
+
tracking_config = pop_tracking_config(payload_kwargs)
|
|
354
|
+
|
|
332
355
|
kwargs: Dict[str, Any] = {
|
|
333
356
|
"completion": completion,
|
|
334
357
|
"ground_truth": ground_truth,
|
|
335
358
|
"workspace": self.workspace,
|
|
336
359
|
"mcp_client": self.client,
|
|
337
|
-
**
|
|
338
|
-
k: v for k, v in data.items() if k not in ("completion", "ground_truth")
|
|
339
|
-
},
|
|
360
|
+
**payload_kwargs,
|
|
340
361
|
}
|
|
341
362
|
|
|
342
363
|
async def _call_reward(name: str, func: RewardFunction) -> Tuple[str, float]:
|
|
@@ -357,8 +378,9 @@ class ProxyServer:
|
|
|
357
378
|
rf: Dict[str, RewardFunction] = reward_functions or {}
|
|
358
379
|
|
|
359
380
|
try:
|
|
360
|
-
|
|
361
|
-
|
|
381
|
+
with tracking_context(tracking_config):
|
|
382
|
+
tasks = [_call_reward(name, func) for name, func in rf.items()]
|
|
383
|
+
results_list: List[Tuple[str, float]] = await asyncio.gather(*tasks)
|
|
362
384
|
results: Dict[str, float] = dict(results_list)
|
|
363
385
|
return JSONResponse(results)
|
|
364
386
|
except Exception as e:
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
from contextlib import contextmanager
|
|
7
|
+
from contextvars import ContextVar
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from functools import wraps
|
|
10
|
+
from typing import Any, Callable, Dict, Iterator, Optional
|
|
11
|
+
|
|
12
|
+
LOGGER = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
TRACKING_EXPERIMENT_ID_KEY = "__benchmax_expt_logger_experiment_id"
|
|
15
|
+
TRACKING_API_KEY_KEY = "__benchmax_expt_logger_api_key"
|
|
16
|
+
|
|
17
|
+
_ACTIVE_TRACKER: ContextVar[Any | None] = ContextVar(
|
|
18
|
+
"benchmax_active_expt_logger_tracker", default=None
|
|
19
|
+
)
|
|
20
|
+
_TRACKER_CACHE: Dict[tuple[Optional[str], Optional[str]], Any | None] = {}
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass(frozen=True)
|
|
24
|
+
class TrackingConfig:
|
|
25
|
+
experiment_id: Optional[str] = None
|
|
26
|
+
api_key: Optional[str] = None
|
|
27
|
+
|
|
28
|
+
def resolved_experiment_id(self) -> Optional[str]:
|
|
29
|
+
return self.experiment_id or os.getenv("EXPT_LOGGER_EXPERIMENT_ID")
|
|
30
|
+
|
|
31
|
+
def is_enabled(self) -> bool:
|
|
32
|
+
return bool(self.resolved_experiment_id())
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _build_tracker(config: TrackingConfig) -> Any | None:
|
|
36
|
+
if not config.is_enabled():
|
|
37
|
+
return None
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
import expt_logger
|
|
41
|
+
except Exception as e:
|
|
42
|
+
LOGGER.debug("expt_logger import failed; env tracking disabled: %s", e)
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
run = expt_logger.init(
|
|
47
|
+
experiment_id=config.resolved_experiment_id(),
|
|
48
|
+
api_key=config.api_key,
|
|
49
|
+
)
|
|
50
|
+
except Exception as e:
|
|
51
|
+
LOGGER.debug("expt_logger init failed; env tracking disabled: %s", e)
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
if hasattr(expt_logger, "log_environment"):
|
|
55
|
+
return expt_logger
|
|
56
|
+
if hasattr(run, "log_environment"):
|
|
57
|
+
return run
|
|
58
|
+
|
|
59
|
+
LOGGER.debug("expt_logger has no log_environment; env tracking disabled")
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def get_tracker(config: TrackingConfig | None) -> Any | None:
|
|
64
|
+
if config is None:
|
|
65
|
+
return None
|
|
66
|
+
|
|
67
|
+
key = (config.resolved_experiment_id(), config.api_key)
|
|
68
|
+
if key not in _TRACKER_CACHE:
|
|
69
|
+
_TRACKER_CACHE[key] = _build_tracker(config)
|
|
70
|
+
return _TRACKER_CACHE[key]
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@contextmanager
|
|
74
|
+
def tracking_context(config: TrackingConfig | None) -> Iterator[None]:
|
|
75
|
+
token = _ACTIVE_TRACKER.set(get_tracker(config))
|
|
76
|
+
try:
|
|
77
|
+
yield
|
|
78
|
+
finally:
|
|
79
|
+
_ACTIVE_TRACKER.reset(token)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def log_env(rollout_id: str, message: str) -> None:
|
|
83
|
+
tracker = _ACTIVE_TRACKER.get()
|
|
84
|
+
if tracker is None:
|
|
85
|
+
return
|
|
86
|
+
|
|
87
|
+
try:
|
|
88
|
+
tracker.log_environment(rollout_id, str(message))
|
|
89
|
+
except Exception as e:
|
|
90
|
+
LOGGER.debug("log_environment failed: %s", e)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def with_tracking(
|
|
94
|
+
config_resolver: Callable[..., TrackingConfig | None],
|
|
95
|
+
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
96
|
+
"""Wrap a function so calls run with an active env tracking context."""
|
|
97
|
+
|
|
98
|
+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
99
|
+
if inspect.iscoroutinefunction(func):
|
|
100
|
+
|
|
101
|
+
@wraps(func)
|
|
102
|
+
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
103
|
+
with tracking_context(config_resolver(*args, **kwargs)):
|
|
104
|
+
return await func(*args, **kwargs)
|
|
105
|
+
|
|
106
|
+
return async_wrapper
|
|
107
|
+
|
|
108
|
+
@wraps(func)
|
|
109
|
+
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
110
|
+
with tracking_context(config_resolver(*args, **kwargs)):
|
|
111
|
+
return func(*args, **kwargs)
|
|
112
|
+
|
|
113
|
+
return sync_wrapper
|
|
114
|
+
|
|
115
|
+
return decorator
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def to_tracking_payload(config: TrackingConfig | None) -> Dict[str, str]:
|
|
119
|
+
if config is None:
|
|
120
|
+
return {}
|
|
121
|
+
|
|
122
|
+
payload: Dict[str, str] = {}
|
|
123
|
+
resolved_experiment_id = config.resolved_experiment_id()
|
|
124
|
+
if resolved_experiment_id:
|
|
125
|
+
payload[TRACKING_EXPERIMENT_ID_KEY] = resolved_experiment_id
|
|
126
|
+
if config.api_key:
|
|
127
|
+
payload[TRACKING_API_KEY_KEY] = config.api_key
|
|
128
|
+
return payload
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def pop_tracking_config(payload: Dict[str, Any]) -> TrackingConfig:
|
|
132
|
+
experiment_id = payload.pop(TRACKING_EXPERIMENT_ID_KEY, None)
|
|
133
|
+
api_key = payload.pop(TRACKING_API_KEY_KEY, None)
|
|
134
|
+
return TrackingConfig(experiment_id=experiment_id, api_key=api_key)
|
|
@@ -4,6 +4,7 @@ import re
|
|
|
4
4
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
5
5
|
|
|
6
6
|
from benchmax.envs.base_env import BaseEnv
|
|
7
|
+
from benchmax.envs.tracking import log_env
|
|
7
8
|
from benchmax.envs.types import ToolDefinition, StandardizedExample
|
|
8
9
|
from benchmax.envs.wikipedia.utils import APIKeyRotator, clean_html, safe_request
|
|
9
10
|
|
|
@@ -13,7 +14,7 @@ Write your complete answer on the final line only as a concise entity, within th
|
|
|
13
14
|
"""
|
|
14
15
|
|
|
15
16
|
|
|
16
|
-
def text_match_reward_function(completion: str, ground_truth: str, **kwargs) -> float:
|
|
17
|
+
def text_match_reward_function(completion: str, ground_truth: str, rollout_id: str, **kwargs) -> float:
|
|
17
18
|
"""
|
|
18
19
|
Score 1.0 if ground truth appears in <answer> tags, else 0.0.
|
|
19
20
|
|
|
@@ -31,10 +32,13 @@ def text_match_reward_function(completion: str, ground_truth: str, **kwargs) ->
|
|
|
31
32
|
r"<answer>(.*?)</answer>", completion, flags=re.IGNORECASE | re.DOTALL
|
|
32
33
|
)
|
|
33
34
|
if not m:
|
|
35
|
+
log_env(rollout_id, "wikipedia_reward:no_answer_tag")
|
|
34
36
|
return 0.0
|
|
35
37
|
|
|
36
38
|
answer_text = unescape(m.group(1)).strip().lower()
|
|
37
|
-
|
|
39
|
+
score = float(ground_truth.lower() == answer_text)
|
|
40
|
+
log_env(rollout_id, f"wikipedia_reward:text_match={score}")
|
|
41
|
+
return score
|
|
38
42
|
|
|
39
43
|
|
|
40
44
|
def _make_wikipedia_tools(key_rotator: APIKeyRotator):
|
|
@@ -264,5 +268,5 @@ class WikipediaEnv(BaseEnv):
|
|
|
264
268
|
) -> Dict[str, float]:
|
|
265
269
|
"""Compute rewards using the text match reward function."""
|
|
266
270
|
return {
|
|
267
|
-
"text_match": text_match_reward_function(completion, ground_truth, **kwargs)
|
|
271
|
+
"text_match": text_match_reward_function(completion, ground_truth, rollout_id, **kwargs)
|
|
268
272
|
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: benchmax
|
|
3
|
-
Version: 0.1.2.
|
|
3
|
+
Version: 0.1.2.dev16
|
|
4
4
|
Summary: Framework-Agnostic RL Environments for LLM Fine-Tuning
|
|
5
5
|
Author: cgft.io
|
|
6
6
|
Classifier: Programming Language :: Python :: 3
|
|
@@ -12,6 +12,7 @@ Requires-Dist: aiohttp>=3.13.1
|
|
|
12
12
|
Requires-Dist: asyncio>=4.0.0
|
|
13
13
|
Requires-Dist: cloudpickle>=3.0.0
|
|
14
14
|
Requires-Dist: datasets>=4.0.0
|
|
15
|
+
Requires-Dist: expt-logger>=0.1.0.dev20
|
|
15
16
|
Provides-Extra: mcp
|
|
16
17
|
Requires-Dist: fastmcp~=2.12.0; extra == "mcp"
|
|
17
18
|
Requires-Dist: pyjwt>=2.10.1; extra == "mcp"
|
|
@@ -18,6 +18,7 @@ src/benchmax/bundle/payload.py
|
|
|
18
18
|
src/benchmax/bundle/validator.py
|
|
19
19
|
src/benchmax/envs/__init__.py
|
|
20
20
|
src/benchmax/envs/base_env.py
|
|
21
|
+
src/benchmax/envs/tracking.py
|
|
21
22
|
src/benchmax/envs/types.py
|
|
22
23
|
src/benchmax/envs/crm/crm_env.py
|
|
23
24
|
src/benchmax/envs/crm/workdir/reward_fn.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/adapters/skyrl/benchmax_data_process.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/crm/workdir/salesforce_mcp.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/excel/workdir/excel_utils.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/mcp/example_workdir/reward_fn.py
RENAMED
|
File without changes
|
{benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/mcp/provisioners/__init__.py
RENAMED
|
File without changes
|
{benchmax-0.1.2.dev14 → benchmax-0.1.2.dev16}/src/benchmax/envs/mcp/provisioners/base_provisioner.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|