eval-protocol 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- development/__init__.py +1 -0
- development/normalize_sandbox_fusion.py +628 -0
- development/utils/__init__.py +1 -0
- development/utils/generate_api_key.py +31 -0
- development/utils/subprocess_manager.py +481 -0
- eval_protocol/__init__.py +86 -0
- eval_protocol/__main__.py +10 -0
- eval_protocol/_version.py +21 -0
- eval_protocol/adapters/__init__.py +1 -0
- eval_protocol/adapters/braintrust.py +8 -0
- eval_protocol/adapters/trl.py +8 -0
- eval_protocol/agent/__init__.py +29 -0
- eval_protocol/agent/models.py +69 -0
- eval_protocol/agent/orchestrator.py +893 -0
- eval_protocol/agent/resource_abc.py +89 -0
- eval_protocol/agent/resource_pool.py +184 -0
- eval_protocol/agent/resources/__init__.py +44 -0
- eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
- eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
- eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
- eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
- eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
- eval_protocol/agent/resources/docker_resource.py +479 -0
- eval_protocol/agent/resources/filesystem_resource.py +371 -0
- eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
- eval_protocol/agent/resources/http_rollout_resource.py +325 -0
- eval_protocol/agent/resources/python_state_resource.py +170 -0
- eval_protocol/agent/resources/sql_resource.py +271 -0
- eval_protocol/agent/task_manager.py +1064 -0
- eval_protocol/agent/tool_registry.py +111 -0
- eval_protocol/auth.py +156 -0
- eval_protocol/cli.py +425 -0
- eval_protocol/cli_commands/__init__.py +1 -0
- eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
- eval_protocol/cli_commands/common.py +242 -0
- eval_protocol/cli_commands/deploy.py +486 -0
- eval_protocol/cli_commands/deploy_mcp.py +287 -0
- eval_protocol/cli_commands/preview.py +186 -0
- eval_protocol/cli_commands/run_eval_cmd.py +202 -0
- eval_protocol/common_utils.py +36 -0
- eval_protocol/config.py +180 -0
- eval_protocol/datasets/__init__.py +1 -0
- eval_protocol/datasets/loader.py +521 -0
- eval_protocol/evaluation.py +1045 -0
- eval_protocol/execution/__init__.py +1 -0
- eval_protocol/execution/pipeline.py +920 -0
- eval_protocol/gcp_tools.py +484 -0
- eval_protocol/generation/cache.py +141 -0
- eval_protocol/generation/clients/base.py +67 -0
- eval_protocol/generation/clients.py +248 -0
- eval_protocol/generic_server.py +165 -0
- eval_protocol/integrations/__init__.py +12 -0
- eval_protocol/integrations/braintrust.py +51 -0
- eval_protocol/integrations/deepeval.py +106 -0
- eval_protocol/integrations/openeval.py +40 -0
- eval_protocol/integrations/trl.py +187 -0
- eval_protocol/mcp/__init__.py +48 -0
- eval_protocol/mcp/adapter.py +131 -0
- eval_protocol/mcp/client/__init__.py +12 -0
- eval_protocol/mcp/client/connection.py +499 -0
- eval_protocol/mcp/clients.py +195 -0
- eval_protocol/mcp/execution/__init__.py +23 -0
- eval_protocol/mcp/execution/base_policy.py +227 -0
- eval_protocol/mcp/execution/fireworks_policy.py +209 -0
- eval_protocol/mcp/execution/manager.py +506 -0
- eval_protocol/mcp/execution/policy.py +421 -0
- eval_protocol/mcp/grid_renderer.py +54 -0
- eval_protocol/mcp/mcpgym.py +637 -0
- eval_protocol/mcp/process_manager.py +177 -0
- eval_protocol/mcp/session/__init__.py +11 -0
- eval_protocol/mcp/session/manager.py +228 -0
- eval_protocol/mcp/simple_process_manager.py +291 -0
- eval_protocol/mcp/simulation_server.py +458 -0
- eval_protocol/mcp/types.py +80 -0
- eval_protocol/mcp_agent/__init__.py +1 -0
- eval_protocol/mcp_agent/config.py +147 -0
- eval_protocol/mcp_agent/intermediary_server.py +542 -0
- eval_protocol/mcp_agent/main.py +210 -0
- eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
- eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
- eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
- eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
- eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
- eval_protocol/mcp_agent/session.py +79 -0
- eval_protocol/mcp_env.py +304 -0
- eval_protocol/models.py +366 -0
- eval_protocol/packaging.py +219 -0
- eval_protocol/platform_api.py +360 -0
- eval_protocol/playback_policy.py +396 -0
- eval_protocol/resources.py +128 -0
- eval_protocol/reward_function.py +410 -0
- eval_protocol/rewards/__init__.py +94 -0
- eval_protocol/rewards/accuracy.py +454 -0
- eval_protocol/rewards/accuracy_length.py +173 -0
- eval_protocol/rewards/apps_coding_reward.py +331 -0
- eval_protocol/rewards/apps_execution_utils.py +149 -0
- eval_protocol/rewards/apps_testing_util.py +559 -0
- eval_protocol/rewards/bfcl_reward.py +313 -0
- eval_protocol/rewards/code_execution.py +1620 -0
- eval_protocol/rewards/code_execution_utils.py +72 -0
- eval_protocol/rewards/cpp_code.py +861 -0
- eval_protocol/rewards/deepcoder_reward.py +161 -0
- eval_protocol/rewards/format.py +129 -0
- eval_protocol/rewards/function_calling.py +541 -0
- eval_protocol/rewards/json_schema.py +422 -0
- eval_protocol/rewards/language_consistency.py +700 -0
- eval_protocol/rewards/lean_prover.py +479 -0
- eval_protocol/rewards/length.py +375 -0
- eval_protocol/rewards/list_comparison_math_reward.py +221 -0
- eval_protocol/rewards/math.py +762 -0
- eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
- eval_protocol/rewards/reasoning_steps.py +249 -0
- eval_protocol/rewards/repetition.py +342 -0
- eval_protocol/rewards/tag_count.py +162 -0
- eval_protocol/rl_processing.py +82 -0
- eval_protocol/server.py +271 -0
- eval_protocol/typed_interface.py +260 -0
- eval_protocol/utils/__init__.py +8 -0
- eval_protocol/utils/batch_evaluation.py +217 -0
- eval_protocol/utils/batch_transformation.py +205 -0
- eval_protocol/utils/dataset_helpers.py +112 -0
- eval_protocol/utils/module_loader.py +56 -0
- eval_protocol/utils/packaging_utils.py +108 -0
- eval_protocol/utils/static_policy.py +305 -0
- eval_protocol-0.0.3.dist-info/METADATA +635 -0
- eval_protocol-0.0.3.dist-info/RECORD +130 -0
- eval_protocol-0.0.3.dist-info/WHEEL +5 -0
- eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
- eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
- eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,454 @@
|
|
|
1
|
+
# pylint: disable=all
|
|
2
|
+
"""
|
|
3
|
+
Reward functions for accuracy evaluation.
|
|
4
|
+
|
|
5
|
+
This module provides reward functions that evaluate the accuracy of model responses
|
|
6
|
+
by comparing them with ground truth answers, optionally using preprocessing steps
|
|
7
|
+
like normalization and LaTeX parsing.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import re
|
|
11
|
+
from typing import Any, Callable, Dict, List, Optional, Union, cast
|
|
12
|
+
|
|
13
|
+
from ..models import EvaluateResult, Message, MetricResult
|
|
14
|
+
from ..typed_interface import reward_function
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def normalize_text(text: str) -> str:
|
|
18
|
+
"""
|
|
19
|
+
Normalize text for comparison by removing excess whitespace, punctuation.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
text: The text to normalize
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Normalized text string
|
|
26
|
+
"""
|
|
27
|
+
text = text.lower()
|
|
28
|
+
text = re.sub(r"\s+", " ", text)
|
|
29
|
+
text = re.sub(r'[,.;:!?"\']', "", text)
|
|
30
|
+
|
|
31
|
+
# Remove parentheses, brackets, etc. that often appear in math expressions
|
|
32
|
+
# but keep their contents
|
|
33
|
+
text = re.sub(r"[\(\)\[\]\{\}]", "", text)
|
|
34
|
+
text = re.sub(r"[^\w\s\d+-/*=]", "", text)
|
|
35
|
+
text = text.replace("×", "*").replace("÷", "/")
|
|
36
|
+
|
|
37
|
+
return text.strip()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def extract_math_expression(text: str) -> str:
|
|
41
|
+
"""
|
|
42
|
+
Extract mathematical expressions from text.
|
|
43
|
+
|
|
44
|
+
This function attempts to find the final answer in mathematical texts,
|
|
45
|
+
handling both numerical answers and expressions.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
text: Text that might contain mathematical expressions
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Extracted mathematical expression or normalized text if no clear
|
|
52
|
+
expression is found
|
|
53
|
+
"""
|
|
54
|
+
# Try to find answer patterns like "= 42" or "answer is 42"
|
|
55
|
+
answer_patterns = [
|
|
56
|
+
# Common exact answer formats
|
|
57
|
+
r"(?:answer|result|solution)(?:\s+is|\s*[:=])\s*(?:x\s*=\s*)?([-+]?\d+(?:\.\d+)?(?:/\d+)?)",
|
|
58
|
+
r"(?:therefore|thus|so)[,:]?\s*(?:x\s*=\s*)?([-+]?\d+(?:\.\d+)?(?:/\d+)?)",
|
|
59
|
+
r"(?:the value of|value)\s*(?:x|y|z)\s*(?:is|=)\s*([-+]?\d+(?:\.\d+)?(?:/\d+)?)",
|
|
60
|
+
r"x\s*=\s*([-+]?\d+(?:\.\d+)?(?:/\d+)?)", # x = 4
|
|
61
|
+
r"(?:=|equals)\s*([-+]?\d+(?:\.\d+)?(?:/\d+)?)",
|
|
62
|
+
# Common answer formats with parentheses
|
|
63
|
+
r"(?:answer|result|solution)[^0-9\n.]*?is[^0-9\n.]*?((?:\([-+]?\))?(?:\d+(?:\.\d+)?(?:/\d+)?))",
|
|
64
|
+
r"(?:answer|result|value)[^0-9\n.]*?((?:\([-+]?\))?(?:\d+(?:\.\d+)?(?:/\d+)?))",
|
|
65
|
+
# Special cases for pi
|
|
66
|
+
r"(?:answer|result|value|=)\s*(?:is\s*)?(?:π|pi)",
|
|
67
|
+
r"(?:answer|result|value|=)\s*(?:is\s*)?(\d+(?:\.\d+)?π)",
|
|
68
|
+
r"(?:answer|result|value|=)\s*(?:is\s*)?π(?:\s*=\s*)?(?:≈\s*)?(3\.14\d*)",
|
|
69
|
+
# Numerical answers with units
|
|
70
|
+
r"(?:answer|result|value|=)\s*(?:is\s*)?([-+]?\d+(?:\.\d+)?)\s*(?:meters|feet|kg|seconds)",
|
|
71
|
+
# LaTeX patterns
|
|
72
|
+
r"\$x\s*=\s*([-+]?\d+(?:\.\d+)?(?:/\d+)?)\$", # LaTeX: $x = 4$
|
|
73
|
+
# Decimal approximations
|
|
74
|
+
r"(?:approximately|about|≈|~)\s*([-+]?\d+\.\d+)",
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
# Check patterns in both original and lowercase text
|
|
78
|
+
for text_variant in [text, text.lower()]:
|
|
79
|
+
for pattern in answer_patterns:
|
|
80
|
+
match = re.search(pattern, text_variant, re.IGNORECASE)
|
|
81
|
+
if match:
|
|
82
|
+
# Check if this is a pi-only match
|
|
83
|
+
if pattern == r"(?:answer|result|value|=)\s*(?:is\s*)?(?:π|pi)":
|
|
84
|
+
return "3.14159" # Return standard pi approximation
|
|
85
|
+
|
|
86
|
+
if match.groups():
|
|
87
|
+
result = match.group(1).strip()
|
|
88
|
+
# Clean up any trailing punctuation
|
|
89
|
+
result = re.sub(r"[.,;:]$", "", result)
|
|
90
|
+
|
|
91
|
+
# Handle pi symbols in the answer
|
|
92
|
+
if "π" in result or "pi" in result.lower():
|
|
93
|
+
result = result.replace("π", "").replace("Pi", "").replace("pi", "")
|
|
94
|
+
try:
|
|
95
|
+
# If it's just a coefficient of pi, convert to decimal
|
|
96
|
+
if result.strip() in ("", "1"):
|
|
97
|
+
return "3.14159" # π alone or 1π
|
|
98
|
+
else:
|
|
99
|
+
# Try to convert coefficient to float and multiply by pi
|
|
100
|
+
coef = float(result.strip())
|
|
101
|
+
return str(coef * 3.14159)
|
|
102
|
+
except (ValueError, TypeError):
|
|
103
|
+
# If conversion fails, return the original with pi
|
|
104
|
+
return result
|
|
105
|
+
|
|
106
|
+
return result
|
|
107
|
+
|
|
108
|
+
# Check for answers in the last line (common in math problems)
|
|
109
|
+
lines = text.strip().split("\n")
|
|
110
|
+
for i in range(min(3, len(lines))): # Check last 3 lines
|
|
111
|
+
last_line = lines[-(i + 1)].strip()
|
|
112
|
+
if "answer" in last_line.lower() or "result" in last_line.lower() or "solution" in last_line.lower():
|
|
113
|
+
# Extract numbers from the last line
|
|
114
|
+
numbers = re.findall(r"[-+]?\d+(?:\.\d+)?", last_line)
|
|
115
|
+
if numbers:
|
|
116
|
+
return numbers[-1] # Take the last number
|
|
117
|
+
|
|
118
|
+
# Direct search for numbers that might be answers
|
|
119
|
+
# Only use as a fallback for short responses with few numbers
|
|
120
|
+
if len(text) < 200: # Only for short answers
|
|
121
|
+
# Count decimal numbers in text
|
|
122
|
+
numbers = re.findall(r"(?:^|\s|[^\w])([-+]?\d+(?:\.\d+)?)(?:\s|$|[^\w])", text)
|
|
123
|
+
if len(numbers) == 1: # If there's only one number, it's likely the answer
|
|
124
|
+
return numbers[0]
|
|
125
|
+
elif numbers and len(text.split()) < 30: # Very short text with numbers
|
|
126
|
+
# Take the last number in a short response
|
|
127
|
+
return numbers[-1]
|
|
128
|
+
|
|
129
|
+
# Look for capitalized city names or other proper nouns as answers
|
|
130
|
+
if re.search(r"capital|city|country|president|largest|smallest", text.lower()):
|
|
131
|
+
noun_pattern = r"is\s+([A-Z][a-zA-Z]+(?:\s+[A-Z][a-zA-Z]+)*)"
|
|
132
|
+
match = re.search(noun_pattern, text)
|
|
133
|
+
if match:
|
|
134
|
+
return match.group(1).strip()
|
|
135
|
+
|
|
136
|
+
# Look for LaTeX math expressions
|
|
137
|
+
latex_patterns = [
|
|
138
|
+
r"\$x\s*=\s*([^$]+)\$", # Inline math with x = ...
|
|
139
|
+
r"\$([^$]+)\$", # Inline math: $...$
|
|
140
|
+
r"\\\((.*?)\\\)", # Inline math: \(...\)
|
|
141
|
+
r"\\\[(.*?)\\\]", # Display math: \[...\]
|
|
142
|
+
]
|
|
143
|
+
|
|
144
|
+
for pattern in latex_patterns:
|
|
145
|
+
matches = re.findall(pattern, text)
|
|
146
|
+
if matches:
|
|
147
|
+
# Process the last match which is often the final answer
|
|
148
|
+
latex_expr = matches[-1].strip()
|
|
149
|
+
|
|
150
|
+
# Try to extract numbers from LaTeX
|
|
151
|
+
if "=" in latex_expr:
|
|
152
|
+
# If there's an equals sign, take what's on the right
|
|
153
|
+
parts = latex_expr.split("=")
|
|
154
|
+
latex_expr = parts[-1].strip()
|
|
155
|
+
|
|
156
|
+
# Extract plain numbers from LaTeX expression
|
|
157
|
+
nums = re.findall(r"[-+]?\d+(?:\.\d+)?", latex_expr)
|
|
158
|
+
if nums:
|
|
159
|
+
return nums[-1]
|
|
160
|
+
|
|
161
|
+
# If no plain numbers, return the cleaned LaTeX
|
|
162
|
+
return re.sub(r"[\\{}\[\]]", "", latex_expr)
|
|
163
|
+
|
|
164
|
+
# If we've reached here, try a more aggressive approach for common words
|
|
165
|
+
for word in ["Paris", "London", "yes", "no", "true", "false"]:
|
|
166
|
+
if word.lower() in text.lower():
|
|
167
|
+
return word
|
|
168
|
+
|
|
169
|
+
# Fall back to normalized text for short texts
|
|
170
|
+
if len(text) < 50:
|
|
171
|
+
return normalize_text(text)
|
|
172
|
+
return ""
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def compare_math_expressions(pred: str, gt: str) -> float:
|
|
176
|
+
"""
|
|
177
|
+
Compare two mathematical expressions for equivalence.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
pred: Predicted math expression
|
|
181
|
+
gt: Ground truth math expression
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
Similarity score between 0.0 and 1.0
|
|
185
|
+
"""
|
|
186
|
+
if not pred and not gt:
|
|
187
|
+
return 1.0
|
|
188
|
+
if not pred or not gt:
|
|
189
|
+
return 0.0
|
|
190
|
+
|
|
191
|
+
pred_norm = normalize_text(pred)
|
|
192
|
+
gt_norm = normalize_text(gt)
|
|
193
|
+
|
|
194
|
+
if pred_norm == gt_norm:
|
|
195
|
+
return 1.0
|
|
196
|
+
|
|
197
|
+
if len(gt) > 2 and not gt.replace(".", "").isdigit():
|
|
198
|
+
if gt.lower() in pred.lower() or pred.lower() in gt.lower():
|
|
199
|
+
return 1.0
|
|
200
|
+
|
|
201
|
+
pred_clean = pred_norm.replace(" ", "")
|
|
202
|
+
gt_clean = gt_norm.replace(" ", "")
|
|
203
|
+
|
|
204
|
+
if (pred_clean.startswith("3.14") and gt_clean.startswith("3.14")) or (
|
|
205
|
+
pred_clean.startswith("314") and gt_clean.startswith("314")
|
|
206
|
+
):
|
|
207
|
+
return 1.0
|
|
208
|
+
|
|
209
|
+
try:
|
|
210
|
+
pred_float = float(pred_clean)
|
|
211
|
+
gt_float = float(gt_clean)
|
|
212
|
+
abs_diff = abs(pred_float - gt_float)
|
|
213
|
+
pred_str_decimal_part = str(pred_float).split(".")[1] if "." in str(pred_float) else ""
|
|
214
|
+
gt_str_decimal_part = str(gt_float).split(".")[1] if "." in str(gt_float) else ""
|
|
215
|
+
|
|
216
|
+
if (
|
|
217
|
+
len(pred_str_decimal_part) >= 2
|
|
218
|
+
and len(gt_str_decimal_part) >= 2
|
|
219
|
+
and pred_str_decimal_part[0:2] == gt_str_decimal_part[0:2]
|
|
220
|
+
):
|
|
221
|
+
if abs_diff < 0.01:
|
|
222
|
+
return 1.0
|
|
223
|
+
if max(abs(gt_float), 0.001) > 0 and abs_diff / max(abs(gt_float), 0.001) < 0.05:
|
|
224
|
+
return 0.9
|
|
225
|
+
except (ValueError, ZeroDivisionError, IndexError):
|
|
226
|
+
pass
|
|
227
|
+
|
|
228
|
+
pred_decimal_from_fraction: Optional[float] = None
|
|
229
|
+
if "/" in pred_clean and pred_clean.count("/") == 1:
|
|
230
|
+
try:
|
|
231
|
+
num, denom = pred_clean.split("/")
|
|
232
|
+
pred_decimal_from_fraction = float(num) / float(denom)
|
|
233
|
+
except (ValueError, ZeroDivisionError):
|
|
234
|
+
pass
|
|
235
|
+
|
|
236
|
+
gt_decimal_from_fraction: Optional[float] = None
|
|
237
|
+
if "/" in gt_clean and gt_clean.count("/") == 1:
|
|
238
|
+
try:
|
|
239
|
+
num, denom = gt_clean.split("/")
|
|
240
|
+
gt_decimal_from_fraction = float(num) / float(denom)
|
|
241
|
+
except (ValueError, ZeroDivisionError):
|
|
242
|
+
pass
|
|
243
|
+
|
|
244
|
+
try:
|
|
245
|
+
pred_val_inter: Optional[float] = None
|
|
246
|
+
if pred_decimal_from_fraction is not None:
|
|
247
|
+
pred_val_inter = pred_decimal_from_fraction
|
|
248
|
+
else:
|
|
249
|
+
try:
|
|
250
|
+
pred_val_inter = float(pred_clean)
|
|
251
|
+
except ValueError:
|
|
252
|
+
pass
|
|
253
|
+
|
|
254
|
+
gt_val_inter: Optional[float] = None
|
|
255
|
+
if gt_decimal_from_fraction is not None:
|
|
256
|
+
gt_val_inter = gt_decimal_from_fraction
|
|
257
|
+
else:
|
|
258
|
+
try:
|
|
259
|
+
gt_val_inter = float(gt_clean)
|
|
260
|
+
except ValueError:
|
|
261
|
+
pass
|
|
262
|
+
|
|
263
|
+
if pred_val_inter is None or gt_val_inter is None:
|
|
264
|
+
return string_similarity(pred_norm, gt_norm)
|
|
265
|
+
|
|
266
|
+
pred_value: float = cast(float, pred_val_inter)
|
|
267
|
+
gt_value: float = cast(float, gt_val_inter)
|
|
268
|
+
|
|
269
|
+
if pred_value == gt_value:
|
|
270
|
+
return 1.0
|
|
271
|
+
|
|
272
|
+
abs_error = abs(pred_value - gt_value)
|
|
273
|
+
abs_tolerance = 0.1
|
|
274
|
+
if abs(gt_value) < 0.1:
|
|
275
|
+
abs_tolerance = 0.001
|
|
276
|
+
elif abs(gt_value) < 1.0:
|
|
277
|
+
abs_tolerance = 0.01
|
|
278
|
+
|
|
279
|
+
if abs_error <= abs_tolerance:
|
|
280
|
+
return 1.0
|
|
281
|
+
|
|
282
|
+
if gt_value != 0:
|
|
283
|
+
relative_error = abs_error / abs(gt_value)
|
|
284
|
+
if relative_error < 0.001:
|
|
285
|
+
return 1.0
|
|
286
|
+
if relative_error < 0.01:
|
|
287
|
+
return 0.9
|
|
288
|
+
if relative_error < 0.05:
|
|
289
|
+
return 0.8
|
|
290
|
+
if relative_error < 0.1:
|
|
291
|
+
return 0.5
|
|
292
|
+
if relative_error < 0.3:
|
|
293
|
+
return 0.3
|
|
294
|
+
return 0.0
|
|
295
|
+
else:
|
|
296
|
+
if abs_error < 0.01:
|
|
297
|
+
return 1.0
|
|
298
|
+
if abs_error < 0.1:
|
|
299
|
+
return 0.5
|
|
300
|
+
return 0.0
|
|
301
|
+
except (ValueError, TypeError):
|
|
302
|
+
return string_similarity(pred_norm, gt_norm)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def string_similarity(s1: str, s2: str) -> float:
|
|
306
|
+
if not s1 and not s2:
|
|
307
|
+
return 1.0
|
|
308
|
+
if not s1 or not s2:
|
|
309
|
+
return 0.0
|
|
310
|
+
words1, words2 = set(s1.split()), set(s2.split())
|
|
311
|
+
if not words1 and not words2:
|
|
312
|
+
return 1.0
|
|
313
|
+
intersection = len(words1.intersection(words2))
|
|
314
|
+
union = len(words1.union(words2))
|
|
315
|
+
return intersection / union if union > 0 else 0.0
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
@reward_function
|
|
319
|
+
def accuracy_reward(
|
|
320
|
+
messages: Union[List[Message], List[Dict[str, Any]]],
|
|
321
|
+
ground_truth: Union[List[Message], List[Dict[str, Any]]],
|
|
322
|
+
extract_fn: Optional[Callable[[str], str]] = None,
|
|
323
|
+
compare_fn: Optional[Callable[[str, str], float]] = None,
|
|
324
|
+
**kwargs: Any,
|
|
325
|
+
) -> EvaluateResult:
|
|
326
|
+
model_response_text = ""
|
|
327
|
+
if not messages:
|
|
328
|
+
return EvaluateResult(
|
|
329
|
+
score=0.0,
|
|
330
|
+
reason="No messages provided.",
|
|
331
|
+
metrics={"accuracy": MetricResult(score=0.0, is_score_valid=False, reason="No messages provided.")},
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
model_last_message = messages[-1]
|
|
335
|
+
if isinstance(model_last_message, Message):
|
|
336
|
+
if model_last_message.role == "assistant" and model_last_message.content is not None:
|
|
337
|
+
model_response_text = model_last_message.content
|
|
338
|
+
else:
|
|
339
|
+
return EvaluateResult(
|
|
340
|
+
score=0.0,
|
|
341
|
+
reason="Last message not valid assistant response.",
|
|
342
|
+
metrics={
|
|
343
|
+
"accuracy": MetricResult(
|
|
344
|
+
score=0.0,
|
|
345
|
+
is_score_valid=False,
|
|
346
|
+
reason="Invalid assistant response.",
|
|
347
|
+
)
|
|
348
|
+
},
|
|
349
|
+
)
|
|
350
|
+
elif isinstance(model_last_message, dict):
|
|
351
|
+
if model_last_message.get("role") == "assistant" and model_last_message.get("content") is not None:
|
|
352
|
+
model_response_text = model_last_message.get("content", "")
|
|
353
|
+
else:
|
|
354
|
+
return EvaluateResult(
|
|
355
|
+
score=0.0,
|
|
356
|
+
reason="Last message not valid assistant response (dict).",
|
|
357
|
+
metrics={
|
|
358
|
+
"accuracy": MetricResult(
|
|
359
|
+
score=0.0,
|
|
360
|
+
is_score_valid=False,
|
|
361
|
+
reason="Invalid assistant response (dict).",
|
|
362
|
+
)
|
|
363
|
+
},
|
|
364
|
+
)
|
|
365
|
+
else:
|
|
366
|
+
return EvaluateResult(
|
|
367
|
+
score=0.0,
|
|
368
|
+
reason=f"Unexpected type for last message: {type(model_last_message)}.",
|
|
369
|
+
metrics={"accuracy": MetricResult(score=0.0, is_score_valid=False, reason="Invalid message type.")},
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
ground_truth_comparison_text = ""
|
|
373
|
+
if not ground_truth or not isinstance(ground_truth, list) or len(ground_truth) == 0:
|
|
374
|
+
return EvaluateResult(
|
|
375
|
+
score=0.0,
|
|
376
|
+
reason="Ground truth not provided/invalid.",
|
|
377
|
+
metrics={
|
|
378
|
+
"accuracy": MetricResult(
|
|
379
|
+
score=0.0,
|
|
380
|
+
is_score_valid=False,
|
|
381
|
+
reason="Invalid ground truth format.",
|
|
382
|
+
)
|
|
383
|
+
},
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
first_gt_message = ground_truth[0]
|
|
387
|
+
if isinstance(first_gt_message, Message):
|
|
388
|
+
if first_gt_message.content is not None:
|
|
389
|
+
ground_truth_comparison_text = first_gt_message.content
|
|
390
|
+
else:
|
|
391
|
+
return EvaluateResult(
|
|
392
|
+
score=0.0,
|
|
393
|
+
reason="First GT message has no content.",
|
|
394
|
+
metrics={
|
|
395
|
+
"accuracy": MetricResult(
|
|
396
|
+
score=0.0,
|
|
397
|
+
is_score_valid=False,
|
|
398
|
+
reason="Ground truth content missing.",
|
|
399
|
+
)
|
|
400
|
+
},
|
|
401
|
+
)
|
|
402
|
+
elif isinstance(first_gt_message, dict):
|
|
403
|
+
if first_gt_message.get("content") is not None:
|
|
404
|
+
ground_truth_comparison_text = first_gt_message.get("content", "")
|
|
405
|
+
else:
|
|
406
|
+
return EvaluateResult(
|
|
407
|
+
score=0.0,
|
|
408
|
+
reason="First GT message (dict) has no content.",
|
|
409
|
+
metrics={
|
|
410
|
+
"accuracy": MetricResult(
|
|
411
|
+
score=0.0,
|
|
412
|
+
is_score_valid=False,
|
|
413
|
+
reason="GT content missing (dict).",
|
|
414
|
+
)
|
|
415
|
+
},
|
|
416
|
+
)
|
|
417
|
+
else:
|
|
418
|
+
return EvaluateResult(
|
|
419
|
+
score=0.0,
|
|
420
|
+
reason=f"Unexpected type for first GT message: {type(first_gt_message)}.",
|
|
421
|
+
metrics={"accuracy": MetricResult(score=0.0, is_score_valid=False, reason="Invalid GT message type.")},
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
extracted_answer = extract_fn(model_response_text) if extract_fn else extract_math_expression(model_response_text)
|
|
425
|
+
if (
|
|
426
|
+
not extracted_answer
|
|
427
|
+
and model_response_text
|
|
428
|
+
and len(ground_truth_comparison_text) > 2
|
|
429
|
+
and ground_truth_comparison_text.lower() in model_response_text.lower()
|
|
430
|
+
):
|
|
431
|
+
extracted_answer = ground_truth_comparison_text
|
|
432
|
+
|
|
433
|
+
has_extracted = bool(extracted_answer)
|
|
434
|
+
similarity_score = (
|
|
435
|
+
compare_fn(extracted_answer, ground_truth_comparison_text)
|
|
436
|
+
if compare_fn
|
|
437
|
+
else compare_math_expressions(extracted_answer, ground_truth_comparison_text)
|
|
438
|
+
)
|
|
439
|
+
success = similarity_score >= 0.9
|
|
440
|
+
reason = f"Expected: '{ground_truth_comparison_text}', Extracted: '{extracted_answer}', Similarity: {similarity_score:.2f}"
|
|
441
|
+
|
|
442
|
+
metrics = {
|
|
443
|
+
"answer_extraction": MetricResult(
|
|
444
|
+
score=1.0 if has_extracted else 0.0,
|
|
445
|
+
is_score_valid=has_extracted,
|
|
446
|
+
reason=(f"Extracted answer: '{extracted_answer}'" if has_extracted else "Failed to extract answer"),
|
|
447
|
+
),
|
|
448
|
+
"answer_accuracy": MetricResult(
|
|
449
|
+
score=similarity_score,
|
|
450
|
+
is_score_valid=success,
|
|
451
|
+
reason=f"Answer similarity: {similarity_score:.2f}",
|
|
452
|
+
),
|
|
453
|
+
}
|
|
454
|
+
return EvaluateResult(score=similarity_score, reason=reason, metrics=metrics)
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Reward function that combines accuracy with cosine-scaled length rewards.
|
|
3
|
+
|
|
4
|
+
This module provides a reward function that evaluates both the accuracy of
|
|
5
|
+
model responses and their length efficiency, combining them into a single
|
|
6
|
+
reward score.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import math
|
|
10
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
|
11
|
+
|
|
12
|
+
from ..models import EvaluateResult, Message, MetricResult
|
|
13
|
+
from ..typed_interface import reward_function
|
|
14
|
+
from .accuracy import accuracy_reward
|
|
15
|
+
from .length import count_tokens
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@reward_function
|
|
19
|
+
def cosine_scaled_accuracy_length_reward(
|
|
20
|
+
messages: List[Message],
|
|
21
|
+
ground_truth: Optional[List[Message]] = None,
|
|
22
|
+
extract_fn: Optional[Callable[[str], str]] = None,
|
|
23
|
+
compare_fn: Optional[Callable[[str, str], float]] = None,
|
|
24
|
+
max_length: int = 1000,
|
|
25
|
+
min_value_wrong: float = 0.0,
|
|
26
|
+
max_value_wrong: float = 0.3,
|
|
27
|
+
min_value_correct: float = 0.5,
|
|
28
|
+
max_value_correct: float = 1.0,
|
|
29
|
+
token_method: str = "whitespace",
|
|
30
|
+
correctness_weight: float = 0.7,
|
|
31
|
+
length_weight: float = 0.3,
|
|
32
|
+
**kwargs: Any,
|
|
33
|
+
) -> EvaluateResult:
|
|
34
|
+
"""
|
|
35
|
+
Reward function that combines accuracy with cosine-scaled length rewards.
|
|
36
|
+
|
|
37
|
+
Evaluates both the accuracy of the response and its length efficiency,
|
|
38
|
+
combining them into a single score. Shorter correct answers are rewarded
|
|
39
|
+
more than longer ones, while maintaining separation between answers.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
messages: List of conversation messages
|
|
43
|
+
ground_truth: Expected correct answer
|
|
44
|
+
extract_fn: Optional function to extract answer from text
|
|
45
|
+
compare_fn: Optional function to compare answers
|
|
46
|
+
max_length: Maximum length for scaling (longer responses get penalized)
|
|
47
|
+
min_value_wrong: Minimum reward for wrong answers
|
|
48
|
+
max_value_wrong: Maximum reward for wrong answers
|
|
49
|
+
min_value_correct: Minimum reward for correct answers
|
|
50
|
+
max_value_correct: Maximum reward for correct answers
|
|
51
|
+
token_method: Method to count tokens ('whitespace', 'character', etc)
|
|
52
|
+
correctness_weight: Weight for the accuracy component (default: 0.7)
|
|
53
|
+
length_weight: Weight for the length component (default: 0.3)
|
|
54
|
+
**kwargs: Additional arguments
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
EvaluateResult with score combining accuracy and length
|
|
58
|
+
"""
|
|
59
|
+
if not messages or len(messages) == 0:
|
|
60
|
+
return EvaluateResult(
|
|
61
|
+
score=0.0,
|
|
62
|
+
reason="No messages provided",
|
|
63
|
+
metrics={"combined_reward": MetricResult(score=0.0, is_score_valid=False, reason="No messages provided")},
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
response = messages[-1]
|
|
67
|
+
|
|
68
|
+
if response.role != "assistant" or not response.content:
|
|
69
|
+
return EvaluateResult(
|
|
70
|
+
score=0.0,
|
|
71
|
+
reason="No assistant response found or response has no content",
|
|
72
|
+
metrics={
|
|
73
|
+
"combined_reward": MetricResult(
|
|
74
|
+
score=0.0,
|
|
75
|
+
is_score_valid=False,
|
|
76
|
+
reason="Message not from assistant or has no content",
|
|
77
|
+
)
|
|
78
|
+
},
|
|
79
|
+
)
|
|
80
|
+
text: str = response.content
|
|
81
|
+
|
|
82
|
+
# Step 1: Evaluate accuracy
|
|
83
|
+
accuracy_eval_result = accuracy_reward(
|
|
84
|
+
messages=messages, # Pass the full messages list
|
|
85
|
+
ground_truth=ground_truth, # Pass the ground_truth list
|
|
86
|
+
extract_fn=extract_fn,
|
|
87
|
+
compare_fn=compare_fn,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
accuracy_score = accuracy_eval_result.score
|
|
91
|
+
# Ensure answer_accuracy metric exists, provide a default if not
|
|
92
|
+
answer_accuracy_metric = accuracy_eval_result.metrics.get(
|
|
93
|
+
"answer_accuracy",
|
|
94
|
+
MetricResult(score=0.0, is_score_valid=False, reason="Accuracy metric not found"),
|
|
95
|
+
)
|
|
96
|
+
accuracy_success = answer_accuracy_metric.is_score_valid
|
|
97
|
+
accuracy_reason = accuracy_eval_result.reason or "No reason from accuracy_reward"
|
|
98
|
+
|
|
99
|
+
# Step 2: Calculate length-based score
|
|
100
|
+
token_count = count_tokens(text, method=token_method)
|
|
101
|
+
|
|
102
|
+
# Normalize token count relative to max_length
|
|
103
|
+
progress = min(1.0, token_count / max_length)
|
|
104
|
+
|
|
105
|
+
# Apply cosine scaling
|
|
106
|
+
cosine_factor = math.cos(progress * math.pi)
|
|
107
|
+
|
|
108
|
+
# Determine reward range based on correctness
|
|
109
|
+
if accuracy_success:
|
|
110
|
+
# For correct answers: shorter is better
|
|
111
|
+
min_value = min_value_correct
|
|
112
|
+
max_value = max_value_correct
|
|
113
|
+
success = True
|
|
114
|
+
else:
|
|
115
|
+
# For incorrect answers: longer is slightly better (showing work)
|
|
116
|
+
min_value = max_value_wrong
|
|
117
|
+
max_value = min_value_wrong
|
|
118
|
+
success = False
|
|
119
|
+
|
|
120
|
+
# Calculate length-scaled score
|
|
121
|
+
scale_factor = 0.5 * (max_value - min_value) * (1.0 + cosine_factor)
|
|
122
|
+
length_score = min_value + scale_factor
|
|
123
|
+
|
|
124
|
+
# Step 3: Calculate combined score (weighted average)
|
|
125
|
+
acc_component = accuracy_score * correctness_weight
|
|
126
|
+
len_component = length_score * length_weight
|
|
127
|
+
combined_score = acc_component + len_component
|
|
128
|
+
|
|
129
|
+
# Ensure the combined score is properly bounded
|
|
130
|
+
combined_score = max(0.0, min(1.0, combined_score))
|
|
131
|
+
|
|
132
|
+
# Prepare detailed reason
|
|
133
|
+
reward_type = "reward" if accuracy_success else "penalty"
|
|
134
|
+
length_reason = (
|
|
135
|
+
f"Length-based {reward_type}: {token_count}/{max_length} tokens, " f"cosine factor: {cosine_factor:.2f}"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
combined_reason = (
|
|
139
|
+
f"Combined score (acc:{accuracy_score:.2f}*{correctness_weight:.1f} + "
|
|
140
|
+
f"len:{length_score:.2f}*{length_weight:.1f} = {combined_score:.2f}). "
|
|
141
|
+
f"Accuracy: {accuracy_reason}. Length: {length_reason}"
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# Prepare metrics
|
|
145
|
+
metrics = {
|
|
146
|
+
"combined_reward": MetricResult(
|
|
147
|
+
score=combined_score,
|
|
148
|
+
is_score_valid=success,
|
|
149
|
+
reason=f"Combined score: {combined_score:.2f}",
|
|
150
|
+
),
|
|
151
|
+
"accuracy": MetricResult(
|
|
152
|
+
score=accuracy_score,
|
|
153
|
+
is_score_valid=accuracy_success,
|
|
154
|
+
reason=f"Accuracy: {accuracy_score:.2f}",
|
|
155
|
+
),
|
|
156
|
+
"length": MetricResult(
|
|
157
|
+
score=length_score,
|
|
158
|
+
is_score_valid=token_count <= max_length,
|
|
159
|
+
reason=f"Length: {token_count}/{max_length} tokens, score: {length_score:.2f}", # noqa
|
|
160
|
+
),
|
|
161
|
+
"token_count": MetricResult(
|
|
162
|
+
score=min(1.0, max(0.0, 1.0 - progress)),
|
|
163
|
+
is_score_valid=token_count <= max_length,
|
|
164
|
+
reason=f"Token count: {token_count}/{max_length}",
|
|
165
|
+
),
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
return EvaluateResult(
|
|
169
|
+
score=combined_score,
|
|
170
|
+
reason=combined_reason,
|
|
171
|
+
metrics=metrics,
|
|
172
|
+
is_score_valid=combined_score > 0.0,
|
|
173
|
+
)
|