eval-protocol 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- development/__init__.py +1 -0
- development/normalize_sandbox_fusion.py +628 -0
- development/utils/__init__.py +1 -0
- development/utils/generate_api_key.py +31 -0
- development/utils/subprocess_manager.py +481 -0
- eval_protocol/__init__.py +86 -0
- eval_protocol/__main__.py +10 -0
- eval_protocol/_version.py +21 -0
- eval_protocol/adapters/__init__.py +1 -0
- eval_protocol/adapters/braintrust.py +8 -0
- eval_protocol/adapters/trl.py +8 -0
- eval_protocol/agent/__init__.py +29 -0
- eval_protocol/agent/models.py +69 -0
- eval_protocol/agent/orchestrator.py +893 -0
- eval_protocol/agent/resource_abc.py +89 -0
- eval_protocol/agent/resource_pool.py +184 -0
- eval_protocol/agent/resources/__init__.py +44 -0
- eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
- eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
- eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
- eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
- eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
- eval_protocol/agent/resources/docker_resource.py +479 -0
- eval_protocol/agent/resources/filesystem_resource.py +371 -0
- eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
- eval_protocol/agent/resources/http_rollout_resource.py +325 -0
- eval_protocol/agent/resources/python_state_resource.py +170 -0
- eval_protocol/agent/resources/sql_resource.py +271 -0
- eval_protocol/agent/task_manager.py +1064 -0
- eval_protocol/agent/tool_registry.py +111 -0
- eval_protocol/auth.py +156 -0
- eval_protocol/cli.py +425 -0
- eval_protocol/cli_commands/__init__.py +1 -0
- eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
- eval_protocol/cli_commands/common.py +242 -0
- eval_protocol/cli_commands/deploy.py +486 -0
- eval_protocol/cli_commands/deploy_mcp.py +287 -0
- eval_protocol/cli_commands/preview.py +186 -0
- eval_protocol/cli_commands/run_eval_cmd.py +202 -0
- eval_protocol/common_utils.py +36 -0
- eval_protocol/config.py +180 -0
- eval_protocol/datasets/__init__.py +1 -0
- eval_protocol/datasets/loader.py +521 -0
- eval_protocol/evaluation.py +1045 -0
- eval_protocol/execution/__init__.py +1 -0
- eval_protocol/execution/pipeline.py +920 -0
- eval_protocol/gcp_tools.py +484 -0
- eval_protocol/generation/cache.py +141 -0
- eval_protocol/generation/clients/base.py +67 -0
- eval_protocol/generation/clients.py +248 -0
- eval_protocol/generic_server.py +165 -0
- eval_protocol/integrations/__init__.py +12 -0
- eval_protocol/integrations/braintrust.py +51 -0
- eval_protocol/integrations/deepeval.py +106 -0
- eval_protocol/integrations/openeval.py +40 -0
- eval_protocol/integrations/trl.py +187 -0
- eval_protocol/mcp/__init__.py +48 -0
- eval_protocol/mcp/adapter.py +131 -0
- eval_protocol/mcp/client/__init__.py +12 -0
- eval_protocol/mcp/client/connection.py +499 -0
- eval_protocol/mcp/clients.py +195 -0
- eval_protocol/mcp/execution/__init__.py +23 -0
- eval_protocol/mcp/execution/base_policy.py +227 -0
- eval_protocol/mcp/execution/fireworks_policy.py +209 -0
- eval_protocol/mcp/execution/manager.py +506 -0
- eval_protocol/mcp/execution/policy.py +421 -0
- eval_protocol/mcp/grid_renderer.py +54 -0
- eval_protocol/mcp/mcpgym.py +637 -0
- eval_protocol/mcp/process_manager.py +177 -0
- eval_protocol/mcp/session/__init__.py +11 -0
- eval_protocol/mcp/session/manager.py +228 -0
- eval_protocol/mcp/simple_process_manager.py +291 -0
- eval_protocol/mcp/simulation_server.py +458 -0
- eval_protocol/mcp/types.py +80 -0
- eval_protocol/mcp_agent/__init__.py +1 -0
- eval_protocol/mcp_agent/config.py +147 -0
- eval_protocol/mcp_agent/intermediary_server.py +542 -0
- eval_protocol/mcp_agent/main.py +210 -0
- eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
- eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
- eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
- eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
- eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
- eval_protocol/mcp_agent/session.py +79 -0
- eval_protocol/mcp_env.py +304 -0
- eval_protocol/models.py +366 -0
- eval_protocol/packaging.py +219 -0
- eval_protocol/platform_api.py +360 -0
- eval_protocol/playback_policy.py +396 -0
- eval_protocol/resources.py +128 -0
- eval_protocol/reward_function.py +410 -0
- eval_protocol/rewards/__init__.py +94 -0
- eval_protocol/rewards/accuracy.py +454 -0
- eval_protocol/rewards/accuracy_length.py +173 -0
- eval_protocol/rewards/apps_coding_reward.py +331 -0
- eval_protocol/rewards/apps_execution_utils.py +149 -0
- eval_protocol/rewards/apps_testing_util.py +559 -0
- eval_protocol/rewards/bfcl_reward.py +313 -0
- eval_protocol/rewards/code_execution.py +1620 -0
- eval_protocol/rewards/code_execution_utils.py +72 -0
- eval_protocol/rewards/cpp_code.py +861 -0
- eval_protocol/rewards/deepcoder_reward.py +161 -0
- eval_protocol/rewards/format.py +129 -0
- eval_protocol/rewards/function_calling.py +541 -0
- eval_protocol/rewards/json_schema.py +422 -0
- eval_protocol/rewards/language_consistency.py +700 -0
- eval_protocol/rewards/lean_prover.py +479 -0
- eval_protocol/rewards/length.py +375 -0
- eval_protocol/rewards/list_comparison_math_reward.py +221 -0
- eval_protocol/rewards/math.py +762 -0
- eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
- eval_protocol/rewards/reasoning_steps.py +249 -0
- eval_protocol/rewards/repetition.py +342 -0
- eval_protocol/rewards/tag_count.py +162 -0
- eval_protocol/rl_processing.py +82 -0
- eval_protocol/server.py +271 -0
- eval_protocol/typed_interface.py +260 -0
- eval_protocol/utils/__init__.py +8 -0
- eval_protocol/utils/batch_evaluation.py +217 -0
- eval_protocol/utils/batch_transformation.py +205 -0
- eval_protocol/utils/dataset_helpers.py +112 -0
- eval_protocol/utils/module_loader.py +56 -0
- eval_protocol/utils/packaging_utils.py +108 -0
- eval_protocol/utils/static_policy.py +305 -0
- eval_protocol-0.0.3.dist-info/METADATA +635 -0
- eval_protocol-0.0.3.dist-info/RECORD +130 -0
- eval_protocol-0.0.3.dist-info/WHEEL +5 -0
- eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
- eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
- eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model clients for generating responses from various LLM APIs.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import abc
|
|
6
|
+
import asyncio
|
|
7
|
+
import json # For parsing content as JSON
|
|
8
|
+
import logging
|
|
9
|
+
import uuid # For generating tool call IDs if not provided
|
|
10
|
+
from typing import Any, Dict, List, Optional
|
|
11
|
+
|
|
12
|
+
import aiohttp
|
|
13
|
+
from omegaconf import DictConfig
|
|
14
|
+
from pydantic import BaseModel, Field # Added for new models
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# Pydantic models for structured tool calls and generation results
|
|
20
|
+
class ToolCallFunction(BaseModel):
|
|
21
|
+
name: str
|
|
22
|
+
arguments: str # Should be a JSON string
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ToolCall(BaseModel):
|
|
26
|
+
id: str
|
|
27
|
+
type: str = "function" # OpenAI default
|
|
28
|
+
function: ToolCallFunction
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class GenerationResult(BaseModel):
|
|
32
|
+
content: Optional[str] = None
|
|
33
|
+
tool_calls: Optional[List[ToolCall]] = None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ModelClient(abc.ABC):
|
|
37
|
+
"""Abstract base class for model clients."""
|
|
38
|
+
|
|
39
|
+
def __init__(self, client_config: DictConfig):
|
|
40
|
+
self.client_config = client_config
|
|
41
|
+
self.model_name = client_config.get("model_name", "unknown_model")
|
|
42
|
+
self.temperature = client_config.get("temperature", 0.0)
|
|
43
|
+
self.max_tokens = client_config.get("max_tokens", 1024)
|
|
44
|
+
self.top_p = client_config.get("top_p", 0.95)
|
|
45
|
+
self.top_k = client_config.get("top_k", 20)
|
|
46
|
+
self.min_p = client_config.get("min_p", 0.0)
|
|
47
|
+
# Add reasoning_effort, defaulting to None if not specified in config
|
|
48
|
+
self.reasoning_effort = client_config.get("reasoning_effort", None)
|
|
49
|
+
|
|
50
|
+
@abc.abstractmethod
|
|
51
|
+
async def generate(
|
|
52
|
+
self,
|
|
53
|
+
messages: List[Dict[str, str]],
|
|
54
|
+
session: aiohttp.ClientSession,
|
|
55
|
+
tools: Optional[List[Dict[str, Any]]] = None, # Added tools parameter
|
|
56
|
+
) -> GenerationResult: # Changed return type
|
|
57
|
+
"""Generates a response from the model given a list of messages."""
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class FireworksModelClient(ModelClient):
|
|
62
|
+
"""Client for Fireworks AI models."""
|
|
63
|
+
|
|
64
|
+
def __init__(self, client_config: DictConfig, api_key: str):
|
|
65
|
+
super().__init__(client_config)
|
|
66
|
+
self.api_key = api_key
|
|
67
|
+
self.api_base = client_config.get("api_base", "https://api.fireworks.ai/inference/v1")
|
|
68
|
+
# TODO: Initialize rate limiter, retry policy from client_config.api_params
|
|
69
|
+
|
|
70
|
+
async def generate(
|
|
71
|
+
self,
|
|
72
|
+
messages: List[Dict[str, str]],
|
|
73
|
+
session: aiohttp.ClientSession,
|
|
74
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
75
|
+
) -> GenerationResult:
|
|
76
|
+
url = f"{self.api_base}/chat/completions"
|
|
77
|
+
|
|
78
|
+
payload: Dict[str, Any] = {
|
|
79
|
+
"model": self.model_name,
|
|
80
|
+
"messages": messages,
|
|
81
|
+
"temperature": self.temperature,
|
|
82
|
+
"max_tokens": self.max_tokens,
|
|
83
|
+
}
|
|
84
|
+
if self.top_p is not None:
|
|
85
|
+
payload["top_p"] = self.top_p
|
|
86
|
+
|
|
87
|
+
if tools:
|
|
88
|
+
payload["tools"] = tools
|
|
89
|
+
# Fireworks API might use "function" or "any" or specific tool name for tool_choice.
|
|
90
|
+
# "auto" is common for OpenAI. If Fireworks needs specific, this might need adjustment.
|
|
91
|
+
# Or, if it's like older OpenAI, it might not use tool_choice if tools are present.
|
|
92
|
+
# For now, let's assume "auto" or that it's implicit if "tools" is provided.
|
|
93
|
+
# The user's log shows the LLM is attempting tool calls even with the simpler prompt,
|
|
94
|
+
# implying the `tools` parameter is having an effect or the model is well-primed.
|
|
95
|
+
payload["tool_choice"] = "auto"
|
|
96
|
+
|
|
97
|
+
headers = {
|
|
98
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
99
|
+
"Content-Type": "application/json",
|
|
100
|
+
"Accept": "application/json",
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
debug_payload_log = json.loads(json.dumps(payload))
|
|
104
|
+
if "messages" in debug_payload_log and debug_payload_log["messages"]:
|
|
105
|
+
if debug_payload_log["messages"][-1].get("content"): # Check if content exists
|
|
106
|
+
debug_payload_log["messages"][-1]["content"] = (
|
|
107
|
+
str(debug_payload_log["messages"][-1]["content"])[:50] + "..."
|
|
108
|
+
)
|
|
109
|
+
logger.debug(f"Calling Fireworks API: {url}, Payload: {debug_payload_log}")
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
for attempt in range(self.client_config.get("api_params", {}).get("max_retries", 3) + 1):
|
|
113
|
+
async with session.post(url, json=payload, headers=headers) as response:
|
|
114
|
+
if response.status == 200:
|
|
115
|
+
data = await response.json()
|
|
116
|
+
if data.get("choices") and len(data["choices"]) > 0:
|
|
117
|
+
choice = data["choices"][0]
|
|
118
|
+
message = choice.get("message", {})
|
|
119
|
+
|
|
120
|
+
# 1. Check for native OpenAI-style tool_calls field
|
|
121
|
+
if message.get("tool_calls"):
|
|
122
|
+
tool_calls_data = message["tool_calls"]
|
|
123
|
+
parsed_tool_calls = []
|
|
124
|
+
for tc_data in tool_calls_data:
|
|
125
|
+
if tc_data.get("type") == "function" and tc_data.get("function"):
|
|
126
|
+
parsed_tool_calls.append(
|
|
127
|
+
ToolCall(
|
|
128
|
+
id=tc_data.get(
|
|
129
|
+
"id", f"call_{uuid.uuid4().hex[:8]}"
|
|
130
|
+
), # Generate ID if missing
|
|
131
|
+
type="function",
|
|
132
|
+
function=ToolCallFunction(
|
|
133
|
+
name=tc_data["function"]["name"],
|
|
134
|
+
arguments=tc_data["function"]["arguments"],
|
|
135
|
+
),
|
|
136
|
+
)
|
|
137
|
+
)
|
|
138
|
+
if parsed_tool_calls:
|
|
139
|
+
logger.debug(f"Parsed native tool_calls: {parsed_tool_calls}")
|
|
140
|
+
return GenerationResult(tool_calls=parsed_tool_calls)
|
|
141
|
+
|
|
142
|
+
# 2. If no native tool_calls, check if content is a JSON string representing a tool call
|
|
143
|
+
# This handles the case where the LLM puts the tool call JSON into the content field.
|
|
144
|
+
# The user's log shows content like: "{\"type\": \"function\", \"name\": \"move_file\", ...}"
|
|
145
|
+
if message.get("content"):
|
|
146
|
+
content_str = message["content"]
|
|
147
|
+
try:
|
|
148
|
+
# Attempt to parse content as JSON
|
|
149
|
+
potential_tool_call_data = json.loads(content_str)
|
|
150
|
+
|
|
151
|
+
# Check if it matches the OpenAI tool call structure (single call in content)
|
|
152
|
+
# e.g., {"type": "function", "function": {"name": "...", "arguments": "{...}"}}
|
|
153
|
+
# or the structure the LLM actually produced: {"type": "function", "name": "...", "parameters": {...}}
|
|
154
|
+
|
|
155
|
+
parsed_tool_calls_from_content = []
|
|
156
|
+
# Handle if content is a list of tool calls (less likely but possible)
|
|
157
|
+
if isinstance(potential_tool_call_data, list):
|
|
158
|
+
data_to_check = potential_tool_call_data
|
|
159
|
+
else: # Assume it's a single tool call object
|
|
160
|
+
data_to_check = [potential_tool_call_data]
|
|
161
|
+
|
|
162
|
+
for item in data_to_check:
|
|
163
|
+
if isinstance(item, dict) and item.get("type") == "function":
|
|
164
|
+
func_details = item.get("function") # OpenAI style
|
|
165
|
+
if func_details and "name" in func_details and "arguments" in func_details:
|
|
166
|
+
parsed_tool_calls_from_content.append(
|
|
167
|
+
ToolCall(
|
|
168
|
+
id=item.get(
|
|
169
|
+
"id",
|
|
170
|
+
f"call_{uuid.uuid4().hex[:8]}",
|
|
171
|
+
),
|
|
172
|
+
type="function",
|
|
173
|
+
function=ToolCallFunction(
|
|
174
|
+
name=func_details["name"],
|
|
175
|
+
arguments=func_details["arguments"],
|
|
176
|
+
),
|
|
177
|
+
)
|
|
178
|
+
)
|
|
179
|
+
continue # Found valid OpenAI style tool call
|
|
180
|
+
|
|
181
|
+
# Check for the LLM's observed output format: {"type": "function", "name": ..., "parameters": ...}
|
|
182
|
+
# This is slightly different from OpenAI's `function.arguments` being a string.
|
|
183
|
+
# Here, `parameters` is an object. We need to dump it to string for `ToolCallFunction.arguments`.
|
|
184
|
+
llm_name = item.get("name")
|
|
185
|
+
llm_params = item.get("parameters")
|
|
186
|
+
if llm_name and isinstance(llm_params, dict):
|
|
187
|
+
parsed_tool_calls_from_content.append(
|
|
188
|
+
ToolCall(
|
|
189
|
+
id=item.get(
|
|
190
|
+
"id",
|
|
191
|
+
f"call_{uuid.uuid4().hex[:8]}",
|
|
192
|
+
), # Generate an ID
|
|
193
|
+
type="function",
|
|
194
|
+
function=ToolCallFunction(
|
|
195
|
+
name=llm_name,
|
|
196
|
+
arguments=json.dumps(llm_params),
|
|
197
|
+
),
|
|
198
|
+
)
|
|
199
|
+
)
|
|
200
|
+
continue # Found valid LLM-specific style tool call
|
|
201
|
+
|
|
202
|
+
if parsed_tool_calls_from_content:
|
|
203
|
+
logger.debug(
|
|
204
|
+
f"Parsed tool_calls from content field: {parsed_tool_calls_from_content}"
|
|
205
|
+
)
|
|
206
|
+
return GenerationResult(tool_calls=parsed_tool_calls_from_content)
|
|
207
|
+
|
|
208
|
+
# If JSON but not a recognized tool call, it's just JSON content
|
|
209
|
+
logger.debug(
|
|
210
|
+
"Content was JSON, but not a recognized tool call structure. Treating as text."
|
|
211
|
+
)
|
|
212
|
+
return GenerationResult(content=content_str.strip())
|
|
213
|
+
|
|
214
|
+
except json.JSONDecodeError:
|
|
215
|
+
# Content is not JSON, so it's a regular text response
|
|
216
|
+
logger.debug("Content is not JSON. Treating as text.")
|
|
217
|
+
return GenerationResult(content=content_str.strip())
|
|
218
|
+
|
|
219
|
+
# If neither tool_calls nor parsable content that looks like a tool call
|
|
220
|
+
logger.warning(f"Fireworks API response malformed or no actionable content/tool_calls: {data}")
|
|
221
|
+
return GenerationResult()
|
|
222
|
+
|
|
223
|
+
# ... (rest of the error handling as before) ...
|
|
224
|
+
elif response.status == 429: # Rate limit
|
|
225
|
+
retry_after = int(response.headers.get("Retry-After", "5"))
|
|
226
|
+
logger.warning(f"Rate limited. Retrying after {retry_after}s (attempt {attempt+1}).")
|
|
227
|
+
await asyncio.sleep(retry_after)
|
|
228
|
+
elif response.status in [401, 403]: # Auth errors
|
|
229
|
+
error_text = await response.text()
|
|
230
|
+
logger.error(f"Fireworks API Auth Error ({response.status}): {error_text}")
|
|
231
|
+
return GenerationResult() # Empty result on auth error
|
|
232
|
+
elif response.status >= 500: # Server errors
|
|
233
|
+
logger.warning(
|
|
234
|
+
f"Fireworks API Server Error ({response.status}). Retrying (attempt {attempt+1})."
|
|
235
|
+
)
|
|
236
|
+
await asyncio.sleep(2**attempt)
|
|
237
|
+
else: # Other client errors
|
|
238
|
+
error_text = await response.text()
|
|
239
|
+
logger.error(f"Fireworks API request failed ({response.status}): {error_text}")
|
|
240
|
+
return GenerationResult() # Empty result
|
|
241
|
+
logger.error("Max retries reached for Fireworks API call.")
|
|
242
|
+
return GenerationResult()
|
|
243
|
+
except aiohttp.ClientError as e:
|
|
244
|
+
logger.error(f"AIOHTTP client error: {e}")
|
|
245
|
+
return GenerationResult()
|
|
246
|
+
except Exception as e:
|
|
247
|
+
logger.error(f"Unexpected error in FireworksModelClient: {e}", exc_info=True)
|
|
248
|
+
return GenerationResult()
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import os
|
|
3
|
+
from typing import Any, Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
import uvicorn
|
|
6
|
+
from fastapi import Depends, FastAPI, HTTPException, Request
|
|
7
|
+
from pydantic import BaseModel, ValidationError
|
|
8
|
+
|
|
9
|
+
# Assuming these models are correctly defined in eval_protocol.models
|
|
10
|
+
from eval_protocol.models import EvaluateResult, Message
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# --- Request and Response Models ---
|
|
14
|
+
class EvaluationRequest(BaseModel):
|
|
15
|
+
messages: List[Dict[str, Any]] # Could also be List[Message] if we enforce that model on input
|
|
16
|
+
ground_truth: Optional[str] = None
|
|
17
|
+
kwargs: Optional[Dict[str, Any]] = {}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# --- Global variable to store the loaded reward function ---
|
|
21
|
+
# This is a simple approach for a single-function server.
|
|
22
|
+
# If multiple functions were to be served by one instance, a different mechanism would be needed.
|
|
23
|
+
_LOADED_REWARD_FUNCTION = None
|
|
24
|
+
_REWARD_FUNCTION_NAME = "N/A"
|
|
25
|
+
|
|
26
|
+
# --- API Key Authentication Dependency ---
|
|
27
|
+
EXPECTED_API_KEY = os.environ.get("RK_ENDPOINT_API_KEY")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
async def verify_api_key(request: Request):
|
|
31
|
+
if EXPECTED_API_KEY:
|
|
32
|
+
# Check for X-Api-Key header first
|
|
33
|
+
client_api_key = request.headers.get("X-Api-Key")
|
|
34
|
+
# If not found, check for Authorization: Bearer <key>
|
|
35
|
+
if not client_api_key:
|
|
36
|
+
auth_header = request.headers.get("Authorization")
|
|
37
|
+
if auth_header and auth_header.startswith("Bearer "):
|
|
38
|
+
client_api_key = auth_header.split(" ", 1)[1]
|
|
39
|
+
|
|
40
|
+
if not client_api_key:
|
|
41
|
+
raise HTTPException(status_code=401, detail="API key required but not provided.")
|
|
42
|
+
if client_api_key != EXPECTED_API_KEY:
|
|
43
|
+
raise HTTPException(status_code=403, detail="Invalid API key.")
|
|
44
|
+
return True # Allow request if no key expected or if key is valid
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
# --- FastAPI App ---
|
|
48
|
+
app = FastAPI(
|
|
49
|
+
title="Reward Kit Generic Reward Function Server",
|
|
50
|
+
description="Serves a dynamically loaded reward function.",
|
|
51
|
+
version="0.1.0", # Or use eval_protocol.__version__
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@app.post("/evaluate", response_model=EvaluateResult, dependencies=[Depends(verify_api_key)])
|
|
56
|
+
async def evaluate_endpoint(request: EvaluationRequest):
|
|
57
|
+
"""
|
|
58
|
+
Endpoint to evaluate a given set of messages using the loaded reward function.
|
|
59
|
+
Requires API key if RK_ENDPOINT_API_KEY environment variable is set.
|
|
60
|
+
"""
|
|
61
|
+
if _LOADED_REWARD_FUNCTION is None:
|
|
62
|
+
raise HTTPException(status_code=500, detail="Reward function not loaded.")
|
|
63
|
+
|
|
64
|
+
try:
|
|
65
|
+
# The user's reward function is expected to match the @reward_function signature
|
|
66
|
+
func_args = {
|
|
67
|
+
"messages": request.messages,
|
|
68
|
+
"ground_truth": request.ground_truth,
|
|
69
|
+
**(request.kwargs or {}),
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
result = _LOADED_REWARD_FUNCTION(**func_args)
|
|
73
|
+
|
|
74
|
+
if not isinstance(result, EvaluateResult):
|
|
75
|
+
# This case should ideally not happen if functions are correctly decorated
|
|
76
|
+
# and return EvaluateResult, but good to have a fallback.
|
|
77
|
+
print(
|
|
78
|
+
f"Warning: Reward function '{_REWARD_FUNCTION_NAME}' did not return an EvaluateResult instance. Type: {type(result)}"
|
|
79
|
+
)
|
|
80
|
+
# Attempt to construct an EvaluateResult if it's a dict-like object,
|
|
81
|
+
# otherwise, this will raise an error or return a poorly formed response.
|
|
82
|
+
# For robustness, one might want to wrap this in another try-except.
|
|
83
|
+
return EvaluateResult(
|
|
84
|
+
score=0.0,
|
|
85
|
+
reason="Invalid return type from reward function, check server logs.",
|
|
86
|
+
is_score_valid=False,
|
|
87
|
+
metrics={},
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
return result
|
|
91
|
+
except ValidationError as ve: # Pydantic validation error from reward function's input/output
|
|
92
|
+
print(f"Validation Error calling reward function '{_REWARD_FUNCTION_NAME}': {ve}")
|
|
93
|
+
raise HTTPException(
|
|
94
|
+
status_code=422,
|
|
95
|
+
detail=f"Input/Output validation error for reward function: {ve.errors()}",
|
|
96
|
+
)
|
|
97
|
+
except Exception as e:
|
|
98
|
+
print(f"Error during evaluation with reward function '{_REWARD_FUNCTION_NAME}': {e}")
|
|
99
|
+
# Consider logging the full traceback here
|
|
100
|
+
raise HTTPException(status_code=500, detail=f"Internal server error during evaluation: {str(e)}")
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@app.get("/health")
|
|
104
|
+
async def health_check():
|
|
105
|
+
"""
|
|
106
|
+
Health check endpoint.
|
|
107
|
+
"""
|
|
108
|
+
if _LOADED_REWARD_FUNCTION:
|
|
109
|
+
return {"status": "ok", "reward_function": _REWARD_FUNCTION_NAME}
|
|
110
|
+
else:
|
|
111
|
+
return {"status": "error", "reason": "Reward function not loaded"}
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def load_reward_function(import_string: str):
|
|
115
|
+
"""
|
|
116
|
+
Loads a reward function from an import string (e.g., 'my_module.my_function').
|
|
117
|
+
"""
|
|
118
|
+
global _LOADED_REWARD_FUNCTION, _REWARD_FUNCTION_NAME
|
|
119
|
+
try:
|
|
120
|
+
module_path, function_name = import_string.rsplit(".", 1)
|
|
121
|
+
module = importlib.import_module(module_path)
|
|
122
|
+
_LOADED_REWARD_FUNCTION = getattr(module, function_name)
|
|
123
|
+
_REWARD_FUNCTION_NAME = import_string
|
|
124
|
+
print(f"Successfully loaded reward function: {_REWARD_FUNCTION_NAME}")
|
|
125
|
+
except Exception as e:
|
|
126
|
+
print(f"Error loading reward function from '{import_string}': {e}")
|
|
127
|
+
_LOADED_REWARD_FUNCTION = None
|
|
128
|
+
_REWARD_FUNCTION_NAME = "Error loading"
|
|
129
|
+
raise # Re-raise to make it fatal if loading fails on startup
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
if __name__ == "__main__":
|
|
133
|
+
import argparse
|
|
134
|
+
|
|
135
|
+
parser = argparse.ArgumentParser(description="Run the Generic Reward Function Server.")
|
|
136
|
+
parser.add_argument(
|
|
137
|
+
"import_string",
|
|
138
|
+
type=str,
|
|
139
|
+
help="Import string for the reward function (e.g., 'my_package.my_module.my_reward_function')",
|
|
140
|
+
)
|
|
141
|
+
parser.add_argument("--host", type=str, default="127.0.0.1", help="Host to bind the server to.")
|
|
142
|
+
parser.add_argument(
|
|
143
|
+
"--port",
|
|
144
|
+
type=int,
|
|
145
|
+
default=8080, # Standard port for Cloud Run, etc.
|
|
146
|
+
help="Port to bind the server to.",
|
|
147
|
+
)
|
|
148
|
+
# Add --reload for uvicorn if needed for development
|
|
149
|
+
# parser.add_argument("--reload", action="store_true", help="Enable auto-reload for development.")
|
|
150
|
+
|
|
151
|
+
args = parser.parse_args()
|
|
152
|
+
|
|
153
|
+
try:
|
|
154
|
+
load_reward_function(args.import_string)
|
|
155
|
+
except Exception:
|
|
156
|
+
print(f"Failed to load reward function. Exiting.")
|
|
157
|
+
exit(1)
|
|
158
|
+
|
|
159
|
+
if not _LOADED_REWARD_FUNCTION:
|
|
160
|
+
print(f"Reward function {_REWARD_FUNCTION_NAME} could not be loaded. Server will not start correctly.")
|
|
161
|
+
# Depending on desired behavior, could exit here or let it run and fail on /evaluate
|
|
162
|
+
exit(1)
|
|
163
|
+
|
|
164
|
+
print(f"Starting server for reward function: {args.import_string} on http://{args.host}:{args.port}")
|
|
165
|
+
uvicorn.run(app, host=args.host, port=args.port) # reload=args.reload for dev
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Integration helpers for Reward Kit."""
|
|
2
|
+
|
|
3
|
+
from .braintrust import reward_fn_to_scorer, scorer_to_reward_fn
|
|
4
|
+
from .openeval import adapt
|
|
5
|
+
from .trl import create_trl_adapter
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"adapt",
|
|
9
|
+
"scorer_to_reward_fn",
|
|
10
|
+
"reward_fn_to_scorer",
|
|
11
|
+
"create_trl_adapter",
|
|
12
|
+
]
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""Adapters for integrating Reward Kit with Braintrust scoring functions."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Callable, List, Optional
|
|
4
|
+
|
|
5
|
+
from eval_protocol.models import EvaluateResult, Message
|
|
6
|
+
from eval_protocol.typed_interface import reward_function
|
|
7
|
+
|
|
8
|
+
# Type alias for Braintrust scoring functions
|
|
9
|
+
BraintrustScorer = Callable[[Any, Any, Any], float]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def scorer_to_reward_fn(
|
|
13
|
+
scorer: BraintrustScorer,
|
|
14
|
+
*,
|
|
15
|
+
messages_to_input: Optional[Callable[[List[Message]], Any]] = None,
|
|
16
|
+
ground_truth_to_expected: Optional[Callable[[List[Message]], Any]] = None,
|
|
17
|
+
) -> Callable[[List[Message], Optional[List[Message]]], EvaluateResult]:
|
|
18
|
+
"""Wrap a Braintrust scorer as a Reward Kit reward function."""
|
|
19
|
+
|
|
20
|
+
@reward_function
|
|
21
|
+
def reward_fn(messages: List[Message], ground_truth: Optional[List[Message]] = None, **kwargs) -> EvaluateResult:
|
|
22
|
+
input_val = messages_to_input(messages) if messages_to_input else messages[0].content
|
|
23
|
+
output_val = messages[-1].content
|
|
24
|
+
expected_val = None
|
|
25
|
+
if ground_truth:
|
|
26
|
+
expected_val = (
|
|
27
|
+
ground_truth_to_expected(ground_truth) if ground_truth_to_expected else ground_truth[-1].content
|
|
28
|
+
)
|
|
29
|
+
score = scorer(input_val, output_val, expected_val)
|
|
30
|
+
return EvaluateResult(score=score)
|
|
31
|
+
|
|
32
|
+
return reward_fn
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def reward_fn_to_scorer(
|
|
36
|
+
reward_fn: Callable[[List[Message], Optional[List[Message]]], EvaluateResult],
|
|
37
|
+
) -> BraintrustScorer:
|
|
38
|
+
"""Create a Braintrust-compatible scorer from a Reward Kit reward function."""
|
|
39
|
+
|
|
40
|
+
def scorer(input_val: Any, output: Any, expected: Any) -> float:
|
|
41
|
+
messages = [
|
|
42
|
+
Message(role="user", content=str(input_val)),
|
|
43
|
+
Message(role="assistant", content=str(output)),
|
|
44
|
+
]
|
|
45
|
+
ground_truth = None
|
|
46
|
+
if expected is not None:
|
|
47
|
+
ground_truth = [Message(role="assistant", content=str(expected))]
|
|
48
|
+
result = reward_fn(messages=messages, ground_truth=ground_truth)
|
|
49
|
+
return result.score
|
|
50
|
+
|
|
51
|
+
return scorer
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional
|
|
2
|
+
|
|
3
|
+
from eval_protocol.models import EvaluateResult, MetricResult
|
|
4
|
+
from eval_protocol.typed_interface import reward_function
|
|
5
|
+
|
|
6
|
+
__all__ = ["adapt_metric"]
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
from deepeval.metrics.base_metric import BaseConversationalMetric, BaseMetric
|
|
10
|
+
from deepeval.test_case import ConversationalTestCase, LLMTestCase
|
|
11
|
+
except Exception: # pragma: no cover - deepeval is optional
|
|
12
|
+
BaseMetric = None
|
|
13
|
+
BaseConversationalMetric = None
|
|
14
|
+
LLMTestCase = None
|
|
15
|
+
ConversationalTestCase = None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _metric_name(metric: Any) -> str:
|
|
19
|
+
name = getattr(metric, "__name__", None)
|
|
20
|
+
if name and name not in {
|
|
21
|
+
"Base Metric",
|
|
22
|
+
"Base Conversational Metric",
|
|
23
|
+
"Base Multimodal Metric",
|
|
24
|
+
}:
|
|
25
|
+
return str(name)
|
|
26
|
+
name = getattr(metric, "name", None)
|
|
27
|
+
if name:
|
|
28
|
+
return str(name)
|
|
29
|
+
return metric.__class__.__name__
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def adapt_metric(metric: Any):
|
|
33
|
+
"""Adapt a deepeval metric object into a reward-kit reward function."""
|
|
34
|
+
|
|
35
|
+
@reward_function
|
|
36
|
+
def wrapped(
|
|
37
|
+
messages: List[Dict[str, Any]],
|
|
38
|
+
ground_truth: Optional[str] = None,
|
|
39
|
+
**kwargs: Any,
|
|
40
|
+
) -> EvaluateResult:
|
|
41
|
+
if BaseMetric is None or LLMTestCase is None:
|
|
42
|
+
raise ImportError("deepeval must be installed to use this integration")
|
|
43
|
+
if not messages:
|
|
44
|
+
return EvaluateResult(score=0.0, reason="No messages", metrics={})
|
|
45
|
+
|
|
46
|
+
output = messages[-1].get("content", "")
|
|
47
|
+
input_msg = ""
|
|
48
|
+
if len(messages) >= 2:
|
|
49
|
+
input_msg = messages[-2].get("content", "")
|
|
50
|
+
|
|
51
|
+
def _build_case_kwargs() -> Dict[str, Any]:
|
|
52
|
+
case_kwargs: Dict[str, Any] = {}
|
|
53
|
+
params = getattr(metric, "evaluation_params", None)
|
|
54
|
+
if params:
|
|
55
|
+
for param in params:
|
|
56
|
+
if param.value == "input":
|
|
57
|
+
case_kwargs["input"] = input_msg
|
|
58
|
+
elif param.value == "actual_output":
|
|
59
|
+
case_kwargs["actual_output"] = output
|
|
60
|
+
elif param.value == "expected_output":
|
|
61
|
+
case_kwargs["expected_output"] = ground_truth
|
|
62
|
+
elif param.value == "context":
|
|
63
|
+
case_kwargs["context"] = kwargs.get("context")
|
|
64
|
+
elif param.value == "retrieval_context":
|
|
65
|
+
case_kwargs["retrieval_context"] = kwargs.get("retrieval_context")
|
|
66
|
+
elif param.value == "tools_called":
|
|
67
|
+
case_kwargs["tools_called"] = kwargs.get("tools_called")
|
|
68
|
+
elif param.value == "expected_tools":
|
|
69
|
+
case_kwargs["expected_tools"] = kwargs.get("expected_tools")
|
|
70
|
+
else:
|
|
71
|
+
case_kwargs = {
|
|
72
|
+
"input": input_msg,
|
|
73
|
+
"actual_output": output,
|
|
74
|
+
"expected_output": ground_truth,
|
|
75
|
+
}
|
|
76
|
+
if "input" not in case_kwargs:
|
|
77
|
+
case_kwargs["input"] = input_msg
|
|
78
|
+
if "actual_output" not in case_kwargs:
|
|
79
|
+
case_kwargs["actual_output"] = output
|
|
80
|
+
return case_kwargs
|
|
81
|
+
|
|
82
|
+
if isinstance(metric, BaseConversationalMetric):
|
|
83
|
+
turns = []
|
|
84
|
+
for i, msg in enumerate(messages):
|
|
85
|
+
turn_input = messages[i - 1].get("content", "") if i > 0 else ""
|
|
86
|
+
output_turn = msg.get("content", "")
|
|
87
|
+
input_msg_backup = input_msg
|
|
88
|
+
input_msg = turn_input
|
|
89
|
+
output = output_turn
|
|
90
|
+
turn_kwargs = _build_case_kwargs()
|
|
91
|
+
turns.append(LLMTestCase(**turn_kwargs))
|
|
92
|
+
input_msg = input_msg_backup
|
|
93
|
+
output = messages[-1].get("content", "")
|
|
94
|
+
test_case = ConversationalTestCase(turns=turns)
|
|
95
|
+
else:
|
|
96
|
+
case_kwargs = _build_case_kwargs()
|
|
97
|
+
test_case = LLMTestCase(**case_kwargs)
|
|
98
|
+
|
|
99
|
+
metric.measure(test_case, **kwargs)
|
|
100
|
+
score = float(metric.score or 0.0)
|
|
101
|
+
reason = getattr(metric, "reason", None)
|
|
102
|
+
name = _metric_name(metric)
|
|
103
|
+
metrics = {name: MetricResult(score=score, reason=reason or "", is_score_valid=True)}
|
|
104
|
+
return EvaluateResult(score=score, reason=reason, metrics=metrics)
|
|
105
|
+
|
|
106
|
+
return wrapped
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from typing import Any, Callable, Dict, List, Union
|
|
2
|
+
|
|
3
|
+
from eval_protocol.models import EvaluateResult, MetricResult
|
|
4
|
+
from eval_protocol.typed_interface import reward_function
|
|
5
|
+
|
|
6
|
+
__all__ = ["adapt"]
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _convert_result(res: Dict[str, Any]) -> EvaluateResult:
|
|
10
|
+
score = float(res.get("score", 0.0))
|
|
11
|
+
reason = res.get("comment")
|
|
12
|
+
key = res.get("key", "openeval")
|
|
13
|
+
metrics = {key: MetricResult(score=score, reason=reason or "", is_score_valid=True)}
|
|
14
|
+
return EvaluateResult(score=score, reason=reason, metrics=metrics)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def adapt(openeval_fn: Callable[..., Union[Dict[str, Any], List[Dict[str, Any]]]]):
|
|
18
|
+
"""Adapt an OpenEvals evaluator into a reward-kit reward function."""
|
|
19
|
+
|
|
20
|
+
@reward_function
|
|
21
|
+
def wrapped(
|
|
22
|
+
messages: List[Dict[str, str]],
|
|
23
|
+
ground_truth: Union[str, List[Dict[str, str]], None] = None,
|
|
24
|
+
**kwargs: Any,
|
|
25
|
+
) -> EvaluateResult:
|
|
26
|
+
if not messages:
|
|
27
|
+
return EvaluateResult(score=0.0, reason="No messages", metrics={})
|
|
28
|
+
output = messages[-1].get("content", "")
|
|
29
|
+
reference = None
|
|
30
|
+
if isinstance(ground_truth, list):
|
|
31
|
+
if ground_truth:
|
|
32
|
+
reference = ground_truth[-1].get("content")
|
|
33
|
+
else:
|
|
34
|
+
reference = ground_truth
|
|
35
|
+
res = openeval_fn(outputs=output, reference_outputs=reference, **kwargs)
|
|
36
|
+
if isinstance(res, list):
|
|
37
|
+
res = res[0]
|
|
38
|
+
return _convert_result(res)
|
|
39
|
+
|
|
40
|
+
return wrapped
|