eval-protocol 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. development/__init__.py +1 -0
  2. development/normalize_sandbox_fusion.py +628 -0
  3. development/utils/__init__.py +1 -0
  4. development/utils/generate_api_key.py +31 -0
  5. development/utils/subprocess_manager.py +481 -0
  6. eval_protocol/__init__.py +86 -0
  7. eval_protocol/__main__.py +10 -0
  8. eval_protocol/_version.py +21 -0
  9. eval_protocol/adapters/__init__.py +1 -0
  10. eval_protocol/adapters/braintrust.py +8 -0
  11. eval_protocol/adapters/trl.py +8 -0
  12. eval_protocol/agent/__init__.py +29 -0
  13. eval_protocol/agent/models.py +69 -0
  14. eval_protocol/agent/orchestrator.py +893 -0
  15. eval_protocol/agent/resource_abc.py +89 -0
  16. eval_protocol/agent/resource_pool.py +184 -0
  17. eval_protocol/agent/resources/__init__.py +44 -0
  18. eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
  19. eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
  20. eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
  21. eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
  22. eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
  23. eval_protocol/agent/resources/docker_resource.py +479 -0
  24. eval_protocol/agent/resources/filesystem_resource.py +371 -0
  25. eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
  26. eval_protocol/agent/resources/http_rollout_resource.py +325 -0
  27. eval_protocol/agent/resources/python_state_resource.py +170 -0
  28. eval_protocol/agent/resources/sql_resource.py +271 -0
  29. eval_protocol/agent/task_manager.py +1064 -0
  30. eval_protocol/agent/tool_registry.py +111 -0
  31. eval_protocol/auth.py +156 -0
  32. eval_protocol/cli.py +425 -0
  33. eval_protocol/cli_commands/__init__.py +1 -0
  34. eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
  35. eval_protocol/cli_commands/common.py +242 -0
  36. eval_protocol/cli_commands/deploy.py +486 -0
  37. eval_protocol/cli_commands/deploy_mcp.py +287 -0
  38. eval_protocol/cli_commands/preview.py +186 -0
  39. eval_protocol/cli_commands/run_eval_cmd.py +202 -0
  40. eval_protocol/common_utils.py +36 -0
  41. eval_protocol/config.py +180 -0
  42. eval_protocol/datasets/__init__.py +1 -0
  43. eval_protocol/datasets/loader.py +521 -0
  44. eval_protocol/evaluation.py +1045 -0
  45. eval_protocol/execution/__init__.py +1 -0
  46. eval_protocol/execution/pipeline.py +920 -0
  47. eval_protocol/gcp_tools.py +484 -0
  48. eval_protocol/generation/cache.py +141 -0
  49. eval_protocol/generation/clients/base.py +67 -0
  50. eval_protocol/generation/clients.py +248 -0
  51. eval_protocol/generic_server.py +165 -0
  52. eval_protocol/integrations/__init__.py +12 -0
  53. eval_protocol/integrations/braintrust.py +51 -0
  54. eval_protocol/integrations/deepeval.py +106 -0
  55. eval_protocol/integrations/openeval.py +40 -0
  56. eval_protocol/integrations/trl.py +187 -0
  57. eval_protocol/mcp/__init__.py +48 -0
  58. eval_protocol/mcp/adapter.py +131 -0
  59. eval_protocol/mcp/client/__init__.py +12 -0
  60. eval_protocol/mcp/client/connection.py +499 -0
  61. eval_protocol/mcp/clients.py +195 -0
  62. eval_protocol/mcp/execution/__init__.py +23 -0
  63. eval_protocol/mcp/execution/base_policy.py +227 -0
  64. eval_protocol/mcp/execution/fireworks_policy.py +209 -0
  65. eval_protocol/mcp/execution/manager.py +506 -0
  66. eval_protocol/mcp/execution/policy.py +421 -0
  67. eval_protocol/mcp/grid_renderer.py +54 -0
  68. eval_protocol/mcp/mcpgym.py +637 -0
  69. eval_protocol/mcp/process_manager.py +177 -0
  70. eval_protocol/mcp/session/__init__.py +11 -0
  71. eval_protocol/mcp/session/manager.py +228 -0
  72. eval_protocol/mcp/simple_process_manager.py +291 -0
  73. eval_protocol/mcp/simulation_server.py +458 -0
  74. eval_protocol/mcp/types.py +80 -0
  75. eval_protocol/mcp_agent/__init__.py +1 -0
  76. eval_protocol/mcp_agent/config.py +147 -0
  77. eval_protocol/mcp_agent/intermediary_server.py +542 -0
  78. eval_protocol/mcp_agent/main.py +210 -0
  79. eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
  80. eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
  81. eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
  82. eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
  83. eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
  84. eval_protocol/mcp_agent/session.py +79 -0
  85. eval_protocol/mcp_env.py +304 -0
  86. eval_protocol/models.py +366 -0
  87. eval_protocol/packaging.py +219 -0
  88. eval_protocol/platform_api.py +360 -0
  89. eval_protocol/playback_policy.py +396 -0
  90. eval_protocol/resources.py +128 -0
  91. eval_protocol/reward_function.py +410 -0
  92. eval_protocol/rewards/__init__.py +94 -0
  93. eval_protocol/rewards/accuracy.py +454 -0
  94. eval_protocol/rewards/accuracy_length.py +173 -0
  95. eval_protocol/rewards/apps_coding_reward.py +331 -0
  96. eval_protocol/rewards/apps_execution_utils.py +149 -0
  97. eval_protocol/rewards/apps_testing_util.py +559 -0
  98. eval_protocol/rewards/bfcl_reward.py +313 -0
  99. eval_protocol/rewards/code_execution.py +1620 -0
  100. eval_protocol/rewards/code_execution_utils.py +72 -0
  101. eval_protocol/rewards/cpp_code.py +861 -0
  102. eval_protocol/rewards/deepcoder_reward.py +161 -0
  103. eval_protocol/rewards/format.py +129 -0
  104. eval_protocol/rewards/function_calling.py +541 -0
  105. eval_protocol/rewards/json_schema.py +422 -0
  106. eval_protocol/rewards/language_consistency.py +700 -0
  107. eval_protocol/rewards/lean_prover.py +479 -0
  108. eval_protocol/rewards/length.py +375 -0
  109. eval_protocol/rewards/list_comparison_math_reward.py +221 -0
  110. eval_protocol/rewards/math.py +762 -0
  111. eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
  112. eval_protocol/rewards/reasoning_steps.py +249 -0
  113. eval_protocol/rewards/repetition.py +342 -0
  114. eval_protocol/rewards/tag_count.py +162 -0
  115. eval_protocol/rl_processing.py +82 -0
  116. eval_protocol/server.py +271 -0
  117. eval_protocol/typed_interface.py +260 -0
  118. eval_protocol/utils/__init__.py +8 -0
  119. eval_protocol/utils/batch_evaluation.py +217 -0
  120. eval_protocol/utils/batch_transformation.py +205 -0
  121. eval_protocol/utils/dataset_helpers.py +112 -0
  122. eval_protocol/utils/module_loader.py +56 -0
  123. eval_protocol/utils/packaging_utils.py +108 -0
  124. eval_protocol/utils/static_policy.py +305 -0
  125. eval_protocol-0.0.3.dist-info/METADATA +635 -0
  126. eval_protocol-0.0.3.dist-info/RECORD +130 -0
  127. eval_protocol-0.0.3.dist-info/WHEEL +5 -0
  128. eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
  129. eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
  130. eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1064 @@
1
+ """
2
+ Task Manager for the Agent Evaluation Framework V2.
3
+ Coordinates multiple tasks and their associated resources.
4
+ """
5
+
6
+ import asyncio
7
+ import importlib
8
+ import json
9
+ import logging
10
+ import os
11
+ import shlex
12
+ import socket
13
+ import statistics
14
+ import subprocess
15
+ import time
16
+ from copy import deepcopy
17
+ from datetime import datetime
18
+ from pathlib import Path
19
+ from typing import Any, Dict, List, Optional, Set, Tuple
20
+
21
+ import requests
22
+
23
+ from ..models import TaskDefinitionModel
24
+ from .orchestrator import Orchestrator
25
+ from .resource_abc import ForkableResource
26
+ from .resource_pool import ResourcePool
27
+
28
+
29
+ class TaskManager:
30
+ """
31
+ Manages the execution of multiple agent evaluation tasks.
32
+ Coordinates resources, orchestrators, and execution flows.
33
+ """
34
+
35
+ def __init__(self):
36
+ """Initialize the TaskManager with an empty task registry."""
37
+ self.tasks: Dict[str, TaskDefinitionModel] = {}
38
+ self.resource_pool = ResourcePool()
39
+ self.logger = logging.getLogger("TaskManager")
40
+ self.orchestrators: Dict[str, Orchestrator] = {}
41
+ self.server_processes: Dict[str, subprocess.Popen] = {}
42
+ self.server_ports: Dict[str, int] = {}
43
+ self.all_server_pids: Set[int] = set()
44
+
45
+ def register_task(self, task_definition_or_name, task_definition=None) -> str:
46
+ """
47
+ Register a task with the manager.
48
+
49
+ Args:
50
+ task_definition_or_name: Either a TaskDefinitionModel instance (legacy) or task name (new)
51
+ task_definition: TaskDefinitionModel instance when first arg is task name
52
+
53
+ Returns:
54
+ task_id: A unique identifier for the registered task
55
+ """
56
+ # Handle both calling patterns for backward compatibility
57
+ if task_definition is None:
58
+ # Legacy call: register_task(task_definition)
59
+ task_def = task_definition_or_name
60
+ task_id = task_def.name
61
+ else:
62
+ # New call: register_task(task_name, task_definition)
63
+ task_id = task_definition_or_name
64
+ task_def = task_definition
65
+
66
+ if task_id in self.tasks:
67
+ self.logger.warning(f"Task '{task_id}' is already registered. Overwriting.")
68
+
69
+ self.tasks[task_id] = task_def
70
+ self.logger.info(f"Registered task: {task_id}")
71
+ return task_id
72
+
73
+ def register_tasks_from_directory(self, directory_path: str) -> List[str]:
74
+ """
75
+ Register all task definition files from a directory.
76
+
77
+ Args:
78
+ directory_path: Path to directory containing task definition files
79
+
80
+ Returns:
81
+ task_ids: List of task IDs that were successfully registered
82
+ """
83
+ task_ids: List[str] = []
84
+ dir_path = Path(directory_path)
85
+
86
+ if not dir_path.exists() or not dir_path.is_dir():
87
+ self.logger.error(f"Directory not found or not a directory: {directory_path}")
88
+ return task_ids
89
+
90
+ for file_path in dir_path.glob("*.y*ml"):
91
+ try:
92
+ task_def = self._load_task_from_file(str(file_path))
93
+ if task_def:
94
+ task_id = self.register_task(task_def)
95
+ task_ids.append(task_id)
96
+ except Exception as e:
97
+ self.logger.error(f"Error loading task from {file_path}: {e}")
98
+
99
+ for file_path in dir_path.glob("*.json"):
100
+ try:
101
+ task_def = self._load_task_from_file(str(file_path))
102
+ if task_def:
103
+ task_id = self.register_task(task_def)
104
+ task_ids.append(task_id)
105
+ except Exception as e:
106
+ self.logger.error(f"Error loading task from {file_path}: {e}")
107
+
108
+ self.logger.info(f"Registered {len(task_ids)} tasks from {directory_path}")
109
+ return task_ids
110
+
111
+ def _load_task_from_file(self, file_path: str) -> Optional[TaskDefinitionModel]:
112
+ """
113
+ Load and validate a task definition from a file.
114
+
115
+ Args:
116
+ file_path: Path to the task definition file
117
+
118
+ Returns:
119
+ task_def: A validated TaskDefinitionModel instance or None if loading fails
120
+ """
121
+ file_path_obj = Path(file_path)
122
+ if not file_path_obj.exists() or not file_path_obj.is_file():
123
+ self.logger.error(f"File not found or not a file: {file_path}")
124
+ return None
125
+
126
+ try:
127
+ # Try to load as YAML first
128
+ try:
129
+ import yaml
130
+
131
+ with open(file_path, "r") as f:
132
+ task_data = yaml.safe_load(f)
133
+ except ImportError:
134
+ # If PyYAML is not available, try JSON
135
+ with open(file_path, "r") as f:
136
+ task_data = json.load(f)
137
+ except Exception:
138
+ # If YAML loading fails, try JSON
139
+ with open(file_path, "r") as f:
140
+ task_data = json.load(f)
141
+
142
+ # Store the original file path for downstream use
143
+ task_data["task_def_path"] = str(file_path_obj.resolve())
144
+
145
+ # Validate with Pydantic model
146
+ task_def = TaskDefinitionModel.model_validate(task_data)
147
+ return task_def
148
+ except Exception as e:
149
+ self.logger.error(f"Error loading task definition from {file_path}: {e}")
150
+ return None
151
+
152
+ def _find_free_port(self) -> int:
153
+ """Find a free port on localhost."""
154
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
155
+ s.bind(("", 0))
156
+ s.listen(1)
157
+ port = s.getsockname()[1]
158
+ return port
159
+
160
+ def _wait_for_server_health(self, health_url: str, timeout: int = 30) -> bool:
161
+ """Wait for a server to become healthy by polling its health endpoint."""
162
+ start_time = time.time()
163
+ while time.time() - start_time < timeout:
164
+ try:
165
+ response = requests.get(health_url, timeout=5)
166
+ if response.status_code == 200:
167
+ self.logger.info(f"Server is healthy at {health_url}")
168
+ return True
169
+ except requests.exceptions.RequestException:
170
+ pass
171
+ time.sleep(1)
172
+
173
+ self.logger.error(f"Server failed to become healthy at {health_url} within {timeout} seconds")
174
+ return False
175
+
176
+ def _start_resource_server(self, task_id: str, task_def: TaskDefinitionModel) -> Optional[int]:
177
+ """Start a resource server for a task and return the allocated port."""
178
+ if not task_def.resource_server:
179
+ return None
180
+
181
+ # Find a free port
182
+ port = self._find_free_port()
183
+
184
+ # Replace {port} placeholder in start command
185
+ start_command = task_def.resource_server.start_command.replace("{port}", str(port))
186
+
187
+ # Start the server process
188
+ try:
189
+ self.logger.info(f"Starting resource server for task '{task_id}' on port {port}: {start_command}")
190
+ process = subprocess.Popen(
191
+ shlex.split(start_command),
192
+ shell=False,
193
+ stdout=subprocess.PIPE,
194
+ stderr=subprocess.PIPE,
195
+ preexec_fn=os.setsid if hasattr(os, "setsid") else None,
196
+ )
197
+
198
+ # Store the process and port
199
+ self.server_processes[task_id] = process
200
+ self.server_ports[task_id] = port
201
+ self.all_server_pids.add(process.pid)
202
+
203
+ # Wait for server to become healthy
204
+ health_url = task_def.resource_server.health_check_url.replace("{port}", str(port))
205
+ if self._wait_for_server_health(health_url):
206
+ self.logger.info(f"Resource server for task '{task_id}' is ready on port {port}")
207
+ return port
208
+ else:
209
+ # Server failed to start properly, clean up
210
+ self._stop_resource_server(task_id)
211
+ return None
212
+
213
+ except Exception as e:
214
+ self.logger.error(f"Failed to start resource server for task '{task_id}': {e}")
215
+ return None
216
+
217
+ def _stop_resource_server(self, task_id: str) -> None:
218
+ """Stop the resource server for a task."""
219
+ if task_id in self.server_processes:
220
+ process = self.server_processes[task_id]
221
+ self.all_server_pids.discard(process.pid)
222
+ try:
223
+ # Try to terminate gracefully first
224
+ if hasattr(os, "killpg"):
225
+ os.killpg(os.getpgid(process.pid), 15) # SIGTERM
226
+ else:
227
+ process.terminate()
228
+
229
+ # Wait a bit for graceful shutdown
230
+ try:
231
+ process.wait(timeout=5)
232
+ except subprocess.TimeoutExpired:
233
+ # Force kill if it doesn't shut down gracefully
234
+ if hasattr(os, "killpg"):
235
+ os.killpg(os.getpgid(process.pid), 9) # SIGKILL
236
+ else:
237
+ process.kill()
238
+ process.wait()
239
+
240
+ self.logger.info(f"Stopped resource server for task '{task_id}'")
241
+ except Exception as e:
242
+ self.logger.error(f"Error stopping resource server for task '{task_id}': {e}")
243
+
244
+ del self.server_processes[task_id]
245
+
246
+ if task_id in self.server_ports:
247
+ del self.server_ports[task_id]
248
+
249
+ async def prepare_task(self, task_id: str) -> bool:
250
+ """
251
+ Prepare a task for execution by setting up its resources.
252
+
253
+ Args:
254
+ task_id: Identifier of the task to prepare
255
+
256
+ Returns:
257
+ success: True if preparation was successful, False otherwise
258
+ """
259
+ if task_id not in self.tasks:
260
+ self.logger.error(f"Task '{task_id}' is not registered.")
261
+ return False
262
+
263
+ task_def = self.tasks[task_id]
264
+
265
+ # Start resource server if needed
266
+ allocated_port = None
267
+ if task_def.resource_server:
268
+ allocated_port = self._start_resource_server(task_id, task_def)
269
+ if allocated_port is None:
270
+ self.logger.error(f"Failed to start resource server for task '{task_id}'")
271
+ return False
272
+
273
+ # Create a modified task definition with updated base_url if a server was started
274
+ effective_task_def = task_def
275
+ if allocated_port is not None:
276
+ # Create a deep copy and update the base_url
277
+ effective_task_def = deepcopy(task_def)
278
+ if hasattr(effective_task_def.base_resource_config, "base_url"):
279
+ # Update existing base_url
280
+ effective_task_def.base_resource_config["base_url"] = f"http://localhost:{allocated_port}"
281
+ elif "base_url" in effective_task_def.base_resource_config:
282
+ # Update base_url in dict
283
+ effective_task_def.base_resource_config["base_url"] = f"http://localhost:{allocated_port}"
284
+ else:
285
+ # Add base_url if it doesn't exist
286
+ effective_task_def.base_resource_config["base_url"] = f"http://localhost:{allocated_port}"
287
+
288
+ # Create an orchestrator for this specific task
289
+ orchestrator = Orchestrator(task_definition=effective_task_def)
290
+ self.orchestrators[task_id] = orchestrator
291
+
292
+ # Prepare the resources for this task
293
+ try:
294
+ # Resource setup is handled by the orchestrator
295
+ await orchestrator.setup_base_resource()
296
+ return True
297
+ except Exception as e:
298
+ self.logger.error(f"Error preparing resources for task '{task_id}': {e}")
299
+ # Clean up server if we started one
300
+ if allocated_port is not None:
301
+ self._stop_resource_server(task_id)
302
+ return False
303
+
304
+ async def execute_task(self, task_id: str) -> Optional[Dict[str, Any]]:
305
+ """
306
+ Execute a registered task.
307
+
308
+ Args:
309
+ task_id: Identifier of the task to execute
310
+
311
+ Returns:
312
+ result: Dictionary containing execution results or None if execution fails
313
+ """
314
+ if task_id not in self.tasks:
315
+ self.logger.error(f"Task '{task_id}' is not registered.")
316
+ return None
317
+
318
+ if task_id not in self.orchestrators:
319
+ self.logger.info(f"Task '{task_id}' orchestrator not initialized. Preparing task...")
320
+ success = await self.prepare_task(task_id)
321
+ if not success:
322
+ self.logger.error(f"Failed to prepare task '{task_id}'.")
323
+ return None
324
+
325
+ orchestrator = self.orchestrators[task_id]
326
+
327
+ try:
328
+ self.logger.info(f"Executing task '{task_id}'...")
329
+ result = await orchestrator.execute_task_poc()
330
+ self.logger.info(f"Task '{task_id}' execution completed.")
331
+ return result
332
+ except Exception as e:
333
+ self.logger.error(f"Error executing task '{task_id}': {e}", exc_info=True)
334
+ return {"error": str(e)}
335
+
336
+ async def execute_tasks(
337
+ self,
338
+ task_ids: Optional[List[str]] = None,
339
+ parallel: bool = False,
340
+ max_concurrency: int = 3,
341
+ num_rollouts_override: Optional[int] = None,
342
+ ) -> Dict[str, Any]:
343
+ """
344
+ Execute multiple tasks sequentially or in parallel.
345
+
346
+ Args:
347
+ task_ids: List of task IDs to execute. If None, execute all registered tasks.
348
+ parallel: If True, execute tasks in parallel; otherwise, execute sequentially
349
+ max_concurrency: Maximum number of tasks to execute in parallel
350
+ num_rollouts_override: Override the number of rollouts for each task
351
+
352
+ Returns:
353
+ results: Dictionary mapping task IDs to execution results (aggregated if multiple rollouts)
354
+ """
355
+ task_ids_to_execute = task_ids if task_ids is not None else list(self.tasks.keys())
356
+
357
+ # Validate task IDs
358
+ valid_task_ids = [tid for tid in task_ids_to_execute if tid in self.tasks]
359
+ if len(valid_task_ids) != len(task_ids_to_execute):
360
+ invalid_task_ids = set(task_ids_to_execute) - set(valid_task_ids)
361
+ self.logger.warning(f"Some task IDs are not registered: {invalid_task_ids}")
362
+
363
+ if not valid_task_ids:
364
+ self.logger.error("No valid tasks to execute.")
365
+ return {}
366
+
367
+ results: Dict[str, Any] = {}
368
+
369
+ # For each task, determine how many rollouts to execute
370
+ for task_id in valid_task_ids:
371
+ task_def = self.tasks[task_id]
372
+
373
+ # Check if this is a data-driven evaluation
374
+ if task_def.dataset_path:
375
+ # Data-driven evaluation: load samples from dataset
376
+ samples = self._load_dataset_samples(task_def.dataset_path)
377
+ if not samples:
378
+ results[task_id] = {"error": "Failed to load dataset or dataset is empty"}
379
+ continue
380
+
381
+ self.logger.info(
382
+ f"Executing data-driven evaluation for task '{task_id}': {len(samples)} samples, {task_def.num_rollouts_per_sample} rollouts per sample"
383
+ )
384
+ rollout_results = await self._execute_data_driven_rollouts(
385
+ task_id, samples, task_def.num_rollouts_per_sample, max_concurrency
386
+ )
387
+ else:
388
+ # Traditional evaluation: fixed number of rollouts
389
+ num_rollouts = num_rollouts_override if num_rollouts_override is not None else task_def.num_rollouts
390
+
391
+ if num_rollouts == 1:
392
+ # Single rollout - existing behavior
393
+ if await self.prepare_task(task_id):
394
+ results[task_id] = await self.execute_task(task_id)
395
+ else:
396
+ results[task_id] = {"error": "Task preparation failed"}
397
+ continue
398
+ else:
399
+ # Multiple rollouts - batch execution
400
+ self.logger.info(f"Executing {num_rollouts} rollouts for task '{task_id}'")
401
+ rollout_results = await self._execute_batch_rollouts(task_id, num_rollouts, max_concurrency)
402
+
403
+ # Aggregate results (for both data-driven and traditional batch execution)
404
+ if rollout_results:
405
+ aggregated_result = self._aggregate_results(rollout_results)
406
+ results[task_id] = aggregated_result
407
+
408
+ # Always save detailed results to .jsonl file (including failed rollouts for analysis)
409
+ try:
410
+ detailed_file_path = self._save_detailed_results(task_id, aggregated_result)
411
+ self.logger.info(f"Detailed results saved to: {detailed_file_path}")
412
+ except Exception as e:
413
+ self.logger.error(f"Failed to save detailed results for task '{task_id}': {e}")
414
+ else:
415
+ results[task_id] = {"error": "All rollouts failed"}
416
+
417
+ return results
418
+
419
+ async def _execute_batch_rollouts(
420
+ self, task_id: str, num_rollouts: int, max_concurrency: int
421
+ ) -> List[Dict[str, Any]]:
422
+ """
423
+ Execute multiple rollouts for a single task in parallel.
424
+
425
+ Args:
426
+ task_id: The base task ID
427
+ num_rollouts: Number of rollouts to execute
428
+ max_concurrency: Maximum number of concurrent rollouts
429
+
430
+ Returns:
431
+ List of results from each rollout
432
+ """
433
+ task_def = self.tasks[task_id]
434
+ rollout_results = []
435
+
436
+ # Create a semaphore to limit concurrency
437
+ semaphore = asyncio.Semaphore(max_concurrency)
438
+
439
+ async def execute_single_rollout(rollout_index: int):
440
+ """Execute a single rollout with its own server instance."""
441
+ rollout_task_id = f"{task_id}_rollout_{rollout_index}"
442
+
443
+ async with semaphore:
444
+ try:
445
+ # Start resource server if needed for this rollout
446
+ allocated_port = None
447
+ if task_def.resource_server:
448
+ allocated_port = self._start_resource_server(rollout_task_id, task_def)
449
+ if allocated_port is None:
450
+ self.logger.error(
451
+ f"Failed to start resource server for rollout {rollout_index} of task '{task_id}'"
452
+ )
453
+ return {"error": f"Failed to start resource server for rollout {rollout_index}"}
454
+
455
+ # Create effective task definition with updated base_url if needed
456
+ effective_task_def = task_def
457
+ if allocated_port is not None:
458
+ effective_task_def = deepcopy(task_def)
459
+ if hasattr(effective_task_def.base_resource_config, "base_url"):
460
+ effective_task_def.base_resource_config["base_url"] = f"http://localhost:{allocated_port}"
461
+ elif "base_url" in effective_task_def.base_resource_config:
462
+ effective_task_def.base_resource_config["base_url"] = f"http://localhost:{allocated_port}"
463
+ else:
464
+ effective_task_def.base_resource_config["base_url"] = f"http://localhost:{allocated_port}"
465
+
466
+ # Create orchestrator for this rollout
467
+ orchestrator = Orchestrator(task_definition=effective_task_def)
468
+
469
+ # Setup and execute
470
+ await orchestrator.setup_base_resource()
471
+ result = await orchestrator.execute_task_poc()
472
+
473
+ # Cleanup orchestrator resources
474
+ if orchestrator.base_resource:
475
+ await orchestrator.base_resource.close()
476
+
477
+ # Stop the resource server for this rollout
478
+ if allocated_port is not None:
479
+ self._stop_resource_server(rollout_task_id)
480
+
481
+ # Handle case where result is None
482
+ if result is None:
483
+ result = {"error": "Execution returned None"}
484
+
485
+ # Handle new orchestrator format that includes reward_function_inputs
486
+ reward_function_inputs = None
487
+ if isinstance(result, dict) and "evaluation_result" in result:
488
+ # New format with separate evaluation_result and reward_function_inputs
489
+ reward_function_inputs = result.get("reward_function_inputs")
490
+ result = result["evaluation_result"]
491
+
492
+ # Convert EvaluateResult to dict if needed
493
+ if hasattr(result, "model_dump"):
494
+ # Pydantic model - convert to dict
495
+ result = result.model_dump()
496
+ elif hasattr(result, "dict"):
497
+ # Older pydantic models
498
+ result = result.dict()
499
+ # If it's already a dict, leave it as is
500
+
501
+ # Add reward function inputs to the result for JSONL trajectory storage
502
+ if reward_function_inputs is not None and isinstance(result, dict):
503
+ result["reward_function_inputs"] = reward_function_inputs
504
+
505
+ score = result.get("score", "N/A") if isinstance(result, dict) else "N/A"
506
+ self.logger.info(f"Rollout {rollout_index} of task '{task_id}' completed with score: {score}")
507
+ return result
508
+
509
+ except Exception as e:
510
+ error_msg = f"Error in rollout {rollout_index} of task '{task_id}': {e}"
511
+ self.logger.error(error_msg, exc_info=True)
512
+
513
+ # Capture server logs if available for debugging
514
+ if rollout_task_id in self.server_processes:
515
+ process = self.server_processes[rollout_task_id]
516
+ try:
517
+ stdout, stderr = process.communicate(timeout=1)
518
+ if stdout:
519
+ self.logger.error(f"Server stdout for rollout {rollout_index}: {stdout.decode()}")
520
+ if stderr:
521
+ self.logger.error(f"Server stderr for rollout {rollout_index}: {stderr.decode()}")
522
+ except Exception:
523
+ pass # Ignore errors in log capture
524
+
525
+ # Cleanup on error
526
+ if allocated_port is not None:
527
+ self._stop_resource_server(rollout_task_id)
528
+ return {"error": str(e)}
529
+
530
+ # Execute all rollouts concurrently
531
+ rollout_tasks = [execute_single_rollout(i) for i in range(num_rollouts)]
532
+ rollout_results = await asyncio.gather(*rollout_tasks)
533
+
534
+ # Log failed rollouts but return all results for comprehensive analysis
535
+ successful_results = [r for r in rollout_results if not (isinstance(r, dict) and "error" in r)]
536
+ failed_count = len(rollout_results) - len(successful_results)
537
+
538
+ if failed_count > 0:
539
+ self.logger.warning(f"{failed_count} out of {num_rollouts} rollouts failed for task '{task_id}'")
540
+
541
+ # Return all results (successful and failed) for comprehensive logging
542
+ return rollout_results
543
+
544
+ def _load_dataset_samples(self, dataset_path: str) -> List[Dict[str, Any]]:
545
+ """
546
+ Load samples from a JSONL dataset file.
547
+
548
+ Args:
549
+ dataset_path: Path to the JSONL dataset file
550
+
551
+ Returns:
552
+ List of sample dictionaries loaded from the dataset
553
+ """
554
+ try:
555
+ samples = []
556
+ # Support both absolute and relative paths
557
+ if not os.path.isabs(dataset_path):
558
+ # Make relative paths relative to the current working directory
559
+ dataset_path = os.path.abspath(dataset_path)
560
+
561
+ if not os.path.exists(dataset_path):
562
+ self.logger.error(f"Dataset file not found: {dataset_path}")
563
+ return []
564
+
565
+ with open(dataset_path, "r") as f:
566
+ for line_num, line in enumerate(f, 1):
567
+ line = line.strip()
568
+ if not line:
569
+ continue
570
+ try:
571
+ sample = json.loads(line)
572
+ samples.append(sample)
573
+ except json.JSONDecodeError as e:
574
+ self.logger.error(f"Invalid JSON on line {line_num} in {dataset_path}: {e}")
575
+ continue
576
+
577
+ self.logger.info(f"Loaded {len(samples)} samples from {dataset_path}")
578
+ return samples
579
+
580
+ except Exception as e:
581
+ self.logger.error(f"Error loading dataset from {dataset_path}: {e}")
582
+ return []
583
+
584
+ async def _execute_data_driven_rollouts(
585
+ self,
586
+ task_id: str,
587
+ samples: List[Dict[str, Any]],
588
+ rollouts_per_sample: int,
589
+ max_concurrency: int,
590
+ ) -> List[Dict[str, Any]]:
591
+ """
592
+ Execute data-driven rollouts where each sample from the dataset is used for multiple rollouts.
593
+
594
+ Args:
595
+ task_id: The base task ID
596
+ samples: List of samples from the dataset
597
+ rollouts_per_sample: Number of rollouts to execute per sample
598
+ max_concurrency: Maximum number of concurrent rollouts
599
+
600
+ Returns:
601
+ List of results from all rollouts across all samples
602
+ """
603
+ task_def = self.tasks[task_id]
604
+ all_rollout_results = []
605
+
606
+ # Create a semaphore to limit concurrency
607
+ semaphore = asyncio.Semaphore(max_concurrency)
608
+
609
+ async def execute_single_rollout(sample_index: int, rollout_index: int, sample_data: Dict[str, Any]):
610
+ """Execute a single rollout with sample data."""
611
+ rollout_task_id = f"{task_id}_sample_{sample_index}_rollout_{rollout_index}"
612
+
613
+ async with semaphore:
614
+ try:
615
+ # Start resource server if needed for this rollout
616
+ allocated_port = None
617
+ if task_def.resource_server:
618
+ allocated_port = self._start_resource_server(rollout_task_id, task_def)
619
+ if allocated_port is None:
620
+ self.logger.error(
621
+ f"Failed to start resource server for rollout {rollout_index} of sample {sample_index} for task '{task_id}'"
622
+ )
623
+ return {
624
+ "error": f"Failed to start resource server for sample {sample_index}, rollout {rollout_index}",
625
+ "sample_data": sample_data,
626
+ }
627
+
628
+ # Create effective task definition with updated base_url if needed
629
+ effective_task_def = task_def
630
+ if allocated_port is not None:
631
+ effective_task_def = deepcopy(task_def)
632
+ if hasattr(effective_task_def.base_resource_config, "base_url"):
633
+ effective_task_def.base_resource_config["base_url"] = f"http://localhost:{allocated_port}"
634
+ elif "base_url" in effective_task_def.base_resource_config:
635
+ effective_task_def.base_resource_config["base_url"] = f"http://localhost:{allocated_port}"
636
+ else:
637
+ effective_task_def.base_resource_config["base_url"] = f"http://localhost:{allocated_port}"
638
+
639
+ # Create orchestrator for this rollout
640
+ orchestrator = Orchestrator(task_definition=effective_task_def)
641
+
642
+ # Setup and execute with sample data
643
+ await orchestrator.setup_base_resource()
644
+ result = await orchestrator.execute_task_poc(sample_data=sample_data)
645
+
646
+ # Cleanup orchestrator resources
647
+ if orchestrator.base_resource:
648
+ await orchestrator.base_resource.close()
649
+
650
+ # Stop the resource server for this rollout
651
+ if allocated_port is not None:
652
+ self._stop_resource_server(rollout_task_id)
653
+
654
+ # Handle case where result is None
655
+ if result is None:
656
+ result = {"error": "Execution returned None"}
657
+
658
+ # Handle new orchestrator format that includes reward_function_inputs
659
+ reward_function_inputs = None
660
+ if isinstance(result, dict) and "evaluation_result" in result:
661
+ # New format with separate evaluation_result and reward_function_inputs
662
+ reward_function_inputs = result.get("reward_function_inputs")
663
+ result = result["evaluation_result"]
664
+
665
+ # Convert EvaluateResult to dict if needed
666
+ if hasattr(result, "model_dump"):
667
+ # Pydantic model - convert to dict
668
+ result = result.model_dump()
669
+ elif hasattr(result, "dict"):
670
+ # Older pydantic models
671
+ result = result.dict()
672
+ # If it's already a dict, leave it as is
673
+
674
+ # Add reward function inputs to the result for JSONL trajectory storage
675
+ if reward_function_inputs is not None and isinstance(result, dict):
676
+ result["reward_function_inputs"] = reward_function_inputs
677
+
678
+ # Add sample metadata to the result
679
+ if isinstance(result, dict):
680
+ result["sample_data"] = sample_data
681
+ result["sample_index"] = sample_index
682
+ result["rollout_index"] = rollout_index
683
+
684
+ score = result.get("score", "N/A") if isinstance(result, dict) else "N/A"
685
+ self.logger.info(
686
+ f"Completed rollout {rollout_index} for sample {sample_index} of task '{task_id}' with score: {score}"
687
+ )
688
+ return result
689
+
690
+ except Exception as e:
691
+ self.logger.error(
692
+ f"Error in rollout {rollout_index} for sample {sample_index} of task '{task_id}': {e}",
693
+ exc_info=True,
694
+ )
695
+
696
+ # Try to capture server logs on error
697
+ if allocated_port is not None:
698
+ try:
699
+ process = self.server_processes.get(rollout_task_id)
700
+ if process:
701
+ stdout, stderr = process.communicate(timeout=1)
702
+ if stdout:
703
+ self.logger.error(
704
+ f"Server stdout for sample {sample_index}, rollout {rollout_index}: {stdout.decode()}"
705
+ )
706
+ if stderr:
707
+ self.logger.error(
708
+ f"Server stderr for sample {sample_index}, rollout {rollout_index}: {stderr.decode()}"
709
+ )
710
+ except Exception:
711
+ pass # Ignore errors in log capture
712
+
713
+ # Cleanup on error
714
+ if allocated_port is not None:
715
+ self._stop_resource_server(rollout_task_id)
716
+ return {
717
+ "error": str(e),
718
+ "sample_data": sample_data,
719
+ "sample_index": sample_index,
720
+ "rollout_index": rollout_index,
721
+ }
722
+
723
+ # Create rollout tasks for all samples
724
+ rollout_tasks = []
725
+ for sample_index, sample_data in enumerate(samples):
726
+ for rollout_index in range(rollouts_per_sample):
727
+ task = execute_single_rollout(sample_index, rollout_index, sample_data)
728
+ rollout_tasks.append(task)
729
+
730
+ # Execute all rollouts concurrently
731
+ all_rollout_results = await asyncio.gather(*rollout_tasks)
732
+
733
+ # Log summary statistics
734
+ total_rollouts = len(all_rollout_results)
735
+ successful_results = [r for r in all_rollout_results if not (isinstance(r, dict) and "error" in r)]
736
+ failed_count = total_rollouts - len(successful_results)
737
+
738
+ if failed_count > 0:
739
+ self.logger.warning(
740
+ f"{failed_count} out of {total_rollouts} total rollouts failed for task '{task_id}' "
741
+ f"({len(samples)} samples x {rollouts_per_sample} rollouts per sample)"
742
+ )
743
+
744
+ self.logger.info(
745
+ f"Completed data-driven evaluation for task '{task_id}': "
746
+ f"{len(successful_results)} successful rollouts out of {total_rollouts} total"
747
+ )
748
+
749
+ # Return all results (successful and failed) for comprehensive logging
750
+ return all_rollout_results
751
+
752
+ def _aggregate_results(self, rollout_results: List[Dict[str, Any]]) -> Dict[str, Any]:
753
+ """
754
+ Aggregate results from multiple rollouts into a single summary.
755
+
756
+ Args:
757
+ rollout_results: List of individual rollout results
758
+
759
+ Returns:
760
+ Aggregated result dictionary
761
+ """
762
+ if not rollout_results:
763
+ return {"error": "No successful rollouts to aggregate"}
764
+
765
+ # Separate successful and failed results
766
+ successful_results = []
767
+ failed_results = []
768
+ scores = []
769
+
770
+ for result in rollout_results:
771
+ if isinstance(result, dict) and result.get("error") is not None:
772
+ failed_results.append(result)
773
+ elif isinstance(result, dict) and "score" in result:
774
+ scores.append(result["score"])
775
+ successful_results.append(result)
776
+ else:
777
+ # Handle unexpected result format
778
+ failed_results.append({"error": f"Invalid result format: {result}"})
779
+
780
+ if not scores:
781
+ # Even with no successful rollouts, we still want to save failed rollout data
782
+ aggregated_result = {
783
+ "aggregated": True,
784
+ "num_rollouts": len(rollout_results),
785
+ "total_rollouts": len(rollout_results), # For compatibility with tests
786
+ "successful_rollouts": 0,
787
+ "failed_rollouts": len(failed_results),
788
+ "success_rate": 0.0,
789
+ "avg_score": 0.0,
790
+ "average_score": 0.0, # For compatibility with tests
791
+ "std_dev": 0.0,
792
+ "min_score": 0.0,
793
+ "max_score": 0.0,
794
+ "score": 0.0, # For compatibility with existing logging
795
+ "individual_scores": [],
796
+ "individual_results": rollout_results, # Include all results (failed)
797
+ "detailed_results": rollout_results, # For compatibility with tests
798
+ "successful_results": [],
799
+ "failed_results": failed_results,
800
+ "timestamp": datetime.now().isoformat(),
801
+ "error": "No valid scores found in rollout results",
802
+ }
803
+ return aggregated_result
804
+
805
+ # Calculate aggregated statistics
806
+ avg_score = sum(scores) / len(scores)
807
+ min_score = min(scores)
808
+ max_score = max(scores)
809
+ success_rate = len(scores) / len(rollout_results) if rollout_results else 0
810
+
811
+ # Calculate standard deviation
812
+ std_dev = statistics.stdev(scores) if len(scores) > 1 else 0.0
813
+
814
+ aggregated_result = {
815
+ "aggregated": True,
816
+ "num_rollouts": len(rollout_results),
817
+ "total_rollouts": len(rollout_results), # For compatibility with tests
818
+ "successful_rollouts": len(scores),
819
+ "failed_rollouts": len(failed_results),
820
+ "success_rate": success_rate,
821
+ "avg_score": avg_score,
822
+ "average_score": avg_score, # For compatibility with tests
823
+ "std_dev": std_dev,
824
+ "min_score": min_score,
825
+ "max_score": max_score,
826
+ "score": avg_score, # For compatibility with existing logging
827
+ "individual_scores": scores,
828
+ "individual_results": rollout_results, # Include all results (successful and failed)
829
+ "detailed_results": rollout_results, # For compatibility with tests
830
+ "successful_results": successful_results,
831
+ "failed_results": failed_results,
832
+ "timestamp": datetime.now().isoformat(),
833
+ }
834
+
835
+ # Aggregate metrics if available
836
+ if successful_results and "metrics" in successful_results[0]:
837
+ aggregated_metrics = {}
838
+ for metric_name in successful_results[0]["metrics"].keys():
839
+ metric_scores = []
840
+ for result in successful_results:
841
+ if metric_name in result.get("metrics", {}):
842
+ metric_result = result["metrics"][metric_name]
843
+ if isinstance(metric_result, dict) and "score" in metric_result:
844
+ metric_scores.append(metric_result["score"])
845
+ elif isinstance(metric_result, (int, float)):
846
+ metric_scores.append(metric_result)
847
+
848
+ if metric_scores:
849
+ aggregated_metrics[metric_name] = {
850
+ "avg_score": sum(metric_scores) / len(metric_scores),
851
+ "min_score": min(metric_scores),
852
+ "max_score": max(metric_scores),
853
+ "individual_scores": metric_scores,
854
+ }
855
+
856
+ if aggregated_metrics:
857
+ aggregated_result["aggregated_metrics"] = aggregated_metrics
858
+
859
+ return aggregated_result
860
+
861
+ def _save_detailed_results(
862
+ self,
863
+ task_id: str,
864
+ aggregated_result: Dict[str, Any],
865
+ output_file: Optional[str] = None,
866
+ ) -> str:
867
+ """
868
+ Save detailed results to a .jsonl file for analysis.
869
+
870
+ Args:
871
+ task_id: The task identifier
872
+ aggregated_result: The aggregated result dictionary
873
+ output_file: Optional custom output file path
874
+
875
+ Returns:
876
+ The path to the saved file
877
+ """
878
+ if output_file is None:
879
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
880
+
881
+ # Prefer evaluation_logs relative to the task definition file
882
+ chosen_dir = None
883
+
884
+ task_def = self.tasks.get(task_id)
885
+ if task_def is not None and hasattr(task_def, "task_def_path"):
886
+ try:
887
+ task_def_path = Path(getattr(task_def, "task_def_path"))
888
+ base_dir = task_def_path.parent
889
+ eval_dir = base_dir / "evaluation_logs"
890
+ eval_dir.mkdir(parents=True, exist_ok=True)
891
+ chosen_dir = eval_dir
892
+ except Exception as e:
893
+ self.logger.warning(f"Failed to create evaluation_logs relative to task definition: {e}")
894
+
895
+ if chosen_dir is None:
896
+ # Look for or create common evaluation log directories
897
+ possible_log_dirs = [
898
+ Path("client/evaluation_logs"),
899
+ Path("evaluation_logs"),
900
+ Path("logs"),
901
+ Path("."), # Fallback to current directory
902
+ ]
903
+
904
+ for log_dir in possible_log_dirs:
905
+ try:
906
+ log_dir.mkdir(parents=True, exist_ok=True)
907
+ except Exception:
908
+ continue
909
+ if log_dir.exists() and log_dir.is_dir():
910
+ chosen_dir = log_dir
911
+ break
912
+
913
+ if chosen_dir is None:
914
+ chosen_dir = Path(".")
915
+
916
+ output_file = chosen_dir / f"trajectory_{task_id}_{timestamp}.jsonl"
917
+
918
+ output_path = Path(output_file)
919
+
920
+ try:
921
+ self.logger.info(f"=== TRAJECTORY SAVE DEBUG START ===")
922
+ self.logger.info(f"Saving trajectory data to: {output_path}")
923
+ self.logger.info(f"Chosen directory: {chosen_dir}")
924
+ self.logger.info(f"Individual results count: {len(aggregated_result.get('individual_results', []))}")
925
+ self.logger.info(f"Output path parent directory exists: {output_path.parent.exists()}")
926
+
927
+ # Ensure the directory exists
928
+ output_path.parent.mkdir(parents=True, exist_ok=True)
929
+
930
+ with open(output_path, "w") as f:
931
+ # Write summary line
932
+ summary = {
933
+ "type": "summary",
934
+ "task_id": task_id,
935
+ "timestamp": aggregated_result.get("timestamp", datetime.now().isoformat()),
936
+ "num_rollouts": aggregated_result["num_rollouts"],
937
+ "successful_rollouts": aggregated_result["successful_rollouts"],
938
+ "failed_rollouts": aggregated_result.get("failed_rollouts", 0),
939
+ "success_rate": aggregated_result["success_rate"],
940
+ "avg_score": aggregated_result["avg_score"],
941
+ "std_dev": aggregated_result["std_dev"],
942
+ "min_score": aggregated_result["min_score"],
943
+ "max_score": aggregated_result["max_score"],
944
+ }
945
+ f.write(json.dumps(summary) + "\n")
946
+ self.logger.info(f"Wrote summary line to {output_path}")
947
+
948
+ # Write individual results
949
+ individual_results = aggregated_result.get("individual_results", [])
950
+ self.logger.info(f"Processing {len(individual_results)} individual results")
951
+ for i, result in enumerate(individual_results):
952
+ self.logger.info(f"Processing individual result {i}: {type(result)} - {len(str(result))} chars")
953
+
954
+ # Clean the result for JSON serialization
955
+ clean_result = {}
956
+ for key, value in result.items():
957
+ if key == "reward_function_inputs" and isinstance(value, dict):
958
+ # Clean the reward function inputs
959
+ clean_inputs = {}
960
+ for input_key, input_value in value.items():
961
+ if input_key == "state" and isinstance(input_value, dict):
962
+ # Clean the state by removing non-serializable objects
963
+ clean_state = {}
964
+ for state_key, state_value in input_value.items():
965
+ if state_key == "resource":
966
+ # Replace resource object with a string representation
967
+ clean_state[state_key] = f"<{type(state_value).__name__}>"
968
+ else:
969
+ clean_state[state_key] = state_value
970
+ clean_inputs[input_key] = clean_state
971
+ else:
972
+ clean_inputs[input_key] = input_value
973
+ clean_result[key] = clean_inputs
974
+ else:
975
+ clean_result[key] = value
976
+
977
+ detailed_result = {
978
+ "type": "individual_result",
979
+ "task_id": task_id,
980
+ "rollout_index": i,
981
+ "timestamp": datetime.now().isoformat(),
982
+ **clean_result,
983
+ }
984
+ f.write(json.dumps(detailed_result) + "\n")
985
+ self.logger.info(f"Wrote individual result {i} to {output_path}")
986
+
987
+ # Force flush to ensure data is written
988
+ f.flush()
989
+ import os
990
+
991
+ os.fsync(f.fileno())
992
+
993
+ self.logger.info(f"Successfully saved trajectory data to: {output_path}")
994
+ self.logger.info(f"Trajectory file size: {output_path.stat().st_size} bytes")
995
+ self.logger.info(f"=== TRAJECTORY SAVE DEBUG END ===")
996
+ return str(output_path)
997
+
998
+ except Exception as e:
999
+ self.logger.error(f"Failed to save detailed results to {output_path}: {e}")
1000
+ import traceback
1001
+
1002
+ self.logger.error(f"Traceback: {traceback.format_exc()}")
1003
+ return ""
1004
+
1005
+ def cleanup_all_servers(self) -> None:
1006
+ """A more robust cleanup that terminates any tracked server process."""
1007
+ if not self.all_server_pids:
1008
+ self.logger.info("No tracked server PIDs to clean up.")
1009
+ return
1010
+
1011
+ self.logger.info(f"Performing robust cleanup of all {len(self.all_server_pids)} tracked server PIDs.")
1012
+ # Iterate over a copy as the set will be modified
1013
+ for pid in list(self.all_server_pids):
1014
+ try:
1015
+ # Find the task_id associated with this PID for logging
1016
+ task_id = "unknown_task"
1017
+ for tid, proc in self.server_processes.items():
1018
+ if proc.pid == pid:
1019
+ task_id = tid
1020
+ break
1021
+
1022
+ self.logger.warning(
1023
+ f"Force-cleaning up potentially orphaned server process for task '{task_id}' (PID: {pid})."
1024
+ )
1025
+ # Use the same killpg logic
1026
+ if hasattr(os, "killpg"):
1027
+ os.killpg(os.getpgid(pid), 9) # Use SIGKILL for forceful cleanup
1028
+ else:
1029
+ os.kill(pid, 9)
1030
+ self.all_server_pids.discard(pid)
1031
+
1032
+ except ProcessLookupError:
1033
+ # Process already gone, which is fine
1034
+ self.all_server_pids.discard(pid)
1035
+ except Exception as e:
1036
+ self.logger.error(f"Error during robust cleanup of PID {pid}: {e}")
1037
+ self.all_server_pids.discard(pid)
1038
+
1039
+ async def cleanup(self, task_ids: Optional[List[str]] = None) -> None:
1040
+ """
1041
+ Clean up resources for specified tasks or all tasks.
1042
+
1043
+ Args:
1044
+ task_ids: List of task IDs to clean up. If None, clean up all tasks.
1045
+ """
1046
+ task_ids_to_cleanup = task_ids if task_ids is not None else list(self.orchestrators.keys())
1047
+
1048
+ for task_id in task_ids_to_cleanup:
1049
+ # Stop resource server if running
1050
+ self._stop_resource_server(task_id)
1051
+
1052
+ # Clean up orchestrator resources
1053
+ if task_id in self.orchestrators:
1054
+ orchestrator = self.orchestrators[task_id]
1055
+ if orchestrator.base_resource:
1056
+ try:
1057
+ await orchestrator.base_resource.close()
1058
+ self.logger.info(f"Cleaned up resources for task '{task_id}'.")
1059
+ except Exception as e:
1060
+ self.logger.error(f"Error cleaning up resources for task '{task_id}': {e}")
1061
+ del self.orchestrators[task_id]
1062
+
1063
+ # Perform robust cleanup of any remaining orphaned processes
1064
+ self.cleanup_all_servers()