eval-protocol 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. development/__init__.py +1 -0
  2. development/normalize_sandbox_fusion.py +628 -0
  3. development/utils/__init__.py +1 -0
  4. development/utils/generate_api_key.py +31 -0
  5. development/utils/subprocess_manager.py +481 -0
  6. eval_protocol/__init__.py +86 -0
  7. eval_protocol/__main__.py +10 -0
  8. eval_protocol/_version.py +21 -0
  9. eval_protocol/adapters/__init__.py +1 -0
  10. eval_protocol/adapters/braintrust.py +8 -0
  11. eval_protocol/adapters/trl.py +8 -0
  12. eval_protocol/agent/__init__.py +29 -0
  13. eval_protocol/agent/models.py +69 -0
  14. eval_protocol/agent/orchestrator.py +893 -0
  15. eval_protocol/agent/resource_abc.py +89 -0
  16. eval_protocol/agent/resource_pool.py +184 -0
  17. eval_protocol/agent/resources/__init__.py +44 -0
  18. eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
  19. eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
  20. eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
  21. eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
  22. eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
  23. eval_protocol/agent/resources/docker_resource.py +479 -0
  24. eval_protocol/agent/resources/filesystem_resource.py +371 -0
  25. eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
  26. eval_protocol/agent/resources/http_rollout_resource.py +325 -0
  27. eval_protocol/agent/resources/python_state_resource.py +170 -0
  28. eval_protocol/agent/resources/sql_resource.py +271 -0
  29. eval_protocol/agent/task_manager.py +1064 -0
  30. eval_protocol/agent/tool_registry.py +111 -0
  31. eval_protocol/auth.py +156 -0
  32. eval_protocol/cli.py +425 -0
  33. eval_protocol/cli_commands/__init__.py +1 -0
  34. eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
  35. eval_protocol/cli_commands/common.py +242 -0
  36. eval_protocol/cli_commands/deploy.py +486 -0
  37. eval_protocol/cli_commands/deploy_mcp.py +287 -0
  38. eval_protocol/cli_commands/preview.py +186 -0
  39. eval_protocol/cli_commands/run_eval_cmd.py +202 -0
  40. eval_protocol/common_utils.py +36 -0
  41. eval_protocol/config.py +180 -0
  42. eval_protocol/datasets/__init__.py +1 -0
  43. eval_protocol/datasets/loader.py +521 -0
  44. eval_protocol/evaluation.py +1045 -0
  45. eval_protocol/execution/__init__.py +1 -0
  46. eval_protocol/execution/pipeline.py +920 -0
  47. eval_protocol/gcp_tools.py +484 -0
  48. eval_protocol/generation/cache.py +141 -0
  49. eval_protocol/generation/clients/base.py +67 -0
  50. eval_protocol/generation/clients.py +248 -0
  51. eval_protocol/generic_server.py +165 -0
  52. eval_protocol/integrations/__init__.py +12 -0
  53. eval_protocol/integrations/braintrust.py +51 -0
  54. eval_protocol/integrations/deepeval.py +106 -0
  55. eval_protocol/integrations/openeval.py +40 -0
  56. eval_protocol/integrations/trl.py +187 -0
  57. eval_protocol/mcp/__init__.py +48 -0
  58. eval_protocol/mcp/adapter.py +131 -0
  59. eval_protocol/mcp/client/__init__.py +12 -0
  60. eval_protocol/mcp/client/connection.py +499 -0
  61. eval_protocol/mcp/clients.py +195 -0
  62. eval_protocol/mcp/execution/__init__.py +23 -0
  63. eval_protocol/mcp/execution/base_policy.py +227 -0
  64. eval_protocol/mcp/execution/fireworks_policy.py +209 -0
  65. eval_protocol/mcp/execution/manager.py +506 -0
  66. eval_protocol/mcp/execution/policy.py +421 -0
  67. eval_protocol/mcp/grid_renderer.py +54 -0
  68. eval_protocol/mcp/mcpgym.py +637 -0
  69. eval_protocol/mcp/process_manager.py +177 -0
  70. eval_protocol/mcp/session/__init__.py +11 -0
  71. eval_protocol/mcp/session/manager.py +228 -0
  72. eval_protocol/mcp/simple_process_manager.py +291 -0
  73. eval_protocol/mcp/simulation_server.py +458 -0
  74. eval_protocol/mcp/types.py +80 -0
  75. eval_protocol/mcp_agent/__init__.py +1 -0
  76. eval_protocol/mcp_agent/config.py +147 -0
  77. eval_protocol/mcp_agent/intermediary_server.py +542 -0
  78. eval_protocol/mcp_agent/main.py +210 -0
  79. eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
  80. eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
  81. eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
  82. eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
  83. eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
  84. eval_protocol/mcp_agent/session.py +79 -0
  85. eval_protocol/mcp_env.py +304 -0
  86. eval_protocol/models.py +366 -0
  87. eval_protocol/packaging.py +219 -0
  88. eval_protocol/platform_api.py +360 -0
  89. eval_protocol/playback_policy.py +396 -0
  90. eval_protocol/resources.py +128 -0
  91. eval_protocol/reward_function.py +410 -0
  92. eval_protocol/rewards/__init__.py +94 -0
  93. eval_protocol/rewards/accuracy.py +454 -0
  94. eval_protocol/rewards/accuracy_length.py +173 -0
  95. eval_protocol/rewards/apps_coding_reward.py +331 -0
  96. eval_protocol/rewards/apps_execution_utils.py +149 -0
  97. eval_protocol/rewards/apps_testing_util.py +559 -0
  98. eval_protocol/rewards/bfcl_reward.py +313 -0
  99. eval_protocol/rewards/code_execution.py +1620 -0
  100. eval_protocol/rewards/code_execution_utils.py +72 -0
  101. eval_protocol/rewards/cpp_code.py +861 -0
  102. eval_protocol/rewards/deepcoder_reward.py +161 -0
  103. eval_protocol/rewards/format.py +129 -0
  104. eval_protocol/rewards/function_calling.py +541 -0
  105. eval_protocol/rewards/json_schema.py +422 -0
  106. eval_protocol/rewards/language_consistency.py +700 -0
  107. eval_protocol/rewards/lean_prover.py +479 -0
  108. eval_protocol/rewards/length.py +375 -0
  109. eval_protocol/rewards/list_comparison_math_reward.py +221 -0
  110. eval_protocol/rewards/math.py +762 -0
  111. eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
  112. eval_protocol/rewards/reasoning_steps.py +249 -0
  113. eval_protocol/rewards/repetition.py +342 -0
  114. eval_protocol/rewards/tag_count.py +162 -0
  115. eval_protocol/rl_processing.py +82 -0
  116. eval_protocol/server.py +271 -0
  117. eval_protocol/typed_interface.py +260 -0
  118. eval_protocol/utils/__init__.py +8 -0
  119. eval_protocol/utils/batch_evaluation.py +217 -0
  120. eval_protocol/utils/batch_transformation.py +205 -0
  121. eval_protocol/utils/dataset_helpers.py +112 -0
  122. eval_protocol/utils/module_loader.py +56 -0
  123. eval_protocol/utils/packaging_utils.py +108 -0
  124. eval_protocol/utils/static_policy.py +305 -0
  125. eval_protocol-0.0.3.dist-info/METADATA +635 -0
  126. eval_protocol-0.0.3.dist-info/RECORD +130 -0
  127. eval_protocol-0.0.3.dist-info/WHEEL +5 -0
  128. eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
  129. eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
  130. eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
@@ -0,0 +1,248 @@
1
+ """
2
+ Model clients for generating responses from various LLM APIs.
3
+ """
4
+
5
+ import abc
6
+ import asyncio
7
+ import json # For parsing content as JSON
8
+ import logging
9
+ import uuid # For generating tool call IDs if not provided
10
+ from typing import Any, Dict, List, Optional
11
+
12
+ import aiohttp
13
+ from omegaconf import DictConfig
14
+ from pydantic import BaseModel, Field # Added for new models
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ # Pydantic models for structured tool calls and generation results
20
+ class ToolCallFunction(BaseModel):
21
+ name: str
22
+ arguments: str # Should be a JSON string
23
+
24
+
25
+ class ToolCall(BaseModel):
26
+ id: str
27
+ type: str = "function" # OpenAI default
28
+ function: ToolCallFunction
29
+
30
+
31
+ class GenerationResult(BaseModel):
32
+ content: Optional[str] = None
33
+ tool_calls: Optional[List[ToolCall]] = None
34
+
35
+
36
+ class ModelClient(abc.ABC):
37
+ """Abstract base class for model clients."""
38
+
39
+ def __init__(self, client_config: DictConfig):
40
+ self.client_config = client_config
41
+ self.model_name = client_config.get("model_name", "unknown_model")
42
+ self.temperature = client_config.get("temperature", 0.0)
43
+ self.max_tokens = client_config.get("max_tokens", 1024)
44
+ self.top_p = client_config.get("top_p", 0.95)
45
+ self.top_k = client_config.get("top_k", 20)
46
+ self.min_p = client_config.get("min_p", 0.0)
47
+ # Add reasoning_effort, defaulting to None if not specified in config
48
+ self.reasoning_effort = client_config.get("reasoning_effort", None)
49
+
50
+ @abc.abstractmethod
51
+ async def generate(
52
+ self,
53
+ messages: List[Dict[str, str]],
54
+ session: aiohttp.ClientSession,
55
+ tools: Optional[List[Dict[str, Any]]] = None, # Added tools parameter
56
+ ) -> GenerationResult: # Changed return type
57
+ """Generates a response from the model given a list of messages."""
58
+ pass
59
+
60
+
61
+ class FireworksModelClient(ModelClient):
62
+ """Client for Fireworks AI models."""
63
+
64
+ def __init__(self, client_config: DictConfig, api_key: str):
65
+ super().__init__(client_config)
66
+ self.api_key = api_key
67
+ self.api_base = client_config.get("api_base", "https://api.fireworks.ai/inference/v1")
68
+ # TODO: Initialize rate limiter, retry policy from client_config.api_params
69
+
70
+ async def generate(
71
+ self,
72
+ messages: List[Dict[str, str]],
73
+ session: aiohttp.ClientSession,
74
+ tools: Optional[List[Dict[str, Any]]] = None,
75
+ ) -> GenerationResult:
76
+ url = f"{self.api_base}/chat/completions"
77
+
78
+ payload: Dict[str, Any] = {
79
+ "model": self.model_name,
80
+ "messages": messages,
81
+ "temperature": self.temperature,
82
+ "max_tokens": self.max_tokens,
83
+ }
84
+ if self.top_p is not None:
85
+ payload["top_p"] = self.top_p
86
+
87
+ if tools:
88
+ payload["tools"] = tools
89
+ # Fireworks API might use "function" or "any" or specific tool name for tool_choice.
90
+ # "auto" is common for OpenAI. If Fireworks needs specific, this might need adjustment.
91
+ # Or, if it's like older OpenAI, it might not use tool_choice if tools are present.
92
+ # For now, let's assume "auto" or that it's implicit if "tools" is provided.
93
+ # The user's log shows the LLM is attempting tool calls even with the simpler prompt,
94
+ # implying the `tools` parameter is having an effect or the model is well-primed.
95
+ payload["tool_choice"] = "auto"
96
+
97
+ headers = {
98
+ "Authorization": f"Bearer {self.api_key}",
99
+ "Content-Type": "application/json",
100
+ "Accept": "application/json",
101
+ }
102
+
103
+ debug_payload_log = json.loads(json.dumps(payload))
104
+ if "messages" in debug_payload_log and debug_payload_log["messages"]:
105
+ if debug_payload_log["messages"][-1].get("content"): # Check if content exists
106
+ debug_payload_log["messages"][-1]["content"] = (
107
+ str(debug_payload_log["messages"][-1]["content"])[:50] + "..."
108
+ )
109
+ logger.debug(f"Calling Fireworks API: {url}, Payload: {debug_payload_log}")
110
+
111
+ try:
112
+ for attempt in range(self.client_config.get("api_params", {}).get("max_retries", 3) + 1):
113
+ async with session.post(url, json=payload, headers=headers) as response:
114
+ if response.status == 200:
115
+ data = await response.json()
116
+ if data.get("choices") and len(data["choices"]) > 0:
117
+ choice = data["choices"][0]
118
+ message = choice.get("message", {})
119
+
120
+ # 1. Check for native OpenAI-style tool_calls field
121
+ if message.get("tool_calls"):
122
+ tool_calls_data = message["tool_calls"]
123
+ parsed_tool_calls = []
124
+ for tc_data in tool_calls_data:
125
+ if tc_data.get("type") == "function" and tc_data.get("function"):
126
+ parsed_tool_calls.append(
127
+ ToolCall(
128
+ id=tc_data.get(
129
+ "id", f"call_{uuid.uuid4().hex[:8]}"
130
+ ), # Generate ID if missing
131
+ type="function",
132
+ function=ToolCallFunction(
133
+ name=tc_data["function"]["name"],
134
+ arguments=tc_data["function"]["arguments"],
135
+ ),
136
+ )
137
+ )
138
+ if parsed_tool_calls:
139
+ logger.debug(f"Parsed native tool_calls: {parsed_tool_calls}")
140
+ return GenerationResult(tool_calls=parsed_tool_calls)
141
+
142
+ # 2. If no native tool_calls, check if content is a JSON string representing a tool call
143
+ # This handles the case where the LLM puts the tool call JSON into the content field.
144
+ # The user's log shows content like: "{\"type\": \"function\", \"name\": \"move_file\", ...}"
145
+ if message.get("content"):
146
+ content_str = message["content"]
147
+ try:
148
+ # Attempt to parse content as JSON
149
+ potential_tool_call_data = json.loads(content_str)
150
+
151
+ # Check if it matches the OpenAI tool call structure (single call in content)
152
+ # e.g., {"type": "function", "function": {"name": "...", "arguments": "{...}"}}
153
+ # or the structure the LLM actually produced: {"type": "function", "name": "...", "parameters": {...}}
154
+
155
+ parsed_tool_calls_from_content = []
156
+ # Handle if content is a list of tool calls (less likely but possible)
157
+ if isinstance(potential_tool_call_data, list):
158
+ data_to_check = potential_tool_call_data
159
+ else: # Assume it's a single tool call object
160
+ data_to_check = [potential_tool_call_data]
161
+
162
+ for item in data_to_check:
163
+ if isinstance(item, dict) and item.get("type") == "function":
164
+ func_details = item.get("function") # OpenAI style
165
+ if func_details and "name" in func_details and "arguments" in func_details:
166
+ parsed_tool_calls_from_content.append(
167
+ ToolCall(
168
+ id=item.get(
169
+ "id",
170
+ f"call_{uuid.uuid4().hex[:8]}",
171
+ ),
172
+ type="function",
173
+ function=ToolCallFunction(
174
+ name=func_details["name"],
175
+ arguments=func_details["arguments"],
176
+ ),
177
+ )
178
+ )
179
+ continue # Found valid OpenAI style tool call
180
+
181
+ # Check for the LLM's observed output format: {"type": "function", "name": ..., "parameters": ...}
182
+ # This is slightly different from OpenAI's `function.arguments` being a string.
183
+ # Here, `parameters` is an object. We need to dump it to string for `ToolCallFunction.arguments`.
184
+ llm_name = item.get("name")
185
+ llm_params = item.get("parameters")
186
+ if llm_name and isinstance(llm_params, dict):
187
+ parsed_tool_calls_from_content.append(
188
+ ToolCall(
189
+ id=item.get(
190
+ "id",
191
+ f"call_{uuid.uuid4().hex[:8]}",
192
+ ), # Generate an ID
193
+ type="function",
194
+ function=ToolCallFunction(
195
+ name=llm_name,
196
+ arguments=json.dumps(llm_params),
197
+ ),
198
+ )
199
+ )
200
+ continue # Found valid LLM-specific style tool call
201
+
202
+ if parsed_tool_calls_from_content:
203
+ logger.debug(
204
+ f"Parsed tool_calls from content field: {parsed_tool_calls_from_content}"
205
+ )
206
+ return GenerationResult(tool_calls=parsed_tool_calls_from_content)
207
+
208
+ # If JSON but not a recognized tool call, it's just JSON content
209
+ logger.debug(
210
+ "Content was JSON, but not a recognized tool call structure. Treating as text."
211
+ )
212
+ return GenerationResult(content=content_str.strip())
213
+
214
+ except json.JSONDecodeError:
215
+ # Content is not JSON, so it's a regular text response
216
+ logger.debug("Content is not JSON. Treating as text.")
217
+ return GenerationResult(content=content_str.strip())
218
+
219
+ # If neither tool_calls nor parsable content that looks like a tool call
220
+ logger.warning(f"Fireworks API response malformed or no actionable content/tool_calls: {data}")
221
+ return GenerationResult()
222
+
223
+ # ... (rest of the error handling as before) ...
224
+ elif response.status == 429: # Rate limit
225
+ retry_after = int(response.headers.get("Retry-After", "5"))
226
+ logger.warning(f"Rate limited. Retrying after {retry_after}s (attempt {attempt+1}).")
227
+ await asyncio.sleep(retry_after)
228
+ elif response.status in [401, 403]: # Auth errors
229
+ error_text = await response.text()
230
+ logger.error(f"Fireworks API Auth Error ({response.status}): {error_text}")
231
+ return GenerationResult() # Empty result on auth error
232
+ elif response.status >= 500: # Server errors
233
+ logger.warning(
234
+ f"Fireworks API Server Error ({response.status}). Retrying (attempt {attempt+1})."
235
+ )
236
+ await asyncio.sleep(2**attempt)
237
+ else: # Other client errors
238
+ error_text = await response.text()
239
+ logger.error(f"Fireworks API request failed ({response.status}): {error_text}")
240
+ return GenerationResult() # Empty result
241
+ logger.error("Max retries reached for Fireworks API call.")
242
+ return GenerationResult()
243
+ except aiohttp.ClientError as e:
244
+ logger.error(f"AIOHTTP client error: {e}")
245
+ return GenerationResult()
246
+ except Exception as e:
247
+ logger.error(f"Unexpected error in FireworksModelClient: {e}", exc_info=True)
248
+ return GenerationResult()
@@ -0,0 +1,165 @@
1
+ import importlib
2
+ import os
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import uvicorn
6
+ from fastapi import Depends, FastAPI, HTTPException, Request
7
+ from pydantic import BaseModel, ValidationError
8
+
9
+ # Assuming these models are correctly defined in eval_protocol.models
10
+ from eval_protocol.models import EvaluateResult, Message
11
+
12
+
13
+ # --- Request and Response Models ---
14
+ class EvaluationRequest(BaseModel):
15
+ messages: List[Dict[str, Any]] # Could also be List[Message] if we enforce that model on input
16
+ ground_truth: Optional[str] = None
17
+ kwargs: Optional[Dict[str, Any]] = {}
18
+
19
+
20
+ # --- Global variable to store the loaded reward function ---
21
+ # This is a simple approach for a single-function server.
22
+ # If multiple functions were to be served by one instance, a different mechanism would be needed.
23
+ _LOADED_REWARD_FUNCTION = None
24
+ _REWARD_FUNCTION_NAME = "N/A"
25
+
26
+ # --- API Key Authentication Dependency ---
27
+ EXPECTED_API_KEY = os.environ.get("RK_ENDPOINT_API_KEY")
28
+
29
+
30
+ async def verify_api_key(request: Request):
31
+ if EXPECTED_API_KEY:
32
+ # Check for X-Api-Key header first
33
+ client_api_key = request.headers.get("X-Api-Key")
34
+ # If not found, check for Authorization: Bearer <key>
35
+ if not client_api_key:
36
+ auth_header = request.headers.get("Authorization")
37
+ if auth_header and auth_header.startswith("Bearer "):
38
+ client_api_key = auth_header.split(" ", 1)[1]
39
+
40
+ if not client_api_key:
41
+ raise HTTPException(status_code=401, detail="API key required but not provided.")
42
+ if client_api_key != EXPECTED_API_KEY:
43
+ raise HTTPException(status_code=403, detail="Invalid API key.")
44
+ return True # Allow request if no key expected or if key is valid
45
+
46
+
47
+ # --- FastAPI App ---
48
+ app = FastAPI(
49
+ title="Reward Kit Generic Reward Function Server",
50
+ description="Serves a dynamically loaded reward function.",
51
+ version="0.1.0", # Or use eval_protocol.__version__
52
+ )
53
+
54
+
55
+ @app.post("/evaluate", response_model=EvaluateResult, dependencies=[Depends(verify_api_key)])
56
+ async def evaluate_endpoint(request: EvaluationRequest):
57
+ """
58
+ Endpoint to evaluate a given set of messages using the loaded reward function.
59
+ Requires API key if RK_ENDPOINT_API_KEY environment variable is set.
60
+ """
61
+ if _LOADED_REWARD_FUNCTION is None:
62
+ raise HTTPException(status_code=500, detail="Reward function not loaded.")
63
+
64
+ try:
65
+ # The user's reward function is expected to match the @reward_function signature
66
+ func_args = {
67
+ "messages": request.messages,
68
+ "ground_truth": request.ground_truth,
69
+ **(request.kwargs or {}),
70
+ }
71
+
72
+ result = _LOADED_REWARD_FUNCTION(**func_args)
73
+
74
+ if not isinstance(result, EvaluateResult):
75
+ # This case should ideally not happen if functions are correctly decorated
76
+ # and return EvaluateResult, but good to have a fallback.
77
+ print(
78
+ f"Warning: Reward function '{_REWARD_FUNCTION_NAME}' did not return an EvaluateResult instance. Type: {type(result)}"
79
+ )
80
+ # Attempt to construct an EvaluateResult if it's a dict-like object,
81
+ # otherwise, this will raise an error or return a poorly formed response.
82
+ # For robustness, one might want to wrap this in another try-except.
83
+ return EvaluateResult(
84
+ score=0.0,
85
+ reason="Invalid return type from reward function, check server logs.",
86
+ is_score_valid=False,
87
+ metrics={},
88
+ )
89
+
90
+ return result
91
+ except ValidationError as ve: # Pydantic validation error from reward function's input/output
92
+ print(f"Validation Error calling reward function '{_REWARD_FUNCTION_NAME}': {ve}")
93
+ raise HTTPException(
94
+ status_code=422,
95
+ detail=f"Input/Output validation error for reward function: {ve.errors()}",
96
+ )
97
+ except Exception as e:
98
+ print(f"Error during evaluation with reward function '{_REWARD_FUNCTION_NAME}': {e}")
99
+ # Consider logging the full traceback here
100
+ raise HTTPException(status_code=500, detail=f"Internal server error during evaluation: {str(e)}")
101
+
102
+
103
+ @app.get("/health")
104
+ async def health_check():
105
+ """
106
+ Health check endpoint.
107
+ """
108
+ if _LOADED_REWARD_FUNCTION:
109
+ return {"status": "ok", "reward_function": _REWARD_FUNCTION_NAME}
110
+ else:
111
+ return {"status": "error", "reason": "Reward function not loaded"}
112
+
113
+
114
+ def load_reward_function(import_string: str):
115
+ """
116
+ Loads a reward function from an import string (e.g., 'my_module.my_function').
117
+ """
118
+ global _LOADED_REWARD_FUNCTION, _REWARD_FUNCTION_NAME
119
+ try:
120
+ module_path, function_name = import_string.rsplit(".", 1)
121
+ module = importlib.import_module(module_path)
122
+ _LOADED_REWARD_FUNCTION = getattr(module, function_name)
123
+ _REWARD_FUNCTION_NAME = import_string
124
+ print(f"Successfully loaded reward function: {_REWARD_FUNCTION_NAME}")
125
+ except Exception as e:
126
+ print(f"Error loading reward function from '{import_string}': {e}")
127
+ _LOADED_REWARD_FUNCTION = None
128
+ _REWARD_FUNCTION_NAME = "Error loading"
129
+ raise # Re-raise to make it fatal if loading fails on startup
130
+
131
+
132
+ if __name__ == "__main__":
133
+ import argparse
134
+
135
+ parser = argparse.ArgumentParser(description="Run the Generic Reward Function Server.")
136
+ parser.add_argument(
137
+ "import_string",
138
+ type=str,
139
+ help="Import string for the reward function (e.g., 'my_package.my_module.my_reward_function')",
140
+ )
141
+ parser.add_argument("--host", type=str, default="127.0.0.1", help="Host to bind the server to.")
142
+ parser.add_argument(
143
+ "--port",
144
+ type=int,
145
+ default=8080, # Standard port for Cloud Run, etc.
146
+ help="Port to bind the server to.",
147
+ )
148
+ # Add --reload for uvicorn if needed for development
149
+ # parser.add_argument("--reload", action="store_true", help="Enable auto-reload for development.")
150
+
151
+ args = parser.parse_args()
152
+
153
+ try:
154
+ load_reward_function(args.import_string)
155
+ except Exception:
156
+ print(f"Failed to load reward function. Exiting.")
157
+ exit(1)
158
+
159
+ if not _LOADED_REWARD_FUNCTION:
160
+ print(f"Reward function {_REWARD_FUNCTION_NAME} could not be loaded. Server will not start correctly.")
161
+ # Depending on desired behavior, could exit here or let it run and fail on /evaluate
162
+ exit(1)
163
+
164
+ print(f"Starting server for reward function: {args.import_string} on http://{args.host}:{args.port}")
165
+ uvicorn.run(app, host=args.host, port=args.port) # reload=args.reload for dev
@@ -0,0 +1,12 @@
1
+ """Integration helpers for Reward Kit."""
2
+
3
+ from .braintrust import reward_fn_to_scorer, scorer_to_reward_fn
4
+ from .openeval import adapt
5
+ from .trl import create_trl_adapter
6
+
7
+ __all__ = [
8
+ "adapt",
9
+ "scorer_to_reward_fn",
10
+ "reward_fn_to_scorer",
11
+ "create_trl_adapter",
12
+ ]
@@ -0,0 +1,51 @@
1
+ """Adapters for integrating Reward Kit with Braintrust scoring functions."""
2
+
3
+ from typing import Any, Callable, List, Optional
4
+
5
+ from eval_protocol.models import EvaluateResult, Message
6
+ from eval_protocol.typed_interface import reward_function
7
+
8
+ # Type alias for Braintrust scoring functions
9
+ BraintrustScorer = Callable[[Any, Any, Any], float]
10
+
11
+
12
+ def scorer_to_reward_fn(
13
+ scorer: BraintrustScorer,
14
+ *,
15
+ messages_to_input: Optional[Callable[[List[Message]], Any]] = None,
16
+ ground_truth_to_expected: Optional[Callable[[List[Message]], Any]] = None,
17
+ ) -> Callable[[List[Message], Optional[List[Message]]], EvaluateResult]:
18
+ """Wrap a Braintrust scorer as a Reward Kit reward function."""
19
+
20
+ @reward_function
21
+ def reward_fn(messages: List[Message], ground_truth: Optional[List[Message]] = None, **kwargs) -> EvaluateResult:
22
+ input_val = messages_to_input(messages) if messages_to_input else messages[0].content
23
+ output_val = messages[-1].content
24
+ expected_val = None
25
+ if ground_truth:
26
+ expected_val = (
27
+ ground_truth_to_expected(ground_truth) if ground_truth_to_expected else ground_truth[-1].content
28
+ )
29
+ score = scorer(input_val, output_val, expected_val)
30
+ return EvaluateResult(score=score)
31
+
32
+ return reward_fn
33
+
34
+
35
+ def reward_fn_to_scorer(
36
+ reward_fn: Callable[[List[Message], Optional[List[Message]]], EvaluateResult],
37
+ ) -> BraintrustScorer:
38
+ """Create a Braintrust-compatible scorer from a Reward Kit reward function."""
39
+
40
+ def scorer(input_val: Any, output: Any, expected: Any) -> float:
41
+ messages = [
42
+ Message(role="user", content=str(input_val)),
43
+ Message(role="assistant", content=str(output)),
44
+ ]
45
+ ground_truth = None
46
+ if expected is not None:
47
+ ground_truth = [Message(role="assistant", content=str(expected))]
48
+ result = reward_fn(messages=messages, ground_truth=ground_truth)
49
+ return result.score
50
+
51
+ return scorer
@@ -0,0 +1,106 @@
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ from eval_protocol.models import EvaluateResult, MetricResult
4
+ from eval_protocol.typed_interface import reward_function
5
+
6
+ __all__ = ["adapt_metric"]
7
+
8
+ try:
9
+ from deepeval.metrics.base_metric import BaseConversationalMetric, BaseMetric
10
+ from deepeval.test_case import ConversationalTestCase, LLMTestCase
11
+ except Exception: # pragma: no cover - deepeval is optional
12
+ BaseMetric = None
13
+ BaseConversationalMetric = None
14
+ LLMTestCase = None
15
+ ConversationalTestCase = None
16
+
17
+
18
+ def _metric_name(metric: Any) -> str:
19
+ name = getattr(metric, "__name__", None)
20
+ if name and name not in {
21
+ "Base Metric",
22
+ "Base Conversational Metric",
23
+ "Base Multimodal Metric",
24
+ }:
25
+ return str(name)
26
+ name = getattr(metric, "name", None)
27
+ if name:
28
+ return str(name)
29
+ return metric.__class__.__name__
30
+
31
+
32
+ def adapt_metric(metric: Any):
33
+ """Adapt a deepeval metric object into a reward-kit reward function."""
34
+
35
+ @reward_function
36
+ def wrapped(
37
+ messages: List[Dict[str, Any]],
38
+ ground_truth: Optional[str] = None,
39
+ **kwargs: Any,
40
+ ) -> EvaluateResult:
41
+ if BaseMetric is None or LLMTestCase is None:
42
+ raise ImportError("deepeval must be installed to use this integration")
43
+ if not messages:
44
+ return EvaluateResult(score=0.0, reason="No messages", metrics={})
45
+
46
+ output = messages[-1].get("content", "")
47
+ input_msg = ""
48
+ if len(messages) >= 2:
49
+ input_msg = messages[-2].get("content", "")
50
+
51
+ def _build_case_kwargs() -> Dict[str, Any]:
52
+ case_kwargs: Dict[str, Any] = {}
53
+ params = getattr(metric, "evaluation_params", None)
54
+ if params:
55
+ for param in params:
56
+ if param.value == "input":
57
+ case_kwargs["input"] = input_msg
58
+ elif param.value == "actual_output":
59
+ case_kwargs["actual_output"] = output
60
+ elif param.value == "expected_output":
61
+ case_kwargs["expected_output"] = ground_truth
62
+ elif param.value == "context":
63
+ case_kwargs["context"] = kwargs.get("context")
64
+ elif param.value == "retrieval_context":
65
+ case_kwargs["retrieval_context"] = kwargs.get("retrieval_context")
66
+ elif param.value == "tools_called":
67
+ case_kwargs["tools_called"] = kwargs.get("tools_called")
68
+ elif param.value == "expected_tools":
69
+ case_kwargs["expected_tools"] = kwargs.get("expected_tools")
70
+ else:
71
+ case_kwargs = {
72
+ "input": input_msg,
73
+ "actual_output": output,
74
+ "expected_output": ground_truth,
75
+ }
76
+ if "input" not in case_kwargs:
77
+ case_kwargs["input"] = input_msg
78
+ if "actual_output" not in case_kwargs:
79
+ case_kwargs["actual_output"] = output
80
+ return case_kwargs
81
+
82
+ if isinstance(metric, BaseConversationalMetric):
83
+ turns = []
84
+ for i, msg in enumerate(messages):
85
+ turn_input = messages[i - 1].get("content", "") if i > 0 else ""
86
+ output_turn = msg.get("content", "")
87
+ input_msg_backup = input_msg
88
+ input_msg = turn_input
89
+ output = output_turn
90
+ turn_kwargs = _build_case_kwargs()
91
+ turns.append(LLMTestCase(**turn_kwargs))
92
+ input_msg = input_msg_backup
93
+ output = messages[-1].get("content", "")
94
+ test_case = ConversationalTestCase(turns=turns)
95
+ else:
96
+ case_kwargs = _build_case_kwargs()
97
+ test_case = LLMTestCase(**case_kwargs)
98
+
99
+ metric.measure(test_case, **kwargs)
100
+ score = float(metric.score or 0.0)
101
+ reason = getattr(metric, "reason", None)
102
+ name = _metric_name(metric)
103
+ metrics = {name: MetricResult(score=score, reason=reason or "", is_score_valid=True)}
104
+ return EvaluateResult(score=score, reason=reason, metrics=metrics)
105
+
106
+ return wrapped
@@ -0,0 +1,40 @@
1
+ from typing import Any, Callable, Dict, List, Union
2
+
3
+ from eval_protocol.models import EvaluateResult, MetricResult
4
+ from eval_protocol.typed_interface import reward_function
5
+
6
+ __all__ = ["adapt"]
7
+
8
+
9
+ def _convert_result(res: Dict[str, Any]) -> EvaluateResult:
10
+ score = float(res.get("score", 0.0))
11
+ reason = res.get("comment")
12
+ key = res.get("key", "openeval")
13
+ metrics = {key: MetricResult(score=score, reason=reason or "", is_score_valid=True)}
14
+ return EvaluateResult(score=score, reason=reason, metrics=metrics)
15
+
16
+
17
+ def adapt(openeval_fn: Callable[..., Union[Dict[str, Any], List[Dict[str, Any]]]]):
18
+ """Adapt an OpenEvals evaluator into a reward-kit reward function."""
19
+
20
+ @reward_function
21
+ def wrapped(
22
+ messages: List[Dict[str, str]],
23
+ ground_truth: Union[str, List[Dict[str, str]], None] = None,
24
+ **kwargs: Any,
25
+ ) -> EvaluateResult:
26
+ if not messages:
27
+ return EvaluateResult(score=0.0, reason="No messages", metrics={})
28
+ output = messages[-1].get("content", "")
29
+ reference = None
30
+ if isinstance(ground_truth, list):
31
+ if ground_truth:
32
+ reference = ground_truth[-1].get("content")
33
+ else:
34
+ reference = ground_truth
35
+ res = openeval_fn(outputs=output, reference_outputs=reference, **kwargs)
36
+ if isinstance(res, list):
37
+ res = res[0]
38
+ return _convert_result(res)
39
+
40
+ return wrapped