eval-protocol 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- development/__init__.py +1 -0
- development/normalize_sandbox_fusion.py +628 -0
- development/utils/__init__.py +1 -0
- development/utils/generate_api_key.py +31 -0
- development/utils/subprocess_manager.py +481 -0
- eval_protocol/__init__.py +86 -0
- eval_protocol/__main__.py +10 -0
- eval_protocol/_version.py +21 -0
- eval_protocol/adapters/__init__.py +1 -0
- eval_protocol/adapters/braintrust.py +8 -0
- eval_protocol/adapters/trl.py +8 -0
- eval_protocol/agent/__init__.py +29 -0
- eval_protocol/agent/models.py +69 -0
- eval_protocol/agent/orchestrator.py +893 -0
- eval_protocol/agent/resource_abc.py +89 -0
- eval_protocol/agent/resource_pool.py +184 -0
- eval_protocol/agent/resources/__init__.py +44 -0
- eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
- eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
- eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
- eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
- eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
- eval_protocol/agent/resources/docker_resource.py +479 -0
- eval_protocol/agent/resources/filesystem_resource.py +371 -0
- eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
- eval_protocol/agent/resources/http_rollout_resource.py +325 -0
- eval_protocol/agent/resources/python_state_resource.py +170 -0
- eval_protocol/agent/resources/sql_resource.py +271 -0
- eval_protocol/agent/task_manager.py +1064 -0
- eval_protocol/agent/tool_registry.py +111 -0
- eval_protocol/auth.py +156 -0
- eval_protocol/cli.py +425 -0
- eval_protocol/cli_commands/__init__.py +1 -0
- eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
- eval_protocol/cli_commands/common.py +242 -0
- eval_protocol/cli_commands/deploy.py +486 -0
- eval_protocol/cli_commands/deploy_mcp.py +287 -0
- eval_protocol/cli_commands/preview.py +186 -0
- eval_protocol/cli_commands/run_eval_cmd.py +202 -0
- eval_protocol/common_utils.py +36 -0
- eval_protocol/config.py +180 -0
- eval_protocol/datasets/__init__.py +1 -0
- eval_protocol/datasets/loader.py +521 -0
- eval_protocol/evaluation.py +1045 -0
- eval_protocol/execution/__init__.py +1 -0
- eval_protocol/execution/pipeline.py +920 -0
- eval_protocol/gcp_tools.py +484 -0
- eval_protocol/generation/cache.py +141 -0
- eval_protocol/generation/clients/base.py +67 -0
- eval_protocol/generation/clients.py +248 -0
- eval_protocol/generic_server.py +165 -0
- eval_protocol/integrations/__init__.py +12 -0
- eval_protocol/integrations/braintrust.py +51 -0
- eval_protocol/integrations/deepeval.py +106 -0
- eval_protocol/integrations/openeval.py +40 -0
- eval_protocol/integrations/trl.py +187 -0
- eval_protocol/mcp/__init__.py +48 -0
- eval_protocol/mcp/adapter.py +131 -0
- eval_protocol/mcp/client/__init__.py +12 -0
- eval_protocol/mcp/client/connection.py +499 -0
- eval_protocol/mcp/clients.py +195 -0
- eval_protocol/mcp/execution/__init__.py +23 -0
- eval_protocol/mcp/execution/base_policy.py +227 -0
- eval_protocol/mcp/execution/fireworks_policy.py +209 -0
- eval_protocol/mcp/execution/manager.py +506 -0
- eval_protocol/mcp/execution/policy.py +421 -0
- eval_protocol/mcp/grid_renderer.py +54 -0
- eval_protocol/mcp/mcpgym.py +637 -0
- eval_protocol/mcp/process_manager.py +177 -0
- eval_protocol/mcp/session/__init__.py +11 -0
- eval_protocol/mcp/session/manager.py +228 -0
- eval_protocol/mcp/simple_process_manager.py +291 -0
- eval_protocol/mcp/simulation_server.py +458 -0
- eval_protocol/mcp/types.py +80 -0
- eval_protocol/mcp_agent/__init__.py +1 -0
- eval_protocol/mcp_agent/config.py +147 -0
- eval_protocol/mcp_agent/intermediary_server.py +542 -0
- eval_protocol/mcp_agent/main.py +210 -0
- eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
- eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
- eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
- eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
- eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
- eval_protocol/mcp_agent/session.py +79 -0
- eval_protocol/mcp_env.py +304 -0
- eval_protocol/models.py +366 -0
- eval_protocol/packaging.py +219 -0
- eval_protocol/platform_api.py +360 -0
- eval_protocol/playback_policy.py +396 -0
- eval_protocol/resources.py +128 -0
- eval_protocol/reward_function.py +410 -0
- eval_protocol/rewards/__init__.py +94 -0
- eval_protocol/rewards/accuracy.py +454 -0
- eval_protocol/rewards/accuracy_length.py +173 -0
- eval_protocol/rewards/apps_coding_reward.py +331 -0
- eval_protocol/rewards/apps_execution_utils.py +149 -0
- eval_protocol/rewards/apps_testing_util.py +559 -0
- eval_protocol/rewards/bfcl_reward.py +313 -0
- eval_protocol/rewards/code_execution.py +1620 -0
- eval_protocol/rewards/code_execution_utils.py +72 -0
- eval_protocol/rewards/cpp_code.py +861 -0
- eval_protocol/rewards/deepcoder_reward.py +161 -0
- eval_protocol/rewards/format.py +129 -0
- eval_protocol/rewards/function_calling.py +541 -0
- eval_protocol/rewards/json_schema.py +422 -0
- eval_protocol/rewards/language_consistency.py +700 -0
- eval_protocol/rewards/lean_prover.py +479 -0
- eval_protocol/rewards/length.py +375 -0
- eval_protocol/rewards/list_comparison_math_reward.py +221 -0
- eval_protocol/rewards/math.py +762 -0
- eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
- eval_protocol/rewards/reasoning_steps.py +249 -0
- eval_protocol/rewards/repetition.py +342 -0
- eval_protocol/rewards/tag_count.py +162 -0
- eval_protocol/rl_processing.py +82 -0
- eval_protocol/server.py +271 -0
- eval_protocol/typed_interface.py +260 -0
- eval_protocol/utils/__init__.py +8 -0
- eval_protocol/utils/batch_evaluation.py +217 -0
- eval_protocol/utils/batch_transformation.py +205 -0
- eval_protocol/utils/dataset_helpers.py +112 -0
- eval_protocol/utils/module_loader.py +56 -0
- eval_protocol/utils/packaging_utils.py +108 -0
- eval_protocol/utils/static_policy.py +305 -0
- eval_protocol-0.0.3.dist-info/METADATA +635 -0
- eval_protocol-0.0.3.dist-info/RECORD +130 -0
- eval_protocol-0.0.3.dist-info/WHEEL +5 -0
- eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
- eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
- eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,458 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MCP Simulation Server Framework
|
|
3
|
+
|
|
4
|
+
This framework enforces the correct separation between production and simulation servers.
|
|
5
|
+
It ensures that:
|
|
6
|
+
1. No session management tools are exposed to models
|
|
7
|
+
2. Session initialization happens via client_info (MCP spec)
|
|
8
|
+
3. Only domain game tools are exposed
|
|
9
|
+
4. Simulation logic is handled internally using proper MCP session management
|
|
10
|
+
|
|
11
|
+
Usage:
|
|
12
|
+
class MyGameSimulation(SimulationServerBase):
|
|
13
|
+
def create_environment(self, config): ...
|
|
14
|
+
def reset_environment(self, env, seed): ...
|
|
15
|
+
# etc.
|
|
16
|
+
|
|
17
|
+
server = MyGameSimulation("MyGame")
|
|
18
|
+
server.run()
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
import asyncio
|
|
22
|
+
import contextlib
|
|
23
|
+
import functools
|
|
24
|
+
import inspect
|
|
25
|
+
import json
|
|
26
|
+
import logging
|
|
27
|
+
import threading
|
|
28
|
+
import time
|
|
29
|
+
import uuid
|
|
30
|
+
from abc import ABC, abstractmethod
|
|
31
|
+
from collections.abc import AsyncIterator
|
|
32
|
+
from contextlib import asynccontextmanager
|
|
33
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
34
|
+
|
|
35
|
+
import uvicorn
|
|
36
|
+
from mcp.server.lowlevel import Server
|
|
37
|
+
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
|
38
|
+
from starlette.applications import Starlette
|
|
39
|
+
from starlette.routing import Mount
|
|
40
|
+
from starlette.types import Receive, Scope, Send
|
|
41
|
+
|
|
42
|
+
logger = logging.getLogger(__name__)
|
|
43
|
+
logger.setLevel(logging.DEBUG)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ToolMismatchError(Exception):
|
|
47
|
+
"""Raised when simulation and production tools do not match."""
|
|
48
|
+
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class SignatureMismatchError(Exception):
|
|
53
|
+
"""Raised when a tool's signature does not match the production version."""
|
|
54
|
+
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def simulation_tool(func: Callable) -> Callable:
|
|
59
|
+
"""
|
|
60
|
+
Decorator to mark methods as simulation tools.
|
|
61
|
+
These tools will be exposed to the MCP client and validated against production.
|
|
62
|
+
"""
|
|
63
|
+
func._is_simulation_tool = True
|
|
64
|
+
return func
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def simulation_resource(uri_pattern: str) -> Callable:
|
|
68
|
+
"""
|
|
69
|
+
Decorator to mark methods as MCP resources in simulation servers.
|
|
70
|
+
|
|
71
|
+
Unlike production resources, simulation resources have access to session context
|
|
72
|
+
and can provide session-specific initial states based on initialization options.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def decorator(func: Callable) -> Callable:
|
|
76
|
+
func._is_resource = True
|
|
77
|
+
func._resource_uri = uri_pattern
|
|
78
|
+
return func
|
|
79
|
+
|
|
80
|
+
return decorator
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class SimulationServerBase(ABC):
|
|
84
|
+
"""
|
|
85
|
+
Base class for simulation MCP servers using proper StreamableHTTPSessionManager.
|
|
86
|
+
|
|
87
|
+
This framework enforces correct separation by:
|
|
88
|
+
- Using StreamableHTTPSessionManager for proper session management
|
|
89
|
+
- Extracting seeds from client_info during session initialization
|
|
90
|
+
- Only exposing domain-specific game tools
|
|
91
|
+
- Preventing session management tool pollution
|
|
92
|
+
- Supporting MCP resources for initial state following proper MCP patterns
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
server_name: str,
|
|
98
|
+
production_server_app=None,
|
|
99
|
+
):
|
|
100
|
+
"""
|
|
101
|
+
Initialize simulation server framework.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
server_name: Name for the MCP server.
|
|
105
|
+
production_server_app: The production server app instance for validation (optional).
|
|
106
|
+
"""
|
|
107
|
+
self.server_name = server_name
|
|
108
|
+
self.production_server_app = production_server_app
|
|
109
|
+
self._domain_tools: Dict[str, Callable] = {}
|
|
110
|
+
self._domain_resources: Dict[str, Callable] = {}
|
|
111
|
+
|
|
112
|
+
# Create low-level MCP server
|
|
113
|
+
self.app = Server(server_name)
|
|
114
|
+
|
|
115
|
+
# Session state storage for simulation environments
|
|
116
|
+
self.session_environments: Dict[str, Dict[str, Any]] = {}
|
|
117
|
+
self.session_lock = threading.Lock()
|
|
118
|
+
|
|
119
|
+
# Discover and register domain tools and resources
|
|
120
|
+
self._discover_and_register_tools()
|
|
121
|
+
self._discover_and_register_resources()
|
|
122
|
+
self._register_session_handlers()
|
|
123
|
+
|
|
124
|
+
def _get_session_id_from_context(self, ctx) -> str:
|
|
125
|
+
"""Extract session ID from MCP request context."""
|
|
126
|
+
# Use a stable session ID based on the client info
|
|
127
|
+
# Since we know the client_info is consistent for a given session,
|
|
128
|
+
# we can use a hash of the client_info to create a stable session ID
|
|
129
|
+
if hasattr(ctx, "session") and hasattr(ctx.session, "client_params"):
|
|
130
|
+
client_params = ctx.session.client_params
|
|
131
|
+
if hasattr(client_params, "clientInfo"):
|
|
132
|
+
client_info = client_params.clientInfo
|
|
133
|
+
if client_info and hasattr(client_info, "_extra"):
|
|
134
|
+
extra_data = client_info._extra
|
|
135
|
+
if extra_data and isinstance(extra_data, dict):
|
|
136
|
+
# Create a stable session ID based on seed and other config
|
|
137
|
+
import hashlib
|
|
138
|
+
import json
|
|
139
|
+
|
|
140
|
+
stable_data = {
|
|
141
|
+
"seed": extra_data.get("seed"),
|
|
142
|
+
"config": extra_data.get("config", {}),
|
|
143
|
+
"name": client_info.name,
|
|
144
|
+
"version": client_info.version,
|
|
145
|
+
}
|
|
146
|
+
stable_str = json.dumps(stable_data, sort_keys=True)
|
|
147
|
+
session_id = hashlib.md5(stable_str.encode()).hexdigest()
|
|
148
|
+
logger.debug(f"Generated stable session_id from client_info: {session_id}")
|
|
149
|
+
return session_id
|
|
150
|
+
|
|
151
|
+
# Fallback for testing or other scenarios
|
|
152
|
+
session_id = f"sim_{id(ctx)}"
|
|
153
|
+
logger.debug(f"Generated fallback session_id: {session_id}")
|
|
154
|
+
return session_id
|
|
155
|
+
|
|
156
|
+
def _get_or_create_session_env(self, ctx) -> Dict[str, Any]:
|
|
157
|
+
"""
|
|
158
|
+
Get or create session environment.
|
|
159
|
+
|
|
160
|
+
This extracts the seed from client_info and creates a session-specific environment.
|
|
161
|
+
"""
|
|
162
|
+
session_id = self._get_session_id_from_context(ctx)
|
|
163
|
+
|
|
164
|
+
with self.session_lock:
|
|
165
|
+
if session_id not in self.session_environments:
|
|
166
|
+
# Extract seed from client info if available
|
|
167
|
+
config = self.get_default_config()
|
|
168
|
+
seed = None
|
|
169
|
+
|
|
170
|
+
# Extract client info and seed
|
|
171
|
+
if hasattr(ctx, "session") and hasattr(ctx.session, "client_params"):
|
|
172
|
+
client_params = ctx.session.client_params
|
|
173
|
+
if hasattr(client_params, "clientInfo"):
|
|
174
|
+
client_info = client_params.clientInfo
|
|
175
|
+
if client_info and hasattr(client_info, "_extra"):
|
|
176
|
+
extra_data = client_info._extra
|
|
177
|
+
if extra_data and isinstance(extra_data, dict):
|
|
178
|
+
# Extract seed from client info
|
|
179
|
+
seed = extra_data.get("seed")
|
|
180
|
+
logger.info(f"🎯 Extracted seed from client_info: {seed}")
|
|
181
|
+
# Update config with any additional options
|
|
182
|
+
if "config" in extra_data:
|
|
183
|
+
config.update(extra_data["config"])
|
|
184
|
+
|
|
185
|
+
# Create environment with seed - use create_environment_with_seed if available
|
|
186
|
+
# This is important for environments like FrozenLake that need the seed during creation
|
|
187
|
+
if hasattr(self, "create_environment_with_seed") and callable(
|
|
188
|
+
getattr(self, "create_environment_with_seed")
|
|
189
|
+
):
|
|
190
|
+
env, obs, info = self.create_environment_with_seed(config, seed=seed)
|
|
191
|
+
else:
|
|
192
|
+
env = self.create_environment(config)
|
|
193
|
+
obs, info = self.reset_environment(env, seed=seed)
|
|
194
|
+
|
|
195
|
+
self.session_environments[session_id] = {
|
|
196
|
+
"env": env,
|
|
197
|
+
"config": config,
|
|
198
|
+
"seed": seed,
|
|
199
|
+
"created_at": time.time(),
|
|
200
|
+
"initial_observation": self.format_observation(obs),
|
|
201
|
+
"session_id": session_id,
|
|
202
|
+
"steps": 0,
|
|
203
|
+
"total_reward": 0.0,
|
|
204
|
+
"last_used": time.time(),
|
|
205
|
+
}
|
|
206
|
+
logger.info(f"🆕 Simulation session created: {session_id[:16]}... (seed={seed})")
|
|
207
|
+
|
|
208
|
+
self.session_environments[session_id]["last_used"] = time.time()
|
|
209
|
+
return self.session_environments[session_id]
|
|
210
|
+
|
|
211
|
+
def _discover_and_register_tools(self):
|
|
212
|
+
"""
|
|
213
|
+
Discover and register tools marked with @simulation_tool.
|
|
214
|
+
"""
|
|
215
|
+
# 1. Discover tools on the subclass instance
|
|
216
|
+
discovered_tools = {}
|
|
217
|
+
for name, method in inspect.getmembers(self, predicate=inspect.ismethod):
|
|
218
|
+
if hasattr(method, "_is_simulation_tool"):
|
|
219
|
+
discovered_tools[method.__name__] = method
|
|
220
|
+
self._domain_tools = discovered_tools
|
|
221
|
+
|
|
222
|
+
# 2. Register the discovered tools with the MCP server
|
|
223
|
+
if discovered_tools:
|
|
224
|
+
|
|
225
|
+
@self.app.call_tool()
|
|
226
|
+
async def call_tool(name: str, arguments: dict):
|
|
227
|
+
# Get the current request context
|
|
228
|
+
ctx = self.app.request_context
|
|
229
|
+
session_state = self._get_or_create_session_env(ctx)
|
|
230
|
+
|
|
231
|
+
# Find the matching tool function
|
|
232
|
+
if name in self._domain_tools:
|
|
233
|
+
tool_func = self._domain_tools[name]
|
|
234
|
+
|
|
235
|
+
# Check if the tool function is async or sync
|
|
236
|
+
if inspect.iscoroutinefunction(tool_func):
|
|
237
|
+
result = await tool_func(ctx=ctx, session_state=session_state, **arguments)
|
|
238
|
+
else:
|
|
239
|
+
# For sync functions, call them directly
|
|
240
|
+
result = tool_func(ctx=ctx, session_state=session_state, **arguments)
|
|
241
|
+
|
|
242
|
+
# Return list of ContentBlock for low-level server
|
|
243
|
+
from mcp.types import TextContent
|
|
244
|
+
|
|
245
|
+
result_str = json.dumps(result) if not isinstance(result, str) else result
|
|
246
|
+
return [TextContent(type="text", text=result_str)]
|
|
247
|
+
else:
|
|
248
|
+
raise ValueError(f"Unknown tool: {name}")
|
|
249
|
+
|
|
250
|
+
@self.app.list_tools()
|
|
251
|
+
async def list_tools():
|
|
252
|
+
"""List all available tools."""
|
|
253
|
+
from mcp.types import Tool
|
|
254
|
+
|
|
255
|
+
tools = []
|
|
256
|
+
for tool_name, tool_func in self._domain_tools.items():
|
|
257
|
+
# Extract docstring as description
|
|
258
|
+
description = tool_func.__doc__ or f"Execute {tool_name} action"
|
|
259
|
+
|
|
260
|
+
# Create a basic input schema - could be enhanced by inspecting function signature
|
|
261
|
+
input_schema = {"type": "object", "properties": {}, "required": []}
|
|
262
|
+
|
|
263
|
+
tools.append(
|
|
264
|
+
Tool(
|
|
265
|
+
name=tool_name,
|
|
266
|
+
description=description,
|
|
267
|
+
inputSchema=input_schema,
|
|
268
|
+
)
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
return tools
|
|
272
|
+
|
|
273
|
+
logger.info(f"✅ Registered {len(discovered_tools)} domain tools")
|
|
274
|
+
|
|
275
|
+
def _discover_and_register_resources(self):
|
|
276
|
+
"""
|
|
277
|
+
Discover and register resources on the subclass instance.
|
|
278
|
+
"""
|
|
279
|
+
# 1. Discover resources on the subclass instance
|
|
280
|
+
discovered_resources = {}
|
|
281
|
+
for name, method in inspect.getmembers(self, predicate=inspect.ismethod):
|
|
282
|
+
if hasattr(method, "_is_resource"):
|
|
283
|
+
discovered_resources[method.__name__] = method
|
|
284
|
+
self._domain_resources = discovered_resources
|
|
285
|
+
|
|
286
|
+
# 2. Register the discovered resources with the MCP server
|
|
287
|
+
if discovered_resources:
|
|
288
|
+
|
|
289
|
+
@self.app.read_resource()
|
|
290
|
+
async def read_resource(uri: str):
|
|
291
|
+
# Get the current request context
|
|
292
|
+
ctx = self.app.request_context
|
|
293
|
+
|
|
294
|
+
# Find the matching resource function by URI pattern
|
|
295
|
+
for resource_name, resource_func in self._domain_resources.items():
|
|
296
|
+
resource_uri_pattern = resource_func._resource_uri
|
|
297
|
+
# Convert URI to string for pattern matching
|
|
298
|
+
uri_str = str(uri)
|
|
299
|
+
# Simple pattern matching - could be enhanced for complex patterns
|
|
300
|
+
if uri_str == resource_uri_pattern or uri_str.endswith(resource_uri_pattern):
|
|
301
|
+
# Create session state for this resource call
|
|
302
|
+
session_state = self._get_or_create_session_env(ctx)
|
|
303
|
+
|
|
304
|
+
# Check if the resource function is async or sync
|
|
305
|
+
if inspect.iscoroutinefunction(resource_func):
|
|
306
|
+
result = await resource_func(ctx=ctx, session_state=session_state)
|
|
307
|
+
else:
|
|
308
|
+
# For sync functions, call them directly
|
|
309
|
+
result = resource_func(ctx=ctx, session_state=session_state)
|
|
310
|
+
|
|
311
|
+
# Ensure we return the proper format for the low-level server
|
|
312
|
+
if isinstance(result, str):
|
|
313
|
+
return result
|
|
314
|
+
else:
|
|
315
|
+
return json.dumps(result)
|
|
316
|
+
|
|
317
|
+
raise ValueError(f"Unknown resource: {uri}")
|
|
318
|
+
|
|
319
|
+
@self.app.list_resources()
|
|
320
|
+
async def list_resources():
|
|
321
|
+
"""List all available resources."""
|
|
322
|
+
from mcp.types import Resource
|
|
323
|
+
|
|
324
|
+
resources = []
|
|
325
|
+
for resource_name, resource_func in self._domain_resources.items():
|
|
326
|
+
# Extract docstring as description
|
|
327
|
+
description = resource_func.__doc__ or f"Resource {resource_name}"
|
|
328
|
+
|
|
329
|
+
resources.append(
|
|
330
|
+
Resource(
|
|
331
|
+
uri=resource_func._resource_uri,
|
|
332
|
+
name=resource_name,
|
|
333
|
+
description=description,
|
|
334
|
+
mimeType="application/json",
|
|
335
|
+
)
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
return resources
|
|
339
|
+
|
|
340
|
+
logger.info(f"✅ Registered {len(discovered_resources)} domain resources")
|
|
341
|
+
|
|
342
|
+
def _register_session_handlers(self):
|
|
343
|
+
"""Register session initialization and cleanup handlers."""
|
|
344
|
+
|
|
345
|
+
@self.app.set_logging_level()
|
|
346
|
+
async def set_logging_level(level):
|
|
347
|
+
"""Handle logging level requests."""
|
|
348
|
+
logger.setLevel(getattr(logging, level.upper()))
|
|
349
|
+
return {}
|
|
350
|
+
|
|
351
|
+
# NOTE: The low-level Server doesn't have built-in session lifecycle hooks
|
|
352
|
+
# We'll need to capture client_info during the first request in each session
|
|
353
|
+
# This is a limitation of using the low-level server directly
|
|
354
|
+
|
|
355
|
+
# Abstract methods that subclasses MUST implement
|
|
356
|
+
|
|
357
|
+
@abstractmethod
|
|
358
|
+
def create_environment(self, config: Dict[str, Any]) -> Any:
|
|
359
|
+
"""Create environment instance."""
|
|
360
|
+
pass
|
|
361
|
+
|
|
362
|
+
@abstractmethod
|
|
363
|
+
def reset_environment(self, env: Any, seed: Optional[int] = None) -> Tuple[Any, Dict[str, Any]]:
|
|
364
|
+
"""Reset environment to initial state."""
|
|
365
|
+
pass
|
|
366
|
+
|
|
367
|
+
@abstractmethod
|
|
368
|
+
def step_environment(self, env: Any, action: Any) -> Tuple[Any, float, bool, bool, Dict[str, Any]]:
|
|
369
|
+
"""Execute step in environment."""
|
|
370
|
+
pass
|
|
371
|
+
|
|
372
|
+
@abstractmethod
|
|
373
|
+
def close_environment(self, env: Any) -> None:
|
|
374
|
+
"""Clean up environment resources."""
|
|
375
|
+
pass
|
|
376
|
+
|
|
377
|
+
@abstractmethod
|
|
378
|
+
def parse_action(self, action_str: str) -> Any:
|
|
379
|
+
"""Parse action string to environment action."""
|
|
380
|
+
pass
|
|
381
|
+
|
|
382
|
+
@abstractmethod
|
|
383
|
+
def format_observation(self, observation: Any) -> Any:
|
|
384
|
+
"""Format observation for JSON serialization."""
|
|
385
|
+
pass
|
|
386
|
+
|
|
387
|
+
@abstractmethod
|
|
388
|
+
def get_default_config(self) -> Dict[str, Any]:
|
|
389
|
+
"""Get default environment configuration."""
|
|
390
|
+
pass
|
|
391
|
+
|
|
392
|
+
def run(self, port: int = 8000, host: str = "127.0.0.1", **kwargs):
|
|
393
|
+
"""
|
|
394
|
+
Run the simulation server using StreamableHTTPSessionManager.
|
|
395
|
+
|
|
396
|
+
Args:
|
|
397
|
+
port: Port to listen on
|
|
398
|
+
host: Host to bind to
|
|
399
|
+
**kwargs: Additional arguments for uvicorn
|
|
400
|
+
"""
|
|
401
|
+
print(f"📡 Starting simulation server with StreamableHTTPSessionManager")
|
|
402
|
+
print(f"🎮 Domain tools: {list(self._domain_tools.keys())}")
|
|
403
|
+
print(f"📦 Domain resources: {list(self._domain_resources.keys())}")
|
|
404
|
+
if self.production_server_app:
|
|
405
|
+
print("✅ Tool signatures validated against production server.")
|
|
406
|
+
print("🚫 No session management tools exposed (framework enforced)")
|
|
407
|
+
print()
|
|
408
|
+
|
|
409
|
+
# Create the session manager with our app
|
|
410
|
+
session_manager = StreamableHTTPSessionManager(
|
|
411
|
+
app=self.app,
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
# ASGI handler for streamable HTTP connections
|
|
415
|
+
async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None:
|
|
416
|
+
await session_manager.handle_request(scope, receive, send)
|
|
417
|
+
|
|
418
|
+
@contextlib.asynccontextmanager
|
|
419
|
+
async def lifespan(app: Starlette) -> AsyncIterator[None]:
|
|
420
|
+
"""Context manager for managing session manager lifecycle."""
|
|
421
|
+
async with session_manager.run():
|
|
422
|
+
logger.info(f"🚀 {self.server_name} started with StreamableHTTP session manager!")
|
|
423
|
+
try:
|
|
424
|
+
yield
|
|
425
|
+
finally:
|
|
426
|
+
logger.info("🧹 Simulation server shutting down...")
|
|
427
|
+
# Clean up session environments
|
|
428
|
+
with self.session_lock:
|
|
429
|
+
for (
|
|
430
|
+
session_id,
|
|
431
|
+
session_data,
|
|
432
|
+
) in self.session_environments.items():
|
|
433
|
+
env = session_data.get("env")
|
|
434
|
+
if env:
|
|
435
|
+
try:
|
|
436
|
+
self.close_environment(env)
|
|
437
|
+
except Exception as e:
|
|
438
|
+
logger.warning(f"⚠️ Error closing environment in session {session_id}: {e}")
|
|
439
|
+
self.session_environments.clear()
|
|
440
|
+
logger.info("✅ Simulation server shutdown complete")
|
|
441
|
+
|
|
442
|
+
# Create an ASGI application using the transport
|
|
443
|
+
starlette_app = Starlette(
|
|
444
|
+
debug=kwargs.get("debug", False),
|
|
445
|
+
routes=[
|
|
446
|
+
Mount("/mcp", app=handle_streamable_http),
|
|
447
|
+
],
|
|
448
|
+
lifespan=lifespan,
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
# Run the server
|
|
452
|
+
uvicorn.run(
|
|
453
|
+
starlette_app,
|
|
454
|
+
host=host,
|
|
455
|
+
port=port,
|
|
456
|
+
log_level=kwargs.get("log_level", "info"),
|
|
457
|
+
**{k: v for k, v in kwargs.items() if k not in ["debug", "log_level"]},
|
|
458
|
+
)
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Any, Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TerminationReason(str, Enum):
|
|
7
|
+
"""Enum representing different reasons why a trajectory terminated.
|
|
8
|
+
|
|
9
|
+
MAX_STEPS: Trajectory ends because we hit the step limit
|
|
10
|
+
CONTROL_PLANE_SIGNAL: Trajectory ends because the control plane signals termination (e.g. env goal reached or failure condition)
|
|
11
|
+
USER_STOP: Trajectory ends because the simulated user signals to stop
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
MAX_STEPS = "max_steps"
|
|
15
|
+
CONTROL_PLANE_SIGNAL = "control_plane_signal"
|
|
16
|
+
USER_STOP = "user_stop"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class MCPToolCall:
|
|
21
|
+
"""Represents a tool call to be executed via MCP."""
|
|
22
|
+
|
|
23
|
+
tool_name: str
|
|
24
|
+
arguments: Dict[str, Any]
|
|
25
|
+
tool_call_id: Optional[str] = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class DatasetRow:
|
|
30
|
+
"""Represents a row from the dataset JSONL."""
|
|
31
|
+
|
|
32
|
+
id: str
|
|
33
|
+
seed: int
|
|
34
|
+
system_prompt: str
|
|
35
|
+
user_prompt_template: str
|
|
36
|
+
environment_context: Dict[str, Any]
|
|
37
|
+
user_simulation: Optional[Dict[str, Any]] = None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class MCPSession:
|
|
42
|
+
"""Represents a single MCP session with an environment."""
|
|
43
|
+
|
|
44
|
+
session_id: str
|
|
45
|
+
base_url: str
|
|
46
|
+
seed: Optional[int]
|
|
47
|
+
model_id: str
|
|
48
|
+
dataset_row: Optional[DatasetRow] = None
|
|
49
|
+
terminated: bool = False
|
|
50
|
+
last_observation: Any = None
|
|
51
|
+
|
|
52
|
+
# Persistent MCP connection components
|
|
53
|
+
_exit_stack: Optional[Any] = None
|
|
54
|
+
_mcp_session: Optional[Any] = None
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class Trajectory:
|
|
59
|
+
"""Represents a complete rollout trajectory."""
|
|
60
|
+
|
|
61
|
+
session: MCPSession
|
|
62
|
+
observations: List[Any]
|
|
63
|
+
actions: List[str]
|
|
64
|
+
rewards: List[float]
|
|
65
|
+
terminated: bool
|
|
66
|
+
total_reward: float
|
|
67
|
+
steps: int
|
|
68
|
+
duration: float
|
|
69
|
+
control_plane_steps: List[Dict[str, Any]]
|
|
70
|
+
control_plane_summary: Dict[str, Any]
|
|
71
|
+
termination_reason: str
|
|
72
|
+
conversation_history: List[Dict[str, Any]]
|
|
73
|
+
llm_usage_summary: Dict[str, int] = field(default_factory=dict)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@dataclass
|
|
77
|
+
class LLMUsageStats:
|
|
78
|
+
prompt_tokens: int
|
|
79
|
+
completion_tokens: int
|
|
80
|
+
total_tokens: int
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Reward Kit MCP Agent Package
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Literal, Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class RemoteApiConfig(BaseModel):
|
|
7
|
+
"""Configuration for a remote orchestration API."""
|
|
8
|
+
|
|
9
|
+
base_url: str = Field(..., description="Base URL of the remote orchestration API.")
|
|
10
|
+
create_instance_endpoint: str = Field("/instances", description="Endpoint to create a new instance.")
|
|
11
|
+
delete_instance_endpoint_template: str = Field(
|
|
12
|
+
"/instances/{remote_instance_id}",
|
|
13
|
+
description="Template for the endpoint to delete an instance. {remote_instance_id} will be replaced.",
|
|
14
|
+
)
|
|
15
|
+
call_tool_endpoint_template: Optional[str] = Field(
|
|
16
|
+
None,
|
|
17
|
+
description="Template for the endpoint to call a tool on an instance. Optional, if not provided, tools are called directly on the instance's mcp_endpoint_url.",
|
|
18
|
+
)
|
|
19
|
+
auth_type: Literal["none", "bearer_token", "custom_header"] = Field(
|
|
20
|
+
"none", description="Authentication type for the remote API."
|
|
21
|
+
)
|
|
22
|
+
auth_details: Optional[Dict[str, str]] = Field(
|
|
23
|
+
None,
|
|
24
|
+
description="Authentication details, e.g., {'token': 'your_token'} or {'header_name': 'X-API-Key', 'header_value': 'your_key'}.",
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class BackendServerConfig(BaseModel):
|
|
29
|
+
"""Configuration for a backend server that the intermediary can manage or proxy."""
|
|
30
|
+
|
|
31
|
+
backend_name_ref: str = Field(
|
|
32
|
+
...,
|
|
33
|
+
description="Unique reference name for this backend configuration (e.g., 'workspace_fs', 'shared_fetch_service').",
|
|
34
|
+
)
|
|
35
|
+
backend_type: Literal["filesystem", "duckdb", "memory", "everything", "fetch", "time"] = Field(
|
|
36
|
+
..., description="The type of the backend server."
|
|
37
|
+
)
|
|
38
|
+
orchestration_mode: Literal["local_docker", "remote_http_api"] = Field(
|
|
39
|
+
..., description="How this backend server is orchestrated."
|
|
40
|
+
)
|
|
41
|
+
instance_scoping: Literal["session", "shared_global"] = Field(
|
|
42
|
+
"session",
|
|
43
|
+
description="Defines if instances are per-session or shared globally. 'session' implies stateful, 'shared_global' implies stateless.",
|
|
44
|
+
)
|
|
45
|
+
mcp_transport: Literal["http", "stdio"] = Field(
|
|
46
|
+
"http",
|
|
47
|
+
description="MCP transport protocol used by the backend server. Defaults to 'http'. If 'stdio', container_port and http-based startup_check_mcp_tool are ignored.",
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# Local Docker Specific Fields
|
|
51
|
+
docker_image: Optional[str] = Field(
|
|
52
|
+
None, description="Docker image to use if orchestration_mode is 'local_docker'."
|
|
53
|
+
)
|
|
54
|
+
container_port: Optional[int] = Field(
|
|
55
|
+
None,
|
|
56
|
+
description="Internal port of the MCP application within the container (for HTTP). Required if orchestration_mode is 'local_docker'.",
|
|
57
|
+
)
|
|
58
|
+
template_data_path_host: Optional[str] = Field(
|
|
59
|
+
None,
|
|
60
|
+
description="Path on the host machine to data used for pre-seeding state in a template container (for 'docker commit').",
|
|
61
|
+
)
|
|
62
|
+
container_template_data_path: Optional[str] = Field(
|
|
63
|
+
None,
|
|
64
|
+
description="Mount path inside the template container where 'template_data_path_host' will be mounted.",
|
|
65
|
+
)
|
|
66
|
+
docker_run_args: Optional[List[str]] = Field(None, description="Additional arguments for 'docker run'.")
|
|
67
|
+
startup_check_mcp_tool: Optional[Dict[str, Any]] = Field(
|
|
68
|
+
None,
|
|
69
|
+
description="An MCP tool call (e.g., {'tool_name': 'ping', 'arguments': {}}) to verify container startup.",
|
|
70
|
+
)
|
|
71
|
+
# Renamed from container_command_args for clarity with docker-py's 'command' kwarg
|
|
72
|
+
container_command: Optional[List[str]] = Field(
|
|
73
|
+
None,
|
|
74
|
+
description="Command to run in the container. Overrides Docker image's CMD or passed as args to ENTRYPOINT.",
|
|
75
|
+
)
|
|
76
|
+
container_volumes: Optional[Dict[str, Dict[str, str]]] = Field(
|
|
77
|
+
None,
|
|
78
|
+
description="Volume mounts for the container, e.g., {'/host/path': {'bind': '/container/path', 'mode': 'rw'}}.",
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# Remote API Specific Fields
|
|
82
|
+
remote_api_config_ref: Optional[str] = Field(
|
|
83
|
+
None,
|
|
84
|
+
description="Reference to a globally defined RemoteApiConfig by its key/name. Used if orchestration_mode is 'remote_http_api'. Can be inline if not referencing a global one.",
|
|
85
|
+
)
|
|
86
|
+
remote_resource_type_identifier: Optional[str] = Field(
|
|
87
|
+
None,
|
|
88
|
+
description="Type identifier for the resource as known by the remote API (e.g., 'duckdb_v1', 'filesystem_large'). Required if orchestration_mode is 'remote_http_api'.",
|
|
89
|
+
)
|
|
90
|
+
# If remote_api_config_ref is not used, RemoteApiConfig can be defined inline
|
|
91
|
+
remote_api_config_inline: Optional[RemoteApiConfig] = Field(
|
|
92
|
+
None, description="Inline RemoteApiConfig if not using remote_api_config_ref."
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
class Config:
|
|
96
|
+
extra = "forbid"
|
|
97
|
+
|
|
98
|
+
def model_post_init(self, __context: Any) -> None:
|
|
99
|
+
if self.orchestration_mode == "local_docker":
|
|
100
|
+
if not self.docker_image:
|
|
101
|
+
raise ValueError("docker_image must be set for local_docker orchestration mode.")
|
|
102
|
+
# container_port is only required for http transport in local_docker mode
|
|
103
|
+
if self.mcp_transport == "http" and not self.container_port:
|
|
104
|
+
raise ValueError("container_port must be set for local_docker orchestration mode with http transport.")
|
|
105
|
+
elif self.orchestration_mode == "remote_http_api":
|
|
106
|
+
if not self.remote_resource_type_identifier:
|
|
107
|
+
raise ValueError("remote_resource_type_identifier must be set for remote_http_api orchestration mode.")
|
|
108
|
+
if not self.remote_api_config_ref and not self.remote_api_config_inline:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
"Either remote_api_config_ref or remote_api_config_inline must be set for remote_http_api orchestration mode."
|
|
111
|
+
)
|
|
112
|
+
if self.remote_api_config_ref and self.remote_api_config_inline:
|
|
113
|
+
raise ValueError("Cannot set both remote_api_config_ref and remote_api_config_inline.")
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class AppConfig(BaseModel):
|
|
117
|
+
"""Root configuration for the RewardKit Intermediary MCP Server."""
|
|
118
|
+
|
|
119
|
+
backends: List[BackendServerConfig] = Field(
|
|
120
|
+
default_factory=list,
|
|
121
|
+
description="List of configurations for all backend types the intermediary can manage or proxy.",
|
|
122
|
+
)
|
|
123
|
+
global_remote_apis: Optional[Dict[str, RemoteApiConfig]] = Field(
|
|
124
|
+
default_factory=dict,
|
|
125
|
+
description="Globally defined remote API configurations, keyed by a reference name.",
|
|
126
|
+
)
|
|
127
|
+
log_level: str = Field("INFO", description="Logging level for the server.")
|
|
128
|
+
global_docker_options: Optional[Dict[str, Any]] = Field(
|
|
129
|
+
default_factory=dict,
|
|
130
|
+
description="Global Docker options, e.g., default network settings.",
|
|
131
|
+
)
|
|
132
|
+
global_remote_api_defaults: Optional[Dict[str, Any]] = Field(
|
|
133
|
+
default_factory=dict,
|
|
134
|
+
description="Global defaults for remote API interactions, e.g., default timeouts.",
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
class Config:
|
|
138
|
+
extra = "forbid"
|
|
139
|
+
|
|
140
|
+
def get_remote_api_config(self, backend_cfg: BackendServerConfig) -> Optional[RemoteApiConfig]:
|
|
141
|
+
if backend_cfg.orchestration_mode != "remote_http_api":
|
|
142
|
+
return None
|
|
143
|
+
if backend_cfg.remote_api_config_inline:
|
|
144
|
+
return backend_cfg.remote_api_config_inline
|
|
145
|
+
if backend_cfg.remote_api_config_ref and self.global_remote_apis:
|
|
146
|
+
return self.global_remote_apis.get(backend_cfg.remote_api_config_ref)
|
|
147
|
+
return None
|