eval-protocol 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- development/__init__.py +1 -0
- development/normalize_sandbox_fusion.py +628 -0
- development/utils/__init__.py +1 -0
- development/utils/generate_api_key.py +31 -0
- development/utils/subprocess_manager.py +481 -0
- eval_protocol/__init__.py +86 -0
- eval_protocol/__main__.py +10 -0
- eval_protocol/_version.py +21 -0
- eval_protocol/adapters/__init__.py +1 -0
- eval_protocol/adapters/braintrust.py +8 -0
- eval_protocol/adapters/trl.py +8 -0
- eval_protocol/agent/__init__.py +29 -0
- eval_protocol/agent/models.py +69 -0
- eval_protocol/agent/orchestrator.py +893 -0
- eval_protocol/agent/resource_abc.py +89 -0
- eval_protocol/agent/resource_pool.py +184 -0
- eval_protocol/agent/resources/__init__.py +44 -0
- eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
- eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
- eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
- eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
- eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
- eval_protocol/agent/resources/docker_resource.py +479 -0
- eval_protocol/agent/resources/filesystem_resource.py +371 -0
- eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
- eval_protocol/agent/resources/http_rollout_resource.py +325 -0
- eval_protocol/agent/resources/python_state_resource.py +170 -0
- eval_protocol/agent/resources/sql_resource.py +271 -0
- eval_protocol/agent/task_manager.py +1064 -0
- eval_protocol/agent/tool_registry.py +111 -0
- eval_protocol/auth.py +156 -0
- eval_protocol/cli.py +425 -0
- eval_protocol/cli_commands/__init__.py +1 -0
- eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
- eval_protocol/cli_commands/common.py +242 -0
- eval_protocol/cli_commands/deploy.py +486 -0
- eval_protocol/cli_commands/deploy_mcp.py +287 -0
- eval_protocol/cli_commands/preview.py +186 -0
- eval_protocol/cli_commands/run_eval_cmd.py +202 -0
- eval_protocol/common_utils.py +36 -0
- eval_protocol/config.py +180 -0
- eval_protocol/datasets/__init__.py +1 -0
- eval_protocol/datasets/loader.py +521 -0
- eval_protocol/evaluation.py +1045 -0
- eval_protocol/execution/__init__.py +1 -0
- eval_protocol/execution/pipeline.py +920 -0
- eval_protocol/gcp_tools.py +484 -0
- eval_protocol/generation/cache.py +141 -0
- eval_protocol/generation/clients/base.py +67 -0
- eval_protocol/generation/clients.py +248 -0
- eval_protocol/generic_server.py +165 -0
- eval_protocol/integrations/__init__.py +12 -0
- eval_protocol/integrations/braintrust.py +51 -0
- eval_protocol/integrations/deepeval.py +106 -0
- eval_protocol/integrations/openeval.py +40 -0
- eval_protocol/integrations/trl.py +187 -0
- eval_protocol/mcp/__init__.py +48 -0
- eval_protocol/mcp/adapter.py +131 -0
- eval_protocol/mcp/client/__init__.py +12 -0
- eval_protocol/mcp/client/connection.py +499 -0
- eval_protocol/mcp/clients.py +195 -0
- eval_protocol/mcp/execution/__init__.py +23 -0
- eval_protocol/mcp/execution/base_policy.py +227 -0
- eval_protocol/mcp/execution/fireworks_policy.py +209 -0
- eval_protocol/mcp/execution/manager.py +506 -0
- eval_protocol/mcp/execution/policy.py +421 -0
- eval_protocol/mcp/grid_renderer.py +54 -0
- eval_protocol/mcp/mcpgym.py +637 -0
- eval_protocol/mcp/process_manager.py +177 -0
- eval_protocol/mcp/session/__init__.py +11 -0
- eval_protocol/mcp/session/manager.py +228 -0
- eval_protocol/mcp/simple_process_manager.py +291 -0
- eval_protocol/mcp/simulation_server.py +458 -0
- eval_protocol/mcp/types.py +80 -0
- eval_protocol/mcp_agent/__init__.py +1 -0
- eval_protocol/mcp_agent/config.py +147 -0
- eval_protocol/mcp_agent/intermediary_server.py +542 -0
- eval_protocol/mcp_agent/main.py +210 -0
- eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
- eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
- eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
- eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
- eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
- eval_protocol/mcp_agent/session.py +79 -0
- eval_protocol/mcp_env.py +304 -0
- eval_protocol/models.py +366 -0
- eval_protocol/packaging.py +219 -0
- eval_protocol/platform_api.py +360 -0
- eval_protocol/playback_policy.py +396 -0
- eval_protocol/resources.py +128 -0
- eval_protocol/reward_function.py +410 -0
- eval_protocol/rewards/__init__.py +94 -0
- eval_protocol/rewards/accuracy.py +454 -0
- eval_protocol/rewards/accuracy_length.py +173 -0
- eval_protocol/rewards/apps_coding_reward.py +331 -0
- eval_protocol/rewards/apps_execution_utils.py +149 -0
- eval_protocol/rewards/apps_testing_util.py +559 -0
- eval_protocol/rewards/bfcl_reward.py +313 -0
- eval_protocol/rewards/code_execution.py +1620 -0
- eval_protocol/rewards/code_execution_utils.py +72 -0
- eval_protocol/rewards/cpp_code.py +861 -0
- eval_protocol/rewards/deepcoder_reward.py +161 -0
- eval_protocol/rewards/format.py +129 -0
- eval_protocol/rewards/function_calling.py +541 -0
- eval_protocol/rewards/json_schema.py +422 -0
- eval_protocol/rewards/language_consistency.py +700 -0
- eval_protocol/rewards/lean_prover.py +479 -0
- eval_protocol/rewards/length.py +375 -0
- eval_protocol/rewards/list_comparison_math_reward.py +221 -0
- eval_protocol/rewards/math.py +762 -0
- eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
- eval_protocol/rewards/reasoning_steps.py +249 -0
- eval_protocol/rewards/repetition.py +342 -0
- eval_protocol/rewards/tag_count.py +162 -0
- eval_protocol/rl_processing.py +82 -0
- eval_protocol/server.py +271 -0
- eval_protocol/typed_interface.py +260 -0
- eval_protocol/utils/__init__.py +8 -0
- eval_protocol/utils/batch_evaluation.py +217 -0
- eval_protocol/utils/batch_transformation.py +205 -0
- eval_protocol/utils/dataset_helpers.py +112 -0
- eval_protocol/utils/module_loader.py +56 -0
- eval_protocol/utils/packaging_utils.py +108 -0
- eval_protocol/utils/static_policy.py +305 -0
- eval_protocol-0.0.3.dist-info/METADATA +635 -0
- eval_protocol-0.0.3.dist-info/RECORD +130 -0
- eval_protocol-0.0.3.dist-info/WHEEL +5 -0
- eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
- eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
- eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,1064 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Task Manager for the Agent Evaluation Framework V2.
|
|
3
|
+
Coordinates multiple tasks and their associated resources.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import asyncio
|
|
7
|
+
import importlib
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
import shlex
|
|
12
|
+
import socket
|
|
13
|
+
import statistics
|
|
14
|
+
import subprocess
|
|
15
|
+
import time
|
|
16
|
+
from copy import deepcopy
|
|
17
|
+
from datetime import datetime
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
|
20
|
+
|
|
21
|
+
import requests
|
|
22
|
+
|
|
23
|
+
from ..models import TaskDefinitionModel
|
|
24
|
+
from .orchestrator import Orchestrator
|
|
25
|
+
from .resource_abc import ForkableResource
|
|
26
|
+
from .resource_pool import ResourcePool
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class TaskManager:
|
|
30
|
+
"""
|
|
31
|
+
Manages the execution of multiple agent evaluation tasks.
|
|
32
|
+
Coordinates resources, orchestrators, and execution flows.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self):
|
|
36
|
+
"""Initialize the TaskManager with an empty task registry."""
|
|
37
|
+
self.tasks: Dict[str, TaskDefinitionModel] = {}
|
|
38
|
+
self.resource_pool = ResourcePool()
|
|
39
|
+
self.logger = logging.getLogger("TaskManager")
|
|
40
|
+
self.orchestrators: Dict[str, Orchestrator] = {}
|
|
41
|
+
self.server_processes: Dict[str, subprocess.Popen] = {}
|
|
42
|
+
self.server_ports: Dict[str, int] = {}
|
|
43
|
+
self.all_server_pids: Set[int] = set()
|
|
44
|
+
|
|
45
|
+
def register_task(self, task_definition_or_name, task_definition=None) -> str:
|
|
46
|
+
"""
|
|
47
|
+
Register a task with the manager.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
task_definition_or_name: Either a TaskDefinitionModel instance (legacy) or task name (new)
|
|
51
|
+
task_definition: TaskDefinitionModel instance when first arg is task name
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
task_id: A unique identifier for the registered task
|
|
55
|
+
"""
|
|
56
|
+
# Handle both calling patterns for backward compatibility
|
|
57
|
+
if task_definition is None:
|
|
58
|
+
# Legacy call: register_task(task_definition)
|
|
59
|
+
task_def = task_definition_or_name
|
|
60
|
+
task_id = task_def.name
|
|
61
|
+
else:
|
|
62
|
+
# New call: register_task(task_name, task_definition)
|
|
63
|
+
task_id = task_definition_or_name
|
|
64
|
+
task_def = task_definition
|
|
65
|
+
|
|
66
|
+
if task_id in self.tasks:
|
|
67
|
+
self.logger.warning(f"Task '{task_id}' is already registered. Overwriting.")
|
|
68
|
+
|
|
69
|
+
self.tasks[task_id] = task_def
|
|
70
|
+
self.logger.info(f"Registered task: {task_id}")
|
|
71
|
+
return task_id
|
|
72
|
+
|
|
73
|
+
def register_tasks_from_directory(self, directory_path: str) -> List[str]:
|
|
74
|
+
"""
|
|
75
|
+
Register all task definition files from a directory.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
directory_path: Path to directory containing task definition files
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
task_ids: List of task IDs that were successfully registered
|
|
82
|
+
"""
|
|
83
|
+
task_ids: List[str] = []
|
|
84
|
+
dir_path = Path(directory_path)
|
|
85
|
+
|
|
86
|
+
if not dir_path.exists() or not dir_path.is_dir():
|
|
87
|
+
self.logger.error(f"Directory not found or not a directory: {directory_path}")
|
|
88
|
+
return task_ids
|
|
89
|
+
|
|
90
|
+
for file_path in dir_path.glob("*.y*ml"):
|
|
91
|
+
try:
|
|
92
|
+
task_def = self._load_task_from_file(str(file_path))
|
|
93
|
+
if task_def:
|
|
94
|
+
task_id = self.register_task(task_def)
|
|
95
|
+
task_ids.append(task_id)
|
|
96
|
+
except Exception as e:
|
|
97
|
+
self.logger.error(f"Error loading task from {file_path}: {e}")
|
|
98
|
+
|
|
99
|
+
for file_path in dir_path.glob("*.json"):
|
|
100
|
+
try:
|
|
101
|
+
task_def = self._load_task_from_file(str(file_path))
|
|
102
|
+
if task_def:
|
|
103
|
+
task_id = self.register_task(task_def)
|
|
104
|
+
task_ids.append(task_id)
|
|
105
|
+
except Exception as e:
|
|
106
|
+
self.logger.error(f"Error loading task from {file_path}: {e}")
|
|
107
|
+
|
|
108
|
+
self.logger.info(f"Registered {len(task_ids)} tasks from {directory_path}")
|
|
109
|
+
return task_ids
|
|
110
|
+
|
|
111
|
+
def _load_task_from_file(self, file_path: str) -> Optional[TaskDefinitionModel]:
|
|
112
|
+
"""
|
|
113
|
+
Load and validate a task definition from a file.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
file_path: Path to the task definition file
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
task_def: A validated TaskDefinitionModel instance or None if loading fails
|
|
120
|
+
"""
|
|
121
|
+
file_path_obj = Path(file_path)
|
|
122
|
+
if not file_path_obj.exists() or not file_path_obj.is_file():
|
|
123
|
+
self.logger.error(f"File not found or not a file: {file_path}")
|
|
124
|
+
return None
|
|
125
|
+
|
|
126
|
+
try:
|
|
127
|
+
# Try to load as YAML first
|
|
128
|
+
try:
|
|
129
|
+
import yaml
|
|
130
|
+
|
|
131
|
+
with open(file_path, "r") as f:
|
|
132
|
+
task_data = yaml.safe_load(f)
|
|
133
|
+
except ImportError:
|
|
134
|
+
# If PyYAML is not available, try JSON
|
|
135
|
+
with open(file_path, "r") as f:
|
|
136
|
+
task_data = json.load(f)
|
|
137
|
+
except Exception:
|
|
138
|
+
# If YAML loading fails, try JSON
|
|
139
|
+
with open(file_path, "r") as f:
|
|
140
|
+
task_data = json.load(f)
|
|
141
|
+
|
|
142
|
+
# Store the original file path for downstream use
|
|
143
|
+
task_data["task_def_path"] = str(file_path_obj.resolve())
|
|
144
|
+
|
|
145
|
+
# Validate with Pydantic model
|
|
146
|
+
task_def = TaskDefinitionModel.model_validate(task_data)
|
|
147
|
+
return task_def
|
|
148
|
+
except Exception as e:
|
|
149
|
+
self.logger.error(f"Error loading task definition from {file_path}: {e}")
|
|
150
|
+
return None
|
|
151
|
+
|
|
152
|
+
def _find_free_port(self) -> int:
|
|
153
|
+
"""Find a free port on localhost."""
|
|
154
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
155
|
+
s.bind(("", 0))
|
|
156
|
+
s.listen(1)
|
|
157
|
+
port = s.getsockname()[1]
|
|
158
|
+
return port
|
|
159
|
+
|
|
160
|
+
def _wait_for_server_health(self, health_url: str, timeout: int = 30) -> bool:
|
|
161
|
+
"""Wait for a server to become healthy by polling its health endpoint."""
|
|
162
|
+
start_time = time.time()
|
|
163
|
+
while time.time() - start_time < timeout:
|
|
164
|
+
try:
|
|
165
|
+
response = requests.get(health_url, timeout=5)
|
|
166
|
+
if response.status_code == 200:
|
|
167
|
+
self.logger.info(f"Server is healthy at {health_url}")
|
|
168
|
+
return True
|
|
169
|
+
except requests.exceptions.RequestException:
|
|
170
|
+
pass
|
|
171
|
+
time.sleep(1)
|
|
172
|
+
|
|
173
|
+
self.logger.error(f"Server failed to become healthy at {health_url} within {timeout} seconds")
|
|
174
|
+
return False
|
|
175
|
+
|
|
176
|
+
def _start_resource_server(self, task_id: str, task_def: TaskDefinitionModel) -> Optional[int]:
|
|
177
|
+
"""Start a resource server for a task and return the allocated port."""
|
|
178
|
+
if not task_def.resource_server:
|
|
179
|
+
return None
|
|
180
|
+
|
|
181
|
+
# Find a free port
|
|
182
|
+
port = self._find_free_port()
|
|
183
|
+
|
|
184
|
+
# Replace {port} placeholder in start command
|
|
185
|
+
start_command = task_def.resource_server.start_command.replace("{port}", str(port))
|
|
186
|
+
|
|
187
|
+
# Start the server process
|
|
188
|
+
try:
|
|
189
|
+
self.logger.info(f"Starting resource server for task '{task_id}' on port {port}: {start_command}")
|
|
190
|
+
process = subprocess.Popen(
|
|
191
|
+
shlex.split(start_command),
|
|
192
|
+
shell=False,
|
|
193
|
+
stdout=subprocess.PIPE,
|
|
194
|
+
stderr=subprocess.PIPE,
|
|
195
|
+
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Store the process and port
|
|
199
|
+
self.server_processes[task_id] = process
|
|
200
|
+
self.server_ports[task_id] = port
|
|
201
|
+
self.all_server_pids.add(process.pid)
|
|
202
|
+
|
|
203
|
+
# Wait for server to become healthy
|
|
204
|
+
health_url = task_def.resource_server.health_check_url.replace("{port}", str(port))
|
|
205
|
+
if self._wait_for_server_health(health_url):
|
|
206
|
+
self.logger.info(f"Resource server for task '{task_id}' is ready on port {port}")
|
|
207
|
+
return port
|
|
208
|
+
else:
|
|
209
|
+
# Server failed to start properly, clean up
|
|
210
|
+
self._stop_resource_server(task_id)
|
|
211
|
+
return None
|
|
212
|
+
|
|
213
|
+
except Exception as e:
|
|
214
|
+
self.logger.error(f"Failed to start resource server for task '{task_id}': {e}")
|
|
215
|
+
return None
|
|
216
|
+
|
|
217
|
+
def _stop_resource_server(self, task_id: str) -> None:
|
|
218
|
+
"""Stop the resource server for a task."""
|
|
219
|
+
if task_id in self.server_processes:
|
|
220
|
+
process = self.server_processes[task_id]
|
|
221
|
+
self.all_server_pids.discard(process.pid)
|
|
222
|
+
try:
|
|
223
|
+
# Try to terminate gracefully first
|
|
224
|
+
if hasattr(os, "killpg"):
|
|
225
|
+
os.killpg(os.getpgid(process.pid), 15) # SIGTERM
|
|
226
|
+
else:
|
|
227
|
+
process.terminate()
|
|
228
|
+
|
|
229
|
+
# Wait a bit for graceful shutdown
|
|
230
|
+
try:
|
|
231
|
+
process.wait(timeout=5)
|
|
232
|
+
except subprocess.TimeoutExpired:
|
|
233
|
+
# Force kill if it doesn't shut down gracefully
|
|
234
|
+
if hasattr(os, "killpg"):
|
|
235
|
+
os.killpg(os.getpgid(process.pid), 9) # SIGKILL
|
|
236
|
+
else:
|
|
237
|
+
process.kill()
|
|
238
|
+
process.wait()
|
|
239
|
+
|
|
240
|
+
self.logger.info(f"Stopped resource server for task '{task_id}'")
|
|
241
|
+
except Exception as e:
|
|
242
|
+
self.logger.error(f"Error stopping resource server for task '{task_id}': {e}")
|
|
243
|
+
|
|
244
|
+
del self.server_processes[task_id]
|
|
245
|
+
|
|
246
|
+
if task_id in self.server_ports:
|
|
247
|
+
del self.server_ports[task_id]
|
|
248
|
+
|
|
249
|
+
async def prepare_task(self, task_id: str) -> bool:
|
|
250
|
+
"""
|
|
251
|
+
Prepare a task for execution by setting up its resources.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
task_id: Identifier of the task to prepare
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
success: True if preparation was successful, False otherwise
|
|
258
|
+
"""
|
|
259
|
+
if task_id not in self.tasks:
|
|
260
|
+
self.logger.error(f"Task '{task_id}' is not registered.")
|
|
261
|
+
return False
|
|
262
|
+
|
|
263
|
+
task_def = self.tasks[task_id]
|
|
264
|
+
|
|
265
|
+
# Start resource server if needed
|
|
266
|
+
allocated_port = None
|
|
267
|
+
if task_def.resource_server:
|
|
268
|
+
allocated_port = self._start_resource_server(task_id, task_def)
|
|
269
|
+
if allocated_port is None:
|
|
270
|
+
self.logger.error(f"Failed to start resource server for task '{task_id}'")
|
|
271
|
+
return False
|
|
272
|
+
|
|
273
|
+
# Create a modified task definition with updated base_url if a server was started
|
|
274
|
+
effective_task_def = task_def
|
|
275
|
+
if allocated_port is not None:
|
|
276
|
+
# Create a deep copy and update the base_url
|
|
277
|
+
effective_task_def = deepcopy(task_def)
|
|
278
|
+
if hasattr(effective_task_def.base_resource_config, "base_url"):
|
|
279
|
+
# Update existing base_url
|
|
280
|
+
effective_task_def.base_resource_config["base_url"] = f"http://localhost:{allocated_port}"
|
|
281
|
+
elif "base_url" in effective_task_def.base_resource_config:
|
|
282
|
+
# Update base_url in dict
|
|
283
|
+
effective_task_def.base_resource_config["base_url"] = f"http://localhost:{allocated_port}"
|
|
284
|
+
else:
|
|
285
|
+
# Add base_url if it doesn't exist
|
|
286
|
+
effective_task_def.base_resource_config["base_url"] = f"http://localhost:{allocated_port}"
|
|
287
|
+
|
|
288
|
+
# Create an orchestrator for this specific task
|
|
289
|
+
orchestrator = Orchestrator(task_definition=effective_task_def)
|
|
290
|
+
self.orchestrators[task_id] = orchestrator
|
|
291
|
+
|
|
292
|
+
# Prepare the resources for this task
|
|
293
|
+
try:
|
|
294
|
+
# Resource setup is handled by the orchestrator
|
|
295
|
+
await orchestrator.setup_base_resource()
|
|
296
|
+
return True
|
|
297
|
+
except Exception as e:
|
|
298
|
+
self.logger.error(f"Error preparing resources for task '{task_id}': {e}")
|
|
299
|
+
# Clean up server if we started one
|
|
300
|
+
if allocated_port is not None:
|
|
301
|
+
self._stop_resource_server(task_id)
|
|
302
|
+
return False
|
|
303
|
+
|
|
304
|
+
async def execute_task(self, task_id: str) -> Optional[Dict[str, Any]]:
|
|
305
|
+
"""
|
|
306
|
+
Execute a registered task.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
task_id: Identifier of the task to execute
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
result: Dictionary containing execution results or None if execution fails
|
|
313
|
+
"""
|
|
314
|
+
if task_id not in self.tasks:
|
|
315
|
+
self.logger.error(f"Task '{task_id}' is not registered.")
|
|
316
|
+
return None
|
|
317
|
+
|
|
318
|
+
if task_id not in self.orchestrators:
|
|
319
|
+
self.logger.info(f"Task '{task_id}' orchestrator not initialized. Preparing task...")
|
|
320
|
+
success = await self.prepare_task(task_id)
|
|
321
|
+
if not success:
|
|
322
|
+
self.logger.error(f"Failed to prepare task '{task_id}'.")
|
|
323
|
+
return None
|
|
324
|
+
|
|
325
|
+
orchestrator = self.orchestrators[task_id]
|
|
326
|
+
|
|
327
|
+
try:
|
|
328
|
+
self.logger.info(f"Executing task '{task_id}'...")
|
|
329
|
+
result = await orchestrator.execute_task_poc()
|
|
330
|
+
self.logger.info(f"Task '{task_id}' execution completed.")
|
|
331
|
+
return result
|
|
332
|
+
except Exception as e:
|
|
333
|
+
self.logger.error(f"Error executing task '{task_id}': {e}", exc_info=True)
|
|
334
|
+
return {"error": str(e)}
|
|
335
|
+
|
|
336
|
+
async def execute_tasks(
|
|
337
|
+
self,
|
|
338
|
+
task_ids: Optional[List[str]] = None,
|
|
339
|
+
parallel: bool = False,
|
|
340
|
+
max_concurrency: int = 3,
|
|
341
|
+
num_rollouts_override: Optional[int] = None,
|
|
342
|
+
) -> Dict[str, Any]:
|
|
343
|
+
"""
|
|
344
|
+
Execute multiple tasks sequentially or in parallel.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
task_ids: List of task IDs to execute. If None, execute all registered tasks.
|
|
348
|
+
parallel: If True, execute tasks in parallel; otherwise, execute sequentially
|
|
349
|
+
max_concurrency: Maximum number of tasks to execute in parallel
|
|
350
|
+
num_rollouts_override: Override the number of rollouts for each task
|
|
351
|
+
|
|
352
|
+
Returns:
|
|
353
|
+
results: Dictionary mapping task IDs to execution results (aggregated if multiple rollouts)
|
|
354
|
+
"""
|
|
355
|
+
task_ids_to_execute = task_ids if task_ids is not None else list(self.tasks.keys())
|
|
356
|
+
|
|
357
|
+
# Validate task IDs
|
|
358
|
+
valid_task_ids = [tid for tid in task_ids_to_execute if tid in self.tasks]
|
|
359
|
+
if len(valid_task_ids) != len(task_ids_to_execute):
|
|
360
|
+
invalid_task_ids = set(task_ids_to_execute) - set(valid_task_ids)
|
|
361
|
+
self.logger.warning(f"Some task IDs are not registered: {invalid_task_ids}")
|
|
362
|
+
|
|
363
|
+
if not valid_task_ids:
|
|
364
|
+
self.logger.error("No valid tasks to execute.")
|
|
365
|
+
return {}
|
|
366
|
+
|
|
367
|
+
results: Dict[str, Any] = {}
|
|
368
|
+
|
|
369
|
+
# For each task, determine how many rollouts to execute
|
|
370
|
+
for task_id in valid_task_ids:
|
|
371
|
+
task_def = self.tasks[task_id]
|
|
372
|
+
|
|
373
|
+
# Check if this is a data-driven evaluation
|
|
374
|
+
if task_def.dataset_path:
|
|
375
|
+
# Data-driven evaluation: load samples from dataset
|
|
376
|
+
samples = self._load_dataset_samples(task_def.dataset_path)
|
|
377
|
+
if not samples:
|
|
378
|
+
results[task_id] = {"error": "Failed to load dataset or dataset is empty"}
|
|
379
|
+
continue
|
|
380
|
+
|
|
381
|
+
self.logger.info(
|
|
382
|
+
f"Executing data-driven evaluation for task '{task_id}': {len(samples)} samples, {task_def.num_rollouts_per_sample} rollouts per sample"
|
|
383
|
+
)
|
|
384
|
+
rollout_results = await self._execute_data_driven_rollouts(
|
|
385
|
+
task_id, samples, task_def.num_rollouts_per_sample, max_concurrency
|
|
386
|
+
)
|
|
387
|
+
else:
|
|
388
|
+
# Traditional evaluation: fixed number of rollouts
|
|
389
|
+
num_rollouts = num_rollouts_override if num_rollouts_override is not None else task_def.num_rollouts
|
|
390
|
+
|
|
391
|
+
if num_rollouts == 1:
|
|
392
|
+
# Single rollout - existing behavior
|
|
393
|
+
if await self.prepare_task(task_id):
|
|
394
|
+
results[task_id] = await self.execute_task(task_id)
|
|
395
|
+
else:
|
|
396
|
+
results[task_id] = {"error": "Task preparation failed"}
|
|
397
|
+
continue
|
|
398
|
+
else:
|
|
399
|
+
# Multiple rollouts - batch execution
|
|
400
|
+
self.logger.info(f"Executing {num_rollouts} rollouts for task '{task_id}'")
|
|
401
|
+
rollout_results = await self._execute_batch_rollouts(task_id, num_rollouts, max_concurrency)
|
|
402
|
+
|
|
403
|
+
# Aggregate results (for both data-driven and traditional batch execution)
|
|
404
|
+
if rollout_results:
|
|
405
|
+
aggregated_result = self._aggregate_results(rollout_results)
|
|
406
|
+
results[task_id] = aggregated_result
|
|
407
|
+
|
|
408
|
+
# Always save detailed results to .jsonl file (including failed rollouts for analysis)
|
|
409
|
+
try:
|
|
410
|
+
detailed_file_path = self._save_detailed_results(task_id, aggregated_result)
|
|
411
|
+
self.logger.info(f"Detailed results saved to: {detailed_file_path}")
|
|
412
|
+
except Exception as e:
|
|
413
|
+
self.logger.error(f"Failed to save detailed results for task '{task_id}': {e}")
|
|
414
|
+
else:
|
|
415
|
+
results[task_id] = {"error": "All rollouts failed"}
|
|
416
|
+
|
|
417
|
+
return results
|
|
418
|
+
|
|
419
|
+
async def _execute_batch_rollouts(
|
|
420
|
+
self, task_id: str, num_rollouts: int, max_concurrency: int
|
|
421
|
+
) -> List[Dict[str, Any]]:
|
|
422
|
+
"""
|
|
423
|
+
Execute multiple rollouts for a single task in parallel.
|
|
424
|
+
|
|
425
|
+
Args:
|
|
426
|
+
task_id: The base task ID
|
|
427
|
+
num_rollouts: Number of rollouts to execute
|
|
428
|
+
max_concurrency: Maximum number of concurrent rollouts
|
|
429
|
+
|
|
430
|
+
Returns:
|
|
431
|
+
List of results from each rollout
|
|
432
|
+
"""
|
|
433
|
+
task_def = self.tasks[task_id]
|
|
434
|
+
rollout_results = []
|
|
435
|
+
|
|
436
|
+
# Create a semaphore to limit concurrency
|
|
437
|
+
semaphore = asyncio.Semaphore(max_concurrency)
|
|
438
|
+
|
|
439
|
+
async def execute_single_rollout(rollout_index: int):
|
|
440
|
+
"""Execute a single rollout with its own server instance."""
|
|
441
|
+
rollout_task_id = f"{task_id}_rollout_{rollout_index}"
|
|
442
|
+
|
|
443
|
+
async with semaphore:
|
|
444
|
+
try:
|
|
445
|
+
# Start resource server if needed for this rollout
|
|
446
|
+
allocated_port = None
|
|
447
|
+
if task_def.resource_server:
|
|
448
|
+
allocated_port = self._start_resource_server(rollout_task_id, task_def)
|
|
449
|
+
if allocated_port is None:
|
|
450
|
+
self.logger.error(
|
|
451
|
+
f"Failed to start resource server for rollout {rollout_index} of task '{task_id}'"
|
|
452
|
+
)
|
|
453
|
+
return {"error": f"Failed to start resource server for rollout {rollout_index}"}
|
|
454
|
+
|
|
455
|
+
# Create effective task definition with updated base_url if needed
|
|
456
|
+
effective_task_def = task_def
|
|
457
|
+
if allocated_port is not None:
|
|
458
|
+
effective_task_def = deepcopy(task_def)
|
|
459
|
+
if hasattr(effective_task_def.base_resource_config, "base_url"):
|
|
460
|
+
effective_task_def.base_resource_config["base_url"] = f"http://localhost:{allocated_port}"
|
|
461
|
+
elif "base_url" in effective_task_def.base_resource_config:
|
|
462
|
+
effective_task_def.base_resource_config["base_url"] = f"http://localhost:{allocated_port}"
|
|
463
|
+
else:
|
|
464
|
+
effective_task_def.base_resource_config["base_url"] = f"http://localhost:{allocated_port}"
|
|
465
|
+
|
|
466
|
+
# Create orchestrator for this rollout
|
|
467
|
+
orchestrator = Orchestrator(task_definition=effective_task_def)
|
|
468
|
+
|
|
469
|
+
# Setup and execute
|
|
470
|
+
await orchestrator.setup_base_resource()
|
|
471
|
+
result = await orchestrator.execute_task_poc()
|
|
472
|
+
|
|
473
|
+
# Cleanup orchestrator resources
|
|
474
|
+
if orchestrator.base_resource:
|
|
475
|
+
await orchestrator.base_resource.close()
|
|
476
|
+
|
|
477
|
+
# Stop the resource server for this rollout
|
|
478
|
+
if allocated_port is not None:
|
|
479
|
+
self._stop_resource_server(rollout_task_id)
|
|
480
|
+
|
|
481
|
+
# Handle case where result is None
|
|
482
|
+
if result is None:
|
|
483
|
+
result = {"error": "Execution returned None"}
|
|
484
|
+
|
|
485
|
+
# Handle new orchestrator format that includes reward_function_inputs
|
|
486
|
+
reward_function_inputs = None
|
|
487
|
+
if isinstance(result, dict) and "evaluation_result" in result:
|
|
488
|
+
# New format with separate evaluation_result and reward_function_inputs
|
|
489
|
+
reward_function_inputs = result.get("reward_function_inputs")
|
|
490
|
+
result = result["evaluation_result"]
|
|
491
|
+
|
|
492
|
+
# Convert EvaluateResult to dict if needed
|
|
493
|
+
if hasattr(result, "model_dump"):
|
|
494
|
+
# Pydantic model - convert to dict
|
|
495
|
+
result = result.model_dump()
|
|
496
|
+
elif hasattr(result, "dict"):
|
|
497
|
+
# Older pydantic models
|
|
498
|
+
result = result.dict()
|
|
499
|
+
# If it's already a dict, leave it as is
|
|
500
|
+
|
|
501
|
+
# Add reward function inputs to the result for JSONL trajectory storage
|
|
502
|
+
if reward_function_inputs is not None and isinstance(result, dict):
|
|
503
|
+
result["reward_function_inputs"] = reward_function_inputs
|
|
504
|
+
|
|
505
|
+
score = result.get("score", "N/A") if isinstance(result, dict) else "N/A"
|
|
506
|
+
self.logger.info(f"Rollout {rollout_index} of task '{task_id}' completed with score: {score}")
|
|
507
|
+
return result
|
|
508
|
+
|
|
509
|
+
except Exception as e:
|
|
510
|
+
error_msg = f"Error in rollout {rollout_index} of task '{task_id}': {e}"
|
|
511
|
+
self.logger.error(error_msg, exc_info=True)
|
|
512
|
+
|
|
513
|
+
# Capture server logs if available for debugging
|
|
514
|
+
if rollout_task_id in self.server_processes:
|
|
515
|
+
process = self.server_processes[rollout_task_id]
|
|
516
|
+
try:
|
|
517
|
+
stdout, stderr = process.communicate(timeout=1)
|
|
518
|
+
if stdout:
|
|
519
|
+
self.logger.error(f"Server stdout for rollout {rollout_index}: {stdout.decode()}")
|
|
520
|
+
if stderr:
|
|
521
|
+
self.logger.error(f"Server stderr for rollout {rollout_index}: {stderr.decode()}")
|
|
522
|
+
except Exception:
|
|
523
|
+
pass # Ignore errors in log capture
|
|
524
|
+
|
|
525
|
+
# Cleanup on error
|
|
526
|
+
if allocated_port is not None:
|
|
527
|
+
self._stop_resource_server(rollout_task_id)
|
|
528
|
+
return {"error": str(e)}
|
|
529
|
+
|
|
530
|
+
# Execute all rollouts concurrently
|
|
531
|
+
rollout_tasks = [execute_single_rollout(i) for i in range(num_rollouts)]
|
|
532
|
+
rollout_results = await asyncio.gather(*rollout_tasks)
|
|
533
|
+
|
|
534
|
+
# Log failed rollouts but return all results for comprehensive analysis
|
|
535
|
+
successful_results = [r for r in rollout_results if not (isinstance(r, dict) and "error" in r)]
|
|
536
|
+
failed_count = len(rollout_results) - len(successful_results)
|
|
537
|
+
|
|
538
|
+
if failed_count > 0:
|
|
539
|
+
self.logger.warning(f"{failed_count} out of {num_rollouts} rollouts failed for task '{task_id}'")
|
|
540
|
+
|
|
541
|
+
# Return all results (successful and failed) for comprehensive logging
|
|
542
|
+
return rollout_results
|
|
543
|
+
|
|
544
|
+
def _load_dataset_samples(self, dataset_path: str) -> List[Dict[str, Any]]:
|
|
545
|
+
"""
|
|
546
|
+
Load samples from a JSONL dataset file.
|
|
547
|
+
|
|
548
|
+
Args:
|
|
549
|
+
dataset_path: Path to the JSONL dataset file
|
|
550
|
+
|
|
551
|
+
Returns:
|
|
552
|
+
List of sample dictionaries loaded from the dataset
|
|
553
|
+
"""
|
|
554
|
+
try:
|
|
555
|
+
samples = []
|
|
556
|
+
# Support both absolute and relative paths
|
|
557
|
+
if not os.path.isabs(dataset_path):
|
|
558
|
+
# Make relative paths relative to the current working directory
|
|
559
|
+
dataset_path = os.path.abspath(dataset_path)
|
|
560
|
+
|
|
561
|
+
if not os.path.exists(dataset_path):
|
|
562
|
+
self.logger.error(f"Dataset file not found: {dataset_path}")
|
|
563
|
+
return []
|
|
564
|
+
|
|
565
|
+
with open(dataset_path, "r") as f:
|
|
566
|
+
for line_num, line in enumerate(f, 1):
|
|
567
|
+
line = line.strip()
|
|
568
|
+
if not line:
|
|
569
|
+
continue
|
|
570
|
+
try:
|
|
571
|
+
sample = json.loads(line)
|
|
572
|
+
samples.append(sample)
|
|
573
|
+
except json.JSONDecodeError as e:
|
|
574
|
+
self.logger.error(f"Invalid JSON on line {line_num} in {dataset_path}: {e}")
|
|
575
|
+
continue
|
|
576
|
+
|
|
577
|
+
self.logger.info(f"Loaded {len(samples)} samples from {dataset_path}")
|
|
578
|
+
return samples
|
|
579
|
+
|
|
580
|
+
except Exception as e:
|
|
581
|
+
self.logger.error(f"Error loading dataset from {dataset_path}: {e}")
|
|
582
|
+
return []
|
|
583
|
+
|
|
584
|
+
async def _execute_data_driven_rollouts(
|
|
585
|
+
self,
|
|
586
|
+
task_id: str,
|
|
587
|
+
samples: List[Dict[str, Any]],
|
|
588
|
+
rollouts_per_sample: int,
|
|
589
|
+
max_concurrency: int,
|
|
590
|
+
) -> List[Dict[str, Any]]:
|
|
591
|
+
"""
|
|
592
|
+
Execute data-driven rollouts where each sample from the dataset is used for multiple rollouts.
|
|
593
|
+
|
|
594
|
+
Args:
|
|
595
|
+
task_id: The base task ID
|
|
596
|
+
samples: List of samples from the dataset
|
|
597
|
+
rollouts_per_sample: Number of rollouts to execute per sample
|
|
598
|
+
max_concurrency: Maximum number of concurrent rollouts
|
|
599
|
+
|
|
600
|
+
Returns:
|
|
601
|
+
List of results from all rollouts across all samples
|
|
602
|
+
"""
|
|
603
|
+
task_def = self.tasks[task_id]
|
|
604
|
+
all_rollout_results = []
|
|
605
|
+
|
|
606
|
+
# Create a semaphore to limit concurrency
|
|
607
|
+
semaphore = asyncio.Semaphore(max_concurrency)
|
|
608
|
+
|
|
609
|
+
async def execute_single_rollout(sample_index: int, rollout_index: int, sample_data: Dict[str, Any]):
|
|
610
|
+
"""Execute a single rollout with sample data."""
|
|
611
|
+
rollout_task_id = f"{task_id}_sample_{sample_index}_rollout_{rollout_index}"
|
|
612
|
+
|
|
613
|
+
async with semaphore:
|
|
614
|
+
try:
|
|
615
|
+
# Start resource server if needed for this rollout
|
|
616
|
+
allocated_port = None
|
|
617
|
+
if task_def.resource_server:
|
|
618
|
+
allocated_port = self._start_resource_server(rollout_task_id, task_def)
|
|
619
|
+
if allocated_port is None:
|
|
620
|
+
self.logger.error(
|
|
621
|
+
f"Failed to start resource server for rollout {rollout_index} of sample {sample_index} for task '{task_id}'"
|
|
622
|
+
)
|
|
623
|
+
return {
|
|
624
|
+
"error": f"Failed to start resource server for sample {sample_index}, rollout {rollout_index}",
|
|
625
|
+
"sample_data": sample_data,
|
|
626
|
+
}
|
|
627
|
+
|
|
628
|
+
# Create effective task definition with updated base_url if needed
|
|
629
|
+
effective_task_def = task_def
|
|
630
|
+
if allocated_port is not None:
|
|
631
|
+
effective_task_def = deepcopy(task_def)
|
|
632
|
+
if hasattr(effective_task_def.base_resource_config, "base_url"):
|
|
633
|
+
effective_task_def.base_resource_config["base_url"] = f"http://localhost:{allocated_port}"
|
|
634
|
+
elif "base_url" in effective_task_def.base_resource_config:
|
|
635
|
+
effective_task_def.base_resource_config["base_url"] = f"http://localhost:{allocated_port}"
|
|
636
|
+
else:
|
|
637
|
+
effective_task_def.base_resource_config["base_url"] = f"http://localhost:{allocated_port}"
|
|
638
|
+
|
|
639
|
+
# Create orchestrator for this rollout
|
|
640
|
+
orchestrator = Orchestrator(task_definition=effective_task_def)
|
|
641
|
+
|
|
642
|
+
# Setup and execute with sample data
|
|
643
|
+
await orchestrator.setup_base_resource()
|
|
644
|
+
result = await orchestrator.execute_task_poc(sample_data=sample_data)
|
|
645
|
+
|
|
646
|
+
# Cleanup orchestrator resources
|
|
647
|
+
if orchestrator.base_resource:
|
|
648
|
+
await orchestrator.base_resource.close()
|
|
649
|
+
|
|
650
|
+
# Stop the resource server for this rollout
|
|
651
|
+
if allocated_port is not None:
|
|
652
|
+
self._stop_resource_server(rollout_task_id)
|
|
653
|
+
|
|
654
|
+
# Handle case where result is None
|
|
655
|
+
if result is None:
|
|
656
|
+
result = {"error": "Execution returned None"}
|
|
657
|
+
|
|
658
|
+
# Handle new orchestrator format that includes reward_function_inputs
|
|
659
|
+
reward_function_inputs = None
|
|
660
|
+
if isinstance(result, dict) and "evaluation_result" in result:
|
|
661
|
+
# New format with separate evaluation_result and reward_function_inputs
|
|
662
|
+
reward_function_inputs = result.get("reward_function_inputs")
|
|
663
|
+
result = result["evaluation_result"]
|
|
664
|
+
|
|
665
|
+
# Convert EvaluateResult to dict if needed
|
|
666
|
+
if hasattr(result, "model_dump"):
|
|
667
|
+
# Pydantic model - convert to dict
|
|
668
|
+
result = result.model_dump()
|
|
669
|
+
elif hasattr(result, "dict"):
|
|
670
|
+
# Older pydantic models
|
|
671
|
+
result = result.dict()
|
|
672
|
+
# If it's already a dict, leave it as is
|
|
673
|
+
|
|
674
|
+
# Add reward function inputs to the result for JSONL trajectory storage
|
|
675
|
+
if reward_function_inputs is not None and isinstance(result, dict):
|
|
676
|
+
result["reward_function_inputs"] = reward_function_inputs
|
|
677
|
+
|
|
678
|
+
# Add sample metadata to the result
|
|
679
|
+
if isinstance(result, dict):
|
|
680
|
+
result["sample_data"] = sample_data
|
|
681
|
+
result["sample_index"] = sample_index
|
|
682
|
+
result["rollout_index"] = rollout_index
|
|
683
|
+
|
|
684
|
+
score = result.get("score", "N/A") if isinstance(result, dict) else "N/A"
|
|
685
|
+
self.logger.info(
|
|
686
|
+
f"Completed rollout {rollout_index} for sample {sample_index} of task '{task_id}' with score: {score}"
|
|
687
|
+
)
|
|
688
|
+
return result
|
|
689
|
+
|
|
690
|
+
except Exception as e:
|
|
691
|
+
self.logger.error(
|
|
692
|
+
f"Error in rollout {rollout_index} for sample {sample_index} of task '{task_id}': {e}",
|
|
693
|
+
exc_info=True,
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
# Try to capture server logs on error
|
|
697
|
+
if allocated_port is not None:
|
|
698
|
+
try:
|
|
699
|
+
process = self.server_processes.get(rollout_task_id)
|
|
700
|
+
if process:
|
|
701
|
+
stdout, stderr = process.communicate(timeout=1)
|
|
702
|
+
if stdout:
|
|
703
|
+
self.logger.error(
|
|
704
|
+
f"Server stdout for sample {sample_index}, rollout {rollout_index}: {stdout.decode()}"
|
|
705
|
+
)
|
|
706
|
+
if stderr:
|
|
707
|
+
self.logger.error(
|
|
708
|
+
f"Server stderr for sample {sample_index}, rollout {rollout_index}: {stderr.decode()}"
|
|
709
|
+
)
|
|
710
|
+
except Exception:
|
|
711
|
+
pass # Ignore errors in log capture
|
|
712
|
+
|
|
713
|
+
# Cleanup on error
|
|
714
|
+
if allocated_port is not None:
|
|
715
|
+
self._stop_resource_server(rollout_task_id)
|
|
716
|
+
return {
|
|
717
|
+
"error": str(e),
|
|
718
|
+
"sample_data": sample_data,
|
|
719
|
+
"sample_index": sample_index,
|
|
720
|
+
"rollout_index": rollout_index,
|
|
721
|
+
}
|
|
722
|
+
|
|
723
|
+
# Create rollout tasks for all samples
|
|
724
|
+
rollout_tasks = []
|
|
725
|
+
for sample_index, sample_data in enumerate(samples):
|
|
726
|
+
for rollout_index in range(rollouts_per_sample):
|
|
727
|
+
task = execute_single_rollout(sample_index, rollout_index, sample_data)
|
|
728
|
+
rollout_tasks.append(task)
|
|
729
|
+
|
|
730
|
+
# Execute all rollouts concurrently
|
|
731
|
+
all_rollout_results = await asyncio.gather(*rollout_tasks)
|
|
732
|
+
|
|
733
|
+
# Log summary statistics
|
|
734
|
+
total_rollouts = len(all_rollout_results)
|
|
735
|
+
successful_results = [r for r in all_rollout_results if not (isinstance(r, dict) and "error" in r)]
|
|
736
|
+
failed_count = total_rollouts - len(successful_results)
|
|
737
|
+
|
|
738
|
+
if failed_count > 0:
|
|
739
|
+
self.logger.warning(
|
|
740
|
+
f"{failed_count} out of {total_rollouts} total rollouts failed for task '{task_id}' "
|
|
741
|
+
f"({len(samples)} samples x {rollouts_per_sample} rollouts per sample)"
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
self.logger.info(
|
|
745
|
+
f"Completed data-driven evaluation for task '{task_id}': "
|
|
746
|
+
f"{len(successful_results)} successful rollouts out of {total_rollouts} total"
|
|
747
|
+
)
|
|
748
|
+
|
|
749
|
+
# Return all results (successful and failed) for comprehensive logging
|
|
750
|
+
return all_rollout_results
|
|
751
|
+
|
|
752
|
+
def _aggregate_results(self, rollout_results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
753
|
+
"""
|
|
754
|
+
Aggregate results from multiple rollouts into a single summary.
|
|
755
|
+
|
|
756
|
+
Args:
|
|
757
|
+
rollout_results: List of individual rollout results
|
|
758
|
+
|
|
759
|
+
Returns:
|
|
760
|
+
Aggregated result dictionary
|
|
761
|
+
"""
|
|
762
|
+
if not rollout_results:
|
|
763
|
+
return {"error": "No successful rollouts to aggregate"}
|
|
764
|
+
|
|
765
|
+
# Separate successful and failed results
|
|
766
|
+
successful_results = []
|
|
767
|
+
failed_results = []
|
|
768
|
+
scores = []
|
|
769
|
+
|
|
770
|
+
for result in rollout_results:
|
|
771
|
+
if isinstance(result, dict) and result.get("error") is not None:
|
|
772
|
+
failed_results.append(result)
|
|
773
|
+
elif isinstance(result, dict) and "score" in result:
|
|
774
|
+
scores.append(result["score"])
|
|
775
|
+
successful_results.append(result)
|
|
776
|
+
else:
|
|
777
|
+
# Handle unexpected result format
|
|
778
|
+
failed_results.append({"error": f"Invalid result format: {result}"})
|
|
779
|
+
|
|
780
|
+
if not scores:
|
|
781
|
+
# Even with no successful rollouts, we still want to save failed rollout data
|
|
782
|
+
aggregated_result = {
|
|
783
|
+
"aggregated": True,
|
|
784
|
+
"num_rollouts": len(rollout_results),
|
|
785
|
+
"total_rollouts": len(rollout_results), # For compatibility with tests
|
|
786
|
+
"successful_rollouts": 0,
|
|
787
|
+
"failed_rollouts": len(failed_results),
|
|
788
|
+
"success_rate": 0.0,
|
|
789
|
+
"avg_score": 0.0,
|
|
790
|
+
"average_score": 0.0, # For compatibility with tests
|
|
791
|
+
"std_dev": 0.0,
|
|
792
|
+
"min_score": 0.0,
|
|
793
|
+
"max_score": 0.0,
|
|
794
|
+
"score": 0.0, # For compatibility with existing logging
|
|
795
|
+
"individual_scores": [],
|
|
796
|
+
"individual_results": rollout_results, # Include all results (failed)
|
|
797
|
+
"detailed_results": rollout_results, # For compatibility with tests
|
|
798
|
+
"successful_results": [],
|
|
799
|
+
"failed_results": failed_results,
|
|
800
|
+
"timestamp": datetime.now().isoformat(),
|
|
801
|
+
"error": "No valid scores found in rollout results",
|
|
802
|
+
}
|
|
803
|
+
return aggregated_result
|
|
804
|
+
|
|
805
|
+
# Calculate aggregated statistics
|
|
806
|
+
avg_score = sum(scores) / len(scores)
|
|
807
|
+
min_score = min(scores)
|
|
808
|
+
max_score = max(scores)
|
|
809
|
+
success_rate = len(scores) / len(rollout_results) if rollout_results else 0
|
|
810
|
+
|
|
811
|
+
# Calculate standard deviation
|
|
812
|
+
std_dev = statistics.stdev(scores) if len(scores) > 1 else 0.0
|
|
813
|
+
|
|
814
|
+
aggregated_result = {
|
|
815
|
+
"aggregated": True,
|
|
816
|
+
"num_rollouts": len(rollout_results),
|
|
817
|
+
"total_rollouts": len(rollout_results), # For compatibility with tests
|
|
818
|
+
"successful_rollouts": len(scores),
|
|
819
|
+
"failed_rollouts": len(failed_results),
|
|
820
|
+
"success_rate": success_rate,
|
|
821
|
+
"avg_score": avg_score,
|
|
822
|
+
"average_score": avg_score, # For compatibility with tests
|
|
823
|
+
"std_dev": std_dev,
|
|
824
|
+
"min_score": min_score,
|
|
825
|
+
"max_score": max_score,
|
|
826
|
+
"score": avg_score, # For compatibility with existing logging
|
|
827
|
+
"individual_scores": scores,
|
|
828
|
+
"individual_results": rollout_results, # Include all results (successful and failed)
|
|
829
|
+
"detailed_results": rollout_results, # For compatibility with tests
|
|
830
|
+
"successful_results": successful_results,
|
|
831
|
+
"failed_results": failed_results,
|
|
832
|
+
"timestamp": datetime.now().isoformat(),
|
|
833
|
+
}
|
|
834
|
+
|
|
835
|
+
# Aggregate metrics if available
|
|
836
|
+
if successful_results and "metrics" in successful_results[0]:
|
|
837
|
+
aggregated_metrics = {}
|
|
838
|
+
for metric_name in successful_results[0]["metrics"].keys():
|
|
839
|
+
metric_scores = []
|
|
840
|
+
for result in successful_results:
|
|
841
|
+
if metric_name in result.get("metrics", {}):
|
|
842
|
+
metric_result = result["metrics"][metric_name]
|
|
843
|
+
if isinstance(metric_result, dict) and "score" in metric_result:
|
|
844
|
+
metric_scores.append(metric_result["score"])
|
|
845
|
+
elif isinstance(metric_result, (int, float)):
|
|
846
|
+
metric_scores.append(metric_result)
|
|
847
|
+
|
|
848
|
+
if metric_scores:
|
|
849
|
+
aggregated_metrics[metric_name] = {
|
|
850
|
+
"avg_score": sum(metric_scores) / len(metric_scores),
|
|
851
|
+
"min_score": min(metric_scores),
|
|
852
|
+
"max_score": max(metric_scores),
|
|
853
|
+
"individual_scores": metric_scores,
|
|
854
|
+
}
|
|
855
|
+
|
|
856
|
+
if aggregated_metrics:
|
|
857
|
+
aggregated_result["aggregated_metrics"] = aggregated_metrics
|
|
858
|
+
|
|
859
|
+
return aggregated_result
|
|
860
|
+
|
|
861
|
+
def _save_detailed_results(
|
|
862
|
+
self,
|
|
863
|
+
task_id: str,
|
|
864
|
+
aggregated_result: Dict[str, Any],
|
|
865
|
+
output_file: Optional[str] = None,
|
|
866
|
+
) -> str:
|
|
867
|
+
"""
|
|
868
|
+
Save detailed results to a .jsonl file for analysis.
|
|
869
|
+
|
|
870
|
+
Args:
|
|
871
|
+
task_id: The task identifier
|
|
872
|
+
aggregated_result: The aggregated result dictionary
|
|
873
|
+
output_file: Optional custom output file path
|
|
874
|
+
|
|
875
|
+
Returns:
|
|
876
|
+
The path to the saved file
|
|
877
|
+
"""
|
|
878
|
+
if output_file is None:
|
|
879
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
880
|
+
|
|
881
|
+
# Prefer evaluation_logs relative to the task definition file
|
|
882
|
+
chosen_dir = None
|
|
883
|
+
|
|
884
|
+
task_def = self.tasks.get(task_id)
|
|
885
|
+
if task_def is not None and hasattr(task_def, "task_def_path"):
|
|
886
|
+
try:
|
|
887
|
+
task_def_path = Path(getattr(task_def, "task_def_path"))
|
|
888
|
+
base_dir = task_def_path.parent
|
|
889
|
+
eval_dir = base_dir / "evaluation_logs"
|
|
890
|
+
eval_dir.mkdir(parents=True, exist_ok=True)
|
|
891
|
+
chosen_dir = eval_dir
|
|
892
|
+
except Exception as e:
|
|
893
|
+
self.logger.warning(f"Failed to create evaluation_logs relative to task definition: {e}")
|
|
894
|
+
|
|
895
|
+
if chosen_dir is None:
|
|
896
|
+
# Look for or create common evaluation log directories
|
|
897
|
+
possible_log_dirs = [
|
|
898
|
+
Path("client/evaluation_logs"),
|
|
899
|
+
Path("evaluation_logs"),
|
|
900
|
+
Path("logs"),
|
|
901
|
+
Path("."), # Fallback to current directory
|
|
902
|
+
]
|
|
903
|
+
|
|
904
|
+
for log_dir in possible_log_dirs:
|
|
905
|
+
try:
|
|
906
|
+
log_dir.mkdir(parents=True, exist_ok=True)
|
|
907
|
+
except Exception:
|
|
908
|
+
continue
|
|
909
|
+
if log_dir.exists() and log_dir.is_dir():
|
|
910
|
+
chosen_dir = log_dir
|
|
911
|
+
break
|
|
912
|
+
|
|
913
|
+
if chosen_dir is None:
|
|
914
|
+
chosen_dir = Path(".")
|
|
915
|
+
|
|
916
|
+
output_file = chosen_dir / f"trajectory_{task_id}_{timestamp}.jsonl"
|
|
917
|
+
|
|
918
|
+
output_path = Path(output_file)
|
|
919
|
+
|
|
920
|
+
try:
|
|
921
|
+
self.logger.info(f"=== TRAJECTORY SAVE DEBUG START ===")
|
|
922
|
+
self.logger.info(f"Saving trajectory data to: {output_path}")
|
|
923
|
+
self.logger.info(f"Chosen directory: {chosen_dir}")
|
|
924
|
+
self.logger.info(f"Individual results count: {len(aggregated_result.get('individual_results', []))}")
|
|
925
|
+
self.logger.info(f"Output path parent directory exists: {output_path.parent.exists()}")
|
|
926
|
+
|
|
927
|
+
# Ensure the directory exists
|
|
928
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
929
|
+
|
|
930
|
+
with open(output_path, "w") as f:
|
|
931
|
+
# Write summary line
|
|
932
|
+
summary = {
|
|
933
|
+
"type": "summary",
|
|
934
|
+
"task_id": task_id,
|
|
935
|
+
"timestamp": aggregated_result.get("timestamp", datetime.now().isoformat()),
|
|
936
|
+
"num_rollouts": aggregated_result["num_rollouts"],
|
|
937
|
+
"successful_rollouts": aggregated_result["successful_rollouts"],
|
|
938
|
+
"failed_rollouts": aggregated_result.get("failed_rollouts", 0),
|
|
939
|
+
"success_rate": aggregated_result["success_rate"],
|
|
940
|
+
"avg_score": aggregated_result["avg_score"],
|
|
941
|
+
"std_dev": aggregated_result["std_dev"],
|
|
942
|
+
"min_score": aggregated_result["min_score"],
|
|
943
|
+
"max_score": aggregated_result["max_score"],
|
|
944
|
+
}
|
|
945
|
+
f.write(json.dumps(summary) + "\n")
|
|
946
|
+
self.logger.info(f"Wrote summary line to {output_path}")
|
|
947
|
+
|
|
948
|
+
# Write individual results
|
|
949
|
+
individual_results = aggregated_result.get("individual_results", [])
|
|
950
|
+
self.logger.info(f"Processing {len(individual_results)} individual results")
|
|
951
|
+
for i, result in enumerate(individual_results):
|
|
952
|
+
self.logger.info(f"Processing individual result {i}: {type(result)} - {len(str(result))} chars")
|
|
953
|
+
|
|
954
|
+
# Clean the result for JSON serialization
|
|
955
|
+
clean_result = {}
|
|
956
|
+
for key, value in result.items():
|
|
957
|
+
if key == "reward_function_inputs" and isinstance(value, dict):
|
|
958
|
+
# Clean the reward function inputs
|
|
959
|
+
clean_inputs = {}
|
|
960
|
+
for input_key, input_value in value.items():
|
|
961
|
+
if input_key == "state" and isinstance(input_value, dict):
|
|
962
|
+
# Clean the state by removing non-serializable objects
|
|
963
|
+
clean_state = {}
|
|
964
|
+
for state_key, state_value in input_value.items():
|
|
965
|
+
if state_key == "resource":
|
|
966
|
+
# Replace resource object with a string representation
|
|
967
|
+
clean_state[state_key] = f"<{type(state_value).__name__}>"
|
|
968
|
+
else:
|
|
969
|
+
clean_state[state_key] = state_value
|
|
970
|
+
clean_inputs[input_key] = clean_state
|
|
971
|
+
else:
|
|
972
|
+
clean_inputs[input_key] = input_value
|
|
973
|
+
clean_result[key] = clean_inputs
|
|
974
|
+
else:
|
|
975
|
+
clean_result[key] = value
|
|
976
|
+
|
|
977
|
+
detailed_result = {
|
|
978
|
+
"type": "individual_result",
|
|
979
|
+
"task_id": task_id,
|
|
980
|
+
"rollout_index": i,
|
|
981
|
+
"timestamp": datetime.now().isoformat(),
|
|
982
|
+
**clean_result,
|
|
983
|
+
}
|
|
984
|
+
f.write(json.dumps(detailed_result) + "\n")
|
|
985
|
+
self.logger.info(f"Wrote individual result {i} to {output_path}")
|
|
986
|
+
|
|
987
|
+
# Force flush to ensure data is written
|
|
988
|
+
f.flush()
|
|
989
|
+
import os
|
|
990
|
+
|
|
991
|
+
os.fsync(f.fileno())
|
|
992
|
+
|
|
993
|
+
self.logger.info(f"Successfully saved trajectory data to: {output_path}")
|
|
994
|
+
self.logger.info(f"Trajectory file size: {output_path.stat().st_size} bytes")
|
|
995
|
+
self.logger.info(f"=== TRAJECTORY SAVE DEBUG END ===")
|
|
996
|
+
return str(output_path)
|
|
997
|
+
|
|
998
|
+
except Exception as e:
|
|
999
|
+
self.logger.error(f"Failed to save detailed results to {output_path}: {e}")
|
|
1000
|
+
import traceback
|
|
1001
|
+
|
|
1002
|
+
self.logger.error(f"Traceback: {traceback.format_exc()}")
|
|
1003
|
+
return ""
|
|
1004
|
+
|
|
1005
|
+
def cleanup_all_servers(self) -> None:
|
|
1006
|
+
"""A more robust cleanup that terminates any tracked server process."""
|
|
1007
|
+
if not self.all_server_pids:
|
|
1008
|
+
self.logger.info("No tracked server PIDs to clean up.")
|
|
1009
|
+
return
|
|
1010
|
+
|
|
1011
|
+
self.logger.info(f"Performing robust cleanup of all {len(self.all_server_pids)} tracked server PIDs.")
|
|
1012
|
+
# Iterate over a copy as the set will be modified
|
|
1013
|
+
for pid in list(self.all_server_pids):
|
|
1014
|
+
try:
|
|
1015
|
+
# Find the task_id associated with this PID for logging
|
|
1016
|
+
task_id = "unknown_task"
|
|
1017
|
+
for tid, proc in self.server_processes.items():
|
|
1018
|
+
if proc.pid == pid:
|
|
1019
|
+
task_id = tid
|
|
1020
|
+
break
|
|
1021
|
+
|
|
1022
|
+
self.logger.warning(
|
|
1023
|
+
f"Force-cleaning up potentially orphaned server process for task '{task_id}' (PID: {pid})."
|
|
1024
|
+
)
|
|
1025
|
+
# Use the same killpg logic
|
|
1026
|
+
if hasattr(os, "killpg"):
|
|
1027
|
+
os.killpg(os.getpgid(pid), 9) # Use SIGKILL for forceful cleanup
|
|
1028
|
+
else:
|
|
1029
|
+
os.kill(pid, 9)
|
|
1030
|
+
self.all_server_pids.discard(pid)
|
|
1031
|
+
|
|
1032
|
+
except ProcessLookupError:
|
|
1033
|
+
# Process already gone, which is fine
|
|
1034
|
+
self.all_server_pids.discard(pid)
|
|
1035
|
+
except Exception as e:
|
|
1036
|
+
self.logger.error(f"Error during robust cleanup of PID {pid}: {e}")
|
|
1037
|
+
self.all_server_pids.discard(pid)
|
|
1038
|
+
|
|
1039
|
+
async def cleanup(self, task_ids: Optional[List[str]] = None) -> None:
|
|
1040
|
+
"""
|
|
1041
|
+
Clean up resources for specified tasks or all tasks.
|
|
1042
|
+
|
|
1043
|
+
Args:
|
|
1044
|
+
task_ids: List of task IDs to clean up. If None, clean up all tasks.
|
|
1045
|
+
"""
|
|
1046
|
+
task_ids_to_cleanup = task_ids if task_ids is not None else list(self.orchestrators.keys())
|
|
1047
|
+
|
|
1048
|
+
for task_id in task_ids_to_cleanup:
|
|
1049
|
+
# Stop resource server if running
|
|
1050
|
+
self._stop_resource_server(task_id)
|
|
1051
|
+
|
|
1052
|
+
# Clean up orchestrator resources
|
|
1053
|
+
if task_id in self.orchestrators:
|
|
1054
|
+
orchestrator = self.orchestrators[task_id]
|
|
1055
|
+
if orchestrator.base_resource:
|
|
1056
|
+
try:
|
|
1057
|
+
await orchestrator.base_resource.close()
|
|
1058
|
+
self.logger.info(f"Cleaned up resources for task '{task_id}'.")
|
|
1059
|
+
except Exception as e:
|
|
1060
|
+
self.logger.error(f"Error cleaning up resources for task '{task_id}': {e}")
|
|
1061
|
+
del self.orchestrators[task_id]
|
|
1062
|
+
|
|
1063
|
+
# Perform robust cleanup of any remaining orphaned processes
|
|
1064
|
+
self.cleanup_all_servers()
|