eval-protocol 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- development/__init__.py +1 -0
- development/normalize_sandbox_fusion.py +628 -0
- development/utils/__init__.py +1 -0
- development/utils/generate_api_key.py +31 -0
- development/utils/subprocess_manager.py +481 -0
- eval_protocol/__init__.py +86 -0
- eval_protocol/__main__.py +10 -0
- eval_protocol/_version.py +21 -0
- eval_protocol/adapters/__init__.py +1 -0
- eval_protocol/adapters/braintrust.py +8 -0
- eval_protocol/adapters/trl.py +8 -0
- eval_protocol/agent/__init__.py +29 -0
- eval_protocol/agent/models.py +69 -0
- eval_protocol/agent/orchestrator.py +893 -0
- eval_protocol/agent/resource_abc.py +89 -0
- eval_protocol/agent/resource_pool.py +184 -0
- eval_protocol/agent/resources/__init__.py +44 -0
- eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
- eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
- eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
- eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
- eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
- eval_protocol/agent/resources/docker_resource.py +479 -0
- eval_protocol/agent/resources/filesystem_resource.py +371 -0
- eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
- eval_protocol/agent/resources/http_rollout_resource.py +325 -0
- eval_protocol/agent/resources/python_state_resource.py +170 -0
- eval_protocol/agent/resources/sql_resource.py +271 -0
- eval_protocol/agent/task_manager.py +1064 -0
- eval_protocol/agent/tool_registry.py +111 -0
- eval_protocol/auth.py +156 -0
- eval_protocol/cli.py +425 -0
- eval_protocol/cli_commands/__init__.py +1 -0
- eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
- eval_protocol/cli_commands/common.py +242 -0
- eval_protocol/cli_commands/deploy.py +486 -0
- eval_protocol/cli_commands/deploy_mcp.py +287 -0
- eval_protocol/cli_commands/preview.py +186 -0
- eval_protocol/cli_commands/run_eval_cmd.py +202 -0
- eval_protocol/common_utils.py +36 -0
- eval_protocol/config.py +180 -0
- eval_protocol/datasets/__init__.py +1 -0
- eval_protocol/datasets/loader.py +521 -0
- eval_protocol/evaluation.py +1045 -0
- eval_protocol/execution/__init__.py +1 -0
- eval_protocol/execution/pipeline.py +920 -0
- eval_protocol/gcp_tools.py +484 -0
- eval_protocol/generation/cache.py +141 -0
- eval_protocol/generation/clients/base.py +67 -0
- eval_protocol/generation/clients.py +248 -0
- eval_protocol/generic_server.py +165 -0
- eval_protocol/integrations/__init__.py +12 -0
- eval_protocol/integrations/braintrust.py +51 -0
- eval_protocol/integrations/deepeval.py +106 -0
- eval_protocol/integrations/openeval.py +40 -0
- eval_protocol/integrations/trl.py +187 -0
- eval_protocol/mcp/__init__.py +48 -0
- eval_protocol/mcp/adapter.py +131 -0
- eval_protocol/mcp/client/__init__.py +12 -0
- eval_protocol/mcp/client/connection.py +499 -0
- eval_protocol/mcp/clients.py +195 -0
- eval_protocol/mcp/execution/__init__.py +23 -0
- eval_protocol/mcp/execution/base_policy.py +227 -0
- eval_protocol/mcp/execution/fireworks_policy.py +209 -0
- eval_protocol/mcp/execution/manager.py +506 -0
- eval_protocol/mcp/execution/policy.py +421 -0
- eval_protocol/mcp/grid_renderer.py +54 -0
- eval_protocol/mcp/mcpgym.py +637 -0
- eval_protocol/mcp/process_manager.py +177 -0
- eval_protocol/mcp/session/__init__.py +11 -0
- eval_protocol/mcp/session/manager.py +228 -0
- eval_protocol/mcp/simple_process_manager.py +291 -0
- eval_protocol/mcp/simulation_server.py +458 -0
- eval_protocol/mcp/types.py +80 -0
- eval_protocol/mcp_agent/__init__.py +1 -0
- eval_protocol/mcp_agent/config.py +147 -0
- eval_protocol/mcp_agent/intermediary_server.py +542 -0
- eval_protocol/mcp_agent/main.py +210 -0
- eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
- eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
- eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
- eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
- eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
- eval_protocol/mcp_agent/session.py +79 -0
- eval_protocol/mcp_env.py +304 -0
- eval_protocol/models.py +366 -0
- eval_protocol/packaging.py +219 -0
- eval_protocol/platform_api.py +360 -0
- eval_protocol/playback_policy.py +396 -0
- eval_protocol/resources.py +128 -0
- eval_protocol/reward_function.py +410 -0
- eval_protocol/rewards/__init__.py +94 -0
- eval_protocol/rewards/accuracy.py +454 -0
- eval_protocol/rewards/accuracy_length.py +173 -0
- eval_protocol/rewards/apps_coding_reward.py +331 -0
- eval_protocol/rewards/apps_execution_utils.py +149 -0
- eval_protocol/rewards/apps_testing_util.py +559 -0
- eval_protocol/rewards/bfcl_reward.py +313 -0
- eval_protocol/rewards/code_execution.py +1620 -0
- eval_protocol/rewards/code_execution_utils.py +72 -0
- eval_protocol/rewards/cpp_code.py +861 -0
- eval_protocol/rewards/deepcoder_reward.py +161 -0
- eval_protocol/rewards/format.py +129 -0
- eval_protocol/rewards/function_calling.py +541 -0
- eval_protocol/rewards/json_schema.py +422 -0
- eval_protocol/rewards/language_consistency.py +700 -0
- eval_protocol/rewards/lean_prover.py +479 -0
- eval_protocol/rewards/length.py +375 -0
- eval_protocol/rewards/list_comparison_math_reward.py +221 -0
- eval_protocol/rewards/math.py +762 -0
- eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
- eval_protocol/rewards/reasoning_steps.py +249 -0
- eval_protocol/rewards/repetition.py +342 -0
- eval_protocol/rewards/tag_count.py +162 -0
- eval_protocol/rl_processing.py +82 -0
- eval_protocol/server.py +271 -0
- eval_protocol/typed_interface.py +260 -0
- eval_protocol/utils/__init__.py +8 -0
- eval_protocol/utils/batch_evaluation.py +217 -0
- eval_protocol/utils/batch_transformation.py +205 -0
- eval_protocol/utils/dataset_helpers.py +112 -0
- eval_protocol/utils/module_loader.py +56 -0
- eval_protocol/utils/packaging_utils.py +108 -0
- eval_protocol/utils/static_policy.py +305 -0
- eval_protocol-0.0.3.dist-info/METADATA +635 -0
- eval_protocol-0.0.3.dist-info/RECORD +130 -0
- eval_protocol-0.0.3.dist-info/WHEEL +5 -0
- eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
- eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
- eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
eval_protocol/server.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
|
6
|
+
|
|
7
|
+
import uvicorn
|
|
8
|
+
from fastapi import FastAPI, HTTPException, Request
|
|
9
|
+
from pydantic import BaseModel, Field
|
|
10
|
+
|
|
11
|
+
from .models import EvaluateResult
|
|
12
|
+
|
|
13
|
+
# Setup logging
|
|
14
|
+
logging.basicConfig(level=logging.INFO)
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Message(BaseModel):
|
|
19
|
+
"""Model for a conversation message."""
|
|
20
|
+
|
|
21
|
+
role: str
|
|
22
|
+
content: str
|
|
23
|
+
|
|
24
|
+
class Config:
|
|
25
|
+
extra = "allow" # Allow extra fields
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class RewardRequest(BaseModel):
|
|
29
|
+
"""Request model for reward endpoints."""
|
|
30
|
+
|
|
31
|
+
messages: List[Message] = Field(..., description="List of conversation messages")
|
|
32
|
+
ground_truth: Optional[Union[str, List[Message]]] = Field(
|
|
33
|
+
None, description="Ground truth data (string or list of messages) for context"
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
class Config:
|
|
37
|
+
extra = "allow" # Allow extra fields for arbitrary kwargs
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class RewardServer:
|
|
41
|
+
"""
|
|
42
|
+
Server for hosting reward functions.
|
|
43
|
+
|
|
44
|
+
This class creates a FastAPI server that can host reward functions.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
func_path: Path to the reward function to host (e.g., "module.path:function_name")
|
|
48
|
+
host: Host to bind the server to
|
|
49
|
+
port: Port to bind the server to
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
func_path: str,
|
|
55
|
+
host: str = "0.0.0.0",
|
|
56
|
+
port: int = 8000,
|
|
57
|
+
):
|
|
58
|
+
self.func_path = func_path
|
|
59
|
+
self.host = host
|
|
60
|
+
self.port = port
|
|
61
|
+
self.app = FastAPI(title="Reward Function Server")
|
|
62
|
+
|
|
63
|
+
# Load the reward function
|
|
64
|
+
self.reward_func = self._load_function()
|
|
65
|
+
|
|
66
|
+
# Register the endpoints
|
|
67
|
+
self._setup_routes()
|
|
68
|
+
|
|
69
|
+
def _load_function(self):
|
|
70
|
+
"""Load the reward function from the provided path."""
|
|
71
|
+
try:
|
|
72
|
+
if ":" not in self.func_path:
|
|
73
|
+
raise ValueError(f"Invalid func_path format: {self.func_path}, expected 'module.path:function_name'")
|
|
74
|
+
|
|
75
|
+
module_path, func_name = self.func_path.split(":", 1)
|
|
76
|
+
module = importlib.import_module(module_path)
|
|
77
|
+
func = getattr(module, func_name)
|
|
78
|
+
|
|
79
|
+
logger.info(f"Loaded reward function {func_name} from {module_path}")
|
|
80
|
+
return func
|
|
81
|
+
except (ImportError, AttributeError) as e:
|
|
82
|
+
raise ImportError(f"Failed to load function from path {self.func_path}: {str(e)}")
|
|
83
|
+
|
|
84
|
+
def _setup_routes(self):
|
|
85
|
+
"""Set up the API routes."""
|
|
86
|
+
|
|
87
|
+
@self.app.get("/")
|
|
88
|
+
async def root():
|
|
89
|
+
"""Get server info."""
|
|
90
|
+
return {
|
|
91
|
+
"status": "ok",
|
|
92
|
+
"reward_function": self.func_path,
|
|
93
|
+
"endpoints": ["/reward"],
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
@self.app.post("/reward")
|
|
97
|
+
async def reward(request: RewardRequest):
|
|
98
|
+
"""
|
|
99
|
+
Get reward score for messages.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
request: RewardRequest object with messages and optional parameters
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
EvaluateResult object with score and metrics
|
|
106
|
+
"""
|
|
107
|
+
try:
|
|
108
|
+
# Extract kwargs from the request
|
|
109
|
+
kwargs = request.dict(exclude={"messages", "ground_truth"})
|
|
110
|
+
|
|
111
|
+
# Set default for ground_truth if not provided and expected as list
|
|
112
|
+
ground_truth_data = request.ground_truth
|
|
113
|
+
if ground_truth_data is None:
|
|
114
|
+
# This default applies if ground_truth is expected to be a list of messages for context
|
|
115
|
+
ground_truth_data = request.messages[:-1] if request.messages else []
|
|
116
|
+
|
|
117
|
+
# Call the reward function
|
|
118
|
+
result = self.reward_func(
|
|
119
|
+
messages=request.messages,
|
|
120
|
+
ground_truth=ground_truth_data,
|
|
121
|
+
**kwargs,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Handle different return types
|
|
125
|
+
# The self.reward_func is expected to be decorated by the new @reward_function,
|
|
126
|
+
# which returns a dictionary.
|
|
127
|
+
if isinstance(result, dict) and "score" in result:
|
|
128
|
+
return result
|
|
129
|
+
elif isinstance(result, EvaluateResult): # Should not happen if func is from new decorator
|
|
130
|
+
logger.warning("Reward function returned EvaluateResult object directly to server; expected dict.")
|
|
131
|
+
return result.model_dump()
|
|
132
|
+
elif isinstance(result, tuple) and len(result) == 2: # Legacy tuple
|
|
133
|
+
logger.warning("Reward function returned legacy tuple format to server.")
|
|
134
|
+
score, components = result
|
|
135
|
+
return {"score": score, "metrics": components}
|
|
136
|
+
else:
|
|
137
|
+
raise TypeError(f"Invalid return type from reward function after decoration: {type(result)}")
|
|
138
|
+
|
|
139
|
+
except Exception as e:
|
|
140
|
+
logger.error(f"Error processing reward request: {str(e)}")
|
|
141
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
142
|
+
|
|
143
|
+
@self.app.get("/health")
|
|
144
|
+
async def health():
|
|
145
|
+
"""Health check endpoint."""
|
|
146
|
+
return {"status": "ok"}
|
|
147
|
+
|
|
148
|
+
def run(self):
|
|
149
|
+
"""Run the server."""
|
|
150
|
+
logger.info(f"Starting reward server on {self.host}:{self.port}")
|
|
151
|
+
uvicorn.run(self.app, host=self.host, port=self.port)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def serve(func_path: str, host: str = "0.0.0.0", port: int = 8000):
|
|
155
|
+
"""
|
|
156
|
+
Serve a reward function as an HTTP API.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
func_path: Path to the reward function to serve (e.g., "module.path:function_name")
|
|
160
|
+
host: Host to bind the server to
|
|
161
|
+
port: Port to bind the server to
|
|
162
|
+
"""
|
|
163
|
+
server = RewardServer(func_path=func_path, host=host, port=port)
|
|
164
|
+
server.run()
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
# ngrok-based serve_tunnel is deprecated in favor of Serveo via subprocess_manager.
|
|
168
|
+
# def serve_tunnel(func_path: str, port: int = 8000):
|
|
169
|
+
# """
|
|
170
|
+
# Serve a reward function with an ngrok tunnel.
|
|
171
|
+
# DEPRECATED.
|
|
172
|
+
# """
|
|
173
|
+
# try:
|
|
174
|
+
# import pyngrok.ngrok as ngrok # type: ignore
|
|
175
|
+
# except ImportError:
|
|
176
|
+
# raise ImportError(
|
|
177
|
+
# "The 'pyngrok' package is required to use serve_tunnel. "
|
|
178
|
+
# "Please install it with 'pip install pyngrok'."
|
|
179
|
+
# )
|
|
180
|
+
#
|
|
181
|
+
# # Open the tunnel
|
|
182
|
+
# tunnel = ngrok.connect(port)
|
|
183
|
+
# public_url = tunnel.public_url
|
|
184
|
+
#
|
|
185
|
+
# # Print the tunnel URL
|
|
186
|
+
# logger.info(f"Reward function available at: {public_url}/reward")
|
|
187
|
+
#
|
|
188
|
+
# # Start the server
|
|
189
|
+
# serve(func_path=func_path, host="0.0.0.0", port=port)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def create_app(reward_func: Callable[..., EvaluateResult]) -> FastAPI:
|
|
193
|
+
"""
|
|
194
|
+
Create a FastAPI app for the given reward function.
|
|
195
|
+
|
|
196
|
+
This function creates a FastAPI app that can be used to serve a reward function.
|
|
197
|
+
It's particularly useful for testing or when you want to manage the lifecycle
|
|
198
|
+
of the app yourself.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
reward_func: The reward function to serve
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
A FastAPI app instance
|
|
205
|
+
"""
|
|
206
|
+
app = FastAPI(title="Reward Function Server")
|
|
207
|
+
|
|
208
|
+
@app.get("/")
|
|
209
|
+
async def root():
|
|
210
|
+
"""Get server info."""
|
|
211
|
+
return {"status": "ok", "endpoints": ["/reward"]}
|
|
212
|
+
|
|
213
|
+
@app.post("/reward")
|
|
214
|
+
async def reward(request_data: RewardRequest):
|
|
215
|
+
"""
|
|
216
|
+
Get reward score for messages.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
request_data: RewardRequest object with messages and optional parameters
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
EvaluateResult object with score and metrics
|
|
223
|
+
"""
|
|
224
|
+
try:
|
|
225
|
+
# Convert Pydantic models to dictionaries using model_dump (Pydantic v2)
|
|
226
|
+
messages = [msg.model_dump() for msg in request_data.messages]
|
|
227
|
+
ground_truth_data: Optional[Union[str, List[Dict[str, Any]]]] = None
|
|
228
|
+
|
|
229
|
+
if isinstance(request_data.ground_truth, str):
|
|
230
|
+
ground_truth_data = request_data.ground_truth
|
|
231
|
+
elif isinstance(request_data.ground_truth, list):
|
|
232
|
+
ground_truth_data = [msg.model_dump() for msg in request_data.ground_truth]
|
|
233
|
+
|
|
234
|
+
# Extract kwargs from any extra fields
|
|
235
|
+
kwargs = {k: v for k, v in request_data.model_dump().items() if k not in ["messages", "ground_truth"]}
|
|
236
|
+
|
|
237
|
+
# Set default for ground_truth if not provided and expected as list
|
|
238
|
+
if ground_truth_data is None:
|
|
239
|
+
# This default applies if ground_truth is expected to be a list of messages for context
|
|
240
|
+
ground_truth_data = messages[:-1] if messages else []
|
|
241
|
+
|
|
242
|
+
# Call the reward function
|
|
243
|
+
result = reward_func(messages=messages, ground_truth=ground_truth_data, **kwargs)
|
|
244
|
+
|
|
245
|
+
# Handle different return types
|
|
246
|
+
# The reward_func is expected to be decorated by the new @reward_function,
|
|
247
|
+
# which returns a dictionary.
|
|
248
|
+
if isinstance(result, dict) and "score" in result:
|
|
249
|
+
return result
|
|
250
|
+
elif isinstance(result, EvaluateResult): # Should not happen if func is from new decorator
|
|
251
|
+
logger.warning(
|
|
252
|
+
"Reward function passed to create_app returned EvaluateResult object directly; expected dict after decoration."
|
|
253
|
+
)
|
|
254
|
+
return result.model_dump()
|
|
255
|
+
elif isinstance(result, tuple) and len(result) == 2: # Legacy tuple
|
|
256
|
+
logger.warning("Reward function passed to create_app returned legacy tuple format.")
|
|
257
|
+
score, components = result
|
|
258
|
+
return {"score": score, "metrics": components}
|
|
259
|
+
else:
|
|
260
|
+
raise TypeError(f"Invalid return type from reward function after decoration: {type(result)}")
|
|
261
|
+
|
|
262
|
+
except Exception as e:
|
|
263
|
+
logger.error(f"Error processing reward request: {str(e)}")
|
|
264
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
265
|
+
|
|
266
|
+
@app.get("/health")
|
|
267
|
+
async def health():
|
|
268
|
+
"""Health check endpoint."""
|
|
269
|
+
return {"status": "ok"}
|
|
270
|
+
|
|
271
|
+
return app
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from functools import wraps
|
|
3
|
+
from typing import (
|
|
4
|
+
Any,
|
|
5
|
+
Callable,
|
|
6
|
+
Dict,
|
|
7
|
+
List,
|
|
8
|
+
Literal,
|
|
9
|
+
Optional,
|
|
10
|
+
Protocol,
|
|
11
|
+
TypeVar,
|
|
12
|
+
Union,
|
|
13
|
+
cast,
|
|
14
|
+
get_args,
|
|
15
|
+
get_origin,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from pydantic import TypeAdapter, ValidationError
|
|
19
|
+
|
|
20
|
+
# EvaluateResult and StepOutput are now extended/defined in models.py
|
|
21
|
+
from .models import ( # Removed StepOutput as it's not used here directly
|
|
22
|
+
EvaluateResult,
|
|
23
|
+
Message,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# Import resource types
|
|
27
|
+
from .resources import ResourceDict
|
|
28
|
+
|
|
29
|
+
_single_res_adapter = TypeAdapter(EvaluateResult)
|
|
30
|
+
_list_res_adapter = TypeAdapter(List[EvaluateResult])
|
|
31
|
+
|
|
32
|
+
# Define a type for the mode parameter
|
|
33
|
+
EvaluationMode = Literal["pointwise", "batch"]
|
|
34
|
+
|
|
35
|
+
# TypeVar for the function being decorated, to preserve its signature as much as possible.
|
|
36
|
+
F = TypeVar("F", bound=Callable[..., Any])
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def reward_function(
|
|
40
|
+
_func: Optional[F] = None,
|
|
41
|
+
*,
|
|
42
|
+
mode: EvaluationMode = "pointwise",
|
|
43
|
+
id: Optional[str] = None,
|
|
44
|
+
requirements: Optional[List[str]] = None, # Changed to List[str]
|
|
45
|
+
resources: Optional[ResourceDict] = None, # Resource management
|
|
46
|
+
concurrency: Optional[int] = None,
|
|
47
|
+
timeout: Optional[int] = None,
|
|
48
|
+
) -> Union[F, Callable[[F], F]]:
|
|
49
|
+
"""
|
|
50
|
+
Decorator for user-defined reward and evaluation functions with resource management.
|
|
51
|
+
|
|
52
|
+
It handles:
|
|
53
|
+
- Coercing input messages (and ground truths if applicable) to Pydantic `Message` objects
|
|
54
|
+
if the decorated function is type-hinted to receive them. This part currently targets
|
|
55
|
+
parameters named 'messages' and 'ground_truth'.
|
|
56
|
+
- Validating that the output conforms to `EvaluateResult` (for pointwise) or `List[EvaluateResult]` (for batch).
|
|
57
|
+
- Managing declared resources (LLMs, databases, etc.) with automatic setup and cleanup
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
_func: The user's reward/evaluation function. Optional for decorator usage with args.
|
|
61
|
+
mode: Specifies the operational mode. Defaults to "pointwise".
|
|
62
|
+
- "pointwise": Function processes one rollout. Expected output: `EvaluateResult`.
|
|
63
|
+
- "batch": Function processes a batch of rollouts. Expected output: `List[EvaluateResult]`.
|
|
64
|
+
id: Optional identifier for the reward function, used for deployment
|
|
65
|
+
requirements: Optional string content for requirements.txt for deployment
|
|
66
|
+
resources: Optional dictionary of resource types to resource instances.
|
|
67
|
+
Example: {"llms": [llm_resource]}
|
|
68
|
+
Resources are automatically setup before evaluation and cleaned up after.
|
|
69
|
+
concurrency: Optional number of concurrent requests to the reward function. This will only take effect if the function is async or there are async resources binded to the reward function (e.g. LLM resource).
|
|
70
|
+
timeout: Optional timeout for the reward function. This will only take effect if the function is async or there are async resources binded to the reward function (e.g. LLM resource).
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
A decorator if `_func` is None, or the decorated function.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def decorator(func: F) -> F:
|
|
77
|
+
sig = inspect.signature(func)
|
|
78
|
+
params = sig.parameters
|
|
79
|
+
|
|
80
|
+
# Validate that the function accepts **kwargs
|
|
81
|
+
has_var_keyword = any(param.kind == inspect.Parameter.VAR_KEYWORD for param in params.values())
|
|
82
|
+
|
|
83
|
+
if not has_var_keyword:
|
|
84
|
+
raise ValueError(
|
|
85
|
+
f"Function '{func.__name__}' must accept **kwargs parameter. "
|
|
86
|
+
f"Please add '**kwargs' to the function signature."
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Setup resources once when the decorator is applied
|
|
90
|
+
resource_managers = {}
|
|
91
|
+
if resources:
|
|
92
|
+
for resource_type, resource_list in resources.items():
|
|
93
|
+
managers = []
|
|
94
|
+
for resource in resource_list:
|
|
95
|
+
resource.setup()
|
|
96
|
+
managers.append(resource)
|
|
97
|
+
resource_managers[resource_type] = managers
|
|
98
|
+
|
|
99
|
+
# Detect if the user supplied function is a coroutine (async def)
|
|
100
|
+
_is_async_function = inspect.iscoroutinefunction(func)
|
|
101
|
+
|
|
102
|
+
def _prepare_final_args(*args: Any, **kwargs: Any):
|
|
103
|
+
"""Prepare final positional and keyword arguments for the user function call.
|
|
104
|
+
This includes Pydantic coercion and resource injection. Returns a tuple of
|
|
105
|
+
(call_args, call_kwargs).
|
|
106
|
+
"""
|
|
107
|
+
# Bind arguments to handle *args and **kwargs correctly for the wrapped function
|
|
108
|
+
bound_args = sig.bind_partial(*args, **kwargs)
|
|
109
|
+
bound_args.apply_defaults()
|
|
110
|
+
|
|
111
|
+
# Create a mutable copy of arguments to modify
|
|
112
|
+
final_func_args = dict(bound_args.arguments)
|
|
113
|
+
|
|
114
|
+
def _coerce_to_list_message(data_list: Any, arg_name_for_error: str) -> List[Message]:
|
|
115
|
+
if not isinstance(data_list, list):
|
|
116
|
+
raise TypeError(f"Expected a list for '{arg_name_for_error}', got {type(data_list)}")
|
|
117
|
+
typed_list = []
|
|
118
|
+
for i, item_data in enumerate(data_list):
|
|
119
|
+
if isinstance(item_data, Message):
|
|
120
|
+
typed_list.append(item_data)
|
|
121
|
+
elif isinstance(item_data, dict):
|
|
122
|
+
typed_list.append(Message(**item_data))
|
|
123
|
+
else:
|
|
124
|
+
raise TypeError(f"Unexpected type for item {i} in '{arg_name_for_error}': {type(item_data)}")
|
|
125
|
+
return typed_list
|
|
126
|
+
|
|
127
|
+
# 1. Conditional Pydantic conversion for 'messages' (pointwise) or 'rollouts_messages' (batch)
|
|
128
|
+
if mode == "pointwise" and "messages" in params and "messages" in final_func_args:
|
|
129
|
+
messages_param_annotation = params["messages"].annotation
|
|
130
|
+
if (
|
|
131
|
+
get_origin(messages_param_annotation) in (list, List)
|
|
132
|
+
and get_args(messages_param_annotation)
|
|
133
|
+
and get_args(messages_param_annotation)[0] == Message
|
|
134
|
+
):
|
|
135
|
+
try:
|
|
136
|
+
final_func_args["messages"] = _coerce_to_list_message(final_func_args["messages"], "messages")
|
|
137
|
+
except Exception as err:
|
|
138
|
+
raise ValueError(f"Input 'messages' failed Pydantic validation: {err}") from None
|
|
139
|
+
|
|
140
|
+
elif mode == "batch" and "rollouts_messages" in params and "rollouts_messages" in final_func_args:
|
|
141
|
+
param_annotation = params["rollouts_messages"].annotation
|
|
142
|
+
inner = get_args(param_annotation)[0] if get_args(param_annotation) else None
|
|
143
|
+
if get_origin(param_annotation) == list and inner and get_origin(inner) == list:
|
|
144
|
+
if get_args(inner) and get_args(inner)[0] == Message:
|
|
145
|
+
try:
|
|
146
|
+
coerced_rollouts = []
|
|
147
|
+
for i, rollout_data in enumerate(final_func_args["rollouts_messages"]):
|
|
148
|
+
coerced_rollouts.append(
|
|
149
|
+
_coerce_to_list_message(rollout_data, f"rollouts_messages[{i}]")
|
|
150
|
+
)
|
|
151
|
+
final_func_args["rollouts_messages"] = coerced_rollouts
|
|
152
|
+
except Exception as err:
|
|
153
|
+
raise ValueError(f"Input 'rollouts_messages' failed Pydantic validation: {err}") from None
|
|
154
|
+
|
|
155
|
+
# Ground truth coercion (if needed)
|
|
156
|
+
if "ground_truth" in params and "ground_truth" in final_func_args:
|
|
157
|
+
gt_ann = params["ground_truth"].annotation
|
|
158
|
+
if get_origin(gt_ann) in (list, List) and get_args(gt_ann) and get_args(gt_ann)[0] == Message:
|
|
159
|
+
if final_func_args["ground_truth"] is not None:
|
|
160
|
+
try:
|
|
161
|
+
final_func_args["ground_truth"] = _coerce_to_list_message(
|
|
162
|
+
final_func_args["ground_truth"], "ground_truth"
|
|
163
|
+
)
|
|
164
|
+
except Exception as err:
|
|
165
|
+
raise ValueError(
|
|
166
|
+
f"Input 'ground_truth' failed Pydantic validation for List[Message]: {err}"
|
|
167
|
+
) from None
|
|
168
|
+
|
|
169
|
+
# Inject resource clients into kwargs (resources are already setup)
|
|
170
|
+
if resource_managers:
|
|
171
|
+
final_func_args["resources"] = {
|
|
172
|
+
resource_type: [manager.get_client() for manager in managers]
|
|
173
|
+
for resource_type, managers in resource_managers.items()
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
# Call the author's function using the (potentially modified) arguments dictionary.
|
|
177
|
+
# final_func_args should contain all parameters expected by func, correctly mapped.
|
|
178
|
+
# Reconstruct args and kwargs for the call to func
|
|
179
|
+
call_args: List[Any] = []
|
|
180
|
+
call_kwargs: Dict[str, Any] = {}
|
|
181
|
+
for (
|
|
182
|
+
p_name,
|
|
183
|
+
p_obj,
|
|
184
|
+
) in params.items(): # params from inspect.signature(func).parameters
|
|
185
|
+
if p_obj.kind == inspect.Parameter.VAR_POSITIONAL:
|
|
186
|
+
# If original func had *pos_args, final_func_args might contain it as a tuple
|
|
187
|
+
call_args.extend(final_func_args.get(p_name, ()))
|
|
188
|
+
elif p_obj.kind == inspect.Parameter.VAR_KEYWORD: # **kwargs
|
|
189
|
+
# If original func had **kw_args, final_func_args contains the dict of these
|
|
190
|
+
call_kwargs.update(final_func_args.get(p_name, {}))
|
|
191
|
+
elif p_name in final_func_args: # Named parameters
|
|
192
|
+
if p_obj.kind == inspect.Parameter.POSITIONAL_ONLY:
|
|
193
|
+
call_args.append(final_func_args[p_name])
|
|
194
|
+
else: # POSITIONAL_OR_KEYWORD, KEYWORD_ONLY
|
|
195
|
+
call_kwargs[p_name] = final_func_args[p_name]
|
|
196
|
+
|
|
197
|
+
return call_args, call_kwargs
|
|
198
|
+
|
|
199
|
+
def _validate_output(result: Any):
|
|
200
|
+
if mode == "pointwise":
|
|
201
|
+
if isinstance(result, EvaluateResult):
|
|
202
|
+
return result
|
|
203
|
+
return _single_res_adapter.validate_python(result)
|
|
204
|
+
elif mode == "batch":
|
|
205
|
+
if isinstance(result, list) and all(isinstance(item, EvaluateResult) for item in result):
|
|
206
|
+
return result
|
|
207
|
+
return _list_res_adapter.validate_python(result)
|
|
208
|
+
else:
|
|
209
|
+
raise ValueError(f"Internal error: Invalid mode '{mode}' in wrapper.")
|
|
210
|
+
|
|
211
|
+
if _is_async_function:
|
|
212
|
+
|
|
213
|
+
@wraps(func)
|
|
214
|
+
async def async_wrapper(
|
|
215
|
+
*args: Any,
|
|
216
|
+
**kwargs: Any,
|
|
217
|
+
) -> Union[EvaluateResult, List[EvaluateResult]]:
|
|
218
|
+
call_args, call_kwargs = _prepare_final_args(*args, **kwargs)
|
|
219
|
+
result = await func(*call_args, **call_kwargs) # type: ignore[misc]
|
|
220
|
+
try:
|
|
221
|
+
return _validate_output(result)
|
|
222
|
+
except ValidationError as err:
|
|
223
|
+
raise ValueError(
|
|
224
|
+
f"Return value from function '{func.__name__}' failed Pydantic validation for mode '{mode}':\n{err}"
|
|
225
|
+
) from None
|
|
226
|
+
|
|
227
|
+
wrapper_fn = async_wrapper
|
|
228
|
+
|
|
229
|
+
else:
|
|
230
|
+
|
|
231
|
+
@wraps(func)
|
|
232
|
+
def sync_wrapper(
|
|
233
|
+
*args: Any,
|
|
234
|
+
**kwargs: Any,
|
|
235
|
+
) -> Union[EvaluateResult, List[EvaluateResult]]:
|
|
236
|
+
call_args, call_kwargs = _prepare_final_args(*args, **kwargs)
|
|
237
|
+
result = func(*call_args, **call_kwargs)
|
|
238
|
+
try:
|
|
239
|
+
return _validate_output(result)
|
|
240
|
+
except ValidationError as err:
|
|
241
|
+
raise ValueError(
|
|
242
|
+
f"Return value from function '{func.__name__}' failed Pydantic validation for mode '{mode}':\n{err}"
|
|
243
|
+
) from None
|
|
244
|
+
|
|
245
|
+
wrapper_fn = sync_wrapper
|
|
246
|
+
|
|
247
|
+
# Set attributes for introspection and deployment
|
|
248
|
+
wrapper_fn._reward_function_id = id # type: ignore[attr-defined]
|
|
249
|
+
wrapper_fn._reward_function_requirements = requirements # type: ignore[attr-defined]
|
|
250
|
+
wrapper_fn._reward_function_mode = mode # type: ignore[attr-defined]
|
|
251
|
+
wrapper_fn._reward_function_resources = resources # type: ignore[attr-defined]
|
|
252
|
+
wrapper_fn._reward_function_timeout = timeout # type: ignore[attr-defined]
|
|
253
|
+
wrapper_fn._reward_function_concurrency = concurrency # type: ignore[attr-defined]
|
|
254
|
+
|
|
255
|
+
return cast(F, wrapper_fn)
|
|
256
|
+
|
|
257
|
+
if _func is None: # Decorator called with arguments, e.g., @reward_function(mode="batch")
|
|
258
|
+
return decorator
|
|
259
|
+
else: # Decorator called without arguments, e.g., @reward_function (defaults to pointwise)
|
|
260
|
+
return decorator(_func)
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
# This file makes the 'utils' directory a Python package.
|
|
2
|
+
|
|
3
|
+
# You can selectively expose functions or classes from modules within 'utils' here
|
|
4
|
+
# for easier access, e.g.:
|
|
5
|
+
# from .dataset_helpers import load_jsonl_to_hf_dataset
|
|
6
|
+
|
|
7
|
+
# For now, allow direct import of modules like:
|
|
8
|
+
# from eval_protocol.utils.dataset_helpers import ...
|