eval-protocol 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- development/__init__.py +1 -0
- development/normalize_sandbox_fusion.py +628 -0
- development/utils/__init__.py +1 -0
- development/utils/generate_api_key.py +31 -0
- development/utils/subprocess_manager.py +481 -0
- eval_protocol/__init__.py +86 -0
- eval_protocol/__main__.py +10 -0
- eval_protocol/_version.py +21 -0
- eval_protocol/adapters/__init__.py +1 -0
- eval_protocol/adapters/braintrust.py +8 -0
- eval_protocol/adapters/trl.py +8 -0
- eval_protocol/agent/__init__.py +29 -0
- eval_protocol/agent/models.py +69 -0
- eval_protocol/agent/orchestrator.py +893 -0
- eval_protocol/agent/resource_abc.py +89 -0
- eval_protocol/agent/resource_pool.py +184 -0
- eval_protocol/agent/resources/__init__.py +44 -0
- eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
- eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
- eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
- eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
- eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
- eval_protocol/agent/resources/docker_resource.py +479 -0
- eval_protocol/agent/resources/filesystem_resource.py +371 -0
- eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
- eval_protocol/agent/resources/http_rollout_resource.py +325 -0
- eval_protocol/agent/resources/python_state_resource.py +170 -0
- eval_protocol/agent/resources/sql_resource.py +271 -0
- eval_protocol/agent/task_manager.py +1064 -0
- eval_protocol/agent/tool_registry.py +111 -0
- eval_protocol/auth.py +156 -0
- eval_protocol/cli.py +425 -0
- eval_protocol/cli_commands/__init__.py +1 -0
- eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
- eval_protocol/cli_commands/common.py +242 -0
- eval_protocol/cli_commands/deploy.py +486 -0
- eval_protocol/cli_commands/deploy_mcp.py +287 -0
- eval_protocol/cli_commands/preview.py +186 -0
- eval_protocol/cli_commands/run_eval_cmd.py +202 -0
- eval_protocol/common_utils.py +36 -0
- eval_protocol/config.py +180 -0
- eval_protocol/datasets/__init__.py +1 -0
- eval_protocol/datasets/loader.py +521 -0
- eval_protocol/evaluation.py +1045 -0
- eval_protocol/execution/__init__.py +1 -0
- eval_protocol/execution/pipeline.py +920 -0
- eval_protocol/gcp_tools.py +484 -0
- eval_protocol/generation/cache.py +141 -0
- eval_protocol/generation/clients/base.py +67 -0
- eval_protocol/generation/clients.py +248 -0
- eval_protocol/generic_server.py +165 -0
- eval_protocol/integrations/__init__.py +12 -0
- eval_protocol/integrations/braintrust.py +51 -0
- eval_protocol/integrations/deepeval.py +106 -0
- eval_protocol/integrations/openeval.py +40 -0
- eval_protocol/integrations/trl.py +187 -0
- eval_protocol/mcp/__init__.py +48 -0
- eval_protocol/mcp/adapter.py +131 -0
- eval_protocol/mcp/client/__init__.py +12 -0
- eval_protocol/mcp/client/connection.py +499 -0
- eval_protocol/mcp/clients.py +195 -0
- eval_protocol/mcp/execution/__init__.py +23 -0
- eval_protocol/mcp/execution/base_policy.py +227 -0
- eval_protocol/mcp/execution/fireworks_policy.py +209 -0
- eval_protocol/mcp/execution/manager.py +506 -0
- eval_protocol/mcp/execution/policy.py +421 -0
- eval_protocol/mcp/grid_renderer.py +54 -0
- eval_protocol/mcp/mcpgym.py +637 -0
- eval_protocol/mcp/process_manager.py +177 -0
- eval_protocol/mcp/session/__init__.py +11 -0
- eval_protocol/mcp/session/manager.py +228 -0
- eval_protocol/mcp/simple_process_manager.py +291 -0
- eval_protocol/mcp/simulation_server.py +458 -0
- eval_protocol/mcp/types.py +80 -0
- eval_protocol/mcp_agent/__init__.py +1 -0
- eval_protocol/mcp_agent/config.py +147 -0
- eval_protocol/mcp_agent/intermediary_server.py +542 -0
- eval_protocol/mcp_agent/main.py +210 -0
- eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
- eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
- eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
- eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
- eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
- eval_protocol/mcp_agent/session.py +79 -0
- eval_protocol/mcp_env.py +304 -0
- eval_protocol/models.py +366 -0
- eval_protocol/packaging.py +219 -0
- eval_protocol/platform_api.py +360 -0
- eval_protocol/playback_policy.py +396 -0
- eval_protocol/resources.py +128 -0
- eval_protocol/reward_function.py +410 -0
- eval_protocol/rewards/__init__.py +94 -0
- eval_protocol/rewards/accuracy.py +454 -0
- eval_protocol/rewards/accuracy_length.py +173 -0
- eval_protocol/rewards/apps_coding_reward.py +331 -0
- eval_protocol/rewards/apps_execution_utils.py +149 -0
- eval_protocol/rewards/apps_testing_util.py +559 -0
- eval_protocol/rewards/bfcl_reward.py +313 -0
- eval_protocol/rewards/code_execution.py +1620 -0
- eval_protocol/rewards/code_execution_utils.py +72 -0
- eval_protocol/rewards/cpp_code.py +861 -0
- eval_protocol/rewards/deepcoder_reward.py +161 -0
- eval_protocol/rewards/format.py +129 -0
- eval_protocol/rewards/function_calling.py +541 -0
- eval_protocol/rewards/json_schema.py +422 -0
- eval_protocol/rewards/language_consistency.py +700 -0
- eval_protocol/rewards/lean_prover.py +479 -0
- eval_protocol/rewards/length.py +375 -0
- eval_protocol/rewards/list_comparison_math_reward.py +221 -0
- eval_protocol/rewards/math.py +762 -0
- eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
- eval_protocol/rewards/reasoning_steps.py +249 -0
- eval_protocol/rewards/repetition.py +342 -0
- eval_protocol/rewards/tag_count.py +162 -0
- eval_protocol/rl_processing.py +82 -0
- eval_protocol/server.py +271 -0
- eval_protocol/typed_interface.py +260 -0
- eval_protocol/utils/__init__.py +8 -0
- eval_protocol/utils/batch_evaluation.py +217 -0
- eval_protocol/utils/batch_transformation.py +205 -0
- eval_protocol/utils/dataset_helpers.py +112 -0
- eval_protocol/utils/module_loader.py +56 -0
- eval_protocol/utils/packaging_utils.py +108 -0
- eval_protocol/utils/static_policy.py +305 -0
- eval_protocol-0.0.3.dist-info/METADATA +635 -0
- eval_protocol-0.0.3.dist-info/RECORD +130 -0
- eval_protocol-0.0.3.dist-info/WHEEL +5 -0
- eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
- eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
- eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,559 @@
|
|
|
1
|
+
# Copyright 2024 PRIME team and/or its affiliates
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# Borrowed from: https://huggingface.co/spaces/codeparrot/apps_metric/blob/main/utils.py
|
|
16
|
+
# Adapted for reward-kit: Removed pyext.RuntimeModule, other minor adjustments may be needed.
|
|
17
|
+
|
|
18
|
+
import ast
|
|
19
|
+
import faulthandler
|
|
20
|
+
import importlib.util # Added for dynamic module loading
|
|
21
|
+
import json
|
|
22
|
+
import platform
|
|
23
|
+
import re # Added for re.search
|
|
24
|
+
|
|
25
|
+
# to run the solution files we're using a timing based approach
|
|
26
|
+
import signal
|
|
27
|
+
import sys
|
|
28
|
+
import textwrap # Added for dedenting model output
|
|
29
|
+
import traceback
|
|
30
|
+
|
|
31
|
+
# used for debugging to time steps
|
|
32
|
+
from datetime import datetime
|
|
33
|
+
from enum import Enum
|
|
34
|
+
|
|
35
|
+
# for capturing the stdout
|
|
36
|
+
from io import StringIO
|
|
37
|
+
|
|
38
|
+
# used for testing the code that reads from input
|
|
39
|
+
from unittest.mock import mock_open, patch
|
|
40
|
+
|
|
41
|
+
import numpy as np
|
|
42
|
+
|
|
43
|
+
# from pyext import RuntimeModule # Removed this problematic import
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def truncatefn(s, length=300):
|
|
47
|
+
assert isinstance(s, str)
|
|
48
|
+
if len(s) <= length:
|
|
49
|
+
return s
|
|
50
|
+
|
|
51
|
+
return s[: length // 2] + "...(truncated) ..." + s[-length // 2 :]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class CODE_TYPE(Enum):
|
|
55
|
+
call_based = 0
|
|
56
|
+
standard_input = 1
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# used to capture stdout as a list
|
|
60
|
+
# from https://stackoverflow.com/a/16571630/6416660
|
|
61
|
+
# alternative use redirect_stdout() from contextlib
|
|
62
|
+
class Capturing(list):
|
|
63
|
+
def __enter__(self):
|
|
64
|
+
self._stdout = sys.stdout
|
|
65
|
+
sys.stdout = self._stringio = StringIO()
|
|
66
|
+
# Make closing the StringIO a no-op
|
|
67
|
+
self._stringio.close = lambda: None # Changed lambda x: 1 to lambda: None
|
|
68
|
+
return self
|
|
69
|
+
|
|
70
|
+
def __exit__(self, *args):
|
|
71
|
+
self.append(self._stringio.getvalue())
|
|
72
|
+
del self._stringio # free up some memory
|
|
73
|
+
sys.stdout = self._stdout
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def only_int_check(val):
|
|
77
|
+
return isinstance(val, int)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def string_int_check(val):
|
|
81
|
+
return isinstance(val, str) and val.isdigit()
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def combined_int_check(val):
|
|
85
|
+
return only_int_check(val) or string_int_check(val)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def clean_traceback(error_traceback):
|
|
89
|
+
file_start = error_traceback.find('File "<string>"')
|
|
90
|
+
if file_start == -1: # Check if "<string>" is not found, common if exec is used directly
|
|
91
|
+
file_start = error_traceback.find('File "<dynamic_module>"') # Fallback for our dynamic module name
|
|
92
|
+
|
|
93
|
+
if file_start != -1:
|
|
94
|
+
error_traceback = "Traceback (most recent call last):\n " + error_traceback[file_start:]
|
|
95
|
+
return error_traceback
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _load_module_from_string(module_name, code_string):
|
|
99
|
+
"""Loads a Python module from a string using importlib."""
|
|
100
|
+
spec = importlib.util.spec_from_loader(module_name, loader=None, origin="<generated_code>")
|
|
101
|
+
if spec is None:
|
|
102
|
+
raise ImportError(f"Could not create spec for dynamic module '{module_name}'")
|
|
103
|
+
|
|
104
|
+
module = importlib.util.module_from_spec(spec)
|
|
105
|
+
|
|
106
|
+
# Execute the code in the new module's namespace
|
|
107
|
+
# Ensure that the module is usable by adding it to sys.modules temporarily if needed,
|
|
108
|
+
# or by ensuring its __dict__ is correctly populated.
|
|
109
|
+
try:
|
|
110
|
+
exec(code_string, module.__dict__)
|
|
111
|
+
# sys.modules[module_name] = module # Optional: if other parts of the code expect it in sys.modules
|
|
112
|
+
except Exception as e:
|
|
113
|
+
raise
|
|
114
|
+
return module
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def run_test(in_outs, test=None, debug=False, timeout=15):
|
|
118
|
+
"""
|
|
119
|
+
if test(generated_code) is not None it'll try to run the code.
|
|
120
|
+
otherwise it'll just return an input and output pair.
|
|
121
|
+
"""
|
|
122
|
+
# Disable functionalities that can make destructive changes to the test.
|
|
123
|
+
reliability_guard()
|
|
124
|
+
|
|
125
|
+
if debug:
|
|
126
|
+
print(f"start = {datetime.now().time()}")
|
|
127
|
+
|
|
128
|
+
if in_outs:
|
|
129
|
+
if in_outs.get("fn_name") is None:
|
|
130
|
+
which_type = CODE_TYPE.standard_input # Standard input
|
|
131
|
+
method_name = None
|
|
132
|
+
else:
|
|
133
|
+
which_type = CODE_TYPE.call_based # Call-based
|
|
134
|
+
method_name = in_outs["fn_name"]
|
|
135
|
+
|
|
136
|
+
if debug:
|
|
137
|
+
print(f"loaded input_output = {datetime.now().time()}")
|
|
138
|
+
|
|
139
|
+
if test is None:
|
|
140
|
+
raise AssertionError("should not happen: test code is none")
|
|
141
|
+
elif test is not None:
|
|
142
|
+
results = []
|
|
143
|
+
# Standard library imports prepended to the solution
|
|
144
|
+
sol = "from string import *\nfrom re import *\nfrom datetime import *\nfrom collections import *\nfrom heapq import *\nfrom bisect import *\nfrom copy import *\nfrom math import *\nfrom random import *\nfrom statistics import *\nfrom itertools import *\nfrom functools import *\nfrom operator import *\nfrom io import *\nfrom sys import *\nfrom json import *\nfrom builtins import *\nfrom typing import *\nimport string\nimport re\nimport datetime\nimport collections\nimport heapq\nimport bisect\nimport copy\nimport math\nimport random\nimport statistics\nimport itertools\nimport functools\nimport operator\nimport io\nimport sys\nimport json\nsys.setrecursionlimit(6*10**5)\n" # noqa: E501
|
|
145
|
+
if debug:
|
|
146
|
+
print(f"loading test code = {datetime.now().time()}")
|
|
147
|
+
|
|
148
|
+
if which_type == CODE_TYPE.call_based:
|
|
149
|
+
sol += test
|
|
150
|
+
if debug:
|
|
151
|
+
print(f"sol = {sol}")
|
|
152
|
+
signal.alarm(timeout) # This is Unix-specific
|
|
153
|
+
try:
|
|
154
|
+
# Replace RuntimeModule.from_string
|
|
155
|
+
tmp_sol = _load_module_from_string("tmp_sol_call_based", sol)
|
|
156
|
+
tmp = tmp_sol if "class Solution" not in test else tmp_sol.Solution()
|
|
157
|
+
signal.alarm(0)
|
|
158
|
+
except Exception as e:
|
|
159
|
+
signal.alarm(0)
|
|
160
|
+
error_traceback = traceback.format_exc()
|
|
161
|
+
if debug:
|
|
162
|
+
print(f"type 0 compilation error = {e}")
|
|
163
|
+
results.append(-2)
|
|
164
|
+
return results, {
|
|
165
|
+
"error": repr(e),
|
|
166
|
+
"traceback": clean_traceback(error_traceback),
|
|
167
|
+
}
|
|
168
|
+
signal.alarm(0)
|
|
169
|
+
|
|
170
|
+
elif which_type == CODE_TYPE.standard_input:
|
|
171
|
+
try:
|
|
172
|
+
astree = ast.parse(test)
|
|
173
|
+
last_block = astree.body[-1]
|
|
174
|
+
if isinstance(last_block, ast.If):
|
|
175
|
+
condition = last_block.test
|
|
176
|
+
if ast.unparse(condition).strip() == "__name__ == '__main__'":
|
|
177
|
+
test = ast.unparse(astree.body[:-1]) + "\n" + ast.unparse(last_block.body)
|
|
178
|
+
except Exception:
|
|
179
|
+
pass
|
|
180
|
+
|
|
181
|
+
# `test` is the user's generated code string at this point.
|
|
182
|
+
# Preprocessing for `if __name__ == "__main__"` is already done.
|
|
183
|
+
|
|
184
|
+
# Dedent the entire model-generated code block first
|
|
185
|
+
try:
|
|
186
|
+
dedented_test_code = textwrap.dedent(test)
|
|
187
|
+
except Exception as e:
|
|
188
|
+
# In case dedent fails (e.g. on empty or malformed string), use original
|
|
189
|
+
if debug:
|
|
190
|
+
print(f"Warning: textwrap.dedent failed on model code: {e}. Using original code.")
|
|
191
|
+
dedented_test_code = test
|
|
192
|
+
|
|
193
|
+
# Check if 'def main(' is in the dedented code and if 'main()' call is missing.
|
|
194
|
+
main_defined = "def main(" in dedented_test_code
|
|
195
|
+
main_called_at_toplevel = re.search(r"^\s*main\s*\(\s*\)", dedented_test_code, re.MULTILINE) is not None
|
|
196
|
+
# Also consider if it's guarded by if __name__ == "__main__": which was removed by AST.
|
|
197
|
+
# If the AST modification removed an if __name__ block that called main,
|
|
198
|
+
# the original `test` string would be different from the AST-unparsed one.
|
|
199
|
+
# This is complex to track perfectly here.
|
|
200
|
+
# For now, a simpler heuristic: if `def main` is there, and no obvious `main()` call.
|
|
201
|
+
|
|
202
|
+
user_code_lines = dedented_test_code.split("\n")
|
|
203
|
+
|
|
204
|
+
# Imports from user code should be top-level in the module `sol`
|
|
205
|
+
# Other lines form the body of `def code():`
|
|
206
|
+
code_body_lines = []
|
|
207
|
+
|
|
208
|
+
for line in user_code_lines:
|
|
209
|
+
stripped_line = line.strip()
|
|
210
|
+
if stripped_line.startswith("from ") or stripped_line.startswith("import "):
|
|
211
|
+
sol += stripped_line + "\n" # Add stripped import directly to sol module scope
|
|
212
|
+
else:
|
|
213
|
+
# Add original line from (potentially dedented) user code to be tab-indented
|
|
214
|
+
code_body_lines.append("\t" + line)
|
|
215
|
+
|
|
216
|
+
if main_defined and not main_called_at_toplevel:
|
|
217
|
+
# If system prompt asks for main(), and model provides def main() but no call, add it.
|
|
218
|
+
# This assumes main() takes no arguments if called this way.
|
|
219
|
+
# This is appended to be *inside* the `def code():` wrapper.
|
|
220
|
+
code_body_lines.append("\tmain()")
|
|
221
|
+
if debug:
|
|
222
|
+
print("Appended main() call as it was defined but not found called at top level.")
|
|
223
|
+
|
|
224
|
+
# Construct the `def code():` wrapper string
|
|
225
|
+
code_wrapper_str = "stdin = sys.stdin\nstdout = sys.stdout\ndef code():\n"
|
|
226
|
+
code_wrapper_str += "\n".join(code_body_lines)
|
|
227
|
+
|
|
228
|
+
sol += code_wrapper_str # Add the "def code(): ..." to sol
|
|
229
|
+
|
|
230
|
+
if debug:
|
|
231
|
+
print(f"Constructed sol for standard_input: {sol}")
|
|
232
|
+
method_name = "code" # We will call the code() function
|
|
233
|
+
signal.alarm(timeout) # Unix-specific
|
|
234
|
+
try:
|
|
235
|
+
# Replace RuntimeModule.from_string
|
|
236
|
+
tmp_sol = _load_module_from_string("tmp_sol_std_input", sol)
|
|
237
|
+
tmp = tmp_sol
|
|
238
|
+
signal.alarm(0)
|
|
239
|
+
except Exception as e:
|
|
240
|
+
signal.alarm(0)
|
|
241
|
+
error_traceback = traceback.format_exc()
|
|
242
|
+
if debug:
|
|
243
|
+
print(f"type 1 compilation error = {e}")
|
|
244
|
+
results.append(-2)
|
|
245
|
+
return results, {
|
|
246
|
+
"error": repr(e),
|
|
247
|
+
"traceback": clean_traceback(error_traceback),
|
|
248
|
+
}
|
|
249
|
+
signal.alarm(0)
|
|
250
|
+
|
|
251
|
+
if debug:
|
|
252
|
+
print(f"get method = {datetime.now().time()}")
|
|
253
|
+
|
|
254
|
+
try:
|
|
255
|
+
method = getattr(tmp, method_name)
|
|
256
|
+
except AttributeError: # More specific exception
|
|
257
|
+
signal.alarm(0)
|
|
258
|
+
error_traceback = traceback.format_exc()
|
|
259
|
+
# error_info = sys.exc_info() # sys.exc_info() is less clear than repr(e)
|
|
260
|
+
results.append(-2)
|
|
261
|
+
return results, {
|
|
262
|
+
"error": f"AttributeError: Method '{method_name}' not found in dynamically loaded module.",
|
|
263
|
+
"traceback": clean_traceback(error_traceback),
|
|
264
|
+
}
|
|
265
|
+
except Exception as e: # Catch other potential errors during getattr
|
|
266
|
+
signal.alarm(0)
|
|
267
|
+
error_traceback = traceback.format_exc()
|
|
268
|
+
results.append(-2)
|
|
269
|
+
return results, {
|
|
270
|
+
"error": repr(e),
|
|
271
|
+
"traceback": clean_traceback(error_traceback),
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
for index, inputs_str in enumerate(in_outs["inputs"]): # Renamed inputs to inputs_str
|
|
275
|
+
raw_inputs = inputs_str
|
|
276
|
+
raw_outputs = in_outs["outputs"][index]
|
|
277
|
+
|
|
278
|
+
current_inputs = [] # Variable to hold processed inputs for the current test case
|
|
279
|
+
|
|
280
|
+
if which_type == CODE_TYPE.call_based:
|
|
281
|
+
# Assuming inputs_str is a string where each line is a separate JSON object for an argument
|
|
282
|
+
current_inputs = [json.loads(line) for line in inputs_str.split("\n") if line.strip()]
|
|
283
|
+
# Ensure in_outs["outputs"][index] is loaded if it's a string
|
|
284
|
+
if isinstance(in_outs["outputs"][index], str):
|
|
285
|
+
in_outs["outputs"][index] = json.loads(in_outs["outputs"][index])
|
|
286
|
+
|
|
287
|
+
truncate_line_size = 300 // (raw_inputs.count("\n") + 1) if raw_inputs.count("\n") > 0 else 300
|
|
288
|
+
raw_inputs_truncated = "\n".join(
|
|
289
|
+
[truncatefn(line, truncate_line_size) for line in raw_inputs.strip().split("\n")]
|
|
290
|
+
)
|
|
291
|
+
raw_outputs_truncated = (
|
|
292
|
+
truncatefn(json.dumps(in_outs["outputs"][index]), 200)
|
|
293
|
+
if not isinstance(in_outs["outputs"][index], str)
|
|
294
|
+
else truncatefn(in_outs["outputs"][index], 200)
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
else: # standard_input
|
|
298
|
+
current_inputs = inputs_str # For standard input, inputs might be a single string block
|
|
299
|
+
raw_inputs_truncated = truncatefn(raw_inputs)
|
|
300
|
+
raw_outputs_truncated = truncatefn(in_outs["outputs"][index], 200)
|
|
301
|
+
|
|
302
|
+
# JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list)
|
|
303
|
+
# This part seems specific and might need careful handling if inputs are not always lists of dicts
|
|
304
|
+
try:
|
|
305
|
+
if which_type == CODE_TYPE.call_based and current_inputs and isinstance(current_inputs[0], dict):
|
|
306
|
+
current_inputs = [
|
|
307
|
+
{int(k) if isinstance(k, str) and k.isdigit() else k: v for k, v in current_inputs[0].items()}
|
|
308
|
+
]
|
|
309
|
+
except Exception:
|
|
310
|
+
pass # Ignore if conversion fails, proceed with original
|
|
311
|
+
|
|
312
|
+
# Similar conversion for outputs
|
|
313
|
+
try:
|
|
314
|
+
if isinstance(in_outs["outputs"][index], dict):
|
|
315
|
+
in_outs["outputs"][index] = {
|
|
316
|
+
int(k) if isinstance(k, str) and k.isdigit() else k: v
|
|
317
|
+
for k, v in in_outs["outputs"][index].items()
|
|
318
|
+
}
|
|
319
|
+
elif (
|
|
320
|
+
isinstance(in_outs["outputs"][index], list)
|
|
321
|
+
and in_outs["outputs"][index]
|
|
322
|
+
and isinstance(in_outs["outputs"][index][0], dict)
|
|
323
|
+
):
|
|
324
|
+
in_outs["outputs"][index][0] = {
|
|
325
|
+
int(k) if isinstance(k, str) and k.isdigit() else k: v
|
|
326
|
+
for k, v in in_outs["outputs"][index][0].items()
|
|
327
|
+
}
|
|
328
|
+
except Exception:
|
|
329
|
+
pass
|
|
330
|
+
|
|
331
|
+
if debug:
|
|
332
|
+
print(
|
|
333
|
+
f"time: {datetime.now().time()} testing index = {index} inputs = {current_inputs}, type = {which_type}"
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
if which_type == CODE_TYPE.call_based:
|
|
337
|
+
signal.alarm(timeout) # Unix-specific
|
|
338
|
+
faulthandler.enable()
|
|
339
|
+
try:
|
|
340
|
+
output = method(*current_inputs)
|
|
341
|
+
|
|
342
|
+
# For comparison, ensure output format matches expected (e.g. list vs tuple)
|
|
343
|
+
# ground truth sequences are not tuples
|
|
344
|
+
if isinstance(output, tuple):
|
|
345
|
+
output = list(output)
|
|
346
|
+
|
|
347
|
+
# Comparison logic
|
|
348
|
+
tmp_result = output == in_outs["outputs"][index]
|
|
349
|
+
# Handle cases where expected output might be a list containing the actual output
|
|
350
|
+
if isinstance(in_outs["outputs"][index], list) and len(in_outs["outputs"][index]) == 1:
|
|
351
|
+
tmp_result = tmp_result or (output == in_outs["outputs"][index][0])
|
|
352
|
+
|
|
353
|
+
# Further comparison for list of tuples vs list of lists
|
|
354
|
+
try:
|
|
355
|
+
if (
|
|
356
|
+
isinstance(output, list)
|
|
357
|
+
and output
|
|
358
|
+
and isinstance(output[0], tuple)
|
|
359
|
+
and isinstance(in_outs["outputs"][index], list)
|
|
360
|
+
and in_outs["outputs"][index]
|
|
361
|
+
and isinstance(in_outs["outputs"][index][0], list)
|
|
362
|
+
):
|
|
363
|
+
output_list_of_lists = [list(x) for x in output]
|
|
364
|
+
tmp_result = tmp_result or (output_list_of_lists == in_outs["outputs"][index])
|
|
365
|
+
if isinstance(in_outs["outputs"][index][0], list): # If expected is list of lists
|
|
366
|
+
tmp_result = tmp_result or (output_list_of_lists == in_outs["outputs"][index][0])
|
|
367
|
+
|
|
368
|
+
except Exception:
|
|
369
|
+
pass
|
|
370
|
+
|
|
371
|
+
results.append(tmp_result)
|
|
372
|
+
|
|
373
|
+
if tmp_result is not True:
|
|
374
|
+
return results, {
|
|
375
|
+
"output": truncatefn(json.dumps(output), 200),
|
|
376
|
+
"expected": raw_outputs_truncated,
|
|
377
|
+
"inputs": raw_inputs_truncated,
|
|
378
|
+
"error_message": "Wrong Answer",
|
|
379
|
+
}
|
|
380
|
+
signal.alarm(0)
|
|
381
|
+
except Exception as e:
|
|
382
|
+
signal.alarm(0)
|
|
383
|
+
error_traceback = traceback.format_exc()
|
|
384
|
+
faulthandler.disable()
|
|
385
|
+
if debug:
|
|
386
|
+
print(f"Call-based runtime error or time limit exceeded error = {e}")
|
|
387
|
+
results.append(-1) # Indicate error
|
|
388
|
+
return results, {
|
|
389
|
+
"error": repr(e),
|
|
390
|
+
"traceback": clean_traceback(error_traceback),
|
|
391
|
+
}
|
|
392
|
+
faulthandler.disable()
|
|
393
|
+
signal.alarm(0)
|
|
394
|
+
|
|
395
|
+
elif which_type == CODE_TYPE.standard_input:
|
|
396
|
+
faulthandler.enable()
|
|
397
|
+
|
|
398
|
+
# Ensure inputs_str is a single string for StringIO
|
|
399
|
+
processed_inputs_str = inputs_str
|
|
400
|
+
if isinstance(inputs_str, list):
|
|
401
|
+
processed_inputs_str = "\n".join(inputs_str)
|
|
402
|
+
|
|
403
|
+
# Ensure ground_truth is a string for comparison
|
|
404
|
+
ground_truth_str = in_outs["outputs"][index]
|
|
405
|
+
if isinstance(ground_truth_str, list):
|
|
406
|
+
ground_truth_str = "\n".join(ground_truth_str)
|
|
407
|
+
|
|
408
|
+
signal.alarm(timeout) # Unix-specific
|
|
409
|
+
captured_output_str = ""
|
|
410
|
+
try:
|
|
411
|
+
with Capturing() as output_lines:
|
|
412
|
+
call_method(method, processed_inputs_str)
|
|
413
|
+
captured_output_str = "".join(
|
|
414
|
+
output_lines
|
|
415
|
+
).rstrip() # rstrip to remove trailing newline often added
|
|
416
|
+
signal.alarm(0)
|
|
417
|
+
except Exception as e:
|
|
418
|
+
signal.alarm(0)
|
|
419
|
+
error_traceback = traceback.format_exc()
|
|
420
|
+
faulthandler.disable()
|
|
421
|
+
results.append(-1) # Indicate error
|
|
422
|
+
return results, {
|
|
423
|
+
"error": repr(e),
|
|
424
|
+
"traceback": clean_traceback(error_traceback),
|
|
425
|
+
}
|
|
426
|
+
faulthandler.disable()
|
|
427
|
+
signal.alarm(0)
|
|
428
|
+
|
|
429
|
+
# Comparison for standard input
|
|
430
|
+
# Normalize by splitting lines and stripping whitespace from each line
|
|
431
|
+
output_for_compare = [line.strip() for line in captured_output_str.splitlines()]
|
|
432
|
+
expected_for_compare = [line.strip() for line in ground_truth_str.splitlines()]
|
|
433
|
+
|
|
434
|
+
tmp_result = output_for_compare == expected_for_compare
|
|
435
|
+
|
|
436
|
+
# Additional float comparison if direct string match fails
|
|
437
|
+
if not tmp_result:
|
|
438
|
+
try:
|
|
439
|
+
if len(output_for_compare) == len(expected_for_compare):
|
|
440
|
+
output_float = [float(x) for x in output_for_compare]
|
|
441
|
+
gt_float = [float(x) for x in expected_for_compare]
|
|
442
|
+
if np.allclose(output_float, gt_float):
|
|
443
|
+
tmp_result = True
|
|
444
|
+
except (ValueError, TypeError):
|
|
445
|
+
pass # Not all are numbers, stick to string comparison
|
|
446
|
+
|
|
447
|
+
results.append(tmp_result)
|
|
448
|
+
if tmp_result is not True:
|
|
449
|
+
return results, {
|
|
450
|
+
"output": truncatefn(captured_output_str, 200),
|
|
451
|
+
"expected": raw_outputs_truncated,
|
|
452
|
+
"inputs": raw_inputs_truncated,
|
|
453
|
+
"error_message": "Wrong Answer",
|
|
454
|
+
}
|
|
455
|
+
# If all test cases for this sample passed
|
|
456
|
+
return results, {}
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def custom_compare_(output, ground_truth):
|
|
460
|
+
# This function seems to be part of an older comparison logic,
|
|
461
|
+
# more direct comparisons are now in run_test.
|
|
462
|
+
# Keeping it for now in case it's referenced, but likely can be simplified/removed.
|
|
463
|
+
if isinstance(output, list):
|
|
464
|
+
output_1 = "\n".join(output)
|
|
465
|
+
if stripped_string_compare(output_1, ground_truth):
|
|
466
|
+
return True
|
|
467
|
+
|
|
468
|
+
if isinstance(output, list):
|
|
469
|
+
output_2 = [o.lstrip().rstrip() for o in output]
|
|
470
|
+
output_2 = "\n".join(output_2)
|
|
471
|
+
if stripped_string_compare(output_2, ground_truth):
|
|
472
|
+
return True
|
|
473
|
+
|
|
474
|
+
return False
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def stripped_string_compare(s1, s2):
|
|
478
|
+
s1 = s1.lstrip().rstrip()
|
|
479
|
+
s2 = s2.lstrip().rstrip()
|
|
480
|
+
return s1 == s2
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def call_method(method, inputs_str_for_mock): # Renamed inputs to avoid conflict
|
|
484
|
+
# inputs_str_for_mock is the single string containing all inputs for stdin
|
|
485
|
+
|
|
486
|
+
inputs_line_iterator = iter(inputs_str_for_mock.split("\n"))
|
|
487
|
+
|
|
488
|
+
@patch("builtins.open", mock_open(read_data=inputs_str_for_mock))
|
|
489
|
+
@patch("sys.stdin", StringIO(inputs_str_for_mock))
|
|
490
|
+
@patch("sys.stdin.readline", lambda *args: next(inputs_line_iterator) + "\n") # Add newline as readline expects
|
|
491
|
+
@patch(
|
|
492
|
+
"sys.stdin.readlines",
|
|
493
|
+
lambda *args: [line + "\n" for line in inputs_str_for_mock.split("\n")],
|
|
494
|
+
)
|
|
495
|
+
@patch("sys.stdin.read", lambda *args: inputs_str_for_mock)
|
|
496
|
+
def _inner_call_method(_method_to_call): # Renamed _method to avoid conflict
|
|
497
|
+
try:
|
|
498
|
+
return _method_to_call()
|
|
499
|
+
except SystemExit: # Allow SystemExit to pass through, e.g. if code calls exit()
|
|
500
|
+
pass
|
|
501
|
+
finally:
|
|
502
|
+
pass
|
|
503
|
+
|
|
504
|
+
return _inner_call_method(method)
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
def reliability_guard(maximum_memory_bytes=None):
|
|
508
|
+
"""
|
|
509
|
+
This disables various destructive functions and prevents the generated code
|
|
510
|
+
from interfering with the test (e.g. fork bomb, killing other processes,
|
|
511
|
+
removing filesystem files, etc.)
|
|
512
|
+
WARNING
|
|
513
|
+
This function is NOT a security sandbox. Untrusted code, including, model-
|
|
514
|
+
generated code, should not be blindly executed outside of one. See the
|
|
515
|
+
Codex paper for more information about OpenAI's code sandbox, and proceed
|
|
516
|
+
with caution.
|
|
517
|
+
"""
|
|
518
|
+
|
|
519
|
+
if maximum_memory_bytes is not None:
|
|
520
|
+
import resource # Moved import here as it's Unix-specific for some parts
|
|
521
|
+
|
|
522
|
+
# Check if resource module has RLIMIT_AS, etc. (for cross-platform safety)
|
|
523
|
+
if hasattr(resource, "RLIMIT_AS"):
|
|
524
|
+
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
|
|
525
|
+
if hasattr(resource, "RLIMIT_DATA"):
|
|
526
|
+
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
|
|
527
|
+
if platform.uname().system != "Darwin" and hasattr(resource, "RLIMIT_STACK"): # RLIMIT_STACK not on macOS
|
|
528
|
+
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
|
|
529
|
+
|
|
530
|
+
faulthandler.disable() # This is fine
|
|
531
|
+
|
|
532
|
+
# It's generally safer to avoid modifying builtins directly if possible.
|
|
533
|
+
# For a library, this can have wide-ranging effects.
|
|
534
|
+
# Consider if this level of modification is truly necessary for eval_protocol's use case
|
|
535
|
+
# or if the multiprocessing wrapper in utils.py provides sufficient isolation.
|
|
536
|
+
# Note: The original implementation had many builtins and os/shutil functions commented out.
|
|
537
|
+
# These have been removed for clarity, as the preferred method of sandboxing
|
|
538
|
+
# would be via process isolation (e.g. multiprocessing or a dedicated sandbox env).
|
|
539
|
+
# Modifying builtins directly in a library function can have unintended side effects.
|
|
540
|
+
|
|
541
|
+
import os
|
|
542
|
+
|
|
543
|
+
os.environ["OMP_NUM_THREADS"] = "1"
|
|
544
|
+
|
|
545
|
+
# Disabling os functions: Be cautious, as this makes the execution environment very restrictive.
|
|
546
|
+
# This might be too aggressive if the generated code legitimately needs some safe os interactions.
|
|
547
|
+
# The multiprocessing wrapper in utils.py already provides process isolation.
|
|
548
|
+
|
|
549
|
+
# Example of functions that were previously considered for disabling:
|
|
550
|
+
# os.kill, os.system, os.remove, os.fork, etc.
|
|
551
|
+
# shutil.rmtree, shutil.move
|
|
552
|
+
# subprocess.Popen
|
|
553
|
+
# Modifying __builtins__ or sys.modules entries.
|
|
554
|
+
|
|
555
|
+
# For eval_protocol, rely on higher-level sandboxing if untrusted code execution is a concern.
|
|
556
|
+
# The memory limits via `resource` are a good first step for resource exhaustion.
|
|
557
|
+
import shutil # Keep import if other shutil functions are used, or remove if not.
|
|
558
|
+
import subprocess # Keep import if other subprocess functions are used, or remove if not.
|
|
559
|
+
import sys # Keep import for sys.stdout, sys.stdin manipulations.
|