eval-protocol 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. development/__init__.py +1 -0
  2. development/normalize_sandbox_fusion.py +628 -0
  3. development/utils/__init__.py +1 -0
  4. development/utils/generate_api_key.py +31 -0
  5. development/utils/subprocess_manager.py +481 -0
  6. eval_protocol/__init__.py +86 -0
  7. eval_protocol/__main__.py +10 -0
  8. eval_protocol/_version.py +21 -0
  9. eval_protocol/adapters/__init__.py +1 -0
  10. eval_protocol/adapters/braintrust.py +8 -0
  11. eval_protocol/adapters/trl.py +8 -0
  12. eval_protocol/agent/__init__.py +29 -0
  13. eval_protocol/agent/models.py +69 -0
  14. eval_protocol/agent/orchestrator.py +893 -0
  15. eval_protocol/agent/resource_abc.py +89 -0
  16. eval_protocol/agent/resource_pool.py +184 -0
  17. eval_protocol/agent/resources/__init__.py +44 -0
  18. eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
  19. eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
  20. eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
  21. eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
  22. eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
  23. eval_protocol/agent/resources/docker_resource.py +479 -0
  24. eval_protocol/agent/resources/filesystem_resource.py +371 -0
  25. eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
  26. eval_protocol/agent/resources/http_rollout_resource.py +325 -0
  27. eval_protocol/agent/resources/python_state_resource.py +170 -0
  28. eval_protocol/agent/resources/sql_resource.py +271 -0
  29. eval_protocol/agent/task_manager.py +1064 -0
  30. eval_protocol/agent/tool_registry.py +111 -0
  31. eval_protocol/auth.py +156 -0
  32. eval_protocol/cli.py +425 -0
  33. eval_protocol/cli_commands/__init__.py +1 -0
  34. eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
  35. eval_protocol/cli_commands/common.py +242 -0
  36. eval_protocol/cli_commands/deploy.py +486 -0
  37. eval_protocol/cli_commands/deploy_mcp.py +287 -0
  38. eval_protocol/cli_commands/preview.py +186 -0
  39. eval_protocol/cli_commands/run_eval_cmd.py +202 -0
  40. eval_protocol/common_utils.py +36 -0
  41. eval_protocol/config.py +180 -0
  42. eval_protocol/datasets/__init__.py +1 -0
  43. eval_protocol/datasets/loader.py +521 -0
  44. eval_protocol/evaluation.py +1045 -0
  45. eval_protocol/execution/__init__.py +1 -0
  46. eval_protocol/execution/pipeline.py +920 -0
  47. eval_protocol/gcp_tools.py +484 -0
  48. eval_protocol/generation/cache.py +141 -0
  49. eval_protocol/generation/clients/base.py +67 -0
  50. eval_protocol/generation/clients.py +248 -0
  51. eval_protocol/generic_server.py +165 -0
  52. eval_protocol/integrations/__init__.py +12 -0
  53. eval_protocol/integrations/braintrust.py +51 -0
  54. eval_protocol/integrations/deepeval.py +106 -0
  55. eval_protocol/integrations/openeval.py +40 -0
  56. eval_protocol/integrations/trl.py +187 -0
  57. eval_protocol/mcp/__init__.py +48 -0
  58. eval_protocol/mcp/adapter.py +131 -0
  59. eval_protocol/mcp/client/__init__.py +12 -0
  60. eval_protocol/mcp/client/connection.py +499 -0
  61. eval_protocol/mcp/clients.py +195 -0
  62. eval_protocol/mcp/execution/__init__.py +23 -0
  63. eval_protocol/mcp/execution/base_policy.py +227 -0
  64. eval_protocol/mcp/execution/fireworks_policy.py +209 -0
  65. eval_protocol/mcp/execution/manager.py +506 -0
  66. eval_protocol/mcp/execution/policy.py +421 -0
  67. eval_protocol/mcp/grid_renderer.py +54 -0
  68. eval_protocol/mcp/mcpgym.py +637 -0
  69. eval_protocol/mcp/process_manager.py +177 -0
  70. eval_protocol/mcp/session/__init__.py +11 -0
  71. eval_protocol/mcp/session/manager.py +228 -0
  72. eval_protocol/mcp/simple_process_manager.py +291 -0
  73. eval_protocol/mcp/simulation_server.py +458 -0
  74. eval_protocol/mcp/types.py +80 -0
  75. eval_protocol/mcp_agent/__init__.py +1 -0
  76. eval_protocol/mcp_agent/config.py +147 -0
  77. eval_protocol/mcp_agent/intermediary_server.py +542 -0
  78. eval_protocol/mcp_agent/main.py +210 -0
  79. eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
  80. eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
  81. eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
  82. eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
  83. eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
  84. eval_protocol/mcp_agent/session.py +79 -0
  85. eval_protocol/mcp_env.py +304 -0
  86. eval_protocol/models.py +366 -0
  87. eval_protocol/packaging.py +219 -0
  88. eval_protocol/platform_api.py +360 -0
  89. eval_protocol/playback_policy.py +396 -0
  90. eval_protocol/resources.py +128 -0
  91. eval_protocol/reward_function.py +410 -0
  92. eval_protocol/rewards/__init__.py +94 -0
  93. eval_protocol/rewards/accuracy.py +454 -0
  94. eval_protocol/rewards/accuracy_length.py +173 -0
  95. eval_protocol/rewards/apps_coding_reward.py +331 -0
  96. eval_protocol/rewards/apps_execution_utils.py +149 -0
  97. eval_protocol/rewards/apps_testing_util.py +559 -0
  98. eval_protocol/rewards/bfcl_reward.py +313 -0
  99. eval_protocol/rewards/code_execution.py +1620 -0
  100. eval_protocol/rewards/code_execution_utils.py +72 -0
  101. eval_protocol/rewards/cpp_code.py +861 -0
  102. eval_protocol/rewards/deepcoder_reward.py +161 -0
  103. eval_protocol/rewards/format.py +129 -0
  104. eval_protocol/rewards/function_calling.py +541 -0
  105. eval_protocol/rewards/json_schema.py +422 -0
  106. eval_protocol/rewards/language_consistency.py +700 -0
  107. eval_protocol/rewards/lean_prover.py +479 -0
  108. eval_protocol/rewards/length.py +375 -0
  109. eval_protocol/rewards/list_comparison_math_reward.py +221 -0
  110. eval_protocol/rewards/math.py +762 -0
  111. eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
  112. eval_protocol/rewards/reasoning_steps.py +249 -0
  113. eval_protocol/rewards/repetition.py +342 -0
  114. eval_protocol/rewards/tag_count.py +162 -0
  115. eval_protocol/rl_processing.py +82 -0
  116. eval_protocol/server.py +271 -0
  117. eval_protocol/typed_interface.py +260 -0
  118. eval_protocol/utils/__init__.py +8 -0
  119. eval_protocol/utils/batch_evaluation.py +217 -0
  120. eval_protocol/utils/batch_transformation.py +205 -0
  121. eval_protocol/utils/dataset_helpers.py +112 -0
  122. eval_protocol/utils/module_loader.py +56 -0
  123. eval_protocol/utils/packaging_utils.py +108 -0
  124. eval_protocol/utils/static_policy.py +305 -0
  125. eval_protocol-0.0.3.dist-info/METADATA +635 -0
  126. eval_protocol-0.0.3.dist-info/RECORD +130 -0
  127. eval_protocol-0.0.3.dist-info/WHEEL +5 -0
  128. eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
  129. eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
  130. eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
@@ -0,0 +1,195 @@
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ from contextlib import AsyncExitStack
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ import aiohttp # Still needed for type hints if we expose the session, but primary interaction changes
8
+ import mcp.types # Reverted to mcp.types; Explicit import for clarity
9
+ from mcp import types as mcp_types # Reverted to mcp.types; Explicit import for clarity
10
+ from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession
11
+ from mcp.client.streamable_http import streamablehttp_client
12
+ from omegaconf import DictConfig
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class IntermediaryMCPClient:
18
+ """
19
+ Client for interacting with the RewardKitIntermediaryServer using mcp.client components.
20
+ """
21
+
22
+ def __init__(self, intermediary_server_url: str):
23
+ if not intermediary_server_url:
24
+ raise ValueError("intermediary_server_url must be provided.")
25
+ self.server_url = intermediary_server_url.rstrip("/") # Should be like http://localhost:8001/mcp
26
+
27
+ self._exit_stack: Optional[AsyncExitStack] = None
28
+ self._mcp_session: Optional[ClientSession] = None
29
+
30
+ async def connect(self):
31
+ """Establishes connection and MCP session."""
32
+ if self._mcp_session and not self._mcp_session.is_closed:
33
+ logger.debug("Already connected.")
34
+ return
35
+
36
+ self._exit_stack = AsyncExitStack()
37
+ try:
38
+ logger.debug(f"Attempting to connect to Intermediary MCP server at {self.server_url}")
39
+ read_stream, write_stream, http_session_info = await self._exit_stack.enter_async_context(
40
+ streamablehttp_client(self.server_url)
41
+ )
42
+ # http_session_info might contain the underlying aiohttp session if needed, and mcp_session_id
43
+ # logger.debug(f"Streamable HTTP transport established. HTTP session info: {http_session_info}")
44
+
45
+ self._mcp_session = await self._exit_stack.enter_async_context(
46
+ ClientSession(read_stream, write_stream, client_info=DEFAULT_CLIENT_INFO)
47
+ )
48
+ await self._mcp_session.initialize()
49
+ logger.info(f"IntermediaryMCPClient connected and MCP session initialized with {self.server_url}")
50
+ except Exception as e:
51
+ if self._exit_stack: # pragma: no cover
52
+ await self._exit_stack.aclose()
53
+ self._exit_stack = None
54
+ self._mcp_session = None
55
+ logger.error(
56
+ f"Failed to connect or initialize MCP session with {self.server_url}: {e}",
57
+ exc_info=True,
58
+ )
59
+ raise
60
+
61
+ async def close(self):
62
+ """Closes the MCP session and underlying transport."""
63
+ if self._exit_stack:
64
+ logger.debug(f"Closing IntermediaryMCPClient connection to {self.server_url}")
65
+ await self._exit_stack.aclose()
66
+ self._exit_stack = None
67
+ self._mcp_session = None
68
+ logger.info(f"IntermediaryMCPClient connection to {self.server_url} closed.")
69
+
70
+ async def _ensure_connected(self):
71
+ # ClientSession doesn't have a public is_closed.
72
+ # We rely on _mcp_session being None or connect() re-establishing.
73
+ # The AsyncExitStack handles actual closure of resources.
74
+ if not self._mcp_session:
75
+ logger.debug("Session not established, attempting to connect...")
76
+ await self.connect()
77
+
78
+ # After attempting to connect, if _mcp_session is still None, it means connection failed.
79
+ if not self._mcp_session:
80
+ raise RuntimeError("Failed to establish or re-establish MCP session.")
81
+
82
+ async def _call_intermediary_tool(self, tool_name: str, tool_args_payload: Dict[str, Any]) -> Dict[str, Any]:
83
+ """
84
+ Helper to make a raw tool call to the intermediary server and parse the result.
85
+ The tool_args_payload is the "arguments" field for the intermediary's tool.
86
+ """
87
+ await self._ensure_connected()
88
+ if not self._mcp_session: # For type checker
89
+ raise RuntimeError("MCP session not available after ensure_connected.")
90
+
91
+ logger.debug(f"Calling intermediary tool '{tool_name}' with payload: {tool_args_payload}")
92
+
93
+ mcp_response: mcp_types.CallToolResult = await self._mcp_session.call_tool(tool_name, tool_args_payload)
94
+
95
+ logger.debug(f"Raw MCP response from intermediary for '{tool_name}': {mcp_response}")
96
+
97
+ if mcp_response.isError or not mcp_response.content or not hasattr(mcp_response.content[0], "text"):
98
+ error_message = f"Tool call '{tool_name}' to intermediary failed."
99
+ if mcp_response.isError and mcp_response.content and hasattr(mcp_response.content[0], "text"):
100
+ error_message += f" Details: {mcp_response.content[0].text}"
101
+ elif mcp_response.isError:
102
+ error_message += " No detailed error message in content."
103
+ logger.error(error_message)
104
+ try:
105
+ if mcp_response.content and hasattr(mcp_response.content[0], "text"):
106
+ parsed_error = json.loads(mcp_response.content[0].text)
107
+ if isinstance(parsed_error, dict) and "error" in parsed_error:
108
+ raise RuntimeError(f"{error_message} Nested error: {parsed_error['error']}")
109
+ except (json.JSONDecodeError, TypeError):
110
+ pass
111
+ raise RuntimeError(error_message)
112
+
113
+ try:
114
+ parsed_result = json.loads(mcp_response.content[0].text)
115
+ logger.debug(f"Parsed JSON result from intermediary for '{tool_name}': {parsed_result}")
116
+ return parsed_result
117
+ except json.JSONDecodeError as e:
118
+ logger.error(
119
+ f"Failed to parse JSON from intermediary's tool '{tool_name}' response content: {mcp_response.content[0].text}. Error: {e}"
120
+ )
121
+ raise RuntimeError(f"Failed to parse JSON response from intermediary tool '{tool_name}'.")
122
+
123
+ async def initialize_session(self, backend_requests: List[Dict[str, Any]]) -> Dict[str, Any]:
124
+ """
125
+ Initializes a session with the IntermediaryServer, requesting backend instances.
126
+ """
127
+ payload_for_intermediary_tool = {"args": {"backends": backend_requests}}
128
+ return await self._call_intermediary_tool(
129
+ tool_name="initialize_session",
130
+ tool_args_payload=payload_for_intermediary_tool,
131
+ )
132
+
133
+ async def call_backend_tool(
134
+ self,
135
+ rk_session_id: str,
136
+ instance_id: str,
137
+ backend_name_ref: str,
138
+ tool_name: str,
139
+ tool_args: Dict[str, Any],
140
+ ) -> Dict[str, Any]:
141
+ """
142
+ Calls a tool on a specific backend instance managed by the IntermediaryServer.
143
+ """
144
+ payload_for_intermediary_tool = {
145
+ "args": {
146
+ "rk_session_id": rk_session_id,
147
+ "instance_id": instance_id,
148
+ "backend_name_ref": backend_name_ref,
149
+ "tool_name": tool_name,
150
+ "tool_args": tool_args,
151
+ }
152
+ }
153
+ return await self._call_intermediary_tool(
154
+ tool_name="call_backend_tool",
155
+ tool_args_payload=payload_for_intermediary_tool,
156
+ )
157
+
158
+ async def list_backend_tools(
159
+ self, rk_session_id: str, instance_id: str, backend_name_ref: str
160
+ ) -> mcp_types.ListToolsResult:
161
+ """
162
+ Lists tools available on a specific backend instance via the IntermediaryServer.
163
+ """
164
+ payload_for_intermediary_tool = {
165
+ "args": {
166
+ "rk_session_id": rk_session_id,
167
+ "instance_id": instance_id,
168
+ "backend_name_ref": backend_name_ref,
169
+ }
170
+ }
171
+ # _call_intermediary_tool returns a Dict[str, Any] which is the parsed JSON
172
+ # from the intermediary's response. This dict should be the model_dump of ListToolsResult.
173
+ raw_result_dict = await self._call_intermediary_tool(
174
+ tool_name="list_backend_tools",
175
+ tool_args_payload=payload_for_intermediary_tool,
176
+ )
177
+ # Parse the dictionary back into the Pydantic model
178
+ return mcp_types.ListToolsResult(**raw_result_dict)
179
+
180
+ async def cleanup_session(self, rk_session_id: str) -> Dict[str, Any]:
181
+ """
182
+ Cleans up a session on the IntermediaryServer.
183
+ """
184
+ payload_for_intermediary_tool = {"args": {"rk_session_id": rk_session_id}}
185
+ result = await self._call_intermediary_tool(
186
+ tool_name="cleanup_session", tool_args_payload=payload_for_intermediary_tool
187
+ )
188
+ return result
189
+
190
+ async def __aenter__(self):
191
+ await self.connect()
192
+ return self
193
+
194
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
195
+ await self.close()
@@ -0,0 +1,23 @@
1
+ """
2
+ MCP Execution Framework
3
+
4
+ This module handles policy execution, tool calling, and rollout coordination.
5
+ """
6
+
7
+ from .base_policy import LLMBasePolicy
8
+ from .policy import AnthropicPolicy, OpenAIPolicy, FireworksPolicy
9
+ from .manager import ExecutionManager
10
+
11
+ # FireworksPolicy is conditionally imported by policy.py
12
+ _FIREWORKS_AVAILABLE = FireworksPolicy is not None
13
+
14
+ __all__ = [
15
+ "LLMBasePolicy",
16
+ "AnthropicPolicy",
17
+ "OpenAIPolicy",
18
+ "ExecutionManager",
19
+ ]
20
+
21
+ # Only export FireworksPolicy if it's available
22
+ if _FIREWORKS_AVAILABLE:
23
+ __all__.append("FireworksPolicy")
@@ -0,0 +1,227 @@
1
+ """
2
+ Base Policy for LLM Policies
3
+
4
+ This module contains the LLMBasePolicy abstract base class that provides
5
+ common functionality for all LLM-based policies (Fireworks, OpenAI, Anthropic, etc.)
6
+ """
7
+
8
+ import asyncio
9
+ import json
10
+ import logging
11
+ import os
12
+ from abc import ABC, abstractmethod
13
+ from typing import Any, Dict, List, Optional, Tuple, Union
14
+
15
+ from ...playback_policy import PlaybackPolicyBase
16
+ from ..types import LLMUsageStats, MCPToolCall
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class LLMBasePolicy(PlaybackPolicyBase, ABC):
22
+ """
23
+ Base class for LLM policies that work with MCP environments via tool calling.
24
+
25
+ This abstraction enables shared code between FireworksPolicy and OpenAIPolicy.
26
+ Maintains conversation history per environment for proper OpenAI-style trajectories.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ model_id: str,
32
+ temperature: float = 0.2,
33
+ max_tokens: int = 4096,
34
+ max_tools_per_turn: Optional[int] = None,
35
+ **kwargs,
36
+ ):
37
+ """
38
+ Initialize base policy with automatic record/playback detection.
39
+
40
+ Args:
41
+ model_id: Model identifier
42
+ temperature: Sampling temperature (0.0 to 2.0)
43
+ max_tokens: Maximum tokens to generate per request
44
+ max_tools_per_turn: Maximum number of tool calls per turn (None = unlimited, 1 = single tool)
45
+ """
46
+ # Initialize playback functionality (parent class handles EP_PLAYBACK_FILE automatically)
47
+ super().__init__(**kwargs)
48
+
49
+ # Store policy configuration
50
+ self.model_id = model_id
51
+ self.temperature = temperature
52
+ self.max_tokens = max_tokens
53
+ self.max_tools_per_turn = max_tools_per_turn
54
+
55
+ # Initialize conversation state tracking for proper OpenAI trajectories
56
+ self.initialized = False
57
+
58
+ @abstractmethod
59
+ async def _make_llm_call(self, messages: List[Dict], tools: List[Dict]) -> Dict:
60
+ """
61
+ Make an LLM API call. Subclasses must implement this.
62
+
63
+ Args:
64
+ messages: Conversation messages
65
+ tools: Available tools in OpenAI format
66
+
67
+ Returns:
68
+ LLM response with choices[0].message containing content and tool_calls
69
+ """
70
+ pass
71
+
72
+ @abstractmethod
73
+ def _convert_mcp_tools_to_llm_format(self, mcp_tools: List[Dict]) -> List[Dict]:
74
+ """
75
+ Convert MCP tool schemas to LLM-specific format.
76
+
77
+ Args:
78
+ mcp_tools: List of MCP tool definitions
79
+
80
+ Returns:
81
+ List of LLM-compatible tool definitions
82
+ """
83
+ pass
84
+
85
+ def add_tool_response(
86
+ self,
87
+ env_index: int,
88
+ tool_call: MCPToolCall,
89
+ tool_response: Union[str, List[Dict[str, Any]]],
90
+ conversation_history: List[Dict[str, Any]],
91
+ reward: float = 0.0,
92
+ terminated: bool = False,
93
+ info: Optional[Dict[str, Any]] = None,
94
+ ):
95
+ """Add tool call and response to conversation history with control plane metadata."""
96
+ # Use the preserved tool_call_id directly
97
+ if tool_call.tool_call_id is None:
98
+ raise ValueError("Tool call ID is required for tool response recording")
99
+
100
+ tool_message = {
101
+ "role": "tool",
102
+ "tool_call_id": tool_call.tool_call_id,
103
+ "content": tool_response,
104
+ }
105
+
106
+ # Add control plane metadata if provided
107
+ if reward != 0.0 or terminated or info:
108
+
109
+ tool_message["metadata"] = {
110
+ "reward": reward,
111
+ "terminated": terminated,
112
+ "info": info or {},
113
+ }
114
+
115
+ conversation_history.append(tool_message)
116
+
117
+ def log_conversation_state_for_playback(
118
+ self, env_index: int, step: int, conversation_history: List[Dict[str, Any]]
119
+ ):
120
+ """
121
+ Log the current conversation state in the format required for playback.
122
+
123
+ Expected format: {"env_index": 0, "step": 0, "messages": [{..}, {..}]}
124
+
125
+ Args:
126
+ env_index: Environment index
127
+ step: Current step number
128
+ """
129
+ # Use EP_PLAYBACK_FILE environment variable for recording
130
+ playback_file = os.environ.get("EP_PLAYBACK_FILE")
131
+ if not playback_file:
132
+ return # No recording file specified
133
+
134
+ playback_entry = {
135
+ "env_index": env_index,
136
+ "step": step,
137
+ "messages": conversation_history.copy(),
138
+ }
139
+
140
+ # TODO: because we're using threads now, the ordering will be wrong.
141
+
142
+ with open(playback_file, "a") as f:
143
+ f.write(json.dumps(playback_entry) + "\n")
144
+
145
+ async def _generate_live_tool_calls(
146
+ self,
147
+ tool_schemas: List[Dict],
148
+ env_index: int,
149
+ conversation_history: List[Dict[str, Any]],
150
+ ) -> Tuple[List[MCPToolCall], LLMUsageStats]:
151
+ """
152
+ Generate tool calls using conversation history for proper OpenAI trajectories.
153
+
154
+ Args:
155
+ tool_schemas: Available MCP tools for this environment
156
+ env_index: Environment index
157
+ user_prompt: Current user prompt with observation
158
+
159
+ Returns:
160
+ List of MCPToolCall objects
161
+ """
162
+ # Convert MCP tools to LLM format
163
+ llm_tools = self._convert_mcp_tools_to_llm_format(tool_schemas)
164
+
165
+ logger.debug(
166
+ f"Environment {env_index} - Converted {len(tool_schemas)} MCP tools to {len(llm_tools)} LLM tools"
167
+ )
168
+ logger.debug(f"Environment {env_index} - Conversation length: {len(conversation_history)} messages")
169
+
170
+ try:
171
+ # Make API call with conversation history
172
+ response = await self._make_llm_call(conversation_history, llm_tools)
173
+ except Exception as e:
174
+ logger.error(f"LLM API call failed for env {env_index}: {e}")
175
+ raise e
176
+
177
+ # ADD ASSISTANT MESSAGE TO ACTUAL CONVERSATION HISTORY
178
+ # This is crucial for proper tool call ID management in add_tool_response
179
+ assistant_message_for_history = {
180
+ "role": "assistant",
181
+ "content": response["choices"][0]["message"]["content"],
182
+ }
183
+ usage_stats = LLMUsageStats(
184
+ prompt_tokens=response["usage"]["prompt_tokens"],
185
+ completion_tokens=response["usage"]["completion_tokens"],
186
+ total_tokens=response["usage"]["total_tokens"],
187
+ )
188
+
189
+ # Extract tool call from response
190
+ message = response["choices"][0]["message"]
191
+ logger.debug(f"Environment {env_index} - Response message: {message}")
192
+
193
+ # Add ALL tool calls if present with the actual API response IDs
194
+ if message.get("tool_calls"):
195
+ assistant_message_for_history["tool_calls"] = message["tool_calls"]
196
+
197
+ # Add to actual conversation history
198
+ conversation_history.append(assistant_message_for_history)
199
+
200
+ if message.get("tool_calls") and len(message["tool_calls"]) > 0:
201
+ tool_calls = message["tool_calls"]
202
+
203
+ # Handle multiple tool calls - create MCPToolCall for each
204
+ mcp_tool_calls = []
205
+ for tool_call in tool_calls:
206
+ mcp_tool_call = MCPToolCall(
207
+ tool_name=tool_call["function"]["name"],
208
+ arguments=json.loads(tool_call["function"]["arguments"]),
209
+ tool_call_id=tool_call["id"],
210
+ )
211
+ mcp_tool_calls.append(mcp_tool_call)
212
+
213
+ if self.max_tools_per_turn:
214
+ mcp_tool_calls = mcp_tool_calls[: self.max_tools_per_turn]
215
+
216
+ return mcp_tool_calls, usage_stats
217
+ else:
218
+ # No tool calls in response - this is normal when episode ends or LLM provides only text
219
+ logger.info(f"No tool calls in response for env {env_index}, message content: {message.get('content')}")
220
+ return [
221
+ MCPToolCall(
222
+ tool_name="_no_tool_call",
223
+ arguments={
224
+ "reason": "no_tool_call_generated",
225
+ },
226
+ )
227
+ ], usage_stats
@@ -0,0 +1,209 @@
1
+ """
2
+ Fireworks AI Policy Implementation
3
+
4
+ This module contains the FireworksPolicy class that integrates with Fireworks AI's LLM API
5
+ for tool calling and conversation management in MCP environments.
6
+ """
7
+
8
+ import asyncio
9
+ import json
10
+ import logging
11
+ import os
12
+ from concurrent.futures import ThreadPoolExecutor
13
+ from typing import Any, Dict, List, Optional
14
+
15
+ from .base_policy import LLMBasePolicy
16
+ from ..types import MCPToolCall
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class FireworksPolicy(LLMBasePolicy):
22
+ """
23
+ Fireworks AI policy implementation that works with ANY MCP environment via tool calling.
24
+
25
+ NO environment-specific logic - everything comes from MCP tools and dataset prompts.
26
+ Supports both live mode (using Fireworks LLM) and playback mode (replaying recorded trajectories).
27
+ """
28
+
29
+ from fireworks import DeploymentTypeLiteral
30
+
31
+ def __init__(
32
+ self,
33
+ model_id: str,
34
+ temperature: float = 0.2,
35
+ deployment_type: DeploymentTypeLiteral = "serverless",
36
+ max_tokens: int = 4096,
37
+ max_tools_per_turn: Optional[int] = None,
38
+ **kwargs,
39
+ ):
40
+ """
41
+ Initialize Fireworks policy.
42
+
43
+ Args:
44
+ model_id: Fireworks model identifier (e.g., "accounts/fireworks/models/qwen3-235b-a22b")
45
+ temperature: Sampling temperature (0.0 to 2.0)
46
+ deployment_type: "serverless", "on-demand", "auto", or "on-demand-lora"
47
+ max_tokens: Maximum tokens to generate per request
48
+ max_tools_per_turn: Maximum number of tool calls per turn (None = unlimited, 1 = single tool)
49
+ """
50
+ super().__init__(model_id, temperature, max_tokens, max_tools_per_turn, **kwargs)
51
+
52
+ self.deployment_type = deployment_type
53
+
54
+ # Only initialize Fireworks LLM in live mode (not in playback mode)
55
+ if not self._is_playback:
56
+ # Import Fireworks Build SDK - optional at module level
57
+ try:
58
+ from fireworks import LLM
59
+ except ImportError:
60
+ raise ImportError(
61
+ "The 'fireworks-ai' package is required for FireworksPolicy. "
62
+ "Please install it with 'pip install fireworks-ai'"
63
+ )
64
+
65
+ # Verify authentication
66
+ from ...auth import get_fireworks_api_key
67
+
68
+ api_key = get_fireworks_api_key()
69
+ if not api_key:
70
+ raise ValueError(
71
+ "FIREWORKS_API_KEY environment variable or ~/.fireworks/auth.ini file is required "
72
+ "to use FireworksPolicy. See the reward-kit documentation for setup instructions."
73
+ )
74
+
75
+ # Set the API key for the Fireworks SDK
76
+ os.environ["FIREWORKS_API_KEY"] = api_key
77
+
78
+ # Initialize the LLM instance using Build SDK
79
+ try:
80
+ self.llm = LLM(
81
+ model=self.model_id,
82
+ deployment_type=self.deployment_type,
83
+ temperature=self.temperature,
84
+ )
85
+ logger.info(f"✅ Initialized Fireworks LLM: {self.model_id} ({self.deployment_type})")
86
+ except Exception as e:
87
+ raise RuntimeError(f"Failed to initialize Fireworks LLM '{self.model_id}': {e}")
88
+ # Create dedicated executor for non-blocking LLM calls
89
+ self.llm_executor = ThreadPoolExecutor(
90
+ max_workers=16, # Allow up to 16 concurrent LLM API calls
91
+ thread_name_prefix="fireworks-api",
92
+ )
93
+ else:
94
+ # In playback mode, skip expensive LLM initialization
95
+ self.llm = None
96
+ logger.info(f"🎬 Playback mode: Skipping Fireworks LLM initialization for performance")
97
+
98
+ def __del__(self):
99
+ """Clean up executor on garbage collection."""
100
+ if hasattr(self, "llm_executor"):
101
+ self.llm_executor.shutdown(wait=False)
102
+
103
+ def _clean_messages_for_api(self, messages: List[Dict]) -> List[Dict]:
104
+ """
105
+ Clean messages by removing metadata fields that Fireworks API doesn't accept.
106
+
107
+ Args:
108
+ messages: Conversation messages with potential metadata
109
+
110
+ Returns:
111
+ Clean messages without metadata fields
112
+ """
113
+ clean_messages = []
114
+ for msg in messages:
115
+ clean_msg = msg.copy()
116
+ # Remove metadata field if present
117
+ if "metadata" in clean_msg:
118
+ del clean_msg["metadata"]
119
+ clean_messages.append(clean_msg)
120
+ return clean_messages
121
+
122
+ async def _make_llm_call(self, messages: List[Dict], tools: List[Dict]) -> Dict:
123
+ """
124
+ Make a Fireworks API call.
125
+
126
+ Args:
127
+ messages: Conversation messages (may contain metadata)
128
+ tools: Available tools in OpenAI format
129
+
130
+ Returns:
131
+ API response in OpenAI format
132
+ """
133
+ llm = self.llm
134
+ if llm is None:
135
+ raise RuntimeError("Fireworks LLM not initialized")
136
+
137
+ # Clean messages by removing metadata before sending to API
138
+ clean_messages = self._clean_messages_for_api(messages)
139
+
140
+ current_request = {
141
+ "messages": clean_messages,
142
+ "tools": tools,
143
+ "temperature": self.temperature,
144
+ "max_tokens": self.max_tokens,
145
+ }
146
+
147
+ loop = asyncio.get_event_loop()
148
+ response = await loop.run_in_executor(
149
+ self.llm_executor, lambda: llm.chat.completions.create(**current_request)
150
+ )
151
+
152
+ # Convert Fireworks response to standard format
153
+ return {
154
+ "choices": [
155
+ {
156
+ "message": {
157
+ "content": response.choices[0].message.content,
158
+ "tool_calls": (
159
+ [
160
+ {
161
+ "id": tc.id,
162
+ "type": tc.type,
163
+ "function": {
164
+ "name": tc.function.name,
165
+ "arguments": tc.function.arguments,
166
+ },
167
+ }
168
+ for tc in (response.choices[0].message.tool_calls or [])
169
+ ]
170
+ if response.choices[0].message.tool_calls
171
+ else []
172
+ ),
173
+ }
174
+ }
175
+ ],
176
+ "usage": {
177
+ "prompt_tokens": response.usage.prompt_tokens,
178
+ "completion_tokens": response.usage.completion_tokens,
179
+ "total_tokens": response.usage.total_tokens,
180
+ },
181
+ }
182
+
183
+ def _convert_mcp_tools_to_llm_format(self, mcp_tools: List[Dict]) -> List[Dict]:
184
+ """
185
+ Convert MCP tool schemas to OpenAI function calling format for Fireworks.
186
+
187
+ Args:
188
+ mcp_tools: List of MCP tool definitions
189
+
190
+ Returns:
191
+ List of OpenAI-compatible tool definitions
192
+ """
193
+ openai_tools = []
194
+
195
+ for mcp_tool in mcp_tools:
196
+ openai_tool = {
197
+ "type": "function",
198
+ "function": {
199
+ "name": mcp_tool["name"],
200
+ "description": mcp_tool.get("description", f"Execute {mcp_tool['name']} action"),
201
+ "parameters": mcp_tool.get(
202
+ "input_schema",
203
+ {"type": "object", "properties": {}, "required": []},
204
+ ),
205
+ },
206
+ }
207
+ openai_tools.append(openai_tool)
208
+
209
+ return openai_tools