eval-protocol 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. development/__init__.py +1 -0
  2. development/normalize_sandbox_fusion.py +628 -0
  3. development/utils/__init__.py +1 -0
  4. development/utils/generate_api_key.py +31 -0
  5. development/utils/subprocess_manager.py +481 -0
  6. eval_protocol/__init__.py +86 -0
  7. eval_protocol/__main__.py +10 -0
  8. eval_protocol/_version.py +21 -0
  9. eval_protocol/adapters/__init__.py +1 -0
  10. eval_protocol/adapters/braintrust.py +8 -0
  11. eval_protocol/adapters/trl.py +8 -0
  12. eval_protocol/agent/__init__.py +29 -0
  13. eval_protocol/agent/models.py +69 -0
  14. eval_protocol/agent/orchestrator.py +893 -0
  15. eval_protocol/agent/resource_abc.py +89 -0
  16. eval_protocol/agent/resource_pool.py +184 -0
  17. eval_protocol/agent/resources/__init__.py +44 -0
  18. eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
  19. eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
  20. eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
  21. eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
  22. eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
  23. eval_protocol/agent/resources/docker_resource.py +479 -0
  24. eval_protocol/agent/resources/filesystem_resource.py +371 -0
  25. eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
  26. eval_protocol/agent/resources/http_rollout_resource.py +325 -0
  27. eval_protocol/agent/resources/python_state_resource.py +170 -0
  28. eval_protocol/agent/resources/sql_resource.py +271 -0
  29. eval_protocol/agent/task_manager.py +1064 -0
  30. eval_protocol/agent/tool_registry.py +111 -0
  31. eval_protocol/auth.py +156 -0
  32. eval_protocol/cli.py +425 -0
  33. eval_protocol/cli_commands/__init__.py +1 -0
  34. eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
  35. eval_protocol/cli_commands/common.py +242 -0
  36. eval_protocol/cli_commands/deploy.py +486 -0
  37. eval_protocol/cli_commands/deploy_mcp.py +287 -0
  38. eval_protocol/cli_commands/preview.py +186 -0
  39. eval_protocol/cli_commands/run_eval_cmd.py +202 -0
  40. eval_protocol/common_utils.py +36 -0
  41. eval_protocol/config.py +180 -0
  42. eval_protocol/datasets/__init__.py +1 -0
  43. eval_protocol/datasets/loader.py +521 -0
  44. eval_protocol/evaluation.py +1045 -0
  45. eval_protocol/execution/__init__.py +1 -0
  46. eval_protocol/execution/pipeline.py +920 -0
  47. eval_protocol/gcp_tools.py +484 -0
  48. eval_protocol/generation/cache.py +141 -0
  49. eval_protocol/generation/clients/base.py +67 -0
  50. eval_protocol/generation/clients.py +248 -0
  51. eval_protocol/generic_server.py +165 -0
  52. eval_protocol/integrations/__init__.py +12 -0
  53. eval_protocol/integrations/braintrust.py +51 -0
  54. eval_protocol/integrations/deepeval.py +106 -0
  55. eval_protocol/integrations/openeval.py +40 -0
  56. eval_protocol/integrations/trl.py +187 -0
  57. eval_protocol/mcp/__init__.py +48 -0
  58. eval_protocol/mcp/adapter.py +131 -0
  59. eval_protocol/mcp/client/__init__.py +12 -0
  60. eval_protocol/mcp/client/connection.py +499 -0
  61. eval_protocol/mcp/clients.py +195 -0
  62. eval_protocol/mcp/execution/__init__.py +23 -0
  63. eval_protocol/mcp/execution/base_policy.py +227 -0
  64. eval_protocol/mcp/execution/fireworks_policy.py +209 -0
  65. eval_protocol/mcp/execution/manager.py +506 -0
  66. eval_protocol/mcp/execution/policy.py +421 -0
  67. eval_protocol/mcp/grid_renderer.py +54 -0
  68. eval_protocol/mcp/mcpgym.py +637 -0
  69. eval_protocol/mcp/process_manager.py +177 -0
  70. eval_protocol/mcp/session/__init__.py +11 -0
  71. eval_protocol/mcp/session/manager.py +228 -0
  72. eval_protocol/mcp/simple_process_manager.py +291 -0
  73. eval_protocol/mcp/simulation_server.py +458 -0
  74. eval_protocol/mcp/types.py +80 -0
  75. eval_protocol/mcp_agent/__init__.py +1 -0
  76. eval_protocol/mcp_agent/config.py +147 -0
  77. eval_protocol/mcp_agent/intermediary_server.py +542 -0
  78. eval_protocol/mcp_agent/main.py +210 -0
  79. eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
  80. eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
  81. eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
  82. eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
  83. eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
  84. eval_protocol/mcp_agent/session.py +79 -0
  85. eval_protocol/mcp_env.py +304 -0
  86. eval_protocol/models.py +366 -0
  87. eval_protocol/packaging.py +219 -0
  88. eval_protocol/platform_api.py +360 -0
  89. eval_protocol/playback_policy.py +396 -0
  90. eval_protocol/resources.py +128 -0
  91. eval_protocol/reward_function.py +410 -0
  92. eval_protocol/rewards/__init__.py +94 -0
  93. eval_protocol/rewards/accuracy.py +454 -0
  94. eval_protocol/rewards/accuracy_length.py +173 -0
  95. eval_protocol/rewards/apps_coding_reward.py +331 -0
  96. eval_protocol/rewards/apps_execution_utils.py +149 -0
  97. eval_protocol/rewards/apps_testing_util.py +559 -0
  98. eval_protocol/rewards/bfcl_reward.py +313 -0
  99. eval_protocol/rewards/code_execution.py +1620 -0
  100. eval_protocol/rewards/code_execution_utils.py +72 -0
  101. eval_protocol/rewards/cpp_code.py +861 -0
  102. eval_protocol/rewards/deepcoder_reward.py +161 -0
  103. eval_protocol/rewards/format.py +129 -0
  104. eval_protocol/rewards/function_calling.py +541 -0
  105. eval_protocol/rewards/json_schema.py +422 -0
  106. eval_protocol/rewards/language_consistency.py +700 -0
  107. eval_protocol/rewards/lean_prover.py +479 -0
  108. eval_protocol/rewards/length.py +375 -0
  109. eval_protocol/rewards/list_comparison_math_reward.py +221 -0
  110. eval_protocol/rewards/math.py +762 -0
  111. eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
  112. eval_protocol/rewards/reasoning_steps.py +249 -0
  113. eval_protocol/rewards/repetition.py +342 -0
  114. eval_protocol/rewards/tag_count.py +162 -0
  115. eval_protocol/rl_processing.py +82 -0
  116. eval_protocol/server.py +271 -0
  117. eval_protocol/typed_interface.py +260 -0
  118. eval_protocol/utils/__init__.py +8 -0
  119. eval_protocol/utils/batch_evaluation.py +217 -0
  120. eval_protocol/utils/batch_transformation.py +205 -0
  121. eval_protocol/utils/dataset_helpers.py +112 -0
  122. eval_protocol/utils/module_loader.py +56 -0
  123. eval_protocol/utils/packaging_utils.py +108 -0
  124. eval_protocol/utils/static_policy.py +305 -0
  125. eval_protocol-0.0.3.dist-info/METADATA +635 -0
  126. eval_protocol-0.0.3.dist-info/RECORD +130 -0
  127. eval_protocol-0.0.3.dist-info/WHEEL +5 -0
  128. eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
  129. eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
  130. eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
@@ -0,0 +1,454 @@
1
+ # pylint: disable=all
2
+ """
3
+ Reward functions for accuracy evaluation.
4
+
5
+ This module provides reward functions that evaluate the accuracy of model responses
6
+ by comparing them with ground truth answers, optionally using preprocessing steps
7
+ like normalization and LaTeX parsing.
8
+ """
9
+
10
+ import re
11
+ from typing import Any, Callable, Dict, List, Optional, Union, cast
12
+
13
+ from ..models import EvaluateResult, Message, MetricResult
14
+ from ..typed_interface import reward_function
15
+
16
+
17
+ def normalize_text(text: str) -> str:
18
+ """
19
+ Normalize text for comparison by removing excess whitespace, punctuation.
20
+
21
+ Args:
22
+ text: The text to normalize
23
+
24
+ Returns:
25
+ Normalized text string
26
+ """
27
+ text = text.lower()
28
+ text = re.sub(r"\s+", " ", text)
29
+ text = re.sub(r'[,.;:!?"\']', "", text)
30
+
31
+ # Remove parentheses, brackets, etc. that often appear in math expressions
32
+ # but keep their contents
33
+ text = re.sub(r"[\(\)\[\]\{\}]", "", text)
34
+ text = re.sub(r"[^\w\s\d+-/*=]", "", text)
35
+ text = text.replace("×", "*").replace("÷", "/")
36
+
37
+ return text.strip()
38
+
39
+
40
+ def extract_math_expression(text: str) -> str:
41
+ """
42
+ Extract mathematical expressions from text.
43
+
44
+ This function attempts to find the final answer in mathematical texts,
45
+ handling both numerical answers and expressions.
46
+
47
+ Args:
48
+ text: Text that might contain mathematical expressions
49
+
50
+ Returns:
51
+ Extracted mathematical expression or normalized text if no clear
52
+ expression is found
53
+ """
54
+ # Try to find answer patterns like "= 42" or "answer is 42"
55
+ answer_patterns = [
56
+ # Common exact answer formats
57
+ r"(?:answer|result|solution)(?:\s+is|\s*[:=])\s*(?:x\s*=\s*)?([-+]?\d+(?:\.\d+)?(?:/\d+)?)",
58
+ r"(?:therefore|thus|so)[,:]?\s*(?:x\s*=\s*)?([-+]?\d+(?:\.\d+)?(?:/\d+)?)",
59
+ r"(?:the value of|value)\s*(?:x|y|z)\s*(?:is|=)\s*([-+]?\d+(?:\.\d+)?(?:/\d+)?)",
60
+ r"x\s*=\s*([-+]?\d+(?:\.\d+)?(?:/\d+)?)", # x = 4
61
+ r"(?:=|equals)\s*([-+]?\d+(?:\.\d+)?(?:/\d+)?)",
62
+ # Common answer formats with parentheses
63
+ r"(?:answer|result|solution)[^0-9\n.]*?is[^0-9\n.]*?((?:\([-+]?\))?(?:\d+(?:\.\d+)?(?:/\d+)?))",
64
+ r"(?:answer|result|value)[^0-9\n.]*?((?:\([-+]?\))?(?:\d+(?:\.\d+)?(?:/\d+)?))",
65
+ # Special cases for pi
66
+ r"(?:answer|result|value|=)\s*(?:is\s*)?(?:π|pi)",
67
+ r"(?:answer|result|value|=)\s*(?:is\s*)?(\d+(?:\.\d+)?π)",
68
+ r"(?:answer|result|value|=)\s*(?:is\s*)?π(?:\s*=\s*)?(?:≈\s*)?(3\.14\d*)",
69
+ # Numerical answers with units
70
+ r"(?:answer|result|value|=)\s*(?:is\s*)?([-+]?\d+(?:\.\d+)?)\s*(?:meters|feet|kg|seconds)",
71
+ # LaTeX patterns
72
+ r"\$x\s*=\s*([-+]?\d+(?:\.\d+)?(?:/\d+)?)\$", # LaTeX: $x = 4$
73
+ # Decimal approximations
74
+ r"(?:approximately|about|≈|~)\s*([-+]?\d+\.\d+)",
75
+ ]
76
+
77
+ # Check patterns in both original and lowercase text
78
+ for text_variant in [text, text.lower()]:
79
+ for pattern in answer_patterns:
80
+ match = re.search(pattern, text_variant, re.IGNORECASE)
81
+ if match:
82
+ # Check if this is a pi-only match
83
+ if pattern == r"(?:answer|result|value|=)\s*(?:is\s*)?(?:π|pi)":
84
+ return "3.14159" # Return standard pi approximation
85
+
86
+ if match.groups():
87
+ result = match.group(1).strip()
88
+ # Clean up any trailing punctuation
89
+ result = re.sub(r"[.,;:]$", "", result)
90
+
91
+ # Handle pi symbols in the answer
92
+ if "π" in result or "pi" in result.lower():
93
+ result = result.replace("π", "").replace("Pi", "").replace("pi", "")
94
+ try:
95
+ # If it's just a coefficient of pi, convert to decimal
96
+ if result.strip() in ("", "1"):
97
+ return "3.14159" # π alone or 1π
98
+ else:
99
+ # Try to convert coefficient to float and multiply by pi
100
+ coef = float(result.strip())
101
+ return str(coef * 3.14159)
102
+ except (ValueError, TypeError):
103
+ # If conversion fails, return the original with pi
104
+ return result
105
+
106
+ return result
107
+
108
+ # Check for answers in the last line (common in math problems)
109
+ lines = text.strip().split("\n")
110
+ for i in range(min(3, len(lines))): # Check last 3 lines
111
+ last_line = lines[-(i + 1)].strip()
112
+ if "answer" in last_line.lower() or "result" in last_line.lower() or "solution" in last_line.lower():
113
+ # Extract numbers from the last line
114
+ numbers = re.findall(r"[-+]?\d+(?:\.\d+)?", last_line)
115
+ if numbers:
116
+ return numbers[-1] # Take the last number
117
+
118
+ # Direct search for numbers that might be answers
119
+ # Only use as a fallback for short responses with few numbers
120
+ if len(text) < 200: # Only for short answers
121
+ # Count decimal numbers in text
122
+ numbers = re.findall(r"(?:^|\s|[^\w])([-+]?\d+(?:\.\d+)?)(?:\s|$|[^\w])", text)
123
+ if len(numbers) == 1: # If there's only one number, it's likely the answer
124
+ return numbers[0]
125
+ elif numbers and len(text.split()) < 30: # Very short text with numbers
126
+ # Take the last number in a short response
127
+ return numbers[-1]
128
+
129
+ # Look for capitalized city names or other proper nouns as answers
130
+ if re.search(r"capital|city|country|president|largest|smallest", text.lower()):
131
+ noun_pattern = r"is\s+([A-Z][a-zA-Z]+(?:\s+[A-Z][a-zA-Z]+)*)"
132
+ match = re.search(noun_pattern, text)
133
+ if match:
134
+ return match.group(1).strip()
135
+
136
+ # Look for LaTeX math expressions
137
+ latex_patterns = [
138
+ r"\$x\s*=\s*([^$]+)\$", # Inline math with x = ...
139
+ r"\$([^$]+)\$", # Inline math: $...$
140
+ r"\\\((.*?)\\\)", # Inline math: \(...\)
141
+ r"\\\[(.*?)\\\]", # Display math: \[...\]
142
+ ]
143
+
144
+ for pattern in latex_patterns:
145
+ matches = re.findall(pattern, text)
146
+ if matches:
147
+ # Process the last match which is often the final answer
148
+ latex_expr = matches[-1].strip()
149
+
150
+ # Try to extract numbers from LaTeX
151
+ if "=" in latex_expr:
152
+ # If there's an equals sign, take what's on the right
153
+ parts = latex_expr.split("=")
154
+ latex_expr = parts[-1].strip()
155
+
156
+ # Extract plain numbers from LaTeX expression
157
+ nums = re.findall(r"[-+]?\d+(?:\.\d+)?", latex_expr)
158
+ if nums:
159
+ return nums[-1]
160
+
161
+ # If no plain numbers, return the cleaned LaTeX
162
+ return re.sub(r"[\\{}\[\]]", "", latex_expr)
163
+
164
+ # If we've reached here, try a more aggressive approach for common words
165
+ for word in ["Paris", "London", "yes", "no", "true", "false"]:
166
+ if word.lower() in text.lower():
167
+ return word
168
+
169
+ # Fall back to normalized text for short texts
170
+ if len(text) < 50:
171
+ return normalize_text(text)
172
+ return ""
173
+
174
+
175
+ def compare_math_expressions(pred: str, gt: str) -> float:
176
+ """
177
+ Compare two mathematical expressions for equivalence.
178
+
179
+ Args:
180
+ pred: Predicted math expression
181
+ gt: Ground truth math expression
182
+
183
+ Returns:
184
+ Similarity score between 0.0 and 1.0
185
+ """
186
+ if not pred and not gt:
187
+ return 1.0
188
+ if not pred or not gt:
189
+ return 0.0
190
+
191
+ pred_norm = normalize_text(pred)
192
+ gt_norm = normalize_text(gt)
193
+
194
+ if pred_norm == gt_norm:
195
+ return 1.0
196
+
197
+ if len(gt) > 2 and not gt.replace(".", "").isdigit():
198
+ if gt.lower() in pred.lower() or pred.lower() in gt.lower():
199
+ return 1.0
200
+
201
+ pred_clean = pred_norm.replace(" ", "")
202
+ gt_clean = gt_norm.replace(" ", "")
203
+
204
+ if (pred_clean.startswith("3.14") and gt_clean.startswith("3.14")) or (
205
+ pred_clean.startswith("314") and gt_clean.startswith("314")
206
+ ):
207
+ return 1.0
208
+
209
+ try:
210
+ pred_float = float(pred_clean)
211
+ gt_float = float(gt_clean)
212
+ abs_diff = abs(pred_float - gt_float)
213
+ pred_str_decimal_part = str(pred_float).split(".")[1] if "." in str(pred_float) else ""
214
+ gt_str_decimal_part = str(gt_float).split(".")[1] if "." in str(gt_float) else ""
215
+
216
+ if (
217
+ len(pred_str_decimal_part) >= 2
218
+ and len(gt_str_decimal_part) >= 2
219
+ and pred_str_decimal_part[0:2] == gt_str_decimal_part[0:2]
220
+ ):
221
+ if abs_diff < 0.01:
222
+ return 1.0
223
+ if max(abs(gt_float), 0.001) > 0 and abs_diff / max(abs(gt_float), 0.001) < 0.05:
224
+ return 0.9
225
+ except (ValueError, ZeroDivisionError, IndexError):
226
+ pass
227
+
228
+ pred_decimal_from_fraction: Optional[float] = None
229
+ if "/" in pred_clean and pred_clean.count("/") == 1:
230
+ try:
231
+ num, denom = pred_clean.split("/")
232
+ pred_decimal_from_fraction = float(num) / float(denom)
233
+ except (ValueError, ZeroDivisionError):
234
+ pass
235
+
236
+ gt_decimal_from_fraction: Optional[float] = None
237
+ if "/" in gt_clean and gt_clean.count("/") == 1:
238
+ try:
239
+ num, denom = gt_clean.split("/")
240
+ gt_decimal_from_fraction = float(num) / float(denom)
241
+ except (ValueError, ZeroDivisionError):
242
+ pass
243
+
244
+ try:
245
+ pred_val_inter: Optional[float] = None
246
+ if pred_decimal_from_fraction is not None:
247
+ pred_val_inter = pred_decimal_from_fraction
248
+ else:
249
+ try:
250
+ pred_val_inter = float(pred_clean)
251
+ except ValueError:
252
+ pass
253
+
254
+ gt_val_inter: Optional[float] = None
255
+ if gt_decimal_from_fraction is not None:
256
+ gt_val_inter = gt_decimal_from_fraction
257
+ else:
258
+ try:
259
+ gt_val_inter = float(gt_clean)
260
+ except ValueError:
261
+ pass
262
+
263
+ if pred_val_inter is None or gt_val_inter is None:
264
+ return string_similarity(pred_norm, gt_norm)
265
+
266
+ pred_value: float = cast(float, pred_val_inter)
267
+ gt_value: float = cast(float, gt_val_inter)
268
+
269
+ if pred_value == gt_value:
270
+ return 1.0
271
+
272
+ abs_error = abs(pred_value - gt_value)
273
+ abs_tolerance = 0.1
274
+ if abs(gt_value) < 0.1:
275
+ abs_tolerance = 0.001
276
+ elif abs(gt_value) < 1.0:
277
+ abs_tolerance = 0.01
278
+
279
+ if abs_error <= abs_tolerance:
280
+ return 1.0
281
+
282
+ if gt_value != 0:
283
+ relative_error = abs_error / abs(gt_value)
284
+ if relative_error < 0.001:
285
+ return 1.0
286
+ if relative_error < 0.01:
287
+ return 0.9
288
+ if relative_error < 0.05:
289
+ return 0.8
290
+ if relative_error < 0.1:
291
+ return 0.5
292
+ if relative_error < 0.3:
293
+ return 0.3
294
+ return 0.0
295
+ else:
296
+ if abs_error < 0.01:
297
+ return 1.0
298
+ if abs_error < 0.1:
299
+ return 0.5
300
+ return 0.0
301
+ except (ValueError, TypeError):
302
+ return string_similarity(pred_norm, gt_norm)
303
+
304
+
305
+ def string_similarity(s1: str, s2: str) -> float:
306
+ if not s1 and not s2:
307
+ return 1.0
308
+ if not s1 or not s2:
309
+ return 0.0
310
+ words1, words2 = set(s1.split()), set(s2.split())
311
+ if not words1 and not words2:
312
+ return 1.0
313
+ intersection = len(words1.intersection(words2))
314
+ union = len(words1.union(words2))
315
+ return intersection / union if union > 0 else 0.0
316
+
317
+
318
+ @reward_function
319
+ def accuracy_reward(
320
+ messages: Union[List[Message], List[Dict[str, Any]]],
321
+ ground_truth: Union[List[Message], List[Dict[str, Any]]],
322
+ extract_fn: Optional[Callable[[str], str]] = None,
323
+ compare_fn: Optional[Callable[[str, str], float]] = None,
324
+ **kwargs: Any,
325
+ ) -> EvaluateResult:
326
+ model_response_text = ""
327
+ if not messages:
328
+ return EvaluateResult(
329
+ score=0.0,
330
+ reason="No messages provided.",
331
+ metrics={"accuracy": MetricResult(score=0.0, is_score_valid=False, reason="No messages provided.")},
332
+ )
333
+
334
+ model_last_message = messages[-1]
335
+ if isinstance(model_last_message, Message):
336
+ if model_last_message.role == "assistant" and model_last_message.content is not None:
337
+ model_response_text = model_last_message.content
338
+ else:
339
+ return EvaluateResult(
340
+ score=0.0,
341
+ reason="Last message not valid assistant response.",
342
+ metrics={
343
+ "accuracy": MetricResult(
344
+ score=0.0,
345
+ is_score_valid=False,
346
+ reason="Invalid assistant response.",
347
+ )
348
+ },
349
+ )
350
+ elif isinstance(model_last_message, dict):
351
+ if model_last_message.get("role") == "assistant" and model_last_message.get("content") is not None:
352
+ model_response_text = model_last_message.get("content", "")
353
+ else:
354
+ return EvaluateResult(
355
+ score=0.0,
356
+ reason="Last message not valid assistant response (dict).",
357
+ metrics={
358
+ "accuracy": MetricResult(
359
+ score=0.0,
360
+ is_score_valid=False,
361
+ reason="Invalid assistant response (dict).",
362
+ )
363
+ },
364
+ )
365
+ else:
366
+ return EvaluateResult(
367
+ score=0.0,
368
+ reason=f"Unexpected type for last message: {type(model_last_message)}.",
369
+ metrics={"accuracy": MetricResult(score=0.0, is_score_valid=False, reason="Invalid message type.")},
370
+ )
371
+
372
+ ground_truth_comparison_text = ""
373
+ if not ground_truth or not isinstance(ground_truth, list) or len(ground_truth) == 0:
374
+ return EvaluateResult(
375
+ score=0.0,
376
+ reason="Ground truth not provided/invalid.",
377
+ metrics={
378
+ "accuracy": MetricResult(
379
+ score=0.0,
380
+ is_score_valid=False,
381
+ reason="Invalid ground truth format.",
382
+ )
383
+ },
384
+ )
385
+
386
+ first_gt_message = ground_truth[0]
387
+ if isinstance(first_gt_message, Message):
388
+ if first_gt_message.content is not None:
389
+ ground_truth_comparison_text = first_gt_message.content
390
+ else:
391
+ return EvaluateResult(
392
+ score=0.0,
393
+ reason="First GT message has no content.",
394
+ metrics={
395
+ "accuracy": MetricResult(
396
+ score=0.0,
397
+ is_score_valid=False,
398
+ reason="Ground truth content missing.",
399
+ )
400
+ },
401
+ )
402
+ elif isinstance(first_gt_message, dict):
403
+ if first_gt_message.get("content") is not None:
404
+ ground_truth_comparison_text = first_gt_message.get("content", "")
405
+ else:
406
+ return EvaluateResult(
407
+ score=0.0,
408
+ reason="First GT message (dict) has no content.",
409
+ metrics={
410
+ "accuracy": MetricResult(
411
+ score=0.0,
412
+ is_score_valid=False,
413
+ reason="GT content missing (dict).",
414
+ )
415
+ },
416
+ )
417
+ else:
418
+ return EvaluateResult(
419
+ score=0.0,
420
+ reason=f"Unexpected type for first GT message: {type(first_gt_message)}.",
421
+ metrics={"accuracy": MetricResult(score=0.0, is_score_valid=False, reason="Invalid GT message type.")},
422
+ )
423
+
424
+ extracted_answer = extract_fn(model_response_text) if extract_fn else extract_math_expression(model_response_text)
425
+ if (
426
+ not extracted_answer
427
+ and model_response_text
428
+ and len(ground_truth_comparison_text) > 2
429
+ and ground_truth_comparison_text.lower() in model_response_text.lower()
430
+ ):
431
+ extracted_answer = ground_truth_comparison_text
432
+
433
+ has_extracted = bool(extracted_answer)
434
+ similarity_score = (
435
+ compare_fn(extracted_answer, ground_truth_comparison_text)
436
+ if compare_fn
437
+ else compare_math_expressions(extracted_answer, ground_truth_comparison_text)
438
+ )
439
+ success = similarity_score >= 0.9
440
+ reason = f"Expected: '{ground_truth_comparison_text}', Extracted: '{extracted_answer}', Similarity: {similarity_score:.2f}"
441
+
442
+ metrics = {
443
+ "answer_extraction": MetricResult(
444
+ score=1.0 if has_extracted else 0.0,
445
+ is_score_valid=has_extracted,
446
+ reason=(f"Extracted answer: '{extracted_answer}'" if has_extracted else "Failed to extract answer"),
447
+ ),
448
+ "answer_accuracy": MetricResult(
449
+ score=similarity_score,
450
+ is_score_valid=success,
451
+ reason=f"Answer similarity: {similarity_score:.2f}",
452
+ ),
453
+ }
454
+ return EvaluateResult(score=similarity_score, reason=reason, metrics=metrics)
@@ -0,0 +1,173 @@
1
+ """
2
+ Reward function that combines accuracy with cosine-scaled length rewards.
3
+
4
+ This module provides a reward function that evaluates both the accuracy of
5
+ model responses and their length efficiency, combining them into a single
6
+ reward score.
7
+ """
8
+
9
+ import math
10
+ from typing import Any, Callable, Dict, List, Optional, Union
11
+
12
+ from ..models import EvaluateResult, Message, MetricResult
13
+ from ..typed_interface import reward_function
14
+ from .accuracy import accuracy_reward
15
+ from .length import count_tokens
16
+
17
+
18
+ @reward_function
19
+ def cosine_scaled_accuracy_length_reward(
20
+ messages: List[Message],
21
+ ground_truth: Optional[List[Message]] = None,
22
+ extract_fn: Optional[Callable[[str], str]] = None,
23
+ compare_fn: Optional[Callable[[str, str], float]] = None,
24
+ max_length: int = 1000,
25
+ min_value_wrong: float = 0.0,
26
+ max_value_wrong: float = 0.3,
27
+ min_value_correct: float = 0.5,
28
+ max_value_correct: float = 1.0,
29
+ token_method: str = "whitespace",
30
+ correctness_weight: float = 0.7,
31
+ length_weight: float = 0.3,
32
+ **kwargs: Any,
33
+ ) -> EvaluateResult:
34
+ """
35
+ Reward function that combines accuracy with cosine-scaled length rewards.
36
+
37
+ Evaluates both the accuracy of the response and its length efficiency,
38
+ combining them into a single score. Shorter correct answers are rewarded
39
+ more than longer ones, while maintaining separation between answers.
40
+
41
+ Args:
42
+ messages: List of conversation messages
43
+ ground_truth: Expected correct answer
44
+ extract_fn: Optional function to extract answer from text
45
+ compare_fn: Optional function to compare answers
46
+ max_length: Maximum length for scaling (longer responses get penalized)
47
+ min_value_wrong: Minimum reward for wrong answers
48
+ max_value_wrong: Maximum reward for wrong answers
49
+ min_value_correct: Minimum reward for correct answers
50
+ max_value_correct: Maximum reward for correct answers
51
+ token_method: Method to count tokens ('whitespace', 'character', etc)
52
+ correctness_weight: Weight for the accuracy component (default: 0.7)
53
+ length_weight: Weight for the length component (default: 0.3)
54
+ **kwargs: Additional arguments
55
+
56
+ Returns:
57
+ EvaluateResult with score combining accuracy and length
58
+ """
59
+ if not messages or len(messages) == 0:
60
+ return EvaluateResult(
61
+ score=0.0,
62
+ reason="No messages provided",
63
+ metrics={"combined_reward": MetricResult(score=0.0, is_score_valid=False, reason="No messages provided")},
64
+ )
65
+
66
+ response = messages[-1]
67
+
68
+ if response.role != "assistant" or not response.content:
69
+ return EvaluateResult(
70
+ score=0.0,
71
+ reason="No assistant response found or response has no content",
72
+ metrics={
73
+ "combined_reward": MetricResult(
74
+ score=0.0,
75
+ is_score_valid=False,
76
+ reason="Message not from assistant or has no content",
77
+ )
78
+ },
79
+ )
80
+ text: str = response.content
81
+
82
+ # Step 1: Evaluate accuracy
83
+ accuracy_eval_result = accuracy_reward(
84
+ messages=messages, # Pass the full messages list
85
+ ground_truth=ground_truth, # Pass the ground_truth list
86
+ extract_fn=extract_fn,
87
+ compare_fn=compare_fn,
88
+ )
89
+
90
+ accuracy_score = accuracy_eval_result.score
91
+ # Ensure answer_accuracy metric exists, provide a default if not
92
+ answer_accuracy_metric = accuracy_eval_result.metrics.get(
93
+ "answer_accuracy",
94
+ MetricResult(score=0.0, is_score_valid=False, reason="Accuracy metric not found"),
95
+ )
96
+ accuracy_success = answer_accuracy_metric.is_score_valid
97
+ accuracy_reason = accuracy_eval_result.reason or "No reason from accuracy_reward"
98
+
99
+ # Step 2: Calculate length-based score
100
+ token_count = count_tokens(text, method=token_method)
101
+
102
+ # Normalize token count relative to max_length
103
+ progress = min(1.0, token_count / max_length)
104
+
105
+ # Apply cosine scaling
106
+ cosine_factor = math.cos(progress * math.pi)
107
+
108
+ # Determine reward range based on correctness
109
+ if accuracy_success:
110
+ # For correct answers: shorter is better
111
+ min_value = min_value_correct
112
+ max_value = max_value_correct
113
+ success = True
114
+ else:
115
+ # For incorrect answers: longer is slightly better (showing work)
116
+ min_value = max_value_wrong
117
+ max_value = min_value_wrong
118
+ success = False
119
+
120
+ # Calculate length-scaled score
121
+ scale_factor = 0.5 * (max_value - min_value) * (1.0 + cosine_factor)
122
+ length_score = min_value + scale_factor
123
+
124
+ # Step 3: Calculate combined score (weighted average)
125
+ acc_component = accuracy_score * correctness_weight
126
+ len_component = length_score * length_weight
127
+ combined_score = acc_component + len_component
128
+
129
+ # Ensure the combined score is properly bounded
130
+ combined_score = max(0.0, min(1.0, combined_score))
131
+
132
+ # Prepare detailed reason
133
+ reward_type = "reward" if accuracy_success else "penalty"
134
+ length_reason = (
135
+ f"Length-based {reward_type}: {token_count}/{max_length} tokens, " f"cosine factor: {cosine_factor:.2f}"
136
+ )
137
+
138
+ combined_reason = (
139
+ f"Combined score (acc:{accuracy_score:.2f}*{correctness_weight:.1f} + "
140
+ f"len:{length_score:.2f}*{length_weight:.1f} = {combined_score:.2f}). "
141
+ f"Accuracy: {accuracy_reason}. Length: {length_reason}"
142
+ )
143
+
144
+ # Prepare metrics
145
+ metrics = {
146
+ "combined_reward": MetricResult(
147
+ score=combined_score,
148
+ is_score_valid=success,
149
+ reason=f"Combined score: {combined_score:.2f}",
150
+ ),
151
+ "accuracy": MetricResult(
152
+ score=accuracy_score,
153
+ is_score_valid=accuracy_success,
154
+ reason=f"Accuracy: {accuracy_score:.2f}",
155
+ ),
156
+ "length": MetricResult(
157
+ score=length_score,
158
+ is_score_valid=token_count <= max_length,
159
+ reason=f"Length: {token_count}/{max_length} tokens, score: {length_score:.2f}", # noqa
160
+ ),
161
+ "token_count": MetricResult(
162
+ score=min(1.0, max(0.0, 1.0 - progress)),
163
+ is_score_valid=token_count <= max_length,
164
+ reason=f"Token count: {token_count}/{max_length}",
165
+ ),
166
+ }
167
+
168
+ return EvaluateResult(
169
+ score=combined_score,
170
+ reason=combined_reason,
171
+ metrics=metrics,
172
+ is_score_valid=combined_score > 0.0,
173
+ )