eval-protocol 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. development/__init__.py +1 -0
  2. development/normalize_sandbox_fusion.py +628 -0
  3. development/utils/__init__.py +1 -0
  4. development/utils/generate_api_key.py +31 -0
  5. development/utils/subprocess_manager.py +481 -0
  6. eval_protocol/__init__.py +86 -0
  7. eval_protocol/__main__.py +10 -0
  8. eval_protocol/_version.py +21 -0
  9. eval_protocol/adapters/__init__.py +1 -0
  10. eval_protocol/adapters/braintrust.py +8 -0
  11. eval_protocol/adapters/trl.py +8 -0
  12. eval_protocol/agent/__init__.py +29 -0
  13. eval_protocol/agent/models.py +69 -0
  14. eval_protocol/agent/orchestrator.py +893 -0
  15. eval_protocol/agent/resource_abc.py +89 -0
  16. eval_protocol/agent/resource_pool.py +184 -0
  17. eval_protocol/agent/resources/__init__.py +44 -0
  18. eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
  19. eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
  20. eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
  21. eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
  22. eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
  23. eval_protocol/agent/resources/docker_resource.py +479 -0
  24. eval_protocol/agent/resources/filesystem_resource.py +371 -0
  25. eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
  26. eval_protocol/agent/resources/http_rollout_resource.py +325 -0
  27. eval_protocol/agent/resources/python_state_resource.py +170 -0
  28. eval_protocol/agent/resources/sql_resource.py +271 -0
  29. eval_protocol/agent/task_manager.py +1064 -0
  30. eval_protocol/agent/tool_registry.py +111 -0
  31. eval_protocol/auth.py +156 -0
  32. eval_protocol/cli.py +425 -0
  33. eval_protocol/cli_commands/__init__.py +1 -0
  34. eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
  35. eval_protocol/cli_commands/common.py +242 -0
  36. eval_protocol/cli_commands/deploy.py +486 -0
  37. eval_protocol/cli_commands/deploy_mcp.py +287 -0
  38. eval_protocol/cli_commands/preview.py +186 -0
  39. eval_protocol/cli_commands/run_eval_cmd.py +202 -0
  40. eval_protocol/common_utils.py +36 -0
  41. eval_protocol/config.py +180 -0
  42. eval_protocol/datasets/__init__.py +1 -0
  43. eval_protocol/datasets/loader.py +521 -0
  44. eval_protocol/evaluation.py +1045 -0
  45. eval_protocol/execution/__init__.py +1 -0
  46. eval_protocol/execution/pipeline.py +920 -0
  47. eval_protocol/gcp_tools.py +484 -0
  48. eval_protocol/generation/cache.py +141 -0
  49. eval_protocol/generation/clients/base.py +67 -0
  50. eval_protocol/generation/clients.py +248 -0
  51. eval_protocol/generic_server.py +165 -0
  52. eval_protocol/integrations/__init__.py +12 -0
  53. eval_protocol/integrations/braintrust.py +51 -0
  54. eval_protocol/integrations/deepeval.py +106 -0
  55. eval_protocol/integrations/openeval.py +40 -0
  56. eval_protocol/integrations/trl.py +187 -0
  57. eval_protocol/mcp/__init__.py +48 -0
  58. eval_protocol/mcp/adapter.py +131 -0
  59. eval_protocol/mcp/client/__init__.py +12 -0
  60. eval_protocol/mcp/client/connection.py +499 -0
  61. eval_protocol/mcp/clients.py +195 -0
  62. eval_protocol/mcp/execution/__init__.py +23 -0
  63. eval_protocol/mcp/execution/base_policy.py +227 -0
  64. eval_protocol/mcp/execution/fireworks_policy.py +209 -0
  65. eval_protocol/mcp/execution/manager.py +506 -0
  66. eval_protocol/mcp/execution/policy.py +421 -0
  67. eval_protocol/mcp/grid_renderer.py +54 -0
  68. eval_protocol/mcp/mcpgym.py +637 -0
  69. eval_protocol/mcp/process_manager.py +177 -0
  70. eval_protocol/mcp/session/__init__.py +11 -0
  71. eval_protocol/mcp/session/manager.py +228 -0
  72. eval_protocol/mcp/simple_process_manager.py +291 -0
  73. eval_protocol/mcp/simulation_server.py +458 -0
  74. eval_protocol/mcp/types.py +80 -0
  75. eval_protocol/mcp_agent/__init__.py +1 -0
  76. eval_protocol/mcp_agent/config.py +147 -0
  77. eval_protocol/mcp_agent/intermediary_server.py +542 -0
  78. eval_protocol/mcp_agent/main.py +210 -0
  79. eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
  80. eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
  81. eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
  82. eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
  83. eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
  84. eval_protocol/mcp_agent/session.py +79 -0
  85. eval_protocol/mcp_env.py +304 -0
  86. eval_protocol/models.py +366 -0
  87. eval_protocol/packaging.py +219 -0
  88. eval_protocol/platform_api.py +360 -0
  89. eval_protocol/playback_policy.py +396 -0
  90. eval_protocol/resources.py +128 -0
  91. eval_protocol/reward_function.py +410 -0
  92. eval_protocol/rewards/__init__.py +94 -0
  93. eval_protocol/rewards/accuracy.py +454 -0
  94. eval_protocol/rewards/accuracy_length.py +173 -0
  95. eval_protocol/rewards/apps_coding_reward.py +331 -0
  96. eval_protocol/rewards/apps_execution_utils.py +149 -0
  97. eval_protocol/rewards/apps_testing_util.py +559 -0
  98. eval_protocol/rewards/bfcl_reward.py +313 -0
  99. eval_protocol/rewards/code_execution.py +1620 -0
  100. eval_protocol/rewards/code_execution_utils.py +72 -0
  101. eval_protocol/rewards/cpp_code.py +861 -0
  102. eval_protocol/rewards/deepcoder_reward.py +161 -0
  103. eval_protocol/rewards/format.py +129 -0
  104. eval_protocol/rewards/function_calling.py +541 -0
  105. eval_protocol/rewards/json_schema.py +422 -0
  106. eval_protocol/rewards/language_consistency.py +700 -0
  107. eval_protocol/rewards/lean_prover.py +479 -0
  108. eval_protocol/rewards/length.py +375 -0
  109. eval_protocol/rewards/list_comparison_math_reward.py +221 -0
  110. eval_protocol/rewards/math.py +762 -0
  111. eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
  112. eval_protocol/rewards/reasoning_steps.py +249 -0
  113. eval_protocol/rewards/repetition.py +342 -0
  114. eval_protocol/rewards/tag_count.py +162 -0
  115. eval_protocol/rl_processing.py +82 -0
  116. eval_protocol/server.py +271 -0
  117. eval_protocol/typed_interface.py +260 -0
  118. eval_protocol/utils/__init__.py +8 -0
  119. eval_protocol/utils/batch_evaluation.py +217 -0
  120. eval_protocol/utils/batch_transformation.py +205 -0
  121. eval_protocol/utils/dataset_helpers.py +112 -0
  122. eval_protocol/utils/module_loader.py +56 -0
  123. eval_protocol/utils/packaging_utils.py +108 -0
  124. eval_protocol/utils/static_policy.py +305 -0
  125. eval_protocol-0.0.3.dist-info/METADATA +635 -0
  126. eval_protocol-0.0.3.dist-info/RECORD +130 -0
  127. eval_protocol-0.0.3.dist-info/WHEEL +5 -0
  128. eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
  129. eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
  130. eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
@@ -0,0 +1,542 @@
1
+ import asyncio
2
+ import logging
3
+ import uuid
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import anyio # Added for debugging cancel scopes and tasks
7
+ from mcp import types as mcp_types # Added for type hinting
8
+ from pydantic import BaseModel, Field
9
+
10
+ from eval_protocol.mcp_agent.config import AppConfig, BackendServerConfig
11
+ from eval_protocol.mcp_agent.orchestration.base_client import (
12
+ AbstractOrchestrationClient,
13
+ ManagedInstanceInfo,
14
+ )
15
+ from eval_protocol.mcp_agent.orchestration.local_docker_client import (
16
+ LocalDockerOrchestrationClient,
17
+ )
18
+ from eval_protocol.mcp_agent.orchestration.remote_http_client import (
19
+ RemoteHttpOrchestrationClient,
20
+ )
21
+ from eval_protocol.mcp_agent.session import IntermediarySessionData
22
+
23
+ logger = logging.getLogger(__name__)
24
+ # logger.setLevel(logging.DEBUG) # Removed: Let level be set by main config
25
+
26
+ from mcp.server.fastmcp.server import Context as FastMCPContext
27
+ from mcp.server.fastmcp.server import FastMCP
28
+
29
+ # RequestContext is not directly used by handlers anymore, mcp_ctx is.
30
+
31
+
32
+ # Backend initialization models (moved here to avoid separate backends module)
33
+ class BackendInitRequest(BaseModel):
34
+ backend_name_ref: str = Field(
35
+ ...,
36
+ description="The unique reference name of the backend configuration to use (must match one in AppConfig.backends).",
37
+ )
38
+ num_instances: int = Field(
39
+ 1,
40
+ ge=1,
41
+ description="Number of instances of this backend to provision for the session.",
42
+ )
43
+ template_details: Optional[Any] = Field(
44
+ None,
45
+ description="Backend-specific details for initializing stateful instances from a template.",
46
+ )
47
+
48
+ class Config:
49
+ extra = "forbid"
50
+
51
+
52
+ class BackendInitResult(BaseModel):
53
+ backend_name_ref: str
54
+ instances: List[ManagedInstanceInfo]
55
+
56
+
57
+ # Pydantic models for tool arguments
58
+ class InitializeSessionArgs(BaseModel):
59
+ backends: List[BackendInitRequest]
60
+
61
+
62
+ class CallBackendToolArgs(BaseModel):
63
+ rk_session_id: str = Field(..., description="The session ID obtained from initialize_session.")
64
+ backend_name_ref: str = Field(..., description="The reference name of the backend to target.")
65
+ instance_id: str = Field(..., description="The ID of the specific backend instance to use.")
66
+ tool_name: str = Field(..., description="The name of the tool to call on the backend instance.")
67
+ tool_args: Dict[str, Any] = Field(default_factory=dict, description="Arguments for the backend tool.")
68
+
69
+
70
+ class ListBackendToolsArgs(BaseModel):
71
+ rk_session_id: str = Field(..., description="The session ID obtained from initialize_session.")
72
+ backend_name_ref: str = Field(..., description="The reference name of the backend to target.")
73
+ instance_id: str = Field(..., description="The ID of the specific backend instance to query for tools.")
74
+
75
+
76
+ class CleanupSessionArgs(BaseModel):
77
+ rk_session_id: str = Field(..., description="The session ID to clean up.")
78
+
79
+
80
+ # Ping might not need specific args if it uses session from mcp_ctx, or could take rk_session_id
81
+ class PingArgs(BaseModel):
82
+ rk_session_id: Optional[str] = Field(default=None, description="Optional session ID for context.")
83
+
84
+
85
+ class RewardKitIntermediaryServer(FastMCP):
86
+ def __init__(self, app_config: AppConfig, **kwargs_for_fastmcp):
87
+ super().__init__(
88
+ name="RewardKitIntermediaryMCP",
89
+ instructions="Intermediary Server for managing backend MCP resources for RewardKit RL rollouts.",
90
+ **kwargs_for_fastmcp,
91
+ )
92
+
93
+ self.app_config = app_config
94
+ self._local_docker_orchestrator: Optional[LocalDockerOrchestrationClient] = None
95
+ self._remote_http_orchestrators: Dict[str, RemoteHttpOrchestrationClient] = {}
96
+ self._shared_global_instances: Dict[str, ManagedInstanceInfo] = {}
97
+ self._shared_instance_locks: Dict[str, asyncio.Lock] = {}
98
+ self.intermediary_session_data: Dict[str, IntermediarySessionData] = {}
99
+
100
+ logger.info("RewardKitIntermediaryServer (FastMCP based) initialized. AppConfig loaded.")
101
+
102
+ # Register tools directly
103
+ self.add_tool(self._initialize_session_actual, name="initialize_session")
104
+ self.add_tool(self._call_backend_tool_actual, name="call_backend_tool")
105
+ self.add_tool(self._list_backend_tools_actual, name="list_backend_tools") # New tool
106
+ self.add_tool(self._cleanup_session_actual, name="cleanup_session")
107
+ self.add_tool(self._ping_actual, name="ping")
108
+
109
+ logger.info("Registered tools directly with FastMCP.")
110
+
111
+ # Explicitly set this module's logger level based on app_config
112
+ # This is to ensure it overrides any prior default or hardcoded DEBUG level
113
+ # if external configuration in main.py isn't fully effective.
114
+ try:
115
+ config_log_level_str = app_config.log_level.upper()
116
+ config_log_level_int = getattr(logging, config_log_level_str, logging.INFO)
117
+ if logger.getEffectiveLevel() != config_log_level_int:
118
+ logger.info(
119
+ f"Overriding intermediary_server logger level from {logging.getLevelName(logger.getEffectiveLevel())} to {config_log_level_str}"
120
+ )
121
+ logger.setLevel(config_log_level_int)
122
+ # Also ensure handlers attached directly to this logger respect it (if any)
123
+ for handler in logger.handlers:
124
+ handler.setLevel(config_log_level_int)
125
+ logger.info(
126
+ f"IntermediaryServer logger effective level: {logging.getLevelName(logger.getEffectiveLevel())}"
127
+ )
128
+
129
+ except Exception as e_log:
130
+ logger.error(f"Error trying to set intermediary_server logger level: {e_log}")
131
+
132
+ # Removed _execute_proxied_tool_impl and _internal_tool_handlers
133
+
134
+ async def _initialize_orchestrators(self):
135
+ logger.info("Initializing orchestration clients...")
136
+ if any(b.orchestration_mode == "local_docker" for b in self.app_config.backends):
137
+ self._local_docker_orchestrator = LocalDockerOrchestrationClient(self.app_config)
138
+ await self._local_docker_orchestrator.startup()
139
+ logger.info("LocalDockerOrchestrationClient initialized and started.")
140
+
141
+ unique_remote_api_refs = set()
142
+ for backend_cfg in self.app_config.backends:
143
+ if backend_cfg.orchestration_mode == "remote_http_api":
144
+ if backend_cfg.remote_api_config_ref:
145
+ unique_remote_api_refs.add(backend_cfg.remote_api_config_ref)
146
+ elif backend_cfg.remote_api_config_inline:
147
+ logger.warning(
148
+ f"Inline remote_api_config for {backend_cfg.backend_name_ref}. Consider using global_remote_apis."
149
+ )
150
+ key = backend_cfg.remote_api_config_inline.base_url
151
+ if key not in self._remote_http_orchestrators:
152
+ temp_app_config_for_inline = AppConfig(
153
+ global_remote_apis={key: backend_cfg.remote_api_config_inline}
154
+ )
155
+ client = RemoteHttpOrchestrationClient(temp_app_config_for_inline)
156
+ await client.startup()
157
+ self._remote_http_orchestrators[key] = client
158
+ logger.info(f"RemoteHttpOrchestrationClient for inline config {key} initialized.")
159
+
160
+ for ref_name in unique_remote_api_refs:
161
+ if ref_name not in self.app_config.global_remote_apis:
162
+ logger.error(f"Remote API ref '{ref_name}' not in global_remote_apis.")
163
+ continue
164
+ if ref_name not in self._remote_http_orchestrators:
165
+ isolated_app_cfg = AppConfig(
166
+ global_remote_apis={ref_name: self.app_config.global_remote_apis[ref_name]},
167
+ global_remote_api_defaults=self.app_config.global_remote_api_defaults,
168
+ )
169
+ client = RemoteHttpOrchestrationClient(isolated_app_cfg)
170
+ await client.startup()
171
+ self._remote_http_orchestrators[ref_name] = client
172
+ logger.info(f"RemoteHttpOrchestrationClient for '{ref_name}' initialized.")
173
+ logger.info("Orchestration clients initialization complete.")
174
+
175
+ def _get_orchestration_client(self, backend_cfg: BackendServerConfig) -> AbstractOrchestrationClient:
176
+ if backend_cfg.orchestration_mode == "local_docker":
177
+ if not self._local_docker_orchestrator:
178
+ raise RuntimeError("Local Docker orchestrator not initialized.")
179
+ return self._local_docker_orchestrator
180
+ elif backend_cfg.orchestration_mode == "remote_http_api":
181
+ key = backend_cfg.remote_api_config_ref
182
+ if not key:
183
+ if backend_cfg.remote_api_config_inline:
184
+ key = backend_cfg.remote_api_config_inline.base_url
185
+ else:
186
+ raise ValueError(f"Remote API config missing for {backend_cfg.backend_name_ref}")
187
+ client = self._remote_http_orchestrators.get(key)
188
+ if not client:
189
+ raise RuntimeError(f"Remote HTTP orchestrator for '{key}' not initialized.")
190
+ return client
191
+ else:
192
+ raise ValueError(f"Unsupported orchestration mode: {backend_cfg.orchestration_mode}")
193
+
194
+ async def _get_or_provision_shared_global_instance(self, backend_name_ref: str) -> ManagedInstanceInfo:
195
+ if backend_name_ref not in self._shared_instance_locks:
196
+ self._shared_instance_locks[backend_name_ref] = asyncio.Lock()
197
+ async with self._shared_instance_locks[backend_name_ref]:
198
+ if backend_name_ref in self._shared_global_instances:
199
+ logger.info(f"Returning existing shared global instance for '{backend_name_ref}'.")
200
+ return self._shared_global_instances[backend_name_ref]
201
+ logger.info(f"Provisioning new shared global instance for '{backend_name_ref}'.")
202
+ backend_cfg = next(
203
+ (b for b in self.app_config.backends if b.backend_name_ref == backend_name_ref),
204
+ None,
205
+ )
206
+ if not backend_cfg or backend_cfg.instance_scoping != "shared_global":
207
+ raise ValueError(f"Backend '{backend_name_ref}' not for shared_global scoping.")
208
+ orchestration_client = self._get_orchestration_client(backend_cfg)
209
+ provisioned_list = await orchestration_client.provision_instances(
210
+ backend_config=backend_cfg,
211
+ num_instances=1,
212
+ session_id="global_shared_session",
213
+ template_details=backend_cfg.template_data_path_host,
214
+ )
215
+ if not provisioned_list:
216
+ raise RuntimeError(f"Failed to provision shared global for '{backend_name_ref}'.")
217
+ instance_info = provisioned_list[0]
218
+ self._shared_global_instances[backend_name_ref] = instance_info
219
+ logger.info(f"Provisioned shared global for '{backend_name_ref}': {instance_info.instance_id}")
220
+ return instance_info
221
+
222
+ async def _provision_shared_global_instances(self):
223
+ logger.info("Pre-provisioning all shared_global instances...")
224
+ for backend_cfg in self.app_config.backends:
225
+ if backend_cfg.instance_scoping == "shared_global":
226
+ try:
227
+ await self._get_or_provision_shared_global_instance(backend_cfg.backend_name_ref)
228
+ except Exception as e:
229
+ logger.error(
230
+ f"Failed to pre-provision for '{backend_cfg.backend_name_ref}': {e}",
231
+ exc_info=True,
232
+ )
233
+ logger.info("Shared_global instances pre-provisioning complete.")
234
+
235
+ async def _initialize_session_actual(self, mcp_ctx: FastMCPContext, args: InitializeSessionArgs) -> Dict[str, Any]:
236
+ task_name = anyio.get_current_task().name if anyio.get_current_task() else "unknown_task"
237
+ logger.debug(
238
+ f"ENTERING _initialize_session_actual: task='{task_name}', mcp_ctx type: {type(mcp_ctx)}, args: {args}"
239
+ )
240
+
241
+ transport_session_id: Optional[str] = None
242
+ if (
243
+ hasattr(mcp_ctx, "session")
244
+ and mcp_ctx.session
245
+ and hasattr(mcp_ctx.session, "client_params")
246
+ and mcp_ctx.session.client_params
247
+ and hasattr(mcp_ctx.session.client_params, "session_id")
248
+ and mcp_ctx.session.client_params.session_id
249
+ ):
250
+ transport_session_id = mcp_ctx.session.client_params.session_id
251
+ logger.info(f"Retrieved transport_session_id: {transport_session_id}")
252
+
253
+ rk_session_id = transport_session_id if transport_session_id else uuid.uuid4().hex
254
+ if not transport_session_id:
255
+ logger.warning(f"Transport session ID not found. Generated new rk_session_id: {rk_session_id}")
256
+ else:
257
+ logger.info(f"Using transport_session_id as rk_session_id: {rk_session_id}")
258
+
259
+ if rk_session_id in self.intermediary_session_data:
260
+ logger.warning(f"rk_session_id '{rk_session_id}' already exists. Overwriting.")
261
+ session_data = IntermediarySessionData(session_id=rk_session_id)
262
+ self.intermediary_session_data[rk_session_id] = session_data
263
+
264
+ logger.info(
265
+ f"Initializing IntermediarySessionData for rk_session_id '{rk_session_id}' with {len(args.backends)} backend requests."
266
+ )
267
+ initialized_backends_results: List[BackendInitResult] = []
268
+
269
+ for backend_req in args.backends:
270
+ backend_cfg = next(
271
+ (b for b in self.app_config.backends if b.backend_name_ref == backend_req.backend_name_ref),
272
+ None,
273
+ )
274
+ if not backend_cfg:
275
+ logger.error(f"Session {rk_session_id}: Config for '{backend_req.backend_name_ref}' not found.")
276
+ initialized_backends_results.append(
277
+ BackendInitResult(backend_name_ref=backend_req.backend_name_ref, instances=[])
278
+ )
279
+ continue
280
+ try:
281
+ if backend_cfg.instance_scoping == "shared_global":
282
+ shared_instance_info = await self._get_or_provision_shared_global_instance(
283
+ backend_req.backend_name_ref
284
+ )
285
+ instances_for_this_backend = [shared_instance_info] * backend_req.num_instances
286
+ else:
287
+ orchestration_client = self._get_orchestration_client(backend_cfg)
288
+ instances_for_this_backend = await orchestration_client.provision_instances(
289
+ backend_config=backend_cfg,
290
+ num_instances=backend_req.num_instances,
291
+ session_id=session_data.session_id,
292
+ template_details=backend_req.template_details,
293
+ )
294
+ session_data.add_managed_instances(backend_req.backend_name_ref, instances_for_this_backend)
295
+ initialized_backends_results.append(
296
+ BackendInitResult(
297
+ backend_name_ref=backend_req.backend_name_ref,
298
+ instances=instances_for_this_backend,
299
+ )
300
+ )
301
+ except Exception as e:
302
+ logger.error(
303
+ f"Session {rk_session_id}: Error initializing '{backend_req.backend_name_ref}': {e}",
304
+ exc_info=True,
305
+ )
306
+ initialized_backends_results.append(
307
+ BackendInitResult(
308
+ backend_name_ref=backend_req.backend_name_ref,
309
+ instances=[],
310
+ error_message=str(e),
311
+ )
312
+ )
313
+
314
+ task_name_exit = anyio.get_current_task().name if anyio.get_current_task() else "unknown_task"
315
+ logger.debug(f"EXITING _initialize_session_actual: task='{task_name_exit}'")
316
+ return {
317
+ "rk_session_id": rk_session_id,
318
+ "initialized_backends": [res.model_dump(exclude_none=True) for res in initialized_backends_results],
319
+ }
320
+
321
+ async def _call_backend_tool_actual(self, mcp_ctx: FastMCPContext, args: CallBackendToolArgs) -> Dict[str, Any]:
322
+ task_name_entry = anyio.get_current_task().name if anyio.get_current_task() else "unknown_task"
323
+ logger.debug(
324
+ f"ENTERING _call_backend_tool_actual: task='{task_name_entry}', mcp_ctx type: {type(mcp_ctx)}, args: {args}"
325
+ )
326
+
327
+ session_data = self.intermediary_session_data.get(args.rk_session_id)
328
+ if not session_data:
329
+ task_name_error = anyio.get_current_task().name if anyio.get_current_task() else "unknown_task"
330
+ logger.error(
331
+ f"ERROR in _call_backend_tool_actual (session not found): task='{task_name_error}', rk_session_id='{args.rk_session_id}'"
332
+ )
333
+ raise ValueError(f"IntermediarySessionData for rk_session_id '{args.rk_session_id}' not found.")
334
+
335
+ target_instances = session_data.get_managed_instances(args.backend_name_ref, args.instance_id)
336
+ if not target_instances:
337
+ raise ValueError(
338
+ f"Instance '{args.instance_id}' for backend '{args.backend_name_ref}' not found in session '{args.rk_session_id}'."
339
+ )
340
+ managed_instance_info = target_instances[0]
341
+ backend_cfg = next(
342
+ (b for b in self.app_config.backends if b.backend_name_ref == args.backend_name_ref),
343
+ None,
344
+ )
345
+ if not backend_cfg:
346
+ raise ValueError(f"Backend config '{args.backend_name_ref}' not found.")
347
+ orchestration_client = self._get_orchestration_client(backend_cfg)
348
+
349
+ task_name_before_call = anyio.get_current_task().name if anyio.get_current_task() else "unknown_task"
350
+ logger.debug(
351
+ f"BEFORE orchestrator.call_tool_on_instance in _call_backend_tool_actual: task='{task_name_before_call}'"
352
+ )
353
+
354
+ try:
355
+ result = await orchestration_client.call_tool_on_instance(
356
+ instance=managed_instance_info,
357
+ tool_name=args.tool_name,
358
+ tool_args=args.tool_args,
359
+ )
360
+ task_name_after_call = anyio.get_current_task().name if anyio.get_current_task() else "unknown_task"
361
+ logger.debug(
362
+ f"AFTER orchestrator.call_tool_on_instance in _call_backend_tool_actual: task='{task_name_after_call}'"
363
+ )
364
+
365
+ task_name_exit = anyio.get_current_task().name if anyio.get_current_task() else "unknown_task"
366
+ logger.debug(f"EXITING _call_backend_tool_actual (SUCCESS): task='{task_name_exit}'")
367
+ return result
368
+ except Exception as e:
369
+ task_name_exception = anyio.get_current_task().name if anyio.get_current_task() else "unknown_task"
370
+ logger.error(
371
+ f"EXCEPTION in _call_backend_tool_actual: task='{task_name_exception}'. Session {args.rk_session_id}: Error calling tool '{args.tool_name}' on instance '{args.instance_id}': {e}",
372
+ exc_info=True,
373
+ )
374
+ raise
375
+
376
+ async def _list_backend_tools_actual(
377
+ self, mcp_ctx: FastMCPContext, args: ListBackendToolsArgs
378
+ ) -> Dict[str, Any]: # Returning dict for FastMCP, will be ListToolsResult internally
379
+ task_name_entry = anyio.get_current_task().name if anyio.get_current_task() else "unknown_task"
380
+ logger.debug(f"ENTERING _list_backend_tools_actual: task='{task_name_entry}', args: {args}")
381
+
382
+ session_data = self.intermediary_session_data.get(args.rk_session_id)
383
+ if not session_data:
384
+ logger.error(
385
+ f"ERROR in _list_backend_tools_actual (session not found): rk_session_id='{args.rk_session_id}'"
386
+ )
387
+ raise ValueError(f"IntermediarySessionData for rk_session_id '{args.rk_session_id}' not found.")
388
+
389
+ target_instances = session_data.get_managed_instances(args.backend_name_ref, args.instance_id)
390
+ if not target_instances:
391
+ raise ValueError(
392
+ f"Instance '{args.instance_id}' for backend '{args.backend_name_ref}' not found in session '{args.rk_session_id}'."
393
+ )
394
+ managed_instance_info = target_instances[0]
395
+
396
+ backend_cfg = next(
397
+ (b for b in self.app_config.backends if b.backend_name_ref == args.backend_name_ref),
398
+ None,
399
+ )
400
+ if not backend_cfg:
401
+ raise ValueError(f"Backend config '{args.backend_name_ref}' not found.")
402
+ orchestration_client = self._get_orchestration_client(backend_cfg)
403
+
404
+ logger.debug(
405
+ f"Calling orchestrator.list_tools_on_instance for backend '{args.backend_name_ref}', instance '{args.instance_id}'"
406
+ )
407
+ try:
408
+ list_tools_result: mcp_types.ListToolsResult = await orchestration_client.list_tools_on_instance(
409
+ instance=managed_instance_info
410
+ )
411
+ # FastMCP tools expect to return a dictionary that can be JSON serialized.
412
+ # ListToolsResult is a Pydantic model, so model_dump() is appropriate.
413
+ return list_tools_result.model_dump(exclude_none=True)
414
+ except Exception as e:
415
+ logger.error(
416
+ f"EXCEPTION in _list_backend_tools_actual for session {args.rk_session_id}, backend {args.backend_name_ref}, instance {args.instance_id}: {e}",
417
+ exc_info=True,
418
+ )
419
+ raise # Re-raise to let FastMCP handle error reporting to client
420
+
421
+ async def cleanup_session_internal(self, session_data_to_clean: IntermediarySessionData, rk_session_id: str):
422
+ logger.info(f"Starting internal cleanup for IntermediarySessionData (rk_session_id: '{rk_session_id}').")
423
+ all_session_instances = session_data_to_clean.get_all_managed_instances()
424
+ local_docker_instances = [inst for inst in all_session_instances if inst.orchestration_mode == "local_docker"]
425
+ if local_docker_instances and self._local_docker_orchestrator:
426
+ try:
427
+ await self._local_docker_orchestrator.deprovision_instances(local_docker_instances)
428
+ except Exception as e:
429
+ logger.error(
430
+ f"Session {rk_session_id}: Error deprovisioning local Docker: {e}",
431
+ exc_info=True,
432
+ )
433
+
434
+ remote_instances_by_key: Dict[str, List[ManagedInstanceInfo]] = {}
435
+ for inst in all_session_instances:
436
+ if inst.orchestration_mode == "remote_http_api":
437
+ key = self._get_orchestration_client_key_for_instance(inst)
438
+ if key:
439
+ remote_instances_by_key.setdefault(key, []).append(inst)
440
+ for key, remote_list in remote_instances_by_key.items():
441
+ orchestrator = self._remote_http_orchestrators.get(key)
442
+ if orchestrator and remote_list:
443
+ try:
444
+ await orchestrator.deprovision_instances(remote_list)
445
+ except Exception as e:
446
+ logger.error(
447
+ f"Session {rk_session_id}: Error deprovisioning remote for '{key}': {e}",
448
+ exc_info=True,
449
+ )
450
+ logger.info(f"Internal cleanup for session data (rk_session_id: '{rk_session_id}') complete.")
451
+
452
+ async def _cleanup_session_actual(self, mcp_ctx: FastMCPContext, args: CleanupSessionArgs) -> Dict[str, str]:
453
+ logger.debug(f"_cleanup_session_actual called. mcp_ctx type: {type(mcp_ctx)}, args: {args}")
454
+ session_data_obj = self.intermediary_session_data.pop(args.rk_session_id, None)
455
+ if not session_data_obj:
456
+ logger.warning(
457
+ f"IntermediarySessionData for rk_session_id '{args.rk_session_id}' not found or already cleaned."
458
+ )
459
+ return {
460
+ "status": "custom_session_data_not_found_or_already_cleaned",
461
+ "rk_session_id": args.rk_session_id,
462
+ }
463
+ await self.cleanup_session_internal(session_data_obj, args.rk_session_id)
464
+ logger.info(f"IntermediarySessionData for rk_session_id '{args.rk_session_id}' fully cleaned up.")
465
+ return {"status": "cleaned", "rk_session_id": args.rk_session_id}
466
+
467
+ async def startup(self):
468
+ logger.info("RewardKitIntermediaryServer performing custom startup tasks...")
469
+ try:
470
+ await self._initialize_orchestrators()
471
+ await self._provision_shared_global_instances()
472
+ logger.info("RewardKitIntermediaryServer custom startup tasks complete.")
473
+ except Exception as e:
474
+ logger.error(
475
+ f"Error during RewardKitIntermediaryServer custom startup: {e}",
476
+ exc_info=True,
477
+ )
478
+ raise
479
+
480
+ async def _ping_actual(self, mcp_ctx: FastMCPContext, args: PingArgs) -> Dict[str, str]:
481
+ logger.debug(f"_ping_actual called. mcp_ctx type: {type(mcp_ctx)}, args: {args}")
482
+ ping_session_id: Optional[str] = None
483
+ if args.rk_session_id: # If client provides its known rk_session_id
484
+ ping_session_id = args.rk_session_id
485
+ logger.info(f"Ping using rk_session_id from args: {ping_session_id}")
486
+ elif (
487
+ hasattr(mcp_ctx, "session")
488
+ and mcp_ctx.session
489
+ and hasattr(mcp_ctx.session, "client_params")
490
+ and mcp_ctx.session.client_params
491
+ and hasattr(mcp_ctx.session.client_params, "session_id")
492
+ and mcp_ctx.session.client_params.session_id
493
+ ):
494
+ ping_session_id = mcp_ctx.session.client_params.session_id
495
+ logger.info(f"Ping using transport_session_id from mcp_ctx: {ping_session_id}")
496
+ else:
497
+ ping_session_id = "unknown_session_for_ping"
498
+ logger.warning(f"Session ID for ping not found in args or mcp_ctx, using fallback: {ping_session_id}")
499
+ return {"reply": "pong", "session_id": ping_session_id or ""}
500
+
501
+ async def shutdown(self):
502
+ logger.info("RewardKitIntermediaryServer (FastMCP based) performing custom shutdown tasks...")
503
+ logger.info(f"Cleaning up {len(self.intermediary_session_data)} IntermediarySessionData entries...")
504
+ for session_id_key in list(self.intermediary_session_data.keys()):
505
+ session_data_obj = self.intermediary_session_data.pop(session_id_key, None)
506
+ if session_data_obj:
507
+ await self.cleanup_session_internal(session_data_obj, session_id_key)
508
+
509
+ shared_instances = list(self._shared_global_instances.values())
510
+ if shared_instances:
511
+ logger.info(f"Deprovisioning {len(shared_instances)} shared global instances.")
512
+ local_shared = [i for i in shared_instances if i.orchestration_mode == "local_docker"]
513
+ if local_shared and self._local_docker_orchestrator:
514
+ await self._local_docker_orchestrator.deprovision_instances(local_shared)
515
+ remote_shared_by_key: Dict[str, List[ManagedInstanceInfo]] = {}
516
+ for inst_info in shared_instances:
517
+ if inst_info.orchestration_mode == "remote_http_api":
518
+ key = self._get_orchestration_client_key_for_instance(inst_info)
519
+ if key:
520
+ remote_shared_by_key.setdefault(key, []).append(inst_info)
521
+ for key, instances_list in remote_shared_by_key.items():
522
+ orchestrator = self._remote_http_orchestrators.get(key)
523
+ if orchestrator:
524
+ await orchestrator.deprovision_instances(instances_list)
525
+
526
+ if self._local_docker_orchestrator:
527
+ await self._local_docker_orchestrator.shutdown()
528
+ for orch in self._remote_http_orchestrators.values():
529
+ await orch.shutdown()
530
+ logger.info("RewardKitIntermediaryServer custom shutdown tasks complete.")
531
+
532
+ def _get_orchestration_client_key_for_instance(self, instance_info: ManagedInstanceInfo) -> Optional[str]:
533
+ if instance_info.orchestration_mode == "remote_http_api":
534
+ backend_cfg = next(
535
+ (b for b in self.app_config.backends if b.backend_name_ref == instance_info.backend_name_ref),
536
+ None,
537
+ )
538
+ if backend_cfg:
539
+ return backend_cfg.remote_api_config_ref or (
540
+ backend_cfg.remote_api_config_inline.base_url if backend_cfg.remote_api_config_inline else None
541
+ )
542
+ return None