eval-protocol 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. development/__init__.py +1 -0
  2. development/normalize_sandbox_fusion.py +628 -0
  3. development/utils/__init__.py +1 -0
  4. development/utils/generate_api_key.py +31 -0
  5. development/utils/subprocess_manager.py +481 -0
  6. eval_protocol/__init__.py +86 -0
  7. eval_protocol/__main__.py +10 -0
  8. eval_protocol/_version.py +21 -0
  9. eval_protocol/adapters/__init__.py +1 -0
  10. eval_protocol/adapters/braintrust.py +8 -0
  11. eval_protocol/adapters/trl.py +8 -0
  12. eval_protocol/agent/__init__.py +29 -0
  13. eval_protocol/agent/models.py +69 -0
  14. eval_protocol/agent/orchestrator.py +893 -0
  15. eval_protocol/agent/resource_abc.py +89 -0
  16. eval_protocol/agent/resource_pool.py +184 -0
  17. eval_protocol/agent/resources/__init__.py +44 -0
  18. eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
  19. eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
  20. eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
  21. eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
  22. eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
  23. eval_protocol/agent/resources/docker_resource.py +479 -0
  24. eval_protocol/agent/resources/filesystem_resource.py +371 -0
  25. eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
  26. eval_protocol/agent/resources/http_rollout_resource.py +325 -0
  27. eval_protocol/agent/resources/python_state_resource.py +170 -0
  28. eval_protocol/agent/resources/sql_resource.py +271 -0
  29. eval_protocol/agent/task_manager.py +1064 -0
  30. eval_protocol/agent/tool_registry.py +111 -0
  31. eval_protocol/auth.py +156 -0
  32. eval_protocol/cli.py +425 -0
  33. eval_protocol/cli_commands/__init__.py +1 -0
  34. eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
  35. eval_protocol/cli_commands/common.py +242 -0
  36. eval_protocol/cli_commands/deploy.py +486 -0
  37. eval_protocol/cli_commands/deploy_mcp.py +287 -0
  38. eval_protocol/cli_commands/preview.py +186 -0
  39. eval_protocol/cli_commands/run_eval_cmd.py +202 -0
  40. eval_protocol/common_utils.py +36 -0
  41. eval_protocol/config.py +180 -0
  42. eval_protocol/datasets/__init__.py +1 -0
  43. eval_protocol/datasets/loader.py +521 -0
  44. eval_protocol/evaluation.py +1045 -0
  45. eval_protocol/execution/__init__.py +1 -0
  46. eval_protocol/execution/pipeline.py +920 -0
  47. eval_protocol/gcp_tools.py +484 -0
  48. eval_protocol/generation/cache.py +141 -0
  49. eval_protocol/generation/clients/base.py +67 -0
  50. eval_protocol/generation/clients.py +248 -0
  51. eval_protocol/generic_server.py +165 -0
  52. eval_protocol/integrations/__init__.py +12 -0
  53. eval_protocol/integrations/braintrust.py +51 -0
  54. eval_protocol/integrations/deepeval.py +106 -0
  55. eval_protocol/integrations/openeval.py +40 -0
  56. eval_protocol/integrations/trl.py +187 -0
  57. eval_protocol/mcp/__init__.py +48 -0
  58. eval_protocol/mcp/adapter.py +131 -0
  59. eval_protocol/mcp/client/__init__.py +12 -0
  60. eval_protocol/mcp/client/connection.py +499 -0
  61. eval_protocol/mcp/clients.py +195 -0
  62. eval_protocol/mcp/execution/__init__.py +23 -0
  63. eval_protocol/mcp/execution/base_policy.py +227 -0
  64. eval_protocol/mcp/execution/fireworks_policy.py +209 -0
  65. eval_protocol/mcp/execution/manager.py +506 -0
  66. eval_protocol/mcp/execution/policy.py +421 -0
  67. eval_protocol/mcp/grid_renderer.py +54 -0
  68. eval_protocol/mcp/mcpgym.py +637 -0
  69. eval_protocol/mcp/process_manager.py +177 -0
  70. eval_protocol/mcp/session/__init__.py +11 -0
  71. eval_protocol/mcp/session/manager.py +228 -0
  72. eval_protocol/mcp/simple_process_manager.py +291 -0
  73. eval_protocol/mcp/simulation_server.py +458 -0
  74. eval_protocol/mcp/types.py +80 -0
  75. eval_protocol/mcp_agent/__init__.py +1 -0
  76. eval_protocol/mcp_agent/config.py +147 -0
  77. eval_protocol/mcp_agent/intermediary_server.py +542 -0
  78. eval_protocol/mcp_agent/main.py +210 -0
  79. eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
  80. eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
  81. eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
  82. eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
  83. eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
  84. eval_protocol/mcp_agent/session.py +79 -0
  85. eval_protocol/mcp_env.py +304 -0
  86. eval_protocol/models.py +366 -0
  87. eval_protocol/packaging.py +219 -0
  88. eval_protocol/platform_api.py +360 -0
  89. eval_protocol/playback_policy.py +396 -0
  90. eval_protocol/resources.py +128 -0
  91. eval_protocol/reward_function.py +410 -0
  92. eval_protocol/rewards/__init__.py +94 -0
  93. eval_protocol/rewards/accuracy.py +454 -0
  94. eval_protocol/rewards/accuracy_length.py +173 -0
  95. eval_protocol/rewards/apps_coding_reward.py +331 -0
  96. eval_protocol/rewards/apps_execution_utils.py +149 -0
  97. eval_protocol/rewards/apps_testing_util.py +559 -0
  98. eval_protocol/rewards/bfcl_reward.py +313 -0
  99. eval_protocol/rewards/code_execution.py +1620 -0
  100. eval_protocol/rewards/code_execution_utils.py +72 -0
  101. eval_protocol/rewards/cpp_code.py +861 -0
  102. eval_protocol/rewards/deepcoder_reward.py +161 -0
  103. eval_protocol/rewards/format.py +129 -0
  104. eval_protocol/rewards/function_calling.py +541 -0
  105. eval_protocol/rewards/json_schema.py +422 -0
  106. eval_protocol/rewards/language_consistency.py +700 -0
  107. eval_protocol/rewards/lean_prover.py +479 -0
  108. eval_protocol/rewards/length.py +375 -0
  109. eval_protocol/rewards/list_comparison_math_reward.py +221 -0
  110. eval_protocol/rewards/math.py +762 -0
  111. eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
  112. eval_protocol/rewards/reasoning_steps.py +249 -0
  113. eval_protocol/rewards/repetition.py +342 -0
  114. eval_protocol/rewards/tag_count.py +162 -0
  115. eval_protocol/rl_processing.py +82 -0
  116. eval_protocol/server.py +271 -0
  117. eval_protocol/typed_interface.py +260 -0
  118. eval_protocol/utils/__init__.py +8 -0
  119. eval_protocol/utils/batch_evaluation.py +217 -0
  120. eval_protocol/utils/batch_transformation.py +205 -0
  121. eval_protocol/utils/dataset_helpers.py +112 -0
  122. eval_protocol/utils/module_loader.py +56 -0
  123. eval_protocol/utils/packaging_utils.py +108 -0
  124. eval_protocol/utils/static_policy.py +305 -0
  125. eval_protocol-0.0.3.dist-info/METADATA +635 -0
  126. eval_protocol-0.0.3.dist-info/RECORD +130 -0
  127. eval_protocol-0.0.3.dist-info/WHEEL +5 -0
  128. eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
  129. eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
  130. eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
@@ -0,0 +1,637 @@
1
+ """
2
+ MCP-Gym Framework - North Star Implementation
3
+
4
+ This module provides the core McpGym base class that implements the north star vision
5
+ for universal RL environment integration via MCP protocol.
6
+
7
+ Key Features:
8
+ - Unified MCP server with FastMCP integration
9
+ - Simple tool registration with @self.mcp.tool() decorator
10
+ - Clean separation between data plane (MCP tool calls) and control plane (custom endpoints)
11
+ - Compatible with CondaServerProcessManager
12
+ - Session-aware control plane endpoints via @control_plane_endpoint decorator
13
+ """
14
+
15
+ import os
16
+ import hashlib
17
+ import threading
18
+ import inspect
19
+ import json
20
+ import logging
21
+ from abc import ABC, abstractmethod
22
+ from typing import Any, Callable, Dict, Optional, Tuple
23
+
24
+ from mcp.server.fastmcp import Context, FastMCP
25
+ from starlette.requests import Request
26
+ from starlette.responses import JSONResponse
27
+
28
+ from .adapter import EnvironmentAdapter
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ def control_plane_endpoint(path: str) -> Callable:
34
+ """
35
+ Decorator to register session-aware control plane endpoints.
36
+
37
+ Control plane endpoints provide rewards, termination status, and other
38
+ metadata without polluting the tool namespace used by LLMs.
39
+
40
+ Args:
41
+ path: URL path for the endpoint (e.g., "/control/reward")
42
+
43
+ Example:
44
+ @control_plane_endpoint("/control/reward")
45
+ def get_reward(self, ctx: Context, session_data: Dict[str, Any]) -> Dict[str, Any]:
46
+ control_plane = session_data.get("control_plane", {})
47
+ return {
48
+ "reward": control_plane.get("reward", 0.0),
49
+ "step_count": control_plane.get("step_count", 0)
50
+ }
51
+ """
52
+
53
+ def decorator(func: Callable) -> Callable:
54
+ func._is_control_plane_endpoint = True
55
+ func._control_plane_path = path
56
+ return func
57
+
58
+ return decorator
59
+
60
+
61
+ class McpGym(ABC):
62
+ """
63
+ Base class for MCP-Gym environments implementing the north star vision.
64
+
65
+ This class provides the universal adapter pattern for RL environments,
66
+ bridging training infrastructure, production MCP standards, and high-quality
67
+ environments through a clean, standardized interface.
68
+
69
+ Key Design Principles:
70
+ - Data Plane: JSON tool calls/responses via MCP (state transitions/actions)
71
+ - Control Plane: Rewards/termination signals via MCP resources
72
+ - Environment Implementation: Single-process MCP server per environment
73
+ """
74
+
75
+ def __init__(self, server_name: str, adapter: EnvironmentAdapter, seed: Optional[int] = None):
76
+ """
77
+ Initialize the MCP-Gym environment.
78
+
79
+ Args:
80
+ server_name: Name for the MCP server
81
+ adapter: Environment adapter instance
82
+ seed: Optional seed for reproducible environments
83
+ """
84
+ self.adapter = adapter
85
+
86
+ # Create FastMCP server
87
+ self.mcp = FastMCP(
88
+ server_name,
89
+ host="0.0.0.0",
90
+ port=int(os.environ.get("PORT", 8000)),
91
+ )
92
+
93
+ # Multi-session support
94
+ self.sessions = {} # session_id -> {"env": env, "obs": obs, "session_data": data}
95
+ self.session_lock = threading.Lock()
96
+
97
+ # Control plane endpoints dictionary
98
+ self._control_plane_endpoints: Dict[str, Callable] = {}
99
+
100
+ # Initialize control plane state (for backward compatibility - single session)
101
+ self.control_plane_state = {
102
+ "reward": 0.0,
103
+ "terminated": False,
104
+ "truncated": False,
105
+ "info": {},
106
+ "step_count": 0,
107
+ "total_reward": 0.0,
108
+ }
109
+
110
+ # Reset with seed if provided
111
+ self.env, self.obs, _info = self._new_env(seed=seed)
112
+
113
+ # Register tools and control plane endpoints
114
+ self._register_tools()
115
+ self._discover_and_register_control_plane_endpoints()
116
+
117
+ def _get_session_id(self, ctx: Context) -> str:
118
+ """
119
+ Extract session ID from MCP context using proper FastMCP pattern.
120
+
121
+ Creates stable session IDs based on client info (seed + config + client details)
122
+ for consistent session management across reconnections.
123
+ """
124
+ print(f"🔍 _get_session_id: Starting session ID extraction")
125
+ print(f"🔍 _get_session_id: ctx type: {type(ctx)}")
126
+ print(f"🔍 _get_session_id: hasattr(ctx, 'session'): {hasattr(ctx, 'session')}")
127
+
128
+ # Use stable session ID based on client info (following simulation_server.py pattern)
129
+ if hasattr(ctx, "session") and hasattr(ctx.session, "client_params"):
130
+ client_params = ctx.session.client_params
131
+ print(f"🔍 _get_session_id: client_params type: {type(client_params)}")
132
+ print(f"🔍 _get_session_id: hasattr(client_params, 'clientInfo'): {hasattr(client_params, 'clientInfo')}")
133
+
134
+ if hasattr(client_params, "clientInfo"):
135
+ client_info = client_params.clientInfo
136
+ print(f"🔍 _get_session_id: client_info: {client_info}")
137
+ print(f"🔍 _get_session_id: hasattr(client_info, '_extra'): {hasattr(client_info, '_extra')}")
138
+
139
+ if client_info and hasattr(client_info, "_extra"):
140
+ extra_data = client_info._extra
141
+ print(f"🔍 _get_session_id: extra_data: {extra_data}")
142
+ print(f"🔍 _get_session_id: extra_data type: {type(extra_data)}")
143
+
144
+ if extra_data and isinstance(extra_data, dict):
145
+ # Create a stable session ID based on seed and other config
146
+ seed_value = extra_data.get("seed")
147
+ config_value = extra_data.get("config", {})
148
+ dataset_row_id_value = extra_data.get("dataset_row_id")
149
+ model_id_value = extra_data.get("model_id")
150
+
151
+ print(f"🔍 _get_session_id: seed_value: {seed_value} (type: {type(seed_value)})")
152
+ print(f"🔍 _get_session_id: config_value: {config_value}")
153
+
154
+ stable_data = {
155
+ "seed": seed_value,
156
+ "config": config_value,
157
+ "dataset_row_id": dataset_row_id_value,
158
+ "model_id": model_id_value,
159
+ "name": client_info.name,
160
+ "version": client_info.version,
161
+ }
162
+
163
+ print(f"🔍 _get_session_id: stable_data: {stable_data}")
164
+ stable_str = json.dumps(stable_data, sort_keys=True)
165
+ session_id = hashlib.md5(stable_str.encode()).hexdigest()
166
+ print(f"🎯 Generated stable session_id: {session_id} for seed: {seed_value}")
167
+ return session_id
168
+
169
+ # Fallback for testing or other scenarios
170
+ session_id = f"gym_{id(ctx)}"
171
+ print(f"🎯 Generated fallback session_id: {session_id}")
172
+ return session_id
173
+
174
+ def _get_or_create_session(self, ctx: Context) -> Dict[str, Any]:
175
+ """
176
+ Get or create session data for the given context.
177
+
178
+ This method handles comprehensive session creation with seed extraction
179
+ from MCP context and proper environment initialization.
180
+ """
181
+ session_id = self._get_session_id(ctx)
182
+ print(f"🔍 _get_or_create_session: session_id: {session_id}")
183
+
184
+ with self.session_lock:
185
+ if session_id not in self.sessions:
186
+ print(f"🔍 _get_or_create_session: Creating new session for {session_id}")
187
+ # Extract seed from context using proper FastMCP pattern
188
+ seed = None
189
+ config = self._get_default_config()
190
+ print(f"🔍 _get_or_create_session: default_config: {config}")
191
+
192
+ if hasattr(ctx, "session") and hasattr(ctx.session, "client_params"):
193
+ client_params = ctx.session.client_params
194
+ if hasattr(client_params, "clientInfo"):
195
+ client_info = client_params.clientInfo
196
+ if client_info and hasattr(client_info, "_extra"):
197
+ extra_data = client_info._extra
198
+ print(f"🔍 _get_or_create_session: extra_data in session creation: {extra_data}")
199
+ if extra_data and isinstance(extra_data, dict):
200
+ # Extract seed from client info
201
+ seed = extra_data.get("seed")
202
+ print(f"🌱 Extracted seed from client_info: {seed} (type: {type(seed)})")
203
+ # Update config with any additional options
204
+ if "config" in extra_data:
205
+ config.update(extra_data["config"])
206
+ print(f"🔍 _get_or_create_session: updated config: {config}")
207
+
208
+ print(f"🔍 _get_or_create_session: About to create environment with seed: {seed}")
209
+
210
+ env, obs, info = self._new_env(seed=seed)
211
+ print(f"🔍 _get_or_create_session: environment created with obs: {obs}, info: {info}")
212
+
213
+ # Initialize session state
214
+ self.sessions[session_id] = {
215
+ "env": env,
216
+ "obs": obs,
217
+ "session_data": {}, # Subclasses can store additional data here
218
+ "session_id": session_id,
219
+ }
220
+
221
+ print(f"🎮 Created new session {session_id[:16]}... with seed {seed}, initial obs: {obs}")
222
+ else:
223
+ print(f"🔍 _get_or_create_session: Returning existing session {session_id}")
224
+
225
+ return self.sessions[session_id]
226
+
227
+ def _discover_and_register_control_plane_endpoints(self):
228
+ """
229
+ Discover and register control plane endpoints on the subclass instance.
230
+
231
+ This scans for methods decorated with @control_plane_endpoint and
232
+ registers them as FastMCP custom routes with session awareness.
233
+ """
234
+ # 1. Discover control plane endpoints on the subclass instance
235
+ discovered_endpoints = {}
236
+ for name, method in inspect.getmembers(self, predicate=inspect.ismethod):
237
+ if hasattr(method, "_is_control_plane_endpoint"):
238
+ discovered_endpoints[method.__name__] = method
239
+ self._control_plane_endpoints = discovered_endpoints
240
+
241
+ # 2. Register the discovered endpoints as FastMCP custom routes
242
+ for endpoint_name, endpoint_func in discovered_endpoints.items():
243
+ path = endpoint_func._control_plane_path
244
+
245
+ # Create session-aware handler for this endpoint
246
+ def create_endpoint_handler(func: Callable):
247
+ async def endpoint_handler(request: Request) -> JSONResponse:
248
+ try:
249
+ # Extract session ID from request headers (similar to StreamableHTTP pattern)
250
+ session_id = request.headers.get("mcp-session-id")
251
+ if not session_id:
252
+ return JSONResponse(
253
+ {"error": "Missing mcp-session-id header"},
254
+ status_code=400,
255
+ )
256
+
257
+ # Get or create session data
258
+ with self.session_lock:
259
+ session_data = self.sessions.get(session_id)
260
+ if not session_data:
261
+ # For initial state endpoint, we need to create the session
262
+ # based on the session ID and available information
263
+ if func.__name__ == "get_initial_state_endpoint":
264
+ env, obs, info = self._new_env(seed=None)
265
+ # Initialize session state with extracted seed from session ID
266
+ session_data = {
267
+ "env": env,
268
+ "obs": obs,
269
+ "session_data": {}, # Subclasses can store additional data here
270
+ "session_id": session_id,
271
+ }
272
+ # Store the session
273
+ self.sessions[session_id] = session_data
274
+ else:
275
+ return JSONResponse(
276
+ {"error": f"Session {session_id} not found"},
277
+ status_code=404,
278
+ )
279
+
280
+ # Call the endpoint function with session data
281
+ if inspect.iscoroutinefunction(func):
282
+ result = await func(session_data=session_data)
283
+ else:
284
+ result = func(session_data=session_data)
285
+
286
+ return JSONResponse(result)
287
+
288
+ except Exception as e:
289
+ return JSONResponse({"error": str(e)}, status_code=500)
290
+
291
+ return endpoint_handler
292
+
293
+ # Register the custom route
294
+ handler = create_endpoint_handler(endpoint_func)
295
+ self.mcp.custom_route(path, methods=["GET"])(handler)
296
+
297
+ if discovered_endpoints:
298
+ logger.info(f"✅ Registered {len(discovered_endpoints)} session-aware control plane endpoints")
299
+ for name, endpoint in discovered_endpoints.items():
300
+ logger.info(f" - {name}: {endpoint._control_plane_path}")
301
+ else:
302
+ logger.info("⚠️ No session-aware control plane endpoints discovered")
303
+
304
+ def _update_control_plane(self, reward: float, terminated: bool, truncated: bool, info: Dict[str, Any]):
305
+ """
306
+ Update control plane state after environment step (single session).
307
+
308
+ Args:
309
+ reward: Reward from environment step
310
+ terminated: Whether episode terminated
311
+ truncated: Whether episode truncated
312
+ info: Info dictionary from environment
313
+ """
314
+ self.control_plane_state["reward"] = reward
315
+ self.control_plane_state["terminated"] = terminated
316
+ self.control_plane_state["truncated"] = truncated
317
+ self.control_plane_state["info"] = info
318
+ self.control_plane_state["step_count"] += 1
319
+ self.control_plane_state["total_reward"] += reward
320
+
321
+ # Log control plane update (for debugging)
322
+ print(
323
+ f"🎛️ Control plane updated: reward={reward}, terminated={terminated}, step={self.control_plane_state['step_count']}"
324
+ )
325
+
326
+ def _get_or_create_session_control_plane(self, session_id: str) -> Dict[str, Any]:
327
+ """Get or create control plane state for a specific session."""
328
+ with self.session_lock:
329
+ if session_id not in self.sessions:
330
+ return {}
331
+
332
+ session_data = self.sessions[session_id]
333
+ if "control_plane" not in session_data["session_data"]:
334
+ session_data["session_data"]["control_plane"] = {
335
+ "reward": 0.0,
336
+ "terminated": False,
337
+ "truncated": False,
338
+ "info": {},
339
+ "step_count": 0,
340
+ "total_reward": 0.0,
341
+ }
342
+
343
+ return session_data["session_data"]["control_plane"]
344
+
345
+ def _update_session_control_plane(
346
+ self,
347
+ session_id: str,
348
+ reward: float,
349
+ terminated: bool,
350
+ truncated: bool,
351
+ info: Dict[str, Any],
352
+ ):
353
+ """Update control plane state for a specific session."""
354
+ control_plane = self._get_or_create_session_control_plane(session_id)
355
+
356
+ control_plane["reward"] = reward
357
+ control_plane["terminated"] = terminated
358
+ control_plane["truncated"] = truncated
359
+ control_plane["info"] = info
360
+ control_plane["step_count"] += 1
361
+ control_plane["total_reward"] += reward
362
+
363
+ # Log control plane update
364
+ print(
365
+ f"🎛️ Session {session_id[:16]}... control plane: reward={reward}, terminated={terminated}, step={control_plane['step_count']}"
366
+ )
367
+
368
+ def get_control_plane_state(self, session_id: str) -> Optional[Dict[str, Any]]:
369
+ """Get control plane state for a specific session (for rollout system)."""
370
+ with self.session_lock:
371
+ if session_id in self.sessions:
372
+ return self._get_or_create_session_control_plane(session_id).copy()
373
+ return None
374
+
375
+ def _execute_environment_step(self, action_int: int) -> Dict[str, Any]:
376
+ """
377
+ Execute environment step and update control plane (single session).
378
+
379
+ Args:
380
+ action_int: Parsed action integer
381
+
382
+ Returns:
383
+ Data plane response (observation only, no rewards)
384
+ """
385
+ # Execute environment step
386
+ obs, reward, terminated, truncated, info = self.adapter.step_environment(self.env, action_int)
387
+
388
+ # Update global observation state
389
+ self.obs = obs
390
+
391
+ # Update control plane (separate from data plane)
392
+ self._update_control_plane(reward, terminated, truncated, info)
393
+
394
+ # Return ONLY data plane information (no rewards/termination)
395
+ return self._render(obs)
396
+
397
+ def _execute_session_environment_step(self, session_id: str, action: Any) -> Dict[str, Any]:
398
+ """
399
+ Execute environment step for a specific session and update control plane.
400
+
401
+ Args:
402
+ session_id: Session identifier
403
+ action_int: Parsed action integer
404
+
405
+ Returns:
406
+ Data plane response (observation only, no rewards)
407
+ """
408
+ session_data = self.sessions[session_id]
409
+ env = session_data["env"]
410
+
411
+ # Execute environment step
412
+ obs, reward, terminated, truncated, info = self.adapter.step_environment(env, action)
413
+
414
+ # Update session observation state
415
+ session_data["obs"] = obs
416
+
417
+ # Update control plane for this session
418
+ self._update_session_control_plane(session_id, reward, terminated, truncated, info)
419
+
420
+ # Return ONLY data plane information (no rewards/termination)
421
+ return self.format_observation(obs, env)
422
+
423
+ def _new_env(self, seed: Optional[int] = None) -> Tuple[Any, Any, Dict]:
424
+ """Create new environment and return initial state."""
425
+ config = self.adapter.get_default_config()
426
+
427
+ if seed:
428
+ env, obs, info = self.adapter.create_environment_with_seed(config, seed=seed)
429
+ else:
430
+ env = self.adapter.create_environment(config)
431
+ obs, info = self.adapter.reset_environment(env, seed=seed)
432
+
433
+ return env, obs, info
434
+
435
+ def _render(self, obs) -> Dict[str, Any]:
436
+ """Format observation using subclass implementation."""
437
+ return self.format_observation(obs, self.env)
438
+
439
+ def _get_default_config(self) -> Dict[str, Any]:
440
+ """
441
+ Get default configuration from adapter.
442
+
443
+ Wrapper method to handle potential adapter interface issues.
444
+ """
445
+ try:
446
+ return self.adapter.get_default_config()
447
+ except AttributeError:
448
+ # Fallback for adapters that don't implement get_default_config
449
+ return {}
450
+
451
+ # ===== SESSION-AWARE CONTROL PLANE ENDPOINTS =====
452
+ # These provide session-specific control plane data via HTTP endpoints
453
+ # instead of global MCP resources, enabling proper multi-session support.
454
+
455
+ @control_plane_endpoint("/control/reward")
456
+ def get_reward_endpoint(self, session_data: Dict[str, Any]) -> Dict[str, Any]:
457
+ """Get current reward information for this session."""
458
+ control_plane = self._get_session_control_plane_from_data(session_data)
459
+ return {
460
+ "reward": control_plane.get("reward", 0.0),
461
+ "step_count": control_plane.get("step_count", 0),
462
+ }
463
+
464
+ @control_plane_endpoint("/control/status")
465
+ def get_status_endpoint(self, session_data: Dict[str, Any]) -> Dict[str, Any]:
466
+ """Get current episode status for this session."""
467
+ control_plane = self._get_session_control_plane_from_data(session_data)
468
+ return {
469
+ "terminated": control_plane.get("terminated", False),
470
+ "truncated": control_plane.get("truncated", False),
471
+ "step_count": control_plane.get("step_count", 0),
472
+ "total_reward": control_plane.get("total_reward", 0.0),
473
+ }
474
+
475
+ @control_plane_endpoint("/control/info")
476
+ def get_info_endpoint(self, session_data: Dict[str, Any]) -> Dict[str, Any]:
477
+ """Get current environment info for this session."""
478
+ control_plane = self._get_session_control_plane_from_data(session_data)
479
+ return control_plane.get("info", {})
480
+
481
+ @control_plane_endpoint("/control/initial_state")
482
+ def get_initial_state_endpoint(self, session_data: Dict[str, Any]) -> Dict[str, Any]:
483
+ """Get initial state for this session."""
484
+ env = session_data.get("env")
485
+ obs = session_data.get("obs")
486
+
487
+ if env and obs is not None:
488
+ try:
489
+ formatted_obs = self.format_observation(obs, env)
490
+ return formatted_obs
491
+ except Exception as e:
492
+ logger.error(f"❌ Error in format_observation: {e}")
493
+ return {
494
+ "error": f"Failed to format observation: {str(e)}",
495
+ "observation_type": str(type(obs)),
496
+ "session_id": session_data.get("session_id", "unknown"),
497
+ }
498
+ else:
499
+ # Fallback if session data is not available
500
+ return {
501
+ "observation": "session_not_initialized",
502
+ "session_id": session_data.get("session_id", "unknown"),
503
+ }
504
+
505
+ def _get_session_control_plane_from_data(self, session_data: Dict[str, Any]) -> Dict[str, Any]:
506
+ """Extract control plane state from session data."""
507
+ return session_data.get("session_data", {}).get(
508
+ "control_plane",
509
+ {
510
+ "reward": 0.0,
511
+ "terminated": False,
512
+ "truncated": False,
513
+ "info": {},
514
+ "step_count": 0,
515
+ "total_reward": 0.0,
516
+ },
517
+ )
518
+
519
+ @abstractmethod
520
+ def _register_tools(self):
521
+ """
522
+ Register domain-specific MCP tools.
523
+
524
+ Subclasses must implement this method to register their specific tools
525
+ using the @self.mcp.tool() decorator pattern.
526
+
527
+ IMPORTANT: Tools should only return data plane information (observations).
528
+ Control plane information (rewards, termination) is available via resources.
529
+ """
530
+ pass
531
+
532
+ def format_observation(self, obs: Any, env: Any) -> Dict[str, Any]:
533
+ """
534
+ Format observation for MCP response.
535
+
536
+ Args:
537
+ obs: Raw observation from environment
538
+ env: Environment instance
539
+
540
+ Returns:
541
+ Formatted observation dictionary (DATA PLANE ONLY)
542
+
543
+ Implementation Note:
544
+ You can use self._to_json_serializable(obs) as a starting point
545
+ for most standard serialization needs.
546
+ """
547
+ serialized_obs = self._to_json_serializable(obs)
548
+
549
+ # If it's already a dict, return as-is, otherwise wrap it
550
+ if isinstance(serialized_obs, dict):
551
+ return serialized_obs
552
+ else:
553
+ return {"observation": serialized_obs}
554
+
555
+ def run(self, transport: str = "streamable-http", **kwargs):
556
+ """
557
+ Run the unified MCP-Gym server.
558
+
559
+ Args:
560
+ transport: MCP transport protocol ("stdio", "sse", "streamable-http")
561
+ **kwargs: Additional arguments passed to FastMCP.run()
562
+ """
563
+ print(f"🚀 {self.mcp.name} MCP-Gym Server Starting...")
564
+ print(f"📡 Transport: {transport}")
565
+ print("🎯 MCP Pattern: HTTP endpoints for control plane, tools for data plane")
566
+ print("🔗 Session-aware control plane endpoints:")
567
+
568
+ # List registered control plane endpoints
569
+ for endpoint_name, endpoint_func in self._control_plane_endpoints.items():
570
+ print(f" - {endpoint_name}: {endpoint_func._control_plane_path}")
571
+
572
+ if not self._control_plane_endpoints:
573
+ print(" - No control plane endpoints registered")
574
+
575
+ print()
576
+
577
+ # Run the unified server
578
+ self.mcp.run(transport=transport, **kwargs)
579
+
580
+ def _to_json_serializable(self, obj: Any) -> Any:
581
+ """Convert any object to JSON-serializable format.
582
+
583
+ Handles Pydantic models, dataclasses, lists, dicts, and primitive types.
584
+ This is a utility method that can be used by format_observation implementations.
585
+ """
586
+ from pydantic import BaseModel
587
+ import dataclasses
588
+ from datetime import datetime, date
589
+ from enum import Enum
590
+
591
+ # Handle None and primitive types
592
+ if obj is None or isinstance(obj, (str, int, float, bool)):
593
+ return obj
594
+
595
+ # Handle datetime objects
596
+ elif isinstance(obj, (datetime, date)):
597
+ return obj.isoformat()
598
+
599
+ # Handle enums
600
+ elif isinstance(obj, Enum):
601
+ return obj.value
602
+
603
+ # Handle Pydantic models (covers tau2 objects and many others)
604
+ elif isinstance(obj, BaseModel):
605
+ return obj.model_dump()
606
+
607
+ # Handle dataclasses
608
+ elif dataclasses.is_dataclass(obj):
609
+ return dataclasses.asdict(obj)
610
+
611
+ # Handle dictionaries
612
+ elif isinstance(obj, dict):
613
+ return {k: self._to_json_serializable(v) for k, v in obj.items()}
614
+
615
+ # Handle lists and tuples
616
+ elif isinstance(obj, (list, tuple)):
617
+ return [self._to_json_serializable(item) for item in obj]
618
+
619
+ # Handle sets (convert to list)
620
+ elif isinstance(obj, set):
621
+ return [self._to_json_serializable(item) for item in obj]
622
+
623
+ # Handle objects with __dict__ (fallback)
624
+ elif hasattr(obj, '__dict__'):
625
+ result = {}
626
+ for key, value in obj.__dict__.items():
627
+ if not key.startswith('_'): # Skip private attributes
628
+ try:
629
+ result[key] = self._to_json_serializable(value)
630
+ except Exception:
631
+ # If conversion fails, store as string
632
+ result[key] = str(value)
633
+ return result
634
+
635
+ # Final fallback - convert to string
636
+ else:
637
+ return str(obj)