eval-protocol 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- development/__init__.py +1 -0
- development/normalize_sandbox_fusion.py +628 -0
- development/utils/__init__.py +1 -0
- development/utils/generate_api_key.py +31 -0
- development/utils/subprocess_manager.py +481 -0
- eval_protocol/__init__.py +86 -0
- eval_protocol/__main__.py +10 -0
- eval_protocol/_version.py +21 -0
- eval_protocol/adapters/__init__.py +1 -0
- eval_protocol/adapters/braintrust.py +8 -0
- eval_protocol/adapters/trl.py +8 -0
- eval_protocol/agent/__init__.py +29 -0
- eval_protocol/agent/models.py +69 -0
- eval_protocol/agent/orchestrator.py +893 -0
- eval_protocol/agent/resource_abc.py +89 -0
- eval_protocol/agent/resource_pool.py +184 -0
- eval_protocol/agent/resources/__init__.py +44 -0
- eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
- eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
- eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
- eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
- eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
- eval_protocol/agent/resources/docker_resource.py +479 -0
- eval_protocol/agent/resources/filesystem_resource.py +371 -0
- eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
- eval_protocol/agent/resources/http_rollout_resource.py +325 -0
- eval_protocol/agent/resources/python_state_resource.py +170 -0
- eval_protocol/agent/resources/sql_resource.py +271 -0
- eval_protocol/agent/task_manager.py +1064 -0
- eval_protocol/agent/tool_registry.py +111 -0
- eval_protocol/auth.py +156 -0
- eval_protocol/cli.py +425 -0
- eval_protocol/cli_commands/__init__.py +1 -0
- eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
- eval_protocol/cli_commands/common.py +242 -0
- eval_protocol/cli_commands/deploy.py +486 -0
- eval_protocol/cli_commands/deploy_mcp.py +287 -0
- eval_protocol/cli_commands/preview.py +186 -0
- eval_protocol/cli_commands/run_eval_cmd.py +202 -0
- eval_protocol/common_utils.py +36 -0
- eval_protocol/config.py +180 -0
- eval_protocol/datasets/__init__.py +1 -0
- eval_protocol/datasets/loader.py +521 -0
- eval_protocol/evaluation.py +1045 -0
- eval_protocol/execution/__init__.py +1 -0
- eval_protocol/execution/pipeline.py +920 -0
- eval_protocol/gcp_tools.py +484 -0
- eval_protocol/generation/cache.py +141 -0
- eval_protocol/generation/clients/base.py +67 -0
- eval_protocol/generation/clients.py +248 -0
- eval_protocol/generic_server.py +165 -0
- eval_protocol/integrations/__init__.py +12 -0
- eval_protocol/integrations/braintrust.py +51 -0
- eval_protocol/integrations/deepeval.py +106 -0
- eval_protocol/integrations/openeval.py +40 -0
- eval_protocol/integrations/trl.py +187 -0
- eval_protocol/mcp/__init__.py +48 -0
- eval_protocol/mcp/adapter.py +131 -0
- eval_protocol/mcp/client/__init__.py +12 -0
- eval_protocol/mcp/client/connection.py +499 -0
- eval_protocol/mcp/clients.py +195 -0
- eval_protocol/mcp/execution/__init__.py +23 -0
- eval_protocol/mcp/execution/base_policy.py +227 -0
- eval_protocol/mcp/execution/fireworks_policy.py +209 -0
- eval_protocol/mcp/execution/manager.py +506 -0
- eval_protocol/mcp/execution/policy.py +421 -0
- eval_protocol/mcp/grid_renderer.py +54 -0
- eval_protocol/mcp/mcpgym.py +637 -0
- eval_protocol/mcp/process_manager.py +177 -0
- eval_protocol/mcp/session/__init__.py +11 -0
- eval_protocol/mcp/session/manager.py +228 -0
- eval_protocol/mcp/simple_process_manager.py +291 -0
- eval_protocol/mcp/simulation_server.py +458 -0
- eval_protocol/mcp/types.py +80 -0
- eval_protocol/mcp_agent/__init__.py +1 -0
- eval_protocol/mcp_agent/config.py +147 -0
- eval_protocol/mcp_agent/intermediary_server.py +542 -0
- eval_protocol/mcp_agent/main.py +210 -0
- eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
- eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
- eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
- eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
- eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
- eval_protocol/mcp_agent/session.py +79 -0
- eval_protocol/mcp_env.py +304 -0
- eval_protocol/models.py +366 -0
- eval_protocol/packaging.py +219 -0
- eval_protocol/platform_api.py +360 -0
- eval_protocol/playback_policy.py +396 -0
- eval_protocol/resources.py +128 -0
- eval_protocol/reward_function.py +410 -0
- eval_protocol/rewards/__init__.py +94 -0
- eval_protocol/rewards/accuracy.py +454 -0
- eval_protocol/rewards/accuracy_length.py +173 -0
- eval_protocol/rewards/apps_coding_reward.py +331 -0
- eval_protocol/rewards/apps_execution_utils.py +149 -0
- eval_protocol/rewards/apps_testing_util.py +559 -0
- eval_protocol/rewards/bfcl_reward.py +313 -0
- eval_protocol/rewards/code_execution.py +1620 -0
- eval_protocol/rewards/code_execution_utils.py +72 -0
- eval_protocol/rewards/cpp_code.py +861 -0
- eval_protocol/rewards/deepcoder_reward.py +161 -0
- eval_protocol/rewards/format.py +129 -0
- eval_protocol/rewards/function_calling.py +541 -0
- eval_protocol/rewards/json_schema.py +422 -0
- eval_protocol/rewards/language_consistency.py +700 -0
- eval_protocol/rewards/lean_prover.py +479 -0
- eval_protocol/rewards/length.py +375 -0
- eval_protocol/rewards/list_comparison_math_reward.py +221 -0
- eval_protocol/rewards/math.py +762 -0
- eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
- eval_protocol/rewards/reasoning_steps.py +249 -0
- eval_protocol/rewards/repetition.py +342 -0
- eval_protocol/rewards/tag_count.py +162 -0
- eval_protocol/rl_processing.py +82 -0
- eval_protocol/server.py +271 -0
- eval_protocol/typed_interface.py +260 -0
- eval_protocol/utils/__init__.py +8 -0
- eval_protocol/utils/batch_evaluation.py +217 -0
- eval_protocol/utils/batch_transformation.py +205 -0
- eval_protocol/utils/dataset_helpers.py +112 -0
- eval_protocol/utils/module_loader.py +56 -0
- eval_protocol/utils/packaging_utils.py +108 -0
- eval_protocol/utils/static_policy.py +305 -0
- eval_protocol-0.0.3.dist-info/METADATA +635 -0
- eval_protocol-0.0.3.dist-info/RECORD +130 -0
- eval_protocol-0.0.3.dist-info/WHEEL +5 -0
- eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
- eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
- eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,541 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
import warnings
|
|
5
|
+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
|
6
|
+
|
|
7
|
+
# Import OpenAI at module level for mocking in tests
|
|
8
|
+
try:
|
|
9
|
+
import openai
|
|
10
|
+
from openai import OpenAI
|
|
11
|
+
except ImportError:
|
|
12
|
+
# Type to mock in tests
|
|
13
|
+
OpenAI = None # type: ignore
|
|
14
|
+
|
|
15
|
+
import copy
|
|
16
|
+
from collections import Counter
|
|
17
|
+
|
|
18
|
+
from ..models import EvaluateResult, Message, MetricResult
|
|
19
|
+
from ..typed_interface import reward_function
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def match_function_call(
|
|
23
|
+
messages: List[Dict[str, Any]],
|
|
24
|
+
function_name: str,
|
|
25
|
+
parsed_arguments: Dict[str, Any],
|
|
26
|
+
expected_call_schema: Dict[str, Any],
|
|
27
|
+
argument_match_strictness: str = "exact",
|
|
28
|
+
**kwargs,
|
|
29
|
+
) -> EvaluateResult:
|
|
30
|
+
"""
|
|
31
|
+
Evaluate how well a function call matches an expected schema.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
messages: The conversation messages (for context, not directly used for call parts).
|
|
35
|
+
function_name: The parsed function name.
|
|
36
|
+
parsed_arguments: The parsed arguments from the function call.
|
|
37
|
+
expected_call_schema: The expected schema for the function call.
|
|
38
|
+
argument_match_strictness: How strict to be with argument matching:
|
|
39
|
+
- "exact": All arguments must match exactly
|
|
40
|
+
- "partial": Only check provided arguments, ignore missing ones
|
|
41
|
+
- "flexible": Allow extra arguments and type mismatches with penalty
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
EvaluateResult with score and metrics
|
|
45
|
+
"""
|
|
46
|
+
metrics = {}
|
|
47
|
+
|
|
48
|
+
# 1. Function name match
|
|
49
|
+
expected_name = expected_call_schema.get("name", "")
|
|
50
|
+
name_match = function_name == expected_name
|
|
51
|
+
name_score = 1.0 if name_match else 0.0
|
|
52
|
+
name_reason = f"Function name {'matches' if name_match else 'does not match'}: expected '{expected_name}', got '{function_name}'"
|
|
53
|
+
metrics["function_name_match"] = MetricResult(score=name_score, reason=name_reason, is_score_valid=name_match)
|
|
54
|
+
|
|
55
|
+
# 2. Arguments match
|
|
56
|
+
expected_args = expected_call_schema.get("arguments", {})
|
|
57
|
+
arg_score = 0.0
|
|
58
|
+
arg_details = []
|
|
59
|
+
|
|
60
|
+
missing_args = []
|
|
61
|
+
extra_args = []
|
|
62
|
+
type_mismatches = []
|
|
63
|
+
perfect_matches = []
|
|
64
|
+
|
|
65
|
+
for arg_name, arg_schema in expected_args.items():
|
|
66
|
+
expected_type = arg_schema.get("type", "any")
|
|
67
|
+
|
|
68
|
+
if arg_name not in parsed_arguments:
|
|
69
|
+
missing_args.append(arg_name)
|
|
70
|
+
arg_details.append(f"Missing argument: {arg_name}")
|
|
71
|
+
else:
|
|
72
|
+
arg_value = parsed_arguments[arg_name]
|
|
73
|
+
type_matched = True
|
|
74
|
+
if expected_type == "string" and not isinstance(arg_value, str):
|
|
75
|
+
type_mismatches.append(arg_name)
|
|
76
|
+
arg_details.append(f"Type mismatch for {arg_name}: expected string, got {type(arg_value).__name__}")
|
|
77
|
+
type_matched = False
|
|
78
|
+
elif expected_type == "number" and not isinstance(arg_value, (int, float)):
|
|
79
|
+
type_mismatches.append(arg_name)
|
|
80
|
+
arg_details.append(f"Type mismatch for {arg_name}: expected number, got {type(arg_value).__name__}")
|
|
81
|
+
type_matched = False
|
|
82
|
+
elif expected_type == "boolean" and not isinstance(arg_value, bool):
|
|
83
|
+
type_mismatches.append(arg_name)
|
|
84
|
+
arg_details.append(f"Type mismatch for {arg_name}: expected boolean, got {type(arg_value).__name__}")
|
|
85
|
+
type_matched = False
|
|
86
|
+
elif expected_type == "array" and not isinstance(arg_value, list):
|
|
87
|
+
type_mismatches.append(arg_name)
|
|
88
|
+
arg_details.append(f"Type mismatch for {arg_name}: expected array, got {type(arg_value).__name__}")
|
|
89
|
+
type_matched = False
|
|
90
|
+
elif expected_type == "object" and not isinstance(arg_value, dict):
|
|
91
|
+
type_mismatches.append(arg_name)
|
|
92
|
+
arg_details.append(f"Type mismatch for {arg_name}: expected object, got {type(arg_value).__name__}")
|
|
93
|
+
type_matched = False
|
|
94
|
+
|
|
95
|
+
if type_matched:
|
|
96
|
+
perfect_matches.append(arg_name)
|
|
97
|
+
arg_details.append(f"Argument {arg_name} matches expected type {expected_type}")
|
|
98
|
+
|
|
99
|
+
for arg_name in parsed_arguments:
|
|
100
|
+
if arg_name not in expected_args:
|
|
101
|
+
extra_args.append(arg_name)
|
|
102
|
+
arg_details.append(f"Unexpected argument: {arg_name}")
|
|
103
|
+
|
|
104
|
+
if argument_match_strictness == "exact":
|
|
105
|
+
if missing_args or extra_args or type_mismatches:
|
|
106
|
+
arg_score = 0.0
|
|
107
|
+
else:
|
|
108
|
+
arg_score = 1.0
|
|
109
|
+
elif argument_match_strictness == "partial":
|
|
110
|
+
if extra_args or type_mismatches:
|
|
111
|
+
arg_score = 0.0
|
|
112
|
+
else:
|
|
113
|
+
total_provided = len(parsed_arguments)
|
|
114
|
+
if total_provided == 0:
|
|
115
|
+
arg_score = 0.0
|
|
116
|
+
else:
|
|
117
|
+
correct_args = len(perfect_matches)
|
|
118
|
+
arg_score = correct_args / total_provided
|
|
119
|
+
elif argument_match_strictness == "permissive" or argument_match_strictness == "flexible":
|
|
120
|
+
if missing_args or type_mismatches:
|
|
121
|
+
arg_score = 0.0
|
|
122
|
+
else:
|
|
123
|
+
arg_score = 1.0
|
|
124
|
+
else:
|
|
125
|
+
raise ValueError(f"Invalid argument_match_strictness: {argument_match_strictness}")
|
|
126
|
+
|
|
127
|
+
arg_reason = "\n".join(arg_details)
|
|
128
|
+
metrics["arguments_match"] = MetricResult(
|
|
129
|
+
score=arg_score,
|
|
130
|
+
reason=arg_reason,
|
|
131
|
+
is_score_valid=arg_score == 1.0 if len(expected_args) > 0 else True,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# 3. Calculate final score
|
|
135
|
+
final_score = (name_score + arg_score) / 2.0
|
|
136
|
+
final_reason = f"Overall score based on name match ({name_score:.2f}) and argument match ({arg_score:.2f})."
|
|
137
|
+
|
|
138
|
+
return EvaluateResult(score=final_score, reason=final_reason, metrics=metrics)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def calculate_jaccard_similarity(set1: Set, set2: Set) -> float:
|
|
142
|
+
"""
|
|
143
|
+
Calculate Jaccard similarity between two sets.
|
|
144
|
+
|
|
145
|
+
Jaccard similarity is defined as the size of the intersection divided by the size of the union.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
set1: First set
|
|
149
|
+
set2: Second set
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Jaccard similarity score between 0.0 and 1.0
|
|
153
|
+
"""
|
|
154
|
+
if not set1 and not set2:
|
|
155
|
+
return 1.0
|
|
156
|
+
|
|
157
|
+
intersection = len(set1.intersection(set2))
|
|
158
|
+
union = len(set1.union(set2))
|
|
159
|
+
|
|
160
|
+
return intersection / union
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def extract_schema_properties(schema: Dict[str, Any]) -> Set[Tuple[str, str]]:
|
|
164
|
+
"""
|
|
165
|
+
Extract properties from a JSON schema as a set of (name, type) tuples.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
schema: JSON schema object
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
Set of (property_name, property_type) tuples
|
|
172
|
+
"""
|
|
173
|
+
properties = set()
|
|
174
|
+
|
|
175
|
+
def process_properties(schema_obj: Dict[str, Any], prefix: str = ""):
|
|
176
|
+
if not isinstance(schema_obj, dict):
|
|
177
|
+
return
|
|
178
|
+
|
|
179
|
+
props = schema_obj.get("properties", {})
|
|
180
|
+
for prop_name, prop_schema in props.items():
|
|
181
|
+
prop_path = f"{prefix}.{prop_name}" if prefix else prop_name
|
|
182
|
+
prop_type = prop_schema.get("type", "any")
|
|
183
|
+
properties.add((prop_path, prop_type))
|
|
184
|
+
if prop_type == "object":
|
|
185
|
+
process_properties(prop_schema, prop_path)
|
|
186
|
+
|
|
187
|
+
pattern_props = schema_obj.get("patternProperties", {})
|
|
188
|
+
for pattern, pattern_schema in pattern_props.items():
|
|
189
|
+
prop_path = f"{prefix}[{pattern}]" if prefix else f"[{pattern}]"
|
|
190
|
+
prop_type = pattern_schema.get("type", "any")
|
|
191
|
+
properties.add((prop_path, prop_type))
|
|
192
|
+
if prop_type == "object":
|
|
193
|
+
process_properties(pattern_schema, prop_path)
|
|
194
|
+
|
|
195
|
+
items = schema_obj.get("items", {})
|
|
196
|
+
if items and isinstance(items, dict):
|
|
197
|
+
prop_path = f"{prefix}[]" if prefix else "[]"
|
|
198
|
+
prop_type = items.get("type", "any")
|
|
199
|
+
properties.add((prop_path, prop_type))
|
|
200
|
+
if prop_type == "object":
|
|
201
|
+
process_properties(items, prop_path)
|
|
202
|
+
|
|
203
|
+
process_properties(schema)
|
|
204
|
+
return properties
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def normalize_schema(schema: Union[Dict[str, Any], str]) -> Dict[str, Any]:
|
|
208
|
+
"""
|
|
209
|
+
Normalize schema to a standard dictionary format.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
schema: JSON schema as dictionary or string
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
Normalized schema dictionary
|
|
216
|
+
"""
|
|
217
|
+
if isinstance(schema, str):
|
|
218
|
+
try:
|
|
219
|
+
schema = json.loads(schema)
|
|
220
|
+
except json.JSONDecodeError:
|
|
221
|
+
return {}
|
|
222
|
+
|
|
223
|
+
if not isinstance(schema, dict):
|
|
224
|
+
return {}
|
|
225
|
+
|
|
226
|
+
return schema
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
# New Exact Tool Match Reward Function and Helpers
|
|
230
|
+
# VVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVV
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def maybe_deserialize_tool_call_arguments(
|
|
234
|
+
tool_calls: list[dict[str, Any]],
|
|
235
|
+
) -> list[dict[str, Any]]:
|
|
236
|
+
"""
|
|
237
|
+
Deserializes the 'arguments' field (if it's a JSON string) within each tool call's 'function' object.
|
|
238
|
+
Input tool_calls are expected to be in OpenAI format:
|
|
239
|
+
[{'id': ..., 'type': 'function', 'function': {'name': ..., 'arguments': 'JSON_STRING_ARGS'}}, ...]
|
|
240
|
+
"""
|
|
241
|
+
processed_tool_calls = []
|
|
242
|
+
if not tool_calls:
|
|
243
|
+
return []
|
|
244
|
+
|
|
245
|
+
for tc_openai_format in tool_calls:
|
|
246
|
+
if not isinstance(tc_openai_format, dict) or "function" not in tc_openai_format:
|
|
247
|
+
continue
|
|
248
|
+
|
|
249
|
+
function_details = tc_openai_format.get("function", {})
|
|
250
|
+
if not isinstance(function_details, dict) or "arguments" not in function_details:
|
|
251
|
+
continue
|
|
252
|
+
|
|
253
|
+
arguments_val = function_details["arguments"]
|
|
254
|
+
deserialized_args = arguments_val
|
|
255
|
+
|
|
256
|
+
if isinstance(arguments_val, str):
|
|
257
|
+
if not arguments_val.strip():
|
|
258
|
+
deserialized_args = {}
|
|
259
|
+
else:
|
|
260
|
+
try:
|
|
261
|
+
deserialized_args = json.loads(arguments_val)
|
|
262
|
+
except json.JSONDecodeError:
|
|
263
|
+
# If arguments string is not valid JSON, keep it as a string.
|
|
264
|
+
# This matches behavior of some models that might return non-JSON arguments.
|
|
265
|
+
pass
|
|
266
|
+
|
|
267
|
+
new_tc = copy.deepcopy(tc_openai_format)
|
|
268
|
+
new_tc["function"]["arguments"] = deserialized_args
|
|
269
|
+
processed_tool_calls.append(new_tc)
|
|
270
|
+
|
|
271
|
+
return processed_tool_calls
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def parse_tool_calls(completion: str) -> list:
|
|
275
|
+
matches = re.findall(r"<tool_call>(.*?)</tool_call>", completion, re.DOTALL)
|
|
276
|
+
row_tool_calls = []
|
|
277
|
+
for match in matches:
|
|
278
|
+
try:
|
|
279
|
+
tool_call_str = match.strip()
|
|
280
|
+
row_tool_calls.append(json.loads(tool_call_str))
|
|
281
|
+
except Exception:
|
|
282
|
+
continue
|
|
283
|
+
return row_tool_calls
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def compare_tool_calls(generated_tool_calls: list, gt_tool_calls: list) -> bool:
|
|
287
|
+
if len(generated_tool_calls) != len(gt_tool_calls):
|
|
288
|
+
return False
|
|
289
|
+
|
|
290
|
+
generated_tool_calls_serialized = [json.dumps(item, sort_keys=True) for item in generated_tool_calls]
|
|
291
|
+
gt_tool_calls_serialized = [json.dumps(item, sort_keys=True) for item in gt_tool_calls]
|
|
292
|
+
|
|
293
|
+
return generated_tool_calls_serialized == gt_tool_calls_serialized
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def eval_tool_call(generation: dict, ground_truth: dict) -> bool:
|
|
297
|
+
if ground_truth is None or "tool_calls" not in ground_truth:
|
|
298
|
+
expected_gt_tool_calls = []
|
|
299
|
+
else:
|
|
300
|
+
expected_gt_tool_calls = ground_truth["tool_calls"]
|
|
301
|
+
|
|
302
|
+
deserialized_gt_openai_tool_calls = maybe_deserialize_tool_call_arguments(expected_gt_tool_calls or [])
|
|
303
|
+
ground_truth_simple_format = [tc["function"] for tc in deserialized_gt_openai_tool_calls if "function" in tc]
|
|
304
|
+
|
|
305
|
+
generated_simple_format = []
|
|
306
|
+
raw_generated_tool_calls = generation.get("tool_calls")
|
|
307
|
+
|
|
308
|
+
if raw_generated_tool_calls:
|
|
309
|
+
processed_gen_tool_calls_openai_format = []
|
|
310
|
+
for tc in raw_generated_tool_calls:
|
|
311
|
+
if hasattr(tc, "model_dump"):
|
|
312
|
+
processed_gen_tool_calls_openai_format.append(tc.model_dump())
|
|
313
|
+
elif isinstance(tc, dict):
|
|
314
|
+
processed_gen_tool_calls_openai_format.append(tc)
|
|
315
|
+
|
|
316
|
+
deserialized_gen_openai_tool_calls = maybe_deserialize_tool_call_arguments(
|
|
317
|
+
processed_gen_tool_calls_openai_format
|
|
318
|
+
)
|
|
319
|
+
generated_simple_format = [tc["function"] for tc in deserialized_gen_openai_tool_calls if "function" in tc]
|
|
320
|
+
elif generation.get("content") and "<tool_call>" in generation["content"]:
|
|
321
|
+
parsed_tool_calls_from_content_str = parse_tool_calls(generation["content"])
|
|
322
|
+
temp_openai_formatted_list = []
|
|
323
|
+
for item in parsed_tool_calls_from_content_str:
|
|
324
|
+
if isinstance(item, dict) and "function" in item and "type" in item:
|
|
325
|
+
temp_openai_formatted_list.append(item)
|
|
326
|
+
elif isinstance(item, dict) and "name" in item and "arguments" in item:
|
|
327
|
+
temp_openai_formatted_list.append(
|
|
328
|
+
{
|
|
329
|
+
"id": f"parsed_call_{len(temp_openai_formatted_list)}",
|
|
330
|
+
"type": "function",
|
|
331
|
+
"function": {
|
|
332
|
+
"name": item["name"],
|
|
333
|
+
"arguments": (
|
|
334
|
+
json.dumps(item["arguments"])
|
|
335
|
+
if isinstance(item["arguments"], dict)
|
|
336
|
+
else item["arguments"]
|
|
337
|
+
),
|
|
338
|
+
},
|
|
339
|
+
}
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
if temp_openai_formatted_list:
|
|
343
|
+
deserialized_calls_from_content = maybe_deserialize_tool_call_arguments(temp_openai_formatted_list)
|
|
344
|
+
generated_simple_format = [tc["function"] for tc in deserialized_calls_from_content if "function" in tc]
|
|
345
|
+
|
|
346
|
+
return compare_tool_calls(generated_simple_format, ground_truth_simple_format)
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
@reward_function
|
|
350
|
+
def exact_tool_match_reward(
|
|
351
|
+
messages: Union[List[Message], List[Dict[str, Any]]],
|
|
352
|
+
ground_truth: Optional[Dict[str, Any]] = None,
|
|
353
|
+
**kwargs,
|
|
354
|
+
) -> EvaluateResult:
|
|
355
|
+
if not messages:
|
|
356
|
+
return EvaluateResult(score=0.0, reason="No messages provided for evaluation.", metrics={})
|
|
357
|
+
|
|
358
|
+
generation_message_obj = messages[-1]
|
|
359
|
+
generation_dict: Dict[str, Any]
|
|
360
|
+
|
|
361
|
+
if isinstance(generation_message_obj, Message):
|
|
362
|
+
generation_dict = {
|
|
363
|
+
"role": generation_message_obj.role,
|
|
364
|
+
"content": generation_message_obj.content,
|
|
365
|
+
}
|
|
366
|
+
if generation_message_obj.tool_calls:
|
|
367
|
+
generation_dict["tool_calls"] = [
|
|
368
|
+
tc.model_dump() if hasattr(tc, "model_dump") else tc for tc in generation_message_obj.tool_calls
|
|
369
|
+
]
|
|
370
|
+
elif isinstance(generation_message_obj, dict):
|
|
371
|
+
generation_dict = generation_message_obj
|
|
372
|
+
else:
|
|
373
|
+
return EvaluateResult(
|
|
374
|
+
score=0.0,
|
|
375
|
+
reason=f"Unexpected type for generation message: {type(generation_message_obj)}",
|
|
376
|
+
metrics={},
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
if ground_truth is None:
|
|
380
|
+
has_generation_tool_calls = False
|
|
381
|
+
if generation_dict.get("tool_calls"):
|
|
382
|
+
has_generation_tool_calls = True
|
|
383
|
+
elif "<tool_call>" in generation_dict.get("content", ""):
|
|
384
|
+
if parse_tool_calls(generation_dict.get("content", "")):
|
|
385
|
+
has_generation_tool_calls = True
|
|
386
|
+
|
|
387
|
+
score = 1.0 if not has_generation_tool_calls else 0.0
|
|
388
|
+
reason = (
|
|
389
|
+
"Ground truth not provided. Score based on absence (1.0) or presence (0.0) of tool calls in generation."
|
|
390
|
+
)
|
|
391
|
+
return EvaluateResult(score=score, reason=reason, metrics={})
|
|
392
|
+
|
|
393
|
+
if isinstance(ground_truth, str):
|
|
394
|
+
try:
|
|
395
|
+
ground_truth = json.loads(ground_truth)
|
|
396
|
+
except json.JSONDecodeError:
|
|
397
|
+
return EvaluateResult(
|
|
398
|
+
score=0.0,
|
|
399
|
+
reason=f"Ground truth was a string but failed to parse as JSON: {ground_truth[:100]}...",
|
|
400
|
+
metrics={},
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
if not isinstance(ground_truth, dict):
|
|
404
|
+
return EvaluateResult(
|
|
405
|
+
score=0.0,
|
|
406
|
+
reason=f"Ground truth is not a dictionary (even after attempting parse): {type(ground_truth)}",
|
|
407
|
+
metrics={},
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
score = float(eval_tool_call(generation_dict, ground_truth))
|
|
411
|
+
reason = f"Exact tool match evaluation score: {score}"
|
|
412
|
+
return EvaluateResult(score=score, reason=reason, metrics={})
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
416
|
+
# End of New Exact Tool Match Reward Function and Helpers
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
@reward_function
|
|
420
|
+
def schema_jaccard_reward(
|
|
421
|
+
messages: Union[List[Message], List[Dict[str, Any]]],
|
|
422
|
+
ground_truth: Optional[Dict[str, Any]] = None,
|
|
423
|
+
function_call: Optional[Dict[str, Any]] = None,
|
|
424
|
+
expected_schema: Optional[Union[Dict[str, Any], str]] = None,
|
|
425
|
+
**kwargs,
|
|
426
|
+
) -> EvaluateResult:
|
|
427
|
+
"""
|
|
428
|
+
DEPRECATED: This function is deprecated and will be removed in a future version.
|
|
429
|
+
Please use `exact_tool_match_reward` for evaluating tool calls.
|
|
430
|
+
|
|
431
|
+
NOTE: This function now delegates to exact_tool_match_reward.
|
|
432
|
+
Original Jaccard similarity logic for function call schemas is bypassed.
|
|
433
|
+
The helper functions for Jaccard similarity are kept in this file as they
|
|
434
|
+
are used by eval_protocol.rewards.json_schema.py.
|
|
435
|
+
|
|
436
|
+
Args:
|
|
437
|
+
messages: List of conversation messages.
|
|
438
|
+
ground_truth: Expected assistant response as a dictionary.
|
|
439
|
+
function_call: Kept for signature compatibility.
|
|
440
|
+
expected_schema: Kept for signature compatibility.
|
|
441
|
+
**kwargs: Additional keyword arguments.
|
|
442
|
+
|
|
443
|
+
Returns:
|
|
444
|
+
EvaluateResult from exact_tool_match_reward.
|
|
445
|
+
"""
|
|
446
|
+
warnings.warn(
|
|
447
|
+
"`schema_jaccard_reward` is deprecated and will be removed in a future version. "
|
|
448
|
+
"Please use `exact_tool_match_reward`.",
|
|
449
|
+
DeprecationWarning,
|
|
450
|
+
stacklevel=2,
|
|
451
|
+
)
|
|
452
|
+
return exact_tool_match_reward(messages=messages, ground_truth=ground_truth, **kwargs)
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
@reward_function
|
|
456
|
+
def llm_judge_reward(
|
|
457
|
+
messages: Union[List[Message], List[Dict[str, Any]]],
|
|
458
|
+
ground_truth: Optional[Dict[str, Any]] = None,
|
|
459
|
+
function_call: Optional[Dict[str, Any]] = None,
|
|
460
|
+
expected_schema: Optional[Union[Dict[str, Any], str]] = None,
|
|
461
|
+
expected_behavior: Optional[str] = None,
|
|
462
|
+
openai_api_key: Optional[str] = None,
|
|
463
|
+
model: str = "gpt-4o-mini",
|
|
464
|
+
temperature: float = 0.0,
|
|
465
|
+
**kwargs,
|
|
466
|
+
) -> EvaluateResult:
|
|
467
|
+
"""
|
|
468
|
+
DEPRECATED: This function is deprecated and will be removed in a future version.
|
|
469
|
+
Please use `exact_tool_match_reward` for evaluating tool calls.
|
|
470
|
+
|
|
471
|
+
NOTE: This function now delegates to exact_tool_match_reward.
|
|
472
|
+
Original LLM judge logic is bypassed.
|
|
473
|
+
|
|
474
|
+
Args:
|
|
475
|
+
messages: List of conversation messages.
|
|
476
|
+
ground_truth: Expected assistant response as a dictionary.
|
|
477
|
+
function_call: Kept for signature compatibility.
|
|
478
|
+
expected_schema: Kept for signature compatibility.
|
|
479
|
+
expected_behavior: Kept for signature compatibility.
|
|
480
|
+
openai_api_key: Kept for signature compatibility.
|
|
481
|
+
model: Kept for signature compatibility.
|
|
482
|
+
temperature: Kept for signature compatibility.
|
|
483
|
+
**kwargs: Additional keyword arguments.
|
|
484
|
+
|
|
485
|
+
Returns:
|
|
486
|
+
EvaluateResult from exact_tool_match_reward.
|
|
487
|
+
"""
|
|
488
|
+
warnings.warn(
|
|
489
|
+
"`llm_judge_reward` is deprecated and will be removed in a future version. "
|
|
490
|
+
"Please use `exact_tool_match_reward`.",
|
|
491
|
+
DeprecationWarning,
|
|
492
|
+
stacklevel=2,
|
|
493
|
+
)
|
|
494
|
+
return exact_tool_match_reward(messages=messages, ground_truth=ground_truth, **kwargs)
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
@reward_function
|
|
498
|
+
def composite_function_call_reward(
|
|
499
|
+
messages: Union[List[Message], List[Dict[str, Any]]],
|
|
500
|
+
ground_truth: Optional[Dict[str, Any]] = None,
|
|
501
|
+
function_call: Optional[Dict[str, Any]] = None,
|
|
502
|
+
expected_schema: Optional[Union[Dict[str, Any], str]] = None,
|
|
503
|
+
expected_behavior: Optional[str] = None,
|
|
504
|
+
openai_api_key: Optional[str] = None,
|
|
505
|
+
llm_model: str = "gpt-4o-mini",
|
|
506
|
+
weights: Optional[Dict[str, float]] = None,
|
|
507
|
+
**kwargs,
|
|
508
|
+
) -> EvaluateResult:
|
|
509
|
+
"""
|
|
510
|
+
DEPRECATED: This function is deprecated and will be removed in a future version.
|
|
511
|
+
Please use `exact_tool_match_reward` for evaluating tool calls.
|
|
512
|
+
|
|
513
|
+
This reward function now delegates to exact_tool_match_reward
|
|
514
|
+
for an exact match evaluation of tool calls.
|
|
515
|
+
The model's response (containing the function call) is assumed to be `messages[-1]`.
|
|
516
|
+
|
|
517
|
+
Args:
|
|
518
|
+
messages: List of conversation messages, where `messages[-1]` is the model's response.
|
|
519
|
+
ground_truth: Expected assistant response as a dictionary, typically containing 'tool_calls'.
|
|
520
|
+
This is passed directly to exact_tool_match_reward.
|
|
521
|
+
function_call: Kept for signature compatibility.
|
|
522
|
+
expected_schema: Kept for signature compatibility.
|
|
523
|
+
expected_behavior: Kept for signature compatibility.
|
|
524
|
+
openai_api_key: Kept for signature compatibility.
|
|
525
|
+
llm_model: Kept for signature compatibility.
|
|
526
|
+
weights: Kept for signature compatibility.
|
|
527
|
+
**kwargs: Additional keyword arguments passed to exact_tool_match_reward.
|
|
528
|
+
|
|
529
|
+
Returns:
|
|
530
|
+
EvaluateResult with score and metrics from exact_tool_match_reward.
|
|
531
|
+
"""
|
|
532
|
+
warnings.warn(
|
|
533
|
+
"`composite_function_call_reward` is deprecated and will be removed in a future version. "
|
|
534
|
+
"Please use `exact_tool_match_reward`.",
|
|
535
|
+
DeprecationWarning,
|
|
536
|
+
stacklevel=2,
|
|
537
|
+
)
|
|
538
|
+
return exact_tool_match_reward(messages=messages, ground_truth=ground_truth, **kwargs)
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
# JSON schema reward functions have been moved to json_schema.py module
|