eval-protocol 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- development/__init__.py +1 -0
- development/normalize_sandbox_fusion.py +628 -0
- development/utils/__init__.py +1 -0
- development/utils/generate_api_key.py +31 -0
- development/utils/subprocess_manager.py +481 -0
- eval_protocol/__init__.py +86 -0
- eval_protocol/__main__.py +10 -0
- eval_protocol/_version.py +21 -0
- eval_protocol/adapters/__init__.py +1 -0
- eval_protocol/adapters/braintrust.py +8 -0
- eval_protocol/adapters/trl.py +8 -0
- eval_protocol/agent/__init__.py +29 -0
- eval_protocol/agent/models.py +69 -0
- eval_protocol/agent/orchestrator.py +893 -0
- eval_protocol/agent/resource_abc.py +89 -0
- eval_protocol/agent/resource_pool.py +184 -0
- eval_protocol/agent/resources/__init__.py +44 -0
- eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
- eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
- eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
- eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
- eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
- eval_protocol/agent/resources/docker_resource.py +479 -0
- eval_protocol/agent/resources/filesystem_resource.py +371 -0
- eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
- eval_protocol/agent/resources/http_rollout_resource.py +325 -0
- eval_protocol/agent/resources/python_state_resource.py +170 -0
- eval_protocol/agent/resources/sql_resource.py +271 -0
- eval_protocol/agent/task_manager.py +1064 -0
- eval_protocol/agent/tool_registry.py +111 -0
- eval_protocol/auth.py +156 -0
- eval_protocol/cli.py +425 -0
- eval_protocol/cli_commands/__init__.py +1 -0
- eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
- eval_protocol/cli_commands/common.py +242 -0
- eval_protocol/cli_commands/deploy.py +486 -0
- eval_protocol/cli_commands/deploy_mcp.py +287 -0
- eval_protocol/cli_commands/preview.py +186 -0
- eval_protocol/cli_commands/run_eval_cmd.py +202 -0
- eval_protocol/common_utils.py +36 -0
- eval_protocol/config.py +180 -0
- eval_protocol/datasets/__init__.py +1 -0
- eval_protocol/datasets/loader.py +521 -0
- eval_protocol/evaluation.py +1045 -0
- eval_protocol/execution/__init__.py +1 -0
- eval_protocol/execution/pipeline.py +920 -0
- eval_protocol/gcp_tools.py +484 -0
- eval_protocol/generation/cache.py +141 -0
- eval_protocol/generation/clients/base.py +67 -0
- eval_protocol/generation/clients.py +248 -0
- eval_protocol/generic_server.py +165 -0
- eval_protocol/integrations/__init__.py +12 -0
- eval_protocol/integrations/braintrust.py +51 -0
- eval_protocol/integrations/deepeval.py +106 -0
- eval_protocol/integrations/openeval.py +40 -0
- eval_protocol/integrations/trl.py +187 -0
- eval_protocol/mcp/__init__.py +48 -0
- eval_protocol/mcp/adapter.py +131 -0
- eval_protocol/mcp/client/__init__.py +12 -0
- eval_protocol/mcp/client/connection.py +499 -0
- eval_protocol/mcp/clients.py +195 -0
- eval_protocol/mcp/execution/__init__.py +23 -0
- eval_protocol/mcp/execution/base_policy.py +227 -0
- eval_protocol/mcp/execution/fireworks_policy.py +209 -0
- eval_protocol/mcp/execution/manager.py +506 -0
- eval_protocol/mcp/execution/policy.py +421 -0
- eval_protocol/mcp/grid_renderer.py +54 -0
- eval_protocol/mcp/mcpgym.py +637 -0
- eval_protocol/mcp/process_manager.py +177 -0
- eval_protocol/mcp/session/__init__.py +11 -0
- eval_protocol/mcp/session/manager.py +228 -0
- eval_protocol/mcp/simple_process_manager.py +291 -0
- eval_protocol/mcp/simulation_server.py +458 -0
- eval_protocol/mcp/types.py +80 -0
- eval_protocol/mcp_agent/__init__.py +1 -0
- eval_protocol/mcp_agent/config.py +147 -0
- eval_protocol/mcp_agent/intermediary_server.py +542 -0
- eval_protocol/mcp_agent/main.py +210 -0
- eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
- eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
- eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
- eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
- eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
- eval_protocol/mcp_agent/session.py +79 -0
- eval_protocol/mcp_env.py +304 -0
- eval_protocol/models.py +366 -0
- eval_protocol/packaging.py +219 -0
- eval_protocol/platform_api.py +360 -0
- eval_protocol/playback_policy.py +396 -0
- eval_protocol/resources.py +128 -0
- eval_protocol/reward_function.py +410 -0
- eval_protocol/rewards/__init__.py +94 -0
- eval_protocol/rewards/accuracy.py +454 -0
- eval_protocol/rewards/accuracy_length.py +173 -0
- eval_protocol/rewards/apps_coding_reward.py +331 -0
- eval_protocol/rewards/apps_execution_utils.py +149 -0
- eval_protocol/rewards/apps_testing_util.py +559 -0
- eval_protocol/rewards/bfcl_reward.py +313 -0
- eval_protocol/rewards/code_execution.py +1620 -0
- eval_protocol/rewards/code_execution_utils.py +72 -0
- eval_protocol/rewards/cpp_code.py +861 -0
- eval_protocol/rewards/deepcoder_reward.py +161 -0
- eval_protocol/rewards/format.py +129 -0
- eval_protocol/rewards/function_calling.py +541 -0
- eval_protocol/rewards/json_schema.py +422 -0
- eval_protocol/rewards/language_consistency.py +700 -0
- eval_protocol/rewards/lean_prover.py +479 -0
- eval_protocol/rewards/length.py +375 -0
- eval_protocol/rewards/list_comparison_math_reward.py +221 -0
- eval_protocol/rewards/math.py +762 -0
- eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
- eval_protocol/rewards/reasoning_steps.py +249 -0
- eval_protocol/rewards/repetition.py +342 -0
- eval_protocol/rewards/tag_count.py +162 -0
- eval_protocol/rl_processing.py +82 -0
- eval_protocol/server.py +271 -0
- eval_protocol/typed_interface.py +260 -0
- eval_protocol/utils/__init__.py +8 -0
- eval_protocol/utils/batch_evaluation.py +217 -0
- eval_protocol/utils/batch_transformation.py +205 -0
- eval_protocol/utils/dataset_helpers.py +112 -0
- eval_protocol/utils/module_loader.py +56 -0
- eval_protocol/utils/packaging_utils.py +108 -0
- eval_protocol/utils/static_policy.py +305 -0
- eval_protocol-0.0.3.dist-info/METADATA +635 -0
- eval_protocol-0.0.3.dist-info/RECORD +130 -0
- eval_protocol-0.0.3.dist-info/WHEEL +5 -0
- eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
- eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
- eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,410 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import importlib.util
|
|
3
|
+
import inspect
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import warnings
|
|
7
|
+
from functools import wraps
|
|
8
|
+
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union, cast
|
|
9
|
+
|
|
10
|
+
import requests
|
|
11
|
+
|
|
12
|
+
from .models import EvaluateResult, MetricResult
|
|
13
|
+
from .typed_interface import reward_function
|
|
14
|
+
|
|
15
|
+
logging.basicConfig(level=logging.INFO)
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
T = TypeVar("T", bound=Callable[..., EvaluateResult])
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class RewardFunction:
|
|
22
|
+
"""
|
|
23
|
+
A wrapper for reward functions that allows them to be run locally or remotely.
|
|
24
|
+
|
|
25
|
+
The RewardFunction class wraps a reward function (either a local function or a remote endpoint)
|
|
26
|
+
and provides a unified interface for calling it. It supports:
|
|
27
|
+
|
|
28
|
+
- Local functions (mode="local")
|
|
29
|
+
- Remote endpoints (mode="remote")
|
|
30
|
+
- Fireworks-hosted models (mode="fireworks_hosted")
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
func: The local function to use (for mode="local")
|
|
34
|
+
func_path: A string path to a function (e.g., "module.submodule:function_name")
|
|
35
|
+
mode: The mode of operation ("local", "remote", or "fireworks_hosted")
|
|
36
|
+
endpoint: The URL of the remote endpoint (for mode="remote")
|
|
37
|
+
model_id: The ID of the Fireworks-hosted model (for mode="fireworks_hosted")
|
|
38
|
+
**kwargs: Additional keyword arguments to pass to the function
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
func: Optional[Callable] = None,
|
|
44
|
+
func_path: Optional[str] = None,
|
|
45
|
+
mode: str = "local",
|
|
46
|
+
endpoint: Optional[str] = None,
|
|
47
|
+
name: Optional[str] = None,
|
|
48
|
+
model_id: Optional[str] = None,
|
|
49
|
+
**kwargs,
|
|
50
|
+
):
|
|
51
|
+
self.mode = mode
|
|
52
|
+
self.func = func
|
|
53
|
+
self.func_path = func_path
|
|
54
|
+
self.endpoint = endpoint
|
|
55
|
+
self.name = name
|
|
56
|
+
self.model_id = model_id
|
|
57
|
+
self.kwargs = kwargs
|
|
58
|
+
|
|
59
|
+
if mode == "local":
|
|
60
|
+
if func is None and func_path is None:
|
|
61
|
+
raise ValueError("Either 'func' or 'func_path' must be provided for local mode")
|
|
62
|
+
if func_path and func is None:
|
|
63
|
+
self.func = self._load_function_from_path(func_path)
|
|
64
|
+
elif mode == "remote":
|
|
65
|
+
if endpoint is None and name is None:
|
|
66
|
+
raise ValueError("Either 'endpoint' or 'name' must be provided for remote mode")
|
|
67
|
+
if name and endpoint is None:
|
|
68
|
+
self.endpoint = f"https://api.fireworks.ai/v1/reward/{name}"
|
|
69
|
+
elif mode == "fireworks_hosted":
|
|
70
|
+
if model_id is None:
|
|
71
|
+
raise ValueError("'model_id' must be provided for fireworks_hosted mode")
|
|
72
|
+
self.endpoint = f"https://api.fireworks.ai/v1/models/{model_id}/reward"
|
|
73
|
+
else:
|
|
74
|
+
raise ValueError(f"Invalid mode: {mode}")
|
|
75
|
+
|
|
76
|
+
def _load_function_from_path(self, func_path: str) -> Callable:
|
|
77
|
+
"""
|
|
78
|
+
Load a function from a path string.
|
|
79
|
+
The path string should be in the format 'module.submodule:function_name' or 'module.submodule.function_name'.
|
|
80
|
+
"""
|
|
81
|
+
# Check for the colon format first (preferred)
|
|
82
|
+
if ":" in func_path:
|
|
83
|
+
module_path, func_name = func_path.split(":", 1)
|
|
84
|
+
|
|
85
|
+
try:
|
|
86
|
+
module = importlib.import_module(module_path)
|
|
87
|
+
func = getattr(module, func_name)
|
|
88
|
+
return func
|
|
89
|
+
except (ImportError, AttributeError) as e:
|
|
90
|
+
raise ImportError(f"Failed to load function from path {func_path}: {str(e)}")
|
|
91
|
+
|
|
92
|
+
# Try dot notation format: module.path.function_name
|
|
93
|
+
# This assumes the last component is the function name
|
|
94
|
+
parts = func_path.split(".")
|
|
95
|
+
if len(parts) < 2:
|
|
96
|
+
raise ValueError(
|
|
97
|
+
f"Invalid func_path format: {func_path}, expected 'module.path:function_name' or 'module.path.function_name'"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
module_path = ".".join(parts[:-1])
|
|
101
|
+
func_name = parts[-1]
|
|
102
|
+
|
|
103
|
+
try:
|
|
104
|
+
module = importlib.import_module(module_path)
|
|
105
|
+
func = getattr(module, func_name)
|
|
106
|
+
return func
|
|
107
|
+
except (ImportError, AttributeError) as e:
|
|
108
|
+
raise ImportError(f"Failed to load function from path {func_path}: {str(e)}")
|
|
109
|
+
|
|
110
|
+
def __call__(
|
|
111
|
+
self,
|
|
112
|
+
messages: List[Dict[str, str]],
|
|
113
|
+
ground_truth: Optional[Union[str, List[Dict[str, str]]]] = None,
|
|
114
|
+
**kwargs,
|
|
115
|
+
) -> EvaluateResult:
|
|
116
|
+
"""
|
|
117
|
+
Call the reward function with the provided messages.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
messages: List of conversation messages, each with 'role' and 'content' keys
|
|
121
|
+
ground_truth: Ground truth data, which can be a string (e.g., an expected answer)
|
|
122
|
+
or a list of original conversation messages (for context).
|
|
123
|
+
If None and context is expected as a list, defaults to messages[:-1].
|
|
124
|
+
**kwargs: Additional keyword arguments to pass to the function
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
EvaluateResult object with score and metrics
|
|
128
|
+
"""
|
|
129
|
+
if ground_truth is None:
|
|
130
|
+
# Default to messages[:-1] if ground_truth is not provided and expected as context (list)
|
|
131
|
+
# This maintains previous behavior of original_messages defaulting.
|
|
132
|
+
# If ground_truth is meant to be a string and is None, it should be handled by the specific reward function.
|
|
133
|
+
ground_truth = messages[:-1] if messages else []
|
|
134
|
+
|
|
135
|
+
combined_kwargs = {**self.kwargs, **kwargs}
|
|
136
|
+
|
|
137
|
+
if self.mode == "local":
|
|
138
|
+
if self.func is None:
|
|
139
|
+
raise ValueError("No function provided for local mode")
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
result = self.func(
|
|
143
|
+
messages=messages,
|
|
144
|
+
ground_truth=ground_truth,
|
|
145
|
+
**combined_kwargs,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
if isinstance(result, EvaluateResult):
|
|
149
|
+
return result
|
|
150
|
+
elif isinstance(result, tuple) and len(result) == 2:
|
|
151
|
+
# Handle legacy (score, components) tuple format
|
|
152
|
+
warnings.warn(
|
|
153
|
+
"Tuple return format is deprecated. Use EvaluateResult instead.",
|
|
154
|
+
DeprecationWarning,
|
|
155
|
+
stacklevel=2,
|
|
156
|
+
)
|
|
157
|
+
score, components = result
|
|
158
|
+
metrics = {
|
|
159
|
+
k: MetricResult(score=v, reason=f"{k} score", is_score_valid=True)
|
|
160
|
+
for k, v in components.items()
|
|
161
|
+
}
|
|
162
|
+
return EvaluateResult(score=score, metrics=metrics)
|
|
163
|
+
elif isinstance(result, dict) and "score" in result:
|
|
164
|
+
# Handle dictionary return format
|
|
165
|
+
warnings.warn(
|
|
166
|
+
"Dictionary return format is deprecated. Use EvaluateResult instead.",
|
|
167
|
+
DeprecationWarning,
|
|
168
|
+
stacklevel=2,
|
|
169
|
+
)
|
|
170
|
+
metrics = {}
|
|
171
|
+
if "metrics" in result:
|
|
172
|
+
for k, v in result["metrics"].items():
|
|
173
|
+
if isinstance(v, dict):
|
|
174
|
+
metrics[k] = MetricResult(
|
|
175
|
+
score=v.get("score", 0.0),
|
|
176
|
+
reason=v.get("reason", f"{k} score"),
|
|
177
|
+
is_score_valid=v.get("is_score_valid", True),
|
|
178
|
+
)
|
|
179
|
+
else:
|
|
180
|
+
metrics[k] = MetricResult(
|
|
181
|
+
score=float(v),
|
|
182
|
+
reason=f"{k} score",
|
|
183
|
+
is_score_valid=True,
|
|
184
|
+
)
|
|
185
|
+
return EvaluateResult(
|
|
186
|
+
score=result["score"],
|
|
187
|
+
reason=result.get("reason"),
|
|
188
|
+
metrics=metrics,
|
|
189
|
+
)
|
|
190
|
+
else:
|
|
191
|
+
raise TypeError(
|
|
192
|
+
f"Invalid return type from reward function: {type(result)}. "
|
|
193
|
+
f"Expected EvaluateResult or (float, Dict[str, float]) tuple."
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
except Exception as e:
|
|
197
|
+
logger.error(f"Error calling local reward function: {str(e)}")
|
|
198
|
+
raise
|
|
199
|
+
|
|
200
|
+
elif self.mode in ["remote", "fireworks_hosted"]:
|
|
201
|
+
if self.endpoint is None:
|
|
202
|
+
raise ValueError(f"No endpoint provided for {self.mode} mode")
|
|
203
|
+
|
|
204
|
+
payload = {
|
|
205
|
+
"messages": messages,
|
|
206
|
+
"ground_truth": ground_truth,
|
|
207
|
+
**combined_kwargs,
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
api_key = os.environ.get("FIREWORKS_API_KEY")
|
|
211
|
+
headers = {
|
|
212
|
+
"Content-Type": "application/json",
|
|
213
|
+
"Authorization": f"Bearer {api_key}" if api_key else "",
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
try:
|
|
217
|
+
response = requests.post(self.endpoint, json=payload, headers=headers)
|
|
218
|
+
response.raise_for_status()
|
|
219
|
+
result = response.json()
|
|
220
|
+
|
|
221
|
+
if isinstance(result, dict) and "score" in result:
|
|
222
|
+
metrics = {}
|
|
223
|
+
if "metrics" in result:
|
|
224
|
+
for k, v in result["metrics"].items():
|
|
225
|
+
if isinstance(v, dict):
|
|
226
|
+
metrics[k] = MetricResult(
|
|
227
|
+
score=v.get("score", 0.0),
|
|
228
|
+
reason=v.get("reason", f"{k} score"),
|
|
229
|
+
is_score_valid=v.get("success", True),
|
|
230
|
+
)
|
|
231
|
+
else:
|
|
232
|
+
metrics[k] = MetricResult(
|
|
233
|
+
score=float(v),
|
|
234
|
+
reason=f"{k} score",
|
|
235
|
+
is_score_valid=True,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
return EvaluateResult(
|
|
239
|
+
score=result["score"],
|
|
240
|
+
reason=result.get("reason"),
|
|
241
|
+
metrics=metrics,
|
|
242
|
+
)
|
|
243
|
+
else:
|
|
244
|
+
raise ValueError(f"Invalid response from remote endpoint: {result}")
|
|
245
|
+
|
|
246
|
+
except Exception as e:
|
|
247
|
+
logger.error(f"Error calling remote endpoint: {str(e)}")
|
|
248
|
+
raise
|
|
249
|
+
|
|
250
|
+
raise ValueError(f"Invalid mode: {self.mode}")
|
|
251
|
+
|
|
252
|
+
def get_trl_adapter(self) -> Callable:
|
|
253
|
+
"""
|
|
254
|
+
Create an adapter function for use with TRL library.
|
|
255
|
+
|
|
256
|
+
The TRL library expects a function that takes batch inputs and returns a batch of reward values.
|
|
257
|
+
This adapter handles:
|
|
258
|
+
1. Batch of messages (List[List[Dict]]) and original messages (List[List[Dict]])
|
|
259
|
+
2. Batch of texts (List[str]) for simpler cases
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
A callable function compatible with TRL's expected signature for reward functions.
|
|
263
|
+
"""
|
|
264
|
+
|
|
265
|
+
def adapter(prompts: List[List[Dict]], completions: Optional[List[str]] = None, **kwargs) -> List[float]:
|
|
266
|
+
"""
|
|
267
|
+
Adapter function compatible with TRL's expected reward function signature.
|
|
268
|
+
TRL typically expects: (prompts: List[str], completions: List[str], **kwargs: Any) -> List[float]
|
|
269
|
+
This adapter handles the conversion from reward-kit's Message format.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
prompts: A batch of prompt message lists as expected by this RewardFunction instance.
|
|
273
|
+
Typically List[List[Dict[str, str]]], e.g.,
|
|
274
|
+
[[{'role':'system',...}, {'role':'user',...}], ...]
|
|
275
|
+
completions: A batch of generated completion strings by the model.
|
|
276
|
+
Optional. If None, it's assumed that the `prompts` argument
|
|
277
|
+
already contains the full conversation history including the assistant's response.
|
|
278
|
+
**kwargs: Additional keyword arguments passed by TRL. These often include
|
|
279
|
+
other columns from the HuggingFace dataset being used for training
|
|
280
|
+
(e.g., 'solution', 'reference_answer'). These are expected to be
|
|
281
|
+
lists of the same length as `prompts`.
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
A list of float reward scores for the batch, one score per sample.
|
|
285
|
+
"""
|
|
286
|
+
results = []
|
|
287
|
+
batch_size = len(prompts)
|
|
288
|
+
|
|
289
|
+
# If completions is None, assume prompts already contain complete conversations
|
|
290
|
+
if completions is None:
|
|
291
|
+
completions = [""] * batch_size
|
|
292
|
+
|
|
293
|
+
# Make sure completions has the right length after the None check
|
|
294
|
+
if batch_size != len(completions):
|
|
295
|
+
logger.warning(
|
|
296
|
+
f"Batch size mismatch between prompts ({batch_size}) and "
|
|
297
|
+
f"completions ({len(completions)}). Using min length."
|
|
298
|
+
)
|
|
299
|
+
batch_size = min(batch_size, len(completions))
|
|
300
|
+
|
|
301
|
+
# Extract potential ground truth solutions if available
|
|
302
|
+
# TRL passes columns from the dataset that weren't removed.
|
|
303
|
+
# We expect 'solution' based on our grpo_example.py setup.
|
|
304
|
+
solutions = kwargs.get("solution", [None] * batch_size)
|
|
305
|
+
if not isinstance(solutions, list) or len(solutions) != batch_size:
|
|
306
|
+
logger.warning(
|
|
307
|
+
f"Expected 'solution' kwarg to be a list of size {batch_size}, but got {type(solutions)}. Ground truth might not be passed correctly."
|
|
308
|
+
)
|
|
309
|
+
solutions = [None] * batch_size
|
|
310
|
+
|
|
311
|
+
for i in range(batch_size):
|
|
312
|
+
completion_input = completions[i]
|
|
313
|
+
actual_completion_str = ""
|
|
314
|
+
|
|
315
|
+
if isinstance(completion_input, list):
|
|
316
|
+
if completion_input:
|
|
317
|
+
first_element = completion_input[0]
|
|
318
|
+
if (
|
|
319
|
+
isinstance(first_element, dict)
|
|
320
|
+
and "content" in first_element
|
|
321
|
+
and isinstance(first_element.get("role"), str)
|
|
322
|
+
and first_element.get("role") == "assistant"
|
|
323
|
+
):
|
|
324
|
+
# Expected structure: completions[i] = [{'role': 'assistant', 'content': 'str_content'}]
|
|
325
|
+
actual_completion_str = str(first_element.get("content", ""))
|
|
326
|
+
logger.debug(
|
|
327
|
+
f"Adapter: completions[{i}] is a list with an assistant message dict. Extracted content."
|
|
328
|
+
)
|
|
329
|
+
else:
|
|
330
|
+
logger.warning(
|
|
331
|
+
f"Adapter: completions[{i}] is a list, but its first element "
|
|
332
|
+
f"is not the expected assistant message dict or is malformed: {first_element}. "
|
|
333
|
+
f"Using str(first_element) as content."
|
|
334
|
+
)
|
|
335
|
+
actual_completion_str = str(first_element)
|
|
336
|
+
else:
|
|
337
|
+
logger.warning(f"Adapter: completions[{i}] is an empty list. Using empty string for content.")
|
|
338
|
+
actual_completion_str = ""
|
|
339
|
+
elif isinstance(completion_input, str):
|
|
340
|
+
actual_completion_str = completion_input
|
|
341
|
+
else:
|
|
342
|
+
# Fallback for other types (e.g. a direct dict, though less likely given warnings)
|
|
343
|
+
logger.warning(
|
|
344
|
+
f"Adapter: completions[{i}] is of unexpected type: {type(completion_input)}. "
|
|
345
|
+
f"Attempting to stringify for content: {completion_input}"
|
|
346
|
+
)
|
|
347
|
+
actual_completion_str = str(completion_input)
|
|
348
|
+
|
|
349
|
+
messages = prompts[i] + [{"role": "assistant", "content": actual_completion_str}]
|
|
350
|
+
|
|
351
|
+
# Prepare kwargs for the underlying reward function call for this specific sample
|
|
352
|
+
call_kwargs: Dict[str, Any] = {} # Initialize with Any type for values
|
|
353
|
+
current_solution = solutions[i]
|
|
354
|
+
|
|
355
|
+
debug_solution_val_str = str(current_solution) if current_solution is not None else "None"
|
|
356
|
+
logger.debug(
|
|
357
|
+
f"Adapter loop i={i}, type(current_solution)={type(current_solution)}, value='{debug_solution_val_str[:100]}...'"
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
processed_solution_val: Optional[str] = None
|
|
361
|
+
if current_solution is not None:
|
|
362
|
+
if isinstance(current_solution, list):
|
|
363
|
+
logger.warning(
|
|
364
|
+
f"Sample {i} solution is a list, attempting to use first element: {current_solution}"
|
|
365
|
+
)
|
|
366
|
+
if current_solution:
|
|
367
|
+
processed_solution_val = str(current_solution[0])
|
|
368
|
+
# If current_solution is an empty list, processed_solution_val remains None
|
|
369
|
+
else:
|
|
370
|
+
processed_solution_val = str(current_solution)
|
|
371
|
+
|
|
372
|
+
if processed_solution_val is not None:
|
|
373
|
+
call_kwargs["solution"] = processed_solution_val
|
|
374
|
+
|
|
375
|
+
# Add any other necessary kwargs (if they were extracted from the main **kwargs)
|
|
376
|
+
# For now, only "solution" is dynamically handled from TRL kwargs.
|
|
377
|
+
|
|
378
|
+
try:
|
|
379
|
+
# Call the underlying RewardFunction instance (__call__)
|
|
380
|
+
# Pass the constructed messages and the extracted kwargs for this sample
|
|
381
|
+
result = self(
|
|
382
|
+
messages=messages,
|
|
383
|
+
ground_truth=(
|
|
384
|
+
call_kwargs.pop("solution") if "solution" in call_kwargs else None
|
|
385
|
+
), # Pass solution as ground_truth if available
|
|
386
|
+
**call_kwargs,
|
|
387
|
+
)
|
|
388
|
+
# Handle both RewardOutput and EvaluateResult
|
|
389
|
+
score = result.score
|
|
390
|
+
results.append(score)
|
|
391
|
+
except Exception as e:
|
|
392
|
+
logger.error(f"Error processing sample {i} in TRL adapter: {str(e)}")
|
|
393
|
+
# Append a default low score (e.g., 0.0) on error
|
|
394
|
+
results.append(0.0)
|
|
395
|
+
|
|
396
|
+
return results
|
|
397
|
+
|
|
398
|
+
return adapter
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
# The legacy_reward_function decorator has been removed as it is no longer needed.
|
|
402
|
+
# Use the reward_function decorator from typed_interface instead.
|
|
403
|
+
#
|
|
404
|
+
# For deployment functionality, use the RewardFunction class or the deployment
|
|
405
|
+
# methods from the evaluation module directly.
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
# The alias below is removed to ensure that `from .typed_interface import reward_function`
|
|
409
|
+
# is the one used throughout the library, thus avoiding the deprecation warning
|
|
410
|
+
# when using the @reward_function decorator.
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Out-of-the-box reward functions for common use cases.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
# Import specific reward functions
|
|
6
|
+
from . import (
|
|
7
|
+
accuracy,
|
|
8
|
+
accuracy_length,
|
|
9
|
+
bfcl_reward,
|
|
10
|
+
code_execution,
|
|
11
|
+
cpp_code,
|
|
12
|
+
deepcoder_reward,
|
|
13
|
+
format,
|
|
14
|
+
function_calling,
|
|
15
|
+
json_schema,
|
|
16
|
+
language_consistency,
|
|
17
|
+
lean_prover,
|
|
18
|
+
length,
|
|
19
|
+
list_comparison_math_reward,
|
|
20
|
+
math,
|
|
21
|
+
multiple_choice_math_reward,
|
|
22
|
+
reasoning_steps,
|
|
23
|
+
repetition,
|
|
24
|
+
tag_count,
|
|
25
|
+
)
|
|
26
|
+
from .accuracy_length import cosine_scaled_accuracy_length_reward
|
|
27
|
+
|
|
28
|
+
# Import function separately to avoid name conflict with the module
|
|
29
|
+
from .bfcl_reward import bfcl_reward as bfcl_reward_function
|
|
30
|
+
|
|
31
|
+
# Directly import specific reward functions for easy access
|
|
32
|
+
from .code_execution import fractional_code_reward
|
|
33
|
+
from .cpp_code import binary_cpp_code_reward, ioi_cpp_code_reward
|
|
34
|
+
from .deepcoder_reward import deepcoder_code_reward
|
|
35
|
+
|
|
36
|
+
# To make individual functions directly importable from eval_protocol.rewards
|
|
37
|
+
# e.g., from eval_protocol.rewards import composite_function_call_reward
|
|
38
|
+
from .function_calling import (
|
|
39
|
+
composite_function_call_reward,
|
|
40
|
+
exact_tool_match_reward,
|
|
41
|
+
llm_judge_reward,
|
|
42
|
+
schema_jaccard_reward,
|
|
43
|
+
)
|
|
44
|
+
from .lean_prover import (
|
|
45
|
+
deepseek_huggingface_prover_benchmark,
|
|
46
|
+
deepseek_prover_v2_reward,
|
|
47
|
+
lean_prover_reward,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# Import these with aliases to avoid name conflicts
|
|
51
|
+
from .list_comparison_math_reward import (
|
|
52
|
+
list_comparison_math_reward as list_comparison_math_reward_function,
|
|
53
|
+
)
|
|
54
|
+
from .multiple_choice_math_reward import (
|
|
55
|
+
multiple_choice_math_reward as multiple_choice_math_reward_function,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
__all__ = [
|
|
59
|
+
# Modules
|
|
60
|
+
"function_calling",
|
|
61
|
+
"json_schema",
|
|
62
|
+
"math",
|
|
63
|
+
"code_execution",
|
|
64
|
+
"format",
|
|
65
|
+
"tag_count",
|
|
66
|
+
"accuracy",
|
|
67
|
+
"language_consistency",
|
|
68
|
+
"reasoning_steps",
|
|
69
|
+
"length",
|
|
70
|
+
"repetition",
|
|
71
|
+
"cpp_code",
|
|
72
|
+
"accuracy_length",
|
|
73
|
+
"lean_prover",
|
|
74
|
+
"deepcoder_reward",
|
|
75
|
+
"multiple_choice_math_reward",
|
|
76
|
+
"list_comparison_math_reward",
|
|
77
|
+
"bfcl_reward",
|
|
78
|
+
# Specific functions for direct import
|
|
79
|
+
"fractional_code_reward",
|
|
80
|
+
"deepcoder_code_reward",
|
|
81
|
+
"multiple_choice_math_reward_function",
|
|
82
|
+
"list_comparison_math_reward_function",
|
|
83
|
+
"ioi_cpp_code_reward",
|
|
84
|
+
"binary_cpp_code_reward",
|
|
85
|
+
"cosine_scaled_accuracy_length_reward",
|
|
86
|
+
"lean_prover_reward",
|
|
87
|
+
"deepseek_prover_v2_reward",
|
|
88
|
+
"deepseek_huggingface_prover_benchmark",
|
|
89
|
+
"bfcl_reward_function",
|
|
90
|
+
"composite_function_call_reward",
|
|
91
|
+
"exact_tool_match_reward",
|
|
92
|
+
"llm_judge_reward",
|
|
93
|
+
"schema_jaccard_reward",
|
|
94
|
+
]
|