eval-protocol 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. development/__init__.py +1 -0
  2. development/normalize_sandbox_fusion.py +628 -0
  3. development/utils/__init__.py +1 -0
  4. development/utils/generate_api_key.py +31 -0
  5. development/utils/subprocess_manager.py +481 -0
  6. eval_protocol/__init__.py +86 -0
  7. eval_protocol/__main__.py +10 -0
  8. eval_protocol/_version.py +21 -0
  9. eval_protocol/adapters/__init__.py +1 -0
  10. eval_protocol/adapters/braintrust.py +8 -0
  11. eval_protocol/adapters/trl.py +8 -0
  12. eval_protocol/agent/__init__.py +29 -0
  13. eval_protocol/agent/models.py +69 -0
  14. eval_protocol/agent/orchestrator.py +893 -0
  15. eval_protocol/agent/resource_abc.py +89 -0
  16. eval_protocol/agent/resource_pool.py +184 -0
  17. eval_protocol/agent/resources/__init__.py +44 -0
  18. eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
  19. eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
  20. eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
  21. eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
  22. eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
  23. eval_protocol/agent/resources/docker_resource.py +479 -0
  24. eval_protocol/agent/resources/filesystem_resource.py +371 -0
  25. eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
  26. eval_protocol/agent/resources/http_rollout_resource.py +325 -0
  27. eval_protocol/agent/resources/python_state_resource.py +170 -0
  28. eval_protocol/agent/resources/sql_resource.py +271 -0
  29. eval_protocol/agent/task_manager.py +1064 -0
  30. eval_protocol/agent/tool_registry.py +111 -0
  31. eval_protocol/auth.py +156 -0
  32. eval_protocol/cli.py +425 -0
  33. eval_protocol/cli_commands/__init__.py +1 -0
  34. eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
  35. eval_protocol/cli_commands/common.py +242 -0
  36. eval_protocol/cli_commands/deploy.py +486 -0
  37. eval_protocol/cli_commands/deploy_mcp.py +287 -0
  38. eval_protocol/cli_commands/preview.py +186 -0
  39. eval_protocol/cli_commands/run_eval_cmd.py +202 -0
  40. eval_protocol/common_utils.py +36 -0
  41. eval_protocol/config.py +180 -0
  42. eval_protocol/datasets/__init__.py +1 -0
  43. eval_protocol/datasets/loader.py +521 -0
  44. eval_protocol/evaluation.py +1045 -0
  45. eval_protocol/execution/__init__.py +1 -0
  46. eval_protocol/execution/pipeline.py +920 -0
  47. eval_protocol/gcp_tools.py +484 -0
  48. eval_protocol/generation/cache.py +141 -0
  49. eval_protocol/generation/clients/base.py +67 -0
  50. eval_protocol/generation/clients.py +248 -0
  51. eval_protocol/generic_server.py +165 -0
  52. eval_protocol/integrations/__init__.py +12 -0
  53. eval_protocol/integrations/braintrust.py +51 -0
  54. eval_protocol/integrations/deepeval.py +106 -0
  55. eval_protocol/integrations/openeval.py +40 -0
  56. eval_protocol/integrations/trl.py +187 -0
  57. eval_protocol/mcp/__init__.py +48 -0
  58. eval_protocol/mcp/adapter.py +131 -0
  59. eval_protocol/mcp/client/__init__.py +12 -0
  60. eval_protocol/mcp/client/connection.py +499 -0
  61. eval_protocol/mcp/clients.py +195 -0
  62. eval_protocol/mcp/execution/__init__.py +23 -0
  63. eval_protocol/mcp/execution/base_policy.py +227 -0
  64. eval_protocol/mcp/execution/fireworks_policy.py +209 -0
  65. eval_protocol/mcp/execution/manager.py +506 -0
  66. eval_protocol/mcp/execution/policy.py +421 -0
  67. eval_protocol/mcp/grid_renderer.py +54 -0
  68. eval_protocol/mcp/mcpgym.py +637 -0
  69. eval_protocol/mcp/process_manager.py +177 -0
  70. eval_protocol/mcp/session/__init__.py +11 -0
  71. eval_protocol/mcp/session/manager.py +228 -0
  72. eval_protocol/mcp/simple_process_manager.py +291 -0
  73. eval_protocol/mcp/simulation_server.py +458 -0
  74. eval_protocol/mcp/types.py +80 -0
  75. eval_protocol/mcp_agent/__init__.py +1 -0
  76. eval_protocol/mcp_agent/config.py +147 -0
  77. eval_protocol/mcp_agent/intermediary_server.py +542 -0
  78. eval_protocol/mcp_agent/main.py +210 -0
  79. eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
  80. eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
  81. eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
  82. eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
  83. eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
  84. eval_protocol/mcp_agent/session.py +79 -0
  85. eval_protocol/mcp_env.py +304 -0
  86. eval_protocol/models.py +366 -0
  87. eval_protocol/packaging.py +219 -0
  88. eval_protocol/platform_api.py +360 -0
  89. eval_protocol/playback_policy.py +396 -0
  90. eval_protocol/resources.py +128 -0
  91. eval_protocol/reward_function.py +410 -0
  92. eval_protocol/rewards/__init__.py +94 -0
  93. eval_protocol/rewards/accuracy.py +454 -0
  94. eval_protocol/rewards/accuracy_length.py +173 -0
  95. eval_protocol/rewards/apps_coding_reward.py +331 -0
  96. eval_protocol/rewards/apps_execution_utils.py +149 -0
  97. eval_protocol/rewards/apps_testing_util.py +559 -0
  98. eval_protocol/rewards/bfcl_reward.py +313 -0
  99. eval_protocol/rewards/code_execution.py +1620 -0
  100. eval_protocol/rewards/code_execution_utils.py +72 -0
  101. eval_protocol/rewards/cpp_code.py +861 -0
  102. eval_protocol/rewards/deepcoder_reward.py +161 -0
  103. eval_protocol/rewards/format.py +129 -0
  104. eval_protocol/rewards/function_calling.py +541 -0
  105. eval_protocol/rewards/json_schema.py +422 -0
  106. eval_protocol/rewards/language_consistency.py +700 -0
  107. eval_protocol/rewards/lean_prover.py +479 -0
  108. eval_protocol/rewards/length.py +375 -0
  109. eval_protocol/rewards/list_comparison_math_reward.py +221 -0
  110. eval_protocol/rewards/math.py +762 -0
  111. eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
  112. eval_protocol/rewards/reasoning_steps.py +249 -0
  113. eval_protocol/rewards/repetition.py +342 -0
  114. eval_protocol/rewards/tag_count.py +162 -0
  115. eval_protocol/rl_processing.py +82 -0
  116. eval_protocol/server.py +271 -0
  117. eval_protocol/typed_interface.py +260 -0
  118. eval_protocol/utils/__init__.py +8 -0
  119. eval_protocol/utils/batch_evaluation.py +217 -0
  120. eval_protocol/utils/batch_transformation.py +205 -0
  121. eval_protocol/utils/dataset_helpers.py +112 -0
  122. eval_protocol/utils/module_loader.py +56 -0
  123. eval_protocol/utils/packaging_utils.py +108 -0
  124. eval_protocol/utils/static_policy.py +305 -0
  125. eval_protocol-0.0.3.dist-info/METADATA +635 -0
  126. eval_protocol-0.0.3.dist-info/RECORD +130 -0
  127. eval_protocol-0.0.3.dist-info/WHEEL +5 -0
  128. eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
  129. eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
  130. eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
@@ -0,0 +1,375 @@
1
+ """
2
+ Reward functions for evaluating response length.
3
+
4
+ This module provides reward functions that evaluate the length of model responses,
5
+ either by simple token/character count or using cosine-scaled rewards to promote
6
+ token efficiency.
7
+ """
8
+
9
+ import math
10
+ import re
11
+ from typing import Any, Callable, Dict, List, Optional, Union
12
+
13
+ from ..models import EvaluateResult, Message, MetricResult
14
+ from ..typed_interface import reward_function
15
+
16
+
17
+ def count_tokens(text: str, method: str = "whitespace") -> int:
18
+ """
19
+ Count tokens in text using different methods.
20
+
21
+ Args:
22
+ text: The text to tokenize
23
+ method: Tokenization method to use ('whitespace', 'character', or 'words')
24
+
25
+ Returns:
26
+ Token count based on the selected method
27
+ """
28
+ if method == "character":
29
+ return len(text)
30
+ elif method == "whitespace":
31
+ return len(re.split(r"\s+", text.strip()))
32
+ elif method == "words":
33
+ return len(re.findall(r"\b[\w\d]+\b", text))
34
+ else:
35
+ return len(re.split(r"\s+", text.strip()))
36
+
37
+
38
+ @reward_function # type: ignore[arg-type]
39
+ def length_reward(
40
+ messages: Union[List[Message], List[Dict[str, Any]]],
41
+ *,
42
+ ground_truth: Optional[
43
+ Union[List[Message], List[Dict[str, Any]]]
44
+ ] = None, # Not used by this function but part of standard signature
45
+ target_length: Optional[int] = None,
46
+ min_length: Optional[int] = None,
47
+ max_length: Optional[int] = None,
48
+ token_method: str = "whitespace",
49
+ scaling: str = "linear",
50
+ reward_range: Optional[List[float]] = None,
51
+ **kwargs: Any,
52
+ ) -> EvaluateResult:
53
+ """
54
+ Reward function that evaluates the length of model responses.
55
+ The model's response is assumed to be the last message in the `messages` list.
56
+
57
+ This function can calculate rewards based on token count and can encourage either
58
+ conciseness or thoroughness by setting appropriate min/max/target parameters.
59
+
60
+ Args:
61
+ messages: List of conversation messages, where `messages[-1]` is the model's response.
62
+ ground_truth: Optional. Expected assistant response trajectory. Not directly used by this length reward.
63
+ target_length: Optional target token count (optimal length).
64
+ min_length: Minimum acceptable token count.
65
+ max_length: Maximum acceptable token count.
66
+ token_method: Method to count tokens ('whitespace', 'character', or 'words')
67
+ scaling: Scaling method for reward calculation ('linear' or 'cosine')
68
+ reward_range: Range for reward values, default is [0.0, 1.0]
69
+ **kwargs: Additional arguments
70
+
71
+ Returns:
72
+ EvaluateResult with score based on length evaluation
73
+ """
74
+ if not messages or len(messages) == 0:
75
+ return EvaluateResult(
76
+ score=0.0,
77
+ reason="No messages provided",
78
+ metrics={"length": MetricResult(score=0.0, is_score_valid=False, reason="No messages provided")},
79
+ )
80
+
81
+ response = messages[-1]
82
+
83
+ if isinstance(response, Message):
84
+ if response.role != "assistant" or not response.content:
85
+ return EvaluateResult(
86
+ score=0.0,
87
+ reason="No assistant response found",
88
+ metrics={
89
+ "length": MetricResult(
90
+ score=0.0,
91
+ is_score_valid=False,
92
+ reason="Message not from assistant or has no content",
93
+ )
94
+ },
95
+ )
96
+ text = response.content
97
+ elif isinstance(response, dict):
98
+ if response.get("role") != "assistant" or not response.get("content"):
99
+ return EvaluateResult(
100
+ score=0.0,
101
+ reason="No assistant response found",
102
+ metrics={
103
+ "length": MetricResult(
104
+ score=0.0,
105
+ is_score_valid=False,
106
+ reason="Message not from assistant or has no content",
107
+ )
108
+ },
109
+ )
110
+ text = response.get("content", "")
111
+ else:
112
+ return EvaluateResult(
113
+ score=0.0,
114
+ reason="Last message is of unexpected type.",
115
+ metrics={
116
+ "length": MetricResult(
117
+ score=0.0,
118
+ is_score_valid=False,
119
+ reason="Invalid message type in messages.",
120
+ )
121
+ },
122
+ )
123
+
124
+ token_count = count_tokens(text, method=token_method)
125
+
126
+ if reward_range is None:
127
+ reward_range = [0.0, 1.0]
128
+ min_reward, max_reward = reward_range
129
+
130
+ if target_length is not None:
131
+ normalized_diff = abs(token_count - target_length) / target_length if target_length > 0 else 1.0
132
+ if scaling == "cosine":
133
+ progress = min(1.0, normalized_diff)
134
+ score = min_reward + (max_reward - min_reward) * (1.0 + math.cos(progress * math.pi)) / 2.0
135
+ else:
136
+ score = max(
137
+ min_reward,
138
+ max_reward - normalized_diff * (max_reward - min_reward),
139
+ )
140
+ reason = (
141
+ f"Response length ({token_count} tokens) deviated by {normalized_diff:.2f} from target ({target_length})"
142
+ )
143
+ success = normalized_diff < 0.2
144
+ elif min_length is not None and max_length is not None:
145
+ if token_count < min_length:
146
+ progress = token_count / min_length
147
+ if scaling == "cosine":
148
+ score = min_reward + (max_reward - min_reward) * (1.0 - math.cos(progress * math.pi / 2.0))
149
+ else:
150
+ score = min_reward + (max_reward - min_reward) * progress
151
+ reason = f"Response length ({token_count} tokens) is below minimum ({min_length})"
152
+ success = False
153
+ elif token_count > max_length:
154
+ excess = token_count - max_length
155
+ range_size = (
156
+ max_length - min_length if max_length > min_length else 1
157
+ ) # Avoid division by zero if min_length == max_length
158
+ progress = min(
159
+ 1.0,
160
+ excess / range_size if range_size > 0 else (1.0 if excess > 0 else 0.0),
161
+ )
162
+
163
+ if scaling == "cosine":
164
+ score = max_reward - (max_reward - min_reward) * (1.0 - math.cos(progress * math.pi / 2.0))
165
+ else:
166
+ score = max_reward - (max_reward - min_reward) * progress
167
+ reason = f"Response length ({token_count} tokens) exceeds maximum ({max_length})"
168
+ success = False
169
+ else:
170
+ score = max_reward
171
+ reason = f"Response length ({token_count} tokens) is within acceptable range ({min_length}-{max_length})"
172
+ success = True
173
+ elif min_length is not None:
174
+ if token_count < min_length:
175
+ progress = token_count / min_length
176
+ if scaling == "cosine":
177
+ score = min_reward + (max_reward - min_reward) * (1.0 - math.cos(progress * math.pi / 2.0))
178
+ else:
179
+ score = min_reward + (max_reward - min_reward) * progress
180
+ reason = f"Response length ({token_count} tokens) is below minimum ({min_length})"
181
+ success = False
182
+ else:
183
+ score = max_reward
184
+ reason = f"Response length ({token_count} tokens) meets minimum requirement ({min_length})"
185
+ success = True
186
+ elif max_length is not None:
187
+ if token_count > max_length:
188
+ excess = token_count - max_length
189
+ progress = min(
190
+ 1.0,
191
+ excess / max_length if max_length > 0 else (1.0 if excess > 0 else 0.0),
192
+ )
193
+ if scaling == "cosine":
194
+ score = max_reward - (max_reward - min_reward) * (1.0 - math.cos(progress * math.pi / 2.0))
195
+ else:
196
+ score = max_reward - (max_reward - min_reward) * progress
197
+ reason = f"Response length ({token_count} tokens) exceeds maximum ({max_length})"
198
+ success = False
199
+ else:
200
+ score = max_reward
201
+ reason = f"Response length ({token_count} tokens) is within maximum limit ({max_length})"
202
+ success = True
203
+ else:
204
+ # This is useful when combined with correctness metrics
205
+ # E.g., shorter correct answers > longer correct answers > incorrect answers
206
+ reference_length = 100 # Default length for normalization
207
+ normalized_length = token_count / reference_length
208
+ if scaling == "cosine":
209
+ progress = min(1.0, normalized_length)
210
+ score = min_reward + (max_reward - min_reward) * (1.0 + math.cos(progress * math.pi)) / 2.0
211
+ else:
212
+ progress = min(1.0, normalized_length)
213
+ score = max_reward - progress * (max_reward - min_reward)
214
+ reason = f"Response length: {token_count} tokens"
215
+ success = True
216
+
217
+ metrics = {
218
+ "length": MetricResult(score=score, is_score_valid=success, reason=reason),
219
+ "token_count": MetricResult(
220
+ score=min(
221
+ 1.0,
222
+ float(token_count) / (target_length or max_length or min_length or 100),
223
+ ),
224
+ is_score_valid=success,
225
+ reason=f"Token count: {token_count}",
226
+ ),
227
+ }
228
+
229
+ return EvaluateResult(score=score, reason=reason, metrics=metrics)
230
+
231
+
232
+ @reward_function # type: ignore[arg-type]
233
+ def cosine_length_reward(
234
+ messages: Union[List[Message], List[Dict[str, Any]]],
235
+ *,
236
+ ground_truth: Optional[
237
+ Union[List[Message], List[Dict[str, Any]]]
238
+ ] = None, # Not used by this function but part of standard signature
239
+ correctness: Optional[float] = None,
240
+ is_correct: Optional[bool] = None,
241
+ max_length: int = 1000,
242
+ min_value_wrong: float = 0.0,
243
+ max_value_wrong: float = 0.3,
244
+ min_value_correct: float = 0.5,
245
+ max_value_correct: float = 1.0,
246
+ token_method: str = "whitespace",
247
+ **kwargs: Any,
248
+ ) -> EvaluateResult:
249
+ """
250
+ Reward function that scales based on completion length using a cosine schedule.
251
+ The model's response is assumed to be the last message in the `messages` list.
252
+
253
+ Inspired by the OpenR1 implementation (https://github.com/OpenRL-Lab/open-r1) and
254
+ Kimi Technical Report (https://arxiv.org/abs/2501.12599).
255
+
256
+ Shorter correct solutions are rewarded more than longer ones.
257
+ Longer incorrect solutions are penalized less than shorter ones.
258
+
259
+ Args:
260
+ messages: List of conversation messages, where `messages[-1]` is the model's response.
261
+ ground_truth: Optional. Expected assistant response trajectory. Not directly used by this length reward.
262
+ correctness: Optional float (0-1) indicating solution correctness.
263
+ is_correct: Optional boolean indicating if the solution is correct.
264
+ max_length: Maximum length for scaling.
265
+ min_value_wrong: Minimum reward for wrong answers (typically negative)
266
+ max_value_wrong: Maximum reward for wrong answers (typically negative but closer to zero)
267
+ min_value_correct: Minimum reward for correct answers (typically positive)
268
+ max_value_correct: Maximum reward for correct answers (typically more positive)
269
+ token_method: Method to count tokens
270
+ **kwargs: Additional arguments
271
+
272
+ Returns:
273
+ EvaluateResult with score based on cosine-scaled length evaluation
274
+ """
275
+ if not messages or len(messages) == 0:
276
+ return EvaluateResult(
277
+ score=0.0,
278
+ reason="No messages provided",
279
+ metrics={"cosine_length": MetricResult(score=0.0, is_score_valid=False, reason="No messages provided")},
280
+ )
281
+
282
+ response = messages[-1]
283
+
284
+ if isinstance(response, Message):
285
+ if response.role != "assistant" or not response.content:
286
+ return EvaluateResult(
287
+ score=0.0,
288
+ reason="No assistant response found",
289
+ metrics={
290
+ "cosine_length": MetricResult(
291
+ score=0.0,
292
+ is_score_valid=False,
293
+ reason="Message not from assistant or has no content",
294
+ )
295
+ },
296
+ )
297
+ text = response.content
298
+ elif isinstance(response, dict):
299
+ if response.get("role") != "assistant" or not response.get("content"):
300
+ return EvaluateResult(
301
+ score=0.0,
302
+ reason="No assistant response found",
303
+ metrics={
304
+ "cosine_length": MetricResult(
305
+ score=0.0,
306
+ is_score_valid=False,
307
+ reason="Message not from assistant or has no content",
308
+ )
309
+ },
310
+ )
311
+ text = response.get("content", "")
312
+ else:
313
+ return EvaluateResult(
314
+ score=0.0,
315
+ reason="Last message is of unexpected type.",
316
+ metrics={
317
+ "cosine_length": MetricResult(
318
+ score=0.0,
319
+ is_score_valid=False,
320
+ reason="Invalid message type in messages.",
321
+ )
322
+ },
323
+ )
324
+
325
+ token_count = count_tokens(text, method=token_method)
326
+
327
+ solution_is_correct = False
328
+ if is_correct is not None:
329
+ solution_is_correct = is_correct
330
+ elif correctness is not None:
331
+ solution_is_correct = correctness >= 0.9
332
+
333
+ progress = min(1.0, token_count / max_length)
334
+ cosine_factor = math.cos(progress * math.pi)
335
+
336
+ if solution_is_correct:
337
+ min_value = min_value_correct
338
+ max_value = max_value_correct
339
+ else:
340
+ min_value = max_value_wrong
341
+ max_value = min_value_wrong
342
+
343
+ score = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine_factor)
344
+
345
+ if solution_is_correct:
346
+ success = True
347
+ reason = f"Correct solution with length penalty: {token_count} tokens"
348
+ else:
349
+ success = False
350
+ reason = f"Incorrect solution with length consideration: {token_count} tokens"
351
+
352
+ detailed_reason = (
353
+ f"Length-based {'reward' if solution_is_correct else 'penalty'}: "
354
+ f"{token_count}/{max_length} tokens, cosine factor: {cosine_factor:.2f}"
355
+ )
356
+
357
+ metrics = {
358
+ "cosine_length": MetricResult(
359
+ score=score,
360
+ is_score_valid=success,
361
+ reason=detailed_reason, # Use detailed_reason here
362
+ ),
363
+ "token_count": MetricResult(
364
+ score=min(1.0, float(token_count) / max_length),
365
+ is_score_valid=success,
366
+ reason=f"Token count: {token_count}/{max_length}",
367
+ ),
368
+ "correctness": MetricResult(
369
+ score=1.0 if solution_is_correct else 0.0,
370
+ is_score_valid=solution_is_correct,
371
+ reason=f"Solution is {'correct' if solution_is_correct else 'incorrect'}",
372
+ ),
373
+ }
374
+
375
+ return EvaluateResult(score=score, reason=reason, metrics=metrics)
@@ -0,0 +1,221 @@
1
+ """
2
+ Reward function for comparing lists of numbers, often found in math answers
3
+ like sets of divisors, roots, etc.
4
+ """
5
+
6
+ import re
7
+ from typing import Any, Dict, List, Optional, Set, Tuple, Union
8
+
9
+ from ..models import EvaluateResult, Message, MetricResult
10
+ from ..typed_interface import reward_function
11
+
12
+
13
+ def parse_number_list_from_string(s: str) -> Optional[List[float]]:
14
+ """
15
+ Parses a string potentially containing a comma-separated list of numbers.
16
+ Handles integers and simple decimals.
17
+ e.g., "1, 2, 3.5, 4" -> [1.0, 2.0, 3.5, 4.0]
18
+ """
19
+ numbers = []
20
+ s = s.replace("$", "").strip()
21
+ parts = re.split(r"\s*,\s*", s)
22
+ if not parts or not any(p.strip() for p in parts):
23
+ return None
24
+
25
+ for part in parts:
26
+ part = part.strip()
27
+ if not part:
28
+ continue
29
+ try:
30
+ numbers.append(float(part))
31
+ except ValueError:
32
+ return None
33
+ return numbers if numbers else None
34
+
35
+
36
+ def extract_number_list(text: str) -> List[List[float]]:
37
+ """
38
+ Extracts lists of numbers from text.
39
+ Prioritizes content within \\boxed{} or $...$.
40
+ If multiple such expressions exist, each valid list is returned.
41
+ If no such delimiters, tries to parse the whole text.
42
+
43
+ Args:
44
+ text: The text to extract number lists from.
45
+
46
+ Returns:
47
+ A list of extracted number lists. Each inner list contains floats.
48
+ Example: "\\boxed{1,2,3}, $4,5$" -> [[1.0, 2.0, 3.0], [4.0, 5.0]]
49
+ """
50
+ extracted_lists: List[List[float]] = []
51
+
52
+ # Priority 1: Boxed LaTeX expressions
53
+ boxed_contents = re.findall(r"\\boxed\{((?:[^{}]|\{[^{}]*\})*)\}", text)
54
+ if boxed_contents:
55
+ for content in boxed_contents:
56
+ parsed_list = parse_number_list_from_string(content)
57
+ if parsed_list:
58
+ extracted_lists.append(parsed_list)
59
+ if extracted_lists:
60
+ return extracted_lists
61
+
62
+ # Priority 2: Content within $...$ or $$...$$
63
+ dollar_contents = re.findall(r"\$\$(.*?)\$\$|\$(.*?)\$", text, re.DOTALL)
64
+ if dollar_contents:
65
+ for group_match in dollar_contents:
66
+ content = group_match[0] if group_match[0] else group_match[1]
67
+ if content:
68
+ parsed_list = parse_number_list_from_string(content.strip())
69
+ if parsed_list:
70
+ extracted_lists.append(parsed_list)
71
+ if extracted_lists:
72
+ return extracted_lists
73
+
74
+ # Priority 3: Try parsing the whole text as a list if no delimiters found
75
+ # This is a fallback and might be less reliable.
76
+ if not extracted_lists:
77
+ full_text_parsed_list = parse_number_list_from_string(text)
78
+ if full_text_parsed_list:
79
+ extracted_lists.append(full_text_parsed_list)
80
+
81
+ return extracted_lists
82
+
83
+
84
+ @reward_function # type: ignore[arg-type]
85
+ def list_comparison_math_reward(
86
+ messages: List[Message],
87
+ *,
88
+ ground_truth: str,
89
+ order_matters: bool = False,
90
+ **kwargs: Any,
91
+ ) -> EvaluateResult:
92
+ """
93
+ Evaluate answers that are lists/sets of numbers.
94
+
95
+ Extracts lists of numbers from the model's response (messages[-1].content)
96
+ and the ground_truth string, then compares them.
97
+ By default, order does not matter (set comparison).
98
+
99
+ Args:
100
+ messages: List of conversation messages. The last message is the assistant's response.
101
+ ground_truth: String representation of the expected list of numbers.
102
+ order_matters: If True, compares lists directly (order and count matter).
103
+ If False (default), compares as sets (order and duplicates
104
+ within a list don't matter beyond presence).
105
+ **kwargs: Additional keyword arguments.
106
+
107
+ Returns:
108
+ EvaluateResult with score and metrics.
109
+ """
110
+ metrics: Dict[str, MetricResult] = {}
111
+
112
+ if (
113
+ not messages
114
+ or not isinstance(messages[-1], Message)
115
+ or messages[-1].role != "assistant"
116
+ or messages[-1].content is None
117
+ ):
118
+ return EvaluateResult(
119
+ score=0.0,
120
+ reason="Invalid or missing assistant response in messages.",
121
+ metrics={
122
+ "error": MetricResult(
123
+ score=0.0,
124
+ is_score_valid=False,
125
+ reason="Last message not a valid assistant response.",
126
+ )
127
+ },
128
+ )
129
+
130
+ gen_content = messages[-1].content
131
+ orig_content = ground_truth
132
+
133
+ if not gen_content:
134
+ return EvaluateResult(
135
+ score=0.0,
136
+ reason="Assistant response content is empty.",
137
+ metrics={
138
+ "error": MetricResult(
139
+ score=0.0,
140
+ is_score_valid=False,
141
+ reason="Empty generated message content.",
142
+ )
143
+ },
144
+ )
145
+ if not orig_content:
146
+ return EvaluateResult(
147
+ score=0.0,
148
+ reason="Ground truth string (expected list) is empty.",
149
+ metrics={"error": MetricResult(score=0.0, is_score_valid=False, reason="Empty ground truth string.")},
150
+ )
151
+
152
+ gen_lists = extract_number_list(gen_content)
153
+ orig_lists = extract_number_list(orig_content)
154
+
155
+ metrics["extracted_original_lists"] = MetricResult(
156
+ score=1.0 if orig_lists else 0.0,
157
+ is_score_valid=bool(orig_lists),
158
+ reason=f"Original lists: {orig_lists}",
159
+ )
160
+ metrics["extracted_generated_lists"] = MetricResult(
161
+ score=1.0 if gen_lists else 0.0,
162
+ is_score_valid=bool(gen_lists),
163
+ reason=f"Generated lists: {gen_lists}",
164
+ )
165
+
166
+ if not orig_lists:
167
+ return EvaluateResult(
168
+ score=0.0,
169
+ reason="Could not extract any number list from original message (ground truth).",
170
+ metrics=metrics,
171
+ )
172
+ if not gen_lists:
173
+ return EvaluateResult(
174
+ score=0.0,
175
+ reason="Could not extract any number list from generated message.",
176
+ metrics=metrics,
177
+ )
178
+
179
+ # For simplicity, compare the first valid list found in each.
180
+ # Future improvement: handle multiple lists (e.g., if solution has multiple boxed lists)
181
+ orig_list_to_compare = orig_lists[0]
182
+ gen_list_to_compare = gen_lists[0]
183
+
184
+ score = 0.0
185
+ match_reason = ""
186
+
187
+ if order_matters:
188
+ # Note: To be robust against float precision, comparison element-wise with tolerance might be needed.
189
+ if gen_list_to_compare == orig_list_to_compare:
190
+ score = 1.0
191
+ match_reason = (
192
+ f"Exact list match (order matters). Gen: {gen_list_to_compare} vs Orig: {orig_list_to_compare}"
193
+ )
194
+ else:
195
+ score = 0.0
196
+ match_reason = f"List mismatch (order matters). Gen: {gen_list_to_compare} vs Orig: {orig_list_to_compare}"
197
+ else:
198
+ # Note: float precision can be an issue with sets. A more robust set comparison would involve tolerance.
199
+ gen_set = set(gen_list_to_compare)
200
+ orig_set = set(orig_list_to_compare)
201
+
202
+ if gen_set == orig_set:
203
+ score = 1.0
204
+ match_reason = (
205
+ f"Set match (order does not matter). Gen: {sorted(list(gen_set))} vs Orig: {sorted(list(orig_set))}"
206
+ )
207
+ else:
208
+ score = 0.0
209
+ missing_in_gen = orig_set - gen_set
210
+ extra_in_gen = gen_set - orig_set
211
+ match_reason_parts = [
212
+ f"Set mismatch (order does not matter). Gen: {sorted(list(gen_set))} vs Orig: {sorted(list(orig_set))}."
213
+ ]
214
+ if missing_in_gen:
215
+ match_reason_parts.append(f"Missing in generated: {sorted(list(missing_in_gen))}.")
216
+ if extra_in_gen:
217
+ match_reason_parts.append(f"Extra in generated: {sorted(list(extra_in_gen))}.")
218
+ match_reason = " ".join(match_reason_parts)
219
+
220
+ metrics["list_comparison"] = MetricResult(score=score, is_score_valid=(score == 1.0), reason=match_reason)
221
+ return EvaluateResult(score=score, reason=match_reason, metrics=metrics)