eval-protocol 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. development/__init__.py +1 -0
  2. development/normalize_sandbox_fusion.py +628 -0
  3. development/utils/__init__.py +1 -0
  4. development/utils/generate_api_key.py +31 -0
  5. development/utils/subprocess_manager.py +481 -0
  6. eval_protocol/__init__.py +86 -0
  7. eval_protocol/__main__.py +10 -0
  8. eval_protocol/_version.py +21 -0
  9. eval_protocol/adapters/__init__.py +1 -0
  10. eval_protocol/adapters/braintrust.py +8 -0
  11. eval_protocol/adapters/trl.py +8 -0
  12. eval_protocol/agent/__init__.py +29 -0
  13. eval_protocol/agent/models.py +69 -0
  14. eval_protocol/agent/orchestrator.py +893 -0
  15. eval_protocol/agent/resource_abc.py +89 -0
  16. eval_protocol/agent/resource_pool.py +184 -0
  17. eval_protocol/agent/resources/__init__.py +44 -0
  18. eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
  19. eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
  20. eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
  21. eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
  22. eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
  23. eval_protocol/agent/resources/docker_resource.py +479 -0
  24. eval_protocol/agent/resources/filesystem_resource.py +371 -0
  25. eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
  26. eval_protocol/agent/resources/http_rollout_resource.py +325 -0
  27. eval_protocol/agent/resources/python_state_resource.py +170 -0
  28. eval_protocol/agent/resources/sql_resource.py +271 -0
  29. eval_protocol/agent/task_manager.py +1064 -0
  30. eval_protocol/agent/tool_registry.py +111 -0
  31. eval_protocol/auth.py +156 -0
  32. eval_protocol/cli.py +425 -0
  33. eval_protocol/cli_commands/__init__.py +1 -0
  34. eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
  35. eval_protocol/cli_commands/common.py +242 -0
  36. eval_protocol/cli_commands/deploy.py +486 -0
  37. eval_protocol/cli_commands/deploy_mcp.py +287 -0
  38. eval_protocol/cli_commands/preview.py +186 -0
  39. eval_protocol/cli_commands/run_eval_cmd.py +202 -0
  40. eval_protocol/common_utils.py +36 -0
  41. eval_protocol/config.py +180 -0
  42. eval_protocol/datasets/__init__.py +1 -0
  43. eval_protocol/datasets/loader.py +521 -0
  44. eval_protocol/evaluation.py +1045 -0
  45. eval_protocol/execution/__init__.py +1 -0
  46. eval_protocol/execution/pipeline.py +920 -0
  47. eval_protocol/gcp_tools.py +484 -0
  48. eval_protocol/generation/cache.py +141 -0
  49. eval_protocol/generation/clients/base.py +67 -0
  50. eval_protocol/generation/clients.py +248 -0
  51. eval_protocol/generic_server.py +165 -0
  52. eval_protocol/integrations/__init__.py +12 -0
  53. eval_protocol/integrations/braintrust.py +51 -0
  54. eval_protocol/integrations/deepeval.py +106 -0
  55. eval_protocol/integrations/openeval.py +40 -0
  56. eval_protocol/integrations/trl.py +187 -0
  57. eval_protocol/mcp/__init__.py +48 -0
  58. eval_protocol/mcp/adapter.py +131 -0
  59. eval_protocol/mcp/client/__init__.py +12 -0
  60. eval_protocol/mcp/client/connection.py +499 -0
  61. eval_protocol/mcp/clients.py +195 -0
  62. eval_protocol/mcp/execution/__init__.py +23 -0
  63. eval_protocol/mcp/execution/base_policy.py +227 -0
  64. eval_protocol/mcp/execution/fireworks_policy.py +209 -0
  65. eval_protocol/mcp/execution/manager.py +506 -0
  66. eval_protocol/mcp/execution/policy.py +421 -0
  67. eval_protocol/mcp/grid_renderer.py +54 -0
  68. eval_protocol/mcp/mcpgym.py +637 -0
  69. eval_protocol/mcp/process_manager.py +177 -0
  70. eval_protocol/mcp/session/__init__.py +11 -0
  71. eval_protocol/mcp/session/manager.py +228 -0
  72. eval_protocol/mcp/simple_process_manager.py +291 -0
  73. eval_protocol/mcp/simulation_server.py +458 -0
  74. eval_protocol/mcp/types.py +80 -0
  75. eval_protocol/mcp_agent/__init__.py +1 -0
  76. eval_protocol/mcp_agent/config.py +147 -0
  77. eval_protocol/mcp_agent/intermediary_server.py +542 -0
  78. eval_protocol/mcp_agent/main.py +210 -0
  79. eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
  80. eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
  81. eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
  82. eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
  83. eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
  84. eval_protocol/mcp_agent/session.py +79 -0
  85. eval_protocol/mcp_env.py +304 -0
  86. eval_protocol/models.py +366 -0
  87. eval_protocol/packaging.py +219 -0
  88. eval_protocol/platform_api.py +360 -0
  89. eval_protocol/playback_policy.py +396 -0
  90. eval_protocol/resources.py +128 -0
  91. eval_protocol/reward_function.py +410 -0
  92. eval_protocol/rewards/__init__.py +94 -0
  93. eval_protocol/rewards/accuracy.py +454 -0
  94. eval_protocol/rewards/accuracy_length.py +173 -0
  95. eval_protocol/rewards/apps_coding_reward.py +331 -0
  96. eval_protocol/rewards/apps_execution_utils.py +149 -0
  97. eval_protocol/rewards/apps_testing_util.py +559 -0
  98. eval_protocol/rewards/bfcl_reward.py +313 -0
  99. eval_protocol/rewards/code_execution.py +1620 -0
  100. eval_protocol/rewards/code_execution_utils.py +72 -0
  101. eval_protocol/rewards/cpp_code.py +861 -0
  102. eval_protocol/rewards/deepcoder_reward.py +161 -0
  103. eval_protocol/rewards/format.py +129 -0
  104. eval_protocol/rewards/function_calling.py +541 -0
  105. eval_protocol/rewards/json_schema.py +422 -0
  106. eval_protocol/rewards/language_consistency.py +700 -0
  107. eval_protocol/rewards/lean_prover.py +479 -0
  108. eval_protocol/rewards/length.py +375 -0
  109. eval_protocol/rewards/list_comparison_math_reward.py +221 -0
  110. eval_protocol/rewards/math.py +762 -0
  111. eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
  112. eval_protocol/rewards/reasoning_steps.py +249 -0
  113. eval_protocol/rewards/repetition.py +342 -0
  114. eval_protocol/rewards/tag_count.py +162 -0
  115. eval_protocol/rl_processing.py +82 -0
  116. eval_protocol/server.py +271 -0
  117. eval_protocol/typed_interface.py +260 -0
  118. eval_protocol/utils/__init__.py +8 -0
  119. eval_protocol/utils/batch_evaluation.py +217 -0
  120. eval_protocol/utils/batch_transformation.py +205 -0
  121. eval_protocol/utils/dataset_helpers.py +112 -0
  122. eval_protocol/utils/module_loader.py +56 -0
  123. eval_protocol/utils/packaging_utils.py +108 -0
  124. eval_protocol/utils/static_policy.py +305 -0
  125. eval_protocol-0.0.3.dist-info/METADATA +635 -0
  126. eval_protocol-0.0.3.dist-info/RECORD +130 -0
  127. eval_protocol-0.0.3.dist-info/WHEEL +5 -0
  128. eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
  129. eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
  130. eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
@@ -0,0 +1,331 @@
1
+ import ast
2
+ import json
3
+ import logging
4
+ import re
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ from eval_protocol.models import EvaluateResult, Message, MetricResult
8
+ from eval_protocol.reward_function import reward_function
9
+
10
+ # Import the new execution utility
11
+ from .apps_execution_utils import check_correctness
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ # Helper function to extract code from the assistant's response
17
+ def _extract_python_code(response_content: str) -> Optional[str]:
18
+ """
19
+ Extracts Python code from a string.
20
+ Tries to find code within ```python ... ``` or ``` ... ``` blocks.
21
+ If not found, tries to find the first 'def ' and takes from there.
22
+ It also attempts to remove <think>...</think> blocks first.
23
+ """
24
+ # Attempt to remove <think>...</think> blocks first
25
+ cleaned_content = re.sub(r"<think>[\s\S]*?</think>", "", response_content, flags=re.IGNORECASE).strip()
26
+ if cleaned_content != response_content.strip(): # Log if <think> block was actually removed
27
+ logger.debug(
28
+ "Removed <think>...</think> block(s). Content after removal (stripped): "
29
+ + repr(cleaned_content[:200])
30
+ + "..."
31
+ )
32
+ if not cleaned_content: # If stripping results in empty string
33
+ logger.warning("Content became empty after removing <think> block and stripping.")
34
+ return None
35
+ else: # No <think> block found or removing it resulted in the same stripped string
36
+ cleaned_content = response_content.strip() # Ensure we work with stripped content if no <think> block
37
+
38
+ # Try to find ```python ... ``` in the cleaned content
39
+ match = re.search(r"```python\s*(.*?)\s*```", cleaned_content, re.DOTALL)
40
+ if match:
41
+ logger.debug("Extracted code using ```python ... ``` block.")
42
+ return match.group(1).strip()
43
+
44
+ # Try to find ``` ... ``` in the cleaned content
45
+ match = re.search(r"```\s*(.*?)\s*```", cleaned_content, re.DOTALL)
46
+ if match:
47
+ logger.debug("Extracted code using ``` ... ``` block.")
48
+ return match.group(1).strip()
49
+
50
+ # Try to find the first 'def ' in the cleaned content
51
+ def_index = cleaned_content.find("def ")
52
+ if def_index != -1:
53
+ logger.debug("Extracted code starting from the first 'def '.")
54
+ return cleaned_content[def_index:].strip()
55
+
56
+ # If no specific markers, return the cleaned content stripped.
57
+ # The warning about parsing the entire response if no markers are found is now more accurate.
58
+ if not match and def_index == -1: # if no ``` or def was found
59
+ # Log if we are falling back to the full (cleaned) content
60
+ logger.warning(
61
+ "No specific code markers (```python, ```, def) found. Attempting to parse content after <think> removal (if any)."
62
+ )
63
+ return cleaned_content # This is already stripped if <think> was removed, or original stripped content
64
+
65
+
66
+ @reward_function
67
+ def evaluate_apps_solution(messages: List[Message], ground_truth: Optional[str], **kwargs) -> EvaluateResult:
68
+ """
69
+ Evaluates a code solution for the APPS dataset.
70
+ Extracts Python code from the last message and checks for basic Python code parsability.
71
+ The ground_truth is expected to be a JSON string containing test cases,
72
+ but it's not used in this initial simplified version.
73
+ """
74
+ if not messages:
75
+ return EvaluateResult(
76
+ score=0.0,
77
+ metrics={
78
+ "error": MetricResult(
79
+ score=0.0,
80
+ reason="No messages provided for evaluation.",
81
+ is_score_valid=False,
82
+ )
83
+ },
84
+ reason="No messages provided.",
85
+ )
86
+
87
+ raw_solution_content = messages[-1].content
88
+ code_solution = _extract_python_code(raw_solution_content)
89
+
90
+ if not code_solution or not code_solution.strip():
91
+ # Log the raw content if extraction resulted in empty/None
92
+ if raw_solution_content:
93
+ logger.warning(
94
+ f"Code extraction resulted in empty solution. Raw content was: '{raw_solution_content[:200]}...'"
95
+ )
96
+ else:
97
+ logger.warning("Code extraction resulted in empty solution. Raw content was empty.")
98
+ return EvaluateResult(
99
+ score=0.0,
100
+ metrics={
101
+ "parsability": MetricResult(
102
+ score=0.0,
103
+ reason="Empty code solution after extraction.",
104
+ is_score_valid=True,
105
+ ),
106
+ "error": MetricResult(
107
+ score=0.0,
108
+ reason="Empty code solution after extraction.",
109
+ is_score_valid=False,
110
+ ),
111
+ },
112
+ reason="The provided code solution was empty after extraction.",
113
+ )
114
+
115
+ logger.debug(f"Extracted code for execution: \n---\n{code_solution[:500]}...\n---")
116
+
117
+ # Default score and reason
118
+ score = 0.0
119
+ reason_msg = "Evaluation did not complete successfully."
120
+ metrics: Dict[str, MetricResult] = {}
121
+
122
+ in_outs: Optional[Dict[str, Any]] = None
123
+ if isinstance(ground_truth, str):
124
+ # Explicitly assign to a str-typed variable after check for Mypy
125
+ gt_str: str = ground_truth
126
+ logger.debug(f"Raw ground_truth string for sample: {gt_str[:1000]}")
127
+ try:
128
+ in_outs = json.loads(gt_str)
129
+ except json.JSONDecodeError as e:
130
+ logger.error(
131
+ f"Failed to parse ground_truth JSON string: {e}. GT (first 200 chars): {(gt_str or '')[:200]}"
132
+ )
133
+ return EvaluateResult(
134
+ score=0.0,
135
+ reason=f"Ground_truth JSONDecodeError: {e}",
136
+ metrics={
137
+ "error": MetricResult(
138
+ score=0.0,
139
+ reason=f"Ground_truth JSONDecodeError: {e}",
140
+ is_score_valid=False,
141
+ )
142
+ },
143
+ )
144
+ elif isinstance(ground_truth, dict):
145
+ logger.debug(f"ground_truth is already a dict: {str(ground_truth)[:1000]}")
146
+ in_outs = ground_truth # It's already parsed (likely by JSONL loader)
147
+ else:
148
+ logger.error(
149
+ f"ground_truth is neither a string nor a dict. Type: {type(ground_truth)}. Value (first 200 chars): {str(ground_truth)[:200]}"
150
+ )
151
+ return EvaluateResult(
152
+ score=0.0,
153
+ reason="Invalid ground_truth type.",
154
+ metrics={
155
+ "error": MetricResult(
156
+ score=0.0,
157
+ reason=f"Invalid ground_truth type: {type(ground_truth)}",
158
+ is_score_valid=False,
159
+ )
160
+ },
161
+ )
162
+
163
+ if not isinstance(in_outs, dict) or "inputs" not in in_outs or "outputs" not in in_outs:
164
+ logger.error(
165
+ f"Parsed ground_truth is not in the expected format (dict with 'inputs' and 'outputs'). Parsed: {str(in_outs)[:200]}"
166
+ )
167
+ return EvaluateResult(
168
+ score=0.0,
169
+ reason="Invalid ground_truth structure after parsing.",
170
+ metrics={
171
+ "error": MetricResult(
172
+ score=0.0,
173
+ reason="Invalid ground_truth structure after parsing.",
174
+ is_score_valid=False,
175
+ )
176
+ },
177
+ )
178
+
179
+ # Log the parsed in_outs and specifically check for fn_name
180
+ fn_name_from_gt = in_outs.get("fn_name")
181
+ if not fn_name_from_gt:
182
+ logger.warning("fn_name not found in ground_truth dict, will rely on system prompt for main() or full script.")
183
+ # fn_name_from_gt will remain None, forcing testing_util to use standard_input path.
184
+ logger.info(
185
+ f"Using fn_name from ground_truth (if present): {fn_name_from_gt}. Parsed in_outs (first 500 chars of dump): {json.dumps(in_outs)[:500]}"
186
+ )
187
+
188
+ # Default timeout for check_correctness, can be made configurable via kwargs if needed
189
+ timeout = kwargs.get("execution_timeout", 10)
190
+ debug_execution = True # For now, enable debug prints from check_correctness/run_test
191
+
192
+ # Construct the wrapper script
193
+ # Standard imports often used in competitive programming / APPS
194
+ standard_imports = """
195
+ import traceback, sys, json, ast, collections, copy, datetime, functools, heapq, io, itertools, math, operator, random, re, string, statistics, typing
196
+ sys.setrecursionlimit(6*10**5)
197
+ """
198
+ # Wrapper to call the user's function (fn_name_from_gt) and handle I/O
199
+ # This wrapper will be executed by testing_util's standard_input path.
200
+ # It expects testing_util to provide the actual test case input via sys.stdin.
201
+ # It will print the function's result to sys.stdout, which testing_util will capture.
202
+
203
+ # Determine how arguments should be passed based on fn_name_from_gt
204
+ # If 'main', assume it handles its own stdin. Otherwise, parse stdin as args.
205
+ # The testing_util.py's standard_input path provides the *entire* input for one test case as a single string to stdin.
206
+
207
+ # If fn_name_from_gt is 'main', the model's code should contain 'def main():' which reads stdin.
208
+ # If fn_name_from_gt is specific, the model's code is 'def specific_name(...):'.
209
+ # The wrapper needs to call this specific_name.
210
+ # The APPS 'inputs' are usually strings, where each string is the *entire* stdin for one run of the target function.
211
+ # Or, for call-based, 'inputs' is a list of lists of arguments.
212
+ # Since we are forcing standard_input path for testing_util by setting fn_name=None in in_outs_for_check,
213
+ # testing_util will provide the content of in_outs["inputs"][test_case_idx] to stdin.
214
+
215
+ # The generated code_solution might be a full script or just a function.
216
+ # If it's just a function, the wrapper needs to call it.
217
+ # If it's a full script with if __name__ == "__main__":, that will be handled by testing_util's stdio path.
218
+
219
+ # Let's assume the new system prompt encourages `def main(): ...`
220
+ # The `testing_util.py` standard input path wraps the solution in `def code(): ... solution ...`
221
+ # and then calls `code()`. If `solution` is `def main(): ...`, then `code()` just defines `main`.
222
+ # We need `main()` to be called.
223
+ # So, the `code_solution` itself should end with `if __name__ == "__main__": main()` or just `main()`.
224
+ # The system prompt now asks for `main()`. Let's assume the model provides it and might call it.
225
+
226
+ # Forcing testing_util to use its standard_input path by ensuring fn_name is None in the dict passed to it.
227
+ # The actual function name logic is now handled by the system prompt guiding the model.
228
+ # The `in_outs` dict passed to check_correctness will have its 'fn_name' key removed or set to None
229
+ # to ensure testing_util.py uses its standard input execution path.
230
+ # The `generation` argument to check_correctness will be the `code_solution`.
231
+ # `testing_util.py` will wrap this `code_solution` in `def code(): ...` and call `code()`.
232
+ # If `code_solution` is `def main(): ... ; main()`, it should work.
233
+ # If `code_solution` is just `def main(): ...`, it won't work.
234
+ # The new system prompt is: "Structure your solution within a main() function. ... main() should handle it. ... main() should print..."
235
+ # This implies the model should provide a callable main that does everything.
236
+
237
+ # Let's simplify: assume the model provides a runnable script (e.g. with main() called at the end, or top-level code)
238
+ # due to the new system prompt. We will rely on testing_util's standard_input path.
239
+ # We need to ensure `fn_name` is NOT in `in_outs` when calling `check_correctness`.
240
+
241
+ in_outs_for_check = in_outs.copy() # Use a copy to modify for check_correctness
242
+ if "fn_name" in in_outs_for_check:
243
+ # Remove fn_name to force testing_util's standard_input path,
244
+ # as our system prompt now asks for a main() that handles IO.
245
+ # The generated code itself should be a runnable script.
246
+ del in_outs_for_check["fn_name"]
247
+ logger.info(f"Removed 'fn_name' from in_outs for check_correctness to use standard_input path.")
248
+
249
+ final_code_to_execute = code_solution # The model's full response (after extraction)
250
+
251
+ try:
252
+ results_list, exec_metadata_list = check_correctness(
253
+ in_outs=in_outs_for_check, # This now has no 'fn_name'
254
+ generation=final_code_to_execute,
255
+ timeout=timeout,
256
+ debug=debug_execution,
257
+ )
258
+
259
+ # Process results_list
260
+ if not results_list: # Should not happen if check_correctness returns properly
261
+ reason_msg = "Execution utility returned no results."
262
+ logger.error(reason_msg)
263
+ metrics["execution_error"] = MetricResult(score=0.0, reason=reason_msg, is_score_valid=False)
264
+ else:
265
+ # Check for error codes (-1 for runtime/timeout, -2 for compilation error)
266
+ # These error codes are per test case as per testing_util.py's results.append()
267
+ # However, check_correctness's _temp_run appends a list, so results_list is a list of lists.
268
+ # The outer list from check_correctness usually has one item: the list of results from run_test.
269
+
270
+ actual_results = results_list # results_list from check_correctness is already the list of actual outcomes
271
+
272
+ num_tests = len(actual_results)
273
+ if num_tests == 0: # Should ideally not happen if in_outs['inputs'] is non-empty
274
+ reason_msg = "No test cases were effectively run or reported by execution utility."
275
+ logger.warning(reason_msg)
276
+ # Score remains 0.0
277
+ else:
278
+ passed_count = sum(1 for res in actual_results if res is True)
279
+ score = float(passed_count) / num_tests
280
+ reason_msg = f"Passed {passed_count}/{num_tests} test cases."
281
+ logger.info(f"Execution result: {reason_msg}")
282
+
283
+ metrics["pass_rate"] = MetricResult(score=score, reason=f"{passed_count}/{num_tests}")
284
+ metrics["raw_results"] = MetricResult(
285
+ score=0.0, reason=json.dumps(actual_results), is_score_valid=False
286
+ ) # Store raw results
287
+
288
+ # Process metadata
289
+ # exec_metadata_list is a list of dicts. If it's a single dict (e.g. compilation error), wrap it.
290
+ # The check_correctness in apps_execution_utils.py should return a list of metadata dicts.
291
+ if exec_metadata_list:
292
+ # If there's a single metadata entry that contains a significant error (like compilation)
293
+ # it might apply to the whole attempt.
294
+ # For now, just log it or add to metrics.
295
+ # The original prime_code's compute_score returns a list of metadata.
296
+ # We'll store it as a JSON string for simplicity in metrics.
297
+ # If only one metadata dict, it might be a global error (e.g. compilation)
298
+ if len(exec_metadata_list) == 1 and exec_metadata_list[0].get("error"):
299
+ reason_msg += f" Execution Error: {exec_metadata_list[0]['error']}"
300
+ metrics["execution_error_details"] = MetricResult(
301
+ score=0.0,
302
+ reason=json.dumps(exec_metadata_list[0]),
303
+ is_score_valid=False,
304
+ )
305
+ elif exec_metadata_list: # It's not a global error, but there's metadata (e.g., for Wrong Answer)
306
+ metrics["execution_metadata"] = MetricResult(
307
+ score=0.0,
308
+ reason=json.dumps(exec_metadata_list),
309
+ is_score_valid=False,
310
+ )
311
+ # If it's a "Wrong Answer" and score is 0, enhance the reason_msg
312
+ if score == 0.0 and exec_metadata_list[0].get("error_message") == "Wrong Answer":
313
+ first_fail_meta = exec_metadata_list[0]
314
+ reason_msg += (
315
+ f". First fail details: Inputs: {first_fail_meta.get('inputs', 'N/A')}, "
316
+ f"Expected: {first_fail_meta.get('expected', 'N/A')}, "
317
+ f"Got: {first_fail_meta.get('output', 'N/A')}"
318
+ )
319
+
320
+ # If score is 0 and there was an error in metadata, reflect it in reason_msg
321
+ # This condition might be redundant now due to the above, or could be a fallback.
322
+ if score == 0.0 and metrics.get("execution_error_details") and "Execution Error" not in reason_msg:
323
+ pass # reason_msg might already be updated by global error or Wrong Answer details.
324
+
325
+ except Exception as e:
326
+ score = 0.0 # Ensure score is 0 on any unexpected error in this block
327
+ reason_msg = f"Error during code execution or result processing: {type(e).__name__}: {e}"
328
+ logger.error(reason_msg, exc_info=True)
329
+ metrics["evaluation_error"] = MetricResult(score=0.0, reason=reason_msg, is_score_valid=False)
330
+
331
+ return EvaluateResult(score=score, metrics=metrics, reason=reason_msg)
@@ -0,0 +1,149 @@
1
+ # Copyright 2024 PRIME team and/or its affiliates (Adapted for reward-kit)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import multiprocessing
17
+ import os
18
+ import sys
19
+ import traceback
20
+ from typing import Any, Dict, List, Optional
21
+
22
+ # Adapted import to point to our local apps_testing_util.py
23
+ from .apps_testing_util import run_test
24
+
25
+ # Note: The original file had a compute_score function.
26
+ # We are primarily interested in check_correctness and its helper _temp_run.
27
+ # The main reward function in apps_coding_reward.py will call check_correctness.
28
+
29
+
30
+ def _temp_run(
31
+ sample: Dict[str, Any],
32
+ generation: str,
33
+ debug: bool,
34
+ result_list: list,
35
+ metadata_list: list,
36
+ timeout: int,
37
+ ):
38
+ """
39
+ Helper function to run a single test in a separate process context.
40
+ Manages stdout/stderr redirection and captures results/metadata.
41
+ """
42
+ # Redirect stdout/stderr to prevent interference and capture output if needed by run_test
43
+ # Note: run_test itself might also capture stdout for standard_input type problems.
44
+ # This top-level redirection ensures the process itself doesn't pollute console.
45
+ original_stdout = sys.stdout
46
+ original_stderr = sys.stderr
47
+ # Temporarily disable stdout/stderr redirection to see debug prints from run_test
48
+ # sys.stdout = open(os.devnull, "w")
49
+ # sys.stderr = open(os.devnull, "w")
50
+ print(f"[_temp_run] Executing run_test for sample. Debug prints from run_test should be visible.")
51
+
52
+ try:
53
+ res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout)
54
+ result_list.append(res)
55
+ metadata_list.append(metadata)
56
+ except Exception:
57
+ # This catch-all is for unexpected errors within _temp_run itself or run_test if it raises
58
+ # instead of returning error codes in `res`.
59
+ # run_test is designed to return error codes like -1 or -2 in `res` for failures.
60
+ num_inputs = len(sample.get("inputs", []))
61
+ tb_str = traceback.format_exc()
62
+ print(f"[_temp_run] Exception caught: {tb_str}")
63
+ result_list.append([-1] * num_inputs if num_inputs > 0 else [-1]) # Mark all as error
64
+ metadata_list.append({"error": "Exception in _temp_run/run_test", "traceback": tb_str})
65
+ finally:
66
+ # Restore stdout/stderr
67
+ # if sys.stdout is not original_stdout and hasattr(sys.stdout, 'close'): # Check if it was replaced and closable
68
+ # sys.stdout.close()
69
+ # if sys.stderr is not original_stderr and hasattr(sys.stderr, 'close'): # Check if it was replaced and closable
70
+ # sys.stderr.close()
71
+ sys.stdout = original_stdout # Always restore
72
+ sys.stderr = original_stderr # Always restore
73
+
74
+
75
+ def check_correctness(
76
+ in_outs: Optional[dict], generation: str, timeout: int = 10, debug: bool = True
77
+ ) -> tuple[List[Any], List[Dict[str, Any]]]:
78
+ """
79
+ Checks correctness of code generation with a global timeout using multiprocessing.
80
+ The global timeout is to catch some extreme/rare cases not handled by the timeouts
81
+ inside `run_test`.
82
+ Args:
83
+ in_outs: Dictionary with "inputs" and "outputs" lists, and optionally "fn_name".
84
+ generation: The code string to test.
85
+ timeout: Timeout in seconds for each test case execution within run_test,
86
+ and also for the overall process.
87
+ debug: Debug flag passed to run_test.
88
+
89
+ Returns:
90
+ A tuple containing:
91
+ - A list of results (e.g., booleans for pass/fail, or error codes).
92
+ - A list of metadata dictionaries corresponding to each result.
93
+ """
94
+ if not in_outs or "inputs" not in in_outs or not isinstance(in_outs["inputs"], list):
95
+ # Handle cases where in_outs might be None or malformed early
96
+ return [-1], [{"error": "Invalid or missing in_outs structure"}]
97
+
98
+ manager = multiprocessing.Manager()
99
+ result_proxy = manager.list() # Using proxy list for multiprocessing
100
+ metadata_proxy = manager.list() # Using proxy list for multiprocessing
101
+
102
+ process = multiprocessing.Process(
103
+ target=_temp_run,
104
+ args=(in_outs, generation, debug, result_proxy, metadata_proxy, timeout),
105
+ )
106
+ process.start()
107
+ process.join(timeout=timeout + 1) # Join with a slightly longer timeout for the process itself
108
+
109
+ if process.is_alive():
110
+ process.kill() # Force kill if still alive after timeout
111
+ # process.terminate() # Alternative, more graceful termination
112
+
113
+ # Convert proxy lists to regular lists for return
114
+ # Ensure that if the process was killed, we have some default error state.
115
+ if not result_proxy: # If result_proxy is empty (e.g., process killed before appending)
116
+ num_inputs = len(in_outs.get("inputs", []))
117
+ final_results = [-1] * num_inputs if num_inputs > 0 else [-1] # Mark all as error
118
+ final_metadata = (
119
+ [
120
+ {
121
+ "error": "Global timeout or process killed prematurely",
122
+ "details": "No results returned from subprocess.",
123
+ }
124
+ ]
125
+ * num_inputs
126
+ if num_inputs > 0
127
+ else [{"error": "Global timeout"}]
128
+ )
129
+ if debug:
130
+ print("Global timeout or process killed before results could be appended.")
131
+ else:
132
+ final_results = list(result_proxy)[0] # Expecting run_test to return a list of results for the inputs
133
+ final_metadata = list(metadata_proxy) # This should be a list of dicts
134
+
135
+ # Ensure metadata_list has a corresponding entry for each input if it was a single dict from run_test
136
+ if (
137
+ isinstance(final_metadata, list)
138
+ and len(final_metadata) == 1
139
+ and isinstance(final_metadata[0], dict)
140
+ and len(in_outs.get("inputs", [])) > 1
141
+ ):
142
+ # If run_test returned a single metadata dict for multiple inputs (e.g. on compilation error)
143
+ # or if _temp_run appended a single error dict.
144
+ # We might want to duplicate this error metadata for all inputs if results indicate multiple failures.
145
+ # However, the original logic in prime_code's compute_score seems to handle metadata as a list already.
146
+ # For now, assume metadata_proxy structure is as intended by _temp_run.
147
+ pass
148
+
149
+ return final_results, final_metadata