eval-protocol 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- development/__init__.py +1 -0
- development/normalize_sandbox_fusion.py +628 -0
- development/utils/__init__.py +1 -0
- development/utils/generate_api_key.py +31 -0
- development/utils/subprocess_manager.py +481 -0
- eval_protocol/__init__.py +86 -0
- eval_protocol/__main__.py +10 -0
- eval_protocol/_version.py +21 -0
- eval_protocol/adapters/__init__.py +1 -0
- eval_protocol/adapters/braintrust.py +8 -0
- eval_protocol/adapters/trl.py +8 -0
- eval_protocol/agent/__init__.py +29 -0
- eval_protocol/agent/models.py +69 -0
- eval_protocol/agent/orchestrator.py +893 -0
- eval_protocol/agent/resource_abc.py +89 -0
- eval_protocol/agent/resource_pool.py +184 -0
- eval_protocol/agent/resources/__init__.py +44 -0
- eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
- eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
- eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
- eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
- eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
- eval_protocol/agent/resources/docker_resource.py +479 -0
- eval_protocol/agent/resources/filesystem_resource.py +371 -0
- eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
- eval_protocol/agent/resources/http_rollout_resource.py +325 -0
- eval_protocol/agent/resources/python_state_resource.py +170 -0
- eval_protocol/agent/resources/sql_resource.py +271 -0
- eval_protocol/agent/task_manager.py +1064 -0
- eval_protocol/agent/tool_registry.py +111 -0
- eval_protocol/auth.py +156 -0
- eval_protocol/cli.py +425 -0
- eval_protocol/cli_commands/__init__.py +1 -0
- eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
- eval_protocol/cli_commands/common.py +242 -0
- eval_protocol/cli_commands/deploy.py +486 -0
- eval_protocol/cli_commands/deploy_mcp.py +287 -0
- eval_protocol/cli_commands/preview.py +186 -0
- eval_protocol/cli_commands/run_eval_cmd.py +202 -0
- eval_protocol/common_utils.py +36 -0
- eval_protocol/config.py +180 -0
- eval_protocol/datasets/__init__.py +1 -0
- eval_protocol/datasets/loader.py +521 -0
- eval_protocol/evaluation.py +1045 -0
- eval_protocol/execution/__init__.py +1 -0
- eval_protocol/execution/pipeline.py +920 -0
- eval_protocol/gcp_tools.py +484 -0
- eval_protocol/generation/cache.py +141 -0
- eval_protocol/generation/clients/base.py +67 -0
- eval_protocol/generation/clients.py +248 -0
- eval_protocol/generic_server.py +165 -0
- eval_protocol/integrations/__init__.py +12 -0
- eval_protocol/integrations/braintrust.py +51 -0
- eval_protocol/integrations/deepeval.py +106 -0
- eval_protocol/integrations/openeval.py +40 -0
- eval_protocol/integrations/trl.py +187 -0
- eval_protocol/mcp/__init__.py +48 -0
- eval_protocol/mcp/adapter.py +131 -0
- eval_protocol/mcp/client/__init__.py +12 -0
- eval_protocol/mcp/client/connection.py +499 -0
- eval_protocol/mcp/clients.py +195 -0
- eval_protocol/mcp/execution/__init__.py +23 -0
- eval_protocol/mcp/execution/base_policy.py +227 -0
- eval_protocol/mcp/execution/fireworks_policy.py +209 -0
- eval_protocol/mcp/execution/manager.py +506 -0
- eval_protocol/mcp/execution/policy.py +421 -0
- eval_protocol/mcp/grid_renderer.py +54 -0
- eval_protocol/mcp/mcpgym.py +637 -0
- eval_protocol/mcp/process_manager.py +177 -0
- eval_protocol/mcp/session/__init__.py +11 -0
- eval_protocol/mcp/session/manager.py +228 -0
- eval_protocol/mcp/simple_process_manager.py +291 -0
- eval_protocol/mcp/simulation_server.py +458 -0
- eval_protocol/mcp/types.py +80 -0
- eval_protocol/mcp_agent/__init__.py +1 -0
- eval_protocol/mcp_agent/config.py +147 -0
- eval_protocol/mcp_agent/intermediary_server.py +542 -0
- eval_protocol/mcp_agent/main.py +210 -0
- eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
- eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
- eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
- eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
- eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
- eval_protocol/mcp_agent/session.py +79 -0
- eval_protocol/mcp_env.py +304 -0
- eval_protocol/models.py +366 -0
- eval_protocol/packaging.py +219 -0
- eval_protocol/platform_api.py +360 -0
- eval_protocol/playback_policy.py +396 -0
- eval_protocol/resources.py +128 -0
- eval_protocol/reward_function.py +410 -0
- eval_protocol/rewards/__init__.py +94 -0
- eval_protocol/rewards/accuracy.py +454 -0
- eval_protocol/rewards/accuracy_length.py +173 -0
- eval_protocol/rewards/apps_coding_reward.py +331 -0
- eval_protocol/rewards/apps_execution_utils.py +149 -0
- eval_protocol/rewards/apps_testing_util.py +559 -0
- eval_protocol/rewards/bfcl_reward.py +313 -0
- eval_protocol/rewards/code_execution.py +1620 -0
- eval_protocol/rewards/code_execution_utils.py +72 -0
- eval_protocol/rewards/cpp_code.py +861 -0
- eval_protocol/rewards/deepcoder_reward.py +161 -0
- eval_protocol/rewards/format.py +129 -0
- eval_protocol/rewards/function_calling.py +541 -0
- eval_protocol/rewards/json_schema.py +422 -0
- eval_protocol/rewards/language_consistency.py +700 -0
- eval_protocol/rewards/lean_prover.py +479 -0
- eval_protocol/rewards/length.py +375 -0
- eval_protocol/rewards/list_comparison_math_reward.py +221 -0
- eval_protocol/rewards/math.py +762 -0
- eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
- eval_protocol/rewards/reasoning_steps.py +249 -0
- eval_protocol/rewards/repetition.py +342 -0
- eval_protocol/rewards/tag_count.py +162 -0
- eval_protocol/rl_processing.py +82 -0
- eval_protocol/server.py +271 -0
- eval_protocol/typed_interface.py +260 -0
- eval_protocol/utils/__init__.py +8 -0
- eval_protocol/utils/batch_evaluation.py +217 -0
- eval_protocol/utils/batch_transformation.py +205 -0
- eval_protocol/utils/dataset_helpers.py +112 -0
- eval_protocol/utils/module_loader.py +56 -0
- eval_protocol/utils/packaging_utils.py +108 -0
- eval_protocol/utils/static_policy.py +305 -0
- eval_protocol-0.0.3.dist-info/METADATA +635 -0
- eval_protocol-0.0.3.dist-info/RECORD +130 -0
- eval_protocol-0.0.3.dist-info/WHEEL +5 -0
- eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
- eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
- eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Reward functions for evaluating response length.
|
|
3
|
+
|
|
4
|
+
This module provides reward functions that evaluate the length of model responses,
|
|
5
|
+
either by simple token/character count or using cosine-scaled rewards to promote
|
|
6
|
+
token efficiency.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import math
|
|
10
|
+
import re
|
|
11
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
|
12
|
+
|
|
13
|
+
from ..models import EvaluateResult, Message, MetricResult
|
|
14
|
+
from ..typed_interface import reward_function
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def count_tokens(text: str, method: str = "whitespace") -> int:
|
|
18
|
+
"""
|
|
19
|
+
Count tokens in text using different methods.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
text: The text to tokenize
|
|
23
|
+
method: Tokenization method to use ('whitespace', 'character', or 'words')
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Token count based on the selected method
|
|
27
|
+
"""
|
|
28
|
+
if method == "character":
|
|
29
|
+
return len(text)
|
|
30
|
+
elif method == "whitespace":
|
|
31
|
+
return len(re.split(r"\s+", text.strip()))
|
|
32
|
+
elif method == "words":
|
|
33
|
+
return len(re.findall(r"\b[\w\d]+\b", text))
|
|
34
|
+
else:
|
|
35
|
+
return len(re.split(r"\s+", text.strip()))
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@reward_function # type: ignore[arg-type]
|
|
39
|
+
def length_reward(
|
|
40
|
+
messages: Union[List[Message], List[Dict[str, Any]]],
|
|
41
|
+
*,
|
|
42
|
+
ground_truth: Optional[
|
|
43
|
+
Union[List[Message], List[Dict[str, Any]]]
|
|
44
|
+
] = None, # Not used by this function but part of standard signature
|
|
45
|
+
target_length: Optional[int] = None,
|
|
46
|
+
min_length: Optional[int] = None,
|
|
47
|
+
max_length: Optional[int] = None,
|
|
48
|
+
token_method: str = "whitespace",
|
|
49
|
+
scaling: str = "linear",
|
|
50
|
+
reward_range: Optional[List[float]] = None,
|
|
51
|
+
**kwargs: Any,
|
|
52
|
+
) -> EvaluateResult:
|
|
53
|
+
"""
|
|
54
|
+
Reward function that evaluates the length of model responses.
|
|
55
|
+
The model's response is assumed to be the last message in the `messages` list.
|
|
56
|
+
|
|
57
|
+
This function can calculate rewards based on token count and can encourage either
|
|
58
|
+
conciseness or thoroughness by setting appropriate min/max/target parameters.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
messages: List of conversation messages, where `messages[-1]` is the model's response.
|
|
62
|
+
ground_truth: Optional. Expected assistant response trajectory. Not directly used by this length reward.
|
|
63
|
+
target_length: Optional target token count (optimal length).
|
|
64
|
+
min_length: Minimum acceptable token count.
|
|
65
|
+
max_length: Maximum acceptable token count.
|
|
66
|
+
token_method: Method to count tokens ('whitespace', 'character', or 'words')
|
|
67
|
+
scaling: Scaling method for reward calculation ('linear' or 'cosine')
|
|
68
|
+
reward_range: Range for reward values, default is [0.0, 1.0]
|
|
69
|
+
**kwargs: Additional arguments
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
EvaluateResult with score based on length evaluation
|
|
73
|
+
"""
|
|
74
|
+
if not messages or len(messages) == 0:
|
|
75
|
+
return EvaluateResult(
|
|
76
|
+
score=0.0,
|
|
77
|
+
reason="No messages provided",
|
|
78
|
+
metrics={"length": MetricResult(score=0.0, is_score_valid=False, reason="No messages provided")},
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
response = messages[-1]
|
|
82
|
+
|
|
83
|
+
if isinstance(response, Message):
|
|
84
|
+
if response.role != "assistant" or not response.content:
|
|
85
|
+
return EvaluateResult(
|
|
86
|
+
score=0.0,
|
|
87
|
+
reason="No assistant response found",
|
|
88
|
+
metrics={
|
|
89
|
+
"length": MetricResult(
|
|
90
|
+
score=0.0,
|
|
91
|
+
is_score_valid=False,
|
|
92
|
+
reason="Message not from assistant or has no content",
|
|
93
|
+
)
|
|
94
|
+
},
|
|
95
|
+
)
|
|
96
|
+
text = response.content
|
|
97
|
+
elif isinstance(response, dict):
|
|
98
|
+
if response.get("role") != "assistant" or not response.get("content"):
|
|
99
|
+
return EvaluateResult(
|
|
100
|
+
score=0.0,
|
|
101
|
+
reason="No assistant response found",
|
|
102
|
+
metrics={
|
|
103
|
+
"length": MetricResult(
|
|
104
|
+
score=0.0,
|
|
105
|
+
is_score_valid=False,
|
|
106
|
+
reason="Message not from assistant or has no content",
|
|
107
|
+
)
|
|
108
|
+
},
|
|
109
|
+
)
|
|
110
|
+
text = response.get("content", "")
|
|
111
|
+
else:
|
|
112
|
+
return EvaluateResult(
|
|
113
|
+
score=0.0,
|
|
114
|
+
reason="Last message is of unexpected type.",
|
|
115
|
+
metrics={
|
|
116
|
+
"length": MetricResult(
|
|
117
|
+
score=0.0,
|
|
118
|
+
is_score_valid=False,
|
|
119
|
+
reason="Invalid message type in messages.",
|
|
120
|
+
)
|
|
121
|
+
},
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
token_count = count_tokens(text, method=token_method)
|
|
125
|
+
|
|
126
|
+
if reward_range is None:
|
|
127
|
+
reward_range = [0.0, 1.0]
|
|
128
|
+
min_reward, max_reward = reward_range
|
|
129
|
+
|
|
130
|
+
if target_length is not None:
|
|
131
|
+
normalized_diff = abs(token_count - target_length) / target_length if target_length > 0 else 1.0
|
|
132
|
+
if scaling == "cosine":
|
|
133
|
+
progress = min(1.0, normalized_diff)
|
|
134
|
+
score = min_reward + (max_reward - min_reward) * (1.0 + math.cos(progress * math.pi)) / 2.0
|
|
135
|
+
else:
|
|
136
|
+
score = max(
|
|
137
|
+
min_reward,
|
|
138
|
+
max_reward - normalized_diff * (max_reward - min_reward),
|
|
139
|
+
)
|
|
140
|
+
reason = (
|
|
141
|
+
f"Response length ({token_count} tokens) deviated by {normalized_diff:.2f} from target ({target_length})"
|
|
142
|
+
)
|
|
143
|
+
success = normalized_diff < 0.2
|
|
144
|
+
elif min_length is not None and max_length is not None:
|
|
145
|
+
if token_count < min_length:
|
|
146
|
+
progress = token_count / min_length
|
|
147
|
+
if scaling == "cosine":
|
|
148
|
+
score = min_reward + (max_reward - min_reward) * (1.0 - math.cos(progress * math.pi / 2.0))
|
|
149
|
+
else:
|
|
150
|
+
score = min_reward + (max_reward - min_reward) * progress
|
|
151
|
+
reason = f"Response length ({token_count} tokens) is below minimum ({min_length})"
|
|
152
|
+
success = False
|
|
153
|
+
elif token_count > max_length:
|
|
154
|
+
excess = token_count - max_length
|
|
155
|
+
range_size = (
|
|
156
|
+
max_length - min_length if max_length > min_length else 1
|
|
157
|
+
) # Avoid division by zero if min_length == max_length
|
|
158
|
+
progress = min(
|
|
159
|
+
1.0,
|
|
160
|
+
excess / range_size if range_size > 0 else (1.0 if excess > 0 else 0.0),
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
if scaling == "cosine":
|
|
164
|
+
score = max_reward - (max_reward - min_reward) * (1.0 - math.cos(progress * math.pi / 2.0))
|
|
165
|
+
else:
|
|
166
|
+
score = max_reward - (max_reward - min_reward) * progress
|
|
167
|
+
reason = f"Response length ({token_count} tokens) exceeds maximum ({max_length})"
|
|
168
|
+
success = False
|
|
169
|
+
else:
|
|
170
|
+
score = max_reward
|
|
171
|
+
reason = f"Response length ({token_count} tokens) is within acceptable range ({min_length}-{max_length})"
|
|
172
|
+
success = True
|
|
173
|
+
elif min_length is not None:
|
|
174
|
+
if token_count < min_length:
|
|
175
|
+
progress = token_count / min_length
|
|
176
|
+
if scaling == "cosine":
|
|
177
|
+
score = min_reward + (max_reward - min_reward) * (1.0 - math.cos(progress * math.pi / 2.0))
|
|
178
|
+
else:
|
|
179
|
+
score = min_reward + (max_reward - min_reward) * progress
|
|
180
|
+
reason = f"Response length ({token_count} tokens) is below minimum ({min_length})"
|
|
181
|
+
success = False
|
|
182
|
+
else:
|
|
183
|
+
score = max_reward
|
|
184
|
+
reason = f"Response length ({token_count} tokens) meets minimum requirement ({min_length})"
|
|
185
|
+
success = True
|
|
186
|
+
elif max_length is not None:
|
|
187
|
+
if token_count > max_length:
|
|
188
|
+
excess = token_count - max_length
|
|
189
|
+
progress = min(
|
|
190
|
+
1.0,
|
|
191
|
+
excess / max_length if max_length > 0 else (1.0 if excess > 0 else 0.0),
|
|
192
|
+
)
|
|
193
|
+
if scaling == "cosine":
|
|
194
|
+
score = max_reward - (max_reward - min_reward) * (1.0 - math.cos(progress * math.pi / 2.0))
|
|
195
|
+
else:
|
|
196
|
+
score = max_reward - (max_reward - min_reward) * progress
|
|
197
|
+
reason = f"Response length ({token_count} tokens) exceeds maximum ({max_length})"
|
|
198
|
+
success = False
|
|
199
|
+
else:
|
|
200
|
+
score = max_reward
|
|
201
|
+
reason = f"Response length ({token_count} tokens) is within maximum limit ({max_length})"
|
|
202
|
+
success = True
|
|
203
|
+
else:
|
|
204
|
+
# This is useful when combined with correctness metrics
|
|
205
|
+
# E.g., shorter correct answers > longer correct answers > incorrect answers
|
|
206
|
+
reference_length = 100 # Default length for normalization
|
|
207
|
+
normalized_length = token_count / reference_length
|
|
208
|
+
if scaling == "cosine":
|
|
209
|
+
progress = min(1.0, normalized_length)
|
|
210
|
+
score = min_reward + (max_reward - min_reward) * (1.0 + math.cos(progress * math.pi)) / 2.0
|
|
211
|
+
else:
|
|
212
|
+
progress = min(1.0, normalized_length)
|
|
213
|
+
score = max_reward - progress * (max_reward - min_reward)
|
|
214
|
+
reason = f"Response length: {token_count} tokens"
|
|
215
|
+
success = True
|
|
216
|
+
|
|
217
|
+
metrics = {
|
|
218
|
+
"length": MetricResult(score=score, is_score_valid=success, reason=reason),
|
|
219
|
+
"token_count": MetricResult(
|
|
220
|
+
score=min(
|
|
221
|
+
1.0,
|
|
222
|
+
float(token_count) / (target_length or max_length or min_length or 100),
|
|
223
|
+
),
|
|
224
|
+
is_score_valid=success,
|
|
225
|
+
reason=f"Token count: {token_count}",
|
|
226
|
+
),
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
return EvaluateResult(score=score, reason=reason, metrics=metrics)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
@reward_function # type: ignore[arg-type]
|
|
233
|
+
def cosine_length_reward(
|
|
234
|
+
messages: Union[List[Message], List[Dict[str, Any]]],
|
|
235
|
+
*,
|
|
236
|
+
ground_truth: Optional[
|
|
237
|
+
Union[List[Message], List[Dict[str, Any]]]
|
|
238
|
+
] = None, # Not used by this function but part of standard signature
|
|
239
|
+
correctness: Optional[float] = None,
|
|
240
|
+
is_correct: Optional[bool] = None,
|
|
241
|
+
max_length: int = 1000,
|
|
242
|
+
min_value_wrong: float = 0.0,
|
|
243
|
+
max_value_wrong: float = 0.3,
|
|
244
|
+
min_value_correct: float = 0.5,
|
|
245
|
+
max_value_correct: float = 1.0,
|
|
246
|
+
token_method: str = "whitespace",
|
|
247
|
+
**kwargs: Any,
|
|
248
|
+
) -> EvaluateResult:
|
|
249
|
+
"""
|
|
250
|
+
Reward function that scales based on completion length using a cosine schedule.
|
|
251
|
+
The model's response is assumed to be the last message in the `messages` list.
|
|
252
|
+
|
|
253
|
+
Inspired by the OpenR1 implementation (https://github.com/OpenRL-Lab/open-r1) and
|
|
254
|
+
Kimi Technical Report (https://arxiv.org/abs/2501.12599).
|
|
255
|
+
|
|
256
|
+
Shorter correct solutions are rewarded more than longer ones.
|
|
257
|
+
Longer incorrect solutions are penalized less than shorter ones.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
messages: List of conversation messages, where `messages[-1]` is the model's response.
|
|
261
|
+
ground_truth: Optional. Expected assistant response trajectory. Not directly used by this length reward.
|
|
262
|
+
correctness: Optional float (0-1) indicating solution correctness.
|
|
263
|
+
is_correct: Optional boolean indicating if the solution is correct.
|
|
264
|
+
max_length: Maximum length for scaling.
|
|
265
|
+
min_value_wrong: Minimum reward for wrong answers (typically negative)
|
|
266
|
+
max_value_wrong: Maximum reward for wrong answers (typically negative but closer to zero)
|
|
267
|
+
min_value_correct: Minimum reward for correct answers (typically positive)
|
|
268
|
+
max_value_correct: Maximum reward for correct answers (typically more positive)
|
|
269
|
+
token_method: Method to count tokens
|
|
270
|
+
**kwargs: Additional arguments
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
EvaluateResult with score based on cosine-scaled length evaluation
|
|
274
|
+
"""
|
|
275
|
+
if not messages or len(messages) == 0:
|
|
276
|
+
return EvaluateResult(
|
|
277
|
+
score=0.0,
|
|
278
|
+
reason="No messages provided",
|
|
279
|
+
metrics={"cosine_length": MetricResult(score=0.0, is_score_valid=False, reason="No messages provided")},
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
response = messages[-1]
|
|
283
|
+
|
|
284
|
+
if isinstance(response, Message):
|
|
285
|
+
if response.role != "assistant" or not response.content:
|
|
286
|
+
return EvaluateResult(
|
|
287
|
+
score=0.0,
|
|
288
|
+
reason="No assistant response found",
|
|
289
|
+
metrics={
|
|
290
|
+
"cosine_length": MetricResult(
|
|
291
|
+
score=0.0,
|
|
292
|
+
is_score_valid=False,
|
|
293
|
+
reason="Message not from assistant or has no content",
|
|
294
|
+
)
|
|
295
|
+
},
|
|
296
|
+
)
|
|
297
|
+
text = response.content
|
|
298
|
+
elif isinstance(response, dict):
|
|
299
|
+
if response.get("role") != "assistant" or not response.get("content"):
|
|
300
|
+
return EvaluateResult(
|
|
301
|
+
score=0.0,
|
|
302
|
+
reason="No assistant response found",
|
|
303
|
+
metrics={
|
|
304
|
+
"cosine_length": MetricResult(
|
|
305
|
+
score=0.0,
|
|
306
|
+
is_score_valid=False,
|
|
307
|
+
reason="Message not from assistant or has no content",
|
|
308
|
+
)
|
|
309
|
+
},
|
|
310
|
+
)
|
|
311
|
+
text = response.get("content", "")
|
|
312
|
+
else:
|
|
313
|
+
return EvaluateResult(
|
|
314
|
+
score=0.0,
|
|
315
|
+
reason="Last message is of unexpected type.",
|
|
316
|
+
metrics={
|
|
317
|
+
"cosine_length": MetricResult(
|
|
318
|
+
score=0.0,
|
|
319
|
+
is_score_valid=False,
|
|
320
|
+
reason="Invalid message type in messages.",
|
|
321
|
+
)
|
|
322
|
+
},
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
token_count = count_tokens(text, method=token_method)
|
|
326
|
+
|
|
327
|
+
solution_is_correct = False
|
|
328
|
+
if is_correct is not None:
|
|
329
|
+
solution_is_correct = is_correct
|
|
330
|
+
elif correctness is not None:
|
|
331
|
+
solution_is_correct = correctness >= 0.9
|
|
332
|
+
|
|
333
|
+
progress = min(1.0, token_count / max_length)
|
|
334
|
+
cosine_factor = math.cos(progress * math.pi)
|
|
335
|
+
|
|
336
|
+
if solution_is_correct:
|
|
337
|
+
min_value = min_value_correct
|
|
338
|
+
max_value = max_value_correct
|
|
339
|
+
else:
|
|
340
|
+
min_value = max_value_wrong
|
|
341
|
+
max_value = min_value_wrong
|
|
342
|
+
|
|
343
|
+
score = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine_factor)
|
|
344
|
+
|
|
345
|
+
if solution_is_correct:
|
|
346
|
+
success = True
|
|
347
|
+
reason = f"Correct solution with length penalty: {token_count} tokens"
|
|
348
|
+
else:
|
|
349
|
+
success = False
|
|
350
|
+
reason = f"Incorrect solution with length consideration: {token_count} tokens"
|
|
351
|
+
|
|
352
|
+
detailed_reason = (
|
|
353
|
+
f"Length-based {'reward' if solution_is_correct else 'penalty'}: "
|
|
354
|
+
f"{token_count}/{max_length} tokens, cosine factor: {cosine_factor:.2f}"
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
metrics = {
|
|
358
|
+
"cosine_length": MetricResult(
|
|
359
|
+
score=score,
|
|
360
|
+
is_score_valid=success,
|
|
361
|
+
reason=detailed_reason, # Use detailed_reason here
|
|
362
|
+
),
|
|
363
|
+
"token_count": MetricResult(
|
|
364
|
+
score=min(1.0, float(token_count) / max_length),
|
|
365
|
+
is_score_valid=success,
|
|
366
|
+
reason=f"Token count: {token_count}/{max_length}",
|
|
367
|
+
),
|
|
368
|
+
"correctness": MetricResult(
|
|
369
|
+
score=1.0 if solution_is_correct else 0.0,
|
|
370
|
+
is_score_valid=solution_is_correct,
|
|
371
|
+
reason=f"Solution is {'correct' if solution_is_correct else 'incorrect'}",
|
|
372
|
+
),
|
|
373
|
+
}
|
|
374
|
+
|
|
375
|
+
return EvaluateResult(score=score, reason=reason, metrics=metrics)
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Reward function for comparing lists of numbers, often found in math answers
|
|
3
|
+
like sets of divisors, roots, etc.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import re
|
|
7
|
+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
|
8
|
+
|
|
9
|
+
from ..models import EvaluateResult, Message, MetricResult
|
|
10
|
+
from ..typed_interface import reward_function
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def parse_number_list_from_string(s: str) -> Optional[List[float]]:
|
|
14
|
+
"""
|
|
15
|
+
Parses a string potentially containing a comma-separated list of numbers.
|
|
16
|
+
Handles integers and simple decimals.
|
|
17
|
+
e.g., "1, 2, 3.5, 4" -> [1.0, 2.0, 3.5, 4.0]
|
|
18
|
+
"""
|
|
19
|
+
numbers = []
|
|
20
|
+
s = s.replace("$", "").strip()
|
|
21
|
+
parts = re.split(r"\s*,\s*", s)
|
|
22
|
+
if not parts or not any(p.strip() for p in parts):
|
|
23
|
+
return None
|
|
24
|
+
|
|
25
|
+
for part in parts:
|
|
26
|
+
part = part.strip()
|
|
27
|
+
if not part:
|
|
28
|
+
continue
|
|
29
|
+
try:
|
|
30
|
+
numbers.append(float(part))
|
|
31
|
+
except ValueError:
|
|
32
|
+
return None
|
|
33
|
+
return numbers if numbers else None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def extract_number_list(text: str) -> List[List[float]]:
|
|
37
|
+
"""
|
|
38
|
+
Extracts lists of numbers from text.
|
|
39
|
+
Prioritizes content within \\boxed{} or $...$.
|
|
40
|
+
If multiple such expressions exist, each valid list is returned.
|
|
41
|
+
If no such delimiters, tries to parse the whole text.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
text: The text to extract number lists from.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
A list of extracted number lists. Each inner list contains floats.
|
|
48
|
+
Example: "\\boxed{1,2,3}, $4,5$" -> [[1.0, 2.0, 3.0], [4.0, 5.0]]
|
|
49
|
+
"""
|
|
50
|
+
extracted_lists: List[List[float]] = []
|
|
51
|
+
|
|
52
|
+
# Priority 1: Boxed LaTeX expressions
|
|
53
|
+
boxed_contents = re.findall(r"\\boxed\{((?:[^{}]|\{[^{}]*\})*)\}", text)
|
|
54
|
+
if boxed_contents:
|
|
55
|
+
for content in boxed_contents:
|
|
56
|
+
parsed_list = parse_number_list_from_string(content)
|
|
57
|
+
if parsed_list:
|
|
58
|
+
extracted_lists.append(parsed_list)
|
|
59
|
+
if extracted_lists:
|
|
60
|
+
return extracted_lists
|
|
61
|
+
|
|
62
|
+
# Priority 2: Content within $...$ or $$...$$
|
|
63
|
+
dollar_contents = re.findall(r"\$\$(.*?)\$\$|\$(.*?)\$", text, re.DOTALL)
|
|
64
|
+
if dollar_contents:
|
|
65
|
+
for group_match in dollar_contents:
|
|
66
|
+
content = group_match[0] if group_match[0] else group_match[1]
|
|
67
|
+
if content:
|
|
68
|
+
parsed_list = parse_number_list_from_string(content.strip())
|
|
69
|
+
if parsed_list:
|
|
70
|
+
extracted_lists.append(parsed_list)
|
|
71
|
+
if extracted_lists:
|
|
72
|
+
return extracted_lists
|
|
73
|
+
|
|
74
|
+
# Priority 3: Try parsing the whole text as a list if no delimiters found
|
|
75
|
+
# This is a fallback and might be less reliable.
|
|
76
|
+
if not extracted_lists:
|
|
77
|
+
full_text_parsed_list = parse_number_list_from_string(text)
|
|
78
|
+
if full_text_parsed_list:
|
|
79
|
+
extracted_lists.append(full_text_parsed_list)
|
|
80
|
+
|
|
81
|
+
return extracted_lists
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@reward_function # type: ignore[arg-type]
|
|
85
|
+
def list_comparison_math_reward(
|
|
86
|
+
messages: List[Message],
|
|
87
|
+
*,
|
|
88
|
+
ground_truth: str,
|
|
89
|
+
order_matters: bool = False,
|
|
90
|
+
**kwargs: Any,
|
|
91
|
+
) -> EvaluateResult:
|
|
92
|
+
"""
|
|
93
|
+
Evaluate answers that are lists/sets of numbers.
|
|
94
|
+
|
|
95
|
+
Extracts lists of numbers from the model's response (messages[-1].content)
|
|
96
|
+
and the ground_truth string, then compares them.
|
|
97
|
+
By default, order does not matter (set comparison).
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
messages: List of conversation messages. The last message is the assistant's response.
|
|
101
|
+
ground_truth: String representation of the expected list of numbers.
|
|
102
|
+
order_matters: If True, compares lists directly (order and count matter).
|
|
103
|
+
If False (default), compares as sets (order and duplicates
|
|
104
|
+
within a list don't matter beyond presence).
|
|
105
|
+
**kwargs: Additional keyword arguments.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
EvaluateResult with score and metrics.
|
|
109
|
+
"""
|
|
110
|
+
metrics: Dict[str, MetricResult] = {}
|
|
111
|
+
|
|
112
|
+
if (
|
|
113
|
+
not messages
|
|
114
|
+
or not isinstance(messages[-1], Message)
|
|
115
|
+
or messages[-1].role != "assistant"
|
|
116
|
+
or messages[-1].content is None
|
|
117
|
+
):
|
|
118
|
+
return EvaluateResult(
|
|
119
|
+
score=0.0,
|
|
120
|
+
reason="Invalid or missing assistant response in messages.",
|
|
121
|
+
metrics={
|
|
122
|
+
"error": MetricResult(
|
|
123
|
+
score=0.0,
|
|
124
|
+
is_score_valid=False,
|
|
125
|
+
reason="Last message not a valid assistant response.",
|
|
126
|
+
)
|
|
127
|
+
},
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
gen_content = messages[-1].content
|
|
131
|
+
orig_content = ground_truth
|
|
132
|
+
|
|
133
|
+
if not gen_content:
|
|
134
|
+
return EvaluateResult(
|
|
135
|
+
score=0.0,
|
|
136
|
+
reason="Assistant response content is empty.",
|
|
137
|
+
metrics={
|
|
138
|
+
"error": MetricResult(
|
|
139
|
+
score=0.0,
|
|
140
|
+
is_score_valid=False,
|
|
141
|
+
reason="Empty generated message content.",
|
|
142
|
+
)
|
|
143
|
+
},
|
|
144
|
+
)
|
|
145
|
+
if not orig_content:
|
|
146
|
+
return EvaluateResult(
|
|
147
|
+
score=0.0,
|
|
148
|
+
reason="Ground truth string (expected list) is empty.",
|
|
149
|
+
metrics={"error": MetricResult(score=0.0, is_score_valid=False, reason="Empty ground truth string.")},
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
gen_lists = extract_number_list(gen_content)
|
|
153
|
+
orig_lists = extract_number_list(orig_content)
|
|
154
|
+
|
|
155
|
+
metrics["extracted_original_lists"] = MetricResult(
|
|
156
|
+
score=1.0 if orig_lists else 0.0,
|
|
157
|
+
is_score_valid=bool(orig_lists),
|
|
158
|
+
reason=f"Original lists: {orig_lists}",
|
|
159
|
+
)
|
|
160
|
+
metrics["extracted_generated_lists"] = MetricResult(
|
|
161
|
+
score=1.0 if gen_lists else 0.0,
|
|
162
|
+
is_score_valid=bool(gen_lists),
|
|
163
|
+
reason=f"Generated lists: {gen_lists}",
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
if not orig_lists:
|
|
167
|
+
return EvaluateResult(
|
|
168
|
+
score=0.0,
|
|
169
|
+
reason="Could not extract any number list from original message (ground truth).",
|
|
170
|
+
metrics=metrics,
|
|
171
|
+
)
|
|
172
|
+
if not gen_lists:
|
|
173
|
+
return EvaluateResult(
|
|
174
|
+
score=0.0,
|
|
175
|
+
reason="Could not extract any number list from generated message.",
|
|
176
|
+
metrics=metrics,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# For simplicity, compare the first valid list found in each.
|
|
180
|
+
# Future improvement: handle multiple lists (e.g., if solution has multiple boxed lists)
|
|
181
|
+
orig_list_to_compare = orig_lists[0]
|
|
182
|
+
gen_list_to_compare = gen_lists[0]
|
|
183
|
+
|
|
184
|
+
score = 0.0
|
|
185
|
+
match_reason = ""
|
|
186
|
+
|
|
187
|
+
if order_matters:
|
|
188
|
+
# Note: To be robust against float precision, comparison element-wise with tolerance might be needed.
|
|
189
|
+
if gen_list_to_compare == orig_list_to_compare:
|
|
190
|
+
score = 1.0
|
|
191
|
+
match_reason = (
|
|
192
|
+
f"Exact list match (order matters). Gen: {gen_list_to_compare} vs Orig: {orig_list_to_compare}"
|
|
193
|
+
)
|
|
194
|
+
else:
|
|
195
|
+
score = 0.0
|
|
196
|
+
match_reason = f"List mismatch (order matters). Gen: {gen_list_to_compare} vs Orig: {orig_list_to_compare}"
|
|
197
|
+
else:
|
|
198
|
+
# Note: float precision can be an issue with sets. A more robust set comparison would involve tolerance.
|
|
199
|
+
gen_set = set(gen_list_to_compare)
|
|
200
|
+
orig_set = set(orig_list_to_compare)
|
|
201
|
+
|
|
202
|
+
if gen_set == orig_set:
|
|
203
|
+
score = 1.0
|
|
204
|
+
match_reason = (
|
|
205
|
+
f"Set match (order does not matter). Gen: {sorted(list(gen_set))} vs Orig: {sorted(list(orig_set))}"
|
|
206
|
+
)
|
|
207
|
+
else:
|
|
208
|
+
score = 0.0
|
|
209
|
+
missing_in_gen = orig_set - gen_set
|
|
210
|
+
extra_in_gen = gen_set - orig_set
|
|
211
|
+
match_reason_parts = [
|
|
212
|
+
f"Set mismatch (order does not matter). Gen: {sorted(list(gen_set))} vs Orig: {sorted(list(orig_set))}."
|
|
213
|
+
]
|
|
214
|
+
if missing_in_gen:
|
|
215
|
+
match_reason_parts.append(f"Missing in generated: {sorted(list(missing_in_gen))}.")
|
|
216
|
+
if extra_in_gen:
|
|
217
|
+
match_reason_parts.append(f"Extra in generated: {sorted(list(extra_in_gen))}.")
|
|
218
|
+
match_reason = " ".join(match_reason_parts)
|
|
219
|
+
|
|
220
|
+
metrics["list_comparison"] = MetricResult(score=score, is_score_valid=(score == 1.0), reason=match_reason)
|
|
221
|
+
return EvaluateResult(score=score, reason=match_reason, metrics=metrics)
|