eval-protocol 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- development/__init__.py +1 -0
- development/normalize_sandbox_fusion.py +628 -0
- development/utils/__init__.py +1 -0
- development/utils/generate_api_key.py +31 -0
- development/utils/subprocess_manager.py +481 -0
- eval_protocol/__init__.py +86 -0
- eval_protocol/__main__.py +10 -0
- eval_protocol/_version.py +21 -0
- eval_protocol/adapters/__init__.py +1 -0
- eval_protocol/adapters/braintrust.py +8 -0
- eval_protocol/adapters/trl.py +8 -0
- eval_protocol/agent/__init__.py +29 -0
- eval_protocol/agent/models.py +69 -0
- eval_protocol/agent/orchestrator.py +893 -0
- eval_protocol/agent/resource_abc.py +89 -0
- eval_protocol/agent/resource_pool.py +184 -0
- eval_protocol/agent/resources/__init__.py +44 -0
- eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
- eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
- eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
- eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
- eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
- eval_protocol/agent/resources/docker_resource.py +479 -0
- eval_protocol/agent/resources/filesystem_resource.py +371 -0
- eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
- eval_protocol/agent/resources/http_rollout_resource.py +325 -0
- eval_protocol/agent/resources/python_state_resource.py +170 -0
- eval_protocol/agent/resources/sql_resource.py +271 -0
- eval_protocol/agent/task_manager.py +1064 -0
- eval_protocol/agent/tool_registry.py +111 -0
- eval_protocol/auth.py +156 -0
- eval_protocol/cli.py +425 -0
- eval_protocol/cli_commands/__init__.py +1 -0
- eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
- eval_protocol/cli_commands/common.py +242 -0
- eval_protocol/cli_commands/deploy.py +486 -0
- eval_protocol/cli_commands/deploy_mcp.py +287 -0
- eval_protocol/cli_commands/preview.py +186 -0
- eval_protocol/cli_commands/run_eval_cmd.py +202 -0
- eval_protocol/common_utils.py +36 -0
- eval_protocol/config.py +180 -0
- eval_protocol/datasets/__init__.py +1 -0
- eval_protocol/datasets/loader.py +521 -0
- eval_protocol/evaluation.py +1045 -0
- eval_protocol/execution/__init__.py +1 -0
- eval_protocol/execution/pipeline.py +920 -0
- eval_protocol/gcp_tools.py +484 -0
- eval_protocol/generation/cache.py +141 -0
- eval_protocol/generation/clients/base.py +67 -0
- eval_protocol/generation/clients.py +248 -0
- eval_protocol/generic_server.py +165 -0
- eval_protocol/integrations/__init__.py +12 -0
- eval_protocol/integrations/braintrust.py +51 -0
- eval_protocol/integrations/deepeval.py +106 -0
- eval_protocol/integrations/openeval.py +40 -0
- eval_protocol/integrations/trl.py +187 -0
- eval_protocol/mcp/__init__.py +48 -0
- eval_protocol/mcp/adapter.py +131 -0
- eval_protocol/mcp/client/__init__.py +12 -0
- eval_protocol/mcp/client/connection.py +499 -0
- eval_protocol/mcp/clients.py +195 -0
- eval_protocol/mcp/execution/__init__.py +23 -0
- eval_protocol/mcp/execution/base_policy.py +227 -0
- eval_protocol/mcp/execution/fireworks_policy.py +209 -0
- eval_protocol/mcp/execution/manager.py +506 -0
- eval_protocol/mcp/execution/policy.py +421 -0
- eval_protocol/mcp/grid_renderer.py +54 -0
- eval_protocol/mcp/mcpgym.py +637 -0
- eval_protocol/mcp/process_manager.py +177 -0
- eval_protocol/mcp/session/__init__.py +11 -0
- eval_protocol/mcp/session/manager.py +228 -0
- eval_protocol/mcp/simple_process_manager.py +291 -0
- eval_protocol/mcp/simulation_server.py +458 -0
- eval_protocol/mcp/types.py +80 -0
- eval_protocol/mcp_agent/__init__.py +1 -0
- eval_protocol/mcp_agent/config.py +147 -0
- eval_protocol/mcp_agent/intermediary_server.py +542 -0
- eval_protocol/mcp_agent/main.py +210 -0
- eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
- eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
- eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
- eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
- eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
- eval_protocol/mcp_agent/session.py +79 -0
- eval_protocol/mcp_env.py +304 -0
- eval_protocol/models.py +366 -0
- eval_protocol/packaging.py +219 -0
- eval_protocol/platform_api.py +360 -0
- eval_protocol/playback_policy.py +396 -0
- eval_protocol/resources.py +128 -0
- eval_protocol/reward_function.py +410 -0
- eval_protocol/rewards/__init__.py +94 -0
- eval_protocol/rewards/accuracy.py +454 -0
- eval_protocol/rewards/accuracy_length.py +173 -0
- eval_protocol/rewards/apps_coding_reward.py +331 -0
- eval_protocol/rewards/apps_execution_utils.py +149 -0
- eval_protocol/rewards/apps_testing_util.py +559 -0
- eval_protocol/rewards/bfcl_reward.py +313 -0
- eval_protocol/rewards/code_execution.py +1620 -0
- eval_protocol/rewards/code_execution_utils.py +72 -0
- eval_protocol/rewards/cpp_code.py +861 -0
- eval_protocol/rewards/deepcoder_reward.py +161 -0
- eval_protocol/rewards/format.py +129 -0
- eval_protocol/rewards/function_calling.py +541 -0
- eval_protocol/rewards/json_schema.py +422 -0
- eval_protocol/rewards/language_consistency.py +700 -0
- eval_protocol/rewards/lean_prover.py +479 -0
- eval_protocol/rewards/length.py +375 -0
- eval_protocol/rewards/list_comparison_math_reward.py +221 -0
- eval_protocol/rewards/math.py +762 -0
- eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
- eval_protocol/rewards/reasoning_steps.py +249 -0
- eval_protocol/rewards/repetition.py +342 -0
- eval_protocol/rewards/tag_count.py +162 -0
- eval_protocol/rl_processing.py +82 -0
- eval_protocol/server.py +271 -0
- eval_protocol/typed_interface.py +260 -0
- eval_protocol/utils/__init__.py +8 -0
- eval_protocol/utils/batch_evaluation.py +217 -0
- eval_protocol/utils/batch_transformation.py +205 -0
- eval_protocol/utils/dataset_helpers.py +112 -0
- eval_protocol/utils/module_loader.py +56 -0
- eval_protocol/utils/packaging_utils.py +108 -0
- eval_protocol/utils/static_policy.py +305 -0
- eval_protocol-0.0.3.dist-info/METADATA +635 -0
- eval_protocol-0.0.3.dist-info/RECORD +130 -0
- eval_protocol-0.0.3.dist-info/WHEEL +5 -0
- eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
- eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
- eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,506 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MCP Execution Management
|
|
3
|
+
|
|
4
|
+
Unified class that handles both session management and rollout execution.
|
|
5
|
+
Combines the functionality of SessionManager and RolloutManager.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
import os
|
|
13
|
+
import time
|
|
14
|
+
import threading
|
|
15
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
16
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
|
17
|
+
|
|
18
|
+
from ..client.connection import MCPConnectionManager
|
|
19
|
+
from ..types import MCPSession, MCPToolCall, Trajectory, TerminationReason, LLMUsageStats
|
|
20
|
+
|
|
21
|
+
from tau2.user.user_simulator import UserSimulator
|
|
22
|
+
from tau2.data_model.message import AssistantMessage, UserMessage
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from ..session.manager import GeneralMCPVectorEnv
|
|
26
|
+
from .policy import LLMBasePolicy
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ExecutionManager:
|
|
32
|
+
"""
|
|
33
|
+
Unified manager that handles both MCP session lifecycle and rollout execution.
|
|
34
|
+
|
|
35
|
+
Combines the functionality of SessionManager and RolloutManager for better
|
|
36
|
+
organization and reduced complexity.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self):
|
|
40
|
+
"""Initialize the execution manager."""
|
|
41
|
+
self.connection_manager = MCPConnectionManager()
|
|
42
|
+
|
|
43
|
+
async def initialize_sessions(self, sessions: List[MCPSession]) -> None:
|
|
44
|
+
"""
|
|
45
|
+
Initialize multiple MCP sessions in parallel.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
sessions: List of MCPSessions to initialize
|
|
49
|
+
"""
|
|
50
|
+
tasks = [self.connection_manager.initialize_session(session) for session in sessions]
|
|
51
|
+
await asyncio.gather(*tasks)
|
|
52
|
+
|
|
53
|
+
async def close_sessions(self, sessions: List[MCPSession]) -> None:
|
|
54
|
+
"""
|
|
55
|
+
Close multiple MCP sessions in parallel.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
sessions: List of MCPSessions to close
|
|
59
|
+
"""
|
|
60
|
+
tasks = [asyncio.create_task(self.connection_manager.close_session(session)) for session in sessions]
|
|
61
|
+
|
|
62
|
+
if tasks:
|
|
63
|
+
try:
|
|
64
|
+
# Wait for all close operations to complete
|
|
65
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
66
|
+
except asyncio.CancelledError:
|
|
67
|
+
# Handle cancellation gracefully (especially important for Python 3.12)
|
|
68
|
+
logger.debug("Close operation was cancelled, but sessions are marked as closed")
|
|
69
|
+
|
|
70
|
+
async def execute_rollouts(
|
|
71
|
+
self,
|
|
72
|
+
envs: "GeneralMCPVectorEnv",
|
|
73
|
+
policy: Union["LLMBasePolicy", Callable],
|
|
74
|
+
steps: int = 512,
|
|
75
|
+
openai_format_log_file: Optional[str] = None,
|
|
76
|
+
max_concurrent_rollouts: int = 8,
|
|
77
|
+
) -> List[Trajectory]:
|
|
78
|
+
"""
|
|
79
|
+
Execute general rollouts using tool calling interface with automatic record/playback.
|
|
80
|
+
|
|
81
|
+
This works with ANY MCP environment because:
|
|
82
|
+
1. Policy receives tool schemas and makes tool calls
|
|
83
|
+
2. Environment prompts come from dataset
|
|
84
|
+
3. No hardcoded environment logic
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
envs: GeneralMCPVectorEnv instance
|
|
88
|
+
policy: Policy that takes tool schemas, observations, prompts and returns tool calls
|
|
89
|
+
steps: Maximum steps per rollout
|
|
90
|
+
openai_format_log_file: Optional file to log clean OpenAI format for terminated trajectories only
|
|
91
|
+
max_concurrent_rollouts: Maximum number of concurrent threads to run
|
|
92
|
+
|
|
93
|
+
Environment Variable Control:
|
|
94
|
+
EP_PLAYBACK_FILE: Controls record/playback mode
|
|
95
|
+
- Not set: Normal live mode
|
|
96
|
+
- Set but file doesn't exist: Record mode (file will be created)
|
|
97
|
+
- Set and file exists: Playback mode (uses recorded data)
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
List of Trajectory objects with complete rollout data
|
|
101
|
+
"""
|
|
102
|
+
start_time = time.time()
|
|
103
|
+
|
|
104
|
+
# Check for record/playback mode
|
|
105
|
+
playback_file = os.environ.get("EP_PLAYBACK_FILE")
|
|
106
|
+
recording_mode = bool(playback_file and not os.path.exists(playback_file))
|
|
107
|
+
playback_mode = bool(playback_file and os.path.exists(playback_file))
|
|
108
|
+
|
|
109
|
+
if recording_mode:
|
|
110
|
+
logger.info(f"📝 Recording mode: Will record to {playback_file}")
|
|
111
|
+
elif playback_mode:
|
|
112
|
+
logger.info(f"🎬 Playback mode: Using recorded data from {playback_file}")
|
|
113
|
+
else:
|
|
114
|
+
logger.info(f"🚀 Live mode: No recording/playback")
|
|
115
|
+
|
|
116
|
+
# Initialize OpenAI format logging for terminated trajectories only
|
|
117
|
+
openai_logger = None
|
|
118
|
+
if openai_format_log_file:
|
|
119
|
+
# Clear the file at start
|
|
120
|
+
with open(openai_format_log_file, "w") as f:
|
|
121
|
+
pass
|
|
122
|
+
openai_logger = lambda data: self._log_openai_entry(openai_format_log_file, data)
|
|
123
|
+
|
|
124
|
+
logger.info(f"🧵 Starting {envs.n} rollouts with max {max_concurrent_rollouts} concurrent threads...")
|
|
125
|
+
|
|
126
|
+
results = {}
|
|
127
|
+
|
|
128
|
+
semaphore = asyncio.Semaphore(max_concurrent_rollouts)
|
|
129
|
+
|
|
130
|
+
async def _execute_with_semaphore(idx):
|
|
131
|
+
async with semaphore:
|
|
132
|
+
return await self._execute_rollout(
|
|
133
|
+
envs, policy, idx, steps, openai_logger, recording_mode, playback_mode, start_time
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
tasks = [_execute_with_semaphore(i) for i in range(envs.n)]
|
|
137
|
+
trajectories = await asyncio.gather(*tasks)
|
|
138
|
+
|
|
139
|
+
# Calculate durations
|
|
140
|
+
total_duration = time.time() - start_time
|
|
141
|
+
for trajectory in trajectories:
|
|
142
|
+
trajectory.duration = total_duration
|
|
143
|
+
|
|
144
|
+
# Clean up
|
|
145
|
+
await envs.close()
|
|
146
|
+
|
|
147
|
+
# Enhanced reporting with control plane info
|
|
148
|
+
successful = sum(1 for traj in trajectories if traj.total_reward > 0)
|
|
149
|
+
terminated_by_control_plane = sum(
|
|
150
|
+
1
|
|
151
|
+
for traj in trajectories
|
|
152
|
+
if traj.control_plane_summary.get("termination_reason") == "control_plane_signal"
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
logger.info(f"📊 Rollout complete: {successful}/{len(trajectories)} reached goal")
|
|
156
|
+
logger.info(f"🎛️ Control plane terminations: {terminated_by_control_plane}/{len(trajectories)}")
|
|
157
|
+
logger.info(f"⏱️ Total duration: {total_duration:.2f}s")
|
|
158
|
+
logger.info(f"🧵 Used {max_concurrent_rollouts} concurrent threads")
|
|
159
|
+
|
|
160
|
+
# Print log file locations if created
|
|
161
|
+
if openai_format_log_file:
|
|
162
|
+
logger.info(f"💬 OpenAI format log: {openai_format_log_file}")
|
|
163
|
+
if recording_mode:
|
|
164
|
+
logger.info(f"📝 Recorded trajectory: {playback_file}")
|
|
165
|
+
# Add note about control plane separation
|
|
166
|
+
logger.info(f"🎛️ Trajectories include control plane separation")
|
|
167
|
+
|
|
168
|
+
return trajectories
|
|
169
|
+
|
|
170
|
+
async def _execute_rollout(
|
|
171
|
+
self,
|
|
172
|
+
envs: "GeneralMCPVectorEnv",
|
|
173
|
+
policy: Union["LLMBasePolicy", Callable],
|
|
174
|
+
rollout_idx: int,
|
|
175
|
+
steps: int,
|
|
176
|
+
openai_logger: Optional[Callable],
|
|
177
|
+
recording_mode: bool,
|
|
178
|
+
playback_mode: bool,
|
|
179
|
+
start_time: float,
|
|
180
|
+
) -> Trajectory:
|
|
181
|
+
"""
|
|
182
|
+
Execute a single rollout for one environment (async version for thread execution).
|
|
183
|
+
|
|
184
|
+
This method runs within a thread's event loop and handles all async operations.
|
|
185
|
+
"""
|
|
186
|
+
session = envs.sessions[rollout_idx]
|
|
187
|
+
dataset_row = envs.dataset_rows[rollout_idx]
|
|
188
|
+
|
|
189
|
+
# Initialize trajectory
|
|
190
|
+
trajectory = Trajectory(
|
|
191
|
+
session=session,
|
|
192
|
+
observations=[],
|
|
193
|
+
actions=[],
|
|
194
|
+
rewards=[],
|
|
195
|
+
terminated=False,
|
|
196
|
+
total_reward=0.0,
|
|
197
|
+
steps=0,
|
|
198
|
+
duration=0.0,
|
|
199
|
+
control_plane_steps=[],
|
|
200
|
+
control_plane_summary={},
|
|
201
|
+
termination_reason="",
|
|
202
|
+
conversation_history=[],
|
|
203
|
+
llm_usage_summary={
|
|
204
|
+
"prompt_tokens": 0,
|
|
205
|
+
"completion_tokens": 0,
|
|
206
|
+
"total_tokens": 0,
|
|
207
|
+
},
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
current_observation, tool_schema = await envs.reset(session)
|
|
211
|
+
system_prompt = dataset_row.system_prompt
|
|
212
|
+
|
|
213
|
+
# Record initial observation
|
|
214
|
+
trajectory.observations.append(current_observation)
|
|
215
|
+
|
|
216
|
+
# Create user simulator for this rollout if configured in dataset
|
|
217
|
+
user_simulator = None
|
|
218
|
+
user_simulator_state = None
|
|
219
|
+
|
|
220
|
+
# If user simulation is enabled, initial message is from the simulated user
|
|
221
|
+
if dataset_row.user_simulation and dataset_row.user_simulation.get("enabled", False):
|
|
222
|
+
user_simulator = UserSimulator(
|
|
223
|
+
instructions=dataset_row.user_simulation.get("system_prompt"),
|
|
224
|
+
llm=dataset_row.user_simulation.get("llm", "gpt-4.1"),
|
|
225
|
+
llm_args=dataset_row.user_simulation.get("llm_args", {"temperature": 0.7}),
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# Get initial messages in tau2-bench format for user simulator
|
|
229
|
+
user_simulator_state = user_simulator.get_init_state()
|
|
230
|
+
user_message, user_simulator_state = user_simulator.generate_next_message(
|
|
231
|
+
AssistantMessage(role="assistant", content="Hi! How can I help you today?"),
|
|
232
|
+
user_simulator_state,
|
|
233
|
+
)
|
|
234
|
+
current_observation = user_message.content if user_message.content else ""
|
|
235
|
+
|
|
236
|
+
user_prompt = envs.format_user_prompt(rollout_idx, current_observation)
|
|
237
|
+
conversation_history = [
|
|
238
|
+
{"role": "system", "content": system_prompt},
|
|
239
|
+
{"role": "user", "content": user_prompt},
|
|
240
|
+
]
|
|
241
|
+
|
|
242
|
+
usage_stats_list: List[LLMUsageStats] = []
|
|
243
|
+
|
|
244
|
+
logger.info(f"🎯 Starting rollout {rollout_idx} in thread {threading.current_thread().name}")
|
|
245
|
+
|
|
246
|
+
# Run rollout loop for this specific environment
|
|
247
|
+
step = 0
|
|
248
|
+
rollout_end = False
|
|
249
|
+
|
|
250
|
+
while step < steps and not trajectory.terminated:
|
|
251
|
+
turn_completed = False
|
|
252
|
+
info = {}
|
|
253
|
+
reward = 0.0
|
|
254
|
+
observation = current_observation
|
|
255
|
+
tool_calls = []
|
|
256
|
+
|
|
257
|
+
if user_simulator and user_simulator_state:
|
|
258
|
+
# Get user simulator messages and find the last assistant message
|
|
259
|
+
user_simulator_messages = self._get_user_simulator_messages(conversation_history)
|
|
260
|
+
|
|
261
|
+
# Last message was agent, simulated user response
|
|
262
|
+
if user_simulator_messages and isinstance(user_simulator_messages[-1], AssistantMessage):
|
|
263
|
+
# Generate user response using the simulator
|
|
264
|
+
user_message, user_simulator_state = user_simulator.generate_next_message(
|
|
265
|
+
user_simulator_messages[-1], user_simulator_state
|
|
266
|
+
)
|
|
267
|
+
user_content = user_message.content if user_message.content else ""
|
|
268
|
+
|
|
269
|
+
user_prompt = envs.format_user_prompt(rollout_idx, user_content)
|
|
270
|
+
conversation_history.append({"role": "user", "content": user_prompt})
|
|
271
|
+
|
|
272
|
+
# Check if user simulator signaled termination
|
|
273
|
+
if UserSimulator.is_stop(user_message):
|
|
274
|
+
trajectory.terminated = True
|
|
275
|
+
trajectory.termination_reason = TerminationReason.USER_STOP
|
|
276
|
+
|
|
277
|
+
# In each turn: keep looping until assistant is ready to provide final response
|
|
278
|
+
while not turn_completed and not trajectory.terminated:
|
|
279
|
+
tool_calls, usage_stats = await policy(tool_schema, rollout_idx, conversation_history)
|
|
280
|
+
|
|
281
|
+
# If no tool call is generated, turn is finished
|
|
282
|
+
if len(tool_calls) == 1:
|
|
283
|
+
# No tool calls means the policy is ready to provide final response on this turn
|
|
284
|
+
if tool_calls[0].tool_name == "_no_tool_call":
|
|
285
|
+
trajectory.terminated = True
|
|
286
|
+
break
|
|
287
|
+
elif tool_calls[0].tool_name == "_playback_terminate":
|
|
288
|
+
trajectory.terminated = True
|
|
289
|
+
break
|
|
290
|
+
|
|
291
|
+
# Execute each tool call sequentially
|
|
292
|
+
for tool_call in tool_calls:
|
|
293
|
+
|
|
294
|
+
# Execute tool call for this environment
|
|
295
|
+
observation, reward, rollout_end, info = await envs.step(rollout_idx, tool_call)
|
|
296
|
+
|
|
297
|
+
tool_response = envs.format_tool_response(observation)
|
|
298
|
+
|
|
299
|
+
policy.add_tool_response(
|
|
300
|
+
rollout_idx,
|
|
301
|
+
tool_call,
|
|
302
|
+
tool_response,
|
|
303
|
+
conversation_history,
|
|
304
|
+
reward,
|
|
305
|
+
rollout_end,
|
|
306
|
+
info,
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# Update trajectory with both data and control plane information
|
|
310
|
+
trajectory.observations.append(observation)
|
|
311
|
+
|
|
312
|
+
# Record action (tool call)
|
|
313
|
+
action_str = f"{tool_call.tool_name}({tool_call.arguments})"
|
|
314
|
+
trajectory.actions.append(action_str)
|
|
315
|
+
|
|
316
|
+
# Record control plane (reward/termination)
|
|
317
|
+
trajectory.rewards.append(reward)
|
|
318
|
+
trajectory.total_reward += reward
|
|
319
|
+
|
|
320
|
+
# Non-user simulator step counter: each tool call is a step
|
|
321
|
+
if user_simulator is None:
|
|
322
|
+
step += 1
|
|
323
|
+
trajectory.steps = step
|
|
324
|
+
|
|
325
|
+
control_plane_step = {
|
|
326
|
+
"step": step - 1,
|
|
327
|
+
"reward": reward,
|
|
328
|
+
"terminated": rollout_end,
|
|
329
|
+
"info": info.get("control_plane", {}),
|
|
330
|
+
"tool_calls": [f"{tool_call.tool_name}({tool_call.arguments})"],
|
|
331
|
+
"num_tool_calls": 1,
|
|
332
|
+
}
|
|
333
|
+
trajectory.control_plane_steps.append(control_plane_step)
|
|
334
|
+
|
|
335
|
+
# Log conversation state for playback if in recording mode
|
|
336
|
+
if recording_mode:
|
|
337
|
+
policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history)
|
|
338
|
+
|
|
339
|
+
if rollout_end:
|
|
340
|
+
trajectory.terminated = True
|
|
341
|
+
trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL
|
|
342
|
+
break
|
|
343
|
+
elif step >= steps:
|
|
344
|
+
trajectory.terminated = True
|
|
345
|
+
trajectory.termination_reason = TerminationReason.MAX_STEPS
|
|
346
|
+
break
|
|
347
|
+
|
|
348
|
+
# Update current observation for potential next turn
|
|
349
|
+
if observation is not None:
|
|
350
|
+
current_observation = observation
|
|
351
|
+
|
|
352
|
+
# calc llm usage stats happened in this turn if there is aany
|
|
353
|
+
if usage_stats:
|
|
354
|
+
usage_stats_list.append(usage_stats)
|
|
355
|
+
|
|
356
|
+
# With user simulator, increment step after an entire conversation step
|
|
357
|
+
if user_simulator is not None:
|
|
358
|
+
step += 1
|
|
359
|
+
trajectory.steps = step
|
|
360
|
+
|
|
361
|
+
# Enhanced trajectory recording with control plane info
|
|
362
|
+
# Create summary of all tool calls executed in this step
|
|
363
|
+
tool_calls_summary = [f"{tc.tool_name}({tc.arguments})" for tc in tool_calls]
|
|
364
|
+
|
|
365
|
+
control_plane_step = {
|
|
366
|
+
"step": step - 1,
|
|
367
|
+
"reward": reward,
|
|
368
|
+
"terminated": rollout_end,
|
|
369
|
+
"info": info.get("control_plane", {}),
|
|
370
|
+
"tool_calls": tool_calls_summary,
|
|
371
|
+
"num_tool_calls": len(tool_calls),
|
|
372
|
+
}
|
|
373
|
+
trajectory.control_plane_steps.append(control_plane_step)
|
|
374
|
+
|
|
375
|
+
# Log conversation state for playback if in recording mode
|
|
376
|
+
if recording_mode:
|
|
377
|
+
policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history)
|
|
378
|
+
|
|
379
|
+
# Use control plane information for termination decision
|
|
380
|
+
if rollout_end:
|
|
381
|
+
trajectory.terminated = True
|
|
382
|
+
trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL
|
|
383
|
+
|
|
384
|
+
# Add final control plane summary
|
|
385
|
+
trajectory.control_plane_summary.update(
|
|
386
|
+
{
|
|
387
|
+
"total_reward": trajectory.total_reward,
|
|
388
|
+
"termination_reason": trajectory.termination_reason,
|
|
389
|
+
"final_step": step - 1,
|
|
390
|
+
"control_plane_source": info.get("control_plane", {}),
|
|
391
|
+
}
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
# Log final OpenAI conversation for terminated trajectories only
|
|
395
|
+
if openai_logger:
|
|
396
|
+
if conversation_history and len(conversation_history) > 0:
|
|
397
|
+
openai_logger(
|
|
398
|
+
{
|
|
399
|
+
"messages": conversation_history,
|
|
400
|
+
"metadata": {
|
|
401
|
+
"session_id": session.session_id,
|
|
402
|
+
"seed": session.seed,
|
|
403
|
+
"total_steps": trajectory.steps,
|
|
404
|
+
"total_reward": trajectory.total_reward,
|
|
405
|
+
"terminated": True,
|
|
406
|
+
"success": reward > 0,
|
|
407
|
+
"control_plane_summary": trajectory.control_plane_summary,
|
|
408
|
+
},
|
|
409
|
+
}
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
logger.info(
|
|
413
|
+
f"🏁 Rollout {rollout_idx} terminated at step {step} (reward: {trajectory.total_reward}) in thread {threading.current_thread().name}"
|
|
414
|
+
)
|
|
415
|
+
break
|
|
416
|
+
|
|
417
|
+
# Progress logging
|
|
418
|
+
if step % 10 == 0:
|
|
419
|
+
logger.debug(f"Rollout {rollout_idx} step {step}, reward: {trajectory.total_reward:.2f}")
|
|
420
|
+
|
|
421
|
+
# Set termination reason if not already set (e.g., due to step limit)
|
|
422
|
+
if not trajectory.termination_reason and step >= steps:
|
|
423
|
+
trajectory.termination_reason = TerminationReason.MAX_STEPS
|
|
424
|
+
|
|
425
|
+
trajectory.conversation_history = conversation_history
|
|
426
|
+
|
|
427
|
+
for usage_stats in usage_stats_list:
|
|
428
|
+
trajectory.llm_usage_summary["prompt_tokens"] += usage_stats.prompt_tokens
|
|
429
|
+
trajectory.llm_usage_summary["completion_tokens"] += usage_stats.completion_tokens
|
|
430
|
+
trajectory.llm_usage_summary["total_tokens"] += usage_stats.total_tokens
|
|
431
|
+
|
|
432
|
+
logger.info(
|
|
433
|
+
f"✅ Rollout {rollout_idx} completed: {trajectory.steps} steps, reward: {trajectory.total_reward:.2f}, termination: {trajectory.termination_reason}, in thread {threading.current_thread().name}"
|
|
434
|
+
)
|
|
435
|
+
return trajectory
|
|
436
|
+
|
|
437
|
+
async def _get_control_plane_status(self, session) -> Optional[Dict[str, Any]]:
|
|
438
|
+
"""
|
|
439
|
+
Query the control plane status endpoint directly for a session.
|
|
440
|
+
|
|
441
|
+
Args:
|
|
442
|
+
session: MCP session object
|
|
443
|
+
|
|
444
|
+
Returns:
|
|
445
|
+
Control plane status dictionary or None if query fails
|
|
446
|
+
"""
|
|
447
|
+
try:
|
|
448
|
+
import httpx
|
|
449
|
+
|
|
450
|
+
# Extract base URL and session ID
|
|
451
|
+
base_url = session.base_url.rstrip("/mcp").rstrip("/")
|
|
452
|
+
session_id = session.session_id
|
|
453
|
+
|
|
454
|
+
if not session_id:
|
|
455
|
+
logger.debug("Control plane query failed: No session ID")
|
|
456
|
+
return None
|
|
457
|
+
|
|
458
|
+
headers = {"mcp-session-id": session_id}
|
|
459
|
+
|
|
460
|
+
# Query status endpoint
|
|
461
|
+
async with httpx.AsyncClient(timeout=2.0) as client:
|
|
462
|
+
status_response = await client.get(
|
|
463
|
+
f"{base_url}/control/status",
|
|
464
|
+
headers=headers,
|
|
465
|
+
timeout=2.0, # Short timeout for performance
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
if status_response.status_code == 200:
|
|
469
|
+
status_data = status_response.json()
|
|
470
|
+
return status_data
|
|
471
|
+
else:
|
|
472
|
+
logger.debug(
|
|
473
|
+
f"Control plane endpoint returned {status_response.status_code} for session {session_id[:16]}"
|
|
474
|
+
)
|
|
475
|
+
return None
|
|
476
|
+
|
|
477
|
+
except asyncio.TimeoutError:
|
|
478
|
+
logger.debug(f"Control plane query timed out for session {session.session_id[:16]}")
|
|
479
|
+
return None
|
|
480
|
+
except Exception as e:
|
|
481
|
+
logger.debug(f"Control plane query failed for session {session.session_id[:16]}: {e}")
|
|
482
|
+
return None
|
|
483
|
+
|
|
484
|
+
def _log_openai_entry(self, log_file: str, data: Dict[str, Any]):
|
|
485
|
+
"""Helper function to log OpenAI format entries."""
|
|
486
|
+
with open(log_file, "a") as f:
|
|
487
|
+
f.write(json.dumps(data) + "\n")
|
|
488
|
+
|
|
489
|
+
def _get_user_simulator_messages(self, conversation_history: List[Dict[str, Any]]) -> List:
|
|
490
|
+
"""
|
|
491
|
+
Filter conversation history for user simulator and convert to tau2-bench format.
|
|
492
|
+
"""
|
|
493
|
+
tau2_messages = []
|
|
494
|
+
|
|
495
|
+
for message in conversation_history:
|
|
496
|
+
role = message.get("role")
|
|
497
|
+
content = message.get("content", "")
|
|
498
|
+
|
|
499
|
+
if role == "assistant":
|
|
500
|
+
if "tool_calls" not in message or not message.get("tool_calls"):
|
|
501
|
+
tau2_messages.append(AssistantMessage(role="assistant", content=content))
|
|
502
|
+
|
|
503
|
+
elif role == "user":
|
|
504
|
+
tau2_messages.append(UserMessage(role="user", content=content))
|
|
505
|
+
|
|
506
|
+
return tau2_messages
|