eval-protocol 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- development/__init__.py +1 -0
- development/normalize_sandbox_fusion.py +628 -0
- development/utils/__init__.py +1 -0
- development/utils/generate_api_key.py +31 -0
- development/utils/subprocess_manager.py +481 -0
- eval_protocol/__init__.py +86 -0
- eval_protocol/__main__.py +10 -0
- eval_protocol/_version.py +21 -0
- eval_protocol/adapters/__init__.py +1 -0
- eval_protocol/adapters/braintrust.py +8 -0
- eval_protocol/adapters/trl.py +8 -0
- eval_protocol/agent/__init__.py +29 -0
- eval_protocol/agent/models.py +69 -0
- eval_protocol/agent/orchestrator.py +893 -0
- eval_protocol/agent/resource_abc.py +89 -0
- eval_protocol/agent/resource_pool.py +184 -0
- eval_protocol/agent/resources/__init__.py +44 -0
- eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
- eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
- eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
- eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
- eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
- eval_protocol/agent/resources/docker_resource.py +479 -0
- eval_protocol/agent/resources/filesystem_resource.py +371 -0
- eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
- eval_protocol/agent/resources/http_rollout_resource.py +325 -0
- eval_protocol/agent/resources/python_state_resource.py +170 -0
- eval_protocol/agent/resources/sql_resource.py +271 -0
- eval_protocol/agent/task_manager.py +1064 -0
- eval_protocol/agent/tool_registry.py +111 -0
- eval_protocol/auth.py +156 -0
- eval_protocol/cli.py +425 -0
- eval_protocol/cli_commands/__init__.py +1 -0
- eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
- eval_protocol/cli_commands/common.py +242 -0
- eval_protocol/cli_commands/deploy.py +486 -0
- eval_protocol/cli_commands/deploy_mcp.py +287 -0
- eval_protocol/cli_commands/preview.py +186 -0
- eval_protocol/cli_commands/run_eval_cmd.py +202 -0
- eval_protocol/common_utils.py +36 -0
- eval_protocol/config.py +180 -0
- eval_protocol/datasets/__init__.py +1 -0
- eval_protocol/datasets/loader.py +521 -0
- eval_protocol/evaluation.py +1045 -0
- eval_protocol/execution/__init__.py +1 -0
- eval_protocol/execution/pipeline.py +920 -0
- eval_protocol/gcp_tools.py +484 -0
- eval_protocol/generation/cache.py +141 -0
- eval_protocol/generation/clients/base.py +67 -0
- eval_protocol/generation/clients.py +248 -0
- eval_protocol/generic_server.py +165 -0
- eval_protocol/integrations/__init__.py +12 -0
- eval_protocol/integrations/braintrust.py +51 -0
- eval_protocol/integrations/deepeval.py +106 -0
- eval_protocol/integrations/openeval.py +40 -0
- eval_protocol/integrations/trl.py +187 -0
- eval_protocol/mcp/__init__.py +48 -0
- eval_protocol/mcp/adapter.py +131 -0
- eval_protocol/mcp/client/__init__.py +12 -0
- eval_protocol/mcp/client/connection.py +499 -0
- eval_protocol/mcp/clients.py +195 -0
- eval_protocol/mcp/execution/__init__.py +23 -0
- eval_protocol/mcp/execution/base_policy.py +227 -0
- eval_protocol/mcp/execution/fireworks_policy.py +209 -0
- eval_protocol/mcp/execution/manager.py +506 -0
- eval_protocol/mcp/execution/policy.py +421 -0
- eval_protocol/mcp/grid_renderer.py +54 -0
- eval_protocol/mcp/mcpgym.py +637 -0
- eval_protocol/mcp/process_manager.py +177 -0
- eval_protocol/mcp/session/__init__.py +11 -0
- eval_protocol/mcp/session/manager.py +228 -0
- eval_protocol/mcp/simple_process_manager.py +291 -0
- eval_protocol/mcp/simulation_server.py +458 -0
- eval_protocol/mcp/types.py +80 -0
- eval_protocol/mcp_agent/__init__.py +1 -0
- eval_protocol/mcp_agent/config.py +147 -0
- eval_protocol/mcp_agent/intermediary_server.py +542 -0
- eval_protocol/mcp_agent/main.py +210 -0
- eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
- eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
- eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
- eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
- eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
- eval_protocol/mcp_agent/session.py +79 -0
- eval_protocol/mcp_env.py +304 -0
- eval_protocol/models.py +366 -0
- eval_protocol/packaging.py +219 -0
- eval_protocol/platform_api.py +360 -0
- eval_protocol/playback_policy.py +396 -0
- eval_protocol/resources.py +128 -0
- eval_protocol/reward_function.py +410 -0
- eval_protocol/rewards/__init__.py +94 -0
- eval_protocol/rewards/accuracy.py +454 -0
- eval_protocol/rewards/accuracy_length.py +173 -0
- eval_protocol/rewards/apps_coding_reward.py +331 -0
- eval_protocol/rewards/apps_execution_utils.py +149 -0
- eval_protocol/rewards/apps_testing_util.py +559 -0
- eval_protocol/rewards/bfcl_reward.py +313 -0
- eval_protocol/rewards/code_execution.py +1620 -0
- eval_protocol/rewards/code_execution_utils.py +72 -0
- eval_protocol/rewards/cpp_code.py +861 -0
- eval_protocol/rewards/deepcoder_reward.py +161 -0
- eval_protocol/rewards/format.py +129 -0
- eval_protocol/rewards/function_calling.py +541 -0
- eval_protocol/rewards/json_schema.py +422 -0
- eval_protocol/rewards/language_consistency.py +700 -0
- eval_protocol/rewards/lean_prover.py +479 -0
- eval_protocol/rewards/length.py +375 -0
- eval_protocol/rewards/list_comparison_math_reward.py +221 -0
- eval_protocol/rewards/math.py +762 -0
- eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
- eval_protocol/rewards/reasoning_steps.py +249 -0
- eval_protocol/rewards/repetition.py +342 -0
- eval_protocol/rewards/tag_count.py +162 -0
- eval_protocol/rl_processing.py +82 -0
- eval_protocol/server.py +271 -0
- eval_protocol/typed_interface.py +260 -0
- eval_protocol/utils/__init__.py +8 -0
- eval_protocol/utils/batch_evaluation.py +217 -0
- eval_protocol/utils/batch_transformation.py +205 -0
- eval_protocol/utils/dataset_helpers.py +112 -0
- eval_protocol/utils/module_loader.py +56 -0
- eval_protocol/utils/packaging_utils.py +108 -0
- eval_protocol/utils/static_policy.py +305 -0
- eval_protocol-0.0.3.dist-info/METADATA +635 -0
- eval_protocol-0.0.3.dist-info/RECORD +130 -0
- eval_protocol-0.0.3.dist-info/WHEEL +5 -0
- eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
- eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
- eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
from contextlib import AsyncExitStack
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
import aiohttp # Still needed for type hints if we expose the session, but primary interaction changes
|
|
8
|
+
import mcp.types # Reverted to mcp.types; Explicit import for clarity
|
|
9
|
+
from mcp import types as mcp_types # Reverted to mcp.types; Explicit import for clarity
|
|
10
|
+
from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession
|
|
11
|
+
from mcp.client.streamable_http import streamablehttp_client
|
|
12
|
+
from omegaconf import DictConfig
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class IntermediaryMCPClient:
|
|
18
|
+
"""
|
|
19
|
+
Client for interacting with the RewardKitIntermediaryServer using mcp.client components.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, intermediary_server_url: str):
|
|
23
|
+
if not intermediary_server_url:
|
|
24
|
+
raise ValueError("intermediary_server_url must be provided.")
|
|
25
|
+
self.server_url = intermediary_server_url.rstrip("/") # Should be like http://localhost:8001/mcp
|
|
26
|
+
|
|
27
|
+
self._exit_stack: Optional[AsyncExitStack] = None
|
|
28
|
+
self._mcp_session: Optional[ClientSession] = None
|
|
29
|
+
|
|
30
|
+
async def connect(self):
|
|
31
|
+
"""Establishes connection and MCP session."""
|
|
32
|
+
if self._mcp_session and not self._mcp_session.is_closed:
|
|
33
|
+
logger.debug("Already connected.")
|
|
34
|
+
return
|
|
35
|
+
|
|
36
|
+
self._exit_stack = AsyncExitStack()
|
|
37
|
+
try:
|
|
38
|
+
logger.debug(f"Attempting to connect to Intermediary MCP server at {self.server_url}")
|
|
39
|
+
read_stream, write_stream, http_session_info = await self._exit_stack.enter_async_context(
|
|
40
|
+
streamablehttp_client(self.server_url)
|
|
41
|
+
)
|
|
42
|
+
# http_session_info might contain the underlying aiohttp session if needed, and mcp_session_id
|
|
43
|
+
# logger.debug(f"Streamable HTTP transport established. HTTP session info: {http_session_info}")
|
|
44
|
+
|
|
45
|
+
self._mcp_session = await self._exit_stack.enter_async_context(
|
|
46
|
+
ClientSession(read_stream, write_stream, client_info=DEFAULT_CLIENT_INFO)
|
|
47
|
+
)
|
|
48
|
+
await self._mcp_session.initialize()
|
|
49
|
+
logger.info(f"IntermediaryMCPClient connected and MCP session initialized with {self.server_url}")
|
|
50
|
+
except Exception as e:
|
|
51
|
+
if self._exit_stack: # pragma: no cover
|
|
52
|
+
await self._exit_stack.aclose()
|
|
53
|
+
self._exit_stack = None
|
|
54
|
+
self._mcp_session = None
|
|
55
|
+
logger.error(
|
|
56
|
+
f"Failed to connect or initialize MCP session with {self.server_url}: {e}",
|
|
57
|
+
exc_info=True,
|
|
58
|
+
)
|
|
59
|
+
raise
|
|
60
|
+
|
|
61
|
+
async def close(self):
|
|
62
|
+
"""Closes the MCP session and underlying transport."""
|
|
63
|
+
if self._exit_stack:
|
|
64
|
+
logger.debug(f"Closing IntermediaryMCPClient connection to {self.server_url}")
|
|
65
|
+
await self._exit_stack.aclose()
|
|
66
|
+
self._exit_stack = None
|
|
67
|
+
self._mcp_session = None
|
|
68
|
+
logger.info(f"IntermediaryMCPClient connection to {self.server_url} closed.")
|
|
69
|
+
|
|
70
|
+
async def _ensure_connected(self):
|
|
71
|
+
# ClientSession doesn't have a public is_closed.
|
|
72
|
+
# We rely on _mcp_session being None or connect() re-establishing.
|
|
73
|
+
# The AsyncExitStack handles actual closure of resources.
|
|
74
|
+
if not self._mcp_session:
|
|
75
|
+
logger.debug("Session not established, attempting to connect...")
|
|
76
|
+
await self.connect()
|
|
77
|
+
|
|
78
|
+
# After attempting to connect, if _mcp_session is still None, it means connection failed.
|
|
79
|
+
if not self._mcp_session:
|
|
80
|
+
raise RuntimeError("Failed to establish or re-establish MCP session.")
|
|
81
|
+
|
|
82
|
+
async def _call_intermediary_tool(self, tool_name: str, tool_args_payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
83
|
+
"""
|
|
84
|
+
Helper to make a raw tool call to the intermediary server and parse the result.
|
|
85
|
+
The tool_args_payload is the "arguments" field for the intermediary's tool.
|
|
86
|
+
"""
|
|
87
|
+
await self._ensure_connected()
|
|
88
|
+
if not self._mcp_session: # For type checker
|
|
89
|
+
raise RuntimeError("MCP session not available after ensure_connected.")
|
|
90
|
+
|
|
91
|
+
logger.debug(f"Calling intermediary tool '{tool_name}' with payload: {tool_args_payload}")
|
|
92
|
+
|
|
93
|
+
mcp_response: mcp_types.CallToolResult = await self._mcp_session.call_tool(tool_name, tool_args_payload)
|
|
94
|
+
|
|
95
|
+
logger.debug(f"Raw MCP response from intermediary for '{tool_name}': {mcp_response}")
|
|
96
|
+
|
|
97
|
+
if mcp_response.isError or not mcp_response.content or not hasattr(mcp_response.content[0], "text"):
|
|
98
|
+
error_message = f"Tool call '{tool_name}' to intermediary failed."
|
|
99
|
+
if mcp_response.isError and mcp_response.content and hasattr(mcp_response.content[0], "text"):
|
|
100
|
+
error_message += f" Details: {mcp_response.content[0].text}"
|
|
101
|
+
elif mcp_response.isError:
|
|
102
|
+
error_message += " No detailed error message in content."
|
|
103
|
+
logger.error(error_message)
|
|
104
|
+
try:
|
|
105
|
+
if mcp_response.content and hasattr(mcp_response.content[0], "text"):
|
|
106
|
+
parsed_error = json.loads(mcp_response.content[0].text)
|
|
107
|
+
if isinstance(parsed_error, dict) and "error" in parsed_error:
|
|
108
|
+
raise RuntimeError(f"{error_message} Nested error: {parsed_error['error']}")
|
|
109
|
+
except (json.JSONDecodeError, TypeError):
|
|
110
|
+
pass
|
|
111
|
+
raise RuntimeError(error_message)
|
|
112
|
+
|
|
113
|
+
try:
|
|
114
|
+
parsed_result = json.loads(mcp_response.content[0].text)
|
|
115
|
+
logger.debug(f"Parsed JSON result from intermediary for '{tool_name}': {parsed_result}")
|
|
116
|
+
return parsed_result
|
|
117
|
+
except json.JSONDecodeError as e:
|
|
118
|
+
logger.error(
|
|
119
|
+
f"Failed to parse JSON from intermediary's tool '{tool_name}' response content: {mcp_response.content[0].text}. Error: {e}"
|
|
120
|
+
)
|
|
121
|
+
raise RuntimeError(f"Failed to parse JSON response from intermediary tool '{tool_name}'.")
|
|
122
|
+
|
|
123
|
+
async def initialize_session(self, backend_requests: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
124
|
+
"""
|
|
125
|
+
Initializes a session with the IntermediaryServer, requesting backend instances.
|
|
126
|
+
"""
|
|
127
|
+
payload_for_intermediary_tool = {"args": {"backends": backend_requests}}
|
|
128
|
+
return await self._call_intermediary_tool(
|
|
129
|
+
tool_name="initialize_session",
|
|
130
|
+
tool_args_payload=payload_for_intermediary_tool,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
async def call_backend_tool(
|
|
134
|
+
self,
|
|
135
|
+
rk_session_id: str,
|
|
136
|
+
instance_id: str,
|
|
137
|
+
backend_name_ref: str,
|
|
138
|
+
tool_name: str,
|
|
139
|
+
tool_args: Dict[str, Any],
|
|
140
|
+
) -> Dict[str, Any]:
|
|
141
|
+
"""
|
|
142
|
+
Calls a tool on a specific backend instance managed by the IntermediaryServer.
|
|
143
|
+
"""
|
|
144
|
+
payload_for_intermediary_tool = {
|
|
145
|
+
"args": {
|
|
146
|
+
"rk_session_id": rk_session_id,
|
|
147
|
+
"instance_id": instance_id,
|
|
148
|
+
"backend_name_ref": backend_name_ref,
|
|
149
|
+
"tool_name": tool_name,
|
|
150
|
+
"tool_args": tool_args,
|
|
151
|
+
}
|
|
152
|
+
}
|
|
153
|
+
return await self._call_intermediary_tool(
|
|
154
|
+
tool_name="call_backend_tool",
|
|
155
|
+
tool_args_payload=payload_for_intermediary_tool,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
async def list_backend_tools(
|
|
159
|
+
self, rk_session_id: str, instance_id: str, backend_name_ref: str
|
|
160
|
+
) -> mcp_types.ListToolsResult:
|
|
161
|
+
"""
|
|
162
|
+
Lists tools available on a specific backend instance via the IntermediaryServer.
|
|
163
|
+
"""
|
|
164
|
+
payload_for_intermediary_tool = {
|
|
165
|
+
"args": {
|
|
166
|
+
"rk_session_id": rk_session_id,
|
|
167
|
+
"instance_id": instance_id,
|
|
168
|
+
"backend_name_ref": backend_name_ref,
|
|
169
|
+
}
|
|
170
|
+
}
|
|
171
|
+
# _call_intermediary_tool returns a Dict[str, Any] which is the parsed JSON
|
|
172
|
+
# from the intermediary's response. This dict should be the model_dump of ListToolsResult.
|
|
173
|
+
raw_result_dict = await self._call_intermediary_tool(
|
|
174
|
+
tool_name="list_backend_tools",
|
|
175
|
+
tool_args_payload=payload_for_intermediary_tool,
|
|
176
|
+
)
|
|
177
|
+
# Parse the dictionary back into the Pydantic model
|
|
178
|
+
return mcp_types.ListToolsResult(**raw_result_dict)
|
|
179
|
+
|
|
180
|
+
async def cleanup_session(self, rk_session_id: str) -> Dict[str, Any]:
|
|
181
|
+
"""
|
|
182
|
+
Cleans up a session on the IntermediaryServer.
|
|
183
|
+
"""
|
|
184
|
+
payload_for_intermediary_tool = {"args": {"rk_session_id": rk_session_id}}
|
|
185
|
+
result = await self._call_intermediary_tool(
|
|
186
|
+
tool_name="cleanup_session", tool_args_payload=payload_for_intermediary_tool
|
|
187
|
+
)
|
|
188
|
+
return result
|
|
189
|
+
|
|
190
|
+
async def __aenter__(self):
|
|
191
|
+
await self.connect()
|
|
192
|
+
return self
|
|
193
|
+
|
|
194
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
195
|
+
await self.close()
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MCP Execution Framework
|
|
3
|
+
|
|
4
|
+
This module handles policy execution, tool calling, and rollout coordination.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .base_policy import LLMBasePolicy
|
|
8
|
+
from .policy import AnthropicPolicy, OpenAIPolicy, FireworksPolicy
|
|
9
|
+
from .manager import ExecutionManager
|
|
10
|
+
|
|
11
|
+
# FireworksPolicy is conditionally imported by policy.py
|
|
12
|
+
_FIREWORKS_AVAILABLE = FireworksPolicy is not None
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"LLMBasePolicy",
|
|
16
|
+
"AnthropicPolicy",
|
|
17
|
+
"OpenAIPolicy",
|
|
18
|
+
"ExecutionManager",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
# Only export FireworksPolicy if it's available
|
|
22
|
+
if _FIREWORKS_AVAILABLE:
|
|
23
|
+
__all__.append("FireworksPolicy")
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base Policy for LLM Policies
|
|
3
|
+
|
|
4
|
+
This module contains the LLMBasePolicy abstract base class that provides
|
|
5
|
+
common functionality for all LLM-based policies (Fireworks, OpenAI, Anthropic, etc.)
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
import os
|
|
12
|
+
from abc import ABC, abstractmethod
|
|
13
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
14
|
+
|
|
15
|
+
from ...playback_policy import PlaybackPolicyBase
|
|
16
|
+
from ..types import LLMUsageStats, MCPToolCall
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class LLMBasePolicy(PlaybackPolicyBase, ABC):
|
|
22
|
+
"""
|
|
23
|
+
Base class for LLM policies that work with MCP environments via tool calling.
|
|
24
|
+
|
|
25
|
+
This abstraction enables shared code between FireworksPolicy and OpenAIPolicy.
|
|
26
|
+
Maintains conversation history per environment for proper OpenAI-style trajectories.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
model_id: str,
|
|
32
|
+
temperature: float = 0.2,
|
|
33
|
+
max_tokens: int = 4096,
|
|
34
|
+
max_tools_per_turn: Optional[int] = None,
|
|
35
|
+
**kwargs,
|
|
36
|
+
):
|
|
37
|
+
"""
|
|
38
|
+
Initialize base policy with automatic record/playback detection.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
model_id: Model identifier
|
|
42
|
+
temperature: Sampling temperature (0.0 to 2.0)
|
|
43
|
+
max_tokens: Maximum tokens to generate per request
|
|
44
|
+
max_tools_per_turn: Maximum number of tool calls per turn (None = unlimited, 1 = single tool)
|
|
45
|
+
"""
|
|
46
|
+
# Initialize playback functionality (parent class handles EP_PLAYBACK_FILE automatically)
|
|
47
|
+
super().__init__(**kwargs)
|
|
48
|
+
|
|
49
|
+
# Store policy configuration
|
|
50
|
+
self.model_id = model_id
|
|
51
|
+
self.temperature = temperature
|
|
52
|
+
self.max_tokens = max_tokens
|
|
53
|
+
self.max_tools_per_turn = max_tools_per_turn
|
|
54
|
+
|
|
55
|
+
# Initialize conversation state tracking for proper OpenAI trajectories
|
|
56
|
+
self.initialized = False
|
|
57
|
+
|
|
58
|
+
@abstractmethod
|
|
59
|
+
async def _make_llm_call(self, messages: List[Dict], tools: List[Dict]) -> Dict:
|
|
60
|
+
"""
|
|
61
|
+
Make an LLM API call. Subclasses must implement this.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
messages: Conversation messages
|
|
65
|
+
tools: Available tools in OpenAI format
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
LLM response with choices[0].message containing content and tool_calls
|
|
69
|
+
"""
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
@abstractmethod
|
|
73
|
+
def _convert_mcp_tools_to_llm_format(self, mcp_tools: List[Dict]) -> List[Dict]:
|
|
74
|
+
"""
|
|
75
|
+
Convert MCP tool schemas to LLM-specific format.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
mcp_tools: List of MCP tool definitions
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
List of LLM-compatible tool definitions
|
|
82
|
+
"""
|
|
83
|
+
pass
|
|
84
|
+
|
|
85
|
+
def add_tool_response(
|
|
86
|
+
self,
|
|
87
|
+
env_index: int,
|
|
88
|
+
tool_call: MCPToolCall,
|
|
89
|
+
tool_response: Union[str, List[Dict[str, Any]]],
|
|
90
|
+
conversation_history: List[Dict[str, Any]],
|
|
91
|
+
reward: float = 0.0,
|
|
92
|
+
terminated: bool = False,
|
|
93
|
+
info: Optional[Dict[str, Any]] = None,
|
|
94
|
+
):
|
|
95
|
+
"""Add tool call and response to conversation history with control plane metadata."""
|
|
96
|
+
# Use the preserved tool_call_id directly
|
|
97
|
+
if tool_call.tool_call_id is None:
|
|
98
|
+
raise ValueError("Tool call ID is required for tool response recording")
|
|
99
|
+
|
|
100
|
+
tool_message = {
|
|
101
|
+
"role": "tool",
|
|
102
|
+
"tool_call_id": tool_call.tool_call_id,
|
|
103
|
+
"content": tool_response,
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
# Add control plane metadata if provided
|
|
107
|
+
if reward != 0.0 or terminated or info:
|
|
108
|
+
|
|
109
|
+
tool_message["metadata"] = {
|
|
110
|
+
"reward": reward,
|
|
111
|
+
"terminated": terminated,
|
|
112
|
+
"info": info or {},
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
conversation_history.append(tool_message)
|
|
116
|
+
|
|
117
|
+
def log_conversation_state_for_playback(
|
|
118
|
+
self, env_index: int, step: int, conversation_history: List[Dict[str, Any]]
|
|
119
|
+
):
|
|
120
|
+
"""
|
|
121
|
+
Log the current conversation state in the format required for playback.
|
|
122
|
+
|
|
123
|
+
Expected format: {"env_index": 0, "step": 0, "messages": [{..}, {..}]}
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
env_index: Environment index
|
|
127
|
+
step: Current step number
|
|
128
|
+
"""
|
|
129
|
+
# Use EP_PLAYBACK_FILE environment variable for recording
|
|
130
|
+
playback_file = os.environ.get("EP_PLAYBACK_FILE")
|
|
131
|
+
if not playback_file:
|
|
132
|
+
return # No recording file specified
|
|
133
|
+
|
|
134
|
+
playback_entry = {
|
|
135
|
+
"env_index": env_index,
|
|
136
|
+
"step": step,
|
|
137
|
+
"messages": conversation_history.copy(),
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
# TODO: because we're using threads now, the ordering will be wrong.
|
|
141
|
+
|
|
142
|
+
with open(playback_file, "a") as f:
|
|
143
|
+
f.write(json.dumps(playback_entry) + "\n")
|
|
144
|
+
|
|
145
|
+
async def _generate_live_tool_calls(
|
|
146
|
+
self,
|
|
147
|
+
tool_schemas: List[Dict],
|
|
148
|
+
env_index: int,
|
|
149
|
+
conversation_history: List[Dict[str, Any]],
|
|
150
|
+
) -> Tuple[List[MCPToolCall], LLMUsageStats]:
|
|
151
|
+
"""
|
|
152
|
+
Generate tool calls using conversation history for proper OpenAI trajectories.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
tool_schemas: Available MCP tools for this environment
|
|
156
|
+
env_index: Environment index
|
|
157
|
+
user_prompt: Current user prompt with observation
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
List of MCPToolCall objects
|
|
161
|
+
"""
|
|
162
|
+
# Convert MCP tools to LLM format
|
|
163
|
+
llm_tools = self._convert_mcp_tools_to_llm_format(tool_schemas)
|
|
164
|
+
|
|
165
|
+
logger.debug(
|
|
166
|
+
f"Environment {env_index} - Converted {len(tool_schemas)} MCP tools to {len(llm_tools)} LLM tools"
|
|
167
|
+
)
|
|
168
|
+
logger.debug(f"Environment {env_index} - Conversation length: {len(conversation_history)} messages")
|
|
169
|
+
|
|
170
|
+
try:
|
|
171
|
+
# Make API call with conversation history
|
|
172
|
+
response = await self._make_llm_call(conversation_history, llm_tools)
|
|
173
|
+
except Exception as e:
|
|
174
|
+
logger.error(f"LLM API call failed for env {env_index}: {e}")
|
|
175
|
+
raise e
|
|
176
|
+
|
|
177
|
+
# ADD ASSISTANT MESSAGE TO ACTUAL CONVERSATION HISTORY
|
|
178
|
+
# This is crucial for proper tool call ID management in add_tool_response
|
|
179
|
+
assistant_message_for_history = {
|
|
180
|
+
"role": "assistant",
|
|
181
|
+
"content": response["choices"][0]["message"]["content"],
|
|
182
|
+
}
|
|
183
|
+
usage_stats = LLMUsageStats(
|
|
184
|
+
prompt_tokens=response["usage"]["prompt_tokens"],
|
|
185
|
+
completion_tokens=response["usage"]["completion_tokens"],
|
|
186
|
+
total_tokens=response["usage"]["total_tokens"],
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# Extract tool call from response
|
|
190
|
+
message = response["choices"][0]["message"]
|
|
191
|
+
logger.debug(f"Environment {env_index} - Response message: {message}")
|
|
192
|
+
|
|
193
|
+
# Add ALL tool calls if present with the actual API response IDs
|
|
194
|
+
if message.get("tool_calls"):
|
|
195
|
+
assistant_message_for_history["tool_calls"] = message["tool_calls"]
|
|
196
|
+
|
|
197
|
+
# Add to actual conversation history
|
|
198
|
+
conversation_history.append(assistant_message_for_history)
|
|
199
|
+
|
|
200
|
+
if message.get("tool_calls") and len(message["tool_calls"]) > 0:
|
|
201
|
+
tool_calls = message["tool_calls"]
|
|
202
|
+
|
|
203
|
+
# Handle multiple tool calls - create MCPToolCall for each
|
|
204
|
+
mcp_tool_calls = []
|
|
205
|
+
for tool_call in tool_calls:
|
|
206
|
+
mcp_tool_call = MCPToolCall(
|
|
207
|
+
tool_name=tool_call["function"]["name"],
|
|
208
|
+
arguments=json.loads(tool_call["function"]["arguments"]),
|
|
209
|
+
tool_call_id=tool_call["id"],
|
|
210
|
+
)
|
|
211
|
+
mcp_tool_calls.append(mcp_tool_call)
|
|
212
|
+
|
|
213
|
+
if self.max_tools_per_turn:
|
|
214
|
+
mcp_tool_calls = mcp_tool_calls[: self.max_tools_per_turn]
|
|
215
|
+
|
|
216
|
+
return mcp_tool_calls, usage_stats
|
|
217
|
+
else:
|
|
218
|
+
# No tool calls in response - this is normal when episode ends or LLM provides only text
|
|
219
|
+
logger.info(f"No tool calls in response for env {env_index}, message content: {message.get('content')}")
|
|
220
|
+
return [
|
|
221
|
+
MCPToolCall(
|
|
222
|
+
tool_name="_no_tool_call",
|
|
223
|
+
arguments={
|
|
224
|
+
"reason": "no_tool_call_generated",
|
|
225
|
+
},
|
|
226
|
+
)
|
|
227
|
+
], usage_stats
|
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Fireworks AI Policy Implementation
|
|
3
|
+
|
|
4
|
+
This module contains the FireworksPolicy class that integrates with Fireworks AI's LLM API
|
|
5
|
+
for tool calling and conversation management in MCP environments.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
import os
|
|
12
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
13
|
+
from typing import Any, Dict, List, Optional
|
|
14
|
+
|
|
15
|
+
from .base_policy import LLMBasePolicy
|
|
16
|
+
from ..types import MCPToolCall
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class FireworksPolicy(LLMBasePolicy):
|
|
22
|
+
"""
|
|
23
|
+
Fireworks AI policy implementation that works with ANY MCP environment via tool calling.
|
|
24
|
+
|
|
25
|
+
NO environment-specific logic - everything comes from MCP tools and dataset prompts.
|
|
26
|
+
Supports both live mode (using Fireworks LLM) and playback mode (replaying recorded trajectories).
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from fireworks import DeploymentTypeLiteral
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
model_id: str,
|
|
34
|
+
temperature: float = 0.2,
|
|
35
|
+
deployment_type: DeploymentTypeLiteral = "serverless",
|
|
36
|
+
max_tokens: int = 4096,
|
|
37
|
+
max_tools_per_turn: Optional[int] = None,
|
|
38
|
+
**kwargs,
|
|
39
|
+
):
|
|
40
|
+
"""
|
|
41
|
+
Initialize Fireworks policy.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
model_id: Fireworks model identifier (e.g., "accounts/fireworks/models/qwen3-235b-a22b")
|
|
45
|
+
temperature: Sampling temperature (0.0 to 2.0)
|
|
46
|
+
deployment_type: "serverless", "on-demand", "auto", or "on-demand-lora"
|
|
47
|
+
max_tokens: Maximum tokens to generate per request
|
|
48
|
+
max_tools_per_turn: Maximum number of tool calls per turn (None = unlimited, 1 = single tool)
|
|
49
|
+
"""
|
|
50
|
+
super().__init__(model_id, temperature, max_tokens, max_tools_per_turn, **kwargs)
|
|
51
|
+
|
|
52
|
+
self.deployment_type = deployment_type
|
|
53
|
+
|
|
54
|
+
# Only initialize Fireworks LLM in live mode (not in playback mode)
|
|
55
|
+
if not self._is_playback:
|
|
56
|
+
# Import Fireworks Build SDK - optional at module level
|
|
57
|
+
try:
|
|
58
|
+
from fireworks import LLM
|
|
59
|
+
except ImportError:
|
|
60
|
+
raise ImportError(
|
|
61
|
+
"The 'fireworks-ai' package is required for FireworksPolicy. "
|
|
62
|
+
"Please install it with 'pip install fireworks-ai'"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Verify authentication
|
|
66
|
+
from ...auth import get_fireworks_api_key
|
|
67
|
+
|
|
68
|
+
api_key = get_fireworks_api_key()
|
|
69
|
+
if not api_key:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
"FIREWORKS_API_KEY environment variable or ~/.fireworks/auth.ini file is required "
|
|
72
|
+
"to use FireworksPolicy. See the reward-kit documentation for setup instructions."
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Set the API key for the Fireworks SDK
|
|
76
|
+
os.environ["FIREWORKS_API_KEY"] = api_key
|
|
77
|
+
|
|
78
|
+
# Initialize the LLM instance using Build SDK
|
|
79
|
+
try:
|
|
80
|
+
self.llm = LLM(
|
|
81
|
+
model=self.model_id,
|
|
82
|
+
deployment_type=self.deployment_type,
|
|
83
|
+
temperature=self.temperature,
|
|
84
|
+
)
|
|
85
|
+
logger.info(f"✅ Initialized Fireworks LLM: {self.model_id} ({self.deployment_type})")
|
|
86
|
+
except Exception as e:
|
|
87
|
+
raise RuntimeError(f"Failed to initialize Fireworks LLM '{self.model_id}': {e}")
|
|
88
|
+
# Create dedicated executor for non-blocking LLM calls
|
|
89
|
+
self.llm_executor = ThreadPoolExecutor(
|
|
90
|
+
max_workers=16, # Allow up to 16 concurrent LLM API calls
|
|
91
|
+
thread_name_prefix="fireworks-api",
|
|
92
|
+
)
|
|
93
|
+
else:
|
|
94
|
+
# In playback mode, skip expensive LLM initialization
|
|
95
|
+
self.llm = None
|
|
96
|
+
logger.info(f"🎬 Playback mode: Skipping Fireworks LLM initialization for performance")
|
|
97
|
+
|
|
98
|
+
def __del__(self):
|
|
99
|
+
"""Clean up executor on garbage collection."""
|
|
100
|
+
if hasattr(self, "llm_executor"):
|
|
101
|
+
self.llm_executor.shutdown(wait=False)
|
|
102
|
+
|
|
103
|
+
def _clean_messages_for_api(self, messages: List[Dict]) -> List[Dict]:
|
|
104
|
+
"""
|
|
105
|
+
Clean messages by removing metadata fields that Fireworks API doesn't accept.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
messages: Conversation messages with potential metadata
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Clean messages without metadata fields
|
|
112
|
+
"""
|
|
113
|
+
clean_messages = []
|
|
114
|
+
for msg in messages:
|
|
115
|
+
clean_msg = msg.copy()
|
|
116
|
+
# Remove metadata field if present
|
|
117
|
+
if "metadata" in clean_msg:
|
|
118
|
+
del clean_msg["metadata"]
|
|
119
|
+
clean_messages.append(clean_msg)
|
|
120
|
+
return clean_messages
|
|
121
|
+
|
|
122
|
+
async def _make_llm_call(self, messages: List[Dict], tools: List[Dict]) -> Dict:
|
|
123
|
+
"""
|
|
124
|
+
Make a Fireworks API call.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
messages: Conversation messages (may contain metadata)
|
|
128
|
+
tools: Available tools in OpenAI format
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
API response in OpenAI format
|
|
132
|
+
"""
|
|
133
|
+
llm = self.llm
|
|
134
|
+
if llm is None:
|
|
135
|
+
raise RuntimeError("Fireworks LLM not initialized")
|
|
136
|
+
|
|
137
|
+
# Clean messages by removing metadata before sending to API
|
|
138
|
+
clean_messages = self._clean_messages_for_api(messages)
|
|
139
|
+
|
|
140
|
+
current_request = {
|
|
141
|
+
"messages": clean_messages,
|
|
142
|
+
"tools": tools,
|
|
143
|
+
"temperature": self.temperature,
|
|
144
|
+
"max_tokens": self.max_tokens,
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
loop = asyncio.get_event_loop()
|
|
148
|
+
response = await loop.run_in_executor(
|
|
149
|
+
self.llm_executor, lambda: llm.chat.completions.create(**current_request)
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# Convert Fireworks response to standard format
|
|
153
|
+
return {
|
|
154
|
+
"choices": [
|
|
155
|
+
{
|
|
156
|
+
"message": {
|
|
157
|
+
"content": response.choices[0].message.content,
|
|
158
|
+
"tool_calls": (
|
|
159
|
+
[
|
|
160
|
+
{
|
|
161
|
+
"id": tc.id,
|
|
162
|
+
"type": tc.type,
|
|
163
|
+
"function": {
|
|
164
|
+
"name": tc.function.name,
|
|
165
|
+
"arguments": tc.function.arguments,
|
|
166
|
+
},
|
|
167
|
+
}
|
|
168
|
+
for tc in (response.choices[0].message.tool_calls or [])
|
|
169
|
+
]
|
|
170
|
+
if response.choices[0].message.tool_calls
|
|
171
|
+
else []
|
|
172
|
+
),
|
|
173
|
+
}
|
|
174
|
+
}
|
|
175
|
+
],
|
|
176
|
+
"usage": {
|
|
177
|
+
"prompt_tokens": response.usage.prompt_tokens,
|
|
178
|
+
"completion_tokens": response.usage.completion_tokens,
|
|
179
|
+
"total_tokens": response.usage.total_tokens,
|
|
180
|
+
},
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
def _convert_mcp_tools_to_llm_format(self, mcp_tools: List[Dict]) -> List[Dict]:
|
|
184
|
+
"""
|
|
185
|
+
Convert MCP tool schemas to OpenAI function calling format for Fireworks.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
mcp_tools: List of MCP tool definitions
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
List of OpenAI-compatible tool definitions
|
|
192
|
+
"""
|
|
193
|
+
openai_tools = []
|
|
194
|
+
|
|
195
|
+
for mcp_tool in mcp_tools:
|
|
196
|
+
openai_tool = {
|
|
197
|
+
"type": "function",
|
|
198
|
+
"function": {
|
|
199
|
+
"name": mcp_tool["name"],
|
|
200
|
+
"description": mcp_tool.get("description", f"Execute {mcp_tool['name']} action"),
|
|
201
|
+
"parameters": mcp_tool.get(
|
|
202
|
+
"input_schema",
|
|
203
|
+
{"type": "object", "properties": {}, "required": []},
|
|
204
|
+
),
|
|
205
|
+
},
|
|
206
|
+
}
|
|
207
|
+
openai_tools.append(openai_tool)
|
|
208
|
+
|
|
209
|
+
return openai_tools
|