eval-protocol 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- development/__init__.py +1 -0
- development/normalize_sandbox_fusion.py +628 -0
- development/utils/__init__.py +1 -0
- development/utils/generate_api_key.py +31 -0
- development/utils/subprocess_manager.py +481 -0
- eval_protocol/__init__.py +86 -0
- eval_protocol/__main__.py +10 -0
- eval_protocol/_version.py +21 -0
- eval_protocol/adapters/__init__.py +1 -0
- eval_protocol/adapters/braintrust.py +8 -0
- eval_protocol/adapters/trl.py +8 -0
- eval_protocol/agent/__init__.py +29 -0
- eval_protocol/agent/models.py +69 -0
- eval_protocol/agent/orchestrator.py +893 -0
- eval_protocol/agent/resource_abc.py +89 -0
- eval_protocol/agent/resource_pool.py +184 -0
- eval_protocol/agent/resources/__init__.py +44 -0
- eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
- eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
- eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
- eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
- eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
- eval_protocol/agent/resources/docker_resource.py +479 -0
- eval_protocol/agent/resources/filesystem_resource.py +371 -0
- eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
- eval_protocol/agent/resources/http_rollout_resource.py +325 -0
- eval_protocol/agent/resources/python_state_resource.py +170 -0
- eval_protocol/agent/resources/sql_resource.py +271 -0
- eval_protocol/agent/task_manager.py +1064 -0
- eval_protocol/agent/tool_registry.py +111 -0
- eval_protocol/auth.py +156 -0
- eval_protocol/cli.py +425 -0
- eval_protocol/cli_commands/__init__.py +1 -0
- eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
- eval_protocol/cli_commands/common.py +242 -0
- eval_protocol/cli_commands/deploy.py +486 -0
- eval_protocol/cli_commands/deploy_mcp.py +287 -0
- eval_protocol/cli_commands/preview.py +186 -0
- eval_protocol/cli_commands/run_eval_cmd.py +202 -0
- eval_protocol/common_utils.py +36 -0
- eval_protocol/config.py +180 -0
- eval_protocol/datasets/__init__.py +1 -0
- eval_protocol/datasets/loader.py +521 -0
- eval_protocol/evaluation.py +1045 -0
- eval_protocol/execution/__init__.py +1 -0
- eval_protocol/execution/pipeline.py +920 -0
- eval_protocol/gcp_tools.py +484 -0
- eval_protocol/generation/cache.py +141 -0
- eval_protocol/generation/clients/base.py +67 -0
- eval_protocol/generation/clients.py +248 -0
- eval_protocol/generic_server.py +165 -0
- eval_protocol/integrations/__init__.py +12 -0
- eval_protocol/integrations/braintrust.py +51 -0
- eval_protocol/integrations/deepeval.py +106 -0
- eval_protocol/integrations/openeval.py +40 -0
- eval_protocol/integrations/trl.py +187 -0
- eval_protocol/mcp/__init__.py +48 -0
- eval_protocol/mcp/adapter.py +131 -0
- eval_protocol/mcp/client/__init__.py +12 -0
- eval_protocol/mcp/client/connection.py +499 -0
- eval_protocol/mcp/clients.py +195 -0
- eval_protocol/mcp/execution/__init__.py +23 -0
- eval_protocol/mcp/execution/base_policy.py +227 -0
- eval_protocol/mcp/execution/fireworks_policy.py +209 -0
- eval_protocol/mcp/execution/manager.py +506 -0
- eval_protocol/mcp/execution/policy.py +421 -0
- eval_protocol/mcp/grid_renderer.py +54 -0
- eval_protocol/mcp/mcpgym.py +637 -0
- eval_protocol/mcp/process_manager.py +177 -0
- eval_protocol/mcp/session/__init__.py +11 -0
- eval_protocol/mcp/session/manager.py +228 -0
- eval_protocol/mcp/simple_process_manager.py +291 -0
- eval_protocol/mcp/simulation_server.py +458 -0
- eval_protocol/mcp/types.py +80 -0
- eval_protocol/mcp_agent/__init__.py +1 -0
- eval_protocol/mcp_agent/config.py +147 -0
- eval_protocol/mcp_agent/intermediary_server.py +542 -0
- eval_protocol/mcp_agent/main.py +210 -0
- eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
- eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
- eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
- eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
- eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
- eval_protocol/mcp_agent/session.py +79 -0
- eval_protocol/mcp_env.py +304 -0
- eval_protocol/models.py +366 -0
- eval_protocol/packaging.py +219 -0
- eval_protocol/platform_api.py +360 -0
- eval_protocol/playback_policy.py +396 -0
- eval_protocol/resources.py +128 -0
- eval_protocol/reward_function.py +410 -0
- eval_protocol/rewards/__init__.py +94 -0
- eval_protocol/rewards/accuracy.py +454 -0
- eval_protocol/rewards/accuracy_length.py +173 -0
- eval_protocol/rewards/apps_coding_reward.py +331 -0
- eval_protocol/rewards/apps_execution_utils.py +149 -0
- eval_protocol/rewards/apps_testing_util.py +559 -0
- eval_protocol/rewards/bfcl_reward.py +313 -0
- eval_protocol/rewards/code_execution.py +1620 -0
- eval_protocol/rewards/code_execution_utils.py +72 -0
- eval_protocol/rewards/cpp_code.py +861 -0
- eval_protocol/rewards/deepcoder_reward.py +161 -0
- eval_protocol/rewards/format.py +129 -0
- eval_protocol/rewards/function_calling.py +541 -0
- eval_protocol/rewards/json_schema.py +422 -0
- eval_protocol/rewards/language_consistency.py +700 -0
- eval_protocol/rewards/lean_prover.py +479 -0
- eval_protocol/rewards/length.py +375 -0
- eval_protocol/rewards/list_comparison_math_reward.py +221 -0
- eval_protocol/rewards/math.py +762 -0
- eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
- eval_protocol/rewards/reasoning_steps.py +249 -0
- eval_protocol/rewards/repetition.py +342 -0
- eval_protocol/rewards/tag_count.py +162 -0
- eval_protocol/rl_processing.py +82 -0
- eval_protocol/server.py +271 -0
- eval_protocol/typed_interface.py +260 -0
- eval_protocol/utils/__init__.py +8 -0
- eval_protocol/utils/batch_evaluation.py +217 -0
- eval_protocol/utils/batch_transformation.py +205 -0
- eval_protocol/utils/dataset_helpers.py +112 -0
- eval_protocol/utils/module_loader.py +56 -0
- eval_protocol/utils/packaging_utils.py +108 -0
- eval_protocol/utils/static_policy.py +305 -0
- eval_protocol-0.0.3.dist-info/METADATA +635 -0
- eval_protocol-0.0.3.dist-info/RECORD +130 -0
- eval_protocol-0.0.3.dist-info/WHEEL +5 -0
- eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
- eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
- eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import re
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
from eval_protocol.models import EvaluateResult, Message, MetricResult
|
|
8
|
+
from eval_protocol.reward_function import reward_function
|
|
9
|
+
|
|
10
|
+
# Import the new execution utility
|
|
11
|
+
from .apps_execution_utils import check_correctness
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# Helper function to extract code from the assistant's response
|
|
17
|
+
def _extract_python_code(response_content: str) -> Optional[str]:
|
|
18
|
+
"""
|
|
19
|
+
Extracts Python code from a string.
|
|
20
|
+
Tries to find code within ```python ... ``` or ``` ... ``` blocks.
|
|
21
|
+
If not found, tries to find the first 'def ' and takes from there.
|
|
22
|
+
It also attempts to remove <think>...</think> blocks first.
|
|
23
|
+
"""
|
|
24
|
+
# Attempt to remove <think>...</think> blocks first
|
|
25
|
+
cleaned_content = re.sub(r"<think>[\s\S]*?</think>", "", response_content, flags=re.IGNORECASE).strip()
|
|
26
|
+
if cleaned_content != response_content.strip(): # Log if <think> block was actually removed
|
|
27
|
+
logger.debug(
|
|
28
|
+
"Removed <think>...</think> block(s). Content after removal (stripped): "
|
|
29
|
+
+ repr(cleaned_content[:200])
|
|
30
|
+
+ "..."
|
|
31
|
+
)
|
|
32
|
+
if not cleaned_content: # If stripping results in empty string
|
|
33
|
+
logger.warning("Content became empty after removing <think> block and stripping.")
|
|
34
|
+
return None
|
|
35
|
+
else: # No <think> block found or removing it resulted in the same stripped string
|
|
36
|
+
cleaned_content = response_content.strip() # Ensure we work with stripped content if no <think> block
|
|
37
|
+
|
|
38
|
+
# Try to find ```python ... ``` in the cleaned content
|
|
39
|
+
match = re.search(r"```python\s*(.*?)\s*```", cleaned_content, re.DOTALL)
|
|
40
|
+
if match:
|
|
41
|
+
logger.debug("Extracted code using ```python ... ``` block.")
|
|
42
|
+
return match.group(1).strip()
|
|
43
|
+
|
|
44
|
+
# Try to find ``` ... ``` in the cleaned content
|
|
45
|
+
match = re.search(r"```\s*(.*?)\s*```", cleaned_content, re.DOTALL)
|
|
46
|
+
if match:
|
|
47
|
+
logger.debug("Extracted code using ``` ... ``` block.")
|
|
48
|
+
return match.group(1).strip()
|
|
49
|
+
|
|
50
|
+
# Try to find the first 'def ' in the cleaned content
|
|
51
|
+
def_index = cleaned_content.find("def ")
|
|
52
|
+
if def_index != -1:
|
|
53
|
+
logger.debug("Extracted code starting from the first 'def '.")
|
|
54
|
+
return cleaned_content[def_index:].strip()
|
|
55
|
+
|
|
56
|
+
# If no specific markers, return the cleaned content stripped.
|
|
57
|
+
# The warning about parsing the entire response if no markers are found is now more accurate.
|
|
58
|
+
if not match and def_index == -1: # if no ``` or def was found
|
|
59
|
+
# Log if we are falling back to the full (cleaned) content
|
|
60
|
+
logger.warning(
|
|
61
|
+
"No specific code markers (```python, ```, def) found. Attempting to parse content after <think> removal (if any)."
|
|
62
|
+
)
|
|
63
|
+
return cleaned_content # This is already stripped if <think> was removed, or original stripped content
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@reward_function
|
|
67
|
+
def evaluate_apps_solution(messages: List[Message], ground_truth: Optional[str], **kwargs) -> EvaluateResult:
|
|
68
|
+
"""
|
|
69
|
+
Evaluates a code solution for the APPS dataset.
|
|
70
|
+
Extracts Python code from the last message and checks for basic Python code parsability.
|
|
71
|
+
The ground_truth is expected to be a JSON string containing test cases,
|
|
72
|
+
but it's not used in this initial simplified version.
|
|
73
|
+
"""
|
|
74
|
+
if not messages:
|
|
75
|
+
return EvaluateResult(
|
|
76
|
+
score=0.0,
|
|
77
|
+
metrics={
|
|
78
|
+
"error": MetricResult(
|
|
79
|
+
score=0.0,
|
|
80
|
+
reason="No messages provided for evaluation.",
|
|
81
|
+
is_score_valid=False,
|
|
82
|
+
)
|
|
83
|
+
},
|
|
84
|
+
reason="No messages provided.",
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
raw_solution_content = messages[-1].content
|
|
88
|
+
code_solution = _extract_python_code(raw_solution_content)
|
|
89
|
+
|
|
90
|
+
if not code_solution or not code_solution.strip():
|
|
91
|
+
# Log the raw content if extraction resulted in empty/None
|
|
92
|
+
if raw_solution_content:
|
|
93
|
+
logger.warning(
|
|
94
|
+
f"Code extraction resulted in empty solution. Raw content was: '{raw_solution_content[:200]}...'"
|
|
95
|
+
)
|
|
96
|
+
else:
|
|
97
|
+
logger.warning("Code extraction resulted in empty solution. Raw content was empty.")
|
|
98
|
+
return EvaluateResult(
|
|
99
|
+
score=0.0,
|
|
100
|
+
metrics={
|
|
101
|
+
"parsability": MetricResult(
|
|
102
|
+
score=0.0,
|
|
103
|
+
reason="Empty code solution after extraction.",
|
|
104
|
+
is_score_valid=True,
|
|
105
|
+
),
|
|
106
|
+
"error": MetricResult(
|
|
107
|
+
score=0.0,
|
|
108
|
+
reason="Empty code solution after extraction.",
|
|
109
|
+
is_score_valid=False,
|
|
110
|
+
),
|
|
111
|
+
},
|
|
112
|
+
reason="The provided code solution was empty after extraction.",
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
logger.debug(f"Extracted code for execution: \n---\n{code_solution[:500]}...\n---")
|
|
116
|
+
|
|
117
|
+
# Default score and reason
|
|
118
|
+
score = 0.0
|
|
119
|
+
reason_msg = "Evaluation did not complete successfully."
|
|
120
|
+
metrics: Dict[str, MetricResult] = {}
|
|
121
|
+
|
|
122
|
+
in_outs: Optional[Dict[str, Any]] = None
|
|
123
|
+
if isinstance(ground_truth, str):
|
|
124
|
+
# Explicitly assign to a str-typed variable after check for Mypy
|
|
125
|
+
gt_str: str = ground_truth
|
|
126
|
+
logger.debug(f"Raw ground_truth string for sample: {gt_str[:1000]}")
|
|
127
|
+
try:
|
|
128
|
+
in_outs = json.loads(gt_str)
|
|
129
|
+
except json.JSONDecodeError as e:
|
|
130
|
+
logger.error(
|
|
131
|
+
f"Failed to parse ground_truth JSON string: {e}. GT (first 200 chars): {(gt_str or '')[:200]}"
|
|
132
|
+
)
|
|
133
|
+
return EvaluateResult(
|
|
134
|
+
score=0.0,
|
|
135
|
+
reason=f"Ground_truth JSONDecodeError: {e}",
|
|
136
|
+
metrics={
|
|
137
|
+
"error": MetricResult(
|
|
138
|
+
score=0.0,
|
|
139
|
+
reason=f"Ground_truth JSONDecodeError: {e}",
|
|
140
|
+
is_score_valid=False,
|
|
141
|
+
)
|
|
142
|
+
},
|
|
143
|
+
)
|
|
144
|
+
elif isinstance(ground_truth, dict):
|
|
145
|
+
logger.debug(f"ground_truth is already a dict: {str(ground_truth)[:1000]}")
|
|
146
|
+
in_outs = ground_truth # It's already parsed (likely by JSONL loader)
|
|
147
|
+
else:
|
|
148
|
+
logger.error(
|
|
149
|
+
f"ground_truth is neither a string nor a dict. Type: {type(ground_truth)}. Value (first 200 chars): {str(ground_truth)[:200]}"
|
|
150
|
+
)
|
|
151
|
+
return EvaluateResult(
|
|
152
|
+
score=0.0,
|
|
153
|
+
reason="Invalid ground_truth type.",
|
|
154
|
+
metrics={
|
|
155
|
+
"error": MetricResult(
|
|
156
|
+
score=0.0,
|
|
157
|
+
reason=f"Invalid ground_truth type: {type(ground_truth)}",
|
|
158
|
+
is_score_valid=False,
|
|
159
|
+
)
|
|
160
|
+
},
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
if not isinstance(in_outs, dict) or "inputs" not in in_outs or "outputs" not in in_outs:
|
|
164
|
+
logger.error(
|
|
165
|
+
f"Parsed ground_truth is not in the expected format (dict with 'inputs' and 'outputs'). Parsed: {str(in_outs)[:200]}"
|
|
166
|
+
)
|
|
167
|
+
return EvaluateResult(
|
|
168
|
+
score=0.0,
|
|
169
|
+
reason="Invalid ground_truth structure after parsing.",
|
|
170
|
+
metrics={
|
|
171
|
+
"error": MetricResult(
|
|
172
|
+
score=0.0,
|
|
173
|
+
reason="Invalid ground_truth structure after parsing.",
|
|
174
|
+
is_score_valid=False,
|
|
175
|
+
)
|
|
176
|
+
},
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# Log the parsed in_outs and specifically check for fn_name
|
|
180
|
+
fn_name_from_gt = in_outs.get("fn_name")
|
|
181
|
+
if not fn_name_from_gt:
|
|
182
|
+
logger.warning("fn_name not found in ground_truth dict, will rely on system prompt for main() or full script.")
|
|
183
|
+
# fn_name_from_gt will remain None, forcing testing_util to use standard_input path.
|
|
184
|
+
logger.info(
|
|
185
|
+
f"Using fn_name from ground_truth (if present): {fn_name_from_gt}. Parsed in_outs (first 500 chars of dump): {json.dumps(in_outs)[:500]}"
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# Default timeout for check_correctness, can be made configurable via kwargs if needed
|
|
189
|
+
timeout = kwargs.get("execution_timeout", 10)
|
|
190
|
+
debug_execution = True # For now, enable debug prints from check_correctness/run_test
|
|
191
|
+
|
|
192
|
+
# Construct the wrapper script
|
|
193
|
+
# Standard imports often used in competitive programming / APPS
|
|
194
|
+
standard_imports = """
|
|
195
|
+
import traceback, sys, json, ast, collections, copy, datetime, functools, heapq, io, itertools, math, operator, random, re, string, statistics, typing
|
|
196
|
+
sys.setrecursionlimit(6*10**5)
|
|
197
|
+
"""
|
|
198
|
+
# Wrapper to call the user's function (fn_name_from_gt) and handle I/O
|
|
199
|
+
# This wrapper will be executed by testing_util's standard_input path.
|
|
200
|
+
# It expects testing_util to provide the actual test case input via sys.stdin.
|
|
201
|
+
# It will print the function's result to sys.stdout, which testing_util will capture.
|
|
202
|
+
|
|
203
|
+
# Determine how arguments should be passed based on fn_name_from_gt
|
|
204
|
+
# If 'main', assume it handles its own stdin. Otherwise, parse stdin as args.
|
|
205
|
+
# The testing_util.py's standard_input path provides the *entire* input for one test case as a single string to stdin.
|
|
206
|
+
|
|
207
|
+
# If fn_name_from_gt is 'main', the model's code should contain 'def main():' which reads stdin.
|
|
208
|
+
# If fn_name_from_gt is specific, the model's code is 'def specific_name(...):'.
|
|
209
|
+
# The wrapper needs to call this specific_name.
|
|
210
|
+
# The APPS 'inputs' are usually strings, where each string is the *entire* stdin for one run of the target function.
|
|
211
|
+
# Or, for call-based, 'inputs' is a list of lists of arguments.
|
|
212
|
+
# Since we are forcing standard_input path for testing_util by setting fn_name=None in in_outs_for_check,
|
|
213
|
+
# testing_util will provide the content of in_outs["inputs"][test_case_idx] to stdin.
|
|
214
|
+
|
|
215
|
+
# The generated code_solution might be a full script or just a function.
|
|
216
|
+
# If it's just a function, the wrapper needs to call it.
|
|
217
|
+
# If it's a full script with if __name__ == "__main__":, that will be handled by testing_util's stdio path.
|
|
218
|
+
|
|
219
|
+
# Let's assume the new system prompt encourages `def main(): ...`
|
|
220
|
+
# The `testing_util.py` standard input path wraps the solution in `def code(): ... solution ...`
|
|
221
|
+
# and then calls `code()`. If `solution` is `def main(): ...`, then `code()` just defines `main`.
|
|
222
|
+
# We need `main()` to be called.
|
|
223
|
+
# So, the `code_solution` itself should end with `if __name__ == "__main__": main()` or just `main()`.
|
|
224
|
+
# The system prompt now asks for `main()`. Let's assume the model provides it and might call it.
|
|
225
|
+
|
|
226
|
+
# Forcing testing_util to use its standard_input path by ensuring fn_name is None in the dict passed to it.
|
|
227
|
+
# The actual function name logic is now handled by the system prompt guiding the model.
|
|
228
|
+
# The `in_outs` dict passed to check_correctness will have its 'fn_name' key removed or set to None
|
|
229
|
+
# to ensure testing_util.py uses its standard input execution path.
|
|
230
|
+
# The `generation` argument to check_correctness will be the `code_solution`.
|
|
231
|
+
# `testing_util.py` will wrap this `code_solution` in `def code(): ...` and call `code()`.
|
|
232
|
+
# If `code_solution` is `def main(): ... ; main()`, it should work.
|
|
233
|
+
# If `code_solution` is just `def main(): ...`, it won't work.
|
|
234
|
+
# The new system prompt is: "Structure your solution within a main() function. ... main() should handle it. ... main() should print..."
|
|
235
|
+
# This implies the model should provide a callable main that does everything.
|
|
236
|
+
|
|
237
|
+
# Let's simplify: assume the model provides a runnable script (e.g. with main() called at the end, or top-level code)
|
|
238
|
+
# due to the new system prompt. We will rely on testing_util's standard_input path.
|
|
239
|
+
# We need to ensure `fn_name` is NOT in `in_outs` when calling `check_correctness`.
|
|
240
|
+
|
|
241
|
+
in_outs_for_check = in_outs.copy() # Use a copy to modify for check_correctness
|
|
242
|
+
if "fn_name" in in_outs_for_check:
|
|
243
|
+
# Remove fn_name to force testing_util's standard_input path,
|
|
244
|
+
# as our system prompt now asks for a main() that handles IO.
|
|
245
|
+
# The generated code itself should be a runnable script.
|
|
246
|
+
del in_outs_for_check["fn_name"]
|
|
247
|
+
logger.info(f"Removed 'fn_name' from in_outs for check_correctness to use standard_input path.")
|
|
248
|
+
|
|
249
|
+
final_code_to_execute = code_solution # The model's full response (after extraction)
|
|
250
|
+
|
|
251
|
+
try:
|
|
252
|
+
results_list, exec_metadata_list = check_correctness(
|
|
253
|
+
in_outs=in_outs_for_check, # This now has no 'fn_name'
|
|
254
|
+
generation=final_code_to_execute,
|
|
255
|
+
timeout=timeout,
|
|
256
|
+
debug=debug_execution,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# Process results_list
|
|
260
|
+
if not results_list: # Should not happen if check_correctness returns properly
|
|
261
|
+
reason_msg = "Execution utility returned no results."
|
|
262
|
+
logger.error(reason_msg)
|
|
263
|
+
metrics["execution_error"] = MetricResult(score=0.0, reason=reason_msg, is_score_valid=False)
|
|
264
|
+
else:
|
|
265
|
+
# Check for error codes (-1 for runtime/timeout, -2 for compilation error)
|
|
266
|
+
# These error codes are per test case as per testing_util.py's results.append()
|
|
267
|
+
# However, check_correctness's _temp_run appends a list, so results_list is a list of lists.
|
|
268
|
+
# The outer list from check_correctness usually has one item: the list of results from run_test.
|
|
269
|
+
|
|
270
|
+
actual_results = results_list # results_list from check_correctness is already the list of actual outcomes
|
|
271
|
+
|
|
272
|
+
num_tests = len(actual_results)
|
|
273
|
+
if num_tests == 0: # Should ideally not happen if in_outs['inputs'] is non-empty
|
|
274
|
+
reason_msg = "No test cases were effectively run or reported by execution utility."
|
|
275
|
+
logger.warning(reason_msg)
|
|
276
|
+
# Score remains 0.0
|
|
277
|
+
else:
|
|
278
|
+
passed_count = sum(1 for res in actual_results if res is True)
|
|
279
|
+
score = float(passed_count) / num_tests
|
|
280
|
+
reason_msg = f"Passed {passed_count}/{num_tests} test cases."
|
|
281
|
+
logger.info(f"Execution result: {reason_msg}")
|
|
282
|
+
|
|
283
|
+
metrics["pass_rate"] = MetricResult(score=score, reason=f"{passed_count}/{num_tests}")
|
|
284
|
+
metrics["raw_results"] = MetricResult(
|
|
285
|
+
score=0.0, reason=json.dumps(actual_results), is_score_valid=False
|
|
286
|
+
) # Store raw results
|
|
287
|
+
|
|
288
|
+
# Process metadata
|
|
289
|
+
# exec_metadata_list is a list of dicts. If it's a single dict (e.g. compilation error), wrap it.
|
|
290
|
+
# The check_correctness in apps_execution_utils.py should return a list of metadata dicts.
|
|
291
|
+
if exec_metadata_list:
|
|
292
|
+
# If there's a single metadata entry that contains a significant error (like compilation)
|
|
293
|
+
# it might apply to the whole attempt.
|
|
294
|
+
# For now, just log it or add to metrics.
|
|
295
|
+
# The original prime_code's compute_score returns a list of metadata.
|
|
296
|
+
# We'll store it as a JSON string for simplicity in metrics.
|
|
297
|
+
# If only one metadata dict, it might be a global error (e.g. compilation)
|
|
298
|
+
if len(exec_metadata_list) == 1 and exec_metadata_list[0].get("error"):
|
|
299
|
+
reason_msg += f" Execution Error: {exec_metadata_list[0]['error']}"
|
|
300
|
+
metrics["execution_error_details"] = MetricResult(
|
|
301
|
+
score=0.0,
|
|
302
|
+
reason=json.dumps(exec_metadata_list[0]),
|
|
303
|
+
is_score_valid=False,
|
|
304
|
+
)
|
|
305
|
+
elif exec_metadata_list: # It's not a global error, but there's metadata (e.g., for Wrong Answer)
|
|
306
|
+
metrics["execution_metadata"] = MetricResult(
|
|
307
|
+
score=0.0,
|
|
308
|
+
reason=json.dumps(exec_metadata_list),
|
|
309
|
+
is_score_valid=False,
|
|
310
|
+
)
|
|
311
|
+
# If it's a "Wrong Answer" and score is 0, enhance the reason_msg
|
|
312
|
+
if score == 0.0 and exec_metadata_list[0].get("error_message") == "Wrong Answer":
|
|
313
|
+
first_fail_meta = exec_metadata_list[0]
|
|
314
|
+
reason_msg += (
|
|
315
|
+
f". First fail details: Inputs: {first_fail_meta.get('inputs', 'N/A')}, "
|
|
316
|
+
f"Expected: {first_fail_meta.get('expected', 'N/A')}, "
|
|
317
|
+
f"Got: {first_fail_meta.get('output', 'N/A')}"
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
# If score is 0 and there was an error in metadata, reflect it in reason_msg
|
|
321
|
+
# This condition might be redundant now due to the above, or could be a fallback.
|
|
322
|
+
if score == 0.0 and metrics.get("execution_error_details") and "Execution Error" not in reason_msg:
|
|
323
|
+
pass # reason_msg might already be updated by global error or Wrong Answer details.
|
|
324
|
+
|
|
325
|
+
except Exception as e:
|
|
326
|
+
score = 0.0 # Ensure score is 0 on any unexpected error in this block
|
|
327
|
+
reason_msg = f"Error during code execution or result processing: {type(e).__name__}: {e}"
|
|
328
|
+
logger.error(reason_msg, exc_info=True)
|
|
329
|
+
metrics["evaluation_error"] = MetricResult(score=0.0, reason=reason_msg, is_score_valid=False)
|
|
330
|
+
|
|
331
|
+
return EvaluateResult(score=score, metrics=metrics, reason=reason_msg)
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
# Copyright 2024 PRIME team and/or its affiliates (Adapted for reward-kit)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import multiprocessing
|
|
17
|
+
import os
|
|
18
|
+
import sys
|
|
19
|
+
import traceback
|
|
20
|
+
from typing import Any, Dict, List, Optional
|
|
21
|
+
|
|
22
|
+
# Adapted import to point to our local apps_testing_util.py
|
|
23
|
+
from .apps_testing_util import run_test
|
|
24
|
+
|
|
25
|
+
# Note: The original file had a compute_score function.
|
|
26
|
+
# We are primarily interested in check_correctness and its helper _temp_run.
|
|
27
|
+
# The main reward function in apps_coding_reward.py will call check_correctness.
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _temp_run(
|
|
31
|
+
sample: Dict[str, Any],
|
|
32
|
+
generation: str,
|
|
33
|
+
debug: bool,
|
|
34
|
+
result_list: list,
|
|
35
|
+
metadata_list: list,
|
|
36
|
+
timeout: int,
|
|
37
|
+
):
|
|
38
|
+
"""
|
|
39
|
+
Helper function to run a single test in a separate process context.
|
|
40
|
+
Manages stdout/stderr redirection and captures results/metadata.
|
|
41
|
+
"""
|
|
42
|
+
# Redirect stdout/stderr to prevent interference and capture output if needed by run_test
|
|
43
|
+
# Note: run_test itself might also capture stdout for standard_input type problems.
|
|
44
|
+
# This top-level redirection ensures the process itself doesn't pollute console.
|
|
45
|
+
original_stdout = sys.stdout
|
|
46
|
+
original_stderr = sys.stderr
|
|
47
|
+
# Temporarily disable stdout/stderr redirection to see debug prints from run_test
|
|
48
|
+
# sys.stdout = open(os.devnull, "w")
|
|
49
|
+
# sys.stderr = open(os.devnull, "w")
|
|
50
|
+
print(f"[_temp_run] Executing run_test for sample. Debug prints from run_test should be visible.")
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout)
|
|
54
|
+
result_list.append(res)
|
|
55
|
+
metadata_list.append(metadata)
|
|
56
|
+
except Exception:
|
|
57
|
+
# This catch-all is for unexpected errors within _temp_run itself or run_test if it raises
|
|
58
|
+
# instead of returning error codes in `res`.
|
|
59
|
+
# run_test is designed to return error codes like -1 or -2 in `res` for failures.
|
|
60
|
+
num_inputs = len(sample.get("inputs", []))
|
|
61
|
+
tb_str = traceback.format_exc()
|
|
62
|
+
print(f"[_temp_run] Exception caught: {tb_str}")
|
|
63
|
+
result_list.append([-1] * num_inputs if num_inputs > 0 else [-1]) # Mark all as error
|
|
64
|
+
metadata_list.append({"error": "Exception in _temp_run/run_test", "traceback": tb_str})
|
|
65
|
+
finally:
|
|
66
|
+
# Restore stdout/stderr
|
|
67
|
+
# if sys.stdout is not original_stdout and hasattr(sys.stdout, 'close'): # Check if it was replaced and closable
|
|
68
|
+
# sys.stdout.close()
|
|
69
|
+
# if sys.stderr is not original_stderr and hasattr(sys.stderr, 'close'): # Check if it was replaced and closable
|
|
70
|
+
# sys.stderr.close()
|
|
71
|
+
sys.stdout = original_stdout # Always restore
|
|
72
|
+
sys.stderr = original_stderr # Always restore
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def check_correctness(
|
|
76
|
+
in_outs: Optional[dict], generation: str, timeout: int = 10, debug: bool = True
|
|
77
|
+
) -> tuple[List[Any], List[Dict[str, Any]]]:
|
|
78
|
+
"""
|
|
79
|
+
Checks correctness of code generation with a global timeout using multiprocessing.
|
|
80
|
+
The global timeout is to catch some extreme/rare cases not handled by the timeouts
|
|
81
|
+
inside `run_test`.
|
|
82
|
+
Args:
|
|
83
|
+
in_outs: Dictionary with "inputs" and "outputs" lists, and optionally "fn_name".
|
|
84
|
+
generation: The code string to test.
|
|
85
|
+
timeout: Timeout in seconds for each test case execution within run_test,
|
|
86
|
+
and also for the overall process.
|
|
87
|
+
debug: Debug flag passed to run_test.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
A tuple containing:
|
|
91
|
+
- A list of results (e.g., booleans for pass/fail, or error codes).
|
|
92
|
+
- A list of metadata dictionaries corresponding to each result.
|
|
93
|
+
"""
|
|
94
|
+
if not in_outs or "inputs" not in in_outs or not isinstance(in_outs["inputs"], list):
|
|
95
|
+
# Handle cases where in_outs might be None or malformed early
|
|
96
|
+
return [-1], [{"error": "Invalid or missing in_outs structure"}]
|
|
97
|
+
|
|
98
|
+
manager = multiprocessing.Manager()
|
|
99
|
+
result_proxy = manager.list() # Using proxy list for multiprocessing
|
|
100
|
+
metadata_proxy = manager.list() # Using proxy list for multiprocessing
|
|
101
|
+
|
|
102
|
+
process = multiprocessing.Process(
|
|
103
|
+
target=_temp_run,
|
|
104
|
+
args=(in_outs, generation, debug, result_proxy, metadata_proxy, timeout),
|
|
105
|
+
)
|
|
106
|
+
process.start()
|
|
107
|
+
process.join(timeout=timeout + 1) # Join with a slightly longer timeout for the process itself
|
|
108
|
+
|
|
109
|
+
if process.is_alive():
|
|
110
|
+
process.kill() # Force kill if still alive after timeout
|
|
111
|
+
# process.terminate() # Alternative, more graceful termination
|
|
112
|
+
|
|
113
|
+
# Convert proxy lists to regular lists for return
|
|
114
|
+
# Ensure that if the process was killed, we have some default error state.
|
|
115
|
+
if not result_proxy: # If result_proxy is empty (e.g., process killed before appending)
|
|
116
|
+
num_inputs = len(in_outs.get("inputs", []))
|
|
117
|
+
final_results = [-1] * num_inputs if num_inputs > 0 else [-1] # Mark all as error
|
|
118
|
+
final_metadata = (
|
|
119
|
+
[
|
|
120
|
+
{
|
|
121
|
+
"error": "Global timeout or process killed prematurely",
|
|
122
|
+
"details": "No results returned from subprocess.",
|
|
123
|
+
}
|
|
124
|
+
]
|
|
125
|
+
* num_inputs
|
|
126
|
+
if num_inputs > 0
|
|
127
|
+
else [{"error": "Global timeout"}]
|
|
128
|
+
)
|
|
129
|
+
if debug:
|
|
130
|
+
print("Global timeout or process killed before results could be appended.")
|
|
131
|
+
else:
|
|
132
|
+
final_results = list(result_proxy)[0] # Expecting run_test to return a list of results for the inputs
|
|
133
|
+
final_metadata = list(metadata_proxy) # This should be a list of dicts
|
|
134
|
+
|
|
135
|
+
# Ensure metadata_list has a corresponding entry for each input if it was a single dict from run_test
|
|
136
|
+
if (
|
|
137
|
+
isinstance(final_metadata, list)
|
|
138
|
+
and len(final_metadata) == 1
|
|
139
|
+
and isinstance(final_metadata[0], dict)
|
|
140
|
+
and len(in_outs.get("inputs", [])) > 1
|
|
141
|
+
):
|
|
142
|
+
# If run_test returned a single metadata dict for multiple inputs (e.g. on compilation error)
|
|
143
|
+
# or if _temp_run appended a single error dict.
|
|
144
|
+
# We might want to duplicate this error metadata for all inputs if results indicate multiple failures.
|
|
145
|
+
# However, the original logic in prime_code's compute_score seems to handle metadata as a list already.
|
|
146
|
+
# For now, assume metadata_proxy structure is as intended by _temp_run.
|
|
147
|
+
pass
|
|
148
|
+
|
|
149
|
+
return final_results, final_metadata
|