eval-protocol 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- development/__init__.py +1 -0
- development/normalize_sandbox_fusion.py +628 -0
- development/utils/__init__.py +1 -0
- development/utils/generate_api_key.py +31 -0
- development/utils/subprocess_manager.py +481 -0
- eval_protocol/__init__.py +86 -0
- eval_protocol/__main__.py +10 -0
- eval_protocol/_version.py +21 -0
- eval_protocol/adapters/__init__.py +1 -0
- eval_protocol/adapters/braintrust.py +8 -0
- eval_protocol/adapters/trl.py +8 -0
- eval_protocol/agent/__init__.py +29 -0
- eval_protocol/agent/models.py +69 -0
- eval_protocol/agent/orchestrator.py +893 -0
- eval_protocol/agent/resource_abc.py +89 -0
- eval_protocol/agent/resource_pool.py +184 -0
- eval_protocol/agent/resources/__init__.py +44 -0
- eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
- eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
- eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
- eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
- eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
- eval_protocol/agent/resources/docker_resource.py +479 -0
- eval_protocol/agent/resources/filesystem_resource.py +371 -0
- eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
- eval_protocol/agent/resources/http_rollout_resource.py +325 -0
- eval_protocol/agent/resources/python_state_resource.py +170 -0
- eval_protocol/agent/resources/sql_resource.py +271 -0
- eval_protocol/agent/task_manager.py +1064 -0
- eval_protocol/agent/tool_registry.py +111 -0
- eval_protocol/auth.py +156 -0
- eval_protocol/cli.py +425 -0
- eval_protocol/cli_commands/__init__.py +1 -0
- eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
- eval_protocol/cli_commands/common.py +242 -0
- eval_protocol/cli_commands/deploy.py +486 -0
- eval_protocol/cli_commands/deploy_mcp.py +287 -0
- eval_protocol/cli_commands/preview.py +186 -0
- eval_protocol/cli_commands/run_eval_cmd.py +202 -0
- eval_protocol/common_utils.py +36 -0
- eval_protocol/config.py +180 -0
- eval_protocol/datasets/__init__.py +1 -0
- eval_protocol/datasets/loader.py +521 -0
- eval_protocol/evaluation.py +1045 -0
- eval_protocol/execution/__init__.py +1 -0
- eval_protocol/execution/pipeline.py +920 -0
- eval_protocol/gcp_tools.py +484 -0
- eval_protocol/generation/cache.py +141 -0
- eval_protocol/generation/clients/base.py +67 -0
- eval_protocol/generation/clients.py +248 -0
- eval_protocol/generic_server.py +165 -0
- eval_protocol/integrations/__init__.py +12 -0
- eval_protocol/integrations/braintrust.py +51 -0
- eval_protocol/integrations/deepeval.py +106 -0
- eval_protocol/integrations/openeval.py +40 -0
- eval_protocol/integrations/trl.py +187 -0
- eval_protocol/mcp/__init__.py +48 -0
- eval_protocol/mcp/adapter.py +131 -0
- eval_protocol/mcp/client/__init__.py +12 -0
- eval_protocol/mcp/client/connection.py +499 -0
- eval_protocol/mcp/clients.py +195 -0
- eval_protocol/mcp/execution/__init__.py +23 -0
- eval_protocol/mcp/execution/base_policy.py +227 -0
- eval_protocol/mcp/execution/fireworks_policy.py +209 -0
- eval_protocol/mcp/execution/manager.py +506 -0
- eval_protocol/mcp/execution/policy.py +421 -0
- eval_protocol/mcp/grid_renderer.py +54 -0
- eval_protocol/mcp/mcpgym.py +637 -0
- eval_protocol/mcp/process_manager.py +177 -0
- eval_protocol/mcp/session/__init__.py +11 -0
- eval_protocol/mcp/session/manager.py +228 -0
- eval_protocol/mcp/simple_process_manager.py +291 -0
- eval_protocol/mcp/simulation_server.py +458 -0
- eval_protocol/mcp/types.py +80 -0
- eval_protocol/mcp_agent/__init__.py +1 -0
- eval_protocol/mcp_agent/config.py +147 -0
- eval_protocol/mcp_agent/intermediary_server.py +542 -0
- eval_protocol/mcp_agent/main.py +210 -0
- eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
- eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
- eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
- eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
- eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
- eval_protocol/mcp_agent/session.py +79 -0
- eval_protocol/mcp_env.py +304 -0
- eval_protocol/models.py +366 -0
- eval_protocol/packaging.py +219 -0
- eval_protocol/platform_api.py +360 -0
- eval_protocol/playback_policy.py +396 -0
- eval_protocol/resources.py +128 -0
- eval_protocol/reward_function.py +410 -0
- eval_protocol/rewards/__init__.py +94 -0
- eval_protocol/rewards/accuracy.py +454 -0
- eval_protocol/rewards/accuracy_length.py +173 -0
- eval_protocol/rewards/apps_coding_reward.py +331 -0
- eval_protocol/rewards/apps_execution_utils.py +149 -0
- eval_protocol/rewards/apps_testing_util.py +559 -0
- eval_protocol/rewards/bfcl_reward.py +313 -0
- eval_protocol/rewards/code_execution.py +1620 -0
- eval_protocol/rewards/code_execution_utils.py +72 -0
- eval_protocol/rewards/cpp_code.py +861 -0
- eval_protocol/rewards/deepcoder_reward.py +161 -0
- eval_protocol/rewards/format.py +129 -0
- eval_protocol/rewards/function_calling.py +541 -0
- eval_protocol/rewards/json_schema.py +422 -0
- eval_protocol/rewards/language_consistency.py +700 -0
- eval_protocol/rewards/lean_prover.py +479 -0
- eval_protocol/rewards/length.py +375 -0
- eval_protocol/rewards/list_comparison_math_reward.py +221 -0
- eval_protocol/rewards/math.py +762 -0
- eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
- eval_protocol/rewards/reasoning_steps.py +249 -0
- eval_protocol/rewards/repetition.py +342 -0
- eval_protocol/rewards/tag_count.py +162 -0
- eval_protocol/rl_processing.py +82 -0
- eval_protocol/server.py +271 -0
- eval_protocol/typed_interface.py +260 -0
- eval_protocol/utils/__init__.py +8 -0
- eval_protocol/utils/batch_evaluation.py +217 -0
- eval_protocol/utils/batch_transformation.py +205 -0
- eval_protocol/utils/dataset_helpers.py +112 -0
- eval_protocol/utils/module_loader.py +56 -0
- eval_protocol/utils/packaging_utils.py +108 -0
- eval_protocol/utils/static_policy.py +305 -0
- eval_protocol-0.0.3.dist-info/METADATA +635 -0
- eval_protocol-0.0.3.dist-info/RECORD +130 -0
- eval_protocol-0.0.3.dist-info/WHEEL +5 -0
- eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
- eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
- eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,499 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MCP Connection Management
|
|
3
|
+
|
|
4
|
+
Handles MCP client connections, session initialization, and resource/tool discovery.
|
|
5
|
+
Extracted from mcp_env.py to improve modularity.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import hashlib
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
from contextlib import AsyncExitStack
|
|
13
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
14
|
+
|
|
15
|
+
from mcp.client.session import ClientSession
|
|
16
|
+
from mcp.client.streamable_http import streamablehttp_client
|
|
17
|
+
|
|
18
|
+
from ..types import MCPSession
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class MCPConnectionManager:
|
|
24
|
+
"""Manages MCP client connections and session lifecycle."""
|
|
25
|
+
|
|
26
|
+
async def initialize_session(self, session: MCPSession) -> None:
|
|
27
|
+
"""
|
|
28
|
+
Initialize a persistent MCP session.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
session: The MCPSession to initialize
|
|
32
|
+
"""
|
|
33
|
+
if session._mcp_session:
|
|
34
|
+
# If a session exists, close it before creating a new one.
|
|
35
|
+
if session._exit_stack:
|
|
36
|
+
try:
|
|
37
|
+
await session._exit_stack.aclose()
|
|
38
|
+
except asyncio.CancelledError:
|
|
39
|
+
# Handle cancellation gracefully (especially important for Python 3.12)
|
|
40
|
+
logger.debug(f"Session {session.session_id} reinit close was cancelled")
|
|
41
|
+
except Exception as e:
|
|
42
|
+
logger.warning(f"Error closing existing session {session.session_id} during reinit: {e}")
|
|
43
|
+
finally:
|
|
44
|
+
session._exit_stack = None
|
|
45
|
+
session._mcp_session = None
|
|
46
|
+
|
|
47
|
+
exit_stack = AsyncExitStack()
|
|
48
|
+
|
|
49
|
+
client_info = None
|
|
50
|
+
if session.seed is not None or (session.dataset_row and session.dataset_row.environment_context):
|
|
51
|
+
from mcp.types import Implementation
|
|
52
|
+
|
|
53
|
+
client_info = Implementation(name="reward-kit", version="1.0.0", _extra={})
|
|
54
|
+
if session.seed is not None:
|
|
55
|
+
client_info._extra["seed"] = session.seed
|
|
56
|
+
if session.dataset_row and session.dataset_row.environment_context:
|
|
57
|
+
client_info._extra["config"] = session.dataset_row.environment_context
|
|
58
|
+
if session.dataset_row and session.dataset_row.id:
|
|
59
|
+
client_info._extra["dataset_row_id"] = session.dataset_row.id
|
|
60
|
+
if session.model_id:
|
|
61
|
+
client_info._extra["model_id"] = session.model_id
|
|
62
|
+
|
|
63
|
+
read_stream, write_stream, _ = await exit_stack.enter_async_context(
|
|
64
|
+
streamablehttp_client(session.base_url, terminate_on_close=True)
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
mcp_session = await exit_stack.enter_async_context(
|
|
68
|
+
ClientSession(read_stream, write_stream, client_info=client_info)
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
await mcp_session.initialize()
|
|
72
|
+
|
|
73
|
+
session._mcp_session = mcp_session
|
|
74
|
+
session._exit_stack = exit_stack
|
|
75
|
+
|
|
76
|
+
# Update session ID to match server's calculation (for control plane sync)
|
|
77
|
+
if client_info and hasattr(client_info, "_extra"):
|
|
78
|
+
extra_data = client_info._extra
|
|
79
|
+
if extra_data and isinstance(extra_data, dict):
|
|
80
|
+
|
|
81
|
+
seed_value = extra_data.get("seed")
|
|
82
|
+
config_value = extra_data.get("config", {})
|
|
83
|
+
dataset_row_id_value = extra_data.get("dataset_row_id")
|
|
84
|
+
model_id_value = extra_data.get("model_id")
|
|
85
|
+
|
|
86
|
+
stable_data = {
|
|
87
|
+
"seed": seed_value,
|
|
88
|
+
"config": config_value,
|
|
89
|
+
"dataset_row_id": dataset_row_id_value,
|
|
90
|
+
"model_id": model_id_value,
|
|
91
|
+
"name": client_info.name,
|
|
92
|
+
"version": client_info.version,
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
stable_str = json.dumps(stable_data, sort_keys=True)
|
|
96
|
+
server_session_id = hashlib.md5(stable_str.encode()).hexdigest()
|
|
97
|
+
|
|
98
|
+
# Update the session ID to match what the server generated
|
|
99
|
+
session.session_id = server_session_id
|
|
100
|
+
logger.debug(f"Updated session ID to match server: {server_session_id}")
|
|
101
|
+
|
|
102
|
+
async def discover_tools(self, session: MCPSession) -> List[Dict]:
|
|
103
|
+
"""
|
|
104
|
+
Discover available tools from an MCP session.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
session: The MCPSession to discover tools from
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
List of tool schemas
|
|
111
|
+
"""
|
|
112
|
+
if not session._mcp_session:
|
|
113
|
+
raise RuntimeError("Session not initialized")
|
|
114
|
+
|
|
115
|
+
mcp_session = session._mcp_session
|
|
116
|
+
|
|
117
|
+
# Get available tools from MCP server
|
|
118
|
+
tools_response = await mcp_session.list_tools()
|
|
119
|
+
tools = tools_response.tools if hasattr(tools_response, "tools") else []
|
|
120
|
+
|
|
121
|
+
# Convert tools to schema format - filter out internal tools
|
|
122
|
+
tool_schemas = []
|
|
123
|
+
for tool in tools:
|
|
124
|
+
# Only expose action tools to the model, not internal state tools
|
|
125
|
+
tool_schema = {
|
|
126
|
+
"name": tool.name,
|
|
127
|
+
"description": tool.description,
|
|
128
|
+
"input_schema": (tool.inputSchema if hasattr(tool, "inputSchema") else {}),
|
|
129
|
+
}
|
|
130
|
+
tool_schemas.append(tool_schema)
|
|
131
|
+
|
|
132
|
+
return tool_schemas
|
|
133
|
+
|
|
134
|
+
async def get_initial_state(self, session: MCPSession) -> Any:
|
|
135
|
+
"""
|
|
136
|
+
Get initial state from session-aware control plane endpoint.
|
|
137
|
+
Uses HTTP endpoint instead of MCP resources for proper session awareness.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
session: The MCPSession to get initial state from
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
Initial observation/state
|
|
144
|
+
"""
|
|
145
|
+
if not session._mcp_session:
|
|
146
|
+
raise RuntimeError("Session not initialized")
|
|
147
|
+
|
|
148
|
+
# Try to get initial state from control plane endpoint first
|
|
149
|
+
initial_observation = None
|
|
150
|
+
|
|
151
|
+
try:
|
|
152
|
+
import httpx
|
|
153
|
+
|
|
154
|
+
# Extract base URL and session ID from the MCP session
|
|
155
|
+
base_url = session.base_url.rstrip("/mcp").rstrip("/")
|
|
156
|
+
session_id = session.session_id
|
|
157
|
+
|
|
158
|
+
if session_id:
|
|
159
|
+
headers = {"mcp-session-id": session_id}
|
|
160
|
+
|
|
161
|
+
# Query initial state endpoint
|
|
162
|
+
try:
|
|
163
|
+
# Use shorter timeout for playback mode
|
|
164
|
+
timeout = 3.0 if hasattr(session, "_is_playback_mode") and session._is_playback_mode else 5.0
|
|
165
|
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
166
|
+
initial_state_response = await client.get(
|
|
167
|
+
f"{base_url}/control/initial_state",
|
|
168
|
+
headers=headers,
|
|
169
|
+
timeout=timeout,
|
|
170
|
+
)
|
|
171
|
+
if initial_state_response.status_code == 200:
|
|
172
|
+
initial_observation = initial_state_response.json()
|
|
173
|
+
logger.info(
|
|
174
|
+
f"Session {session.session_id}: ✅ Successfully fetched session-aware initial state from control plane endpoint"
|
|
175
|
+
)
|
|
176
|
+
else:
|
|
177
|
+
logger.warning(
|
|
178
|
+
f"Control plane initial state endpoint returned {initial_state_response.status_code}"
|
|
179
|
+
)
|
|
180
|
+
except httpx.TimeoutException:
|
|
181
|
+
logger.warning(f"Control plane initial state endpoint timed out after {timeout}s")
|
|
182
|
+
except Exception as e:
|
|
183
|
+
logger.warning(f"Failed to query initial state endpoint: {e}")
|
|
184
|
+
|
|
185
|
+
except Exception as e:
|
|
186
|
+
logger.warning(f"Failed to query control plane initial state endpoint: {e}")
|
|
187
|
+
|
|
188
|
+
# Fallback to MCP resource if control plane endpoint fails (backward compatibility)
|
|
189
|
+
if initial_observation is None:
|
|
190
|
+
logger.debug(f"Session {session.session_id}: Falling back to MCP resource for initial state")
|
|
191
|
+
initial_observation = await self._get_initial_state_from_mcp_resource(session)
|
|
192
|
+
|
|
193
|
+
# Ensure we have some observation
|
|
194
|
+
if initial_observation is None:
|
|
195
|
+
logger.debug(f"Session {session.session_id}: Using default initial state")
|
|
196
|
+
initial_observation = {
|
|
197
|
+
"observation": "default_initial_state",
|
|
198
|
+
"session_id": session.session_id,
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
return initial_observation
|
|
202
|
+
|
|
203
|
+
async def _get_initial_state_from_mcp_resource(self, session: MCPSession) -> Any:
|
|
204
|
+
"""
|
|
205
|
+
Fallback method to get initial state from MCP resources.
|
|
206
|
+
This is kept for backward compatibility but should be replaced by control plane endpoints.
|
|
207
|
+
"""
|
|
208
|
+
mcp_session = session._mcp_session
|
|
209
|
+
initial_observation = None
|
|
210
|
+
|
|
211
|
+
try:
|
|
212
|
+
# List available resources - this is where initial state should come from
|
|
213
|
+
logger.debug(f"Session {session.session_id}: Discovering MCP resources for initial state...")
|
|
214
|
+
resources_response = await mcp_session.list_resources()
|
|
215
|
+
resources = resources_response.resources if hasattr(resources_response, "resources") else []
|
|
216
|
+
logger.debug(f"Session {session.session_id}: Found {len(resources)} MCP resources")
|
|
217
|
+
for resource in resources:
|
|
218
|
+
logger.debug(f"Session {session.session_id}: Resource: {resource.name} | URI: {resource.uri}")
|
|
219
|
+
|
|
220
|
+
# Try to identify initial state resource based on common patterns
|
|
221
|
+
initial_state_resource = None
|
|
222
|
+
for resource in resources:
|
|
223
|
+
resource_name_lower = resource.name.lower()
|
|
224
|
+
resource_uri_lower = str(resource.uri).lower() # Convert AnyUrl to string first
|
|
225
|
+
if any(
|
|
226
|
+
keyword in resource_name_lower or keyword in resource_uri_lower
|
|
227
|
+
for keyword in ["initial", "state", "observation", "start"]
|
|
228
|
+
):
|
|
229
|
+
initial_state_resource = resource
|
|
230
|
+
logger.debug(
|
|
231
|
+
f"Session {session.session_id}: ✅ Found initial state resource: {resource.name} | URI: {resource.uri}"
|
|
232
|
+
)
|
|
233
|
+
break
|
|
234
|
+
|
|
235
|
+
if initial_state_resource:
|
|
236
|
+
# Read the initial state resource
|
|
237
|
+
logger.debug(
|
|
238
|
+
f"Session {session.session_id}: Reading initial state from resource: {initial_state_resource.uri}"
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
resource_content = await mcp_session.read_resource(initial_state_resource.uri)
|
|
242
|
+
|
|
243
|
+
# Handle the new ResourceContents format
|
|
244
|
+
if hasattr(resource_content, "text"):
|
|
245
|
+
try:
|
|
246
|
+
initial_observation = json.loads(resource_content.text)
|
|
247
|
+
logger.info(
|
|
248
|
+
f"Session {session.session_id}: ✅ Successfully parsed JSON initial state with grid_layout: {initial_observation.get('grid_layout', 'N/A')[:20]}..."
|
|
249
|
+
)
|
|
250
|
+
except json.JSONDecodeError:
|
|
251
|
+
initial_observation = {"observation": resource_content.text}
|
|
252
|
+
elif (
|
|
253
|
+
hasattr(resource_content, "contents")
|
|
254
|
+
and resource_content.contents
|
|
255
|
+
and len(resource_content.contents) > 0
|
|
256
|
+
):
|
|
257
|
+
# Fallback to old format for backward compatibility
|
|
258
|
+
content = resource_content.contents[0]
|
|
259
|
+
if hasattr(content, "text"):
|
|
260
|
+
try:
|
|
261
|
+
initial_observation = json.loads(content.text)
|
|
262
|
+
except json.JSONDecodeError:
|
|
263
|
+
initial_observation = {"observation": content.text}
|
|
264
|
+
else:
|
|
265
|
+
initial_observation = {"observation": str(resource_content)}
|
|
266
|
+
else:
|
|
267
|
+
logger.warning(f"Session {session.session_id}: Resource content is empty or unrecognized format")
|
|
268
|
+
logger.warning(f"Session {session.session_id}: Unexpected resource format")
|
|
269
|
+
initial_state_resource = None # Fall back to other options
|
|
270
|
+
else:
|
|
271
|
+
logger.warning(
|
|
272
|
+
f"Session {session.session_id}: ❌ No initial state resource found among {len(resources)} resources"
|
|
273
|
+
)
|
|
274
|
+
# Fallback: if no initial state resource, try first available resource
|
|
275
|
+
if resources:
|
|
276
|
+
first_resource = resources[0]
|
|
277
|
+
logger.debug(
|
|
278
|
+
f"Session {session.session_id}: No initial state resource found, using first resource: {first_resource.name}"
|
|
279
|
+
)
|
|
280
|
+
logger.debug(
|
|
281
|
+
f"Session {session.session_id}: About to call mcp_session.read_resource with fallback URI: {first_resource.uri}"
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
resource_content = await mcp_session.read_resource(first_resource.uri)
|
|
285
|
+
|
|
286
|
+
logger.debug(
|
|
287
|
+
f"Session {session.session_id}: fallback read_resource returned type: {type(resource_content)}"
|
|
288
|
+
)
|
|
289
|
+
logger.debug(
|
|
290
|
+
f"Session {session.session_id}: fallback read_resource returned value: {resource_content}"
|
|
291
|
+
)
|
|
292
|
+
logger.debug(
|
|
293
|
+
f"Session {session.session_id}: fallback read_resource dir(): {dir(resource_content)}"
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
# Handle the new ResourceContents format
|
|
297
|
+
if hasattr(resource_content, "text"):
|
|
298
|
+
try:
|
|
299
|
+
initial_observation = json.loads(resource_content.text)
|
|
300
|
+
except json.JSONDecodeError:
|
|
301
|
+
initial_observation = {"observation": resource_content.text}
|
|
302
|
+
elif (
|
|
303
|
+
hasattr(resource_content, "contents")
|
|
304
|
+
and resource_content.contents
|
|
305
|
+
and len(resource_content.contents) > 0
|
|
306
|
+
):
|
|
307
|
+
# Fallback to old format for backward compatibility
|
|
308
|
+
content = resource_content.contents[0]
|
|
309
|
+
if hasattr(content, "text"):
|
|
310
|
+
try:
|
|
311
|
+
initial_observation = json.loads(content.text)
|
|
312
|
+
except json.JSONDecodeError:
|
|
313
|
+
initial_observation = {"observation": content.text}
|
|
314
|
+
else:
|
|
315
|
+
initial_observation = {"observation": str(content)}
|
|
316
|
+
else:
|
|
317
|
+
logger.warning(f"Session {session.session_id}: Fallback resource has unexpected format")
|
|
318
|
+
initial_observation = {"observation": str(resource_content)}
|
|
319
|
+
else:
|
|
320
|
+
logger.debug(f"Session {session.session_id}: No resources available from MCP server")
|
|
321
|
+
|
|
322
|
+
except Exception as e:
|
|
323
|
+
# If resources are not available, fall back to a default observation
|
|
324
|
+
# This maintains backward compatibility with servers that don't expose resources
|
|
325
|
+
logger.warning(f"Session {session.session_id}: Failed to read initial state from MCP resources: {e}")
|
|
326
|
+
logger.warning(f"Session {session.session_id}: Exception type: {type(e)}")
|
|
327
|
+
logger.warning(f"Session {session.session_id}: Exception args: {e.args}")
|
|
328
|
+
import traceback
|
|
329
|
+
|
|
330
|
+
logger.warning(f"Session {session.session_id}: Full traceback: {traceback.format_exc()}")
|
|
331
|
+
initial_observation = {
|
|
332
|
+
"observation": "initial_state",
|
|
333
|
+
"message": "Session established",
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
return initial_observation
|
|
337
|
+
|
|
338
|
+
async def call_tool(self, session: MCPSession, tool_name: str, arguments: Dict) -> Tuple[Any, float, bool, Dict]:
|
|
339
|
+
"""
|
|
340
|
+
Execute a tool call via MCP protocol with control plane separation.
|
|
341
|
+
|
|
342
|
+
This method implements the control plane separation architecture:
|
|
343
|
+
1. Execute tool call (data plane) - contains only observations
|
|
344
|
+
2. Query control plane resources for reward/termination info
|
|
345
|
+
3. Return combined result maintaining strict plane separation
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
session: The MCPSession to execute the tool call on
|
|
349
|
+
tool_name: Name of the tool to call
|
|
350
|
+
arguments: Arguments for the tool call
|
|
351
|
+
|
|
352
|
+
Returns:
|
|
353
|
+
Tuple of (observation, reward, done, info) with control plane data
|
|
354
|
+
"""
|
|
355
|
+
if not session._mcp_session:
|
|
356
|
+
raise RuntimeError("Session not initialized")
|
|
357
|
+
|
|
358
|
+
mcp_session = session._mcp_session
|
|
359
|
+
|
|
360
|
+
# 1. Execute the tool call via MCP protocol (DATA PLANE)
|
|
361
|
+
tool_result = await mcp_session.call_tool(tool_name, arguments)
|
|
362
|
+
|
|
363
|
+
# Extract data plane results (observation only)
|
|
364
|
+
if tool_result.content and len(tool_result.content) > 0:
|
|
365
|
+
content = tool_result.content[0]
|
|
366
|
+
if hasattr(content, "text"):
|
|
367
|
+
# Fix: Handle empty or invalid JSON responses gracefully
|
|
368
|
+
if not content.text or content.text.strip() == "":
|
|
369
|
+
logger.warning(f"Session {session.session_id}: Empty tool response from {tool_name}")
|
|
370
|
+
observation = {
|
|
371
|
+
"observation": "empty_response",
|
|
372
|
+
"session_id": session.session_id,
|
|
373
|
+
}
|
|
374
|
+
else:
|
|
375
|
+
try:
|
|
376
|
+
observation = json.loads(content.text)
|
|
377
|
+
except json.JSONDecodeError as e:
|
|
378
|
+
logger.warning(
|
|
379
|
+
f"Session {session.session_id}: Invalid JSON from {tool_name}: {content.text}. Error: {e}"
|
|
380
|
+
)
|
|
381
|
+
# Create a structured response from the raw text
|
|
382
|
+
observation = {
|
|
383
|
+
"observation": content.text,
|
|
384
|
+
"session_id": session.session_id,
|
|
385
|
+
"error": "invalid_json_response",
|
|
386
|
+
}
|
|
387
|
+
else:
|
|
388
|
+
# Handle non-text content
|
|
389
|
+
observation = {
|
|
390
|
+
"observation": str(content),
|
|
391
|
+
"session_id": session.session_id,
|
|
392
|
+
}
|
|
393
|
+
else:
|
|
394
|
+
# Handle completely empty tool result
|
|
395
|
+
logger.warning(f"Session {session.session_id}: Tool {tool_name} returned empty result")
|
|
396
|
+
observation = {
|
|
397
|
+
"observation": "no_response",
|
|
398
|
+
"session_id": session.session_id,
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
# 2. Query CONTROL PLANE endpoints for reward/termination info
|
|
402
|
+
reward = 0.0
|
|
403
|
+
terminated = False
|
|
404
|
+
truncated = False
|
|
405
|
+
control_plane_info = {}
|
|
406
|
+
|
|
407
|
+
try:
|
|
408
|
+
# Query control plane endpoints following the new architecture
|
|
409
|
+
import httpx
|
|
410
|
+
|
|
411
|
+
# Extract base URL and session ID from the MCP session
|
|
412
|
+
base_url = session.base_url.rstrip("/mcp").rstrip("/")
|
|
413
|
+
# Use the session ID from the established MCP session
|
|
414
|
+
session_id = session.session_id
|
|
415
|
+
|
|
416
|
+
if session_id:
|
|
417
|
+
headers = {"mcp-session-id": session_id}
|
|
418
|
+
|
|
419
|
+
# Query reward endpoint
|
|
420
|
+
try:
|
|
421
|
+
# Use shorter timeout for better responsiveness
|
|
422
|
+
timeout = 3.0
|
|
423
|
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
424
|
+
reward_response = await client.get(
|
|
425
|
+
f"{base_url}/control/reward",
|
|
426
|
+
headers=headers,
|
|
427
|
+
timeout=timeout,
|
|
428
|
+
)
|
|
429
|
+
if reward_response.status_code == 200:
|
|
430
|
+
reward_data = reward_response.json()
|
|
431
|
+
reward = reward_data.get("reward", 0.0)
|
|
432
|
+
control_plane_info["reward_source"] = "control_plane_endpoint"
|
|
433
|
+
else:
|
|
434
|
+
logger.warning(f"Control plane reward endpoint returned {reward_response.status_code}")
|
|
435
|
+
except httpx.TimeoutException:
|
|
436
|
+
logger.warning(f"Control plane reward endpoint timed out after {timeout}s")
|
|
437
|
+
except Exception as e:
|
|
438
|
+
logger.warning(f"Failed to query reward endpoint: {e}")
|
|
439
|
+
|
|
440
|
+
# Query status endpoint
|
|
441
|
+
try:
|
|
442
|
+
timeout = 3.0
|
|
443
|
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
444
|
+
status_response = await client.get(
|
|
445
|
+
f"{base_url}/control/status",
|
|
446
|
+
headers=headers,
|
|
447
|
+
timeout=timeout,
|
|
448
|
+
)
|
|
449
|
+
if status_response.status_code == 200:
|
|
450
|
+
status_data = status_response.json()
|
|
451
|
+
terminated = status_data.get("terminated", False)
|
|
452
|
+
truncated = status_data.get("truncated", False)
|
|
453
|
+
control_plane_info["status_source"] = "control_plane_endpoint"
|
|
454
|
+
else:
|
|
455
|
+
logger.warning(f"Control plane status endpoint returned {status_response.status_code}")
|
|
456
|
+
except httpx.TimeoutException:
|
|
457
|
+
logger.warning(f"Control plane status endpoint timed out after {timeout}s")
|
|
458
|
+
except Exception as e:
|
|
459
|
+
logger.warning(f"Failed to query status endpoint: {e}")
|
|
460
|
+
|
|
461
|
+
except Exception as e:
|
|
462
|
+
logger.warning(f"Failed to query control plane endpoints: {e}")
|
|
463
|
+
|
|
464
|
+
# 3. Combine results maintaining strict separation
|
|
465
|
+
done = terminated or truncated
|
|
466
|
+
|
|
467
|
+
info = {
|
|
468
|
+
"steps": observation.get("moves", observation.get("steps", 0)),
|
|
469
|
+
"tool_call": tool_name,
|
|
470
|
+
"arguments": arguments,
|
|
471
|
+
"control_plane": control_plane_info, # Mark control plane data
|
|
472
|
+
}
|
|
473
|
+
|
|
474
|
+
# Log control plane separation
|
|
475
|
+
logger.debug(
|
|
476
|
+
f"Session {session.session_id}: Data plane: {list(observation.keys())}, Control plane: reward={reward}, terminated={terminated}"
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
return observation, reward, done, info
|
|
480
|
+
|
|
481
|
+
async def close_session(self, session: MCPSession) -> None:
|
|
482
|
+
"""
|
|
483
|
+
Close an MCP session and clean up resources.
|
|
484
|
+
|
|
485
|
+
Args:
|
|
486
|
+
session: The MCPSession to close
|
|
487
|
+
"""
|
|
488
|
+
if session._exit_stack:
|
|
489
|
+
try:
|
|
490
|
+
await session._exit_stack.aclose()
|
|
491
|
+
except asyncio.CancelledError:
|
|
492
|
+
# Handle cancellation gracefully (especially important for Python 3.12)
|
|
493
|
+
logger.debug(f"Session {session.session_id} close was cancelled")
|
|
494
|
+
except Exception as e:
|
|
495
|
+
# Hitting this error, probably because of use of threads: "Attempted to exit cancel scope in a different task than it was entered in"
|
|
496
|
+
logger.debug(f"Error closing session {session.session_id}: {e}")
|
|
497
|
+
finally:
|
|
498
|
+
session._exit_stack = None
|
|
499
|
+
session._mcp_session = None
|