eval-protocol 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- development/__init__.py +1 -0
- development/normalize_sandbox_fusion.py +628 -0
- development/utils/__init__.py +1 -0
- development/utils/generate_api_key.py +31 -0
- development/utils/subprocess_manager.py +481 -0
- eval_protocol/__init__.py +86 -0
- eval_protocol/__main__.py +10 -0
- eval_protocol/_version.py +21 -0
- eval_protocol/adapters/__init__.py +1 -0
- eval_protocol/adapters/braintrust.py +8 -0
- eval_protocol/adapters/trl.py +8 -0
- eval_protocol/agent/__init__.py +29 -0
- eval_protocol/agent/models.py +69 -0
- eval_protocol/agent/orchestrator.py +893 -0
- eval_protocol/agent/resource_abc.py +89 -0
- eval_protocol/agent/resource_pool.py +184 -0
- eval_protocol/agent/resources/__init__.py +44 -0
- eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
- eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
- eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
- eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
- eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
- eval_protocol/agent/resources/docker_resource.py +479 -0
- eval_protocol/agent/resources/filesystem_resource.py +371 -0
- eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
- eval_protocol/agent/resources/http_rollout_resource.py +325 -0
- eval_protocol/agent/resources/python_state_resource.py +170 -0
- eval_protocol/agent/resources/sql_resource.py +271 -0
- eval_protocol/agent/task_manager.py +1064 -0
- eval_protocol/agent/tool_registry.py +111 -0
- eval_protocol/auth.py +156 -0
- eval_protocol/cli.py +425 -0
- eval_protocol/cli_commands/__init__.py +1 -0
- eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
- eval_protocol/cli_commands/common.py +242 -0
- eval_protocol/cli_commands/deploy.py +486 -0
- eval_protocol/cli_commands/deploy_mcp.py +287 -0
- eval_protocol/cli_commands/preview.py +186 -0
- eval_protocol/cli_commands/run_eval_cmd.py +202 -0
- eval_protocol/common_utils.py +36 -0
- eval_protocol/config.py +180 -0
- eval_protocol/datasets/__init__.py +1 -0
- eval_protocol/datasets/loader.py +521 -0
- eval_protocol/evaluation.py +1045 -0
- eval_protocol/execution/__init__.py +1 -0
- eval_protocol/execution/pipeline.py +920 -0
- eval_protocol/gcp_tools.py +484 -0
- eval_protocol/generation/cache.py +141 -0
- eval_protocol/generation/clients/base.py +67 -0
- eval_protocol/generation/clients.py +248 -0
- eval_protocol/generic_server.py +165 -0
- eval_protocol/integrations/__init__.py +12 -0
- eval_protocol/integrations/braintrust.py +51 -0
- eval_protocol/integrations/deepeval.py +106 -0
- eval_protocol/integrations/openeval.py +40 -0
- eval_protocol/integrations/trl.py +187 -0
- eval_protocol/mcp/__init__.py +48 -0
- eval_protocol/mcp/adapter.py +131 -0
- eval_protocol/mcp/client/__init__.py +12 -0
- eval_protocol/mcp/client/connection.py +499 -0
- eval_protocol/mcp/clients.py +195 -0
- eval_protocol/mcp/execution/__init__.py +23 -0
- eval_protocol/mcp/execution/base_policy.py +227 -0
- eval_protocol/mcp/execution/fireworks_policy.py +209 -0
- eval_protocol/mcp/execution/manager.py +506 -0
- eval_protocol/mcp/execution/policy.py +421 -0
- eval_protocol/mcp/grid_renderer.py +54 -0
- eval_protocol/mcp/mcpgym.py +637 -0
- eval_protocol/mcp/process_manager.py +177 -0
- eval_protocol/mcp/session/__init__.py +11 -0
- eval_protocol/mcp/session/manager.py +228 -0
- eval_protocol/mcp/simple_process_manager.py +291 -0
- eval_protocol/mcp/simulation_server.py +458 -0
- eval_protocol/mcp/types.py +80 -0
- eval_protocol/mcp_agent/__init__.py +1 -0
- eval_protocol/mcp_agent/config.py +147 -0
- eval_protocol/mcp_agent/intermediary_server.py +542 -0
- eval_protocol/mcp_agent/main.py +210 -0
- eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
- eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
- eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
- eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
- eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
- eval_protocol/mcp_agent/session.py +79 -0
- eval_protocol/mcp_env.py +304 -0
- eval_protocol/models.py +366 -0
- eval_protocol/packaging.py +219 -0
- eval_protocol/platform_api.py +360 -0
- eval_protocol/playback_policy.py +396 -0
- eval_protocol/resources.py +128 -0
- eval_protocol/reward_function.py +410 -0
- eval_protocol/rewards/__init__.py +94 -0
- eval_protocol/rewards/accuracy.py +454 -0
- eval_protocol/rewards/accuracy_length.py +173 -0
- eval_protocol/rewards/apps_coding_reward.py +331 -0
- eval_protocol/rewards/apps_execution_utils.py +149 -0
- eval_protocol/rewards/apps_testing_util.py +559 -0
- eval_protocol/rewards/bfcl_reward.py +313 -0
- eval_protocol/rewards/code_execution.py +1620 -0
- eval_protocol/rewards/code_execution_utils.py +72 -0
- eval_protocol/rewards/cpp_code.py +861 -0
- eval_protocol/rewards/deepcoder_reward.py +161 -0
- eval_protocol/rewards/format.py +129 -0
- eval_protocol/rewards/function_calling.py +541 -0
- eval_protocol/rewards/json_schema.py +422 -0
- eval_protocol/rewards/language_consistency.py +700 -0
- eval_protocol/rewards/lean_prover.py +479 -0
- eval_protocol/rewards/length.py +375 -0
- eval_protocol/rewards/list_comparison_math_reward.py +221 -0
- eval_protocol/rewards/math.py +762 -0
- eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
- eval_protocol/rewards/reasoning_steps.py +249 -0
- eval_protocol/rewards/repetition.py +342 -0
- eval_protocol/rewards/tag_count.py +162 -0
- eval_protocol/rl_processing.py +82 -0
- eval_protocol/server.py +271 -0
- eval_protocol/typed_interface.py +260 -0
- eval_protocol/utils/__init__.py +8 -0
- eval_protocol/utils/batch_evaluation.py +217 -0
- eval_protocol/utils/batch_transformation.py +205 -0
- eval_protocol/utils/dataset_helpers.py +112 -0
- eval_protocol/utils/module_loader.py +56 -0
- eval_protocol/utils/packaging_utils.py +108 -0
- eval_protocol/utils/static_policy.py +305 -0
- eval_protocol-0.0.3.dist-info/METADATA +635 -0
- eval_protocol-0.0.3.dist-info/RECORD +130 -0
- eval_protocol-0.0.3.dist-info/WHEEL +5 -0
- eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
- eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
- eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,920 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Core evaluation execution pipeline for reward-kit.
|
|
3
|
+
This module orchestrates dataset loading, model response generation (optional),
|
|
4
|
+
and evaluation using specified reward functions.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
from typing import Any, Dict, List, Optional, Union
|
|
12
|
+
|
|
13
|
+
import aiohttp
|
|
14
|
+
import hydra
|
|
15
|
+
from datasets import Dataset, DatasetDict
|
|
16
|
+
from hydra.errors import InstantiationException
|
|
17
|
+
from omegaconf import DictConfig, OmegaConf
|
|
18
|
+
|
|
19
|
+
from eval_protocol.auth import get_fireworks_api_key
|
|
20
|
+
from eval_protocol.generation.cache import ResponseCache
|
|
21
|
+
from eval_protocol.generation.clients import FireworksModelClient, ModelClient
|
|
22
|
+
from eval_protocol.mcp.clients import IntermediaryMCPClient
|
|
23
|
+
from eval_protocol.models import Message
|
|
24
|
+
from eval_protocol.utils.module_loader import load_function as load_reward_function
|
|
25
|
+
from eval_protocol.utils.packaging_utils import install_requirements
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class EvaluationPipeline:
|
|
31
|
+
def __init__(self, pipeline_cfg: DictConfig):
|
|
32
|
+
self.cfg = pipeline_cfg # Root config for this pipeline run
|
|
33
|
+
|
|
34
|
+
# Initialize components based on config
|
|
35
|
+
logger.info("Initializing EvaluationPipeline...")
|
|
36
|
+
|
|
37
|
+
self.model_client: Optional[ModelClient] = None
|
|
38
|
+
if self.cfg.generation.enabled:
|
|
39
|
+
api_key = get_fireworks_api_key()
|
|
40
|
+
if not api_key:
|
|
41
|
+
logger.error("Fireworks API key not found, but generation is enabled.")
|
|
42
|
+
raise ValueError("API key required for Fireworks model client when generation is enabled.")
|
|
43
|
+
self.model_client = FireworksModelClient(
|
|
44
|
+
client_config=self.cfg.generation,
|
|
45
|
+
api_key=api_key,
|
|
46
|
+
)
|
|
47
|
+
logger.info(f"Initialized FireworksModelClient for model: {self.cfg.generation.model_name}")
|
|
48
|
+
|
|
49
|
+
self.cache = ResponseCache(self.cfg.generation.cache)
|
|
50
|
+
logger.info("ResponseCache initialized.")
|
|
51
|
+
|
|
52
|
+
self.reward_function = load_reward_function(self.cfg.reward.function_path)
|
|
53
|
+
logger.info(f"Loaded reward function from: {self.cfg.reward.function_path}")
|
|
54
|
+
|
|
55
|
+
# Install requirements if specified by the decorator
|
|
56
|
+
if hasattr(self.reward_function, "_reward_function_requirements"):
|
|
57
|
+
requirements = getattr(self.reward_function, "_reward_function_requirements")
|
|
58
|
+
if isinstance(requirements, list) and requirements:
|
|
59
|
+
logger.info(f"Found requirements for reward function {self.cfg.reward.function_path}: {requirements}")
|
|
60
|
+
try:
|
|
61
|
+
# Assuming install_requirements uses the current environment's pip by default
|
|
62
|
+
install_requirements(requirements_list=requirements)
|
|
63
|
+
logger.info(f"Successfully processed requirements for {self.cfg.reward.function_path}.")
|
|
64
|
+
except Exception as e:
|
|
65
|
+
logger.error(
|
|
66
|
+
f"Failed to install requirements for {self.cfg.reward.function_path}: {e}",
|
|
67
|
+
exc_info=True,
|
|
68
|
+
)
|
|
69
|
+
# Depending on policy, might re-raise or allow continuation if some are optional
|
|
70
|
+
# For now, log error and continue; pip install errors are already logged by the utility.
|
|
71
|
+
# If strict, could raise RuntimeError here.
|
|
72
|
+
elif requirements: # Not a list or empty
|
|
73
|
+
logger.warning(
|
|
74
|
+
f"_reward_function_requirements for {self.cfg.reward.function_path} is not a non-empty list: {requirements}"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
self.mcp_intermediary_client: Optional[IntermediaryMCPClient] = None
|
|
78
|
+
if self.cfg.get("agent") and self.cfg.agent.get("type") == "mcp_agent":
|
|
79
|
+
if not self.cfg.agent.get("intermediary_server_url"):
|
|
80
|
+
raise ValueError("agent.intermediary_server_url must be configured for mcp_agent type.")
|
|
81
|
+
logger.info(f"Pipeline configured for mcp_agent. IntermediaryMCPClient will be initialized in run().")
|
|
82
|
+
|
|
83
|
+
async def _discover_tools_for_sample(self, sample_id: str, mcp_backend_ref: str) -> List[Dict[str, Any]]:
|
|
84
|
+
"""Discover available tools from MCP backend for a sample."""
|
|
85
|
+
discovered_tools = []
|
|
86
|
+
rk_session_id = None
|
|
87
|
+
|
|
88
|
+
try:
|
|
89
|
+
backend_requests = [{"backend_name_ref": mcp_backend_ref, "num_instances": 1}]
|
|
90
|
+
init_response = await self.mcp_intermediary_client.initialize_session(backend_requests)
|
|
91
|
+
|
|
92
|
+
if init_response.get("error"):
|
|
93
|
+
raise RuntimeError(
|
|
94
|
+
f"MCP session for tool discovery failed: {init_response.get('error_details', init_response['error'])}"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
rk_session_id = init_response.get("rk_session_id")
|
|
98
|
+
initialized_backends = init_response.get("initialized_backends", [])
|
|
99
|
+
|
|
100
|
+
if not rk_session_id or not initialized_backends:
|
|
101
|
+
raise RuntimeError(f"Malformed init response for tool discovery: {init_response}")
|
|
102
|
+
|
|
103
|
+
for backend_info in initialized_backends:
|
|
104
|
+
current_backend_name_ref = backend_info.get("backend_name_ref")
|
|
105
|
+
instances_info = backend_info.get("instances", [])
|
|
106
|
+
if not current_backend_name_ref or not instances_info:
|
|
107
|
+
continue
|
|
108
|
+
for inst_info_dict in instances_info:
|
|
109
|
+
current_instance_id = inst_info_dict.get("instance_id")
|
|
110
|
+
if not current_instance_id:
|
|
111
|
+
continue
|
|
112
|
+
list_tools_result = await self.mcp_intermediary_client.list_backend_tools(
|
|
113
|
+
rk_session_id=rk_session_id,
|
|
114
|
+
instance_id=current_instance_id,
|
|
115
|
+
backend_name_ref=current_backend_name_ref,
|
|
116
|
+
)
|
|
117
|
+
if list_tools_result and list_tools_result.tools:
|
|
118
|
+
for tool_obj in list_tools_result.tools:
|
|
119
|
+
discovered_tools.append(tool_obj.model_dump(exclude_none=True))
|
|
120
|
+
|
|
121
|
+
logger.info(f"Sample {sample_id}: Discovered {len(discovered_tools)} tools.")
|
|
122
|
+
|
|
123
|
+
except Exception as e_tool_discovery:
|
|
124
|
+
logger.error(
|
|
125
|
+
f"Sample {sample_id}: Error during tool discovery: {e_tool_discovery}",
|
|
126
|
+
exc_info=True,
|
|
127
|
+
)
|
|
128
|
+
discovered_tools = []
|
|
129
|
+
finally:
|
|
130
|
+
if rk_session_id and self.mcp_intermediary_client:
|
|
131
|
+
logger.info(f"Sample {sample_id}: Cleaning up tool discovery session '{rk_session_id}'.")
|
|
132
|
+
try:
|
|
133
|
+
await self.mcp_intermediary_client.cleanup_session(rk_session_id)
|
|
134
|
+
except Exception as e_cl:
|
|
135
|
+
logger.error(
|
|
136
|
+
f"Error cleaning up discovery session '{rk_session_id}': {e_cl}",
|
|
137
|
+
exc_info=True,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
return discovered_tools
|
|
141
|
+
|
|
142
|
+
async def _process_n_variants(
|
|
143
|
+
self,
|
|
144
|
+
sample: Dict[str, Any],
|
|
145
|
+
sample_id: str,
|
|
146
|
+
user_query: Optional[str],
|
|
147
|
+
ground_truth_for_eval: Optional[str],
|
|
148
|
+
existing_messages: Optional[List[Dict[str, Any]]],
|
|
149
|
+
http_session: Optional[aiohttp.ClientSession],
|
|
150
|
+
n_variants: int,
|
|
151
|
+
original_index: Optional[int] = None,
|
|
152
|
+
) -> List[Dict[str, Any]]:
|
|
153
|
+
"""Process a sample to generate N variants of responses."""
|
|
154
|
+
results = []
|
|
155
|
+
|
|
156
|
+
for variant_idx in range(n_variants):
|
|
157
|
+
# Create a variant-specific sample ID
|
|
158
|
+
variant_sample_id = f"{sample_id}_v{variant_idx}"
|
|
159
|
+
|
|
160
|
+
# Create a modified sample for this variant
|
|
161
|
+
variant_sample = sample.copy()
|
|
162
|
+
variant_sample["id"] = variant_sample_id
|
|
163
|
+
|
|
164
|
+
# Temporarily set n=1 to avoid infinite recursion
|
|
165
|
+
original_n = self.cfg.generation.get("n", 1)
|
|
166
|
+
self.cfg.generation.n = 1
|
|
167
|
+
|
|
168
|
+
try:
|
|
169
|
+
# Process this variant as a single sample
|
|
170
|
+
variant_result = await self._process_single_sample_internal(
|
|
171
|
+
sample=variant_sample,
|
|
172
|
+
http_session=http_session,
|
|
173
|
+
original_index=original_index,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
if variant_result is not None:
|
|
177
|
+
# Add variant metadata for batch evaluation compatibility
|
|
178
|
+
variant_result["request_id"] = sample_id # Original sample ID for grouping
|
|
179
|
+
variant_result["response_id"] = variant_idx # Variant index within the request
|
|
180
|
+
results.append(variant_result)
|
|
181
|
+
|
|
182
|
+
except Exception as e:
|
|
183
|
+
logger.error(
|
|
184
|
+
f"Error processing variant {variant_idx} for sample {sample_id}: {e}",
|
|
185
|
+
exc_info=True,
|
|
186
|
+
)
|
|
187
|
+
# Add error result for this variant
|
|
188
|
+
results.append(
|
|
189
|
+
{
|
|
190
|
+
"id": variant_sample_id,
|
|
191
|
+
"request_id": sample_id,
|
|
192
|
+
"response_id": variant_idx,
|
|
193
|
+
"error": f"Variant processing failed: {str(e)}",
|
|
194
|
+
"evaluation_score": 0.0,
|
|
195
|
+
}
|
|
196
|
+
)
|
|
197
|
+
finally:
|
|
198
|
+
# Restore original n value
|
|
199
|
+
self.cfg.generation.n = original_n
|
|
200
|
+
|
|
201
|
+
return results
|
|
202
|
+
|
|
203
|
+
async def _execute_standard_generation(
|
|
204
|
+
self,
|
|
205
|
+
sample_id: str,
|
|
206
|
+
user_query: str,
|
|
207
|
+
system_prompt_content: Optional[str],
|
|
208
|
+
http_session: aiohttp.ClientSession,
|
|
209
|
+
) -> Dict[str, Any]:
|
|
210
|
+
"""Execute standard LLM generation without agent capabilities."""
|
|
211
|
+
current_messages_for_rollout = []
|
|
212
|
+
if system_prompt_content:
|
|
213
|
+
current_messages_for_rollout.append({"role": "system", "content": system_prompt_content})
|
|
214
|
+
current_messages_for_rollout.append({"role": "user", "content": user_query})
|
|
215
|
+
|
|
216
|
+
generation_output_std = await self.model_client.generate(
|
|
217
|
+
messages=current_messages_for_rollout,
|
|
218
|
+
session=http_session,
|
|
219
|
+
tools=None, # No tools for non-agent
|
|
220
|
+
)
|
|
221
|
+
final_assistant_output_for_log = generation_output_std.content
|
|
222
|
+
|
|
223
|
+
if not final_assistant_output_for_log:
|
|
224
|
+
logger.warning(f"Sample {sample_id}: Standard generation resulted in no content.")
|
|
225
|
+
final_assistant_output_for_log = "LLM provided no content."
|
|
226
|
+
|
|
227
|
+
# Cache standard generation if applicable
|
|
228
|
+
if (
|
|
229
|
+
final_assistant_output_for_log
|
|
230
|
+
and self.cfg.generation.cache.enabled
|
|
231
|
+
and self.model_client.temperature == 0.0
|
|
232
|
+
):
|
|
233
|
+
self.cache.put(
|
|
234
|
+
sample_id=sample_id,
|
|
235
|
+
system_prompt=system_prompt_content,
|
|
236
|
+
user_query=user_query,
|
|
237
|
+
model_name=self.model_client.model_name,
|
|
238
|
+
temperature=self.model_client.temperature,
|
|
239
|
+
response=final_assistant_output_for_log,
|
|
240
|
+
top_p=self.model_client.top_p,
|
|
241
|
+
top_k=self.model_client.top_k,
|
|
242
|
+
min_p=self.model_client.min_p,
|
|
243
|
+
max_tokens=self.model_client.max_tokens,
|
|
244
|
+
reasoning_effort=self.model_client.reasoning_effort,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
return {
|
|
248
|
+
"success": True,
|
|
249
|
+
"final_assistant_output": final_assistant_output_for_log,
|
|
250
|
+
"conversation_history": current_messages_for_rollout
|
|
251
|
+
+ [{"role": "assistant", "content": final_assistant_output_for_log}],
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
async def _execute_mcp_agent_rollout(
|
|
255
|
+
self,
|
|
256
|
+
sample_id: str,
|
|
257
|
+
user_query: str,
|
|
258
|
+
system_prompt_content: Optional[str],
|
|
259
|
+
openai_formatted_tools: Optional[List[Dict[str, Any]]],
|
|
260
|
+
http_session: aiohttp.ClientSession,
|
|
261
|
+
discovered_tools_for_llm_prompt: List[Dict[str, Any]],
|
|
262
|
+
) -> Dict[str, Any]:
|
|
263
|
+
"""Execute MCP agent rollout with tool calling."""
|
|
264
|
+
mcp_backend_ref = self.cfg.agent.get("mcp_backend_ref")
|
|
265
|
+
rk_session_id = None
|
|
266
|
+
all_executed_tool_calls_for_sample = []
|
|
267
|
+
final_llm_text_response = None
|
|
268
|
+
final_filesystem_state_from_mcp = None
|
|
269
|
+
|
|
270
|
+
# Initial messages for the rollout
|
|
271
|
+
current_messages_for_rollout = []
|
|
272
|
+
if system_prompt_content:
|
|
273
|
+
current_messages_for_rollout.append({"role": "system", "content": system_prompt_content})
|
|
274
|
+
current_messages_for_rollout.append({"role": "user", "content": user_query})
|
|
275
|
+
|
|
276
|
+
try:
|
|
277
|
+
backend_requests = [{"backend_name_ref": mcp_backend_ref, "num_instances": 1}]
|
|
278
|
+
init_response = await self.mcp_intermediary_client.initialize_session(backend_requests)
|
|
279
|
+
if init_response.get("error"):
|
|
280
|
+
raise RuntimeError(
|
|
281
|
+
f"Main MCP session init failed: {init_response.get('error_details', init_response['error'])}"
|
|
282
|
+
)
|
|
283
|
+
rk_session_id = init_response.get("rk_session_id")
|
|
284
|
+
|
|
285
|
+
primary_instance_id_for_agent_actions = None
|
|
286
|
+
initialized_backends = init_response.get("initialized_backends", [])
|
|
287
|
+
if not rk_session_id or not initialized_backends:
|
|
288
|
+
raise RuntimeError(f"Malformed main MCP init response: {init_response}")
|
|
289
|
+
for be_info in initialized_backends:
|
|
290
|
+
if be_info.get("backend_name_ref") == mcp_backend_ref and be_info.get("instances"):
|
|
291
|
+
primary_instance_id_for_agent_actions = be_info["instances"][0].get("instance_id")
|
|
292
|
+
break
|
|
293
|
+
if not primary_instance_id_for_agent_actions:
|
|
294
|
+
raise RuntimeError(f"Primary instance ID for agent actions not found for {mcp_backend_ref}")
|
|
295
|
+
|
|
296
|
+
logger.info(
|
|
297
|
+
f"Sample {sample_id}: Main MCP session for agent execution. rk_session_id='{rk_session_id}', instance='{primary_instance_id_for_agent_actions}'."
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
max_rollout_turns = self.cfg.agent.get("max_rollout_turns", 5)
|
|
301
|
+
final_assistant_output_for_log = None
|
|
302
|
+
|
|
303
|
+
for turn_num in range(max_rollout_turns):
|
|
304
|
+
logger.info(
|
|
305
|
+
f"Sample {sample_id}: Agent Rollout Turn {turn_num + 1}/{max_rollout_turns}. History size: {len(current_messages_for_rollout)}"
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
generation_output_turn = await self.model_client.generate(
|
|
309
|
+
messages=current_messages_for_rollout,
|
|
310
|
+
session=http_session,
|
|
311
|
+
tools=openai_formatted_tools,
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
assistant_msg_for_history = {"role": "assistant"}
|
|
315
|
+
|
|
316
|
+
if generation_output_turn.tool_calls:
|
|
317
|
+
assistant_msg_for_history["tool_calls"] = [
|
|
318
|
+
tc.model_dump() for tc in generation_output_turn.tool_calls
|
|
319
|
+
]
|
|
320
|
+
current_messages_for_rollout.append(assistant_msg_for_history)
|
|
321
|
+
final_assistant_output_for_log = json.dumps(assistant_msg_for_history["tool_calls"])
|
|
322
|
+
|
|
323
|
+
for tool_call in generation_output_turn.tool_calls:
|
|
324
|
+
tool_name = tool_call.function.name
|
|
325
|
+
tool_call_id = tool_call.id
|
|
326
|
+
tool_args_dict = None
|
|
327
|
+
tool_result_content_str = ""
|
|
328
|
+
try:
|
|
329
|
+
tool_args_dict = json.loads(tool_call.function.arguments)
|
|
330
|
+
if not isinstance(tool_args_dict, dict):
|
|
331
|
+
raise ValueError("Args not dict")
|
|
332
|
+
|
|
333
|
+
exec_result = await self.mcp_intermediary_client.call_backend_tool(
|
|
334
|
+
rk_session_id=rk_session_id,
|
|
335
|
+
instance_id=primary_instance_id_for_agent_actions,
|
|
336
|
+
backend_name_ref=mcp_backend_ref,
|
|
337
|
+
tool_name=tool_name,
|
|
338
|
+
tool_args=tool_args_dict,
|
|
339
|
+
)
|
|
340
|
+
tool_result_content_str = json.dumps(exec_result)
|
|
341
|
+
all_executed_tool_calls_for_sample.append(
|
|
342
|
+
{
|
|
343
|
+
"tool_call_id": tool_call_id,
|
|
344
|
+
"name": tool_name,
|
|
345
|
+
"arguments": tool_args_dict,
|
|
346
|
+
"result": exec_result,
|
|
347
|
+
}
|
|
348
|
+
)
|
|
349
|
+
except Exception as e_tool_exec:
|
|
350
|
+
logger.error(
|
|
351
|
+
f"Sample {sample_id}, Turn {turn_num+1}: Error executing/parsing tool '{tool_name}': {e_tool_exec}",
|
|
352
|
+
exc_info=True,
|
|
353
|
+
)
|
|
354
|
+
error_payload = {"error": str(e_tool_exec)}
|
|
355
|
+
if isinstance(e_tool_exec, json.JSONDecodeError):
|
|
356
|
+
error_payload["detail"] = "Failed to parse arguments string from LLM."
|
|
357
|
+
tool_result_content_str = json.dumps(error_payload)
|
|
358
|
+
all_executed_tool_calls_for_sample.append(
|
|
359
|
+
{
|
|
360
|
+
"tool_call_id": tool_call_id,
|
|
361
|
+
"name": tool_name,
|
|
362
|
+
"arguments_str": tool_call.function.arguments,
|
|
363
|
+
"error": str(e_tool_exec),
|
|
364
|
+
}
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
current_messages_for_rollout.append(
|
|
368
|
+
{
|
|
369
|
+
"role": "tool",
|
|
370
|
+
"tool_call_id": tool_call_id,
|
|
371
|
+
"name": tool_name,
|
|
372
|
+
"content": tool_result_content_str,
|
|
373
|
+
}
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
if turn_num == max_rollout_turns - 1:
|
|
377
|
+
logger.warning(f"Sample {sample_id}: Max rollout turns reached after tool call(s).")
|
|
378
|
+
|
|
379
|
+
elif generation_output_turn.content:
|
|
380
|
+
final_llm_text_response = generation_output_turn.content
|
|
381
|
+
assistant_msg_for_history["content"] = final_llm_text_response
|
|
382
|
+
current_messages_for_rollout.append(assistant_msg_for_history)
|
|
383
|
+
final_assistant_output_for_log = final_llm_text_response
|
|
384
|
+
logger.info(f"Sample {sample_id}, Turn {turn_num+1}: LLM responded with text. Ending rollout.")
|
|
385
|
+
break
|
|
386
|
+
else:
|
|
387
|
+
logger.warning(
|
|
388
|
+
f"Sample {sample_id}, Turn {turn_num+1}: LLM provided no content or tool calls. Ending rollout."
|
|
389
|
+
)
|
|
390
|
+
final_llm_text_response = "LLM provided no actionable response in this turn."
|
|
391
|
+
assistant_msg_for_history["content"] = final_llm_text_response
|
|
392
|
+
current_messages_for_rollout.append(assistant_msg_for_history)
|
|
393
|
+
final_assistant_output_for_log = final_llm_text_response
|
|
394
|
+
break
|
|
395
|
+
|
|
396
|
+
if (
|
|
397
|
+
not final_llm_text_response
|
|
398
|
+
and not all_executed_tool_calls_for_sample
|
|
399
|
+
and not final_assistant_output_for_log
|
|
400
|
+
):
|
|
401
|
+
final_assistant_output_for_log = "Agent did not produce text or tool calls within max turns."
|
|
402
|
+
|
|
403
|
+
# State Capture
|
|
404
|
+
state_capture_tool = self.cfg.agent.get("state_capture_tool")
|
|
405
|
+
if state_capture_tool:
|
|
406
|
+
state_capture_args = dict(self.cfg.agent.get("state_capture_args", OmegaConf.create({})))
|
|
407
|
+
final_filesystem_state_from_mcp = await self.mcp_intermediary_client.call_backend_tool(
|
|
408
|
+
rk_session_id=rk_session_id,
|
|
409
|
+
instance_id=primary_instance_id_for_agent_actions,
|
|
410
|
+
backend_name_ref=mcp_backend_ref,
|
|
411
|
+
tool_name=state_capture_tool,
|
|
412
|
+
tool_args=state_capture_args,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
return {
|
|
416
|
+
"success": True,
|
|
417
|
+
"final_assistant_output": final_assistant_output_for_log,
|
|
418
|
+
"conversation_history": current_messages_for_rollout,
|
|
419
|
+
"executed_tool_calls": all_executed_tool_calls_for_sample,
|
|
420
|
+
"final_filesystem_state": final_filesystem_state_from_mcp,
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
except Exception as e_mcp_main:
|
|
424
|
+
logger.error(
|
|
425
|
+
f"Error during MCP agent main processing for sample {sample_id}: {e_mcp_main}",
|
|
426
|
+
exc_info=True,
|
|
427
|
+
)
|
|
428
|
+
return {
|
|
429
|
+
"success": False,
|
|
430
|
+
"error": f"MCP agent processing failed: {str(e_mcp_main)}",
|
|
431
|
+
}
|
|
432
|
+
finally:
|
|
433
|
+
if rk_session_id and self.mcp_intermediary_client:
|
|
434
|
+
await self.mcp_intermediary_client.cleanup_session(rk_session_id)
|
|
435
|
+
|
|
436
|
+
async def _process_single_sample(
|
|
437
|
+
self,
|
|
438
|
+
sample: Dict[str, Any],
|
|
439
|
+
http_session: Optional[aiohttp.ClientSession], # For model_client, not mcp_client
|
|
440
|
+
original_index: Optional[int] = None,
|
|
441
|
+
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
|
|
442
|
+
"""Main entry point for processing a single sample, handling N-variant generation."""
|
|
443
|
+
sample_id_fallback = (
|
|
444
|
+
f"idx_{original_index}" if original_index is not None else "unknown_id_" + os.urandom(4).hex()
|
|
445
|
+
)
|
|
446
|
+
sample_id = sample.get("id", sample_id_fallback)
|
|
447
|
+
user_query = sample.get("user_query")
|
|
448
|
+
ground_truth_for_eval = sample.get("ground_truth_for_eval")
|
|
449
|
+
existing_messages = sample.get("messages")
|
|
450
|
+
|
|
451
|
+
# Check for N-variant generation
|
|
452
|
+
n_variants = self.cfg.generation.get("n", 1)
|
|
453
|
+
if not isinstance(n_variants, int) or n_variants < 1:
|
|
454
|
+
n_variants = 1
|
|
455
|
+
|
|
456
|
+
# If N > 1, generate multiple variants
|
|
457
|
+
if n_variants > 1:
|
|
458
|
+
return await self._process_n_variants(
|
|
459
|
+
sample=sample,
|
|
460
|
+
sample_id=sample_id,
|
|
461
|
+
user_query=user_query,
|
|
462
|
+
ground_truth_for_eval=ground_truth_for_eval,
|
|
463
|
+
existing_messages=existing_messages,
|
|
464
|
+
http_session=http_session,
|
|
465
|
+
n_variants=n_variants,
|
|
466
|
+
original_index=original_index,
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
# Single variant processing
|
|
470
|
+
return await self._process_single_sample_internal(
|
|
471
|
+
sample=sample,
|
|
472
|
+
http_session=http_session,
|
|
473
|
+
original_index=original_index,
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
async def _process_single_sample_internal(
|
|
477
|
+
self,
|
|
478
|
+
sample: Dict[str, Any],
|
|
479
|
+
http_session: Optional[aiohttp.ClientSession], # For model_client, not mcp_client
|
|
480
|
+
original_index: Optional[int] = None,
|
|
481
|
+
) -> Optional[Dict[str, Any]]:
|
|
482
|
+
sample_id_fallback = (
|
|
483
|
+
f"idx_{original_index}" if original_index is not None else "unknown_id_" + os.urandom(4).hex()
|
|
484
|
+
)
|
|
485
|
+
sample_id = sample.get("id", sample_id_fallback)
|
|
486
|
+
user_query = sample.get("user_query")
|
|
487
|
+
ground_truth_for_eval = sample.get("ground_truth_for_eval")
|
|
488
|
+
existing_messages = sample.get("messages")
|
|
489
|
+
|
|
490
|
+
# Check if we have either the generation format (user_query + ground_truth)
|
|
491
|
+
# or the evaluation format (existing messages)
|
|
492
|
+
has_generation_format = user_query is not None and ground_truth_for_eval is not None
|
|
493
|
+
has_evaluation_format = existing_messages is not None
|
|
494
|
+
|
|
495
|
+
if not has_generation_format and not has_evaluation_format:
|
|
496
|
+
logger.warning(
|
|
497
|
+
f"Skipping sample {sample_id}: needs either ('user_query' + 'ground_truth_for_eval') for generation or 'messages' for evaluation."
|
|
498
|
+
)
|
|
499
|
+
return None
|
|
500
|
+
|
|
501
|
+
original_system_prompt = sample.get("system_prompt") or self.cfg.get("system_prompt")
|
|
502
|
+
discovered_tools_for_llm_prompt: List[Dict[str, Any]] = []
|
|
503
|
+
openai_formatted_tools: Optional[List[Dict[str, Any]]] = None
|
|
504
|
+
|
|
505
|
+
# This variable will hold the final assistant response string for top-level logging/preview
|
|
506
|
+
# It might be a text response or a JSON string of the last tool call request by LLM.
|
|
507
|
+
final_assistant_output_for_log: Optional[str] = None
|
|
508
|
+
|
|
509
|
+
# --- Pre-generation: Tool Discovery (if MCP agent) ---
|
|
510
|
+
if self.mcp_intermediary_client and self.cfg.agent.type == "mcp_agent":
|
|
511
|
+
mcp_backend_ref_for_tools = self.cfg.agent.get("mcp_backend_ref")
|
|
512
|
+
if not mcp_backend_ref_for_tools:
|
|
513
|
+
raise ValueError("agent.mcp_backend_ref must be configured for mcp_agent tool discovery.")
|
|
514
|
+
discovered_tools_for_llm_prompt = await self._discover_tools_for_sample(
|
|
515
|
+
sample_id, mcp_backend_ref_for_tools
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
# --- Construct System Prompt and Format Tools for LLM ---
|
|
519
|
+
system_prompt_content = original_system_prompt
|
|
520
|
+
if self.mcp_intermediary_client and self.cfg.agent.type == "mcp_agent" and discovered_tools_for_llm_prompt:
|
|
521
|
+
openai_formatted_tools = []
|
|
522
|
+
for mcp_tool_dict in discovered_tools_for_llm_prompt:
|
|
523
|
+
input_schema = mcp_tool_dict.get("inputSchema", {})
|
|
524
|
+
openai_formatted_tools.append(
|
|
525
|
+
{
|
|
526
|
+
"type": "function",
|
|
527
|
+
"function": {
|
|
528
|
+
"name": mcp_tool_dict.get("name", "unknown"),
|
|
529
|
+
"description": mcp_tool_dict.get("description", ""),
|
|
530
|
+
"parameters": input_schema,
|
|
531
|
+
},
|
|
532
|
+
}
|
|
533
|
+
)
|
|
534
|
+
if original_system_prompt:
|
|
535
|
+
system_prompt_content = (
|
|
536
|
+
f"{original_system_prompt}\n\nYou have access to tools. Use them if appropriate."
|
|
537
|
+
)
|
|
538
|
+
else:
|
|
539
|
+
system_prompt_content = "You are a helpful assistant with access to tools. Use them if appropriate."
|
|
540
|
+
|
|
541
|
+
# Handle existing messages for evaluation vs building new messages for generation
|
|
542
|
+
if has_evaluation_format and not has_generation_format:
|
|
543
|
+
# Evaluation mode: use existing messages, skip generation
|
|
544
|
+
if not self.cfg.generation.enabled:
|
|
545
|
+
# Pass raw messages as-is for evaluation mode - the @reward_function decorator
|
|
546
|
+
# will handle conversion based on the function's type annotations
|
|
547
|
+
final_messages_for_eval = existing_messages
|
|
548
|
+
|
|
549
|
+
# Call reward function directly with existing messages
|
|
550
|
+
eval_params = dict(self.cfg.reward.get("params", OmegaConf.create({})))
|
|
551
|
+
|
|
552
|
+
# Extract per-sample values and add them to eval_params
|
|
553
|
+
sample_agent_id = sample.get("agent_id")
|
|
554
|
+
sample_test_id = sample.get("test_id")
|
|
555
|
+
if sample_agent_id is not None:
|
|
556
|
+
eval_params["agent_id"] = sample_agent_id
|
|
557
|
+
if sample_test_id is not None:
|
|
558
|
+
eval_params["test_id"] = sample_test_id
|
|
559
|
+
|
|
560
|
+
eval_result_obj = self.reward_function(
|
|
561
|
+
messages=final_messages_for_eval,
|
|
562
|
+
ground_truth=ground_truth_for_eval,
|
|
563
|
+
final_filesystem_state=None,
|
|
564
|
+
**eval_params,
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
return {
|
|
568
|
+
"id": sample_id,
|
|
569
|
+
"user_query": None, # Not applicable for evaluation mode
|
|
570
|
+
"system_prompt": system_prompt_content,
|
|
571
|
+
"assistant_response": None, # Not applicable for evaluation mode
|
|
572
|
+
"ground_truth_for_eval": ground_truth_for_eval,
|
|
573
|
+
"evaluation_score": eval_result_obj.score,
|
|
574
|
+
"evaluation_reason": eval_result_obj.reason,
|
|
575
|
+
"evaluation_metrics": (
|
|
576
|
+
{k: v.model_dump() for k, v in eval_result_obj.metrics.items()}
|
|
577
|
+
if eval_result_obj.metrics
|
|
578
|
+
else {}
|
|
579
|
+
),
|
|
580
|
+
"full_conversation_history": existing_messages, # Store original messages
|
|
581
|
+
}
|
|
582
|
+
else:
|
|
583
|
+
logger.warning(f"Sample {sample_id}: Evaluation mode requires generation.enabled=false")
|
|
584
|
+
return None
|
|
585
|
+
|
|
586
|
+
# Generation mode: Initial messages for the main rollout (or single generation if not agent)
|
|
587
|
+
current_messages_for_rollout: List[Dict[str, Any]] = []
|
|
588
|
+
if system_prompt_content:
|
|
589
|
+
current_messages_for_rollout.append({"role": "system", "content": system_prompt_content})
|
|
590
|
+
current_messages_for_rollout.append({"role": "user", "content": user_query})
|
|
591
|
+
|
|
592
|
+
# --- LLM Generation / Agent Rollout ---
|
|
593
|
+
if not self.cfg.generation.enabled:
|
|
594
|
+
# ... (existing logic for disabled generation, using assistant_response_content from sample or cache) ...
|
|
595
|
+
# This part needs to ensure final_assistant_output_for_log is set.
|
|
596
|
+
# For brevity, assuming this part correctly sets final_assistant_output_for_log if generation is disabled.
|
|
597
|
+
assistant_response_col_name = self.cfg.dataset.get("column_mapping", {}).get(
|
|
598
|
+
"assistant_response_column", "assistant_response"
|
|
599
|
+
)
|
|
600
|
+
final_assistant_output_for_log = sample.get(assistant_response_col_name)
|
|
601
|
+
# ... (rest of the non-generation logic)
|
|
602
|
+
if not final_assistant_output_for_log: # Try cache if generation disabled and no direct column
|
|
603
|
+
gen_cfg = self.cfg.generation
|
|
604
|
+
final_assistant_output_for_log = self.cache.get(
|
|
605
|
+
sample_id=sample_id,
|
|
606
|
+
system_prompt=original_system_prompt,
|
|
607
|
+
user_query=user_query,
|
|
608
|
+
model_name=gen_cfg.get("model_name", "unknown_model"),
|
|
609
|
+
temperature=gen_cfg.get("temperature", 0.0),
|
|
610
|
+
# ... other cache params
|
|
611
|
+
)
|
|
612
|
+
if not final_assistant_output_for_log:
|
|
613
|
+
return {
|
|
614
|
+
"id": sample_id,
|
|
615
|
+
"error": "No response (gen disabled, not in sample/cache)",
|
|
616
|
+
"evaluation_score": 0.0,
|
|
617
|
+
}
|
|
618
|
+
|
|
619
|
+
elif not self.model_client or not http_session: # Generation enabled but client/session missing
|
|
620
|
+
return {
|
|
621
|
+
"id": sample_id,
|
|
622
|
+
"error": "Generation client/session not configured",
|
|
623
|
+
}
|
|
624
|
+
|
|
625
|
+
# --- MCP Agent Rollout Loop ---
|
|
626
|
+
elif self.mcp_intermediary_client and self.cfg.agent.type == "mcp_agent":
|
|
627
|
+
mcp_result = await self._execute_mcp_agent_rollout(
|
|
628
|
+
sample_id=sample_id,
|
|
629
|
+
user_query=user_query,
|
|
630
|
+
system_prompt_content=system_prompt_content,
|
|
631
|
+
openai_formatted_tools=openai_formatted_tools,
|
|
632
|
+
http_session=http_session,
|
|
633
|
+
discovered_tools_for_llm_prompt=discovered_tools_for_llm_prompt,
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
if not mcp_result["success"]:
|
|
637
|
+
return {
|
|
638
|
+
"id": sample_id,
|
|
639
|
+
"error": mcp_result["error"],
|
|
640
|
+
"discovered_tools": discovered_tools_for_llm_prompt,
|
|
641
|
+
}
|
|
642
|
+
|
|
643
|
+
final_assistant_output_for_log = mcp_result["final_assistant_output"]
|
|
644
|
+
current_messages_for_rollout = mcp_result["conversation_history"]
|
|
645
|
+
all_executed_tool_calls_for_sample = mcp_result["executed_tool_calls"]
|
|
646
|
+
final_filesystem_state_from_mcp = mcp_result["final_filesystem_state"]
|
|
647
|
+
|
|
648
|
+
# Evaluation based on the final state and conversation
|
|
649
|
+
eval_params = dict(self.cfg.reward.get("params", OmegaConf.create({})))
|
|
650
|
+
eval_result_obj = self.reward_function(
|
|
651
|
+
messages=[Message(**msg) for msg in current_messages_for_rollout],
|
|
652
|
+
ground_truth=ground_truth_for_eval,
|
|
653
|
+
final_filesystem_state=final_filesystem_state_from_mcp,
|
|
654
|
+
**eval_params,
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
return {
|
|
658
|
+
"id": sample_id,
|
|
659
|
+
"user_query": user_query,
|
|
660
|
+
"system_prompt": system_prompt_content,
|
|
661
|
+
"assistant_response": final_assistant_output_for_log,
|
|
662
|
+
"full_conversation_history": current_messages_for_rollout,
|
|
663
|
+
"ground_truth_for_eval": ground_truth_for_eval,
|
|
664
|
+
"discovered_tools": discovered_tools_for_llm_prompt,
|
|
665
|
+
"executed_tool_calls": all_executed_tool_calls_for_sample,
|
|
666
|
+
"final_mcp_state_captured": final_filesystem_state_from_mcp or "Not captured",
|
|
667
|
+
"evaluation_score": eval_result_obj.score,
|
|
668
|
+
"evaluation_reason": eval_result_obj.reason,
|
|
669
|
+
"evaluation_metrics": (
|
|
670
|
+
{k: v.model_dump() for k, v in eval_result_obj.metrics.items()} if eval_result_obj.metrics else {}
|
|
671
|
+
),
|
|
672
|
+
}
|
|
673
|
+
|
|
674
|
+
# --- Standard LLM Generation (Non-Agent) ---
|
|
675
|
+
else:
|
|
676
|
+
generation_result = await self._execute_standard_generation(
|
|
677
|
+
sample_id=sample_id,
|
|
678
|
+
user_query=user_query,
|
|
679
|
+
system_prompt_content=system_prompt_content,
|
|
680
|
+
http_session=http_session,
|
|
681
|
+
)
|
|
682
|
+
|
|
683
|
+
if not generation_result["success"]:
|
|
684
|
+
return {
|
|
685
|
+
"id": sample_id,
|
|
686
|
+
"error": "Standard generation failed",
|
|
687
|
+
"evaluation_score": 0.0,
|
|
688
|
+
}
|
|
689
|
+
|
|
690
|
+
final_assistant_output_for_log = generation_result["final_assistant_output"]
|
|
691
|
+
conversation_history = generation_result["conversation_history"]
|
|
692
|
+
|
|
693
|
+
# Construct final_messages_for_eval for standard evaluation
|
|
694
|
+
final_messages_for_eval = [Message(**msg) for msg in conversation_history]
|
|
695
|
+
|
|
696
|
+
eval_params = dict(self.cfg.reward.get("params", OmegaConf.create({})))
|
|
697
|
+
eval_result_obj = self.reward_function(
|
|
698
|
+
messages=final_messages_for_eval,
|
|
699
|
+
ground_truth=ground_truth_for_eval,
|
|
700
|
+
final_filesystem_state=None,
|
|
701
|
+
**eval_params,
|
|
702
|
+
)
|
|
703
|
+
return {
|
|
704
|
+
"id": sample_id,
|
|
705
|
+
"user_query": user_query,
|
|
706
|
+
"system_prompt": system_prompt_content,
|
|
707
|
+
"assistant_response": final_assistant_output_for_log,
|
|
708
|
+
"ground_truth_for_eval": ground_truth_for_eval,
|
|
709
|
+
"evaluation_score": eval_result_obj.score,
|
|
710
|
+
"evaluation_reason": eval_result_obj.reason,
|
|
711
|
+
"evaluation_metrics": (
|
|
712
|
+
{k: v.model_dump() for k, v in eval_result_obj.metrics.items()} if eval_result_obj.metrics else {}
|
|
713
|
+
),
|
|
714
|
+
}
|
|
715
|
+
|
|
716
|
+
# Fallback if logic didn't hit a return, though it should.
|
|
717
|
+
return {
|
|
718
|
+
"id": sample_id,
|
|
719
|
+
"error": "Processing logic incomplete",
|
|
720
|
+
"evaluation_score": 0.0,
|
|
721
|
+
}
|
|
722
|
+
|
|
723
|
+
async def run(self) -> List[Dict[str, Any]]:
|
|
724
|
+
logger.info("Starting evaluation pipeline run...")
|
|
725
|
+
|
|
726
|
+
try:
|
|
727
|
+
prompt_dataset_config = self.cfg.dataset
|
|
728
|
+
prompt_dataset = hydra.utils.instantiate(prompt_dataset_config)
|
|
729
|
+
|
|
730
|
+
if isinstance(prompt_dataset, DatasetDict):
|
|
731
|
+
split_name = prompt_dataset_config.get("split", "train")
|
|
732
|
+
if split_name in prompt_dataset:
|
|
733
|
+
prompt_dataset = prompt_dataset[split_name]
|
|
734
|
+
else:
|
|
735
|
+
logger.error(f"Split '{split_name}' not found. Available: {list(prompt_dataset.keys())}")
|
|
736
|
+
return []
|
|
737
|
+
elif not isinstance(prompt_dataset, Dataset):
|
|
738
|
+
logger.error(f"Loaded dataset is not a Hugging Face Dataset. Type: {type(prompt_dataset)}")
|
|
739
|
+
return []
|
|
740
|
+
|
|
741
|
+
dataset_source = getattr(
|
|
742
|
+
prompt_dataset_config,
|
|
743
|
+
"path_or_name",
|
|
744
|
+
getattr(prompt_dataset_config, "base_dataset", "dataset"),
|
|
745
|
+
)
|
|
746
|
+
logger.info(f"Loaded {len(prompt_dataset)} samples from {dataset_source}.")
|
|
747
|
+
except InstantiationException as ie:
|
|
748
|
+
final_cause = ie
|
|
749
|
+
while final_cause.__cause__ is not None:
|
|
750
|
+
final_cause = final_cause.__cause__
|
|
751
|
+
if (
|
|
752
|
+
isinstance(final_cause, ValueError)
|
|
753
|
+
and str(final_cause) == "Invalid pattern: '**' can only be an entire path component"
|
|
754
|
+
):
|
|
755
|
+
base_dataset_config_name = prompt_dataset_config.get("base_dataset", "UnknownBaseDatasetConfig")
|
|
756
|
+
dataset_display_name = base_dataset_config_name
|
|
757
|
+
helpful_message = (
|
|
758
|
+
f"Failed to load the base dataset specified as '{dataset_display_name}' in your derived dataset configuration. "
|
|
759
|
+
f"This occurred due to an internal error in the 'datasets' library (via fsspec): '{str(final_cause)}'.\n"
|
|
760
|
+
"The error message \"Invalid pattern: '**' can only be an entire path component\" often indicates issues with "
|
|
761
|
+
"how the 'datasets' library is resolving the path to the dataset, potential Hugging Face Hub connectivity/authentication problems, or a corrupted local cache.\n\n"
|
|
762
|
+
"Please try the following troubleshooting steps:\n"
|
|
763
|
+
"1. Verify Hugging Face Hub Token: Ensure your token is correctly configured (e.g., run `huggingface-cli login`).\n"
|
|
764
|
+
"2. Clear Datasets Cache: Try removing the subdirectory related to the actual Hugging Face dataset path/name from `~/.cache/huggingface/datasets/`.\n"
|
|
765
|
+
"3. Update Libraries: `pip install --upgrade datasets huggingface_hub fsspec`.\n"
|
|
766
|
+
"4. Test Direct Loading: (See previous detailed instructions for direct loading test script).\n"
|
|
767
|
+
f"Original InstantiationException details: {ie}"
|
|
768
|
+
)
|
|
769
|
+
logger.error(helpful_message, exc_info=False)
|
|
770
|
+
else:
|
|
771
|
+
logger.error(f"Failed to load prompt dataset: {ie}", exc_info=True)
|
|
772
|
+
return []
|
|
773
|
+
except Exception as e:
|
|
774
|
+
logger.error(
|
|
775
|
+
f"An unexpected error occurred during dataset loading: {e}",
|
|
776
|
+
exc_info=True,
|
|
777
|
+
)
|
|
778
|
+
return []
|
|
779
|
+
|
|
780
|
+
all_results: List[Dict[str, Any]] = []
|
|
781
|
+
limit_samples = self.cfg.evaluation_params.get("limit_samples", None)
|
|
782
|
+
samples_to_process_count = len(prompt_dataset)
|
|
783
|
+
if limit_samples is not None and limit_samples > 0:
|
|
784
|
+
samples_to_process_count = min(len(prompt_dataset), limit_samples)
|
|
785
|
+
|
|
786
|
+
logger.info(f"Processing {samples_to_process_count} samples.")
|
|
787
|
+
|
|
788
|
+
http_session_for_model_client: Optional[aiohttp.ClientSession] = None # Renamed for clarity
|
|
789
|
+
if self.cfg.generation.enabled and self.model_client:
|
|
790
|
+
http_session_for_model_client = aiohttp.ClientSession()
|
|
791
|
+
|
|
792
|
+
if self.cfg.get("agent") and self.cfg.agent.get("type") == "mcp_agent":
|
|
793
|
+
self.mcp_intermediary_client = IntermediaryMCPClient(
|
|
794
|
+
intermediary_server_url=self.cfg.agent.intermediary_server_url
|
|
795
|
+
)
|
|
796
|
+
logger.info(f"Created IntermediaryMCPClient instance with URL: {self.cfg.agent.intermediary_server_url}")
|
|
797
|
+
|
|
798
|
+
async def execute_tasks():
|
|
799
|
+
tasks = []
|
|
800
|
+
# http_session_for_model_client is managed outside this async def now
|
|
801
|
+
|
|
802
|
+
max_concurrent = self.cfg.generation.api_params.get("max_concurrent_requests", 5)
|
|
803
|
+
if not isinstance(max_concurrent, int) or max_concurrent <= 0:
|
|
804
|
+
logger.warning(f"Invalid max_concurrent_requests value ({max_concurrent}), defaulting to 5.")
|
|
805
|
+
max_concurrent = 5
|
|
806
|
+
semaphore = asyncio.Semaphore(max_concurrent)
|
|
807
|
+
|
|
808
|
+
async def process_with_semaphore_wrapper(sample_idx: int, sample_data: Dict[str, Any]):
|
|
809
|
+
prelim_sample_id = sample_data.get("id", f"idx_{sample_idx}")
|
|
810
|
+
async with semaphore:
|
|
811
|
+
logger.info(f"Concurrency slot acquired for sample '{prelim_sample_id}', attempting to process.")
|
|
812
|
+
return await self._process_single_sample(
|
|
813
|
+
sample_data,
|
|
814
|
+
http_session_for_model_client,
|
|
815
|
+
original_index=sample_idx,
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
for i in range(samples_to_process_count):
|
|
819
|
+
tasks.append(process_with_semaphore_wrapper(i, prompt_dataset[i]))
|
|
820
|
+
|
|
821
|
+
batch_size_for_logging = self.cfg.logging_params.get("batch_log_interval", 10)
|
|
822
|
+
if not isinstance(batch_size_for_logging, int) or batch_size_for_logging <= 0:
|
|
823
|
+
logger.warning(f"Invalid batch_log_interval ({batch_size_for_logging}), defaulting to 10.")
|
|
824
|
+
batch_size_for_logging = 10
|
|
825
|
+
|
|
826
|
+
for i_outer in range(0, len(tasks), batch_size_for_logging):
|
|
827
|
+
batch_tasks = tasks[i_outer : i_outer + batch_size_for_logging]
|
|
828
|
+
batch_results_values = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
|
829
|
+
for res_idx, res_or_exc in enumerate(batch_results_values):
|
|
830
|
+
if isinstance(res_or_exc, Exception):
|
|
831
|
+
logger.error(
|
|
832
|
+
f"Task for sample index {i_outer + res_idx} failed: {res_or_exc}",
|
|
833
|
+
exc_info=True,
|
|
834
|
+
)
|
|
835
|
+
all_results.append(
|
|
836
|
+
{
|
|
837
|
+
"id": prompt_dataset[i_outer + res_idx].get("id", "unknown"),
|
|
838
|
+
"error": str(res_or_exc),
|
|
839
|
+
}
|
|
840
|
+
)
|
|
841
|
+
elif res_or_exc is not None:
|
|
842
|
+
# Handle both single results and lists of results (N-variant)
|
|
843
|
+
if isinstance(res_or_exc, list):
|
|
844
|
+
all_results.extend(res_or_exc)
|
|
845
|
+
else:
|
|
846
|
+
all_results.append(res_or_exc)
|
|
847
|
+
logger.info(
|
|
848
|
+
f"Completed batch up to sample {i_outer + len(batch_tasks)}. Total results/errors: {len(all_results)}"
|
|
849
|
+
)
|
|
850
|
+
|
|
851
|
+
try:
|
|
852
|
+
if self.mcp_intermediary_client:
|
|
853
|
+
async with self.mcp_intermediary_client:
|
|
854
|
+
await execute_tasks()
|
|
855
|
+
else:
|
|
856
|
+
await execute_tasks()
|
|
857
|
+
finally:
|
|
858
|
+
if http_session_for_model_client:
|
|
859
|
+
await http_session_for_model_client.close()
|
|
860
|
+
logger.debug("Closed aiohttp.ClientSession for model_client in main run() finally block.")
|
|
861
|
+
|
|
862
|
+
output_file_path = self.cfg.output.get("results_file", None)
|
|
863
|
+
if output_file_path:
|
|
864
|
+
if not os.path.isabs(output_file_path) and self.cfg.hydra_output_dir:
|
|
865
|
+
output_file_path = os.path.join(self.cfg.hydra_output_dir, output_file_path)
|
|
866
|
+
try:
|
|
867
|
+
os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
|
|
868
|
+
with open(output_file_path, "w", encoding="utf-8") as f:
|
|
869
|
+
for result_item in all_results:
|
|
870
|
+
f.write(json.dumps(result_item) + "\n")
|
|
871
|
+
logger.info(f"Detailed results saved to: {os.path.abspath(output_file_path)}")
|
|
872
|
+
except Exception as e:
|
|
873
|
+
logger.error(f"Failed to save results to {output_file_path}: {e}")
|
|
874
|
+
|
|
875
|
+
preview_pairs_file_path = self.cfg.output.get("preview_pairs_file", "preview_input_output_pairs.jsonl")
|
|
876
|
+
if preview_pairs_file_path:
|
|
877
|
+
if not os.path.isabs(preview_pairs_file_path) and self.cfg.hydra_output_dir:
|
|
878
|
+
preview_pairs_file_path = os.path.join(self.cfg.hydra_output_dir, preview_pairs_file_path)
|
|
879
|
+
preview_data_to_save = []
|
|
880
|
+
for result_item in all_results:
|
|
881
|
+
# Use full_conversation_history if available, otherwise construct from system/user/assistant
|
|
882
|
+
if "full_conversation_history" in result_item:
|
|
883
|
+
messages_to_save = result_item["full_conversation_history"]
|
|
884
|
+
elif (
|
|
885
|
+
"error" in result_item
|
|
886
|
+
or not result_item.get("user_query")
|
|
887
|
+
or not result_item.get("assistant_response")
|
|
888
|
+
):
|
|
889
|
+
continue # Skip if essential parts for basic preview are missing and no history
|
|
890
|
+
else: # Construct basic messages for non-agent or simple agent cases
|
|
891
|
+
messages_to_save = []
|
|
892
|
+
if result_item.get("system_prompt"):
|
|
893
|
+
messages_to_save.append({"role": "system", "content": result_item["system_prompt"]})
|
|
894
|
+
messages_to_save.append({"role": "user", "content": result_item["user_query"]})
|
|
895
|
+
messages_to_save.append(
|
|
896
|
+
{
|
|
897
|
+
"role": "assistant",
|
|
898
|
+
"content": result_item["assistant_response"],
|
|
899
|
+
}
|
|
900
|
+
)
|
|
901
|
+
|
|
902
|
+
pair_item = {"messages": messages_to_save}
|
|
903
|
+
if result_item.get("ground_truth_for_eval"):
|
|
904
|
+
pair_item["ground_truth"] = result_item["ground_truth_for_eval"]
|
|
905
|
+
if result_item.get("id"):
|
|
906
|
+
pair_item["id"] = result_item["id"]
|
|
907
|
+
preview_data_to_save.append(pair_item)
|
|
908
|
+
|
|
909
|
+
if preview_data_to_save:
|
|
910
|
+
try:
|
|
911
|
+
os.makedirs(os.path.dirname(preview_pairs_file_path), exist_ok=True)
|
|
912
|
+
with open(preview_pairs_file_path, "w", encoding="utf-8") as f:
|
|
913
|
+
for item in preview_data_to_save:
|
|
914
|
+
f.write(json.dumps(item) + "\n")
|
|
915
|
+
logger.info(f"Input/output pairs for preview saved to: {os.path.abspath(preview_pairs_file_path)}")
|
|
916
|
+
except Exception as e:
|
|
917
|
+
logger.error(f"Failed to save preview pairs to {preview_pairs_file_path}: {e}")
|
|
918
|
+
|
|
919
|
+
logger.info("Evaluation pipeline run finished.")
|
|
920
|
+
return all_results
|