eval-protocol 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. development/__init__.py +1 -0
  2. development/normalize_sandbox_fusion.py +628 -0
  3. development/utils/__init__.py +1 -0
  4. development/utils/generate_api_key.py +31 -0
  5. development/utils/subprocess_manager.py +481 -0
  6. eval_protocol/__init__.py +86 -0
  7. eval_protocol/__main__.py +10 -0
  8. eval_protocol/_version.py +21 -0
  9. eval_protocol/adapters/__init__.py +1 -0
  10. eval_protocol/adapters/braintrust.py +8 -0
  11. eval_protocol/adapters/trl.py +8 -0
  12. eval_protocol/agent/__init__.py +29 -0
  13. eval_protocol/agent/models.py +69 -0
  14. eval_protocol/agent/orchestrator.py +893 -0
  15. eval_protocol/agent/resource_abc.py +89 -0
  16. eval_protocol/agent/resource_pool.py +184 -0
  17. eval_protocol/agent/resources/__init__.py +44 -0
  18. eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
  19. eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
  20. eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
  21. eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
  22. eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
  23. eval_protocol/agent/resources/docker_resource.py +479 -0
  24. eval_protocol/agent/resources/filesystem_resource.py +371 -0
  25. eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
  26. eval_protocol/agent/resources/http_rollout_resource.py +325 -0
  27. eval_protocol/agent/resources/python_state_resource.py +170 -0
  28. eval_protocol/agent/resources/sql_resource.py +271 -0
  29. eval_protocol/agent/task_manager.py +1064 -0
  30. eval_protocol/agent/tool_registry.py +111 -0
  31. eval_protocol/auth.py +156 -0
  32. eval_protocol/cli.py +425 -0
  33. eval_protocol/cli_commands/__init__.py +1 -0
  34. eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
  35. eval_protocol/cli_commands/common.py +242 -0
  36. eval_protocol/cli_commands/deploy.py +486 -0
  37. eval_protocol/cli_commands/deploy_mcp.py +287 -0
  38. eval_protocol/cli_commands/preview.py +186 -0
  39. eval_protocol/cli_commands/run_eval_cmd.py +202 -0
  40. eval_protocol/common_utils.py +36 -0
  41. eval_protocol/config.py +180 -0
  42. eval_protocol/datasets/__init__.py +1 -0
  43. eval_protocol/datasets/loader.py +521 -0
  44. eval_protocol/evaluation.py +1045 -0
  45. eval_protocol/execution/__init__.py +1 -0
  46. eval_protocol/execution/pipeline.py +920 -0
  47. eval_protocol/gcp_tools.py +484 -0
  48. eval_protocol/generation/cache.py +141 -0
  49. eval_protocol/generation/clients/base.py +67 -0
  50. eval_protocol/generation/clients.py +248 -0
  51. eval_protocol/generic_server.py +165 -0
  52. eval_protocol/integrations/__init__.py +12 -0
  53. eval_protocol/integrations/braintrust.py +51 -0
  54. eval_protocol/integrations/deepeval.py +106 -0
  55. eval_protocol/integrations/openeval.py +40 -0
  56. eval_protocol/integrations/trl.py +187 -0
  57. eval_protocol/mcp/__init__.py +48 -0
  58. eval_protocol/mcp/adapter.py +131 -0
  59. eval_protocol/mcp/client/__init__.py +12 -0
  60. eval_protocol/mcp/client/connection.py +499 -0
  61. eval_protocol/mcp/clients.py +195 -0
  62. eval_protocol/mcp/execution/__init__.py +23 -0
  63. eval_protocol/mcp/execution/base_policy.py +227 -0
  64. eval_protocol/mcp/execution/fireworks_policy.py +209 -0
  65. eval_protocol/mcp/execution/manager.py +506 -0
  66. eval_protocol/mcp/execution/policy.py +421 -0
  67. eval_protocol/mcp/grid_renderer.py +54 -0
  68. eval_protocol/mcp/mcpgym.py +637 -0
  69. eval_protocol/mcp/process_manager.py +177 -0
  70. eval_protocol/mcp/session/__init__.py +11 -0
  71. eval_protocol/mcp/session/manager.py +228 -0
  72. eval_protocol/mcp/simple_process_manager.py +291 -0
  73. eval_protocol/mcp/simulation_server.py +458 -0
  74. eval_protocol/mcp/types.py +80 -0
  75. eval_protocol/mcp_agent/__init__.py +1 -0
  76. eval_protocol/mcp_agent/config.py +147 -0
  77. eval_protocol/mcp_agent/intermediary_server.py +542 -0
  78. eval_protocol/mcp_agent/main.py +210 -0
  79. eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
  80. eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
  81. eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
  82. eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
  83. eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
  84. eval_protocol/mcp_agent/session.py +79 -0
  85. eval_protocol/mcp_env.py +304 -0
  86. eval_protocol/models.py +366 -0
  87. eval_protocol/packaging.py +219 -0
  88. eval_protocol/platform_api.py +360 -0
  89. eval_protocol/playback_policy.py +396 -0
  90. eval_protocol/resources.py +128 -0
  91. eval_protocol/reward_function.py +410 -0
  92. eval_protocol/rewards/__init__.py +94 -0
  93. eval_protocol/rewards/accuracy.py +454 -0
  94. eval_protocol/rewards/accuracy_length.py +173 -0
  95. eval_protocol/rewards/apps_coding_reward.py +331 -0
  96. eval_protocol/rewards/apps_execution_utils.py +149 -0
  97. eval_protocol/rewards/apps_testing_util.py +559 -0
  98. eval_protocol/rewards/bfcl_reward.py +313 -0
  99. eval_protocol/rewards/code_execution.py +1620 -0
  100. eval_protocol/rewards/code_execution_utils.py +72 -0
  101. eval_protocol/rewards/cpp_code.py +861 -0
  102. eval_protocol/rewards/deepcoder_reward.py +161 -0
  103. eval_protocol/rewards/format.py +129 -0
  104. eval_protocol/rewards/function_calling.py +541 -0
  105. eval_protocol/rewards/json_schema.py +422 -0
  106. eval_protocol/rewards/language_consistency.py +700 -0
  107. eval_protocol/rewards/lean_prover.py +479 -0
  108. eval_protocol/rewards/length.py +375 -0
  109. eval_protocol/rewards/list_comparison_math_reward.py +221 -0
  110. eval_protocol/rewards/math.py +762 -0
  111. eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
  112. eval_protocol/rewards/reasoning_steps.py +249 -0
  113. eval_protocol/rewards/repetition.py +342 -0
  114. eval_protocol/rewards/tag_count.py +162 -0
  115. eval_protocol/rl_processing.py +82 -0
  116. eval_protocol/server.py +271 -0
  117. eval_protocol/typed_interface.py +260 -0
  118. eval_protocol/utils/__init__.py +8 -0
  119. eval_protocol/utils/batch_evaluation.py +217 -0
  120. eval_protocol/utils/batch_transformation.py +205 -0
  121. eval_protocol/utils/dataset_helpers.py +112 -0
  122. eval_protocol/utils/module_loader.py +56 -0
  123. eval_protocol/utils/packaging_utils.py +108 -0
  124. eval_protocol/utils/static_policy.py +305 -0
  125. eval_protocol-0.0.3.dist-info/METADATA +635 -0
  126. eval_protocol-0.0.3.dist-info/RECORD +130 -0
  127. eval_protocol-0.0.3.dist-info/WHEEL +5 -0
  128. eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
  129. eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
  130. eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
@@ -0,0 +1,304 @@
1
+ import logging
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import httpx
5
+ from mcp import types as mcp_types
6
+ from mcp.client.session import ClientSession
7
+ from mcp.client.streamable_http import streamablehttp_client
8
+
9
+ from eval_protocol.mcp_agent.config import (
10
+ AppConfig,
11
+ BackendServerConfig,
12
+ RemoteApiConfig,
13
+ )
14
+ from eval_protocol.mcp_agent.orchestration.base_client import (
15
+ AbstractOrchestrationClient,
16
+ ManagedInstanceInfo,
17
+ )
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class RemoteHttpOrchestrationClient(AbstractOrchestrationClient):
23
+ """
24
+ Orchestrates backend MCP server instances by communicating with a remote HTTP API.
25
+ This client translates provisioning, deprovisioning, and tool call requests
26
+ into HTTP requests to a customer-defined remote orchestration service.
27
+ """
28
+
29
+ def __init__(self, app_config: AppConfig):
30
+ self.app_config = app_config
31
+ self.http_client: Optional[httpx.AsyncClient] = None
32
+
33
+ async def startup(self) -> None:
34
+ """Initializes the httpx client."""
35
+ # Default timeout can be overridden by specific remote_api_config later
36
+ timeout_config = httpx.Timeout(
37
+ self.app_config.global_remote_api_defaults.get("timeout", 30.0),
38
+ connect=self.app_config.global_remote_api_defaults.get("connect_timeout", 5.0),
39
+ )
40
+ self.http_client = httpx.AsyncClient(timeout=timeout_config)
41
+ logger.info("RemoteHttpOrchestrationClient started.")
42
+
43
+ async def shutdown(self) -> None:
44
+ """Closes the httpx client."""
45
+ if self.http_client:
46
+ await self.http_client.aclose()
47
+ logger.info("HTTPX client for RemoteHttpOrchestrationClient closed.")
48
+ logger.info("RemoteHttpOrchestrationClient shut down.")
49
+
50
+ def _get_auth_headers(self, remote_api_config: RemoteApiConfig) -> Dict[str, str]:
51
+ """Constructs authentication headers based on the remote API config."""
52
+ headers = {}
53
+ if remote_api_config.auth_type == "bearer_token":
54
+ token = remote_api_config.auth_details.get("token")
55
+ if token:
56
+ headers["Authorization"] = f"Bearer {token}"
57
+ else:
58
+ logger.warning("Bearer token auth selected but no token provided.")
59
+ elif remote_api_config.auth_type == "custom_header":
60
+ header_name = remote_api_config.auth_details.get("header_name")
61
+ header_value = remote_api_config.auth_details.get("header_value")
62
+ if header_name and header_value:
63
+ headers[header_name] = header_value
64
+ else:
65
+ logger.warning("Custom header auth selected but header_name or header_value missing.")
66
+ return headers
67
+
68
+ async def _make_request(
69
+ self,
70
+ method: str,
71
+ url: str,
72
+ remote_api_config: RemoteApiConfig,
73
+ json_payload: Optional[Dict[str, Any]] = None,
74
+ params: Optional[Dict[str, Any]] = None,
75
+ ) -> httpx.Response:
76
+ """Helper method to make HTTP requests with authentication and error handling."""
77
+ if not self.http_client:
78
+ raise RuntimeError("HTTP client not initialized. Call startup() first.")
79
+
80
+ headers = self._get_auth_headers(remote_api_config)
81
+ headers["Content-Type"] = "application/json" # Assume JSON requests
82
+
83
+ try:
84
+ logger.debug(f"Making {method} request to {url} with payload: {json_payload} and params: {params}")
85
+ response = await self.http_client.request(method, url, headers=headers, json=json_payload, params=params)
86
+ response.raise_for_status() # Raise an exception for 4xx/5xx responses
87
+ return response
88
+ except httpx.RequestError as e:
89
+ logger.error(f"Request error during {method} to {url}: {e}")
90
+ raise RuntimeError(f"Remote API request failed: Network error calling {url}") from e
91
+ except httpx.HTTPStatusError as e:
92
+ logger.error(f"HTTP status error during {method} to {url}: {e.response.status_code} - {e.response.text}")
93
+ try:
94
+ error_details = e.response.json()
95
+ except Exception:
96
+ error_details = e.response.text
97
+ raise RuntimeError(
98
+ f"Remote API request failed: Server returned error {e.response.status_code}. Details: {error_details}"
99
+ ) from e
100
+
101
+ async def provision_instances(
102
+ self,
103
+ backend_config: BackendServerConfig,
104
+ num_instances: int,
105
+ session_id: str,
106
+ template_details: Optional[Any] = None,
107
+ ) -> List[ManagedInstanceInfo]:
108
+ if backend_config.orchestration_mode != "remote_http_api":
109
+ raise ValueError("RemoteHttpOrchestrationClient can only handle 'remote_http_api' mode.")
110
+
111
+ remote_api_config = self.app_config.get_remote_api_config(backend_config)
112
+ if not remote_api_config:
113
+ raise ValueError(f"RemoteApiConfig not found for backend {backend_config.backend_name_ref}.")
114
+
115
+ create_url = (
116
+ f"{remote_api_config.base_url.rstrip('/')}/{remote_api_config.create_instance_endpoint.lstrip('/')}"
117
+ )
118
+
119
+ provisioned_instances_info: List[ManagedInstanceInfo] = []
120
+
121
+ # The remote API might support batch creation or require individual calls.
122
+ # This example assumes the remote API can take num_instances and returns a list.
123
+ # Adjust if the API requires one call per instance.
124
+ payload = {
125
+ "resource_type_identifier": backend_config.remote_resource_type_identifier,
126
+ "num_instances": num_instances,
127
+ "session_id": session_id,
128
+ "instance_scoping": backend_config.instance_scoping,
129
+ "template_details": template_details, # Pass along any template info
130
+ # Add any other necessary parameters the remote API expects
131
+ }
132
+
133
+ logger.info(
134
+ f"Requesting {num_instances} instances of type '{backend_config.remote_resource_type_identifier}' from {create_url}"
135
+ )
136
+
137
+ response = await self._make_request("POST", create_url, remote_api_config, json_payload=payload)
138
+ response_data = response.json() # Expecting a list of instance details
139
+
140
+ if not isinstance(response_data, list):
141
+ raise ValueError(
142
+ f"Remote API at {create_url} did not return a list of instances. Response: {response_data}"
143
+ )
144
+
145
+ for i, inst_data in enumerate(response_data):
146
+ # The remote API response should provide necessary details for ManagedInstanceInfo
147
+ # Required: instance_id (client-facing), mcp_endpoint_url, internal_instance_details (like remote_instance_id)
148
+ remote_instance_id = inst_data.get("remote_instance_id")
149
+ mcp_endpoint_url = inst_data.get("mcp_endpoint_url")
150
+ client_facing_instance_id = inst_data.get(
151
+ "instance_id", f"{session_id}-{backend_config.backend_name_ref}-{i}"
152
+ )
153
+
154
+ if not remote_instance_id or not mcp_endpoint_url:
155
+ logger.error(
156
+ f"Remote API response for instance missing 'remote_instance_id' or 'mcp_endpoint_url'. Data: {inst_data}"
157
+ )
158
+ # Decide on error handling: skip this instance, or fail all?
159
+ # For now, let's raise an error if critical info is missing.
160
+ raise ValueError(f"Remote API response for instance creation is incomplete: {inst_data}")
161
+
162
+ provisioned_instances_info.append(
163
+ ManagedInstanceInfo(
164
+ instance_id=client_facing_instance_id,
165
+ backend_name_ref=backend_config.backend_name_ref,
166
+ orchestration_mode="remote_http_api",
167
+ mcp_endpoint_url=mcp_endpoint_url,
168
+ internal_instance_details={
169
+ "remote_instance_id": remote_instance_id,
170
+ **inst_data.get("additional_details", {}), # Any other info from remote
171
+ },
172
+ )
173
+ )
174
+ logger.info(
175
+ f"Instance {client_facing_instance_id} (Remote ID: {remote_instance_id}) provisioned. MCP Endpoint: {mcp_endpoint_url}"
176
+ )
177
+
178
+ if (
179
+ len(provisioned_instances_info) != num_instances and num_instances > 0 and len(response_data) > 0
180
+ ): # if API supports batch and returns partial
181
+ logger.warning(
182
+ f"Requested {num_instances} but remote API returned details for {len(provisioned_instances_info)} instances."
183
+ )
184
+
185
+ return provisioned_instances_info
186
+
187
+ async def deprovision_instances(self, instances: List[ManagedInstanceInfo]) -> None:
188
+ for instance in instances:
189
+ if instance.orchestration_mode != "remote_http_api":
190
+ logger.warning(
191
+ f"Skipping deprovision for instance {instance.instance_id} as it's not remote_http_api."
192
+ )
193
+ continue
194
+
195
+ # Need to find the BackendServerConfig that led to this instance to get its RemoteApiConfig
196
+ backend_cfg = next(
197
+ (b for b in self.app_config.backends if b.backend_name_ref == instance.backend_name_ref),
198
+ None,
199
+ )
200
+ if not backend_cfg:
201
+ logger.error(
202
+ f"Could not find BackendServerConfig for {instance.backend_name_ref} during deprovision of {instance.instance_id}"
203
+ )
204
+ continue
205
+
206
+ remote_api_config = self.app_config.get_remote_api_config(backend_cfg)
207
+ if not remote_api_config:
208
+ logger.error(
209
+ f"RemoteApiConfig not found for backend {instance.backend_name_ref} during deprovision of {instance.instance_id}."
210
+ )
211
+ continue
212
+
213
+ remote_instance_id = instance.internal_instance_details.get("remote_instance_id")
214
+ if not remote_instance_id:
215
+ logger.warning(f"No remote_instance_id found for instance {instance.instance_id}. Cannot deprovision.")
216
+ continue
217
+
218
+ delete_url_template = remote_api_config.delete_instance_endpoint_template
219
+ delete_url = f"{remote_api_config.base_url.rstrip('/')}/{delete_url_template.lstrip('/').format(remote_instance_id=remote_instance_id)}"
220
+
221
+ logger.info(f"Requesting deprovision of remote instance {remote_instance_id} via {delete_url}")
222
+ try:
223
+ await self._make_request("DELETE", delete_url, remote_api_config)
224
+ logger.info(f"Successfully requested deprovision for remote instance {remote_instance_id}.")
225
+ except Exception as e:
226
+ # Log error but continue trying to deprovision other instances
227
+ logger.error(f"Failed to deprovision remote instance {remote_instance_id}: {e}")
228
+
229
+ async def call_tool_on_instance(
230
+ self, instance: ManagedInstanceInfo, tool_name: str, tool_args: Dict[str, Any]
231
+ ) -> Dict[str, Any]:
232
+ if instance.orchestration_mode != "remote_http_api":
233
+ raise ValueError("This client only handles remote_http_api instances.")
234
+
235
+ backend_cfg = next(
236
+ (b for b in self.app_config.backends if b.backend_name_ref == instance.backend_name_ref),
237
+ None,
238
+ )
239
+ if not backend_cfg:
240
+ raise RuntimeError(f"Could not find BackendServerConfig for {instance.backend_name_ref}")
241
+
242
+ remote_api_config = self.app_config.get_remote_api_config(backend_cfg)
243
+ if not remote_api_config:
244
+ raise RuntimeError(f"RemoteApiConfig not found for backend {instance.backend_name_ref}.")
245
+
246
+ mcp_payload = {"tool_name": tool_name, "arguments": tool_args}
247
+
248
+ target_url: str
249
+ # Check if tool calls are proxied through the orchestrator or made directly to the instance
250
+ if remote_api_config.call_tool_endpoint_template:
251
+ remote_instance_id = instance.internal_instance_details.get("remote_instance_id")
252
+ if not remote_instance_id:
253
+ raise ValueError(
254
+ f"Missing remote_instance_id for instance {instance.instance_id} when proxying tool call."
255
+ )
256
+
257
+ call_template = remote_api_config.call_tool_endpoint_template
258
+ # The template might need remote_instance_id and potentially tool_name if it's part of the path
259
+ # Assuming a generic proxy endpoint for now that takes tool_name in payload
260
+ target_url = f"{remote_api_config.base_url.rstrip('/')}/{call_template.lstrip('/').format(remote_instance_id=remote_instance_id)}"
261
+ # The payload to the proxy might need to be wrapped, e.g. including the actual MCP payload
262
+ # For now, assume the proxy forwards the mcp_payload directly.
263
+ logger.debug(f"Proxying tool {tool_name} to {target_url} for instance {instance.instance_id}")
264
+ else:
265
+ # Call tool directly on the instance's MCP endpoint
266
+ target_url = instance.mcp_endpoint_url
267
+ logger.debug(f"Calling tool {tool_name} directly on {target_url} for instance {instance.instance_id}")
268
+
269
+ response = await self._make_request("POST", target_url, remote_api_config, json_payload=mcp_payload)
270
+ return response.json()
271
+
272
+ async def list_tools_on_instance(self, instance: ManagedInstanceInfo) -> mcp_types.ListToolsResult:
273
+ if instance.orchestration_mode != "remote_http_api":
274
+ raise ValueError("RemoteHttpOrchestrationClient can only list tools for 'remote_http_api' instances.")
275
+ if instance.mcp_transport != "http" or not instance.mcp_endpoint_url:
276
+ raise ValueError(
277
+ f"Instance {instance.instance_id} ({instance.backend_name_ref}) is not configured for HTTP MCP transport or mcp_endpoint_url is missing."
278
+ )
279
+
280
+ # Assuming instance.mcp_endpoint_url is the base URL of the target MCP server
281
+ # e.g., "http://localhost:12345"
282
+ target_base_url = instance.mcp_endpoint_url.rstrip("/")
283
+
284
+ logger.info(
285
+ f"Listing tools for remote HTTP instance {instance.instance_id} ({instance.backend_name_ref}) at base URL {target_base_url}"
286
+ )
287
+
288
+ try:
289
+ # streamablehttp_client will manage its own httpx.AsyncClient if one is not provided.
290
+ # The context manager handles session.initialize() and session.close().
291
+ async with streamablehttp_client(base_url=target_base_url) as session: # type: ClientSession
292
+ list_tools_result = await session.list_tools()
293
+ logger.info(
294
+ f"Successfully listed {len(list_tools_result.tools)} tools from {target_base_url} for instance {instance.instance_id} ({instance.backend_name_ref})"
295
+ )
296
+ return list_tools_result
297
+ except Exception as e:
298
+ logger.error(
299
+ f"Error listing tools from {target_base_url} for instance {instance.instance_id} ({instance.backend_name_ref}): {e}",
300
+ exc_info=True,
301
+ )
302
+ raise RuntimeError(
303
+ f"Failed to list tools from backend instance {instance.instance_id} ({instance.backend_name_ref}) at {target_base_url}"
304
+ ) from e
@@ -0,0 +1,3 @@
1
+ # This file is intentionally left empty and can be deleted.
2
+ # The stdio communication will be handled directly by LocalDockerOrchestrationClient
3
+ # using Docker's attach capabilities, not via this helper script.
@@ -0,0 +1,79 @@
1
+ import logging
2
+ from typing import Dict, List, Optional, Set
3
+
4
+ from eval_protocol.mcp_agent.orchestration.base_client import ManagedInstanceInfo
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ from dataclasses import dataclass, field
9
+
10
+ # Attempting to find ReadStream and WriteStream in a different location
11
+ # from mcp.server.streamable_transport import ReadStream, WriteStream # Original problematic import
12
+ # Option 1: Try mcp.server.transport
13
+ # from mcp.server.transport import ReadStream, WriteStream
14
+ # Option 2: If not found, use typing.Any as a fallback for type hints
15
+ from typing import Any as ReadStream # Fallback if specific types are not found
16
+ from typing import Any as WriteStream
17
+
18
+ from mcp.server.session import ServerSession # Correct base class
19
+
20
+ # Placeholder BaseSession class removed.
21
+ # IntermediarySession class is removed as we are using a separate data class.
22
+
23
+
24
+ @dataclass
25
+ class IntermediarySessionData:
26
+ """
27
+ Data class to hold custom state for an intermediary session.
28
+ This state is managed by RewardKitIntermediaryServer and keyed by transport session_id.
29
+ """
30
+
31
+ session_id: str # This is the transport-level session_id
32
+ managed_backends: Dict[str, List[ManagedInstanceInfo]] = field(default_factory=dict)
33
+ temporary_docker_images: Set[str] = field(default_factory=set)
34
+
35
+ def add_managed_instances(self, backend_name_ref: str, instances: List[ManagedInstanceInfo]):
36
+ """Adds a list of managed instances for a given backend reference."""
37
+ if backend_name_ref not in self.managed_backends:
38
+ self.managed_backends[backend_name_ref] = []
39
+ self.managed_backends[backend_name_ref].extend(instances)
40
+ logger.info(
41
+ f"SessionData {self.session_id}: Added {len(instances)} instances for backend '{backend_name_ref}'."
42
+ )
43
+ for instance in instances:
44
+ if instance.committed_image_tag:
45
+ self.temporary_docker_images.add(instance.committed_image_tag)
46
+ logger.debug(
47
+ f"SessionData {self.session_id}: Tracking temporary image '{instance.committed_image_tag}'."
48
+ )
49
+
50
+ def get_managed_instances(
51
+ self, backend_name_ref: str, instance_id: Optional[str] = None
52
+ ) -> List[ManagedInstanceInfo]:
53
+ """
54
+ Retrieves managed instances for a backend reference.
55
+ If instance_id is provided, returns a list containing that specific instance (if found).
56
+ Otherwise, returns all instances for the backend_name_ref.
57
+ """
58
+ backend_instances = self.managed_backends.get(backend_name_ref, [])
59
+ if not backend_instances:
60
+ return []
61
+
62
+ if instance_id:
63
+ for inst in backend_instances:
64
+ if inst.instance_id == instance_id:
65
+ return [inst]
66
+ return [] # Specific instance_id not found
67
+
68
+ return backend_instances
69
+
70
+ def get_all_managed_instances(self) -> List[ManagedInstanceInfo]:
71
+ """Returns a flat list of all managed instances in this session data."""
72
+ all_instances = []
73
+ for instances in self.managed_backends.values():
74
+ all_instances.extend(instances)
75
+ return all_instances
76
+
77
+
78
+ # Note: The IntermediarySession class that inherited from ServerSession has been removed.
79
+ # The RewardKitIntermediaryServer will now manage IntermediarySessionData instances directly.