benchmax 0.1.2.dev16__tar.gz → 0.1.2.dev18__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/PKG-INFO +1 -1
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/pyproject.toml +1 -1
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/adapters/benchmax_wrapper.py +4 -3
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/base_env.py +46 -21
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/crm/workdir/reward_fn.py +3 -2
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/excel/excel_env.py +2 -2
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/excel/workdir/reward_fn.py +2 -8
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/math/workdir/reward_fn.py +5 -3
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/mcp/example_workdir/reward_fn.py +7 -4
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/mcp/parallel_mcp_env.py +2 -2
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/tracking.py +2 -2
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/types.py +4 -2
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/wikipedia/wiki_env.py +12 -7
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax.egg-info/PKG-INFO +1 -1
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/LICENSE +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/README.md +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/setup.cfg +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/adapters/__init__.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/adapters/skyrl/benchmax_data_process.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/adapters/skyrl/skyrl_adapter.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/bundle/__init__.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/bundle/bundler.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/bundle/errors.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/bundle/loader.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/bundle/payload.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/bundle/validator.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/__init__.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/crm/crm_env.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/crm/workdir/salesforce_mcp.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/excel/data_utils.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/excel/workdir/__init__.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/excel/workdir/excel_code_runner_mcp.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/excel/workdir/excel_utils.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/math/math_env.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/mcp/__init__.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/mcp/example_workdir/demo_mcp_server.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/mcp/provisioners/__init__.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/mcp/provisioners/base_provisioner.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/mcp/provisioners/local_provisioner.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/mcp/provisioners/manual_provisioner.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/mcp/provisioners/skypilot_provisioner.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/mcp/provisioners/utils.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/mcp/proxy_server.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/mcp/server_pool.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/mcp/utils.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/wikipedia/utils.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/prompts/__init__.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/prompts/tools.py +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax.egg-info/SOURCES.txt +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax.egg-info/dependency_links.txt +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax.egg-info/requires.txt +0 -0
- {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax.egg-info/top_level.txt +0 -0
|
@@ -5,6 +5,7 @@ from ray.actor import ActorClass, ActorProxy
|
|
|
5
5
|
from typing import Dict, List, Any, Optional, Type, Union
|
|
6
6
|
|
|
7
7
|
from benchmax.envs.base_env import BaseEnv
|
|
8
|
+
from benchmax.envs.types import Completion
|
|
8
9
|
|
|
9
10
|
# 5 minutes timeout in seconds
|
|
10
11
|
RAY_GET_TIMEOUT = 300
|
|
@@ -67,7 +68,7 @@ class BenchmaxEnv:
|
|
|
67
68
|
|
|
68
69
|
@ray.method
|
|
69
70
|
async def compute_reward(
|
|
70
|
-
self, rollout_id: str, completion:
|
|
71
|
+
self, rollout_id: str, completion: Completion, ground_truth: Any, **kwargs: Any
|
|
71
72
|
) -> Dict[str, float]:
|
|
72
73
|
return await self._env.compute_reward(
|
|
73
74
|
rollout_id=rollout_id,
|
|
@@ -258,7 +259,7 @@ class BenchmaxEnvWrapper:
|
|
|
258
259
|
async def compute_reward(
|
|
259
260
|
self,
|
|
260
261
|
rollout_id: str,
|
|
261
|
-
completion:
|
|
262
|
+
completion: Completion,
|
|
262
263
|
ground_truth: Any,
|
|
263
264
|
**kwargs: Any,
|
|
264
265
|
) -> Dict[str, float]:
|
|
@@ -271,7 +272,7 @@ class BenchmaxEnvWrapper:
|
|
|
271
272
|
def compute_reward_sync(
|
|
272
273
|
self,
|
|
273
274
|
rollout_id: str,
|
|
274
|
-
completion:
|
|
275
|
+
completion: Completion,
|
|
275
276
|
ground_truth: Any,
|
|
276
277
|
**kwargs: Any,
|
|
277
278
|
) -> Dict[str, float]:
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
+
from functools import wraps
|
|
2
3
|
from pathlib import Path
|
|
3
4
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
|
4
5
|
|
|
5
|
-
from benchmax.envs.tracking import TrackingConfig, log_env,
|
|
6
|
-
from benchmax.envs.types import StandardizedExample, ToolDefinition
|
|
6
|
+
from benchmax.envs.tracking import TrackingConfig, log_env, tracking_context
|
|
7
|
+
from benchmax.envs.types import Completion, StandardizedExample, ToolDefinition
|
|
7
8
|
from benchmax.prompts.tools import render_tools_prompt
|
|
8
9
|
|
|
9
10
|
if TYPE_CHECKING:
|
|
@@ -16,32 +17,28 @@ class BaseEnv(ABC):
|
|
|
16
17
|
system_prompt: str = ""
|
|
17
18
|
_tracking_config: TrackingConfig | None = None
|
|
18
19
|
|
|
19
|
-
def
|
|
20
|
-
|
|
20
|
+
def __init__(self, **kwargs):
|
|
21
|
+
self._tracking_config: Optional[TrackingConfig] = None
|
|
21
22
|
|
|
22
|
-
|
|
23
|
-
if compute_reward is None:
|
|
24
|
-
return
|
|
25
|
-
if getattr(compute_reward, "__benchmax_tracking_wrapped__", False):
|
|
26
|
-
return
|
|
27
|
-
|
|
28
|
-
wrapped = with_tracking(lambda self, *a, **kw: self.get_tracking_config())(
|
|
29
|
-
compute_reward
|
|
30
|
-
)
|
|
31
|
-
setattr(wrapped, "__benchmax_tracking_wrapped__", True)
|
|
32
|
-
setattr(cls, "compute_reward", wrapped)
|
|
33
|
-
|
|
34
|
-
def __init__(
|
|
23
|
+
def enable_tracking(
|
|
35
24
|
self,
|
|
36
25
|
experiment_id: Optional[str] = None,
|
|
37
26
|
api_key: Optional[str] = None,
|
|
38
|
-
|
|
39
|
-
|
|
27
|
+
) -> None:
|
|
28
|
+
"""Enable experiment tracking. Wraps compute_reward on this instance with a tracking context."""
|
|
40
29
|
self._tracking_config = TrackingConfig(
|
|
41
30
|
experiment_id=experiment_id, api_key=api_key
|
|
42
31
|
)
|
|
32
|
+
cls_compute_reward = type(self).compute_reward
|
|
43
33
|
|
|
44
|
-
|
|
34
|
+
@wraps(cls_compute_reward)
|
|
35
|
+
async def _tracked(*args, **kwargs):
|
|
36
|
+
with tracking_context(self._tracking_config):
|
|
37
|
+
return await cls_compute_reward(self, *args, **kwargs)
|
|
38
|
+
|
|
39
|
+
self.compute_reward = _tracked
|
|
40
|
+
|
|
41
|
+
def get_tracking_config(self) -> Optional[TrackingConfig]:
|
|
45
42
|
return self._tracking_config
|
|
46
43
|
|
|
47
44
|
def log_env(self, rollout_id: str, message: str) -> None:
|
|
@@ -106,7 +103,7 @@ class BaseEnv(ABC):
|
|
|
106
103
|
|
|
107
104
|
@abstractmethod
|
|
108
105
|
async def compute_reward(
|
|
109
|
-
self, rollout_id: str, completion:
|
|
106
|
+
self, rollout_id: str, completion: Completion, ground_truth: Any, **kwargs: Any
|
|
110
107
|
) -> Dict[str, float]:
|
|
111
108
|
"""Compute rewards using registered functions
|
|
112
109
|
|
|
@@ -114,6 +111,34 @@ class BaseEnv(ABC):
|
|
|
114
111
|
"""
|
|
115
112
|
pass
|
|
116
113
|
|
|
114
|
+
async def compute_group_reward(
|
|
115
|
+
self,
|
|
116
|
+
rollout_ids: List[str],
|
|
117
|
+
completions: List[str | List[Dict[str, str]]],
|
|
118
|
+
ground_truths: List[Any],
|
|
119
|
+
**kwargs: Any,
|
|
120
|
+
) -> List[Dict[str, float]]:
|
|
121
|
+
"""Compute rewards across a group of rollouts jointly.
|
|
122
|
+
|
|
123
|
+
Override this when reward computation requires cross-rollout context (e.g.,
|
|
124
|
+
relative scoring, group normalization, or deduplication). Can be used alongside
|
|
125
|
+
``compute_reward`` — the two are not mutually exclusive. The default implementation
|
|
126
|
+
returns empty reward dicts, deferring entirely to per-rollout ``compute_reward`` calls.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
rollout_ids: Identifiers for each rollout in the group.
|
|
130
|
+
completions: Model outputs, one per rollout. Each entry is either a
|
|
131
|
+
plain string or a list of message dicts.
|
|
132
|
+
ground_truths: Reference answers, one per rollout.
|
|
133
|
+
**kwargs: Additional environment-specific arguments.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
A list of reward dicts (one per rollout), each mapping reward function
|
|
137
|
+
names to their computed scores. An empty dict signals that no group
|
|
138
|
+
reward was computed for that rollout.
|
|
139
|
+
"""
|
|
140
|
+
return [{} for _ in rollout_ids]
|
|
141
|
+
|
|
117
142
|
async def get_system_prompt(self, add_tool_defs: bool = False) -> str:
|
|
118
143
|
"""Get system prompt. To add tool definitions, set add_tool_defs to True."""
|
|
119
144
|
if add_tool_defs:
|
|
@@ -95,7 +95,7 @@ def get_all_metrics(proposed_answer: str, ground_truth: str) -> float:
|
|
|
95
95
|
|
|
96
96
|
|
|
97
97
|
def crm_matching_reward_function(
|
|
98
|
-
completion: str,
|
|
98
|
+
completion: List[Dict[str, Any]],
|
|
99
99
|
ground_truth: List[str],
|
|
100
100
|
mcp_client: Client,
|
|
101
101
|
workspace: Path,
|
|
@@ -119,7 +119,8 @@ def crm_matching_reward_function(
|
|
|
119
119
|
if not reward_metric:
|
|
120
120
|
raise ValueError("kwargs must contain reward metric")
|
|
121
121
|
|
|
122
|
-
|
|
122
|
+
completion_text = completion[-1].get("content", "") if completion else ""
|
|
123
|
+
proposed_answer = completion_text.strip() if completion_text else ""
|
|
123
124
|
proposed_answer = parse_answers(proposed_answer)
|
|
124
125
|
|
|
125
126
|
if reward_metric == "exact_match":
|
|
@@ -7,7 +7,7 @@ from benchmax.envs.mcp.parallel_mcp_env import ParallelMcpEnv
|
|
|
7
7
|
from benchmax.envs.mcp.provisioners.base_provisioner import BaseProvisioner
|
|
8
8
|
from benchmax.envs.mcp.provisioners.local_provisioner import LocalProvisioner
|
|
9
9
|
from benchmax.envs.mcp.provisioners.skypilot_provisioner import SkypilotProvisioner
|
|
10
|
-
from benchmax.envs.types import StandardizedExample
|
|
10
|
+
from benchmax.envs.types import Completion, StandardizedExample
|
|
11
11
|
from .data_utils import download_and_extract
|
|
12
12
|
|
|
13
13
|
# Using library shared with mcp workdir
|
|
@@ -162,7 +162,7 @@ Output Path: {output_filename}"""
|
|
|
162
162
|
await self.copy_to_workspace(rollout_id, Path(input_src_path))
|
|
163
163
|
|
|
164
164
|
async def compute_reward(
|
|
165
|
-
self, rollout_id: str, completion:
|
|
165
|
+
self, rollout_id: str, completion: Completion, ground_truth: Any, **kwargs: Any
|
|
166
166
|
) -> Dict[str, float]:
|
|
167
167
|
answer_position: Optional[str] = kwargs.get("answer_position")
|
|
168
168
|
output_filename: Optional[str] = kwargs.get("output_filename")
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from pathlib import Path
|
|
2
|
-
from typing import Any, Awaitable, Callable, Dict, Optional, Union
|
|
2
|
+
from typing import Any, Awaitable, Callable, Dict, List, Optional, Union
|
|
3
3
|
|
|
4
4
|
from fastmcp import Client
|
|
5
5
|
|
|
@@ -13,7 +13,7 @@ RewardFunction = Callable[..., Union[float, Awaitable[float]]]
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def spreadsheet_comparison_reward(
|
|
16
|
-
completion: str,
|
|
16
|
+
completion: List[Dict[str, Any]],
|
|
17
17
|
ground_truth: dict,
|
|
18
18
|
mcp_client: Client,
|
|
19
19
|
workspace: Path,
|
|
@@ -34,23 +34,17 @@ def spreadsheet_comparison_reward(
|
|
|
34
34
|
|
|
35
35
|
output_path = workspace / output_filename
|
|
36
36
|
ground_truth_path = workspace / ground_truth_filename
|
|
37
|
-
rollout_id = kwargs.get("rollout_id", "unknown_rollout")
|
|
38
|
-
log_env(
|
|
39
|
-
rollout_id, f"excel_reward:compare_files={ground_truth_filename}:{output_filename}:{answer_position}"
|
|
40
|
-
)
|
|
41
37
|
|
|
42
38
|
# Return 1.0 score if the output completely matches the ground truth
|
|
43
39
|
try:
|
|
44
40
|
match, _ = compare_excel_cells(
|
|
45
41
|
str(ground_truth_path), str(output_path), answer_position
|
|
46
42
|
)
|
|
47
|
-
log_env(rollout_id, f"excel_reward:spreadsheet_match={float(match)}")
|
|
48
43
|
return 1.0 if match else 0.0
|
|
49
44
|
except Exception as e:
|
|
50
45
|
print(
|
|
51
46
|
f"Error comparing spreadsheets {ground_truth_path} and {output_path}: {e}"
|
|
52
47
|
)
|
|
53
|
-
log_env(rollout_id, f"excel_reward:error={str(e)}")
|
|
54
48
|
return 0.0
|
|
55
49
|
|
|
56
50
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from pathlib import Path
|
|
2
2
|
import re
|
|
3
|
-
from typing import Any, Callable, Dict, Union, Awaitable
|
|
3
|
+
from typing import Any, Callable, Dict, List, Union, Awaitable
|
|
4
4
|
from fastmcp import Client
|
|
5
5
|
from html import unescape
|
|
6
6
|
|
|
@@ -8,7 +8,7 @@ RewardFunction = Callable[..., Union[float, Awaitable[float]]]
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
async def text_match_reward(
|
|
11
|
-
completion: str,
|
|
11
|
+
completion: List[Dict[str, Any]],
|
|
12
12
|
ground_truth: str,
|
|
13
13
|
mcp_client: Client,
|
|
14
14
|
workspace: Path,
|
|
@@ -21,9 +21,11 @@ async def text_match_reward(
|
|
|
21
21
|
Falls back to 0 if the tag is missing or empty.
|
|
22
22
|
"""
|
|
23
23
|
|
|
24
|
+
completion_text = completion[-1].get("content", "") if completion else ""
|
|
25
|
+
|
|
24
26
|
# Grab only the text inside the first <answer> … </answer> pair (case-insensitive).
|
|
25
27
|
m = re.search(
|
|
26
|
-
r"<answer>(.*?)</answer>",
|
|
28
|
+
r"<answer>(.*?)</answer>", completion_text, flags=re.IGNORECASE | re.DOTALL
|
|
27
29
|
)
|
|
28
30
|
if m is None:
|
|
29
31
|
return 0.0
|
{benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/mcp/example_workdir/reward_fn.py
RENAMED
|
@@ -23,7 +23,7 @@ and reaading from the workspace that the MCP is operating in.
|
|
|
23
23
|
"""
|
|
24
24
|
|
|
25
25
|
from pathlib import Path
|
|
26
|
-
from typing import Any, Callable, Dict, Union, Awaitable
|
|
26
|
+
from typing import Any, Callable, Dict, List, Union, Awaitable
|
|
27
27
|
from mcp.types import TextContent
|
|
28
28
|
from fastmcp import Client
|
|
29
29
|
from fastmcp.exceptions import ToolError
|
|
@@ -35,7 +35,7 @@ RewardFunction = Callable[..., Union[float, Awaitable[float]]]
|
|
|
35
35
|
# Reward 0: Stateless completion check
|
|
36
36
|
# -------------------------------
|
|
37
37
|
async def completion_match_reward(
|
|
38
|
-
completion: str,
|
|
38
|
+
completion: List[Dict[str, Any]],
|
|
39
39
|
ground_truth: dict,
|
|
40
40
|
mcp_client: Client,
|
|
41
41
|
workspace: Path,
|
|
@@ -47,6 +47,9 @@ async def completion_match_reward(
|
|
|
47
47
|
Uses: ground_truth['completion'] (str)
|
|
48
48
|
"""
|
|
49
49
|
expected = ground_truth.get("completion", "")
|
|
50
|
+
if isinstance(completion, list):
|
|
51
|
+
completion = completion[-1].get("content", "") if completion else ""
|
|
52
|
+
completion = str(completion)
|
|
50
53
|
return 1.0 if completion.strip() == expected.strip() else 0.0
|
|
51
54
|
|
|
52
55
|
|
|
@@ -54,7 +57,7 @@ async def completion_match_reward(
|
|
|
54
57
|
# Reward 1: Tool call variable in memory check
|
|
55
58
|
# -------------------------------
|
|
56
59
|
async def variable_in_memory_reward(
|
|
57
|
-
completion: str, ground_truth: dict, mcp_client: Client, workspace: Path, **kwargs
|
|
60
|
+
completion: str | list[dict], ground_truth: dict, mcp_client: Client, workspace: Path, **kwargs
|
|
58
61
|
) -> float:
|
|
59
62
|
"""
|
|
60
63
|
Reward uses tool call to match in-memory variable value.
|
|
@@ -96,7 +99,7 @@ async def variable_in_memory_reward(
|
|
|
96
99
|
# Reward 2: Workspace log check
|
|
97
100
|
# -------------------------------
|
|
98
101
|
async def log_in_workspace_reward(
|
|
99
|
-
completion: str,
|
|
102
|
+
completion: str | list[dict],
|
|
100
103
|
ground_truth: dict,
|
|
101
104
|
mcp_client: Client,
|
|
102
105
|
workspace: Path,
|
|
@@ -20,7 +20,7 @@ except ModuleNotFoundError as e:
|
|
|
20
20
|
|
|
21
21
|
from benchmax.envs.base_env import BaseEnv
|
|
22
22
|
from benchmax.envs.tracking import to_tracking_payload
|
|
23
|
-
from benchmax.envs.types import ToolDefinition
|
|
23
|
+
from benchmax.envs.types import Completion, ToolDefinition
|
|
24
24
|
from .server_pool import ServerPool
|
|
25
25
|
from .provisioners.base_provisioner import BaseProvisioner
|
|
26
26
|
from .utils import (
|
|
@@ -344,7 +344,7 @@ class ParallelMcpEnv(BaseEnv):
|
|
|
344
344
|
return str(e)
|
|
345
345
|
|
|
346
346
|
async def compute_reward(
|
|
347
|
-
self, rollout_id: str, completion:
|
|
347
|
+
self, rollout_id: str, completion: Completion, ground_truth: Any, **kwargs: Any
|
|
348
348
|
) -> Dict[str, float]:
|
|
349
349
|
"""
|
|
350
350
|
Compute reward and cleanup rollout.
|
|
@@ -79,13 +79,13 @@ def tracking_context(config: TrackingConfig | None) -> Iterator[None]:
|
|
|
79
79
|
_ACTIVE_TRACKER.reset(token)
|
|
80
80
|
|
|
81
81
|
|
|
82
|
-
def log_env(rollout_id: str, message: str) -> None:
|
|
82
|
+
def log_env(rollout_id: str, message: str, commit: bool = False, flush_k: int = 30) -> None:
|
|
83
83
|
tracker = _ACTIVE_TRACKER.get()
|
|
84
84
|
if tracker is None:
|
|
85
85
|
return
|
|
86
86
|
|
|
87
87
|
try:
|
|
88
|
-
tracker.log_environment(rollout_id, str(message))
|
|
88
|
+
tracker.log_environment(rollout_id, str(message), commit, flush_k)
|
|
89
89
|
except Exception as e:
|
|
90
90
|
LOGGER.debug("log_environment failed: %s", e)
|
|
91
91
|
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from typing import Any, Dict, Optional, TypedDict
|
|
2
|
+
from typing import Any, Dict, List, Optional, TypedDict
|
|
3
|
+
|
|
4
|
+
Completion = List[Dict[str, Any]]
|
|
3
5
|
|
|
4
6
|
|
|
5
7
|
class StandardizedExample(TypedDict):
|
|
6
|
-
prompt: str
|
|
8
|
+
prompt: str | List[Dict[str, Any]]
|
|
7
9
|
ground_truth: Any
|
|
8
10
|
init_rollout_args: Optional[Dict[str, Any]]
|
|
9
11
|
|
|
@@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
|
5
5
|
|
|
6
6
|
from benchmax.envs.base_env import BaseEnv
|
|
7
7
|
from benchmax.envs.tracking import log_env
|
|
8
|
-
from benchmax.envs.types import ToolDefinition, StandardizedExample
|
|
8
|
+
from benchmax.envs.types import Completion, ToolDefinition, StandardizedExample
|
|
9
9
|
from benchmax.envs.wikipedia.utils import APIKeyRotator, clean_html, safe_request
|
|
10
10
|
|
|
11
11
|
SYSTEM_PROMPT = """Please use the tools provided to get accurate, up-to-date information.
|
|
@@ -14,12 +14,12 @@ Write your complete answer on the final line only as a concise entity, within th
|
|
|
14
14
|
"""
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
def text_match_reward_function(completion:
|
|
17
|
+
def text_match_reward_function(completion: Completion, ground_truth: str, rollout_id: str, **kwargs) -> float:
|
|
18
18
|
"""
|
|
19
19
|
Score 1.0 if ground truth appears in <answer> tags, else 0.0.
|
|
20
20
|
|
|
21
21
|
Args:
|
|
22
|
-
completion: The model's generated text
|
|
22
|
+
completion: The model's generated text (str or list of message dicts)
|
|
23
23
|
ground_truth: Expected answer (case-insensitive)
|
|
24
24
|
**kwargs: Catch-all for BaseEnv compatibility
|
|
25
25
|
|
|
@@ -27,17 +27,22 @@ def text_match_reward_function(completion: str, ground_truth: str, rollout_id: s
|
|
|
27
27
|
1.0 if ground_truth matches the answer text, else 0.0
|
|
28
28
|
"""
|
|
29
29
|
assert ground_truth is not None
|
|
30
|
+
completion_str = ""
|
|
31
|
+
if isinstance(completion, list):
|
|
32
|
+
completion_str = completion[-1].get("content", "") if completion else ""
|
|
33
|
+
elif isinstance(completion, str):
|
|
34
|
+
completion_str = completion
|
|
35
|
+
else:
|
|
36
|
+
completion_str = ""
|
|
30
37
|
|
|
31
38
|
m = re.search(
|
|
32
|
-
r"<answer>(.*?)</answer>",
|
|
39
|
+
r"<answer>(.*?)</answer>", completion_str, flags=re.IGNORECASE | re.DOTALL
|
|
33
40
|
)
|
|
34
41
|
if not m:
|
|
35
|
-
log_env(rollout_id, "wikipedia_reward:no_answer_tag")
|
|
36
42
|
return 0.0
|
|
37
43
|
|
|
38
44
|
answer_text = unescape(m.group(1)).strip().lower()
|
|
39
45
|
score = float(ground_truth.lower() == answer_text)
|
|
40
|
-
log_env(rollout_id, f"wikipedia_reward:text_match={score}")
|
|
41
46
|
return score
|
|
42
47
|
|
|
43
48
|
|
|
@@ -264,7 +269,7 @@ class WikipediaEnv(BaseEnv):
|
|
|
264
269
|
pass
|
|
265
270
|
|
|
266
271
|
async def compute_reward(
|
|
267
|
-
self, rollout_id: str, completion:
|
|
272
|
+
self, rollout_id: str, completion: Completion, ground_truth: Any, **kwargs: Any
|
|
268
273
|
) -> Dict[str, float]:
|
|
269
274
|
"""Compute rewards using the text match reward function."""
|
|
270
275
|
return {
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/adapters/skyrl/benchmax_data_process.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/crm/workdir/salesforce_mcp.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/excel/workdir/excel_utils.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/mcp/provisioners/__init__.py
RENAMED
|
File without changes
|
{benchmax-0.1.2.dev16 → benchmax-0.1.2.dev18}/src/benchmax/envs/mcp/provisioners/base_provisioner.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|