benchmax 0.1.2.dev16__tar.gz → 0.1.2.dev17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/PKG-INFO +1 -1
  2. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/pyproject.toml +1 -1
  3. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/adapters/benchmax_wrapper.py +4 -3
  4. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/base_env.py +46 -21
  5. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/crm/workdir/reward_fn.py +3 -2
  6. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/excel/excel_env.py +2 -2
  7. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/excel/workdir/reward_fn.py +2 -8
  8. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/math/workdir/reward_fn.py +5 -3
  9. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/mcp/example_workdir/reward_fn.py +7 -4
  10. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/mcp/parallel_mcp_env.py +2 -2
  11. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/types.py +4 -2
  12. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/wikipedia/wiki_env.py +12 -7
  13. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax.egg-info/PKG-INFO +1 -1
  14. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/LICENSE +0 -0
  15. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/README.md +0 -0
  16. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/setup.cfg +0 -0
  17. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/adapters/__init__.py +0 -0
  18. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/adapters/skyrl/benchmax_data_process.py +0 -0
  19. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/adapters/skyrl/skyrl_adapter.py +0 -0
  20. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/bundle/__init__.py +0 -0
  21. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/bundle/bundler.py +0 -0
  22. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/bundle/errors.py +0 -0
  23. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/bundle/loader.py +0 -0
  24. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/bundle/payload.py +0 -0
  25. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/bundle/validator.py +0 -0
  26. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/__init__.py +0 -0
  27. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/crm/crm_env.py +0 -0
  28. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/crm/workdir/salesforce_mcp.py +0 -0
  29. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/excel/data_utils.py +0 -0
  30. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/excel/workdir/__init__.py +0 -0
  31. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/excel/workdir/excel_code_runner_mcp.py +0 -0
  32. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/excel/workdir/excel_utils.py +0 -0
  33. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/math/math_env.py +0 -0
  34. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/mcp/__init__.py +0 -0
  35. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/mcp/example_workdir/demo_mcp_server.py +0 -0
  36. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/mcp/provisioners/__init__.py +0 -0
  37. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/mcp/provisioners/base_provisioner.py +0 -0
  38. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/mcp/provisioners/local_provisioner.py +0 -0
  39. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/mcp/provisioners/manual_provisioner.py +0 -0
  40. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/mcp/provisioners/skypilot_provisioner.py +0 -0
  41. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/mcp/provisioners/utils.py +0 -0
  42. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/mcp/proxy_server.py +0 -0
  43. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/mcp/server_pool.py +0 -0
  44. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/mcp/utils.py +0 -0
  45. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/tracking.py +0 -0
  46. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/envs/wikipedia/utils.py +0 -0
  47. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/prompts/__init__.py +0 -0
  48. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax/prompts/tools.py +0 -0
  49. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax.egg-info/SOURCES.txt +0 -0
  50. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax.egg-info/dependency_links.txt +0 -0
  51. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax.egg-info/requires.txt +0 -0
  52. {benchmax-0.1.2.dev16 → benchmax-0.1.2.dev17}/src/benchmax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: benchmax
3
- Version: 0.1.2.dev16
3
+ Version: 0.1.2.dev17
4
4
  Summary: Framework-Agnostic RL Environments for LLM Fine-Tuning
5
5
  Author: cgft.io
6
6
  Classifier: Programming Language :: Python :: 3
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "benchmax"
3
- version = "0.1.2.dev16"
3
+ version = "0.1.2.dev17"
4
4
  description = "Framework-Agnostic RL Environments for LLM Fine-Tuning"
5
5
  readme = "README.md"
6
6
  authors = [{ name = "cgft.io" }]
@@ -5,6 +5,7 @@ from ray.actor import ActorClass, ActorProxy
5
5
  from typing import Dict, List, Any, Optional, Type, Union
6
6
 
7
7
  from benchmax.envs.base_env import BaseEnv
8
+ from benchmax.envs.types import Completion
8
9
 
9
10
  # 5 minutes timeout in seconds
10
11
  RAY_GET_TIMEOUT = 300
@@ -67,7 +68,7 @@ class BenchmaxEnv:
67
68
 
68
69
  @ray.method
69
70
  async def compute_reward(
70
- self, rollout_id: str, completion: str, ground_truth: Any, **kwargs: Any
71
+ self, rollout_id: str, completion: Completion, ground_truth: Any, **kwargs: Any
71
72
  ) -> Dict[str, float]:
72
73
  return await self._env.compute_reward(
73
74
  rollout_id=rollout_id,
@@ -258,7 +259,7 @@ class BenchmaxEnvWrapper:
258
259
  async def compute_reward(
259
260
  self,
260
261
  rollout_id: str,
261
- completion: str,
262
+ completion: Completion,
262
263
  ground_truth: Any,
263
264
  **kwargs: Any,
264
265
  ) -> Dict[str, float]:
@@ -271,7 +272,7 @@ class BenchmaxEnvWrapper:
271
272
  def compute_reward_sync(
272
273
  self,
273
274
  rollout_id: str,
274
- completion: str,
275
+ completion: Completion,
275
276
  ground_truth: Any,
276
277
  **kwargs: Any,
277
278
  ) -> Dict[str, float]:
@@ -1,9 +1,10 @@
1
1
  from abc import ABC, abstractmethod
2
+ from functools import wraps
2
3
  from pathlib import Path
3
4
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
4
5
 
5
- from benchmax.envs.tracking import TrackingConfig, log_env, with_tracking
6
- from benchmax.envs.types import StandardizedExample, ToolDefinition
6
+ from benchmax.envs.tracking import TrackingConfig, log_env, tracking_context
7
+ from benchmax.envs.types import Completion, StandardizedExample, ToolDefinition
7
8
  from benchmax.prompts.tools import render_tools_prompt
8
9
 
9
10
  if TYPE_CHECKING:
@@ -16,32 +17,28 @@ class BaseEnv(ABC):
16
17
  system_prompt: str = ""
17
18
  _tracking_config: TrackingConfig | None = None
18
19
 
19
- def __init_subclass__(cls, **kwargs):
20
- super().__init_subclass__(**kwargs)
20
+ def __init__(self, **kwargs):
21
+ self._tracking_config: Optional[TrackingConfig] = None
21
22
 
22
- compute_reward = cls.__dict__.get("compute_reward")
23
- if compute_reward is None:
24
- return
25
- if getattr(compute_reward, "__benchmax_tracking_wrapped__", False):
26
- return
27
-
28
- wrapped = with_tracking(lambda self, *a, **kw: self.get_tracking_config())(
29
- compute_reward
30
- )
31
- setattr(wrapped, "__benchmax_tracking_wrapped__", True)
32
- setattr(cls, "compute_reward", wrapped)
33
-
34
- def __init__(
23
+ def enable_tracking(
35
24
  self,
36
25
  experiment_id: Optional[str] = None,
37
26
  api_key: Optional[str] = None,
38
- **kwargs,
39
- ):
27
+ ) -> None:
28
+ """Enable experiment tracking. Wraps compute_reward on this instance with a tracking context."""
40
29
  self._tracking_config = TrackingConfig(
41
30
  experiment_id=experiment_id, api_key=api_key
42
31
  )
32
+ cls_compute_reward = type(self).compute_reward
43
33
 
44
- def get_tracking_config(self) -> TrackingConfig | None:
34
+ @wraps(cls_compute_reward)
35
+ async def _tracked(*args, **kwargs):
36
+ with tracking_context(self._tracking_config):
37
+ return await cls_compute_reward(self, *args, **kwargs)
38
+
39
+ self.compute_reward = _tracked
40
+
41
+ def get_tracking_config(self) -> Optional[TrackingConfig]:
45
42
  return self._tracking_config
46
43
 
47
44
  def log_env(self, rollout_id: str, message: str) -> None:
@@ -106,7 +103,7 @@ class BaseEnv(ABC):
106
103
 
107
104
  @abstractmethod
108
105
  async def compute_reward(
109
- self, rollout_id: str, completion: str, ground_truth: Any, **kwargs: Any
106
+ self, rollout_id: str, completion: Completion, ground_truth: Any, **kwargs: Any
110
107
  ) -> Dict[str, float]:
111
108
  """Compute rewards using registered functions
112
109
 
@@ -114,6 +111,34 @@ class BaseEnv(ABC):
114
111
  """
115
112
  pass
116
113
 
114
+ async def compute_group_reward(
115
+ self,
116
+ rollout_ids: List[str],
117
+ completions: List[str | List[Dict[str, str]]],
118
+ ground_truths: List[Any],
119
+ **kwargs: Any,
120
+ ) -> List[Dict[str, float]]:
121
+ """Compute rewards across a group of rollouts jointly.
122
+
123
+ Override this when reward computation requires cross-rollout context (e.g.,
124
+ relative scoring, group normalization, or deduplication). Can be used alongside
125
+ ``compute_reward`` — the two are not mutually exclusive. The default implementation
126
+ returns empty reward dicts, deferring entirely to per-rollout ``compute_reward`` calls.
127
+
128
+ Args:
129
+ rollout_ids: Identifiers for each rollout in the group.
130
+ completions: Model outputs, one per rollout. Each entry is either a
131
+ plain string or a list of message dicts.
132
+ ground_truths: Reference answers, one per rollout.
133
+ **kwargs: Additional environment-specific arguments.
134
+
135
+ Returns:
136
+ A list of reward dicts (one per rollout), each mapping reward function
137
+ names to their computed scores. An empty dict signals that no group
138
+ reward was computed for that rollout.
139
+ """
140
+ return [{} for _ in rollout_ids]
141
+
117
142
  async def get_system_prompt(self, add_tool_defs: bool = False) -> str:
118
143
  """Get system prompt. To add tool definitions, set add_tool_defs to True."""
119
144
  if add_tool_defs:
@@ -95,7 +95,7 @@ def get_all_metrics(proposed_answer: str, ground_truth: str) -> float:
95
95
 
96
96
 
97
97
  def crm_matching_reward_function(
98
- completion: str,
98
+ completion: List[Dict[str, Any]],
99
99
  ground_truth: List[str],
100
100
  mcp_client: Client,
101
101
  workspace: Path,
@@ -119,7 +119,8 @@ def crm_matching_reward_function(
119
119
  if not reward_metric:
120
120
  raise ValueError("kwargs must contain reward metric")
121
121
 
122
- proposed_answer = completion.strip() if completion else ""
122
+ completion_text = completion[-1].get("content", "") if completion else ""
123
+ proposed_answer = completion_text.strip() if completion_text else ""
123
124
  proposed_answer = parse_answers(proposed_answer)
124
125
 
125
126
  if reward_metric == "exact_match":
@@ -7,7 +7,7 @@ from benchmax.envs.mcp.parallel_mcp_env import ParallelMcpEnv
7
7
  from benchmax.envs.mcp.provisioners.base_provisioner import BaseProvisioner
8
8
  from benchmax.envs.mcp.provisioners.local_provisioner import LocalProvisioner
9
9
  from benchmax.envs.mcp.provisioners.skypilot_provisioner import SkypilotProvisioner
10
- from benchmax.envs.types import StandardizedExample
10
+ from benchmax.envs.types import Completion, StandardizedExample
11
11
  from .data_utils import download_and_extract
12
12
 
13
13
  # Using library shared with mcp workdir
@@ -162,7 +162,7 @@ Output Path: {output_filename}"""
162
162
  await self.copy_to_workspace(rollout_id, Path(input_src_path))
163
163
 
164
164
  async def compute_reward(
165
- self, rollout_id: str, completion: str, ground_truth: Any, **kwargs: Any
165
+ self, rollout_id: str, completion: Completion, ground_truth: Any, **kwargs: Any
166
166
  ) -> Dict[str, float]:
167
167
  answer_position: Optional[str] = kwargs.get("answer_position")
168
168
  output_filename: Optional[str] = kwargs.get("output_filename")
@@ -1,5 +1,5 @@
1
1
  from pathlib import Path
2
- from typing import Any, Awaitable, Callable, Dict, Optional, Union
2
+ from typing import Any, Awaitable, Callable, Dict, List, Optional, Union
3
3
 
4
4
  from fastmcp import Client
5
5
 
@@ -13,7 +13,7 @@ RewardFunction = Callable[..., Union[float, Awaitable[float]]]
13
13
 
14
14
 
15
15
  def spreadsheet_comparison_reward(
16
- completion: str,
16
+ completion: List[Dict[str, Any]],
17
17
  ground_truth: dict,
18
18
  mcp_client: Client,
19
19
  workspace: Path,
@@ -34,23 +34,17 @@ def spreadsheet_comparison_reward(
34
34
 
35
35
  output_path = workspace / output_filename
36
36
  ground_truth_path = workspace / ground_truth_filename
37
- rollout_id = kwargs.get("rollout_id", "unknown_rollout")
38
- log_env(
39
- rollout_id, f"excel_reward:compare_files={ground_truth_filename}:{output_filename}:{answer_position}"
40
- )
41
37
 
42
38
  # Return 1.0 score if the output completely matches the ground truth
43
39
  try:
44
40
  match, _ = compare_excel_cells(
45
41
  str(ground_truth_path), str(output_path), answer_position
46
42
  )
47
- log_env(rollout_id, f"excel_reward:spreadsheet_match={float(match)}")
48
43
  return 1.0 if match else 0.0
49
44
  except Exception as e:
50
45
  print(
51
46
  f"Error comparing spreadsheets {ground_truth_path} and {output_path}: {e}"
52
47
  )
53
- log_env(rollout_id, f"excel_reward:error={str(e)}")
54
48
  return 0.0
55
49
 
56
50
 
@@ -1,6 +1,6 @@
1
1
  from pathlib import Path
2
2
  import re
3
- from typing import Any, Callable, Dict, Union, Awaitable
3
+ from typing import Any, Callable, Dict, List, Union, Awaitable
4
4
  from fastmcp import Client
5
5
  from html import unescape
6
6
 
@@ -8,7 +8,7 @@ RewardFunction = Callable[..., Union[float, Awaitable[float]]]
8
8
 
9
9
 
10
10
  async def text_match_reward(
11
- completion: str,
11
+ completion: List[Dict[str, Any]],
12
12
  ground_truth: str,
13
13
  mcp_client: Client,
14
14
  workspace: Path,
@@ -21,9 +21,11 @@ async def text_match_reward(
21
21
  Falls back to 0 if the tag is missing or empty.
22
22
  """
23
23
 
24
+ completion_text = completion[-1].get("content", "") if completion else ""
25
+
24
26
  # Grab only the text inside the first <answer> … </answer> pair (case-insensitive).
25
27
  m = re.search(
26
- r"<answer>(.*?)</answer>", completion, flags=re.IGNORECASE | re.DOTALL
28
+ r"<answer>(.*?)</answer>", completion_text, flags=re.IGNORECASE | re.DOTALL
27
29
  )
28
30
  if m is None:
29
31
  return 0.0
@@ -23,7 +23,7 @@ and reaading from the workspace that the MCP is operating in.
23
23
  """
24
24
 
25
25
  from pathlib import Path
26
- from typing import Any, Callable, Dict, Union, Awaitable
26
+ from typing import Any, Callable, Dict, List, Union, Awaitable
27
27
  from mcp.types import TextContent
28
28
  from fastmcp import Client
29
29
  from fastmcp.exceptions import ToolError
@@ -35,7 +35,7 @@ RewardFunction = Callable[..., Union[float, Awaitable[float]]]
35
35
  # Reward 0: Stateless completion check
36
36
  # -------------------------------
37
37
  async def completion_match_reward(
38
- completion: str,
38
+ completion: List[Dict[str, Any]],
39
39
  ground_truth: dict,
40
40
  mcp_client: Client,
41
41
  workspace: Path,
@@ -47,6 +47,9 @@ async def completion_match_reward(
47
47
  Uses: ground_truth['completion'] (str)
48
48
  """
49
49
  expected = ground_truth.get("completion", "")
50
+ if isinstance(completion, list):
51
+ completion = completion[-1].get("content", "") if completion else ""
52
+ completion = str(completion)
50
53
  return 1.0 if completion.strip() == expected.strip() else 0.0
51
54
 
52
55
 
@@ -54,7 +57,7 @@ async def completion_match_reward(
54
57
  # Reward 1: Tool call variable in memory check
55
58
  # -------------------------------
56
59
  async def variable_in_memory_reward(
57
- completion: str, ground_truth: dict, mcp_client: Client, workspace: Path, **kwargs
60
+ completion: str | list[dict], ground_truth: dict, mcp_client: Client, workspace: Path, **kwargs
58
61
  ) -> float:
59
62
  """
60
63
  Reward uses tool call to match in-memory variable value.
@@ -96,7 +99,7 @@ async def variable_in_memory_reward(
96
99
  # Reward 2: Workspace log check
97
100
  # -------------------------------
98
101
  async def log_in_workspace_reward(
99
- completion: str,
102
+ completion: str | list[dict],
100
103
  ground_truth: dict,
101
104
  mcp_client: Client,
102
105
  workspace: Path,
@@ -20,7 +20,7 @@ except ModuleNotFoundError as e:
20
20
 
21
21
  from benchmax.envs.base_env import BaseEnv
22
22
  from benchmax.envs.tracking import to_tracking_payload
23
- from benchmax.envs.types import ToolDefinition
23
+ from benchmax.envs.types import Completion, ToolDefinition
24
24
  from .server_pool import ServerPool
25
25
  from .provisioners.base_provisioner import BaseProvisioner
26
26
  from .utils import (
@@ -344,7 +344,7 @@ class ParallelMcpEnv(BaseEnv):
344
344
  return str(e)
345
345
 
346
346
  async def compute_reward(
347
- self, rollout_id: str, completion: str, ground_truth: Any, **kwargs: Any
347
+ self, rollout_id: str, completion: Completion, ground_truth: Any, **kwargs: Any
348
348
  ) -> Dict[str, float]:
349
349
  """
350
350
  Compute reward and cleanup rollout.
@@ -1,9 +1,11 @@
1
1
  from dataclasses import dataclass
2
- from typing import Any, Dict, Optional, TypedDict
2
+ from typing import Any, Dict, List, Optional, TypedDict
3
+
4
+ Completion = List[Dict[str, Any]]
3
5
 
4
6
 
5
7
  class StandardizedExample(TypedDict):
6
- prompt: str
8
+ prompt: str | List[Dict[str, Any]]
7
9
  ground_truth: Any
8
10
  init_rollout_args: Optional[Dict[str, Any]]
9
11
 
@@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
5
5
 
6
6
  from benchmax.envs.base_env import BaseEnv
7
7
  from benchmax.envs.tracking import log_env
8
- from benchmax.envs.types import ToolDefinition, StandardizedExample
8
+ from benchmax.envs.types import Completion, ToolDefinition, StandardizedExample
9
9
  from benchmax.envs.wikipedia.utils import APIKeyRotator, clean_html, safe_request
10
10
 
11
11
  SYSTEM_PROMPT = """Please use the tools provided to get accurate, up-to-date information.
@@ -14,12 +14,12 @@ Write your complete answer on the final line only as a concise entity, within th
14
14
  """
15
15
 
16
16
 
17
- def text_match_reward_function(completion: str, ground_truth: str, rollout_id: str, **kwargs) -> float:
17
+ def text_match_reward_function(completion: Completion, ground_truth: str, rollout_id: str, **kwargs) -> float:
18
18
  """
19
19
  Score 1.0 if ground truth appears in <answer> tags, else 0.0.
20
20
 
21
21
  Args:
22
- completion: The model's generated text
22
+ completion: The model's generated text (str or list of message dicts)
23
23
  ground_truth: Expected answer (case-insensitive)
24
24
  **kwargs: Catch-all for BaseEnv compatibility
25
25
 
@@ -27,17 +27,22 @@ def text_match_reward_function(completion: str, ground_truth: str, rollout_id: s
27
27
  1.0 if ground_truth matches the answer text, else 0.0
28
28
  """
29
29
  assert ground_truth is not None
30
+ completion_str = ""
31
+ if isinstance(completion, list):
32
+ completion_str = completion[-1].get("content", "") if completion else ""
33
+ elif isinstance(completion, str):
34
+ completion_str = completion
35
+ else:
36
+ completion_str = ""
30
37
 
31
38
  m = re.search(
32
- r"<answer>(.*?)</answer>", completion, flags=re.IGNORECASE | re.DOTALL
39
+ r"<answer>(.*?)</answer>", completion_str, flags=re.IGNORECASE | re.DOTALL
33
40
  )
34
41
  if not m:
35
- log_env(rollout_id, "wikipedia_reward:no_answer_tag")
36
42
  return 0.0
37
43
 
38
44
  answer_text = unescape(m.group(1)).strip().lower()
39
45
  score = float(ground_truth.lower() == answer_text)
40
- log_env(rollout_id, f"wikipedia_reward:text_match={score}")
41
46
  return score
42
47
 
43
48
 
@@ -264,7 +269,7 @@ class WikipediaEnv(BaseEnv):
264
269
  pass
265
270
 
266
271
  async def compute_reward(
267
- self, rollout_id: str, completion: str, ground_truth: Any, **kwargs: Any
272
+ self, rollout_id: str, completion: Completion, ground_truth: Any, **kwargs: Any
268
273
  ) -> Dict[str, float]:
269
274
  """Compute rewards using the text match reward function."""
270
275
  return {
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: benchmax
3
- Version: 0.1.2.dev16
3
+ Version: 0.1.2.dev17
4
4
  Summary: Framework-Agnostic RL Environments for LLM Fine-Tuning
5
5
  Author: cgft.io
6
6
  Classifier: Programming Language :: Python :: 3
File without changes
File without changes
File without changes