eval-protocol 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. development/__init__.py +1 -0
  2. development/normalize_sandbox_fusion.py +628 -0
  3. development/utils/__init__.py +1 -0
  4. development/utils/generate_api_key.py +31 -0
  5. development/utils/subprocess_manager.py +481 -0
  6. eval_protocol/__init__.py +86 -0
  7. eval_protocol/__main__.py +10 -0
  8. eval_protocol/_version.py +21 -0
  9. eval_protocol/adapters/__init__.py +1 -0
  10. eval_protocol/adapters/braintrust.py +8 -0
  11. eval_protocol/adapters/trl.py +8 -0
  12. eval_protocol/agent/__init__.py +29 -0
  13. eval_protocol/agent/models.py +69 -0
  14. eval_protocol/agent/orchestrator.py +893 -0
  15. eval_protocol/agent/resource_abc.py +89 -0
  16. eval_protocol/agent/resource_pool.py +184 -0
  17. eval_protocol/agent/resources/__init__.py +44 -0
  18. eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
  19. eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
  20. eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
  21. eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
  22. eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
  23. eval_protocol/agent/resources/docker_resource.py +479 -0
  24. eval_protocol/agent/resources/filesystem_resource.py +371 -0
  25. eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
  26. eval_protocol/agent/resources/http_rollout_resource.py +325 -0
  27. eval_protocol/agent/resources/python_state_resource.py +170 -0
  28. eval_protocol/agent/resources/sql_resource.py +271 -0
  29. eval_protocol/agent/task_manager.py +1064 -0
  30. eval_protocol/agent/tool_registry.py +111 -0
  31. eval_protocol/auth.py +156 -0
  32. eval_protocol/cli.py +425 -0
  33. eval_protocol/cli_commands/__init__.py +1 -0
  34. eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
  35. eval_protocol/cli_commands/common.py +242 -0
  36. eval_protocol/cli_commands/deploy.py +486 -0
  37. eval_protocol/cli_commands/deploy_mcp.py +287 -0
  38. eval_protocol/cli_commands/preview.py +186 -0
  39. eval_protocol/cli_commands/run_eval_cmd.py +202 -0
  40. eval_protocol/common_utils.py +36 -0
  41. eval_protocol/config.py +180 -0
  42. eval_protocol/datasets/__init__.py +1 -0
  43. eval_protocol/datasets/loader.py +521 -0
  44. eval_protocol/evaluation.py +1045 -0
  45. eval_protocol/execution/__init__.py +1 -0
  46. eval_protocol/execution/pipeline.py +920 -0
  47. eval_protocol/gcp_tools.py +484 -0
  48. eval_protocol/generation/cache.py +141 -0
  49. eval_protocol/generation/clients/base.py +67 -0
  50. eval_protocol/generation/clients.py +248 -0
  51. eval_protocol/generic_server.py +165 -0
  52. eval_protocol/integrations/__init__.py +12 -0
  53. eval_protocol/integrations/braintrust.py +51 -0
  54. eval_protocol/integrations/deepeval.py +106 -0
  55. eval_protocol/integrations/openeval.py +40 -0
  56. eval_protocol/integrations/trl.py +187 -0
  57. eval_protocol/mcp/__init__.py +48 -0
  58. eval_protocol/mcp/adapter.py +131 -0
  59. eval_protocol/mcp/client/__init__.py +12 -0
  60. eval_protocol/mcp/client/connection.py +499 -0
  61. eval_protocol/mcp/clients.py +195 -0
  62. eval_protocol/mcp/execution/__init__.py +23 -0
  63. eval_protocol/mcp/execution/base_policy.py +227 -0
  64. eval_protocol/mcp/execution/fireworks_policy.py +209 -0
  65. eval_protocol/mcp/execution/manager.py +506 -0
  66. eval_protocol/mcp/execution/policy.py +421 -0
  67. eval_protocol/mcp/grid_renderer.py +54 -0
  68. eval_protocol/mcp/mcpgym.py +637 -0
  69. eval_protocol/mcp/process_manager.py +177 -0
  70. eval_protocol/mcp/session/__init__.py +11 -0
  71. eval_protocol/mcp/session/manager.py +228 -0
  72. eval_protocol/mcp/simple_process_manager.py +291 -0
  73. eval_protocol/mcp/simulation_server.py +458 -0
  74. eval_protocol/mcp/types.py +80 -0
  75. eval_protocol/mcp_agent/__init__.py +1 -0
  76. eval_protocol/mcp_agent/config.py +147 -0
  77. eval_protocol/mcp_agent/intermediary_server.py +542 -0
  78. eval_protocol/mcp_agent/main.py +210 -0
  79. eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
  80. eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
  81. eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
  82. eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
  83. eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
  84. eval_protocol/mcp_agent/session.py +79 -0
  85. eval_protocol/mcp_env.py +304 -0
  86. eval_protocol/models.py +366 -0
  87. eval_protocol/packaging.py +219 -0
  88. eval_protocol/platform_api.py +360 -0
  89. eval_protocol/playback_policy.py +396 -0
  90. eval_protocol/resources.py +128 -0
  91. eval_protocol/reward_function.py +410 -0
  92. eval_protocol/rewards/__init__.py +94 -0
  93. eval_protocol/rewards/accuracy.py +454 -0
  94. eval_protocol/rewards/accuracy_length.py +173 -0
  95. eval_protocol/rewards/apps_coding_reward.py +331 -0
  96. eval_protocol/rewards/apps_execution_utils.py +149 -0
  97. eval_protocol/rewards/apps_testing_util.py +559 -0
  98. eval_protocol/rewards/bfcl_reward.py +313 -0
  99. eval_protocol/rewards/code_execution.py +1620 -0
  100. eval_protocol/rewards/code_execution_utils.py +72 -0
  101. eval_protocol/rewards/cpp_code.py +861 -0
  102. eval_protocol/rewards/deepcoder_reward.py +161 -0
  103. eval_protocol/rewards/format.py +129 -0
  104. eval_protocol/rewards/function_calling.py +541 -0
  105. eval_protocol/rewards/json_schema.py +422 -0
  106. eval_protocol/rewards/language_consistency.py +700 -0
  107. eval_protocol/rewards/lean_prover.py +479 -0
  108. eval_protocol/rewards/length.py +375 -0
  109. eval_protocol/rewards/list_comparison_math_reward.py +221 -0
  110. eval_protocol/rewards/math.py +762 -0
  111. eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
  112. eval_protocol/rewards/reasoning_steps.py +249 -0
  113. eval_protocol/rewards/repetition.py +342 -0
  114. eval_protocol/rewards/tag_count.py +162 -0
  115. eval_protocol/rl_processing.py +82 -0
  116. eval_protocol/server.py +271 -0
  117. eval_protocol/typed_interface.py +260 -0
  118. eval_protocol/utils/__init__.py +8 -0
  119. eval_protocol/utils/batch_evaluation.py +217 -0
  120. eval_protocol/utils/batch_transformation.py +205 -0
  121. eval_protocol/utils/dataset_helpers.py +112 -0
  122. eval_protocol/utils/module_loader.py +56 -0
  123. eval_protocol/utils/packaging_utils.py +108 -0
  124. eval_protocol/utils/static_policy.py +305 -0
  125. eval_protocol-0.0.3.dist-info/METADATA +635 -0
  126. eval_protocol-0.0.3.dist-info/RECORD +130 -0
  127. eval_protocol-0.0.3.dist-info/WHEEL +5 -0
  128. eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
  129. eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
  130. eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
@@ -0,0 +1,762 @@
1
+ """
2
+ Math reward function for evaluating mathematical answer correctness.
3
+
4
+ This module provides functions to evaluate the correctness of mathematical
5
+ answers by extracting numerical values from text using regex patterns and
6
+ comparing them with expected answers.
7
+ """
8
+
9
+ import math
10
+ import re
11
+ from typing import Any, Dict, List, Optional, Set, Tuple, Union
12
+
13
+ from ..models import EvaluateResult, Message, MetricResult
14
+ from ..typed_interface import reward_function
15
+
16
+ _ALGEBRAIC_VARS_SET: Set[str] = {
17
+ "x",
18
+ "y",
19
+ "z",
20
+ "a",
21
+ "b",
22
+ "c",
23
+ "n",
24
+ "t",
25
+ "q",
26
+ "p",
27
+ "r",
28
+ "u",
29
+ "v",
30
+ "w",
31
+ }
32
+
33
+
34
+ def _parse_numeric_string(s: str) -> Optional[float]:
35
+ s = s.strip()
36
+ try:
37
+ if re.fullmatch(r"-?\d+(\.\d+)?", s):
38
+ return float(s)
39
+ m_frac = re.fullmatch(r"(-?\d+(?:\.\d+)?)\s*/\s*(-?\d+(?:\.\d+)?)", s)
40
+ if m_frac:
41
+ num = float(m_frac.group(1))
42
+ den = float(m_frac.group(2))
43
+ return num / den if den != 0 else None
44
+ except (ValueError, ZeroDivisionError):
45
+ return None
46
+ return None
47
+
48
+
49
+ def _is_coefficient(
50
+ text_content: str,
51
+ match_obj: re.Match,
52
+ num_group_idx: int = 1,
53
+ unit_group_idx: Optional[int] = None,
54
+ ) -> bool:
55
+ """
56
+ Checks if a number identified by match_obj in text_content is likely a coefficient.
57
+ """
58
+ unit_candidate = ""
59
+ if unit_group_idx is not None and len(match_obj.groups()) >= unit_group_idx and match_obj.group(unit_group_idx):
60
+ unit_candidate = match_obj.group(unit_group_idx).strip()
61
+
62
+ if unit_candidate and len(unit_candidate) == 1 and unit_candidate.lower() in _ALGEBRAIC_VARS_SET:
63
+ return True
64
+
65
+ idx_after_num_str = match_obj.end(num_group_idx)
66
+
67
+ if idx_after_num_str < len(text_content) and text_content[idx_after_num_str].lower() in _ALGEBRAIC_VARS_SET:
68
+ if idx_after_num_str + 1 == len(text_content) or not text_content[idx_after_num_str + 1].isalnum():
69
+ return True
70
+
71
+ if (
72
+ idx_after_num_str + 1 < len(text_content)
73
+ and text_content[idx_after_num_str] == " "
74
+ and text_content[idx_after_num_str + 1].lower() in _ALGEBRAIC_VARS_SET
75
+ ):
76
+ if idx_after_num_str + 2 == len(text_content) or not text_content[idx_after_num_str + 2].isalnum():
77
+ return True
78
+ return False
79
+
80
+
81
+ def _extract_html_tag_answers(text: str) -> List[Tuple[str, Union[float, str]]]:
82
+ """Extracts answers from <answer> or <ans> HTML-like tags."""
83
+ html_tag_answers: List[Tuple[str, Union[float, str]]] = []
84
+ tag_re = re.compile(
85
+ r"<(?P<tag>answer|ans)\b[^>]*>(?P<inner>.*?)</(?P=tag)>",
86
+ re.IGNORECASE | re.DOTALL,
87
+ )
88
+ for m in tag_re.finditer(text):
89
+ raw = m.group(0)
90
+ inner = m.group("inner").strip()
91
+ inner = re.sub(r"^\$+|^\(+|^\[+|(\$|\)|\])+?$", "", inner).strip()
92
+
93
+ val = _parse_numeric_string(inner)
94
+ if val is not None:
95
+ html_tag_answers.append((raw, val))
96
+ continue
97
+
98
+ m_frac = re.fullmatch(r"\\frac\{(-?\d+(?:\.\d+)?)\}\{(-?\d+(?:\.\d+)?)\}", inner)
99
+ if m_frac:
100
+ try:
101
+ num, den = float(m_frac.group(1)), float(m_frac.group(2))
102
+ if den != 0:
103
+ html_tag_answers.append((raw, num / den))
104
+ continue
105
+ except (ValueError, ZeroDivisionError):
106
+ pass
107
+
108
+ sci = re.fullmatch(r"([-+]?\d[\d,]*(?:\.\d+)?(?:[eE][-+]?\d+)?)", inner)
109
+ if sci:
110
+ try:
111
+ cleaned = sci.group(1).replace(",", "")
112
+ html_tag_answers.append((raw, float(cleaned)))
113
+ continue
114
+ except ValueError:
115
+ pass
116
+
117
+ m_num_unit = re.fullmatch(r"(-?\d+(?:\.\d+)?)[ ]*[a-zA-Z%]+", inner)
118
+ if m_num_unit:
119
+ try:
120
+ html_tag_answers.append((raw, float(m_num_unit.group(1))))
121
+ continue
122
+ except ValueError:
123
+ pass
124
+ return html_tag_answers
125
+
126
+
127
+ def _extract_boxed_latex_answers(
128
+ text: str,
129
+ ) -> Tuple[List[Tuple[str, Union[float, str]]], bool]:
130
+ """
131
+ Extracts answers from \\boxed{} LaTeX expressions.
132
+ Returns a tuple: (list of answers, boolean indicating if any boxed expr was found).
133
+ """
134
+ boxed_answers: List[Tuple[str, Union[float, str]]] = []
135
+ found_any_boxed_expr = False
136
+ for m_boxed in re.finditer(r"\\boxed\s*\{\s*((?:[^{}]|\{[^{}]*\})*?)\s*\}", text):
137
+ found_any_boxed_expr = True
138
+ original_boxed_expr = m_boxed.group(0)
139
+ content = m_boxed.group(1).strip()
140
+
141
+ if not content:
142
+ continue
143
+
144
+ if " or " in content.lower():
145
+ boxed_answers.append((original_boxed_expr, content))
146
+ continue
147
+ if re.fullmatch(r"[A-Ea-e]", content):
148
+ boxed_answers.append((original_boxed_expr, content.upper()))
149
+ continue
150
+
151
+ m_latex_frac = re.fullmatch(r"\\frac\{(-?\d+(?:\.\d+)?)\}\{(-?\d+(?:\.\d+)?)\}", content)
152
+ if m_latex_frac:
153
+ try:
154
+ num = float(m_latex_frac.group(1))
155
+ den = float(m_latex_frac.group(2))
156
+ if den != 0:
157
+ boxed_answers.append((original_boxed_expr, num / den))
158
+ continue
159
+ except (ValueError, ZeroDivisionError):
160
+ pass
161
+
162
+ numeric_val = _parse_numeric_string(content)
163
+ if numeric_val is not None:
164
+ boxed_answers.append((original_boxed_expr, numeric_val))
165
+ continue
166
+
167
+ m_num_unit = re.fullmatch(r"(-?\d+(?:\.\d+)?)\s*([a-zA-Z%]+)", content)
168
+ if m_num_unit:
169
+ try:
170
+ num_val = float(m_num_unit.group(1))
171
+ boxed_answers.append((original_boxed_expr, num_val))
172
+ continue
173
+ except ValueError:
174
+ pass
175
+
176
+ if found_any_boxed_expr and boxed_answers:
177
+ if len(boxed_answers) > 1:
178
+ numeric_values_only = all(isinstance(val, (float, int)) for _, val in boxed_answers)
179
+ if numeric_values_only and len(boxed_answers) > 1:
180
+ first_val_candidate = boxed_answers[0][1]
181
+ if isinstance(first_val_candidate, (float, int)):
182
+ first_numeric_value = float(first_val_candidate)
183
+ all_other_values_identical = True
184
+ if len(boxed_answers) > 1:
185
+ all_other_values_identical = all(
186
+ math.isclose(val, first_numeric_value, rel_tol=1e-9, abs_tol=1e-9)
187
+ for _, val in boxed_answers[1:]
188
+ if isinstance(val, (float, int))
189
+ )
190
+ if all_other_values_identical:
191
+ boxed_answers = [boxed_answers[0]]
192
+ return boxed_answers, found_any_boxed_expr
193
+
194
+
195
+ def extract_numbers(text: str) -> List[Tuple[str, Union[float, str]]]:
196
+ """
197
+ Extracts mathematical answers from text based on a hierarchical priority:
198
+ 1. HTML <answer>/<ans> tags
199
+ 2. Boxed LaTeX expressions (e.g., \\boxed{answer})
200
+ 3. GSM8K-style final answer markers (e.g., #### 123)
201
+ 4. General numeric or LaTeX-formatted numbers as a fallback.
202
+
203
+ Args:
204
+ text: The text to extract answers from.
205
+
206
+ Returns:
207
+ A list of tuples, where each tuple contains the original matched
208
+ string and its normalized value (float for numbers, str for MCQs
209
+ or specific string expressions like "A or B").
210
+ Returns an empty list if no answer is confidently extracted.
211
+ """
212
+ html_answers = _extract_html_tag_answers(text)
213
+ if html_answers:
214
+ return html_answers
215
+
216
+ boxed_answers, found_any_boxed = _extract_boxed_latex_answers(text)
217
+ if found_any_boxed:
218
+ return boxed_answers
219
+
220
+ gsm8k_answers = _extract_gsm8k_answers(text)
221
+ if gsm8k_answers:
222
+ return gsm8k_answers
223
+
224
+ general_answers = _extract_general_numeric_answers(text)
225
+ if general_answers:
226
+ return general_answers
227
+
228
+ return []
229
+
230
+
231
+ def _extract_gsm8k_answers(text: str) -> List[Tuple[str, Union[float, str]]]:
232
+ """Extracts answers from GSM8K-style final answer markers (#### ...)."""
233
+ final_marker_answers: List[Tuple[str, Union[float, str]]] = []
234
+ GSM8K_NUM_CONTENT_PATTERN = r"-?\d{1,3}(?:,\d{3})*(?:\.\d+)?|-?\d+(?:\.\d+)?"
235
+ for m_final in re.finditer(rf"####\s*({GSM8K_NUM_CONTENT_PATTERN})", text):
236
+ original_marker_expr = m_final.group(0)
237
+ num_str_from_regex = m_final.group(1)
238
+ cleaned_num_str = num_str_from_regex.replace(",", "")
239
+ try:
240
+ final_marker_answers.append((original_marker_expr, float(cleaned_num_str)))
241
+ except ValueError:
242
+ pass
243
+ return final_marker_answers
244
+
245
+
246
+ def _extract_general_numeric_answers(text: str) -> List[Tuple[str, Union[float, str]]]:
247
+ """Extracts general numeric or LaTeX-formatted numbers as a fallback."""
248
+ potential_general_matches: List[Dict[str, Any]] = []
249
+
250
+ for latex_block_match in re.finditer(r"\$\$(.*?)\$\$|\$(.*?)\$", text, re.DOTALL):
251
+ content = latex_block_match.group(1) if latex_block_match.group(1) is not None else latex_block_match.group(2)
252
+ offset = latex_block_match.start(1) if latex_block_match.group(1) is not None else latex_block_match.start(2)
253
+ if not content:
254
+ continue
255
+ if content.strip().startswith("\\boxed{") and content.strip().endswith("}"):
256
+ continue
257
+
258
+ for m_frac_latex in re.finditer(r"\\frac\{(-?\d+(?:\.\d+)?)\}\{(-?\d+(?:\.\d+)?)\}", content):
259
+ try:
260
+ num, den = float(m_frac_latex.group(1)), float(m_frac_latex.group(2))
261
+ if den != 0:
262
+ potential_general_matches.append(
263
+ {
264
+ "text": m_frac_latex.group(0),
265
+ "value": num / den,
266
+ "span": (
267
+ m_frac_latex.start(0) + offset,
268
+ m_frac_latex.end(0) + offset,
269
+ ),
270
+ "type_priority": 1,
271
+ }
272
+ )
273
+ except (ValueError, ZeroDivisionError):
274
+ pass
275
+
276
+ for m_sci_latex in re.finditer(r"(-?\d+(?:\.\d+)?)\s*\\times\s*10\^\{(.*?)\}", content):
277
+ try:
278
+ base, exp = float(m_sci_latex.group(1)), float(m_sci_latex.group(2))
279
+ potential_general_matches.append(
280
+ {
281
+ "text": m_sci_latex.group(0),
282
+ "value": base * (10**exp),
283
+ "span": (
284
+ m_sci_latex.start(0) + offset,
285
+ m_sci_latex.end(0) + offset,
286
+ ),
287
+ "type_priority": 2,
288
+ }
289
+ )
290
+ except ValueError:
291
+ pass
292
+
293
+ for m_plain_latex in re.finditer(r"(?<![a-zA-Z0-9_])(-?\d+(?:\.\d+)?)(?![a-zA-Z0-9_])", content):
294
+ if _is_coefficient(text_content=content, match_obj=m_plain_latex, num_group_idx=1):
295
+ continue
296
+ try:
297
+ potential_general_matches.append(
298
+ {
299
+ "text": m_plain_latex.group(1),
300
+ "value": float(m_plain_latex.group(1)),
301
+ "span": (
302
+ m_plain_latex.start(1) + offset,
303
+ m_plain_latex.end(1) + offset,
304
+ ),
305
+ "type_priority": 3,
306
+ }
307
+ )
308
+ except ValueError:
309
+ pass
310
+
311
+ sci_pattern = r"(?<![a-zA-Z0-9_])(-?\d+\.?\d*[eE][-+]?\d+)(?:\s*([a-zA-Z%]+))?"
312
+ for m in re.finditer(sci_pattern, text):
313
+ if _is_coefficient(text_content=text, match_obj=m, num_group_idx=1, unit_group_idx=2):
314
+ continue
315
+ try:
316
+ potential_general_matches.append(
317
+ {
318
+ "text": m.group(0),
319
+ "value": float(m.group(1)),
320
+ "span": m.span(),
321
+ "type_priority": 4,
322
+ }
323
+ )
324
+ except ValueError:
325
+ pass
326
+
327
+ frac_pattern = r"(?<!\d/)(?<!\d)(?<!\.)(-?\d+)\s*/\s*(-?\d+)(?!\.\d)(?!\d*/)(?:\s+(?!(?:and|or)\b)([a-zA-Z%]+)\b)?"
328
+ for m in re.finditer(frac_pattern, text):
329
+ if _is_coefficient(text_content=text, match_obj=m, num_group_idx=1, unit_group_idx=3):
330
+ continue
331
+ try:
332
+ num, den = float(m.group(1)), float(m.group(2))
333
+ if den == 0:
334
+ continue
335
+ num_str_clean, den_str_clean = m.group(1), m.group(2)
336
+ unit_str_clean = m.group(3) or ""
337
+ display_text = f"{num_str_clean}/{den_str_clean}"
338
+ if unit_str_clean:
339
+ display_text += f" {unit_str_clean}"
340
+ potential_general_matches.append(
341
+ {
342
+ "text": display_text,
343
+ "value": num / den,
344
+ "span": m.span(),
345
+ "type_priority": 5,
346
+ }
347
+ )
348
+ except (ValueError, ZeroDivisionError):
349
+ pass
350
+
351
+ comma_num_pattern = r"(?<![a-zA-Z0-9_])(-?\d{1,3}(?:,\d{3})*(?:\.\d+)?)(?:\s*([a-zA-Z%]+))?"
352
+ for m in re.finditer(comma_num_pattern, text):
353
+ if _is_coefficient(text_content=text, match_obj=m, num_group_idx=1, unit_group_idx=2):
354
+ continue
355
+ try:
356
+ potential_general_matches.append(
357
+ {
358
+ "text": m.group(0),
359
+ "value": float(m.group(1).replace(",", "")),
360
+ "span": m.span(),
361
+ "type_priority": 6,
362
+ }
363
+ )
364
+ except ValueError:
365
+ pass
366
+
367
+ decimal_pattern = r"(?<![a-zA-Z0-9_])(?<!,\d{3})(-?\d+\.\d+)(?!\d*[eE])(?:\s*([a-zA-Z%]+))?"
368
+ for m in re.finditer(decimal_pattern, text):
369
+ if _is_coefficient(text_content=text, match_obj=m, num_group_idx=1, unit_group_idx=2):
370
+ continue
371
+ try:
372
+ potential_general_matches.append(
373
+ {
374
+ "text": m.group(0),
375
+ "value": float(m.group(1)),
376
+ "span": m.span(),
377
+ "type_priority": 7,
378
+ }
379
+ )
380
+ except ValueError:
381
+ pass
382
+
383
+ integer_pattern = (
384
+ r"(?<![a-zA-Z0-9_])(?<!\d\.)(-?\d+)(?!\.\d)(?![eE][-+]?\d+)(?!,\d{3})(?!\s*/\s*\d+)(?:\s*([a-zA-Z%]+))?"
385
+ )
386
+ for m in re.finditer(integer_pattern, text):
387
+ if _is_coefficient(text_content=text, match_obj=m, num_group_idx=1, unit_group_idx=2):
388
+ continue
389
+ try:
390
+ potential_general_matches.append(
391
+ {
392
+ "text": m.group(0),
393
+ "value": float(m.group(1)),
394
+ "span": m.span(),
395
+ "type_priority": 8,
396
+ }
397
+ )
398
+ except ValueError:
399
+ pass
400
+
401
+ potential_general_matches.sort(key=lambda x: (x["span"][0], -(x["span"][1] - x["span"][0]), x["type_priority"]))
402
+ filtered_general_answers: List[Tuple[str, Union[float, str]]] = []
403
+ last_covered_end = -1
404
+ for item in potential_general_matches:
405
+ start, end = item["span"]
406
+ if start >= last_covered_end:
407
+ value_to_append = item["value"]
408
+ if isinstance(value_to_append, (int, float)):
409
+ filtered_general_answers.append((item["text"], float(value_to_append)))
410
+ last_covered_end = end
411
+
412
+ if filtered_general_answers:
413
+ return filtered_general_answers
414
+
415
+ return []
416
+
417
+
418
+ def compare_numbers(
419
+ expected: float,
420
+ actual: float,
421
+ relative_tolerance: float = 1e-5,
422
+ absolute_tolerance: float = 1e-8,
423
+ ) -> Tuple[bool, float]:
424
+ is_close = math.isclose(expected, actual, rel_tol=relative_tolerance, abs_tol=absolute_tolerance)
425
+ if is_close:
426
+ return True, 1.0
427
+ try:
428
+ if expected == 0:
429
+ error = abs(actual)
430
+ similarity = max(0.0, 1.0 - min(1.0, error / absolute_tolerance))
431
+ else:
432
+ rel_error = abs((expected - actual) / expected)
433
+ similarity = max(0.0, 1.0 - min(1.0, rel_error / relative_tolerance))
434
+ except (ZeroDivisionError, OverflowError):
435
+ similarity = 0.0
436
+ return False, similarity
437
+
438
+
439
+ def _has_unit_text(full_extracted_text: str, numeric_value: float) -> bool:
440
+ """Checks if the extracted text for a number likely contains a unit."""
441
+ content_to_check = full_extracted_text
442
+ if content_to_check.startswith("\\boxed{") and content_to_check.endswith("}"):
443
+ content_to_check = content_to_check[7:-1].strip()
444
+
445
+ num_str_float = str(numeric_value)
446
+ num_str_int = str(int(numeric_value)) if numeric_value == int(numeric_value) else None
447
+ search_terms = [num_str_float]
448
+ if num_str_int and num_str_int != num_str_float:
449
+ search_terms.append(num_str_int)
450
+
451
+ for term in search_terms:
452
+ found_at = content_to_check.find(term)
453
+ if found_at != -1:
454
+ suffix_start = found_at + len(term)
455
+ if suffix_start < len(content_to_check):
456
+ suffix = content_to_check[suffix_start:].strip().split(" ")[0]
457
+ if suffix and not suffix.replace(".", "", 1).isdigit() and suffix.lower() != "or":
458
+ return True
459
+ return False
460
+
461
+
462
+ def _check_unboxed_or_strictness(
463
+ model_response_content: str,
464
+ gen_answers_extracted: List[Tuple[str, Union[float, str]]],
465
+ metrics: Dict[str, MetricResult],
466
+ ) -> Optional[EvaluateResult]:
467
+ """Checks for 'unboxed or' strictness violation."""
468
+ raw_extracted_numbers = extract_numbers(model_response_content)
469
+ if (
470
+ " or " in model_response_content.lower()
471
+ and sum(1 for _, val_check in raw_extracted_numbers if isinstance(val_check, (float, int))) > 1
472
+ and not (
473
+ len(gen_answers_extracted) == 1
474
+ and isinstance(gen_answers_extracted[0][1], str)
475
+ and " or " in gen_answers_extracted[0][1].lower()
476
+ )
477
+ ):
478
+ specific_reason_detail = (
479
+ "Generated answer offers multiple numeric alternatives with an unboxed 'or' in the raw response."
480
+ )
481
+ full_reason = f"Strictness fail (Issue #1 - Unboxed 'or'): {specific_reason_detail}"
482
+ metrics["strictness_penalty_unboxed_or"] = MetricResult(
483
+ score=0.0, is_score_valid=False, reason=specific_reason_detail
484
+ )
485
+ return EvaluateResult(score=0.0, reason=full_reason, metrics=metrics)
486
+ return None
487
+
488
+
489
+ def _check_ambiguity_strictness(
490
+ orig_answers_extracted: List[Tuple[str, Union[float, str]]],
491
+ gen_answers_extracted: List[Tuple[str, Union[float, str]]],
492
+ metrics: Dict[str, MetricResult],
493
+ ) -> Optional[EvaluateResult]:
494
+ """Checks for ambiguity strictness violation."""
495
+ if len(orig_answers_extracted) == 1 and len(gen_answers_extracted) > 1:
496
+ specific_reason_detail = "Ground truth is specific (one answer), but generated answer is ambiguous (multiple answers extracted, even after potential leniency)."
497
+ full_reason = f"Strictness fail (Issue #2 - Ambiguity): {specific_reason_detail}"
498
+ metrics["strictness_penalty_ambiguity"] = MetricResult(
499
+ score=0.0, is_score_valid=False, reason=specific_reason_detail
500
+ )
501
+ return EvaluateResult(score=0.0, reason=full_reason, metrics=metrics)
502
+ return None
503
+
504
+
505
+ def _check_conflicting_answers_strictness(
506
+ orig_answers_extracted: List[Tuple[str, Union[float, str]]],
507
+ gen_answers_extracted: List[Tuple[str, Union[float, str]]],
508
+ best_match_score: float,
509
+ match_found_flag: bool,
510
+ is_single_orig_boxed_truth: bool,
511
+ has_matching_gen_boxed_answer: bool,
512
+ tolerance: float,
513
+ absolute_tolerance: float,
514
+ current_best_reason: str,
515
+ metrics: Dict[str, MetricResult],
516
+ ) -> Tuple[float, bool, str]:
517
+ """Checks for conflicting answers strictness violation."""
518
+ if not (match_found_flag and best_match_score > 0.75):
519
+ return best_match_score, match_found_flag, current_best_reason
520
+
521
+ conflicting_extra_numeric_values = []
522
+ if not (is_single_orig_boxed_truth and has_matching_gen_boxed_answer):
523
+ for _, gen_val in gen_answers_extracted:
524
+ if not isinstance(gen_val, (float, int)):
525
+ continue
526
+ is_gen_val_a_match_to_an_orig_val = False
527
+ for _, orig_val_comp in orig_answers_extracted:
528
+ if isinstance(orig_val_comp, (float, int)):
529
+ if math.isclose(
530
+ gen_val,
531
+ orig_val_comp,
532
+ rel_tol=tolerance,
533
+ abs_tol=absolute_tolerance,
534
+ ):
535
+ is_gen_val_a_match_to_an_orig_val = True
536
+ break
537
+ if not is_gen_val_a_match_to_an_orig_val:
538
+ conflicting_extra_numeric_values.append(gen_val)
539
+
540
+ if conflicting_extra_numeric_values:
541
+ formatted_conflicting = ", ".join(map(str, sorted(list(set(conflicting_extra_numeric_values)))))
542
+ specific_reason_detail = (
543
+ f"Generated answer, while containing a match for an original answer, "
544
+ f"also includes other distinct numerical values not matching any original answer: [{formatted_conflicting}]"
545
+ )
546
+ metrics["strictness_penalty_conflicting_answers"] = MetricResult(
547
+ score=0.0, is_score_valid=False, reason=specific_reason_detail
548
+ )
549
+ return (
550
+ 0.0,
551
+ False,
552
+ f"Strictness fail (Conflicting Answers): {specific_reason_detail}. Initial match was: {current_best_reason}",
553
+ )
554
+
555
+ return best_match_score, match_found_flag, current_best_reason
556
+
557
+
558
+ @reward_function
559
+ def math_reward(
560
+ messages: List[Message],
561
+ *,
562
+ ground_truth: str,
563
+ tolerance: float = 0.001,
564
+ absolute_tolerance: float = 1e-8,
565
+ require_units: bool = False,
566
+ **kwargs: Any,
567
+ ) -> EvaluateResult:
568
+ if (
569
+ not messages
570
+ or not isinstance(messages[-1], Message)
571
+ or messages[-1].role != "assistant"
572
+ or messages[-1].content is None
573
+ ):
574
+ return EvaluateResult(
575
+ score=0.0,
576
+ reason="Invalid or missing assistant response in messages.",
577
+ metrics={
578
+ "error": MetricResult(
579
+ score=0.0,
580
+ is_score_valid=False,
581
+ reason="Last message not a valid assistant response.",
582
+ )
583
+ },
584
+ )
585
+ model_response_content = messages[-1].content
586
+ if ground_truth is None or ground_truth == "":
587
+ return EvaluateResult(
588
+ score=0.0,
589
+ reason="Missing or empty ground_truth (expected math answer string).",
590
+ metrics={
591
+ "error": MetricResult(
592
+ score=0.0,
593
+ is_score_valid=False,
594
+ reason="Invalid ground_truth string.",
595
+ )
596
+ },
597
+ )
598
+
599
+ gen_answers_extracted_initial = extract_numbers(model_response_content)
600
+ orig_answers_extracted = extract_numbers(ground_truth)
601
+ gen_answers_extracted = list(gen_answers_extracted_initial)
602
+ metrics: Dict[str, MetricResult] = {}
603
+
604
+ def format_extracted(items: List[Tuple[str, Union[float, str]]]) -> str:
605
+ if not items:
606
+ return "None"
607
+ return ", ".join([f"'{i[0]}' ({i[1]})" for i in items])
608
+
609
+ metrics["extracted_original_answers"] = MetricResult(
610
+ score=0.0,
611
+ is_score_valid=bool(orig_answers_extracted),
612
+ reason=f"Extracted from original: {format_extracted(orig_answers_extracted)}",
613
+ )
614
+ metrics["extracted_generated_answers"] = MetricResult(
615
+ score=0.0,
616
+ is_score_valid=bool(gen_answers_extracted_initial),
617
+ reason=f"Extracted from generated (initial): {format_extracted(gen_answers_extracted_initial)}",
618
+ )
619
+
620
+ if not orig_answers_extracted:
621
+ return EvaluateResult(
622
+ score=0.0,
623
+ reason="Could not extract answers from original message (ground truth).",
624
+ metrics=metrics,
625
+ )
626
+ if not gen_answers_extracted_initial:
627
+ return EvaluateResult(
628
+ score=0.0,
629
+ reason="Could not extract answers from generated message, but original message has answers.",
630
+ metrics=metrics,
631
+ )
632
+
633
+ # --- DEMO Leniency Modification START ---
634
+ is_single_orig_boxed_truth = False
635
+ orig_boxed_value = None
636
+ if len(orig_answers_extracted) == 1 and orig_answers_extracted[0][0].startswith("\\boxed{"):
637
+ if isinstance(orig_answers_extracted[0][1], (float, int)):
638
+ is_single_orig_boxed_truth = True
639
+ orig_boxed_value = orig_answers_extracted[0][1]
640
+
641
+ has_matching_gen_boxed_answer = False
642
+ if is_single_orig_boxed_truth and orig_boxed_value is not None:
643
+ for gen_text, gen_val in gen_answers_extracted_initial:
644
+ if gen_text.startswith("\\boxed{") and isinstance(gen_val, (float, int)):
645
+ if math.isclose(
646
+ gen_val,
647
+ orig_boxed_value,
648
+ rel_tol=tolerance,
649
+ abs_tol=absolute_tolerance,
650
+ ):
651
+ has_matching_gen_boxed_answer = True
652
+ gen_answers_extracted = [(gen_text, gen_val)]
653
+ metrics["demo_leniency_info"] = MetricResult(
654
+ score=1.0,
655
+ is_score_valid=True,
656
+ reason=f"Demo Leniency: Matching boxed answer '{gen_text}' found. Simplified gen_answers to this match.",
657
+ )
658
+ break
659
+ # --- DEMO Leniency Modification END ---
660
+
661
+ unboxed_or_result = _check_unboxed_or_strictness(model_response_content, gen_answers_extracted, metrics)
662
+ if unboxed_or_result:
663
+ return unboxed_or_result
664
+
665
+ ambiguity_result = _check_ambiguity_strictness(orig_answers_extracted, gen_answers_extracted, metrics)
666
+ if ambiguity_result:
667
+ return ambiguity_result
668
+
669
+ best_match_score = 0.0
670
+ best_match_reason = "No matching answer found"
671
+ match_found_flag = False
672
+ first_comparison_details_for_no_match = ""
673
+
674
+ for orig_text, orig_value in orig_answers_extracted:
675
+ for gen_text, gen_value in gen_answers_extracted:
676
+ current_match = False
677
+ current_similarity = 0.0
678
+ comparison_details = ""
679
+ if isinstance(orig_value, (float, int)) and isinstance(gen_value, (float, int)):
680
+ if require_units:
681
+ orig_has_unit = _has_unit_text(orig_text, float(orig_value))
682
+ gen_has_unit = _has_unit_text(gen_text, float(gen_value))
683
+ if orig_has_unit != gen_has_unit:
684
+ comparison_details = f"Unit presence mismatch (require_units=True). Orig_text: '{orig_text}', Gen_text: '{gen_text}'"
685
+ else:
686
+ current_match, current_similarity = compare_numbers(
687
+ float(orig_value),
688
+ float(gen_value),
689
+ tolerance,
690
+ absolute_tolerance,
691
+ )
692
+ comparison_details = (
693
+ f"Numeric match: {'Yes' if current_match else 'No'}, Similarity: {current_similarity:.3f}"
694
+ )
695
+ else:
696
+ current_match, current_similarity = compare_numbers(
697
+ float(orig_value),
698
+ float(gen_value),
699
+ tolerance,
700
+ absolute_tolerance,
701
+ )
702
+ comparison_details = (
703
+ f"Numeric match: {'Yes' if current_match else 'No'}, Similarity: {current_similarity:.3f}"
704
+ )
705
+ elif isinstance(orig_value, str) and isinstance(gen_value, str):
706
+ if orig_value.lower() == gen_value.lower():
707
+ current_match = True
708
+ current_similarity = 1.0
709
+ comparison_details = (
710
+ f"String match: {'Yes' if current_match else 'No'} (value: '{gen_value}' vs '{orig_value}')"
711
+ )
712
+ else:
713
+ comparison_details = (
714
+ f"Type mismatch: Gen({type(gen_value).__name__}) vs Orig({type(orig_value).__name__})"
715
+ )
716
+
717
+ if not first_comparison_details_for_no_match:
718
+ first_comparison_details_for_no_match = (
719
+ f"Initial comparison: Gen='{gen_text}' ({gen_value}) vs Orig='{orig_text}' ({orig_value}).\n"
720
+ f"{comparison_details}"
721
+ )
722
+
723
+ if current_similarity > best_match_score:
724
+ best_match_score = current_similarity
725
+ match_found_flag = current_match
726
+ best_match_reason = (
727
+ f"Best match: Gen='{gen_text}' ({gen_value}) vs Orig='{orig_text}' ({orig_value}).\n"
728
+ f"{comparison_details}"
729
+ )
730
+ elif best_match_score == 0 and not match_found_flag and current_similarity == 0:
731
+ best_match_reason = (
732
+ f"No score match: Gen='{gen_text}' ({gen_value}) vs Orig='{orig_text}' ({orig_value}).\n"
733
+ f"{comparison_details}"
734
+ )
735
+
736
+ if (
737
+ best_match_score == 0
738
+ and not match_found_flag
739
+ and first_comparison_details_for_no_match
740
+ and best_match_reason == "No matching answer found"
741
+ ):
742
+ best_match_reason = first_comparison_details_for_no_match
743
+
744
+ best_match_score, match_found_flag, best_match_reason = _check_conflicting_answers_strictness(
745
+ orig_answers_extracted,
746
+ gen_answers_extracted,
747
+ best_match_score,
748
+ match_found_flag,
749
+ is_single_orig_boxed_truth,
750
+ has_matching_gen_boxed_answer,
751
+ tolerance,
752
+ absolute_tolerance,
753
+ best_match_reason,
754
+ metrics,
755
+ )
756
+
757
+ metrics["answer_comparison"] = MetricResult(
758
+ score=best_match_score,
759
+ is_score_valid=match_found_flag and best_match_score > 0,
760
+ reason=best_match_reason,
761
+ )
762
+ return EvaluateResult(score=best_match_score, reason=best_match_reason, metrics=metrics)