eval-protocol 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- development/__init__.py +1 -0
- development/normalize_sandbox_fusion.py +628 -0
- development/utils/__init__.py +1 -0
- development/utils/generate_api_key.py +31 -0
- development/utils/subprocess_manager.py +481 -0
- eval_protocol/__init__.py +86 -0
- eval_protocol/__main__.py +10 -0
- eval_protocol/_version.py +21 -0
- eval_protocol/adapters/__init__.py +1 -0
- eval_protocol/adapters/braintrust.py +8 -0
- eval_protocol/adapters/trl.py +8 -0
- eval_protocol/agent/__init__.py +29 -0
- eval_protocol/agent/models.py +69 -0
- eval_protocol/agent/orchestrator.py +893 -0
- eval_protocol/agent/resource_abc.py +89 -0
- eval_protocol/agent/resource_pool.py +184 -0
- eval_protocol/agent/resources/__init__.py +44 -0
- eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
- eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
- eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
- eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
- eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
- eval_protocol/agent/resources/docker_resource.py +479 -0
- eval_protocol/agent/resources/filesystem_resource.py +371 -0
- eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
- eval_protocol/agent/resources/http_rollout_resource.py +325 -0
- eval_protocol/agent/resources/python_state_resource.py +170 -0
- eval_protocol/agent/resources/sql_resource.py +271 -0
- eval_protocol/agent/task_manager.py +1064 -0
- eval_protocol/agent/tool_registry.py +111 -0
- eval_protocol/auth.py +156 -0
- eval_protocol/cli.py +425 -0
- eval_protocol/cli_commands/__init__.py +1 -0
- eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
- eval_protocol/cli_commands/common.py +242 -0
- eval_protocol/cli_commands/deploy.py +486 -0
- eval_protocol/cli_commands/deploy_mcp.py +287 -0
- eval_protocol/cli_commands/preview.py +186 -0
- eval_protocol/cli_commands/run_eval_cmd.py +202 -0
- eval_protocol/common_utils.py +36 -0
- eval_protocol/config.py +180 -0
- eval_protocol/datasets/__init__.py +1 -0
- eval_protocol/datasets/loader.py +521 -0
- eval_protocol/evaluation.py +1045 -0
- eval_protocol/execution/__init__.py +1 -0
- eval_protocol/execution/pipeline.py +920 -0
- eval_protocol/gcp_tools.py +484 -0
- eval_protocol/generation/cache.py +141 -0
- eval_protocol/generation/clients/base.py +67 -0
- eval_protocol/generation/clients.py +248 -0
- eval_protocol/generic_server.py +165 -0
- eval_protocol/integrations/__init__.py +12 -0
- eval_protocol/integrations/braintrust.py +51 -0
- eval_protocol/integrations/deepeval.py +106 -0
- eval_protocol/integrations/openeval.py +40 -0
- eval_protocol/integrations/trl.py +187 -0
- eval_protocol/mcp/__init__.py +48 -0
- eval_protocol/mcp/adapter.py +131 -0
- eval_protocol/mcp/client/__init__.py +12 -0
- eval_protocol/mcp/client/connection.py +499 -0
- eval_protocol/mcp/clients.py +195 -0
- eval_protocol/mcp/execution/__init__.py +23 -0
- eval_protocol/mcp/execution/base_policy.py +227 -0
- eval_protocol/mcp/execution/fireworks_policy.py +209 -0
- eval_protocol/mcp/execution/manager.py +506 -0
- eval_protocol/mcp/execution/policy.py +421 -0
- eval_protocol/mcp/grid_renderer.py +54 -0
- eval_protocol/mcp/mcpgym.py +637 -0
- eval_protocol/mcp/process_manager.py +177 -0
- eval_protocol/mcp/session/__init__.py +11 -0
- eval_protocol/mcp/session/manager.py +228 -0
- eval_protocol/mcp/simple_process_manager.py +291 -0
- eval_protocol/mcp/simulation_server.py +458 -0
- eval_protocol/mcp/types.py +80 -0
- eval_protocol/mcp_agent/__init__.py +1 -0
- eval_protocol/mcp_agent/config.py +147 -0
- eval_protocol/mcp_agent/intermediary_server.py +542 -0
- eval_protocol/mcp_agent/main.py +210 -0
- eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
- eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
- eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
- eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
- eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
- eval_protocol/mcp_agent/session.py +79 -0
- eval_protocol/mcp_env.py +304 -0
- eval_protocol/models.py +366 -0
- eval_protocol/packaging.py +219 -0
- eval_protocol/platform_api.py +360 -0
- eval_protocol/playback_policy.py +396 -0
- eval_protocol/resources.py +128 -0
- eval_protocol/reward_function.py +410 -0
- eval_protocol/rewards/__init__.py +94 -0
- eval_protocol/rewards/accuracy.py +454 -0
- eval_protocol/rewards/accuracy_length.py +173 -0
- eval_protocol/rewards/apps_coding_reward.py +331 -0
- eval_protocol/rewards/apps_execution_utils.py +149 -0
- eval_protocol/rewards/apps_testing_util.py +559 -0
- eval_protocol/rewards/bfcl_reward.py +313 -0
- eval_protocol/rewards/code_execution.py +1620 -0
- eval_protocol/rewards/code_execution_utils.py +72 -0
- eval_protocol/rewards/cpp_code.py +861 -0
- eval_protocol/rewards/deepcoder_reward.py +161 -0
- eval_protocol/rewards/format.py +129 -0
- eval_protocol/rewards/function_calling.py +541 -0
- eval_protocol/rewards/json_schema.py +422 -0
- eval_protocol/rewards/language_consistency.py +700 -0
- eval_protocol/rewards/lean_prover.py +479 -0
- eval_protocol/rewards/length.py +375 -0
- eval_protocol/rewards/list_comparison_math_reward.py +221 -0
- eval_protocol/rewards/math.py +762 -0
- eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
- eval_protocol/rewards/reasoning_steps.py +249 -0
- eval_protocol/rewards/repetition.py +342 -0
- eval_protocol/rewards/tag_count.py +162 -0
- eval_protocol/rl_processing.py +82 -0
- eval_protocol/server.py +271 -0
- eval_protocol/typed_interface.py +260 -0
- eval_protocol/utils/__init__.py +8 -0
- eval_protocol/utils/batch_evaluation.py +217 -0
- eval_protocol/utils/batch_transformation.py +205 -0
- eval_protocol/utils/dataset_helpers.py +112 -0
- eval_protocol/utils/module_loader.py +56 -0
- eval_protocol/utils/packaging_utils.py +108 -0
- eval_protocol/utils/static_policy.py +305 -0
- eval_protocol-0.0.3.dist-info/METADATA +635 -0
- eval_protocol-0.0.3.dist-info/RECORD +130 -0
- eval_protocol-0.0.3.dist-info/WHEEL +5 -0
- eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
- eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
- eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,479 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import re
|
|
3
|
+
from typing import Any, Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
from eval_protocol.models import EvaluateResult, Message, MetricResult
|
|
6
|
+
from eval_protocol.reward_function import reward_function
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@reward_function
|
|
10
|
+
def lean_prover_reward(
|
|
11
|
+
messages: List[Message],
|
|
12
|
+
ground_truth: Optional[str], # This is the expected_answer (proof string)
|
|
13
|
+
**kwargs: Any,
|
|
14
|
+
) -> EvaluateResult:
|
|
15
|
+
"""
|
|
16
|
+
Evaluates a Lean proof by analyzing the response for valid syntax, proof completion,
|
|
17
|
+
and correctness based on the DeepSeek-Prover-V2 benchmark approach.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
messages: List of conversation messages. The last message is the assistant's response.
|
|
21
|
+
ground_truth: The expected proof string. Corresponds to 'expected_answer' in original kwargs.
|
|
22
|
+
**kwargs: Must include 'statement' (str). Optional:
|
|
23
|
+
'lean_version' (str, default "4"), 'check_partial_progress' (bool, default True),
|
|
24
|
+
'verbose' (bool, default False).
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
EvaluateResult with score and metrics
|
|
28
|
+
"""
|
|
29
|
+
statement: Optional[str] = kwargs.get("statement")
|
|
30
|
+
expected_answer: Optional[str] = ground_truth
|
|
31
|
+
# lean_version: str = kwargs.get("lean_version", "4") # lean_version is not used in this function's logic
|
|
32
|
+
check_partial_progress: bool = kwargs.get("check_partial_progress", True)
|
|
33
|
+
verbose: bool = kwargs.get("verbose", False)
|
|
34
|
+
|
|
35
|
+
if not statement:
|
|
36
|
+
return EvaluateResult(
|
|
37
|
+
score=0.0,
|
|
38
|
+
reason="Statement not provided in kwargs.",
|
|
39
|
+
metrics={"error": MetricResult(score=0.0, is_score_valid=False, reason="Statement not provided.")},
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
if (
|
|
43
|
+
not messages
|
|
44
|
+
or not isinstance(messages[-1], Message)
|
|
45
|
+
or messages[-1].role != "assistant"
|
|
46
|
+
or messages[-1].content is None
|
|
47
|
+
):
|
|
48
|
+
return EvaluateResult(
|
|
49
|
+
score=0.0,
|
|
50
|
+
reason="Invalid or missing assistant response in messages.",
|
|
51
|
+
metrics={
|
|
52
|
+
"error": MetricResult(
|
|
53
|
+
score=0.0,
|
|
54
|
+
is_score_valid=False,
|
|
55
|
+
reason="Last message not a valid assistant response.",
|
|
56
|
+
)
|
|
57
|
+
},
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
response = messages[-1].content
|
|
61
|
+
if not response:
|
|
62
|
+
return EvaluateResult(
|
|
63
|
+
score=0.0,
|
|
64
|
+
reason="Assistant response content is empty.",
|
|
65
|
+
metrics={
|
|
66
|
+
"error": MetricResult(
|
|
67
|
+
score=0.0,
|
|
68
|
+
is_score_valid=False,
|
|
69
|
+
reason="Empty assistant response content.",
|
|
70
|
+
)
|
|
71
|
+
},
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
patterns = {
|
|
75
|
+
"theorem_def": r"theorem\s+\w+(\s*\{[^}]*\})?(\s*\([^)]*\))?\s*:=?",
|
|
76
|
+
"lemma_def": r"lemma\s+\w+(\s*\{[^}]*\})?(\s*\([^)]*\))?\s*:=?",
|
|
77
|
+
"example_def": r"example\s*(\{[^}]*\})?(\s*\([^)]*\))?\s*:=?",
|
|
78
|
+
"by_tactic": r"by\s+\w+",
|
|
79
|
+
"sorry": r"sorry",
|
|
80
|
+
"admitted": r"admitted",
|
|
81
|
+
"end_of_proof": r"(QED|qed|∎|#check)",
|
|
82
|
+
"have_statement": r"have\s+\w+(\s*:\s*[^:=]+)?\s*:=",
|
|
83
|
+
"apply_tactic": r"apply\s+[\w\.]+",
|
|
84
|
+
"intro_tactic": r"intro\s+\w+",
|
|
85
|
+
"rw_tactic": r"rw\s+[\[\]\w\s\.\,]+",
|
|
86
|
+
"simp_tactic": r"simp(\s+[\[\]\w\s\.\,]+)?",
|
|
87
|
+
"exact_tactic": r"exact\s+[\w\.]+",
|
|
88
|
+
"calc_block": r"calc\s+",
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
has_theorem_def = (
|
|
92
|
+
bool(re.search(patterns["theorem_def"], response))
|
|
93
|
+
or bool(re.search(patterns["lemma_def"], response))
|
|
94
|
+
or bool(re.search(patterns["example_def"], response))
|
|
95
|
+
)
|
|
96
|
+
has_sorry = bool(re.search(patterns["sorry"], response))
|
|
97
|
+
has_admitted = bool(re.search(patterns["admitted"], response))
|
|
98
|
+
has_end_marker = bool(re.search(patterns["end_of_proof"], response))
|
|
99
|
+
has_by_tactic = bool(re.search(patterns["by_tactic"], response))
|
|
100
|
+
|
|
101
|
+
tactics_present = []
|
|
102
|
+
tactics_count = 0
|
|
103
|
+
for tactic_name in [
|
|
104
|
+
"have_statement",
|
|
105
|
+
"apply_tactic",
|
|
106
|
+
"intro_tactic",
|
|
107
|
+
"rw_tactic",
|
|
108
|
+
"simp_tactic",
|
|
109
|
+
"exact_tactic",
|
|
110
|
+
"calc_block",
|
|
111
|
+
]:
|
|
112
|
+
if bool(re.search(patterns[tactic_name], response)):
|
|
113
|
+
tactics_present.append(tactic_name)
|
|
114
|
+
tactics_count += len(re.findall(patterns[tactic_name], response))
|
|
115
|
+
|
|
116
|
+
score = 0.0
|
|
117
|
+
reason = "No valid Lean proof attempt"
|
|
118
|
+
|
|
119
|
+
if not has_theorem_def and tactics_count == 0:
|
|
120
|
+
score = 0.0
|
|
121
|
+
reason = "No valid Lean proof attempt"
|
|
122
|
+
elif has_theorem_def and (has_sorry or has_admitted):
|
|
123
|
+
if check_partial_progress:
|
|
124
|
+
score = min(0.4, 0.1 + (tactics_count / 10) * 0.3)
|
|
125
|
+
reason = f"Incomplete proof with {tactics_count} tactics"
|
|
126
|
+
else:
|
|
127
|
+
score = 0.1
|
|
128
|
+
reason = "Incomplete proof (has sorry/admitted)"
|
|
129
|
+
elif has_theorem_def and not (has_sorry or has_admitted):
|
|
130
|
+
score = 0.5
|
|
131
|
+
reason = "Complete proof"
|
|
132
|
+
if tactics_count >= 5:
|
|
133
|
+
score += 0.4
|
|
134
|
+
reason = f"Complete proof with good complexity ({tactics_count} tactics)"
|
|
135
|
+
else:
|
|
136
|
+
score += (tactics_count / 5) * 0.4
|
|
137
|
+
reason = f"Complete proof with {tactics_count} tactics"
|
|
138
|
+
|
|
139
|
+
if expected_answer and expected_answer.lower() in response.lower():
|
|
140
|
+
score = 1.0
|
|
141
|
+
reason = "Perfect match with expected proof"
|
|
142
|
+
|
|
143
|
+
metrics = {}
|
|
144
|
+
if verbose:
|
|
145
|
+
metrics = {
|
|
146
|
+
"syntax": MetricResult(
|
|
147
|
+
score=float(has_theorem_def),
|
|
148
|
+
is_score_valid=has_theorem_def,
|
|
149
|
+
reason=("Has valid theorem definition" if has_theorem_def else "Missing theorem definition"),
|
|
150
|
+
),
|
|
151
|
+
"completeness": MetricResult(
|
|
152
|
+
score=0.0 if has_sorry or has_admitted else 1.0,
|
|
153
|
+
is_score_valid=not (has_sorry or has_admitted),
|
|
154
|
+
reason=("Incomplete proof (has sorry/admitted)" if has_sorry or has_admitted else "Complete proof"),
|
|
155
|
+
),
|
|
156
|
+
"tactics": MetricResult(
|
|
157
|
+
score=min(1.0, tactics_count / 10),
|
|
158
|
+
is_score_valid=tactics_count > 0,
|
|
159
|
+
reason=f"Used {tactics_count} tactics",
|
|
160
|
+
),
|
|
161
|
+
}
|
|
162
|
+
if expected_answer:
|
|
163
|
+
expected_match_bool = expected_answer.lower() in response.lower()
|
|
164
|
+
metrics["expected_match"] = MetricResult(
|
|
165
|
+
score=1.0 if expected_match_bool else 0.0,
|
|
166
|
+
is_score_valid=expected_match_bool,
|
|
167
|
+
reason=("Matches expected proof" if expected_match_bool else "Doesn't match expected proof"),
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
return EvaluateResult(score=score, reason=reason, metrics=metrics)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
@reward_function
|
|
174
|
+
def deepseek_prover_v2_reward(
|
|
175
|
+
messages: List[Message],
|
|
176
|
+
ground_truth: Optional[str], # This is the expected_proof
|
|
177
|
+
**kwargs: Any,
|
|
178
|
+
) -> EvaluateResult:
|
|
179
|
+
"""
|
|
180
|
+
Evaluates a Lean proof based on the DeepSeek-Prover-V2 methodology that
|
|
181
|
+
focuses on subgoal decomposition and formal verification.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
messages: List of conversation messages. The last message is the assistant's response.
|
|
185
|
+
ground_truth: The expected proof string. Corresponds to 'expected_proof' in original kwargs.
|
|
186
|
+
**kwargs: Must include 'statement' (str). Optional:
|
|
187
|
+
'check_subgoals' (bool, default True), 'verbose' (bool, default False).
|
|
188
|
+
Returns:
|
|
189
|
+
EvaluateResult with score and metrics
|
|
190
|
+
"""
|
|
191
|
+
statement: Optional[str] = kwargs.get("statement")
|
|
192
|
+
expected_proof: Optional[str] = ground_truth
|
|
193
|
+
check_subgoals: bool = kwargs.get("check_subgoals", True)
|
|
194
|
+
verbose: bool = kwargs.get("verbose", False)
|
|
195
|
+
|
|
196
|
+
if not statement:
|
|
197
|
+
return EvaluateResult(
|
|
198
|
+
score=0.0,
|
|
199
|
+
reason="Statement not provided in kwargs for deepseek_prover_v2_reward.",
|
|
200
|
+
metrics={"error": MetricResult(score=0.0, is_score_valid=False, reason="Statement not provided.")},
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
lean_prover_kwargs_for_call = {
|
|
204
|
+
"statement": statement,
|
|
205
|
+
"check_partial_progress": True,
|
|
206
|
+
"verbose": verbose,
|
|
207
|
+
}
|
|
208
|
+
base_evaluate_result: EvaluateResult = lean_prover_reward(
|
|
209
|
+
messages=messages, ground_truth=expected_proof, **lean_prover_kwargs_for_call
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
base_score = base_evaluate_result.score
|
|
213
|
+
base_reason = base_evaluate_result.reason or "Formal proof evaluation"
|
|
214
|
+
base_metrics = base_evaluate_result.metrics or {}
|
|
215
|
+
top_level_reason = base_reason
|
|
216
|
+
metrics = base_metrics.copy()
|
|
217
|
+
|
|
218
|
+
subgoal_patterns = {
|
|
219
|
+
"have_statement": r"have\s+(\w+)(\s*:\s*[^:=]+)?\s*:=",
|
|
220
|
+
"suffices": r"suffices\s+(\w+)(\s*:\s*[^,]+)?\s*,",
|
|
221
|
+
"let": r"let\s+(\w+)(\s*:\s*[^:=]+)?\s*:=",
|
|
222
|
+
"decomposition_comment": r"(\/\*|\/\/)\s*(decomposing|breaking down|subgoal|step \d+)",
|
|
223
|
+
"recursion": r"(recursion|induction|structural|recursive)",
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
response_content = ""
|
|
227
|
+
if (
|
|
228
|
+
messages
|
|
229
|
+
and isinstance(messages[-1], Message)
|
|
230
|
+
and messages[-1].role == "assistant"
|
|
231
|
+
and messages[-1].content is not None
|
|
232
|
+
):
|
|
233
|
+
response_content = messages[-1].content
|
|
234
|
+
|
|
235
|
+
final_score = base_score
|
|
236
|
+
subgoal_count = 0
|
|
237
|
+
hierarchy_depth: float = 0.0
|
|
238
|
+
subgoal_score: float = 0.0
|
|
239
|
+
hierarchy_score: float = 0.0
|
|
240
|
+
|
|
241
|
+
if check_subgoals and response_content:
|
|
242
|
+
for pattern_name, pattern in subgoal_patterns.items():
|
|
243
|
+
subgoal_count += len(re.findall(pattern, response_content))
|
|
244
|
+
|
|
245
|
+
lines = response_content.split("\n")
|
|
246
|
+
max_indent = 0
|
|
247
|
+
for line in lines:
|
|
248
|
+
spaces = len(line) - len(line.lstrip(" "))
|
|
249
|
+
if spaces > max_indent:
|
|
250
|
+
max_indent = spaces
|
|
251
|
+
hierarchy_depth = min(1.0, max_indent / 40) if max_indent > 0 else 0
|
|
252
|
+
subgoal_score = min(0.3, (subgoal_count / 10) * 0.3)
|
|
253
|
+
hierarchy_score = hierarchy_depth * 0.2
|
|
254
|
+
|
|
255
|
+
if base_score >= 0.5:
|
|
256
|
+
final_score = min(1.0, base_score + subgoal_score + hierarchy_score)
|
|
257
|
+
top_level_reason = f"{top_level_reason} with good subgoal decomposition"
|
|
258
|
+
else:
|
|
259
|
+
final_score = base_score
|
|
260
|
+
|
|
261
|
+
subgoal_decomposition_score_normalized = subgoal_score / 0.3 if subgoal_score > 0 else 0.0
|
|
262
|
+
metrics["subgoal_decomposition"] = MetricResult(
|
|
263
|
+
score=min(1.0, subgoal_decomposition_score_normalized),
|
|
264
|
+
is_score_valid=subgoal_decomposition_score_normalized > 0.5,
|
|
265
|
+
reason=f"Found {subgoal_count} subgoal patterns",
|
|
266
|
+
)
|
|
267
|
+
metrics["hierarchical_structure"] = MetricResult(
|
|
268
|
+
score=hierarchy_depth,
|
|
269
|
+
is_score_valid=hierarchy_depth > 0.5,
|
|
270
|
+
reason=f"Hierarchical depth: {hierarchy_depth:.2f}",
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
return EvaluateResult(
|
|
274
|
+
score=final_score,
|
|
275
|
+
reason=top_level_reason,
|
|
276
|
+
metrics=metrics,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
@reward_function
|
|
281
|
+
def deepseek_huggingface_prover_benchmark(
|
|
282
|
+
messages: List[Message],
|
|
283
|
+
ground_truth: Dict[str, Any],
|
|
284
|
+
**kwargs: Any,
|
|
285
|
+
) -> EvaluateResult:
|
|
286
|
+
"""
|
|
287
|
+
Evaluates a Lean proof against the DeepSeek ProverBench dataset from Hugging Face.
|
|
288
|
+
This reward function is specifically designed to work with the
|
|
289
|
+
deepseek-ai/DeepSeek-ProverBench dataset.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
messages: List of conversation messages. The last message is the assistant's response.
|
|
293
|
+
ground_truth: A dictionary containing ground truth information. Expected keys:
|
|
294
|
+
'statement' (str): The theorem statement.
|
|
295
|
+
Optionally 'dataset_item' (dict): Pre-loaded dataset item.
|
|
296
|
+
Optionally 'expected_proof' (str): The reference proof.
|
|
297
|
+
Optionally 'answer' (str): A short answer if applicable.
|
|
298
|
+
**kwargs: Optional: 'dataset_name' (str), 'check_for_answer' (bool), 'verbose' (bool).
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
EvaluateResult with score and metrics
|
|
302
|
+
"""
|
|
303
|
+
statement: Optional[str] = ground_truth.get("statement")
|
|
304
|
+
dataset_item: Optional[Dict[str, Any]] = ground_truth.get("dataset_item")
|
|
305
|
+
expected_proof_from_gt: Optional[str] = ground_truth.get("expected_proof")
|
|
306
|
+
answer_from_gt: Optional[str] = ground_truth.get("answer")
|
|
307
|
+
|
|
308
|
+
dataset_name: str = kwargs.get("dataset_name", "deepseek-ai/DeepSeek-ProverBench")
|
|
309
|
+
check_for_answer: bool = kwargs.get("check_for_answer", True)
|
|
310
|
+
verbose: bool = kwargs.get("verbose", False)
|
|
311
|
+
|
|
312
|
+
if not statement:
|
|
313
|
+
return EvaluateResult(
|
|
314
|
+
score=0.0,
|
|
315
|
+
reason="Statement not found in ground_truth dict for HuggingFace benchmark.",
|
|
316
|
+
metrics={
|
|
317
|
+
"error": MetricResult(
|
|
318
|
+
score=0.0,
|
|
319
|
+
is_score_valid=False,
|
|
320
|
+
reason="Statement not provided in ground_truth.",
|
|
321
|
+
)
|
|
322
|
+
},
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
if (
|
|
326
|
+
not messages
|
|
327
|
+
or not isinstance(messages[-1], Message)
|
|
328
|
+
or messages[-1].role != "assistant"
|
|
329
|
+
or messages[-1].content is None
|
|
330
|
+
):
|
|
331
|
+
return EvaluateResult(
|
|
332
|
+
score=0.0,
|
|
333
|
+
reason="Invalid or missing assistant response in messages.",
|
|
334
|
+
metrics={
|
|
335
|
+
"error": MetricResult(
|
|
336
|
+
score=0.0,
|
|
337
|
+
is_score_valid=False,
|
|
338
|
+
reason="Last message not a valid assistant response.",
|
|
339
|
+
)
|
|
340
|
+
},
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
response = messages[-1].content
|
|
344
|
+
if not response:
|
|
345
|
+
return EvaluateResult(
|
|
346
|
+
score=0.0,
|
|
347
|
+
reason="Assistant response content is empty for HuggingFace benchmark.",
|
|
348
|
+
metrics={
|
|
349
|
+
"error": MetricResult(
|
|
350
|
+
score=0.0,
|
|
351
|
+
is_score_valid=False,
|
|
352
|
+
reason="Empty assistant response content.",
|
|
353
|
+
)
|
|
354
|
+
},
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
try:
|
|
358
|
+
from datasets import load_dataset
|
|
359
|
+
except ImportError:
|
|
360
|
+
raise ImportError(
|
|
361
|
+
"The 'datasets' package is required to use this reward function. "
|
|
362
|
+
"Please install it with 'pip install datasets'."
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
metrics = {}
|
|
366
|
+
|
|
367
|
+
if dataset_item is None:
|
|
368
|
+
dataset = load_dataset(dataset_name)
|
|
369
|
+
matched_item = None
|
|
370
|
+
for split in dataset.keys():
|
|
371
|
+
for item in dataset[split]:
|
|
372
|
+
if statement.strip() in item.get("statement", ""):
|
|
373
|
+
matched_item = item
|
|
374
|
+
break
|
|
375
|
+
if matched_item:
|
|
376
|
+
break
|
|
377
|
+
|
|
378
|
+
if not matched_item:
|
|
379
|
+
from difflib import SequenceMatcher
|
|
380
|
+
|
|
381
|
+
best_ratio: float = 0.0
|
|
382
|
+
matched_ratio: float = 0.0
|
|
383
|
+
for split in dataset.keys():
|
|
384
|
+
for item in dataset[split]:
|
|
385
|
+
ratio = SequenceMatcher(None, statement.strip(), item.get("statement", "")).ratio()
|
|
386
|
+
if ratio > best_ratio and ratio > 0.7:
|
|
387
|
+
best_ratio = ratio
|
|
388
|
+
matched_item = item
|
|
389
|
+
matched_ratio = ratio
|
|
390
|
+
if not matched_item:
|
|
391
|
+
return EvaluateResult(
|
|
392
|
+
score=0.0,
|
|
393
|
+
reason="No matching problem found in the dataset",
|
|
394
|
+
metrics={
|
|
395
|
+
"dataset_match": MetricResult(
|
|
396
|
+
score=0.0,
|
|
397
|
+
is_score_valid=False,
|
|
398
|
+
reason="No matching problem found in the dataset",
|
|
399
|
+
)
|
|
400
|
+
},
|
|
401
|
+
)
|
|
402
|
+
metrics["dataset_match"] = MetricResult(
|
|
403
|
+
score=matched_ratio,
|
|
404
|
+
is_score_valid=matched_ratio > 0.7,
|
|
405
|
+
reason=f"Found similar problem with {matched_ratio:.2f} similarity",
|
|
406
|
+
)
|
|
407
|
+
else:
|
|
408
|
+
metrics["dataset_match"] = MetricResult(
|
|
409
|
+
score=1.0, is_score_valid=True, reason="Found exact match in dataset"
|
|
410
|
+
)
|
|
411
|
+
dataset_item = matched_item
|
|
412
|
+
|
|
413
|
+
expected_proof = expected_proof_from_gt
|
|
414
|
+
reference_solution = None
|
|
415
|
+
if dataset_item:
|
|
416
|
+
if not expected_proof:
|
|
417
|
+
expected_proof = dataset_item.get("expected_proof", None)
|
|
418
|
+
reference_solution = dataset_item.get("reference_solution", None)
|
|
419
|
+
proof_reference = expected_proof or reference_solution
|
|
420
|
+
|
|
421
|
+
current_top_level_reason = "Evaluation against DeepSeek ProverBench dataset."
|
|
422
|
+
answer_to_check = answer_from_gt
|
|
423
|
+
if not answer_to_check and dataset_item:
|
|
424
|
+
answer_to_check = dataset_item.get("answer")
|
|
425
|
+
|
|
426
|
+
if check_for_answer and answer_to_check:
|
|
427
|
+
expected_answer_str = str(answer_to_check)
|
|
428
|
+
answer_found = expected_answer_str in response
|
|
429
|
+
if not answer_found:
|
|
430
|
+
metrics["answer_match"] = MetricResult(
|
|
431
|
+
score=0.0,
|
|
432
|
+
is_score_valid=False,
|
|
433
|
+
reason=f"Expected answer '{expected_answer_str}' not found in response",
|
|
434
|
+
)
|
|
435
|
+
return EvaluateResult(
|
|
436
|
+
score=0.2,
|
|
437
|
+
reason=f"Expected answer '{expected_answer_str}' not found.",
|
|
438
|
+
metrics=metrics,
|
|
439
|
+
)
|
|
440
|
+
else:
|
|
441
|
+
metrics["answer_match"] = MetricResult(
|
|
442
|
+
score=1.0,
|
|
443
|
+
is_score_valid=True,
|
|
444
|
+
reason="Expected answer found in response",
|
|
445
|
+
)
|
|
446
|
+
current_top_level_reason += " Expected answer found."
|
|
447
|
+
|
|
448
|
+
deepseek_kwargs_for_call = {
|
|
449
|
+
"statement": statement,
|
|
450
|
+
"check_subgoals": True,
|
|
451
|
+
"verbose": verbose,
|
|
452
|
+
}
|
|
453
|
+
eval_result_from_deepseek: EvaluateResult = deepseek_prover_v2_reward(
|
|
454
|
+
messages=messages, ground_truth=proof_reference, **deepseek_kwargs_for_call
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
result_score = eval_result_from_deepseek.score
|
|
458
|
+
result_reason = eval_result_from_deepseek.reason
|
|
459
|
+
result_metrics = eval_result_from_deepseek.metrics or {}
|
|
460
|
+
combined_metrics = {**metrics, **result_metrics}
|
|
461
|
+
|
|
462
|
+
if result_reason and result_reason not in current_top_level_reason:
|
|
463
|
+
current_top_level_reason += f" Sub-evaluation: {result_reason}"
|
|
464
|
+
|
|
465
|
+
if verbose:
|
|
466
|
+
combined_metrics["dataset_info"] = MetricResult(
|
|
467
|
+
score=1.0,
|
|
468
|
+
is_score_valid=True,
|
|
469
|
+
reason=json.dumps(
|
|
470
|
+
{
|
|
471
|
+
"id": dataset_item.get("id", ""),
|
|
472
|
+
"has_expected_proof": expected_proof is not None,
|
|
473
|
+
"has_reference_solution": reference_solution is not None,
|
|
474
|
+
"has_answer": "answer" in dataset_item if dataset_item else False,
|
|
475
|
+
}
|
|
476
|
+
),
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
return EvaluateResult(score=result_score, reason=current_top_level_reason, metrics=combined_metrics)
|