eval-protocol 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. development/__init__.py +1 -0
  2. development/normalize_sandbox_fusion.py +628 -0
  3. development/utils/__init__.py +1 -0
  4. development/utils/generate_api_key.py +31 -0
  5. development/utils/subprocess_manager.py +481 -0
  6. eval_protocol/__init__.py +86 -0
  7. eval_protocol/__main__.py +10 -0
  8. eval_protocol/_version.py +21 -0
  9. eval_protocol/adapters/__init__.py +1 -0
  10. eval_protocol/adapters/braintrust.py +8 -0
  11. eval_protocol/adapters/trl.py +8 -0
  12. eval_protocol/agent/__init__.py +29 -0
  13. eval_protocol/agent/models.py +69 -0
  14. eval_protocol/agent/orchestrator.py +893 -0
  15. eval_protocol/agent/resource_abc.py +89 -0
  16. eval_protocol/agent/resource_pool.py +184 -0
  17. eval_protocol/agent/resources/__init__.py +44 -0
  18. eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
  19. eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
  20. eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
  21. eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
  22. eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
  23. eval_protocol/agent/resources/docker_resource.py +479 -0
  24. eval_protocol/agent/resources/filesystem_resource.py +371 -0
  25. eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
  26. eval_protocol/agent/resources/http_rollout_resource.py +325 -0
  27. eval_protocol/agent/resources/python_state_resource.py +170 -0
  28. eval_protocol/agent/resources/sql_resource.py +271 -0
  29. eval_protocol/agent/task_manager.py +1064 -0
  30. eval_protocol/agent/tool_registry.py +111 -0
  31. eval_protocol/auth.py +156 -0
  32. eval_protocol/cli.py +425 -0
  33. eval_protocol/cli_commands/__init__.py +1 -0
  34. eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
  35. eval_protocol/cli_commands/common.py +242 -0
  36. eval_protocol/cli_commands/deploy.py +486 -0
  37. eval_protocol/cli_commands/deploy_mcp.py +287 -0
  38. eval_protocol/cli_commands/preview.py +186 -0
  39. eval_protocol/cli_commands/run_eval_cmd.py +202 -0
  40. eval_protocol/common_utils.py +36 -0
  41. eval_protocol/config.py +180 -0
  42. eval_protocol/datasets/__init__.py +1 -0
  43. eval_protocol/datasets/loader.py +521 -0
  44. eval_protocol/evaluation.py +1045 -0
  45. eval_protocol/execution/__init__.py +1 -0
  46. eval_protocol/execution/pipeline.py +920 -0
  47. eval_protocol/gcp_tools.py +484 -0
  48. eval_protocol/generation/cache.py +141 -0
  49. eval_protocol/generation/clients/base.py +67 -0
  50. eval_protocol/generation/clients.py +248 -0
  51. eval_protocol/generic_server.py +165 -0
  52. eval_protocol/integrations/__init__.py +12 -0
  53. eval_protocol/integrations/braintrust.py +51 -0
  54. eval_protocol/integrations/deepeval.py +106 -0
  55. eval_protocol/integrations/openeval.py +40 -0
  56. eval_protocol/integrations/trl.py +187 -0
  57. eval_protocol/mcp/__init__.py +48 -0
  58. eval_protocol/mcp/adapter.py +131 -0
  59. eval_protocol/mcp/client/__init__.py +12 -0
  60. eval_protocol/mcp/client/connection.py +499 -0
  61. eval_protocol/mcp/clients.py +195 -0
  62. eval_protocol/mcp/execution/__init__.py +23 -0
  63. eval_protocol/mcp/execution/base_policy.py +227 -0
  64. eval_protocol/mcp/execution/fireworks_policy.py +209 -0
  65. eval_protocol/mcp/execution/manager.py +506 -0
  66. eval_protocol/mcp/execution/policy.py +421 -0
  67. eval_protocol/mcp/grid_renderer.py +54 -0
  68. eval_protocol/mcp/mcpgym.py +637 -0
  69. eval_protocol/mcp/process_manager.py +177 -0
  70. eval_protocol/mcp/session/__init__.py +11 -0
  71. eval_protocol/mcp/session/manager.py +228 -0
  72. eval_protocol/mcp/simple_process_manager.py +291 -0
  73. eval_protocol/mcp/simulation_server.py +458 -0
  74. eval_protocol/mcp/types.py +80 -0
  75. eval_protocol/mcp_agent/__init__.py +1 -0
  76. eval_protocol/mcp_agent/config.py +147 -0
  77. eval_protocol/mcp_agent/intermediary_server.py +542 -0
  78. eval_protocol/mcp_agent/main.py +210 -0
  79. eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
  80. eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
  81. eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
  82. eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
  83. eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
  84. eval_protocol/mcp_agent/session.py +79 -0
  85. eval_protocol/mcp_env.py +304 -0
  86. eval_protocol/models.py +366 -0
  87. eval_protocol/packaging.py +219 -0
  88. eval_protocol/platform_api.py +360 -0
  89. eval_protocol/playback_policy.py +396 -0
  90. eval_protocol/resources.py +128 -0
  91. eval_protocol/reward_function.py +410 -0
  92. eval_protocol/rewards/__init__.py +94 -0
  93. eval_protocol/rewards/accuracy.py +454 -0
  94. eval_protocol/rewards/accuracy_length.py +173 -0
  95. eval_protocol/rewards/apps_coding_reward.py +331 -0
  96. eval_protocol/rewards/apps_execution_utils.py +149 -0
  97. eval_protocol/rewards/apps_testing_util.py +559 -0
  98. eval_protocol/rewards/bfcl_reward.py +313 -0
  99. eval_protocol/rewards/code_execution.py +1620 -0
  100. eval_protocol/rewards/code_execution_utils.py +72 -0
  101. eval_protocol/rewards/cpp_code.py +861 -0
  102. eval_protocol/rewards/deepcoder_reward.py +161 -0
  103. eval_protocol/rewards/format.py +129 -0
  104. eval_protocol/rewards/function_calling.py +541 -0
  105. eval_protocol/rewards/json_schema.py +422 -0
  106. eval_protocol/rewards/language_consistency.py +700 -0
  107. eval_protocol/rewards/lean_prover.py +479 -0
  108. eval_protocol/rewards/length.py +375 -0
  109. eval_protocol/rewards/list_comparison_math_reward.py +221 -0
  110. eval_protocol/rewards/math.py +762 -0
  111. eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
  112. eval_protocol/rewards/reasoning_steps.py +249 -0
  113. eval_protocol/rewards/repetition.py +342 -0
  114. eval_protocol/rewards/tag_count.py +162 -0
  115. eval_protocol/rl_processing.py +82 -0
  116. eval_protocol/server.py +271 -0
  117. eval_protocol/typed_interface.py +260 -0
  118. eval_protocol/utils/__init__.py +8 -0
  119. eval_protocol/utils/batch_evaluation.py +217 -0
  120. eval_protocol/utils/batch_transformation.py +205 -0
  121. eval_protocol/utils/dataset_helpers.py +112 -0
  122. eval_protocol/utils/module_loader.py +56 -0
  123. eval_protocol/utils/packaging_utils.py +108 -0
  124. eval_protocol/utils/static_policy.py +305 -0
  125. eval_protocol-0.0.3.dist-info/METADATA +635 -0
  126. eval_protocol-0.0.3.dist-info/RECORD +130 -0
  127. eval_protocol-0.0.3.dist-info/WHEEL +5 -0
  128. eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
  129. eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
  130. eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
@@ -0,0 +1,861 @@
1
+ """
2
+ C/C++ code execution reward functions for evaluating C/C++ code correctness.
3
+
4
+ This module provides functions to evaluate the correctness of C/C++ code by:
5
+ 1. Extracting code blocks from messages
6
+ 2. Executing the code using the Piston execution engine
7
+ 3. Comparing the output with expected results or running against test cases
8
+ """
9
+
10
+ import asyncio
11
+ import json
12
+ import os
13
+ import re
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Union
16
+
17
+ import aiohttp
18
+
19
+ from ..models import EvaluateResult, Message, MetricResult
20
+ from ..reward_function import reward_function
21
+
22
+
23
+ @dataclass
24
+ class TestResult:
25
+ """
26
+ Represents the result of a single test case execution.
27
+ """
28
+
29
+ test_name: str
30
+ score: float = 0.0
31
+ status: str = "SKIPPED"
32
+ feedback: str = ""
33
+ actual_output: str = ""
34
+ expected_output: str = ""
35
+
36
+
37
+ class PistonError(Exception):
38
+ """Exception raised for errors from the Piston API."""
39
+
40
+ pass
41
+
42
+
43
+ class PistonClient:
44
+ """
45
+ A client that communicates with Piston API endpoints for code execution.
46
+
47
+ Piston is a general purpose code execution engine:
48
+ https://github.com/engineer-man/piston
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ base_endpoint: str = "https://emkc.org/api/v2/piston",
54
+ session: Optional[aiohttp.ClientSession] = None,
55
+ timeout: int = 30,
56
+ ):
57
+ self.base_endpoint = base_endpoint
58
+ self._session = session
59
+ self.timeout = timeout
60
+
61
+ @property
62
+ def session(self):
63
+ if self._session is None:
64
+ self._session = aiohttp.ClientSession(
65
+ timeout=aiohttp.ClientTimeout(sock_read=self.timeout),
66
+ connector=aiohttp.TCPConnector(
67
+ limit=10,
68
+ ttl_dns_cache=300,
69
+ keepalive_timeout=30,
70
+ ),
71
+ )
72
+ return self._session
73
+
74
+ async def close(self):
75
+ """Close the session."""
76
+ if self._session:
77
+ await self._session.close()
78
+ self._session = None
79
+
80
+ async def get_runtimes(self) -> List[Dict[str, Any]]:
81
+ """Get list of supported runtimes."""
82
+ async with self.session.get(f"{self.base_endpoint}/runtimes") as response:
83
+ if response.status != 200:
84
+ raise PistonError(f"Error getting runtimes: {response.status}")
85
+ return await response.json()
86
+
87
+ async def execute(
88
+ self,
89
+ language: str,
90
+ version: str,
91
+ files: List[Dict[str, str]],
92
+ stdin: str = "",
93
+ args: List[str] = [],
94
+ compile_timeout: int = 10000,
95
+ run_timeout: int = 3000,
96
+ compile_memory_limit: int = -1,
97
+ run_memory_limit: int = -1,
98
+ ) -> Dict[str, Any]:
99
+ """
100
+ Execute code using the Piston API.
101
+
102
+ Args:
103
+ language: Programming language (e.g., "c", "cpp")
104
+ version: Version of the language (e.g., "10.2.0")
105
+ files: List of files to include in execution (each with "name" and "content")
106
+ stdin: Standard input to provide to the program
107
+ args: Command-line arguments to pass to the program
108
+ compile_timeout: Maximum compilation time in milliseconds
109
+ run_timeout: Maximum execution time in milliseconds
110
+ compile_memory_limit: Maximum memory for compilation in bytes (-1 for unlimited)
111
+ run_memory_limit: Maximum memory for execution in bytes (-1 for unlimited)
112
+
113
+ Returns:
114
+ Dictionary containing the execution results
115
+ """
116
+ payload = {
117
+ "language": language,
118
+ "version": version,
119
+ "files": files,
120
+ "stdin": stdin,
121
+ "args": args,
122
+ "compile_timeout": compile_timeout,
123
+ "run_timeout": run_timeout,
124
+ "compile_memory_limit": compile_memory_limit,
125
+ "run_memory_limit": run_memory_limit,
126
+ }
127
+
128
+ async with self.session.post(
129
+ f"{self.base_endpoint}/execute",
130
+ json=payload,
131
+ headers={"Content-Type": "application/json"},
132
+ ) as response:
133
+ if response.status != 200:
134
+ error_text = await response.text()
135
+ raise PistonError(f"Error executing code: {response.status} - {error_text}")
136
+
137
+ result = await response.json()
138
+
139
+ if "message" in result:
140
+ raise PistonError(result["message"])
141
+
142
+ return result
143
+
144
+
145
+ def get_piston_client(endpoint: Optional[str] = None) -> PistonClient:
146
+ """
147
+ Get a Piston client instance.
148
+
149
+ Args:
150
+ endpoint: Optional custom Piston API endpoint
151
+
152
+ Returns:
153
+ PistonClient instance
154
+ """
155
+ piston_endpoint = endpoint or os.environ.get("PISTON_ENDPOINT", "https://emkc.org/api/v2/piston")
156
+ assert isinstance(piston_endpoint, str)
157
+ return PistonClient(base_endpoint=piston_endpoint)
158
+
159
+
160
+ def extract_code_blocks(text: str, language: str = "cpp") -> List[Dict[str, str]]:
161
+ """
162
+ Extract code blocks from text.
163
+
164
+ Args:
165
+ text: The text to extract code blocks from
166
+ language: Language to filter by (e.g., "cpp", "c")
167
+
168
+ Returns:
169
+ List of dictionaries with "code" and "language" keys
170
+ """
171
+ pattern = r"```(\w*)\n([\s\S]*?)\n```"
172
+ matches = re.findall(pattern, text)
173
+
174
+ code_blocks = []
175
+ for lang, code in matches:
176
+ lang = lang.lower()
177
+
178
+ if language and lang:
179
+ if language == "cpp" and lang not in ["cpp", "c++"]:
180
+ continue
181
+ elif language == "c" and lang != "c":
182
+ continue
183
+ elif language not in ["c", "cpp"] and language != lang:
184
+ continue
185
+
186
+ detected_lang = lang if lang else "unknown"
187
+ code_blocks.append({"language": detected_lang, "code": code.strip()})
188
+
189
+ return code_blocks
190
+
191
+
192
+ def add_cpp_includes(code: str) -> str:
193
+ """
194
+ Add common C++ includes if they're missing.
195
+
196
+ Args:
197
+ code: C++ code
198
+
199
+ Returns:
200
+ Code with added includes if necessary
201
+ """
202
+ if not code:
203
+ return code
204
+
205
+ includes = []
206
+
207
+ if "#include <iostream>" not in code:
208
+ includes.append("#include <iostream>")
209
+ if "#include <vector>" not in code:
210
+ includes.append("#include <vector>")
211
+ if "#include <string>" not in code:
212
+ includes.append("#include <string>")
213
+ if "#include <bits/stdc++.h>" not in code:
214
+ includes.append("#include <bits/stdc++.h>")
215
+ if "using namespace std;" not in code and "std::" not in code:
216
+ includes.append("using namespace std;")
217
+
218
+ if includes:
219
+ return "\n".join(includes) + "\n\n" + code
220
+
221
+ return code
222
+
223
+
224
+ def add_c_includes(code: str) -> str:
225
+ """
226
+ Add common C includes if they're missing.
227
+
228
+ Args:
229
+ code: C code
230
+
231
+ Returns:
232
+ Code with added includes if necessary
233
+ """
234
+ if not code:
235
+ return code
236
+
237
+ includes = []
238
+
239
+ if "#include <stdio.h>" not in code:
240
+ includes.append("#include <stdio.h>")
241
+ if "#include <stdlib.h>" not in code:
242
+ includes.append("#include <stdlib.h>")
243
+ if "#include <string.h>" not in code:
244
+ includes.append("#include <string.h>")
245
+
246
+ if includes:
247
+ return "\n".join(includes) + "\n\n" + code
248
+
249
+ return code
250
+
251
+
252
+ async def execute_cpp_code(
253
+ code: str,
254
+ stdin: str = "",
255
+ language: str = "cpp",
256
+ version: str = "11.4.0",
257
+ timeout: int = 5000,
258
+ memory_limit: int = 512000000,
259
+ piston_endpoint: Optional[str] = None,
260
+ ) -> Dict[str, Any]:
261
+ """
262
+ Execute C/C++ code using the Piston API.
263
+
264
+ Args:
265
+ code: C/C++ code to execute
266
+ stdin: Standard input to provide to the program
267
+ language: "c" or "cpp"
268
+ version: Version of the compiler to use
269
+ timeout: Maximum execution time in milliseconds
270
+ memory_limit: Maximum memory in bytes
271
+ piston_endpoint: Optional custom Piston API endpoint
272
+
273
+ Returns:
274
+ Dictionary with execution results
275
+ """
276
+ if language == "cpp":
277
+ code = add_cpp_includes(code)
278
+ else:
279
+ code = add_c_includes(code)
280
+
281
+ client = get_piston_client(piston_endpoint)
282
+
283
+ try:
284
+ main_file = {
285
+ "name": "main.cpp" if language == "cpp" else "main.c",
286
+ "content": code,
287
+ }
288
+
289
+ result = await client.execute(
290
+ language=language,
291
+ version=version,
292
+ files=[main_file],
293
+ stdin=stdin,
294
+ compile_timeout=timeout,
295
+ run_timeout=timeout,
296
+ run_memory_limit=memory_limit,
297
+ )
298
+
299
+ if "compile" in result and result["compile"]["code"] != 0:
300
+ return {
301
+ "success": False,
302
+ "output": None,
303
+ "error": f"Compilation error: {result['compile']['stderr']}",
304
+ }
305
+
306
+ if "run" in result:
307
+ if result["run"]["code"] == 0:
308
+ return {
309
+ "success": True,
310
+ "output": result["run"]["stdout"],
311
+ "error": None,
312
+ }
313
+ else:
314
+ return {
315
+ "success": False,
316
+ "output": (result["run"]["stdout"] if result["run"]["stdout"] else None),
317
+ "error": f"Runtime error (exit code {result['run']['code']}): {result['run']['stderr']}",
318
+ }
319
+
320
+ return {
321
+ "success": False,
322
+ "output": None,
323
+ "error": "Unknown error during execution",
324
+ }
325
+
326
+ except PistonError as e:
327
+ return {
328
+ "success": False,
329
+ "output": None,
330
+ "error": f"Piston error: {str(e)}",
331
+ }
332
+ except Exception as e:
333
+ return {"success": False, "output": None, "error": f"Error: {str(e)}"}
334
+ finally:
335
+ loop = asyncio.get_event_loop()
336
+ loop.create_task(client.close())
337
+
338
+
339
+ def compare_outputs(actual: str, expected: str) -> float:
340
+ """
341
+ Compare actual and expected outputs to calculate a similarity score.
342
+
343
+ Args:
344
+ actual: Actual output from code execution
345
+ expected: Expected output
346
+
347
+ Returns:
348
+ Similarity score between 0.0 and 1.0
349
+ """
350
+ if actual is None:
351
+ actual = ""
352
+ if expected is None:
353
+ expected = ""
354
+
355
+ actual_norm = re.sub(r"\s+", " ", actual.strip())
356
+ expected_norm = re.sub(r"\s+", " ", expected.strip())
357
+
358
+ if actual_norm == expected_norm:
359
+ return 1.0
360
+
361
+ try:
362
+ actual_num = float(actual_norm)
363
+ expected_num = float(expected_norm)
364
+
365
+ if expected_num == 0:
366
+ return 1.0 if actual_num == 0 else 0.0
367
+
368
+ rel_diff = abs(actual_num - expected_num) / abs(expected_num)
369
+
370
+ if rel_diff <= 0.001:
371
+ return 1.0
372
+ elif rel_diff <= 0.01:
373
+ return 0.95
374
+ elif rel_diff <= 0.1:
375
+ return 0.7
376
+ else:
377
+ return max(0.0, 1.0 - min(1.0, rel_diff))
378
+ except (ValueError, TypeError):
379
+ pass
380
+
381
+ if "\n" in actual_norm or "\n" in expected_norm:
382
+ actual_lines = actual_norm.split("\n")
383
+ expected_lines = expected_norm.split("\n")
384
+
385
+ common_len = min(len(actual_lines), len(expected_lines))
386
+ if common_len == 0:
387
+ return 0.0
388
+
389
+ line_similarities = []
390
+ for i in range(common_len):
391
+ if actual_lines[i] == expected_lines[i]:
392
+ line_similarities.append(1.0)
393
+ else:
394
+ line_similarities.append(string_similarity(actual_lines[i], expected_lines[i]))
395
+
396
+ total_weight = sum(1 / (i + 1) for i in range(common_len))
397
+ weighted_sum = sum((1 / (i + 1)) * sim for i, sim in enumerate(line_similarities))
398
+ similarity = weighted_sum / total_weight if total_weight > 0 else 0.0
399
+
400
+ length_penalty = min(len(actual_lines), len(expected_lines)) / max(len(actual_lines), len(expected_lines))
401
+
402
+ return similarity * length_penalty
403
+
404
+ return string_similarity(actual_norm, expected_norm)
405
+
406
+
407
+ def string_similarity(s1: str, s2: str) -> float:
408
+ """
409
+ Calculate string similarity.
410
+
411
+ Args:
412
+ s1: First string
413
+ s2: Second string
414
+
415
+ Returns:
416
+ Similarity score between 0.0 and 1.0
417
+ """
418
+ if not s1 and not s2:
419
+ return 1.0
420
+ if not s1 or not s2:
421
+ return 0.0
422
+
423
+ distance = levenshtein_distance(s1, s2)
424
+ max_len = max(len(s1), len(s2))
425
+ return 1.0 - (distance / max_len if max_len > 0 else 0.0)
426
+
427
+
428
+ def levenshtein_distance(s1: str, s2: str) -> int:
429
+ """
430
+ Calculate the Levenshtein distance between two strings.
431
+
432
+ Args:
433
+ s1: First string
434
+ s2: Second string
435
+
436
+ Returns:
437
+ Edit distance between strings
438
+ """
439
+ if len(s1) < len(s2):
440
+ return levenshtein_distance(s2, s1)
441
+
442
+ if not s2:
443
+ return len(s1)
444
+
445
+ previous_row: List[int] = list(range(len(s2) + 1))
446
+ for i, c1 in enumerate(s1):
447
+ current_row = [i + 1]
448
+ for j, c2 in enumerate(s2):
449
+ insertions = previous_row[j + 1] + 1
450
+ deletions = current_row[j] + 1
451
+ substitutions = previous_row[j] + (c1 != c2)
452
+ current_row.append(min(insertions, deletions, substitutions))
453
+ previous_row = current_row
454
+
455
+ return previous_row[-1]
456
+
457
+
458
+ async def run_cpp_test_cases(
459
+ code: str,
460
+ test_cases: List[Dict[str, Any]],
461
+ language: str = "cpp",
462
+ version: str = "11.4.0",
463
+ timeout: int = 5000,
464
+ memory_limit: int = 512000000,
465
+ piston_endpoint: Optional[str] = None,
466
+ ) -> List[TestResult]:
467
+ """
468
+ Run C/C++ code against multiple test cases.
469
+
470
+ Args:
471
+ code: C/C++ code to execute
472
+ test_cases: List of test cases with "input" and "expected_output" keys
473
+ language: "c" or "cpp"
474
+ version: Version of the compiler to use
475
+ timeout: Maximum execution time in milliseconds
476
+ memory_limit: Maximum memory in bytes
477
+ piston_endpoint: Optional custom Piston API endpoint
478
+
479
+ Returns:
480
+ List of TestResult objects
481
+ """
482
+ results = []
483
+
484
+ for i, test_case in enumerate(test_cases):
485
+ test_input = test_case.get("input", "")
486
+ expected_output = test_case.get("expected_output", "")
487
+ test_name = test_case.get("name", f"Test {i+1}")
488
+
489
+ execution_result = await execute_cpp_code(
490
+ code=code,
491
+ stdin=test_input,
492
+ language=language,
493
+ version=version,
494
+ timeout=timeout,
495
+ memory_limit=memory_limit,
496
+ piston_endpoint=piston_endpoint,
497
+ )
498
+
499
+ test_result = TestResult(test_name=test_name, expected_output=expected_output)
500
+
501
+ if execution_result["success"]:
502
+ actual_output = execution_result["output"]
503
+ test_result.actual_output = actual_output
504
+ similarity = compare_outputs(actual_output, expected_output)
505
+ test_result.score = similarity
506
+
507
+ if similarity >= 0.99:
508
+ test_result.status = "AC"
509
+ elif similarity > 0:
510
+ test_result.status = "PA"
511
+ else:
512
+ test_result.status = "WA"
513
+ test_result.feedback = f"Similarity: {similarity:.2f}"
514
+ else:
515
+ test_result.status = "CE" if "Compilation error" in execution_result["error"] else "RE"
516
+ test_result.feedback = execution_result["error"]
517
+ test_result.score = 0.0
518
+
519
+ results.append(test_result)
520
+ if test_result.score == 0.0:
521
+ break
522
+
523
+ return results
524
+
525
+
526
+ @reward_function
527
+ def ioi_cpp_code_reward(
528
+ messages: List[Message],
529
+ ground_truth: Union[Optional[str], Optional[List[Dict[str, Any]]]],
530
+ language: str = "cpp",
531
+ version: str = "11.4.0",
532
+ timeout: int = 5000,
533
+ memory_limit: int = 512000000,
534
+ piston_endpoint: Optional[str] = None,
535
+ pass_threshold: float = 0.99,
536
+ **kwargs: Any,
537
+ ) -> EvaluateResult:
538
+ """
539
+ Wrapper function for the asynchronous implementation to make it compatible with the reward_function decorator.
540
+ """
541
+ loop = asyncio.new_event_loop()
542
+ asyncio.set_event_loop(loop)
543
+
544
+ try:
545
+ return _ioi_cpp_code_reward_impl(
546
+ messages=messages,
547
+ ground_truth=ground_truth,
548
+ language=language,
549
+ version=version,
550
+ timeout=timeout,
551
+ memory_limit=memory_limit,
552
+ piston_endpoint=piston_endpoint,
553
+ pass_threshold=pass_threshold,
554
+ **kwargs,
555
+ )
556
+ finally:
557
+ loop.close()
558
+
559
+
560
+ def _ioi_cpp_code_reward_impl(
561
+ messages: List[Message],
562
+ ground_truth: Union[Optional[str], Optional[List[Dict[str, Any]]]],
563
+ language: str = "cpp",
564
+ version: str = "11.4.0",
565
+ timeout: int = 5000,
566
+ memory_limit: int = 512000000,
567
+ piston_endpoint: Optional[str] = None,
568
+ pass_threshold: float = 0.99,
569
+ **kwargs: Any,
570
+ ) -> EvaluateResult:
571
+ """
572
+ Evaluate C/C++ code correctness using the Piston execution engine.
573
+
574
+ This function evaluates code for competitive programming problems (like IOI)
575
+ by compiling and executing C/C++ code against test cases.
576
+
577
+ Args:
578
+ messages: Generated conversation messages
579
+ ground_truth: Expected output string or list of test case dictionaries.
580
+ language: Programming language ("c" or "cpp")
581
+ version: Version of the compiler to use
582
+ timeout: Maximum execution time in milliseconds
583
+ memory_limit: Maximum memory in bytes
584
+ piston_endpoint: Optional custom Piston API endpoint
585
+ pass_threshold: Similarity threshold for considering a test passed
586
+ **kwargs: Additional keyword arguments
587
+
588
+ Returns:
589
+ EvaluateResult with score and metrics
590
+ """
591
+ metrics: Dict[str, MetricResult] = {}
592
+
593
+ if (
594
+ not messages
595
+ or not isinstance(messages[-1], Message)
596
+ or messages[-1].role != "assistant"
597
+ or messages[-1].content is None
598
+ ):
599
+ return EvaluateResult(
600
+ score=0.0,
601
+ reason="Invalid or missing assistant response in messages.",
602
+ metrics={
603
+ "error": MetricResult(
604
+ score=0.0,
605
+ is_score_valid=False,
606
+ reason="Last message not a valid assistant response.",
607
+ )
608
+ },
609
+ )
610
+
611
+ response_content = messages[-1].content
612
+
613
+ expected_output_str_from_gt: Optional[str] = None
614
+ test_cases_from_gt: Optional[List[Dict[str, Any]]] = None
615
+
616
+ if isinstance(ground_truth, str):
617
+ expected_output_str_from_gt = ground_truth
618
+ elif isinstance(ground_truth, list):
619
+ if all(isinstance(item, dict) for item in ground_truth):
620
+ test_cases_from_gt = ground_truth
621
+ else:
622
+ return EvaluateResult(
623
+ score=0.0,
624
+ reason="Invalid ground_truth format: if list, must be list of test case dicts.",
625
+ metrics={
626
+ "error": MetricResult(
627
+ score=0.0,
628
+ is_score_valid=False,
629
+ reason="Invalid ground_truth list format.",
630
+ )
631
+ },
632
+ )
633
+ elif ground_truth is not None:
634
+ return EvaluateResult(
635
+ score=0.0,
636
+ reason="Invalid ground_truth format: expected string, list of test case dicts, or None.",
637
+ metrics={
638
+ "error": MetricResult(
639
+ score=0.0,
640
+ is_score_valid=False,
641
+ reason="Invalid ground_truth format.",
642
+ )
643
+ },
644
+ )
645
+
646
+ code_blocks = extract_code_blocks(response_content, language)
647
+
648
+ if not code_blocks:
649
+ return EvaluateResult(
650
+ score=0.0,
651
+ reason=f"No {language} code blocks found in model's response.",
652
+ metrics={
653
+ "error": MetricResult(
654
+ score=0.0,
655
+ is_score_valid=False,
656
+ reason=f"No {language} code blocks found in model's response.",
657
+ )
658
+ },
659
+ )
660
+
661
+ code = code_blocks[0]["code"]
662
+
663
+ metrics["extracted_code"] = MetricResult(
664
+ score=0.0,
665
+ is_score_valid=True,
666
+ reason=f"Extracted code:\n```{language}\n{code}\n```",
667
+ )
668
+
669
+ if expected_output_str_from_gt and not test_cases_from_gt:
670
+ metrics["expected_output"] = MetricResult(
671
+ score=0.0,
672
+ is_score_valid=True,
673
+ reason=f"Expected output:\n{expected_output_str_from_gt}",
674
+ )
675
+
676
+ if test_cases_from_gt:
677
+ results = asyncio.get_event_loop().run_until_complete(
678
+ run_cpp_test_cases(
679
+ code=code,
680
+ test_cases=test_cases_from_gt,
681
+ language=language,
682
+ version=version,
683
+ timeout=timeout,
684
+ memory_limit=memory_limit,
685
+ piston_endpoint=piston_endpoint,
686
+ )
687
+ )
688
+
689
+ passed = sum(1 for result in results if result.score >= pass_threshold)
690
+ total = len(results)
691
+ overall_score = passed / total if total > 0 else 0.0
692
+ final_reason = f"{passed}/{total} tests passed ({overall_score:.2%})."
693
+
694
+ metrics["test_results"] = MetricResult(
695
+ score=overall_score,
696
+ is_score_valid=overall_score >= pass_threshold,
697
+ reason=json.dumps(
698
+ [
699
+ {
700
+ "test_name": result.test_name,
701
+ "status": result.status,
702
+ "score": result.score,
703
+ "feedback": result.feedback,
704
+ }
705
+ for result in results
706
+ ],
707
+ indent=2,
708
+ ),
709
+ )
710
+
711
+ metrics["pass_rate"] = MetricResult(
712
+ score=overall_score,
713
+ is_score_valid=overall_score == 1.0,
714
+ reason=f"{passed}/{total} tests passed ({overall_score:.2%})",
715
+ )
716
+
717
+ return EvaluateResult(score=overall_score, reason=final_reason, metrics=metrics)
718
+
719
+ elif expected_output_str_from_gt:
720
+ execution_result = asyncio.get_event_loop().run_until_complete(
721
+ execute_cpp_code(
722
+ code=code,
723
+ language=language,
724
+ version=version,
725
+ timeout=timeout,
726
+ memory_limit=memory_limit,
727
+ piston_endpoint=piston_endpoint,
728
+ )
729
+ )
730
+
731
+ if execution_result["success"]:
732
+ output = execution_result["output"]
733
+ final_reason = "Code executed successfully."
734
+
735
+ metrics["execution_result"] = MetricResult(
736
+ score=1.0,
737
+ is_score_valid=True,
738
+ reason=f"Code executed successfully with output:\n{output}",
739
+ )
740
+
741
+ similarity = compare_outputs(output, expected_output_str_from_gt)
742
+ match_reason = (
743
+ f"Output similarity: {similarity:.2f}\n\nExpected:\n{expected_output_str_from_gt}\n\nActual:\n{output}"
744
+ )
745
+ final_reason += f" Output similarity: {similarity:.2f}."
746
+
747
+ metrics["output_match"] = MetricResult(
748
+ score=similarity,
749
+ is_score_valid=similarity >= pass_threshold,
750
+ reason=match_reason,
751
+ )
752
+
753
+ return EvaluateResult(score=similarity, reason=final_reason, metrics=metrics)
754
+ else:
755
+ error = execution_result["error"]
756
+ final_reason = f"Code execution failed: {error}"
757
+
758
+ metrics["execution_result"] = MetricResult(
759
+ score=0.0,
760
+ is_score_valid=False,
761
+ reason=f"Code execution failed with error:\n{error}",
762
+ )
763
+
764
+ return EvaluateResult(score=0.0, reason=final_reason, metrics=metrics)
765
+
766
+ else:
767
+ execution_result = asyncio.get_event_loop().run_until_complete(
768
+ execute_cpp_code(
769
+ code=code,
770
+ language=language,
771
+ version=version,
772
+ timeout=timeout,
773
+ memory_limit=memory_limit,
774
+ piston_endpoint=piston_endpoint,
775
+ )
776
+ )
777
+
778
+ if execution_result["success"]:
779
+ output = execution_result["output"]
780
+ final_reason = "Code executed successfully (no expected output for comparison)."
781
+
782
+ metrics["execution_result"] = MetricResult(
783
+ score=1.0,
784
+ is_score_valid=True,
785
+ reason=f"Code executed successfully with output:\n{output}",
786
+ )
787
+
788
+ return EvaluateResult(score=1.0, reason=final_reason, metrics=metrics)
789
+ else:
790
+ error = execution_result["error"]
791
+ final_reason = f"Code execution failed: {error}"
792
+ metrics["execution_result"] = MetricResult(
793
+ score=0.0,
794
+ is_score_valid=False,
795
+ reason=f"Code execution failed with error:\n{error}",
796
+ )
797
+
798
+ return EvaluateResult(score=0.0, reason=final_reason, metrics=metrics)
799
+
800
+
801
+ @reward_function
802
+ def binary_cpp_code_reward(
803
+ messages: List[Message],
804
+ ground_truth: Union[Optional[str], Optional[List[Dict[str, Any]]]],
805
+ language: str = "cpp",
806
+ version: str = "11.4.0",
807
+ timeout: int = 5000,
808
+ memory_limit: int = 512000000,
809
+ piston_endpoint: Optional[str] = None,
810
+ pass_threshold: float = 0.99,
811
+ **kwargs: Any,
812
+ ) -> EvaluateResult:
813
+ """
814
+ Evaluate C/C++ code correctness and return a binary result (passed/failed).
815
+
816
+ This function is a wrapper around ioi_cpp_code_reward that returns 1.0 if the
817
+ score is at or above the pass_threshold, and 0.0 otherwise.
818
+
819
+ Args:
820
+ messages: Generated conversation messages
821
+ ground_truth: Expected output string or list of test case dictionaries.
822
+ language: Programming language ("c" or "cpp")
823
+ version: Version of the compiler to use
824
+ timeout: Maximum execution time in milliseconds
825
+ memory_limit: Maximum memory in bytes
826
+ piston_endpoint: Optional custom Piston API endpoint
827
+ pass_threshold: Similarity threshold for considering a test passed
828
+ **kwargs: Additional keyword arguments
829
+
830
+ Returns:
831
+ EvaluateResult with binary score (0.0 or 1.0) and metrics
832
+ """
833
+ loop = asyncio.new_event_loop()
834
+ asyncio.set_event_loop(loop)
835
+
836
+ try:
837
+ reward_output = _ioi_cpp_code_reward_impl(
838
+ messages=messages,
839
+ ground_truth=ground_truth,
840
+ language=language,
841
+ version=version,
842
+ timeout=timeout,
843
+ memory_limit=memory_limit,
844
+ piston_endpoint=piston_endpoint,
845
+ pass_threshold=pass_threshold,
846
+ **kwargs,
847
+ )
848
+
849
+ score = reward_output.score
850
+ binary_score = 1.0 if score >= pass_threshold else 0.0
851
+ metrics = dict(reward_output.metrics)
852
+ final_reason = f"Binary score based on threshold {pass_threshold:.2f}. Original score: {score:.2f}."
853
+ metrics["binary_result"] = MetricResult(
854
+ score=binary_score,
855
+ is_score_valid=binary_score == 1.0,
856
+ reason=f"{'Passed' if binary_score > 0 else 'Failed'} (threshold: {pass_threshold:.2f}, actual: {score:.2f})",
857
+ )
858
+
859
+ return EvaluateResult(score=binary_score, reason=final_reason, metrics=metrics)
860
+ finally:
861
+ loop.close()