eval-protocol 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. development/__init__.py +1 -0
  2. development/normalize_sandbox_fusion.py +628 -0
  3. development/utils/__init__.py +1 -0
  4. development/utils/generate_api_key.py +31 -0
  5. development/utils/subprocess_manager.py +481 -0
  6. eval_protocol/__init__.py +86 -0
  7. eval_protocol/__main__.py +10 -0
  8. eval_protocol/_version.py +21 -0
  9. eval_protocol/adapters/__init__.py +1 -0
  10. eval_protocol/adapters/braintrust.py +8 -0
  11. eval_protocol/adapters/trl.py +8 -0
  12. eval_protocol/agent/__init__.py +29 -0
  13. eval_protocol/agent/models.py +69 -0
  14. eval_protocol/agent/orchestrator.py +893 -0
  15. eval_protocol/agent/resource_abc.py +89 -0
  16. eval_protocol/agent/resource_pool.py +184 -0
  17. eval_protocol/agent/resources/__init__.py +44 -0
  18. eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
  19. eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
  20. eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
  21. eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
  22. eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
  23. eval_protocol/agent/resources/docker_resource.py +479 -0
  24. eval_protocol/agent/resources/filesystem_resource.py +371 -0
  25. eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
  26. eval_protocol/agent/resources/http_rollout_resource.py +325 -0
  27. eval_protocol/agent/resources/python_state_resource.py +170 -0
  28. eval_protocol/agent/resources/sql_resource.py +271 -0
  29. eval_protocol/agent/task_manager.py +1064 -0
  30. eval_protocol/agent/tool_registry.py +111 -0
  31. eval_protocol/auth.py +156 -0
  32. eval_protocol/cli.py +425 -0
  33. eval_protocol/cli_commands/__init__.py +1 -0
  34. eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
  35. eval_protocol/cli_commands/common.py +242 -0
  36. eval_protocol/cli_commands/deploy.py +486 -0
  37. eval_protocol/cli_commands/deploy_mcp.py +287 -0
  38. eval_protocol/cli_commands/preview.py +186 -0
  39. eval_protocol/cli_commands/run_eval_cmd.py +202 -0
  40. eval_protocol/common_utils.py +36 -0
  41. eval_protocol/config.py +180 -0
  42. eval_protocol/datasets/__init__.py +1 -0
  43. eval_protocol/datasets/loader.py +521 -0
  44. eval_protocol/evaluation.py +1045 -0
  45. eval_protocol/execution/__init__.py +1 -0
  46. eval_protocol/execution/pipeline.py +920 -0
  47. eval_protocol/gcp_tools.py +484 -0
  48. eval_protocol/generation/cache.py +141 -0
  49. eval_protocol/generation/clients/base.py +67 -0
  50. eval_protocol/generation/clients.py +248 -0
  51. eval_protocol/generic_server.py +165 -0
  52. eval_protocol/integrations/__init__.py +12 -0
  53. eval_protocol/integrations/braintrust.py +51 -0
  54. eval_protocol/integrations/deepeval.py +106 -0
  55. eval_protocol/integrations/openeval.py +40 -0
  56. eval_protocol/integrations/trl.py +187 -0
  57. eval_protocol/mcp/__init__.py +48 -0
  58. eval_protocol/mcp/adapter.py +131 -0
  59. eval_protocol/mcp/client/__init__.py +12 -0
  60. eval_protocol/mcp/client/connection.py +499 -0
  61. eval_protocol/mcp/clients.py +195 -0
  62. eval_protocol/mcp/execution/__init__.py +23 -0
  63. eval_protocol/mcp/execution/base_policy.py +227 -0
  64. eval_protocol/mcp/execution/fireworks_policy.py +209 -0
  65. eval_protocol/mcp/execution/manager.py +506 -0
  66. eval_protocol/mcp/execution/policy.py +421 -0
  67. eval_protocol/mcp/grid_renderer.py +54 -0
  68. eval_protocol/mcp/mcpgym.py +637 -0
  69. eval_protocol/mcp/process_manager.py +177 -0
  70. eval_protocol/mcp/session/__init__.py +11 -0
  71. eval_protocol/mcp/session/manager.py +228 -0
  72. eval_protocol/mcp/simple_process_manager.py +291 -0
  73. eval_protocol/mcp/simulation_server.py +458 -0
  74. eval_protocol/mcp/types.py +80 -0
  75. eval_protocol/mcp_agent/__init__.py +1 -0
  76. eval_protocol/mcp_agent/config.py +147 -0
  77. eval_protocol/mcp_agent/intermediary_server.py +542 -0
  78. eval_protocol/mcp_agent/main.py +210 -0
  79. eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
  80. eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
  81. eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
  82. eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
  83. eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
  84. eval_protocol/mcp_agent/session.py +79 -0
  85. eval_protocol/mcp_env.py +304 -0
  86. eval_protocol/models.py +366 -0
  87. eval_protocol/packaging.py +219 -0
  88. eval_protocol/platform_api.py +360 -0
  89. eval_protocol/playback_policy.py +396 -0
  90. eval_protocol/resources.py +128 -0
  91. eval_protocol/reward_function.py +410 -0
  92. eval_protocol/rewards/__init__.py +94 -0
  93. eval_protocol/rewards/accuracy.py +454 -0
  94. eval_protocol/rewards/accuracy_length.py +173 -0
  95. eval_protocol/rewards/apps_coding_reward.py +331 -0
  96. eval_protocol/rewards/apps_execution_utils.py +149 -0
  97. eval_protocol/rewards/apps_testing_util.py +559 -0
  98. eval_protocol/rewards/bfcl_reward.py +313 -0
  99. eval_protocol/rewards/code_execution.py +1620 -0
  100. eval_protocol/rewards/code_execution_utils.py +72 -0
  101. eval_protocol/rewards/cpp_code.py +861 -0
  102. eval_protocol/rewards/deepcoder_reward.py +161 -0
  103. eval_protocol/rewards/format.py +129 -0
  104. eval_protocol/rewards/function_calling.py +541 -0
  105. eval_protocol/rewards/json_schema.py +422 -0
  106. eval_protocol/rewards/language_consistency.py +700 -0
  107. eval_protocol/rewards/lean_prover.py +479 -0
  108. eval_protocol/rewards/length.py +375 -0
  109. eval_protocol/rewards/list_comparison_math_reward.py +221 -0
  110. eval_protocol/rewards/math.py +762 -0
  111. eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
  112. eval_protocol/rewards/reasoning_steps.py +249 -0
  113. eval_protocol/rewards/repetition.py +342 -0
  114. eval_protocol/rewards/tag_count.py +162 -0
  115. eval_protocol/rl_processing.py +82 -0
  116. eval_protocol/server.py +271 -0
  117. eval_protocol/typed_interface.py +260 -0
  118. eval_protocol/utils/__init__.py +8 -0
  119. eval_protocol/utils/batch_evaluation.py +217 -0
  120. eval_protocol/utils/batch_transformation.py +205 -0
  121. eval_protocol/utils/dataset_helpers.py +112 -0
  122. eval_protocol/utils/module_loader.py +56 -0
  123. eval_protocol/utils/packaging_utils.py +108 -0
  124. eval_protocol/utils/static_policy.py +305 -0
  125. eval_protocol-0.0.3.dist-info/METADATA +635 -0
  126. eval_protocol-0.0.3.dist-info/RECORD +130 -0
  127. eval_protocol-0.0.3.dist-info/WHEEL +5 -0
  128. eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
  129. eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
  130. eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
@@ -0,0 +1,506 @@
1
+ """
2
+ MCP Execution Management
3
+
4
+ Unified class that handles both session management and rollout execution.
5
+ Combines the functionality of SessionManager and RolloutManager.
6
+ """
7
+
8
+ import asyncio
9
+ from dataclasses import dataclass
10
+ import json
11
+ import logging
12
+ import os
13
+ import time
14
+ import threading
15
+ from concurrent.futures import ThreadPoolExecutor, as_completed
16
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
17
+
18
+ from ..client.connection import MCPConnectionManager
19
+ from ..types import MCPSession, MCPToolCall, Trajectory, TerminationReason, LLMUsageStats
20
+
21
+ from tau2.user.user_simulator import UserSimulator
22
+ from tau2.data_model.message import AssistantMessage, UserMessage
23
+
24
+ if TYPE_CHECKING:
25
+ from ..session.manager import GeneralMCPVectorEnv
26
+ from .policy import LLMBasePolicy
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class ExecutionManager:
32
+ """
33
+ Unified manager that handles both MCP session lifecycle and rollout execution.
34
+
35
+ Combines the functionality of SessionManager and RolloutManager for better
36
+ organization and reduced complexity.
37
+ """
38
+
39
+ def __init__(self):
40
+ """Initialize the execution manager."""
41
+ self.connection_manager = MCPConnectionManager()
42
+
43
+ async def initialize_sessions(self, sessions: List[MCPSession]) -> None:
44
+ """
45
+ Initialize multiple MCP sessions in parallel.
46
+
47
+ Args:
48
+ sessions: List of MCPSessions to initialize
49
+ """
50
+ tasks = [self.connection_manager.initialize_session(session) for session in sessions]
51
+ await asyncio.gather(*tasks)
52
+
53
+ async def close_sessions(self, sessions: List[MCPSession]) -> None:
54
+ """
55
+ Close multiple MCP sessions in parallel.
56
+
57
+ Args:
58
+ sessions: List of MCPSessions to close
59
+ """
60
+ tasks = [asyncio.create_task(self.connection_manager.close_session(session)) for session in sessions]
61
+
62
+ if tasks:
63
+ try:
64
+ # Wait for all close operations to complete
65
+ await asyncio.gather(*tasks, return_exceptions=True)
66
+ except asyncio.CancelledError:
67
+ # Handle cancellation gracefully (especially important for Python 3.12)
68
+ logger.debug("Close operation was cancelled, but sessions are marked as closed")
69
+
70
+ async def execute_rollouts(
71
+ self,
72
+ envs: "GeneralMCPVectorEnv",
73
+ policy: Union["LLMBasePolicy", Callable],
74
+ steps: int = 512,
75
+ openai_format_log_file: Optional[str] = None,
76
+ max_concurrent_rollouts: int = 8,
77
+ ) -> List[Trajectory]:
78
+ """
79
+ Execute general rollouts using tool calling interface with automatic record/playback.
80
+
81
+ This works with ANY MCP environment because:
82
+ 1. Policy receives tool schemas and makes tool calls
83
+ 2. Environment prompts come from dataset
84
+ 3. No hardcoded environment logic
85
+
86
+ Args:
87
+ envs: GeneralMCPVectorEnv instance
88
+ policy: Policy that takes tool schemas, observations, prompts and returns tool calls
89
+ steps: Maximum steps per rollout
90
+ openai_format_log_file: Optional file to log clean OpenAI format for terminated trajectories only
91
+ max_concurrent_rollouts: Maximum number of concurrent threads to run
92
+
93
+ Environment Variable Control:
94
+ EP_PLAYBACK_FILE: Controls record/playback mode
95
+ - Not set: Normal live mode
96
+ - Set but file doesn't exist: Record mode (file will be created)
97
+ - Set and file exists: Playback mode (uses recorded data)
98
+
99
+ Returns:
100
+ List of Trajectory objects with complete rollout data
101
+ """
102
+ start_time = time.time()
103
+
104
+ # Check for record/playback mode
105
+ playback_file = os.environ.get("EP_PLAYBACK_FILE")
106
+ recording_mode = bool(playback_file and not os.path.exists(playback_file))
107
+ playback_mode = bool(playback_file and os.path.exists(playback_file))
108
+
109
+ if recording_mode:
110
+ logger.info(f"📝 Recording mode: Will record to {playback_file}")
111
+ elif playback_mode:
112
+ logger.info(f"🎬 Playback mode: Using recorded data from {playback_file}")
113
+ else:
114
+ logger.info(f"🚀 Live mode: No recording/playback")
115
+
116
+ # Initialize OpenAI format logging for terminated trajectories only
117
+ openai_logger = None
118
+ if openai_format_log_file:
119
+ # Clear the file at start
120
+ with open(openai_format_log_file, "w") as f:
121
+ pass
122
+ openai_logger = lambda data: self._log_openai_entry(openai_format_log_file, data)
123
+
124
+ logger.info(f"🧵 Starting {envs.n} rollouts with max {max_concurrent_rollouts} concurrent threads...")
125
+
126
+ results = {}
127
+
128
+ semaphore = asyncio.Semaphore(max_concurrent_rollouts)
129
+
130
+ async def _execute_with_semaphore(idx):
131
+ async with semaphore:
132
+ return await self._execute_rollout(
133
+ envs, policy, idx, steps, openai_logger, recording_mode, playback_mode, start_time
134
+ )
135
+
136
+ tasks = [_execute_with_semaphore(i) for i in range(envs.n)]
137
+ trajectories = await asyncio.gather(*tasks)
138
+
139
+ # Calculate durations
140
+ total_duration = time.time() - start_time
141
+ for trajectory in trajectories:
142
+ trajectory.duration = total_duration
143
+
144
+ # Clean up
145
+ await envs.close()
146
+
147
+ # Enhanced reporting with control plane info
148
+ successful = sum(1 for traj in trajectories if traj.total_reward > 0)
149
+ terminated_by_control_plane = sum(
150
+ 1
151
+ for traj in trajectories
152
+ if traj.control_plane_summary.get("termination_reason") == "control_plane_signal"
153
+ )
154
+
155
+ logger.info(f"📊 Rollout complete: {successful}/{len(trajectories)} reached goal")
156
+ logger.info(f"🎛️ Control plane terminations: {terminated_by_control_plane}/{len(trajectories)}")
157
+ logger.info(f"⏱️ Total duration: {total_duration:.2f}s")
158
+ logger.info(f"🧵 Used {max_concurrent_rollouts} concurrent threads")
159
+
160
+ # Print log file locations if created
161
+ if openai_format_log_file:
162
+ logger.info(f"💬 OpenAI format log: {openai_format_log_file}")
163
+ if recording_mode:
164
+ logger.info(f"📝 Recorded trajectory: {playback_file}")
165
+ # Add note about control plane separation
166
+ logger.info(f"🎛️ Trajectories include control plane separation")
167
+
168
+ return trajectories
169
+
170
+ async def _execute_rollout(
171
+ self,
172
+ envs: "GeneralMCPVectorEnv",
173
+ policy: Union["LLMBasePolicy", Callable],
174
+ rollout_idx: int,
175
+ steps: int,
176
+ openai_logger: Optional[Callable],
177
+ recording_mode: bool,
178
+ playback_mode: bool,
179
+ start_time: float,
180
+ ) -> Trajectory:
181
+ """
182
+ Execute a single rollout for one environment (async version for thread execution).
183
+
184
+ This method runs within a thread's event loop and handles all async operations.
185
+ """
186
+ session = envs.sessions[rollout_idx]
187
+ dataset_row = envs.dataset_rows[rollout_idx]
188
+
189
+ # Initialize trajectory
190
+ trajectory = Trajectory(
191
+ session=session,
192
+ observations=[],
193
+ actions=[],
194
+ rewards=[],
195
+ terminated=False,
196
+ total_reward=0.0,
197
+ steps=0,
198
+ duration=0.0,
199
+ control_plane_steps=[],
200
+ control_plane_summary={},
201
+ termination_reason="",
202
+ conversation_history=[],
203
+ llm_usage_summary={
204
+ "prompt_tokens": 0,
205
+ "completion_tokens": 0,
206
+ "total_tokens": 0,
207
+ },
208
+ )
209
+
210
+ current_observation, tool_schema = await envs.reset(session)
211
+ system_prompt = dataset_row.system_prompt
212
+
213
+ # Record initial observation
214
+ trajectory.observations.append(current_observation)
215
+
216
+ # Create user simulator for this rollout if configured in dataset
217
+ user_simulator = None
218
+ user_simulator_state = None
219
+
220
+ # If user simulation is enabled, initial message is from the simulated user
221
+ if dataset_row.user_simulation and dataset_row.user_simulation.get("enabled", False):
222
+ user_simulator = UserSimulator(
223
+ instructions=dataset_row.user_simulation.get("system_prompt"),
224
+ llm=dataset_row.user_simulation.get("llm", "gpt-4.1"),
225
+ llm_args=dataset_row.user_simulation.get("llm_args", {"temperature": 0.7}),
226
+ )
227
+
228
+ # Get initial messages in tau2-bench format for user simulator
229
+ user_simulator_state = user_simulator.get_init_state()
230
+ user_message, user_simulator_state = user_simulator.generate_next_message(
231
+ AssistantMessage(role="assistant", content="Hi! How can I help you today?"),
232
+ user_simulator_state,
233
+ )
234
+ current_observation = user_message.content if user_message.content else ""
235
+
236
+ user_prompt = envs.format_user_prompt(rollout_idx, current_observation)
237
+ conversation_history = [
238
+ {"role": "system", "content": system_prompt},
239
+ {"role": "user", "content": user_prompt},
240
+ ]
241
+
242
+ usage_stats_list: List[LLMUsageStats] = []
243
+
244
+ logger.info(f"🎯 Starting rollout {rollout_idx} in thread {threading.current_thread().name}")
245
+
246
+ # Run rollout loop for this specific environment
247
+ step = 0
248
+ rollout_end = False
249
+
250
+ while step < steps and not trajectory.terminated:
251
+ turn_completed = False
252
+ info = {}
253
+ reward = 0.0
254
+ observation = current_observation
255
+ tool_calls = []
256
+
257
+ if user_simulator and user_simulator_state:
258
+ # Get user simulator messages and find the last assistant message
259
+ user_simulator_messages = self._get_user_simulator_messages(conversation_history)
260
+
261
+ # Last message was agent, simulated user response
262
+ if user_simulator_messages and isinstance(user_simulator_messages[-1], AssistantMessage):
263
+ # Generate user response using the simulator
264
+ user_message, user_simulator_state = user_simulator.generate_next_message(
265
+ user_simulator_messages[-1], user_simulator_state
266
+ )
267
+ user_content = user_message.content if user_message.content else ""
268
+
269
+ user_prompt = envs.format_user_prompt(rollout_idx, user_content)
270
+ conversation_history.append({"role": "user", "content": user_prompt})
271
+
272
+ # Check if user simulator signaled termination
273
+ if UserSimulator.is_stop(user_message):
274
+ trajectory.terminated = True
275
+ trajectory.termination_reason = TerminationReason.USER_STOP
276
+
277
+ # In each turn: keep looping until assistant is ready to provide final response
278
+ while not turn_completed and not trajectory.terminated:
279
+ tool_calls, usage_stats = await policy(tool_schema, rollout_idx, conversation_history)
280
+
281
+ # If no tool call is generated, turn is finished
282
+ if len(tool_calls) == 1:
283
+ # No tool calls means the policy is ready to provide final response on this turn
284
+ if tool_calls[0].tool_name == "_no_tool_call":
285
+ trajectory.terminated = True
286
+ break
287
+ elif tool_calls[0].tool_name == "_playback_terminate":
288
+ trajectory.terminated = True
289
+ break
290
+
291
+ # Execute each tool call sequentially
292
+ for tool_call in tool_calls:
293
+
294
+ # Execute tool call for this environment
295
+ observation, reward, rollout_end, info = await envs.step(rollout_idx, tool_call)
296
+
297
+ tool_response = envs.format_tool_response(observation)
298
+
299
+ policy.add_tool_response(
300
+ rollout_idx,
301
+ tool_call,
302
+ tool_response,
303
+ conversation_history,
304
+ reward,
305
+ rollout_end,
306
+ info,
307
+ )
308
+
309
+ # Update trajectory with both data and control plane information
310
+ trajectory.observations.append(observation)
311
+
312
+ # Record action (tool call)
313
+ action_str = f"{tool_call.tool_name}({tool_call.arguments})"
314
+ trajectory.actions.append(action_str)
315
+
316
+ # Record control plane (reward/termination)
317
+ trajectory.rewards.append(reward)
318
+ trajectory.total_reward += reward
319
+
320
+ # Non-user simulator step counter: each tool call is a step
321
+ if user_simulator is None:
322
+ step += 1
323
+ trajectory.steps = step
324
+
325
+ control_plane_step = {
326
+ "step": step - 1,
327
+ "reward": reward,
328
+ "terminated": rollout_end,
329
+ "info": info.get("control_plane", {}),
330
+ "tool_calls": [f"{tool_call.tool_name}({tool_call.arguments})"],
331
+ "num_tool_calls": 1,
332
+ }
333
+ trajectory.control_plane_steps.append(control_plane_step)
334
+
335
+ # Log conversation state for playback if in recording mode
336
+ if recording_mode:
337
+ policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history)
338
+
339
+ if rollout_end:
340
+ trajectory.terminated = True
341
+ trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL
342
+ break
343
+ elif step >= steps:
344
+ trajectory.terminated = True
345
+ trajectory.termination_reason = TerminationReason.MAX_STEPS
346
+ break
347
+
348
+ # Update current observation for potential next turn
349
+ if observation is not None:
350
+ current_observation = observation
351
+
352
+ # calc llm usage stats happened in this turn if there is aany
353
+ if usage_stats:
354
+ usage_stats_list.append(usage_stats)
355
+
356
+ # With user simulator, increment step after an entire conversation step
357
+ if user_simulator is not None:
358
+ step += 1
359
+ trajectory.steps = step
360
+
361
+ # Enhanced trajectory recording with control plane info
362
+ # Create summary of all tool calls executed in this step
363
+ tool_calls_summary = [f"{tc.tool_name}({tc.arguments})" for tc in tool_calls]
364
+
365
+ control_plane_step = {
366
+ "step": step - 1,
367
+ "reward": reward,
368
+ "terminated": rollout_end,
369
+ "info": info.get("control_plane", {}),
370
+ "tool_calls": tool_calls_summary,
371
+ "num_tool_calls": len(tool_calls),
372
+ }
373
+ trajectory.control_plane_steps.append(control_plane_step)
374
+
375
+ # Log conversation state for playback if in recording mode
376
+ if recording_mode:
377
+ policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history)
378
+
379
+ # Use control plane information for termination decision
380
+ if rollout_end:
381
+ trajectory.terminated = True
382
+ trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL
383
+
384
+ # Add final control plane summary
385
+ trajectory.control_plane_summary.update(
386
+ {
387
+ "total_reward": trajectory.total_reward,
388
+ "termination_reason": trajectory.termination_reason,
389
+ "final_step": step - 1,
390
+ "control_plane_source": info.get("control_plane", {}),
391
+ }
392
+ )
393
+
394
+ # Log final OpenAI conversation for terminated trajectories only
395
+ if openai_logger:
396
+ if conversation_history and len(conversation_history) > 0:
397
+ openai_logger(
398
+ {
399
+ "messages": conversation_history,
400
+ "metadata": {
401
+ "session_id": session.session_id,
402
+ "seed": session.seed,
403
+ "total_steps": trajectory.steps,
404
+ "total_reward": trajectory.total_reward,
405
+ "terminated": True,
406
+ "success": reward > 0,
407
+ "control_plane_summary": trajectory.control_plane_summary,
408
+ },
409
+ }
410
+ )
411
+
412
+ logger.info(
413
+ f"🏁 Rollout {rollout_idx} terminated at step {step} (reward: {trajectory.total_reward}) in thread {threading.current_thread().name}"
414
+ )
415
+ break
416
+
417
+ # Progress logging
418
+ if step % 10 == 0:
419
+ logger.debug(f"Rollout {rollout_idx} step {step}, reward: {trajectory.total_reward:.2f}")
420
+
421
+ # Set termination reason if not already set (e.g., due to step limit)
422
+ if not trajectory.termination_reason and step >= steps:
423
+ trajectory.termination_reason = TerminationReason.MAX_STEPS
424
+
425
+ trajectory.conversation_history = conversation_history
426
+
427
+ for usage_stats in usage_stats_list:
428
+ trajectory.llm_usage_summary["prompt_tokens"] += usage_stats.prompt_tokens
429
+ trajectory.llm_usage_summary["completion_tokens"] += usage_stats.completion_tokens
430
+ trajectory.llm_usage_summary["total_tokens"] += usage_stats.total_tokens
431
+
432
+ logger.info(
433
+ f"✅ Rollout {rollout_idx} completed: {trajectory.steps} steps, reward: {trajectory.total_reward:.2f}, termination: {trajectory.termination_reason}, in thread {threading.current_thread().name}"
434
+ )
435
+ return trajectory
436
+
437
+ async def _get_control_plane_status(self, session) -> Optional[Dict[str, Any]]:
438
+ """
439
+ Query the control plane status endpoint directly for a session.
440
+
441
+ Args:
442
+ session: MCP session object
443
+
444
+ Returns:
445
+ Control plane status dictionary or None if query fails
446
+ """
447
+ try:
448
+ import httpx
449
+
450
+ # Extract base URL and session ID
451
+ base_url = session.base_url.rstrip("/mcp").rstrip("/")
452
+ session_id = session.session_id
453
+
454
+ if not session_id:
455
+ logger.debug("Control plane query failed: No session ID")
456
+ return None
457
+
458
+ headers = {"mcp-session-id": session_id}
459
+
460
+ # Query status endpoint
461
+ async with httpx.AsyncClient(timeout=2.0) as client:
462
+ status_response = await client.get(
463
+ f"{base_url}/control/status",
464
+ headers=headers,
465
+ timeout=2.0, # Short timeout for performance
466
+ )
467
+
468
+ if status_response.status_code == 200:
469
+ status_data = status_response.json()
470
+ return status_data
471
+ else:
472
+ logger.debug(
473
+ f"Control plane endpoint returned {status_response.status_code} for session {session_id[:16]}"
474
+ )
475
+ return None
476
+
477
+ except asyncio.TimeoutError:
478
+ logger.debug(f"Control plane query timed out for session {session.session_id[:16]}")
479
+ return None
480
+ except Exception as e:
481
+ logger.debug(f"Control plane query failed for session {session.session_id[:16]}: {e}")
482
+ return None
483
+
484
+ def _log_openai_entry(self, log_file: str, data: Dict[str, Any]):
485
+ """Helper function to log OpenAI format entries."""
486
+ with open(log_file, "a") as f:
487
+ f.write(json.dumps(data) + "\n")
488
+
489
+ def _get_user_simulator_messages(self, conversation_history: List[Dict[str, Any]]) -> List:
490
+ """
491
+ Filter conversation history for user simulator and convert to tau2-bench format.
492
+ """
493
+ tau2_messages = []
494
+
495
+ for message in conversation_history:
496
+ role = message.get("role")
497
+ content = message.get("content", "")
498
+
499
+ if role == "assistant":
500
+ if "tool_calls" not in message or not message.get("tool_calls"):
501
+ tau2_messages.append(AssistantMessage(role="assistant", content=content))
502
+
503
+ elif role == "user":
504
+ tau2_messages.append(UserMessage(role="user", content=content))
505
+
506
+ return tau2_messages