eval-protocol 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- development/__init__.py +1 -0
- development/normalize_sandbox_fusion.py +628 -0
- development/utils/__init__.py +1 -0
- development/utils/generate_api_key.py +31 -0
- development/utils/subprocess_manager.py +481 -0
- eval_protocol/__init__.py +86 -0
- eval_protocol/__main__.py +10 -0
- eval_protocol/_version.py +21 -0
- eval_protocol/adapters/__init__.py +1 -0
- eval_protocol/adapters/braintrust.py +8 -0
- eval_protocol/adapters/trl.py +8 -0
- eval_protocol/agent/__init__.py +29 -0
- eval_protocol/agent/models.py +69 -0
- eval_protocol/agent/orchestrator.py +893 -0
- eval_protocol/agent/resource_abc.py +89 -0
- eval_protocol/agent/resource_pool.py +184 -0
- eval_protocol/agent/resources/__init__.py +44 -0
- eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
- eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
- eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
- eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
- eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
- eval_protocol/agent/resources/docker_resource.py +479 -0
- eval_protocol/agent/resources/filesystem_resource.py +371 -0
- eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
- eval_protocol/agent/resources/http_rollout_resource.py +325 -0
- eval_protocol/agent/resources/python_state_resource.py +170 -0
- eval_protocol/agent/resources/sql_resource.py +271 -0
- eval_protocol/agent/task_manager.py +1064 -0
- eval_protocol/agent/tool_registry.py +111 -0
- eval_protocol/auth.py +156 -0
- eval_protocol/cli.py +425 -0
- eval_protocol/cli_commands/__init__.py +1 -0
- eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
- eval_protocol/cli_commands/common.py +242 -0
- eval_protocol/cli_commands/deploy.py +486 -0
- eval_protocol/cli_commands/deploy_mcp.py +287 -0
- eval_protocol/cli_commands/preview.py +186 -0
- eval_protocol/cli_commands/run_eval_cmd.py +202 -0
- eval_protocol/common_utils.py +36 -0
- eval_protocol/config.py +180 -0
- eval_protocol/datasets/__init__.py +1 -0
- eval_protocol/datasets/loader.py +521 -0
- eval_protocol/evaluation.py +1045 -0
- eval_protocol/execution/__init__.py +1 -0
- eval_protocol/execution/pipeline.py +920 -0
- eval_protocol/gcp_tools.py +484 -0
- eval_protocol/generation/cache.py +141 -0
- eval_protocol/generation/clients/base.py +67 -0
- eval_protocol/generation/clients.py +248 -0
- eval_protocol/generic_server.py +165 -0
- eval_protocol/integrations/__init__.py +12 -0
- eval_protocol/integrations/braintrust.py +51 -0
- eval_protocol/integrations/deepeval.py +106 -0
- eval_protocol/integrations/openeval.py +40 -0
- eval_protocol/integrations/trl.py +187 -0
- eval_protocol/mcp/__init__.py +48 -0
- eval_protocol/mcp/adapter.py +131 -0
- eval_protocol/mcp/client/__init__.py +12 -0
- eval_protocol/mcp/client/connection.py +499 -0
- eval_protocol/mcp/clients.py +195 -0
- eval_protocol/mcp/execution/__init__.py +23 -0
- eval_protocol/mcp/execution/base_policy.py +227 -0
- eval_protocol/mcp/execution/fireworks_policy.py +209 -0
- eval_protocol/mcp/execution/manager.py +506 -0
- eval_protocol/mcp/execution/policy.py +421 -0
- eval_protocol/mcp/grid_renderer.py +54 -0
- eval_protocol/mcp/mcpgym.py +637 -0
- eval_protocol/mcp/process_manager.py +177 -0
- eval_protocol/mcp/session/__init__.py +11 -0
- eval_protocol/mcp/session/manager.py +228 -0
- eval_protocol/mcp/simple_process_manager.py +291 -0
- eval_protocol/mcp/simulation_server.py +458 -0
- eval_protocol/mcp/types.py +80 -0
- eval_protocol/mcp_agent/__init__.py +1 -0
- eval_protocol/mcp_agent/config.py +147 -0
- eval_protocol/mcp_agent/intermediary_server.py +542 -0
- eval_protocol/mcp_agent/main.py +210 -0
- eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
- eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
- eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
- eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
- eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
- eval_protocol/mcp_agent/session.py +79 -0
- eval_protocol/mcp_env.py +304 -0
- eval_protocol/models.py +366 -0
- eval_protocol/packaging.py +219 -0
- eval_protocol/platform_api.py +360 -0
- eval_protocol/playback_policy.py +396 -0
- eval_protocol/resources.py +128 -0
- eval_protocol/reward_function.py +410 -0
- eval_protocol/rewards/__init__.py +94 -0
- eval_protocol/rewards/accuracy.py +454 -0
- eval_protocol/rewards/accuracy_length.py +173 -0
- eval_protocol/rewards/apps_coding_reward.py +331 -0
- eval_protocol/rewards/apps_execution_utils.py +149 -0
- eval_protocol/rewards/apps_testing_util.py +559 -0
- eval_protocol/rewards/bfcl_reward.py +313 -0
- eval_protocol/rewards/code_execution.py +1620 -0
- eval_protocol/rewards/code_execution_utils.py +72 -0
- eval_protocol/rewards/cpp_code.py +861 -0
- eval_protocol/rewards/deepcoder_reward.py +161 -0
- eval_protocol/rewards/format.py +129 -0
- eval_protocol/rewards/function_calling.py +541 -0
- eval_protocol/rewards/json_schema.py +422 -0
- eval_protocol/rewards/language_consistency.py +700 -0
- eval_protocol/rewards/lean_prover.py +479 -0
- eval_protocol/rewards/length.py +375 -0
- eval_protocol/rewards/list_comparison_math_reward.py +221 -0
- eval_protocol/rewards/math.py +762 -0
- eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
- eval_protocol/rewards/reasoning_steps.py +249 -0
- eval_protocol/rewards/repetition.py +342 -0
- eval_protocol/rewards/tag_count.py +162 -0
- eval_protocol/rl_processing.py +82 -0
- eval_protocol/server.py +271 -0
- eval_protocol/typed_interface.py +260 -0
- eval_protocol/utils/__init__.py +8 -0
- eval_protocol/utils/batch_evaluation.py +217 -0
- eval_protocol/utils/batch_transformation.py +205 -0
- eval_protocol/utils/dataset_helpers.py +112 -0
- eval_protocol/utils/module_loader.py +56 -0
- eval_protocol/utils/packaging_utils.py +108 -0
- eval_protocol/utils/static_policy.py +305 -0
- eval_protocol-0.0.3.dist-info/METADATA +635 -0
- eval_protocol-0.0.3.dist-info/RECORD +130 -0
- eval_protocol-0.0.3.dist-info/WHEEL +5 -0
- eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
- eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
- eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,637 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MCP-Gym Framework - North Star Implementation
|
|
3
|
+
|
|
4
|
+
This module provides the core McpGym base class that implements the north star vision
|
|
5
|
+
for universal RL environment integration via MCP protocol.
|
|
6
|
+
|
|
7
|
+
Key Features:
|
|
8
|
+
- Unified MCP server with FastMCP integration
|
|
9
|
+
- Simple tool registration with @self.mcp.tool() decorator
|
|
10
|
+
- Clean separation between data plane (MCP tool calls) and control plane (custom endpoints)
|
|
11
|
+
- Compatible with CondaServerProcessManager
|
|
12
|
+
- Session-aware control plane endpoints via @control_plane_endpoint decorator
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
import hashlib
|
|
17
|
+
import threading
|
|
18
|
+
import inspect
|
|
19
|
+
import json
|
|
20
|
+
import logging
|
|
21
|
+
from abc import ABC, abstractmethod
|
|
22
|
+
from typing import Any, Callable, Dict, Optional, Tuple
|
|
23
|
+
|
|
24
|
+
from mcp.server.fastmcp import Context, FastMCP
|
|
25
|
+
from starlette.requests import Request
|
|
26
|
+
from starlette.responses import JSONResponse
|
|
27
|
+
|
|
28
|
+
from .adapter import EnvironmentAdapter
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def control_plane_endpoint(path: str) -> Callable:
|
|
34
|
+
"""
|
|
35
|
+
Decorator to register session-aware control plane endpoints.
|
|
36
|
+
|
|
37
|
+
Control plane endpoints provide rewards, termination status, and other
|
|
38
|
+
metadata without polluting the tool namespace used by LLMs.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
path: URL path for the endpoint (e.g., "/control/reward")
|
|
42
|
+
|
|
43
|
+
Example:
|
|
44
|
+
@control_plane_endpoint("/control/reward")
|
|
45
|
+
def get_reward(self, ctx: Context, session_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
46
|
+
control_plane = session_data.get("control_plane", {})
|
|
47
|
+
return {
|
|
48
|
+
"reward": control_plane.get("reward", 0.0),
|
|
49
|
+
"step_count": control_plane.get("step_count", 0)
|
|
50
|
+
}
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def decorator(func: Callable) -> Callable:
|
|
54
|
+
func._is_control_plane_endpoint = True
|
|
55
|
+
func._control_plane_path = path
|
|
56
|
+
return func
|
|
57
|
+
|
|
58
|
+
return decorator
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class McpGym(ABC):
|
|
62
|
+
"""
|
|
63
|
+
Base class for MCP-Gym environments implementing the north star vision.
|
|
64
|
+
|
|
65
|
+
This class provides the universal adapter pattern for RL environments,
|
|
66
|
+
bridging training infrastructure, production MCP standards, and high-quality
|
|
67
|
+
environments through a clean, standardized interface.
|
|
68
|
+
|
|
69
|
+
Key Design Principles:
|
|
70
|
+
- Data Plane: JSON tool calls/responses via MCP (state transitions/actions)
|
|
71
|
+
- Control Plane: Rewards/termination signals via MCP resources
|
|
72
|
+
- Environment Implementation: Single-process MCP server per environment
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(self, server_name: str, adapter: EnvironmentAdapter, seed: Optional[int] = None):
|
|
76
|
+
"""
|
|
77
|
+
Initialize the MCP-Gym environment.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
server_name: Name for the MCP server
|
|
81
|
+
adapter: Environment adapter instance
|
|
82
|
+
seed: Optional seed for reproducible environments
|
|
83
|
+
"""
|
|
84
|
+
self.adapter = adapter
|
|
85
|
+
|
|
86
|
+
# Create FastMCP server
|
|
87
|
+
self.mcp = FastMCP(
|
|
88
|
+
server_name,
|
|
89
|
+
host="0.0.0.0",
|
|
90
|
+
port=int(os.environ.get("PORT", 8000)),
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Multi-session support
|
|
94
|
+
self.sessions = {} # session_id -> {"env": env, "obs": obs, "session_data": data}
|
|
95
|
+
self.session_lock = threading.Lock()
|
|
96
|
+
|
|
97
|
+
# Control plane endpoints dictionary
|
|
98
|
+
self._control_plane_endpoints: Dict[str, Callable] = {}
|
|
99
|
+
|
|
100
|
+
# Initialize control plane state (for backward compatibility - single session)
|
|
101
|
+
self.control_plane_state = {
|
|
102
|
+
"reward": 0.0,
|
|
103
|
+
"terminated": False,
|
|
104
|
+
"truncated": False,
|
|
105
|
+
"info": {},
|
|
106
|
+
"step_count": 0,
|
|
107
|
+
"total_reward": 0.0,
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
# Reset with seed if provided
|
|
111
|
+
self.env, self.obs, _info = self._new_env(seed=seed)
|
|
112
|
+
|
|
113
|
+
# Register tools and control plane endpoints
|
|
114
|
+
self._register_tools()
|
|
115
|
+
self._discover_and_register_control_plane_endpoints()
|
|
116
|
+
|
|
117
|
+
def _get_session_id(self, ctx: Context) -> str:
|
|
118
|
+
"""
|
|
119
|
+
Extract session ID from MCP context using proper FastMCP pattern.
|
|
120
|
+
|
|
121
|
+
Creates stable session IDs based on client info (seed + config + client details)
|
|
122
|
+
for consistent session management across reconnections.
|
|
123
|
+
"""
|
|
124
|
+
print(f"🔍 _get_session_id: Starting session ID extraction")
|
|
125
|
+
print(f"🔍 _get_session_id: ctx type: {type(ctx)}")
|
|
126
|
+
print(f"🔍 _get_session_id: hasattr(ctx, 'session'): {hasattr(ctx, 'session')}")
|
|
127
|
+
|
|
128
|
+
# Use stable session ID based on client info (following simulation_server.py pattern)
|
|
129
|
+
if hasattr(ctx, "session") and hasattr(ctx.session, "client_params"):
|
|
130
|
+
client_params = ctx.session.client_params
|
|
131
|
+
print(f"🔍 _get_session_id: client_params type: {type(client_params)}")
|
|
132
|
+
print(f"🔍 _get_session_id: hasattr(client_params, 'clientInfo'): {hasattr(client_params, 'clientInfo')}")
|
|
133
|
+
|
|
134
|
+
if hasattr(client_params, "clientInfo"):
|
|
135
|
+
client_info = client_params.clientInfo
|
|
136
|
+
print(f"🔍 _get_session_id: client_info: {client_info}")
|
|
137
|
+
print(f"🔍 _get_session_id: hasattr(client_info, '_extra'): {hasattr(client_info, '_extra')}")
|
|
138
|
+
|
|
139
|
+
if client_info and hasattr(client_info, "_extra"):
|
|
140
|
+
extra_data = client_info._extra
|
|
141
|
+
print(f"🔍 _get_session_id: extra_data: {extra_data}")
|
|
142
|
+
print(f"🔍 _get_session_id: extra_data type: {type(extra_data)}")
|
|
143
|
+
|
|
144
|
+
if extra_data and isinstance(extra_data, dict):
|
|
145
|
+
# Create a stable session ID based on seed and other config
|
|
146
|
+
seed_value = extra_data.get("seed")
|
|
147
|
+
config_value = extra_data.get("config", {})
|
|
148
|
+
dataset_row_id_value = extra_data.get("dataset_row_id")
|
|
149
|
+
model_id_value = extra_data.get("model_id")
|
|
150
|
+
|
|
151
|
+
print(f"🔍 _get_session_id: seed_value: {seed_value} (type: {type(seed_value)})")
|
|
152
|
+
print(f"🔍 _get_session_id: config_value: {config_value}")
|
|
153
|
+
|
|
154
|
+
stable_data = {
|
|
155
|
+
"seed": seed_value,
|
|
156
|
+
"config": config_value,
|
|
157
|
+
"dataset_row_id": dataset_row_id_value,
|
|
158
|
+
"model_id": model_id_value,
|
|
159
|
+
"name": client_info.name,
|
|
160
|
+
"version": client_info.version,
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
print(f"🔍 _get_session_id: stable_data: {stable_data}")
|
|
164
|
+
stable_str = json.dumps(stable_data, sort_keys=True)
|
|
165
|
+
session_id = hashlib.md5(stable_str.encode()).hexdigest()
|
|
166
|
+
print(f"🎯 Generated stable session_id: {session_id} for seed: {seed_value}")
|
|
167
|
+
return session_id
|
|
168
|
+
|
|
169
|
+
# Fallback for testing or other scenarios
|
|
170
|
+
session_id = f"gym_{id(ctx)}"
|
|
171
|
+
print(f"🎯 Generated fallback session_id: {session_id}")
|
|
172
|
+
return session_id
|
|
173
|
+
|
|
174
|
+
def _get_or_create_session(self, ctx: Context) -> Dict[str, Any]:
|
|
175
|
+
"""
|
|
176
|
+
Get or create session data for the given context.
|
|
177
|
+
|
|
178
|
+
This method handles comprehensive session creation with seed extraction
|
|
179
|
+
from MCP context and proper environment initialization.
|
|
180
|
+
"""
|
|
181
|
+
session_id = self._get_session_id(ctx)
|
|
182
|
+
print(f"🔍 _get_or_create_session: session_id: {session_id}")
|
|
183
|
+
|
|
184
|
+
with self.session_lock:
|
|
185
|
+
if session_id not in self.sessions:
|
|
186
|
+
print(f"🔍 _get_or_create_session: Creating new session for {session_id}")
|
|
187
|
+
# Extract seed from context using proper FastMCP pattern
|
|
188
|
+
seed = None
|
|
189
|
+
config = self._get_default_config()
|
|
190
|
+
print(f"🔍 _get_or_create_session: default_config: {config}")
|
|
191
|
+
|
|
192
|
+
if hasattr(ctx, "session") and hasattr(ctx.session, "client_params"):
|
|
193
|
+
client_params = ctx.session.client_params
|
|
194
|
+
if hasattr(client_params, "clientInfo"):
|
|
195
|
+
client_info = client_params.clientInfo
|
|
196
|
+
if client_info and hasattr(client_info, "_extra"):
|
|
197
|
+
extra_data = client_info._extra
|
|
198
|
+
print(f"🔍 _get_or_create_session: extra_data in session creation: {extra_data}")
|
|
199
|
+
if extra_data and isinstance(extra_data, dict):
|
|
200
|
+
# Extract seed from client info
|
|
201
|
+
seed = extra_data.get("seed")
|
|
202
|
+
print(f"🌱 Extracted seed from client_info: {seed} (type: {type(seed)})")
|
|
203
|
+
# Update config with any additional options
|
|
204
|
+
if "config" in extra_data:
|
|
205
|
+
config.update(extra_data["config"])
|
|
206
|
+
print(f"🔍 _get_or_create_session: updated config: {config}")
|
|
207
|
+
|
|
208
|
+
print(f"🔍 _get_or_create_session: About to create environment with seed: {seed}")
|
|
209
|
+
|
|
210
|
+
env, obs, info = self._new_env(seed=seed)
|
|
211
|
+
print(f"🔍 _get_or_create_session: environment created with obs: {obs}, info: {info}")
|
|
212
|
+
|
|
213
|
+
# Initialize session state
|
|
214
|
+
self.sessions[session_id] = {
|
|
215
|
+
"env": env,
|
|
216
|
+
"obs": obs,
|
|
217
|
+
"session_data": {}, # Subclasses can store additional data here
|
|
218
|
+
"session_id": session_id,
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
print(f"🎮 Created new session {session_id[:16]}... with seed {seed}, initial obs: {obs}")
|
|
222
|
+
else:
|
|
223
|
+
print(f"🔍 _get_or_create_session: Returning existing session {session_id}")
|
|
224
|
+
|
|
225
|
+
return self.sessions[session_id]
|
|
226
|
+
|
|
227
|
+
def _discover_and_register_control_plane_endpoints(self):
|
|
228
|
+
"""
|
|
229
|
+
Discover and register control plane endpoints on the subclass instance.
|
|
230
|
+
|
|
231
|
+
This scans for methods decorated with @control_plane_endpoint and
|
|
232
|
+
registers them as FastMCP custom routes with session awareness.
|
|
233
|
+
"""
|
|
234
|
+
# 1. Discover control plane endpoints on the subclass instance
|
|
235
|
+
discovered_endpoints = {}
|
|
236
|
+
for name, method in inspect.getmembers(self, predicate=inspect.ismethod):
|
|
237
|
+
if hasattr(method, "_is_control_plane_endpoint"):
|
|
238
|
+
discovered_endpoints[method.__name__] = method
|
|
239
|
+
self._control_plane_endpoints = discovered_endpoints
|
|
240
|
+
|
|
241
|
+
# 2. Register the discovered endpoints as FastMCP custom routes
|
|
242
|
+
for endpoint_name, endpoint_func in discovered_endpoints.items():
|
|
243
|
+
path = endpoint_func._control_plane_path
|
|
244
|
+
|
|
245
|
+
# Create session-aware handler for this endpoint
|
|
246
|
+
def create_endpoint_handler(func: Callable):
|
|
247
|
+
async def endpoint_handler(request: Request) -> JSONResponse:
|
|
248
|
+
try:
|
|
249
|
+
# Extract session ID from request headers (similar to StreamableHTTP pattern)
|
|
250
|
+
session_id = request.headers.get("mcp-session-id")
|
|
251
|
+
if not session_id:
|
|
252
|
+
return JSONResponse(
|
|
253
|
+
{"error": "Missing mcp-session-id header"},
|
|
254
|
+
status_code=400,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# Get or create session data
|
|
258
|
+
with self.session_lock:
|
|
259
|
+
session_data = self.sessions.get(session_id)
|
|
260
|
+
if not session_data:
|
|
261
|
+
# For initial state endpoint, we need to create the session
|
|
262
|
+
# based on the session ID and available information
|
|
263
|
+
if func.__name__ == "get_initial_state_endpoint":
|
|
264
|
+
env, obs, info = self._new_env(seed=None)
|
|
265
|
+
# Initialize session state with extracted seed from session ID
|
|
266
|
+
session_data = {
|
|
267
|
+
"env": env,
|
|
268
|
+
"obs": obs,
|
|
269
|
+
"session_data": {}, # Subclasses can store additional data here
|
|
270
|
+
"session_id": session_id,
|
|
271
|
+
}
|
|
272
|
+
# Store the session
|
|
273
|
+
self.sessions[session_id] = session_data
|
|
274
|
+
else:
|
|
275
|
+
return JSONResponse(
|
|
276
|
+
{"error": f"Session {session_id} not found"},
|
|
277
|
+
status_code=404,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
# Call the endpoint function with session data
|
|
281
|
+
if inspect.iscoroutinefunction(func):
|
|
282
|
+
result = await func(session_data=session_data)
|
|
283
|
+
else:
|
|
284
|
+
result = func(session_data=session_data)
|
|
285
|
+
|
|
286
|
+
return JSONResponse(result)
|
|
287
|
+
|
|
288
|
+
except Exception as e:
|
|
289
|
+
return JSONResponse({"error": str(e)}, status_code=500)
|
|
290
|
+
|
|
291
|
+
return endpoint_handler
|
|
292
|
+
|
|
293
|
+
# Register the custom route
|
|
294
|
+
handler = create_endpoint_handler(endpoint_func)
|
|
295
|
+
self.mcp.custom_route(path, methods=["GET"])(handler)
|
|
296
|
+
|
|
297
|
+
if discovered_endpoints:
|
|
298
|
+
logger.info(f"✅ Registered {len(discovered_endpoints)} session-aware control plane endpoints")
|
|
299
|
+
for name, endpoint in discovered_endpoints.items():
|
|
300
|
+
logger.info(f" - {name}: {endpoint._control_plane_path}")
|
|
301
|
+
else:
|
|
302
|
+
logger.info("⚠️ No session-aware control plane endpoints discovered")
|
|
303
|
+
|
|
304
|
+
def _update_control_plane(self, reward: float, terminated: bool, truncated: bool, info: Dict[str, Any]):
|
|
305
|
+
"""
|
|
306
|
+
Update control plane state after environment step (single session).
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
reward: Reward from environment step
|
|
310
|
+
terminated: Whether episode terminated
|
|
311
|
+
truncated: Whether episode truncated
|
|
312
|
+
info: Info dictionary from environment
|
|
313
|
+
"""
|
|
314
|
+
self.control_plane_state["reward"] = reward
|
|
315
|
+
self.control_plane_state["terminated"] = terminated
|
|
316
|
+
self.control_plane_state["truncated"] = truncated
|
|
317
|
+
self.control_plane_state["info"] = info
|
|
318
|
+
self.control_plane_state["step_count"] += 1
|
|
319
|
+
self.control_plane_state["total_reward"] += reward
|
|
320
|
+
|
|
321
|
+
# Log control plane update (for debugging)
|
|
322
|
+
print(
|
|
323
|
+
f"🎛️ Control plane updated: reward={reward}, terminated={terminated}, step={self.control_plane_state['step_count']}"
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
def _get_or_create_session_control_plane(self, session_id: str) -> Dict[str, Any]:
|
|
327
|
+
"""Get or create control plane state for a specific session."""
|
|
328
|
+
with self.session_lock:
|
|
329
|
+
if session_id not in self.sessions:
|
|
330
|
+
return {}
|
|
331
|
+
|
|
332
|
+
session_data = self.sessions[session_id]
|
|
333
|
+
if "control_plane" not in session_data["session_data"]:
|
|
334
|
+
session_data["session_data"]["control_plane"] = {
|
|
335
|
+
"reward": 0.0,
|
|
336
|
+
"terminated": False,
|
|
337
|
+
"truncated": False,
|
|
338
|
+
"info": {},
|
|
339
|
+
"step_count": 0,
|
|
340
|
+
"total_reward": 0.0,
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
return session_data["session_data"]["control_plane"]
|
|
344
|
+
|
|
345
|
+
def _update_session_control_plane(
|
|
346
|
+
self,
|
|
347
|
+
session_id: str,
|
|
348
|
+
reward: float,
|
|
349
|
+
terminated: bool,
|
|
350
|
+
truncated: bool,
|
|
351
|
+
info: Dict[str, Any],
|
|
352
|
+
):
|
|
353
|
+
"""Update control plane state for a specific session."""
|
|
354
|
+
control_plane = self._get_or_create_session_control_plane(session_id)
|
|
355
|
+
|
|
356
|
+
control_plane["reward"] = reward
|
|
357
|
+
control_plane["terminated"] = terminated
|
|
358
|
+
control_plane["truncated"] = truncated
|
|
359
|
+
control_plane["info"] = info
|
|
360
|
+
control_plane["step_count"] += 1
|
|
361
|
+
control_plane["total_reward"] += reward
|
|
362
|
+
|
|
363
|
+
# Log control plane update
|
|
364
|
+
print(
|
|
365
|
+
f"🎛️ Session {session_id[:16]}... control plane: reward={reward}, terminated={terminated}, step={control_plane['step_count']}"
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
def get_control_plane_state(self, session_id: str) -> Optional[Dict[str, Any]]:
|
|
369
|
+
"""Get control plane state for a specific session (for rollout system)."""
|
|
370
|
+
with self.session_lock:
|
|
371
|
+
if session_id in self.sessions:
|
|
372
|
+
return self._get_or_create_session_control_plane(session_id).copy()
|
|
373
|
+
return None
|
|
374
|
+
|
|
375
|
+
def _execute_environment_step(self, action_int: int) -> Dict[str, Any]:
|
|
376
|
+
"""
|
|
377
|
+
Execute environment step and update control plane (single session).
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
action_int: Parsed action integer
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
Data plane response (observation only, no rewards)
|
|
384
|
+
"""
|
|
385
|
+
# Execute environment step
|
|
386
|
+
obs, reward, terminated, truncated, info = self.adapter.step_environment(self.env, action_int)
|
|
387
|
+
|
|
388
|
+
# Update global observation state
|
|
389
|
+
self.obs = obs
|
|
390
|
+
|
|
391
|
+
# Update control plane (separate from data plane)
|
|
392
|
+
self._update_control_plane(reward, terminated, truncated, info)
|
|
393
|
+
|
|
394
|
+
# Return ONLY data plane information (no rewards/termination)
|
|
395
|
+
return self._render(obs)
|
|
396
|
+
|
|
397
|
+
def _execute_session_environment_step(self, session_id: str, action: Any) -> Dict[str, Any]:
|
|
398
|
+
"""
|
|
399
|
+
Execute environment step for a specific session and update control plane.
|
|
400
|
+
|
|
401
|
+
Args:
|
|
402
|
+
session_id: Session identifier
|
|
403
|
+
action_int: Parsed action integer
|
|
404
|
+
|
|
405
|
+
Returns:
|
|
406
|
+
Data plane response (observation only, no rewards)
|
|
407
|
+
"""
|
|
408
|
+
session_data = self.sessions[session_id]
|
|
409
|
+
env = session_data["env"]
|
|
410
|
+
|
|
411
|
+
# Execute environment step
|
|
412
|
+
obs, reward, terminated, truncated, info = self.adapter.step_environment(env, action)
|
|
413
|
+
|
|
414
|
+
# Update session observation state
|
|
415
|
+
session_data["obs"] = obs
|
|
416
|
+
|
|
417
|
+
# Update control plane for this session
|
|
418
|
+
self._update_session_control_plane(session_id, reward, terminated, truncated, info)
|
|
419
|
+
|
|
420
|
+
# Return ONLY data plane information (no rewards/termination)
|
|
421
|
+
return self.format_observation(obs, env)
|
|
422
|
+
|
|
423
|
+
def _new_env(self, seed: Optional[int] = None) -> Tuple[Any, Any, Dict]:
|
|
424
|
+
"""Create new environment and return initial state."""
|
|
425
|
+
config = self.adapter.get_default_config()
|
|
426
|
+
|
|
427
|
+
if seed:
|
|
428
|
+
env, obs, info = self.adapter.create_environment_with_seed(config, seed=seed)
|
|
429
|
+
else:
|
|
430
|
+
env = self.adapter.create_environment(config)
|
|
431
|
+
obs, info = self.adapter.reset_environment(env, seed=seed)
|
|
432
|
+
|
|
433
|
+
return env, obs, info
|
|
434
|
+
|
|
435
|
+
def _render(self, obs) -> Dict[str, Any]:
|
|
436
|
+
"""Format observation using subclass implementation."""
|
|
437
|
+
return self.format_observation(obs, self.env)
|
|
438
|
+
|
|
439
|
+
def _get_default_config(self) -> Dict[str, Any]:
|
|
440
|
+
"""
|
|
441
|
+
Get default configuration from adapter.
|
|
442
|
+
|
|
443
|
+
Wrapper method to handle potential adapter interface issues.
|
|
444
|
+
"""
|
|
445
|
+
try:
|
|
446
|
+
return self.adapter.get_default_config()
|
|
447
|
+
except AttributeError:
|
|
448
|
+
# Fallback for adapters that don't implement get_default_config
|
|
449
|
+
return {}
|
|
450
|
+
|
|
451
|
+
# ===== SESSION-AWARE CONTROL PLANE ENDPOINTS =====
|
|
452
|
+
# These provide session-specific control plane data via HTTP endpoints
|
|
453
|
+
# instead of global MCP resources, enabling proper multi-session support.
|
|
454
|
+
|
|
455
|
+
@control_plane_endpoint("/control/reward")
|
|
456
|
+
def get_reward_endpoint(self, session_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
457
|
+
"""Get current reward information for this session."""
|
|
458
|
+
control_plane = self._get_session_control_plane_from_data(session_data)
|
|
459
|
+
return {
|
|
460
|
+
"reward": control_plane.get("reward", 0.0),
|
|
461
|
+
"step_count": control_plane.get("step_count", 0),
|
|
462
|
+
}
|
|
463
|
+
|
|
464
|
+
@control_plane_endpoint("/control/status")
|
|
465
|
+
def get_status_endpoint(self, session_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
466
|
+
"""Get current episode status for this session."""
|
|
467
|
+
control_plane = self._get_session_control_plane_from_data(session_data)
|
|
468
|
+
return {
|
|
469
|
+
"terminated": control_plane.get("terminated", False),
|
|
470
|
+
"truncated": control_plane.get("truncated", False),
|
|
471
|
+
"step_count": control_plane.get("step_count", 0),
|
|
472
|
+
"total_reward": control_plane.get("total_reward", 0.0),
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
@control_plane_endpoint("/control/info")
|
|
476
|
+
def get_info_endpoint(self, session_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
477
|
+
"""Get current environment info for this session."""
|
|
478
|
+
control_plane = self._get_session_control_plane_from_data(session_data)
|
|
479
|
+
return control_plane.get("info", {})
|
|
480
|
+
|
|
481
|
+
@control_plane_endpoint("/control/initial_state")
|
|
482
|
+
def get_initial_state_endpoint(self, session_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
483
|
+
"""Get initial state for this session."""
|
|
484
|
+
env = session_data.get("env")
|
|
485
|
+
obs = session_data.get("obs")
|
|
486
|
+
|
|
487
|
+
if env and obs is not None:
|
|
488
|
+
try:
|
|
489
|
+
formatted_obs = self.format_observation(obs, env)
|
|
490
|
+
return formatted_obs
|
|
491
|
+
except Exception as e:
|
|
492
|
+
logger.error(f"❌ Error in format_observation: {e}")
|
|
493
|
+
return {
|
|
494
|
+
"error": f"Failed to format observation: {str(e)}",
|
|
495
|
+
"observation_type": str(type(obs)),
|
|
496
|
+
"session_id": session_data.get("session_id", "unknown"),
|
|
497
|
+
}
|
|
498
|
+
else:
|
|
499
|
+
# Fallback if session data is not available
|
|
500
|
+
return {
|
|
501
|
+
"observation": "session_not_initialized",
|
|
502
|
+
"session_id": session_data.get("session_id", "unknown"),
|
|
503
|
+
}
|
|
504
|
+
|
|
505
|
+
def _get_session_control_plane_from_data(self, session_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
506
|
+
"""Extract control plane state from session data."""
|
|
507
|
+
return session_data.get("session_data", {}).get(
|
|
508
|
+
"control_plane",
|
|
509
|
+
{
|
|
510
|
+
"reward": 0.0,
|
|
511
|
+
"terminated": False,
|
|
512
|
+
"truncated": False,
|
|
513
|
+
"info": {},
|
|
514
|
+
"step_count": 0,
|
|
515
|
+
"total_reward": 0.0,
|
|
516
|
+
},
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
@abstractmethod
|
|
520
|
+
def _register_tools(self):
|
|
521
|
+
"""
|
|
522
|
+
Register domain-specific MCP tools.
|
|
523
|
+
|
|
524
|
+
Subclasses must implement this method to register their specific tools
|
|
525
|
+
using the @self.mcp.tool() decorator pattern.
|
|
526
|
+
|
|
527
|
+
IMPORTANT: Tools should only return data plane information (observations).
|
|
528
|
+
Control plane information (rewards, termination) is available via resources.
|
|
529
|
+
"""
|
|
530
|
+
pass
|
|
531
|
+
|
|
532
|
+
def format_observation(self, obs: Any, env: Any) -> Dict[str, Any]:
|
|
533
|
+
"""
|
|
534
|
+
Format observation for MCP response.
|
|
535
|
+
|
|
536
|
+
Args:
|
|
537
|
+
obs: Raw observation from environment
|
|
538
|
+
env: Environment instance
|
|
539
|
+
|
|
540
|
+
Returns:
|
|
541
|
+
Formatted observation dictionary (DATA PLANE ONLY)
|
|
542
|
+
|
|
543
|
+
Implementation Note:
|
|
544
|
+
You can use self._to_json_serializable(obs) as a starting point
|
|
545
|
+
for most standard serialization needs.
|
|
546
|
+
"""
|
|
547
|
+
serialized_obs = self._to_json_serializable(obs)
|
|
548
|
+
|
|
549
|
+
# If it's already a dict, return as-is, otherwise wrap it
|
|
550
|
+
if isinstance(serialized_obs, dict):
|
|
551
|
+
return serialized_obs
|
|
552
|
+
else:
|
|
553
|
+
return {"observation": serialized_obs}
|
|
554
|
+
|
|
555
|
+
def run(self, transport: str = "streamable-http", **kwargs):
|
|
556
|
+
"""
|
|
557
|
+
Run the unified MCP-Gym server.
|
|
558
|
+
|
|
559
|
+
Args:
|
|
560
|
+
transport: MCP transport protocol ("stdio", "sse", "streamable-http")
|
|
561
|
+
**kwargs: Additional arguments passed to FastMCP.run()
|
|
562
|
+
"""
|
|
563
|
+
print(f"🚀 {self.mcp.name} MCP-Gym Server Starting...")
|
|
564
|
+
print(f"📡 Transport: {transport}")
|
|
565
|
+
print("🎯 MCP Pattern: HTTP endpoints for control plane, tools for data plane")
|
|
566
|
+
print("🔗 Session-aware control plane endpoints:")
|
|
567
|
+
|
|
568
|
+
# List registered control plane endpoints
|
|
569
|
+
for endpoint_name, endpoint_func in self._control_plane_endpoints.items():
|
|
570
|
+
print(f" - {endpoint_name}: {endpoint_func._control_plane_path}")
|
|
571
|
+
|
|
572
|
+
if not self._control_plane_endpoints:
|
|
573
|
+
print(" - No control plane endpoints registered")
|
|
574
|
+
|
|
575
|
+
print()
|
|
576
|
+
|
|
577
|
+
# Run the unified server
|
|
578
|
+
self.mcp.run(transport=transport, **kwargs)
|
|
579
|
+
|
|
580
|
+
def _to_json_serializable(self, obj: Any) -> Any:
|
|
581
|
+
"""Convert any object to JSON-serializable format.
|
|
582
|
+
|
|
583
|
+
Handles Pydantic models, dataclasses, lists, dicts, and primitive types.
|
|
584
|
+
This is a utility method that can be used by format_observation implementations.
|
|
585
|
+
"""
|
|
586
|
+
from pydantic import BaseModel
|
|
587
|
+
import dataclasses
|
|
588
|
+
from datetime import datetime, date
|
|
589
|
+
from enum import Enum
|
|
590
|
+
|
|
591
|
+
# Handle None and primitive types
|
|
592
|
+
if obj is None or isinstance(obj, (str, int, float, bool)):
|
|
593
|
+
return obj
|
|
594
|
+
|
|
595
|
+
# Handle datetime objects
|
|
596
|
+
elif isinstance(obj, (datetime, date)):
|
|
597
|
+
return obj.isoformat()
|
|
598
|
+
|
|
599
|
+
# Handle enums
|
|
600
|
+
elif isinstance(obj, Enum):
|
|
601
|
+
return obj.value
|
|
602
|
+
|
|
603
|
+
# Handle Pydantic models (covers tau2 objects and many others)
|
|
604
|
+
elif isinstance(obj, BaseModel):
|
|
605
|
+
return obj.model_dump()
|
|
606
|
+
|
|
607
|
+
# Handle dataclasses
|
|
608
|
+
elif dataclasses.is_dataclass(obj):
|
|
609
|
+
return dataclasses.asdict(obj)
|
|
610
|
+
|
|
611
|
+
# Handle dictionaries
|
|
612
|
+
elif isinstance(obj, dict):
|
|
613
|
+
return {k: self._to_json_serializable(v) for k, v in obj.items()}
|
|
614
|
+
|
|
615
|
+
# Handle lists and tuples
|
|
616
|
+
elif isinstance(obj, (list, tuple)):
|
|
617
|
+
return [self._to_json_serializable(item) for item in obj]
|
|
618
|
+
|
|
619
|
+
# Handle sets (convert to list)
|
|
620
|
+
elif isinstance(obj, set):
|
|
621
|
+
return [self._to_json_serializable(item) for item in obj]
|
|
622
|
+
|
|
623
|
+
# Handle objects with __dict__ (fallback)
|
|
624
|
+
elif hasattr(obj, '__dict__'):
|
|
625
|
+
result = {}
|
|
626
|
+
for key, value in obj.__dict__.items():
|
|
627
|
+
if not key.startswith('_'): # Skip private attributes
|
|
628
|
+
try:
|
|
629
|
+
result[key] = self._to_json_serializable(value)
|
|
630
|
+
except Exception:
|
|
631
|
+
# If conversion fails, store as string
|
|
632
|
+
result[key] = str(value)
|
|
633
|
+
return result
|
|
634
|
+
|
|
635
|
+
# Final fallback - convert to string
|
|
636
|
+
else:
|
|
637
|
+
return str(obj)
|