eval-protocol 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- development/__init__.py +1 -0
- development/normalize_sandbox_fusion.py +628 -0
- development/utils/__init__.py +1 -0
- development/utils/generate_api_key.py +31 -0
- development/utils/subprocess_manager.py +481 -0
- eval_protocol/__init__.py +86 -0
- eval_protocol/__main__.py +10 -0
- eval_protocol/_version.py +21 -0
- eval_protocol/adapters/__init__.py +1 -0
- eval_protocol/adapters/braintrust.py +8 -0
- eval_protocol/adapters/trl.py +8 -0
- eval_protocol/agent/__init__.py +29 -0
- eval_protocol/agent/models.py +69 -0
- eval_protocol/agent/orchestrator.py +893 -0
- eval_protocol/agent/resource_abc.py +89 -0
- eval_protocol/agent/resource_pool.py +184 -0
- eval_protocol/agent/resources/__init__.py +44 -0
- eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
- eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
- eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
- eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
- eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
- eval_protocol/agent/resources/docker_resource.py +479 -0
- eval_protocol/agent/resources/filesystem_resource.py +371 -0
- eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
- eval_protocol/agent/resources/http_rollout_resource.py +325 -0
- eval_protocol/agent/resources/python_state_resource.py +170 -0
- eval_protocol/agent/resources/sql_resource.py +271 -0
- eval_protocol/agent/task_manager.py +1064 -0
- eval_protocol/agent/tool_registry.py +111 -0
- eval_protocol/auth.py +156 -0
- eval_protocol/cli.py +425 -0
- eval_protocol/cli_commands/__init__.py +1 -0
- eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
- eval_protocol/cli_commands/common.py +242 -0
- eval_protocol/cli_commands/deploy.py +486 -0
- eval_protocol/cli_commands/deploy_mcp.py +287 -0
- eval_protocol/cli_commands/preview.py +186 -0
- eval_protocol/cli_commands/run_eval_cmd.py +202 -0
- eval_protocol/common_utils.py +36 -0
- eval_protocol/config.py +180 -0
- eval_protocol/datasets/__init__.py +1 -0
- eval_protocol/datasets/loader.py +521 -0
- eval_protocol/evaluation.py +1045 -0
- eval_protocol/execution/__init__.py +1 -0
- eval_protocol/execution/pipeline.py +920 -0
- eval_protocol/gcp_tools.py +484 -0
- eval_protocol/generation/cache.py +141 -0
- eval_protocol/generation/clients/base.py +67 -0
- eval_protocol/generation/clients.py +248 -0
- eval_protocol/generic_server.py +165 -0
- eval_protocol/integrations/__init__.py +12 -0
- eval_protocol/integrations/braintrust.py +51 -0
- eval_protocol/integrations/deepeval.py +106 -0
- eval_protocol/integrations/openeval.py +40 -0
- eval_protocol/integrations/trl.py +187 -0
- eval_protocol/mcp/__init__.py +48 -0
- eval_protocol/mcp/adapter.py +131 -0
- eval_protocol/mcp/client/__init__.py +12 -0
- eval_protocol/mcp/client/connection.py +499 -0
- eval_protocol/mcp/clients.py +195 -0
- eval_protocol/mcp/execution/__init__.py +23 -0
- eval_protocol/mcp/execution/base_policy.py +227 -0
- eval_protocol/mcp/execution/fireworks_policy.py +209 -0
- eval_protocol/mcp/execution/manager.py +506 -0
- eval_protocol/mcp/execution/policy.py +421 -0
- eval_protocol/mcp/grid_renderer.py +54 -0
- eval_protocol/mcp/mcpgym.py +637 -0
- eval_protocol/mcp/process_manager.py +177 -0
- eval_protocol/mcp/session/__init__.py +11 -0
- eval_protocol/mcp/session/manager.py +228 -0
- eval_protocol/mcp/simple_process_manager.py +291 -0
- eval_protocol/mcp/simulation_server.py +458 -0
- eval_protocol/mcp/types.py +80 -0
- eval_protocol/mcp_agent/__init__.py +1 -0
- eval_protocol/mcp_agent/config.py +147 -0
- eval_protocol/mcp_agent/intermediary_server.py +542 -0
- eval_protocol/mcp_agent/main.py +210 -0
- eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
- eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
- eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
- eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
- eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
- eval_protocol/mcp_agent/session.py +79 -0
- eval_protocol/mcp_env.py +304 -0
- eval_protocol/models.py +366 -0
- eval_protocol/packaging.py +219 -0
- eval_protocol/platform_api.py +360 -0
- eval_protocol/playback_policy.py +396 -0
- eval_protocol/resources.py +128 -0
- eval_protocol/reward_function.py +410 -0
- eval_protocol/rewards/__init__.py +94 -0
- eval_protocol/rewards/accuracy.py +454 -0
- eval_protocol/rewards/accuracy_length.py +173 -0
- eval_protocol/rewards/apps_coding_reward.py +331 -0
- eval_protocol/rewards/apps_execution_utils.py +149 -0
- eval_protocol/rewards/apps_testing_util.py +559 -0
- eval_protocol/rewards/bfcl_reward.py +313 -0
- eval_protocol/rewards/code_execution.py +1620 -0
- eval_protocol/rewards/code_execution_utils.py +72 -0
- eval_protocol/rewards/cpp_code.py +861 -0
- eval_protocol/rewards/deepcoder_reward.py +161 -0
- eval_protocol/rewards/format.py +129 -0
- eval_protocol/rewards/function_calling.py +541 -0
- eval_protocol/rewards/json_schema.py +422 -0
- eval_protocol/rewards/language_consistency.py +700 -0
- eval_protocol/rewards/lean_prover.py +479 -0
- eval_protocol/rewards/length.py +375 -0
- eval_protocol/rewards/list_comparison_math_reward.py +221 -0
- eval_protocol/rewards/math.py +762 -0
- eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
- eval_protocol/rewards/reasoning_steps.py +249 -0
- eval_protocol/rewards/repetition.py +342 -0
- eval_protocol/rewards/tag_count.py +162 -0
- eval_protocol/rl_processing.py +82 -0
- eval_protocol/server.py +271 -0
- eval_protocol/typed_interface.py +260 -0
- eval_protocol/utils/__init__.py +8 -0
- eval_protocol/utils/batch_evaluation.py +217 -0
- eval_protocol/utils/batch_transformation.py +205 -0
- eval_protocol/utils/dataset_helpers.py +112 -0
- eval_protocol/utils/module_loader.py +56 -0
- eval_protocol/utils/packaging_utils.py +108 -0
- eval_protocol/utils/static_policy.py +305 -0
- eval_protocol-0.0.3.dist-info/METADATA +635 -0
- eval_protocol-0.0.3.dist-info/RECORD +130 -0
- eval_protocol-0.0.3.dist-info/WHEEL +5 -0
- eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
- eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
- eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,893 @@
|
|
|
1
|
+
# mypy: ignore-errors
|
|
2
|
+
"""
|
|
3
|
+
Orchestrator for the Agent Evaluation Framework V2.
|
|
4
|
+
Manages the lifecycle of a task using ForkableResources.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
import importlib
|
|
9
|
+
import inspect
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
import os
|
|
13
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, cast
|
|
14
|
+
|
|
15
|
+
# Attempt to import OpenAI client
|
|
16
|
+
try:
|
|
17
|
+
from openai import AsyncOpenAI, OpenAI
|
|
18
|
+
from openai.types.chat import ChatCompletionMessage, ChatCompletionToolParam
|
|
19
|
+
from openai.types.chat.chat_completion_message_tool_call import (
|
|
20
|
+
ChatCompletionMessageToolCall,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
OPENAI_AVAILABLE = True
|
|
24
|
+
except ImportError:
|
|
25
|
+
OPENAI_AVAILABLE = False
|
|
26
|
+
# Define dummy types if openai is not installed, to avoid runtime errors on load
|
|
27
|
+
from typing import Any, Dict, List, Optional, Union
|
|
28
|
+
|
|
29
|
+
# Use simple class definitions for runtime and type checking
|
|
30
|
+
class OpenAI:
|
|
31
|
+
def __init__(self, **kwargs: Any) -> None:
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
class AsyncOpenAI:
|
|
35
|
+
def __init__(self, **kwargs: Any) -> None:
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
class ChatCompletionMessage:
|
|
39
|
+
content: str = ""
|
|
40
|
+
role: str = "assistant"
|
|
41
|
+
|
|
42
|
+
class ChatCompletionToolParam:
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
class ChatCompletionMessageToolCall:
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# Max steps for the inner loop within a single user turn
|
|
50
|
+
MAX_STEPS_PER_USER_TURN = 10
|
|
51
|
+
|
|
52
|
+
from ..models import Message, TaskDefinitionModel
|
|
53
|
+
from .resource_abc import ForkableResource
|
|
54
|
+
|
|
55
|
+
# Import specific resource types for type checking if needed, or handle dynamically
|
|
56
|
+
from .resources import (
|
|
57
|
+
BFCLSimAPIResource,
|
|
58
|
+
DockerResource,
|
|
59
|
+
FileSystemResource,
|
|
60
|
+
HttpRolloutResource,
|
|
61
|
+
PythonStateResource,
|
|
62
|
+
SQLResource,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class Orchestrator:
|
|
67
|
+
def __init__(self, task_definition: TaskDefinitionModel):
|
|
68
|
+
self.task_definition = task_definition
|
|
69
|
+
self.base_resource: Optional[ForkableResource] = None
|
|
70
|
+
self.tools_module: Optional[Any] = None
|
|
71
|
+
self.reward_function: Optional[Callable[..., Any]] = None
|
|
72
|
+
self.logger = logging.getLogger(f"Orchestrator.{self.task_definition.name}")
|
|
73
|
+
self.logger.setLevel(logging.DEBUG) # Ensure debug logs are processed
|
|
74
|
+
self.logger.info(f"Orchestrator initialized for task: {self.task_definition.name}")
|
|
75
|
+
self._openai_client: Optional[AsyncOpenAI] = None
|
|
76
|
+
|
|
77
|
+
def _initialize_openai_client(self):
|
|
78
|
+
"""Initializes the AsyncOpenAI client if available and not already initialized."""
|
|
79
|
+
if not OPENAI_AVAILABLE:
|
|
80
|
+
self.logger.warning("OpenAI library not available. Cannot use OpenAI models.")
|
|
81
|
+
return
|
|
82
|
+
if self._openai_client is None:
|
|
83
|
+
# Consider adding error handling for missing API key
|
|
84
|
+
try:
|
|
85
|
+
self._openai_client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
|
86
|
+
self.logger.info("AsyncOpenAI client initialized.")
|
|
87
|
+
except Exception as e:
|
|
88
|
+
self.logger.error(f"Failed to initialize AsyncOpenAI client: {e}")
|
|
89
|
+
self._openai_client = None # Ensure it's None if init fails
|
|
90
|
+
|
|
91
|
+
def _initialize_fireworks_client(self):
|
|
92
|
+
"""Initializes the Fireworks client using OpenAI-compatible interface."""
|
|
93
|
+
if not OPENAI_AVAILABLE:
|
|
94
|
+
self.logger.warning("OpenAI library not available. Cannot use Fireworks models.")
|
|
95
|
+
return
|
|
96
|
+
if self._openai_client is None:
|
|
97
|
+
try:
|
|
98
|
+
self._openai_client = AsyncOpenAI(
|
|
99
|
+
api_key=os.environ.get("FIREWORKS_API_KEY"),
|
|
100
|
+
base_url="https://api.fireworks.ai/inference/v1",
|
|
101
|
+
)
|
|
102
|
+
self.logger.info("Fireworks client initialized.")
|
|
103
|
+
except Exception as e:
|
|
104
|
+
self.logger.error(f"Failed to initialize Fireworks client: {e}")
|
|
105
|
+
self._openai_client = None
|
|
106
|
+
|
|
107
|
+
def _validate_conversation_messages(self, conversation_messages: List[Dict[str, Any]]) -> None:
|
|
108
|
+
"""
|
|
109
|
+
Validate and fix conversation messages to ensure OpenAI API compliance.
|
|
110
|
+
|
|
111
|
+
OpenAI requires that tool messages must be preceded by an assistant message with tool_calls.
|
|
112
|
+
This method detects and fixes cases where tool messages are orphaned.
|
|
113
|
+
"""
|
|
114
|
+
if not conversation_messages:
|
|
115
|
+
return
|
|
116
|
+
|
|
117
|
+
for i, msg in enumerate(conversation_messages):
|
|
118
|
+
if msg.get("role") == "tool":
|
|
119
|
+
# Check if previous message is assistant with tool_calls
|
|
120
|
+
if i == 0:
|
|
121
|
+
# Tool message at start - this is always invalid
|
|
122
|
+
self.logger.error(f"Found orphaned tool message at start of conversation: {msg}")
|
|
123
|
+
raise ValueError("Tool message cannot be the first message in conversation")
|
|
124
|
+
|
|
125
|
+
prev_msg = conversation_messages[i - 1]
|
|
126
|
+
if prev_msg.get("role") != "assistant" or not prev_msg.get("tool_calls"):
|
|
127
|
+
# Found orphaned tool message - log and remove it
|
|
128
|
+
self.logger.warning(
|
|
129
|
+
f"Found orphaned tool message without preceding assistant tool_calls at index {i}: {msg}"
|
|
130
|
+
)
|
|
131
|
+
self.logger.warning(
|
|
132
|
+
"This suggests a bug in conversation history management - removing invalid tool message"
|
|
133
|
+
)
|
|
134
|
+
conversation_messages.pop(i)
|
|
135
|
+
# Recursively validate again since we modified the list
|
|
136
|
+
return self._validate_conversation_messages(conversation_messages)
|
|
137
|
+
|
|
138
|
+
def _load_module_and_function(self, full_path: str) -> Optional[Callable[..., Any]]:
|
|
139
|
+
try:
|
|
140
|
+
module_path, function_name = full_path.rsplit(".", 1)
|
|
141
|
+
module = importlib.import_module(module_path)
|
|
142
|
+
func = getattr(module, function_name)
|
|
143
|
+
|
|
144
|
+
# Check if the attribute exists but might not be directly callable due to decoration
|
|
145
|
+
# For example, bfcl_reward is defined in the module but wrapped with @reward_function
|
|
146
|
+
if hasattr(module, function_name):
|
|
147
|
+
# For attributes that are or contain callable objects
|
|
148
|
+
attr = getattr(module, function_name)
|
|
149
|
+
if callable(attr):
|
|
150
|
+
self.logger.info(f"Successfully loaded function '{function_name}' from module '{module_path}'.")
|
|
151
|
+
return attr
|
|
152
|
+
# For module-level objects that might wrap callable functions
|
|
153
|
+
elif hasattr(attr, "__call__"):
|
|
154
|
+
self.logger.info(
|
|
155
|
+
f"Successfully loaded callable object '{function_name}' from module '{module_path}'."
|
|
156
|
+
)
|
|
157
|
+
return attr.__call__
|
|
158
|
+
else:
|
|
159
|
+
self.logger.error(f"Loaded attribute '{function_name}' from '{module_path}' is not callable.")
|
|
160
|
+
else:
|
|
161
|
+
self.logger.error(f"Attribute '{function_name}' not found in module '{module_path}'.")
|
|
162
|
+
return None
|
|
163
|
+
except (ImportError, AttributeError, ValueError) as e:
|
|
164
|
+
self.logger.error(f"Failed to load function from '{full_path}': {e}")
|
|
165
|
+
return None
|
|
166
|
+
|
|
167
|
+
async def _load_task_components(self) -> bool:
|
|
168
|
+
if self.task_definition.tools_module_path:
|
|
169
|
+
try:
|
|
170
|
+
self.tools_module = importlib.import_module(self.task_definition.tools_module_path)
|
|
171
|
+
self.logger.info(f"Successfully loaded tools module: {self.task_definition.tools_module_path}")
|
|
172
|
+
except ImportError as e:
|
|
173
|
+
self.logger.error(f"Failed to import tools module '{self.task_definition.tools_module_path}': {e}")
|
|
174
|
+
return False
|
|
175
|
+
else:
|
|
176
|
+
self.logger.info("No 'tools_module_path' specified. Tools may only come from resource.get_tools_spec().")
|
|
177
|
+
|
|
178
|
+
# Load reward function
|
|
179
|
+
if self.task_definition.reward_function_path:
|
|
180
|
+
try:
|
|
181
|
+
# First try direct import
|
|
182
|
+
self.reward_function = self._load_module_and_function(self.task_definition.reward_function_path)
|
|
183
|
+
|
|
184
|
+
if not self.reward_function:
|
|
185
|
+
# If that failed, check if we need to import from eval_protocol.rewards
|
|
186
|
+
if "." not in self.task_definition.reward_function_path:
|
|
187
|
+
# Try importing from rewards directly as a fallback
|
|
188
|
+
fallback_path = f"eval_protocol.rewards.{self.task_definition.reward_function_path}"
|
|
189
|
+
self.logger.info(f"Attempting fallback import from: {fallback_path}")
|
|
190
|
+
self.reward_function = self._load_module_and_function(fallback_path)
|
|
191
|
+
|
|
192
|
+
# If still no function, try importing from __init__ exports
|
|
193
|
+
if (
|
|
194
|
+
not self.reward_function
|
|
195
|
+
and "eval_protocol.rewards" in self.task_definition.reward_function_path
|
|
196
|
+
):
|
|
197
|
+
# Extract the function name from the path
|
|
198
|
+
func_name = self.task_definition.reward_function_path.split(".")[-1]
|
|
199
|
+
self.logger.debug(f"Attempting to get function by name: {func_name}")
|
|
200
|
+
try:
|
|
201
|
+
import eval_protocol.rewards
|
|
202
|
+
|
|
203
|
+
self.logger.debug(f"Available in rewards module: {dir(eval_protocol.rewards)}")
|
|
204
|
+
if hasattr(eval_protocol.rewards, func_name):
|
|
205
|
+
self.reward_function = getattr(eval_protocol.rewards, func_name)
|
|
206
|
+
self.logger.info(f"Found reward function {func_name} in eval_protocol.rewards")
|
|
207
|
+
self.logger.debug(f"Loaded function type: {type(self.reward_function)}")
|
|
208
|
+
self.logger.debug(f"Is callable: {callable(self.reward_function)}")
|
|
209
|
+
else:
|
|
210
|
+
self.logger.error(f"Function {func_name} not found in eval_protocol.rewards")
|
|
211
|
+
except (ImportError, AttributeError) as e:
|
|
212
|
+
self.logger.error(f"Error importing from rewards module: {e}")
|
|
213
|
+
|
|
214
|
+
if self.reward_function:
|
|
215
|
+
self.logger.info(
|
|
216
|
+
f"Successfully loaded reward function: {self.task_definition.reward_function_path}"
|
|
217
|
+
)
|
|
218
|
+
return True
|
|
219
|
+
else:
|
|
220
|
+
self.logger.error(
|
|
221
|
+
f"Failed to load reward function from '{self.task_definition.reward_function_path}'"
|
|
222
|
+
)
|
|
223
|
+
return False
|
|
224
|
+
except Exception as e:
|
|
225
|
+
self.logger.error(f"Error loading reward function: {e}", exc_info=True)
|
|
226
|
+
return False
|
|
227
|
+
else:
|
|
228
|
+
self.logger.error("Reward function path is mandatory but missing.")
|
|
229
|
+
return False
|
|
230
|
+
return True
|
|
231
|
+
|
|
232
|
+
def _get_resource_class(self, resource_type_name: str) -> Type[ForkableResource]:
|
|
233
|
+
# This method will now need to look into eval_protocol.agent_v2.resources
|
|
234
|
+
# For example: from .resources import SQLResource, PythonStateResource etc.
|
|
235
|
+
# And then map resource_type_name string to the class.
|
|
236
|
+
# For now, a placeholder that would need specific imports or a registry.
|
|
237
|
+
|
|
238
|
+
# Option 1: Direct mapping (requires importing all known resource types here)
|
|
239
|
+
# from .resources import PythonStateResource, SQLResource, FileSystemResource, DockerResource # noqa
|
|
240
|
+
|
|
241
|
+
mapping = {
|
|
242
|
+
"PythonStateResource": PythonStateResource,
|
|
243
|
+
"SQLResource": SQLResource,
|
|
244
|
+
"FileSystemResource": FileSystemResource,
|
|
245
|
+
"DockerResource": DockerResource,
|
|
246
|
+
"BFCLSimAPIResource": BFCLSimAPIResource, # Add BFCLSimAPIResource to mapping
|
|
247
|
+
"HttpRolloutResource": HttpRolloutResource, # Add HttpRolloutResource to mapping
|
|
248
|
+
"http_rollout": HttpRolloutResource, # Allow lowercase alias for convenience
|
|
249
|
+
}
|
|
250
|
+
resource_class = mapping.get(resource_type_name)
|
|
251
|
+
|
|
252
|
+
if resource_class is None:
|
|
253
|
+
raise ValueError(
|
|
254
|
+
f"Resource class '{resource_type_name}' not found or not mapped in Orchestrator._get_resource_class."
|
|
255
|
+
)
|
|
256
|
+
# No need to check issubclass here if mapping is correct and types are imported.
|
|
257
|
+
return cast(Type[ForkableResource], resource_class)
|
|
258
|
+
|
|
259
|
+
async def setup_base_resource(self) -> None:
|
|
260
|
+
resource_type = self.task_definition.resource_type
|
|
261
|
+
base_config = self.task_definition.base_resource_config
|
|
262
|
+
|
|
263
|
+
self.logger.info(f"Attempting to set up base resource of type '{resource_type}'...")
|
|
264
|
+
try:
|
|
265
|
+
ResourceClass = self._get_resource_class(resource_type)
|
|
266
|
+
self.base_resource = ResourceClass()
|
|
267
|
+
await self.base_resource.setup(base_config)
|
|
268
|
+
self.logger.info(f"Base resource '{resource_type}' setup complete.")
|
|
269
|
+
except ValueError as e_val:
|
|
270
|
+
self.logger.error(f"Could not get resource class '{resource_type}'. {e_val}")
|
|
271
|
+
self.base_resource = None
|
|
272
|
+
except Exception as e_setup:
|
|
273
|
+
self.logger.error(
|
|
274
|
+
f"Failed to setup base resource '{resource_type}'. {e_setup}",
|
|
275
|
+
exc_info=True,
|
|
276
|
+
)
|
|
277
|
+
self.base_resource = None
|
|
278
|
+
|
|
279
|
+
async def _get_available_tools(self, episode_resource: ForkableResource) -> Dict[str, Callable[..., Any]]:
|
|
280
|
+
available_tools: Dict[str, Callable[..., Any]] = {}
|
|
281
|
+
if episode_resource:
|
|
282
|
+
resource_tool_specs = await episode_resource.get_tools_spec()
|
|
283
|
+
self.logger.debug(f"Raw tool specs from resource.get_tools_spec(): {resource_tool_specs}")
|
|
284
|
+
for tool_spec in resource_tool_specs:
|
|
285
|
+
# Corrected logic based on BFCLSimAPIResource._infer_schema_from_method output
|
|
286
|
+
tool_name = tool_spec.get("name")
|
|
287
|
+
if tool_name:
|
|
288
|
+
# Create an async adapter function that calls episode_resource.step
|
|
289
|
+
async def resource_tool_adapter(
|
|
290
|
+
params: Dict[str, Any],
|
|
291
|
+
bound_tool_name=tool_name,
|
|
292
|
+
bound_resource=episode_resource,
|
|
293
|
+
):
|
|
294
|
+
# Ensure params are passed correctly to step
|
|
295
|
+
return await bound_resource.step(action_name=bound_tool_name, action_params=params)
|
|
296
|
+
|
|
297
|
+
available_tools[tool_name] = resource_tool_adapter
|
|
298
|
+
self.logger.debug(f"Added tool '{tool_name}' from resource spec.")
|
|
299
|
+
else:
|
|
300
|
+
self.logger.warning(f"Skipping resource tool spec due to missing 'name': {tool_spec}")
|
|
301
|
+
|
|
302
|
+
# Check for tools defined using ToolRegistry (more common pattern)
|
|
303
|
+
if self.tools_module:
|
|
304
|
+
self.logger.debug(f"Inspecting tools_module: {self.tools_module} (type: {type(self.tools_module)})")
|
|
305
|
+
|
|
306
|
+
# First, try to find a ToolRegistry instance
|
|
307
|
+
registry_instances = []
|
|
308
|
+
for name, member in inspect.getmembers(self.tools_module):
|
|
309
|
+
# Skip if it starts with underscore or is not a ToolRegistry
|
|
310
|
+
if name.startswith("_"):
|
|
311
|
+
continue
|
|
312
|
+
|
|
313
|
+
if hasattr(member, "get_tools") and callable(member.get_tools):
|
|
314
|
+
registry_instances.append((name, member))
|
|
315
|
+
self.logger.debug(f"Found ToolRegistry instance: {name}")
|
|
316
|
+
|
|
317
|
+
if registry_instances:
|
|
318
|
+
# Use the first registry instance found
|
|
319
|
+
registry_name, registry = registry_instances[0]
|
|
320
|
+
self.logger.info(f"Using ToolRegistry '{registry_name}' from module")
|
|
321
|
+
|
|
322
|
+
# Get all tools from the registry
|
|
323
|
+
registry_tools = registry.get_tools()
|
|
324
|
+
for tool_name, tool_func in registry_tools.items():
|
|
325
|
+
# Create an adapter that will pass the resource to the tool
|
|
326
|
+
def create_tool_adapter(tool_func):
|
|
327
|
+
async def adapter(params: Dict[str, Any], bound_resource=episode_resource):
|
|
328
|
+
# Handle both sync and async functions
|
|
329
|
+
if asyncio.iscoroutinefunction(tool_func):
|
|
330
|
+
result = await tool_func(resource=bound_resource, **params)
|
|
331
|
+
else:
|
|
332
|
+
result = tool_func(resource=bound_resource, **params)
|
|
333
|
+
return result
|
|
334
|
+
|
|
335
|
+
return adapter
|
|
336
|
+
|
|
337
|
+
available_tools[tool_name] = create_tool_adapter(tool_func)
|
|
338
|
+
self.logger.debug(f"Added tool '{tool_name}' from registry {registry_name}")
|
|
339
|
+
|
|
340
|
+
# If we found and used a registry, we're done
|
|
341
|
+
if available_tools:
|
|
342
|
+
self.logger.info(f"Found {len(available_tools)} tools from ToolRegistry")
|
|
343
|
+
self.logger.debug(f"Tool names: {list(available_tools.keys())}")
|
|
344
|
+
|
|
345
|
+
# If no registry tools were found, fall back to module inspection
|
|
346
|
+
if not available_tools:
|
|
347
|
+
self.logger.debug("No ToolRegistry found or no tools in registry. Falling back to module inspection.")
|
|
348
|
+
|
|
349
|
+
members_to_inspect = []
|
|
350
|
+
if inspect.ismodule(self.tools_module):
|
|
351
|
+
self.logger.debug("tools_module is a module. Using inspect.getmembers.")
|
|
352
|
+
members_to_inspect = inspect.getmembers(self.tools_module)
|
|
353
|
+
elif hasattr(self.tools_module, "__dict__"):
|
|
354
|
+
self.logger.debug("tools_module is an object with __dict__. Iterating __dict__.items().")
|
|
355
|
+
members_to_inspect = self.tools_module.__dict__.items()
|
|
356
|
+
else:
|
|
357
|
+
self.logger.debug("Falling back to inspect.getmembers.")
|
|
358
|
+
members_to_inspect = inspect.getmembers(self.tools_module)
|
|
359
|
+
|
|
360
|
+
for name, member in members_to_inspect:
|
|
361
|
+
self.logger.debug(
|
|
362
|
+
f"Found member in tools_module: '{name}', type: {type(member)}, callable: {callable(member)}"
|
|
363
|
+
)
|
|
364
|
+
if name.startswith("_") or not callable(member):
|
|
365
|
+
self.logger.debug(f"Skipping member '{name}' (startswith_ or not callable).")
|
|
366
|
+
continue
|
|
367
|
+
|
|
368
|
+
# Check if it's a sync or async function
|
|
369
|
+
is_async = asyncio.iscoroutinefunction(member)
|
|
370
|
+
self.logger.debug(f"Member '{name}' is {'async' if is_async else 'sync'} function.")
|
|
371
|
+
|
|
372
|
+
try:
|
|
373
|
+
sig = inspect.signature(member)
|
|
374
|
+
resource_param_name = next(
|
|
375
|
+
(pname for pname in ["resource", "db_resource"] if pname in sig.parameters),
|
|
376
|
+
None,
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
if resource_param_name:
|
|
380
|
+
|
|
381
|
+
async def module_tool_adapter(
|
|
382
|
+
params: Dict[str, Any],
|
|
383
|
+
bound_tool_func=member,
|
|
384
|
+
bound_resource=episode_resource,
|
|
385
|
+
res_param_name=resource_param_name,
|
|
386
|
+
is_async=is_async,
|
|
387
|
+
):
|
|
388
|
+
tool_kwargs = {res_param_name: bound_resource, **params}
|
|
389
|
+
if is_async:
|
|
390
|
+
return await bound_tool_func(**tool_kwargs)
|
|
391
|
+
else:
|
|
392
|
+
return bound_tool_func(**tool_kwargs)
|
|
393
|
+
|
|
394
|
+
available_tools[name] = module_tool_adapter
|
|
395
|
+
self.logger.debug(f"Added tool '{name}' from tools_module directly.")
|
|
396
|
+
else:
|
|
397
|
+
self.logger.debug(
|
|
398
|
+
f"Skipping module tool '{name}': no 'resource' or 'db_resource' parameter in signature '{sig}'."
|
|
399
|
+
)
|
|
400
|
+
except ValueError as e_sig:
|
|
401
|
+
self.logger.debug(f"Skipping module tool '{name}': could not get signature. Error: {e_sig}")
|
|
402
|
+
self.logger.info(f"Combined available tools: {list(available_tools.keys())}")
|
|
403
|
+
return available_tools
|
|
404
|
+
|
|
405
|
+
async def execute_task_poc(self, sample_data: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]:
|
|
406
|
+
if not await self._load_task_components():
|
|
407
|
+
self.logger.error("Failed to load task components.")
|
|
408
|
+
return None
|
|
409
|
+
if not self.base_resource:
|
|
410
|
+
await self.setup_base_resource()
|
|
411
|
+
if not self.base_resource:
|
|
412
|
+
self.logger.error("Base resource setup failed or not performed.")
|
|
413
|
+
return None
|
|
414
|
+
if not self.reward_function:
|
|
415
|
+
self.logger.error("Reward function not loaded.")
|
|
416
|
+
return None # Should be caught by _load_task_components
|
|
417
|
+
|
|
418
|
+
self.logger.info(f"Starting execution for task '{self.task_definition.name}'...")
|
|
419
|
+
episode_resource: Optional[ForkableResource] = None
|
|
420
|
+
evaluation_result: Optional[Dict[str, Any]] = None
|
|
421
|
+
|
|
422
|
+
all_user_turns_successful_function_calls: List[List[Dict[str, Any]]] = (
|
|
423
|
+
[]
|
|
424
|
+
) # Track successful calls for reward fn, list of lists (per user turn)
|
|
425
|
+
conversation_messages: List[Dict[str, Any]] = [] # Use dicts for API compatibility
|
|
426
|
+
|
|
427
|
+
# --- Agent Model Setup ---
|
|
428
|
+
agent_model_name = os.environ.get("MODEL_AGENT")
|
|
429
|
+
if not agent_model_name:
|
|
430
|
+
self.logger.error("MODEL_AGENT environment variable not set.")
|
|
431
|
+
return None
|
|
432
|
+
if agent_model_name.startswith("openai/"):
|
|
433
|
+
self._initialize_openai_client()
|
|
434
|
+
if not self._openai_client:
|
|
435
|
+
self.logger.error("OpenAI client failed to initialize. Cannot proceed.")
|
|
436
|
+
return None
|
|
437
|
+
agent_model_name = agent_model_name.split("openai/", 1)[1] # Get actual model name
|
|
438
|
+
self.logger.info(f"Using OpenAI model: {agent_model_name}")
|
|
439
|
+
elif agent_model_name.startswith("fireworks/") or agent_model_name.startswith("accounts/fireworks"):
|
|
440
|
+
self._initialize_fireworks_client()
|
|
441
|
+
if not self._openai_client:
|
|
442
|
+
self.logger.error("Fireworks client failed to initialize. Cannot proceed.")
|
|
443
|
+
return None
|
|
444
|
+
# Remove prefix if it exists
|
|
445
|
+
if agent_model_name.startswith("fireworks/"):
|
|
446
|
+
agent_model_name = agent_model_name.split("fireworks/", 1)[1]
|
|
447
|
+
# If it starts with accounts/fireworks, keep the full model name
|
|
448
|
+
self.logger.info(f"Using Fireworks model: {agent_model_name}")
|
|
449
|
+
else:
|
|
450
|
+
# Placeholder for other model providers if needed in the future
|
|
451
|
+
self.logger.error(f"Unsupported model provider for MODEL_AGENT: {agent_model_name}")
|
|
452
|
+
return None
|
|
453
|
+
|
|
454
|
+
try:
|
|
455
|
+
# --- Task Setup ---
|
|
456
|
+
if not await self._load_task_components():
|
|
457
|
+
self.logger.error("Failed to load task components.")
|
|
458
|
+
return None
|
|
459
|
+
if not self.base_resource:
|
|
460
|
+
await self.setup_base_resource()
|
|
461
|
+
if not self.base_resource:
|
|
462
|
+
self.logger.error("Base resource setup failed or not performed.")
|
|
463
|
+
return None
|
|
464
|
+
if not self.reward_function:
|
|
465
|
+
self.logger.error("Reward function not loaded.")
|
|
466
|
+
return None
|
|
467
|
+
|
|
468
|
+
self.logger.info("Forking base resource for episode...")
|
|
469
|
+
episode_resource = await self.base_resource.fork()
|
|
470
|
+
self.logger.info(f"Episode resource forked: {type(episode_resource).__name__}")
|
|
471
|
+
|
|
472
|
+
# Initialize the episode resource with sample data if provided
|
|
473
|
+
if sample_data:
|
|
474
|
+
self.logger.info(f"Initializing episode resource with sample data: {sample_data}")
|
|
475
|
+
if hasattr(episode_resource, "initialize"):
|
|
476
|
+
await episode_resource.initialize(**sample_data)
|
|
477
|
+
else:
|
|
478
|
+
self.logger.warning(
|
|
479
|
+
f"Episode resource {type(episode_resource).__name__} does not have initialize method"
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
# Get initial state for injection into first prompt (for HTTP rollout)
|
|
483
|
+
initial_state_description = None
|
|
484
|
+
if hasattr(episode_resource, "get_initial_state_description"):
|
|
485
|
+
try:
|
|
486
|
+
initial_state_description = await episode_resource.get_initial_state_description()
|
|
487
|
+
self.logger.info("Retrieved initial state description for first prompt")
|
|
488
|
+
except Exception as e:
|
|
489
|
+
self.logger.warning(f"Failed to get initial state description: {e}")
|
|
490
|
+
|
|
491
|
+
# --- Initial Conversation State ---
|
|
492
|
+
# The conversation_messages list will be built turn by turn.
|
|
493
|
+
# We need a copy of the user turns from the task definition.
|
|
494
|
+
user_turns_from_task: List[Dict[str, Any]] = []
|
|
495
|
+
if self.task_definition.messages:
|
|
496
|
+
for msg_data in self.task_definition.messages:
|
|
497
|
+
if isinstance(msg_data, dict) and msg_data.get("role") == "user":
|
|
498
|
+
# Ensure it's a dict and has a role, content can be complex
|
|
499
|
+
user_turns_from_task.append(msg_data)
|
|
500
|
+
elif isinstance(msg_data, Message) and msg_data.role == "user":
|
|
501
|
+
user_turns_from_task.append(msg_data.model_dump(exclude_none=True))
|
|
502
|
+
else:
|
|
503
|
+
self.logger.warning(
|
|
504
|
+
f"Skipping non-user message or invalid message type in task definition's messages: {msg_data}"
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
if not user_turns_from_task:
|
|
508
|
+
self.logger.error("No user turns found in task definition's messages. Cannot proceed.")
|
|
509
|
+
return None
|
|
510
|
+
|
|
511
|
+
# --- Interaction Loop ---
|
|
512
|
+
# Loop through the user turns defined in the task or up to poc_max_turns
|
|
513
|
+
num_defined_user_turns = len(user_turns_from_task)
|
|
514
|
+
max_interaction_turns = min(self.task_definition.poc_max_turns, num_defined_user_turns)
|
|
515
|
+
|
|
516
|
+
current_user_turn_index = 0
|
|
517
|
+
|
|
518
|
+
for turn_num in range(1, max_interaction_turns + 1): # Outer loop for user turns
|
|
519
|
+
self.logger.info(
|
|
520
|
+
f"--- User Turn {turn_num}/{max_interaction_turns} (Overall Index {current_user_turn_index + 1}/{num_defined_user_turns}) ---"
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
current_user_turn_accumulated_successful_calls: List[Dict[str, Any]] = []
|
|
524
|
+
|
|
525
|
+
# Add the current user turn's message(s) to the conversation history
|
|
526
|
+
if current_user_turn_index < num_defined_user_turns:
|
|
527
|
+
current_user_turn_message = user_turns_from_task[
|
|
528
|
+
current_user_turn_index
|
|
529
|
+
].copy() # Make a copy to avoid modifying the original
|
|
530
|
+
|
|
531
|
+
# Inject initial state into first user message
|
|
532
|
+
if current_user_turn_index == 0 and initial_state_description:
|
|
533
|
+
original_content = current_user_turn_message.get("content", "")
|
|
534
|
+
enhanced_content = f"{original_content}\n\n{initial_state_description}"
|
|
535
|
+
current_user_turn_message["content"] = enhanced_content
|
|
536
|
+
self.logger.info("Injected initial state into first user prompt")
|
|
537
|
+
|
|
538
|
+
# The user message content might be a string or a list of content blocks (e.g. for multi-modal)
|
|
539
|
+
# For BFCL, it's a string that might represent a JSON list of user messages for that turn.
|
|
540
|
+
# We need to parse it if it's a JSON string representing a list of messages.
|
|
541
|
+
try:
|
|
542
|
+
# Attempt to parse content if it's a string that looks like a JSON list
|
|
543
|
+
if isinstance(current_user_turn_message.get("content"), str):
|
|
544
|
+
parsed_content = json.loads(current_user_turn_message["content"])
|
|
545
|
+
if isinstance(parsed_content, list):
|
|
546
|
+
for sub_msg_dict in parsed_content:
|
|
547
|
+
if (
|
|
548
|
+
isinstance(sub_msg_dict, dict)
|
|
549
|
+
and "role" in sub_msg_dict
|
|
550
|
+
and "content" in sub_msg_dict
|
|
551
|
+
):
|
|
552
|
+
conversation_messages.append(sub_msg_dict)
|
|
553
|
+
else:
|
|
554
|
+
self.logger.warning(
|
|
555
|
+
f"Skipping sub-message in user turn due to invalid format: {sub_msg_dict}"
|
|
556
|
+
)
|
|
557
|
+
conversation_messages.append(
|
|
558
|
+
current_user_turn_message
|
|
559
|
+
) # Fallback to original if parsing fails partially
|
|
560
|
+
break # Stop processing sub-messages for this turn
|
|
561
|
+
else: # If loop completed without break
|
|
562
|
+
pass # Successfully processed all sub-messages
|
|
563
|
+
else: # Content is a JSON string but not a list
|
|
564
|
+
conversation_messages.append(current_user_turn_message)
|
|
565
|
+
else: # Content is not a string or already a complex object
|
|
566
|
+
conversation_messages.append(current_user_turn_message)
|
|
567
|
+
except json.JSONDecodeError: # Content is a string but not valid JSON
|
|
568
|
+
conversation_messages.append(current_user_turn_message)
|
|
569
|
+
|
|
570
|
+
current_user_turn_index += 1
|
|
571
|
+
else:
|
|
572
|
+
self.logger.info("No more user turns defined by task. Ending interaction.")
|
|
573
|
+
break # Break outer loop if no more user messages from task def
|
|
574
|
+
|
|
575
|
+
# 1. Get available tools for this user turn (can be dynamic based on resource state)
|
|
576
|
+
# For BFCL, tools are generally static for the episode, but good practice to refresh.
|
|
577
|
+
resource_tool_specs = await episode_resource.get_tools_spec()
|
|
578
|
+
available_tools_adapters = await self._get_available_tools(
|
|
579
|
+
episode_resource
|
|
580
|
+
) # Get adapters for execution
|
|
581
|
+
|
|
582
|
+
# Format tools for OpenAI API (should be done once per user turn, or if tools change)
|
|
583
|
+
openai_tools: List[ChatCompletionToolParam] = []
|
|
584
|
+
if OPENAI_AVAILABLE:
|
|
585
|
+
# First add tools from the resource
|
|
586
|
+
for spec in resource_tool_specs:
|
|
587
|
+
# Ensure spec has the structure with name and parameters
|
|
588
|
+
if "name" in spec and "parameters" in spec:
|
|
589
|
+
openai_tools.append(
|
|
590
|
+
ChatCompletionToolParam(
|
|
591
|
+
type="function",
|
|
592
|
+
function={
|
|
593
|
+
"name": spec["name"],
|
|
594
|
+
"description": spec.get("description", ""),
|
|
595
|
+
"parameters": spec["parameters"], # Assuming this matches OpenAI schema
|
|
596
|
+
},
|
|
597
|
+
)
|
|
598
|
+
)
|
|
599
|
+
else:
|
|
600
|
+
self.logger.warning(f"Skipping tool spec due to missing name/parameters: {spec}")
|
|
601
|
+
|
|
602
|
+
# Now add tools from the registry
|
|
603
|
+
if (
|
|
604
|
+
self.tools_module
|
|
605
|
+
and hasattr(self.tools_module, "R")
|
|
606
|
+
and hasattr(self.tools_module.R, "get_openai_tools")
|
|
607
|
+
):
|
|
608
|
+
registry_tools = self.tools_module.R.get_openai_tools()
|
|
609
|
+
for tool_spec in registry_tools:
|
|
610
|
+
openai_tools.append(
|
|
611
|
+
ChatCompletionToolParam(
|
|
612
|
+
type="function",
|
|
613
|
+
function={
|
|
614
|
+
"name": tool_spec["name"],
|
|
615
|
+
"description": tool_spec.get("description", ""),
|
|
616
|
+
"parameters": tool_spec["parameters"],
|
|
617
|
+
},
|
|
618
|
+
)
|
|
619
|
+
)
|
|
620
|
+
else:
|
|
621
|
+
self.logger.warning("OpenAI not available, cannot format tools for API.")
|
|
622
|
+
|
|
623
|
+
if not available_tools_adapters and not openai_tools: # If no tools can be formed or executed
|
|
624
|
+
self.logger.info(
|
|
625
|
+
"No tools available from resource or module for this turn. Agent cannot make tool calls."
|
|
626
|
+
)
|
|
627
|
+
# Agent might still respond textually. Let the loop proceed for one LLM call.
|
|
628
|
+
|
|
629
|
+
# Inner loop for multi-step tool use within this single user turn
|
|
630
|
+
current_inner_step = 0
|
|
631
|
+
while current_inner_step < MAX_STEPS_PER_USER_TURN:
|
|
632
|
+
current_inner_step += 1
|
|
633
|
+
self.logger.info(
|
|
634
|
+
f"--- User Turn {turn_num}, Inner Step {current_inner_step}/{MAX_STEPS_PER_USER_TURN} ---"
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
# 2. Call the LLM (OpenAI)
|
|
638
|
+
try:
|
|
639
|
+
# Validate conversation messages for OpenAI API compliance
|
|
640
|
+
self._validate_conversation_messages(conversation_messages)
|
|
641
|
+
|
|
642
|
+
self.logger.debug(
|
|
643
|
+
f"Calling OpenAI: model={agent_model_name}, messages_FULL_HISTORY={json.dumps(conversation_messages, indent=2)}, tools={openai_tools}"
|
|
644
|
+
) # Log full message history
|
|
645
|
+
if not self._openai_client:
|
|
646
|
+
raise Exception("OpenAI client not initialized")
|
|
647
|
+
|
|
648
|
+
response = await self._openai_client.chat.completions.create(
|
|
649
|
+
model=agent_model_name,
|
|
650
|
+
messages=conversation_messages, # type: ignore
|
|
651
|
+
tools=openai_tools if openai_tools else None,
|
|
652
|
+
tool_choice="auto" if openai_tools else None,
|
|
653
|
+
max_tokens=4096,
|
|
654
|
+
temperature=0.0,
|
|
655
|
+
)
|
|
656
|
+
response_message = response.choices[0].message
|
|
657
|
+
self.logger.debug(f"OpenAI response message: {response_message}")
|
|
658
|
+
|
|
659
|
+
except Exception as e_openai:
|
|
660
|
+
self.logger.error(f"Error calling OpenAI API: {e_openai}", exc_info=True)
|
|
661
|
+
# Break inner loop on API error, then outer loop will decide to continue or break.
|
|
662
|
+
# For now, let's break the outer loop as well to prevent cascading errors.
|
|
663
|
+
# TODO: Consider more nuanced error handling for outer loop.
|
|
664
|
+
evaluation_result = {"error": f"OpenAI API error: {e_openai}"}
|
|
665
|
+
# Clean up and return
|
|
666
|
+
if episode_resource:
|
|
667
|
+
await episode_resource.close()
|
|
668
|
+
if self.base_resource:
|
|
669
|
+
await self.base_resource.close()
|
|
670
|
+
self.base_resource = None
|
|
671
|
+
return evaluation_result
|
|
672
|
+
|
|
673
|
+
# 3. Process LLM Response
|
|
674
|
+
# Append assistant's response (content and tool calls) to history
|
|
675
|
+
conversation_messages.append(response_message.model_dump(exclude_none=True))
|
|
676
|
+
|
|
677
|
+
tool_calls = response_message.tool_calls
|
|
678
|
+
if tool_calls:
|
|
679
|
+
self.logger.info(f"Assistant requested {len(tool_calls)} tool calls in this step.")
|
|
680
|
+
current_llm_response_successful_calls: List[Dict[str, Any]] = []
|
|
681
|
+
for tool_call in tool_calls:
|
|
682
|
+
function_name = tool_call.function.name
|
|
683
|
+
function_args_str = tool_call.function.arguments
|
|
684
|
+
self.logger.info(f"Attempting tool call: {function_name}({function_args_str})")
|
|
685
|
+
|
|
686
|
+
tool_adapter = available_tools_adapters.get(function_name)
|
|
687
|
+
if tool_adapter:
|
|
688
|
+
try:
|
|
689
|
+
function_args = json.loads(function_args_str)
|
|
690
|
+
print("show function args: ", function_args)
|
|
691
|
+
function_response = await tool_adapter(function_args)
|
|
692
|
+
self.logger.info(
|
|
693
|
+
f"Tool '{function_name}' result: {str(function_response)[:200]}..."
|
|
694
|
+
)
|
|
695
|
+
conversation_messages.append(
|
|
696
|
+
{
|
|
697
|
+
"tool_call_id": tool_call.id,
|
|
698
|
+
"role": "tool",
|
|
699
|
+
"name": function_name,
|
|
700
|
+
"content": json.dumps(function_response),
|
|
701
|
+
}
|
|
702
|
+
)
|
|
703
|
+
current_llm_response_successful_calls.append(
|
|
704
|
+
{
|
|
705
|
+
"name": function_name,
|
|
706
|
+
"args": function_args,
|
|
707
|
+
}
|
|
708
|
+
)
|
|
709
|
+
except json.JSONDecodeError:
|
|
710
|
+
self.logger.error(
|
|
711
|
+
f"Failed to parse arguments for tool '{function_name}': {function_args_str}"
|
|
712
|
+
)
|
|
713
|
+
conversation_messages.append(
|
|
714
|
+
{
|
|
715
|
+
"tool_call_id": tool_call.id,
|
|
716
|
+
"role": "tool",
|
|
717
|
+
"name": function_name,
|
|
718
|
+
"content": json.dumps({"error": "Invalid JSON arguments"}),
|
|
719
|
+
}
|
|
720
|
+
)
|
|
721
|
+
except Exception as e_tool_exec:
|
|
722
|
+
self.logger.error(
|
|
723
|
+
f"Error executing tool '{function_name}': {e_tool_exec}",
|
|
724
|
+
exc_info=True,
|
|
725
|
+
)
|
|
726
|
+
conversation_messages.append(
|
|
727
|
+
{
|
|
728
|
+
"tool_call_id": tool_call.id,
|
|
729
|
+
"role": "tool",
|
|
730
|
+
"name": function_name,
|
|
731
|
+
"content": json.dumps({"error": f"Execution failed: {e_tool_exec}"}),
|
|
732
|
+
}
|
|
733
|
+
)
|
|
734
|
+
else:
|
|
735
|
+
self.logger.error(
|
|
736
|
+
f"Tool '{function_name}' requested by model but not found in available tools."
|
|
737
|
+
)
|
|
738
|
+
conversation_messages.append(
|
|
739
|
+
{
|
|
740
|
+
"tool_call_id": tool_call.id,
|
|
741
|
+
"role": "tool",
|
|
742
|
+
"name": function_name,
|
|
743
|
+
"content": json.dumps({"error": "Tool not found"}),
|
|
744
|
+
}
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
if current_llm_response_successful_calls:
|
|
748
|
+
current_user_turn_accumulated_successful_calls.extend(
|
|
749
|
+
current_llm_response_successful_calls
|
|
750
|
+
)
|
|
751
|
+
|
|
752
|
+
# If tool calls were made, continue the inner loop for the LLM to react to tool results.
|
|
753
|
+
if not openai_tools and not available_tools_adapters: # No tools were ever available
|
|
754
|
+
self.logger.info(
|
|
755
|
+
"No tools were available, but LLM hallucinated tool calls. Breaking inner loop."
|
|
756
|
+
)
|
|
757
|
+
break # Break inner loop
|
|
758
|
+
else:
|
|
759
|
+
# No tool calls from LLM in this step, means assistant provided a final textual response for this user turn.
|
|
760
|
+
self.logger.info(
|
|
761
|
+
"Assistant did not request tool calls in this step. Ending inner loop for this user turn."
|
|
762
|
+
)
|
|
763
|
+
break # Break the inner while loop
|
|
764
|
+
else: # Inner while loop finished due to max_steps_per_user_turn
|
|
765
|
+
self.logger.warning(
|
|
766
|
+
f"Reached max steps ({MAX_STEPS_PER_USER_TURN}) for user turn {turn_num}. Ending inner loop."
|
|
767
|
+
)
|
|
768
|
+
# End of inner while loop for multi-step tool use
|
|
769
|
+
|
|
770
|
+
if current_user_turn_accumulated_successful_calls:
|
|
771
|
+
all_user_turns_successful_function_calls.append(current_user_turn_accumulated_successful_calls)
|
|
772
|
+
# End of outer for loop for user turns
|
|
773
|
+
|
|
774
|
+
# --- Evaluation ---
|
|
775
|
+
self.logger.info("Evaluating task outcome...")
|
|
776
|
+
task_achieved = False # Reset task_achieved, as PoC logic is gone
|
|
777
|
+
eval_criteria = self.task_definition.evaluation_criteria
|
|
778
|
+
|
|
779
|
+
# Log evaluation_criteria and its relevant fields before calling reward function
|
|
780
|
+
self.logger.debug(f"Evaluation criteria object: {eval_criteria}")
|
|
781
|
+
if eval_criteria:
|
|
782
|
+
self.logger.debug(
|
|
783
|
+
f"Evaluation criteria ground_truth_function_calls: {getattr(eval_criteria, 'ground_truth_function_calls', 'AttributeError or None')}"
|
|
784
|
+
)
|
|
785
|
+
self.logger.debug(
|
|
786
|
+
f"Evaluation criteria ground_truth_comparable_state: {getattr(eval_criteria, 'ground_truth_comparable_state', 'AttributeError or None')}"
|
|
787
|
+
)
|
|
788
|
+
|
|
789
|
+
# Check if episode_resource is SQLResource for final_state_query
|
|
790
|
+
# from .resources import SQLResource # Would be needed here for isinstance
|
|
791
|
+
if eval_criteria and eval_criteria.final_state_query: # and isinstance(episode_resource, SQLResource):
|
|
792
|
+
if hasattr(episode_resource, "step"): # Generic check
|
|
793
|
+
query_res_step = await episode_resource.step(
|
|
794
|
+
"fetch_val_sql", {"query": eval_criteria.final_state_query}
|
|
795
|
+
)
|
|
796
|
+
if query_res_step.get("status") == "success":
|
|
797
|
+
outcome = query_res_step.get("result")
|
|
798
|
+
if eval_criteria.expected_query_result_transform:
|
|
799
|
+
try:
|
|
800
|
+
transform_func = eval(eval_criteria.expected_query_result_transform)
|
|
801
|
+
task_achieved = bool(transform_func(outcome))
|
|
802
|
+
except Exception as e_tf:
|
|
803
|
+
self.logger.error(f"Error applying transform: {e_tf}")
|
|
804
|
+
else:
|
|
805
|
+
task_achieved = bool(outcome)
|
|
806
|
+
self.logger.info(f"Final state query outcome: {outcome}, Task achieved: {task_achieved}")
|
|
807
|
+
else:
|
|
808
|
+
self.logger.error(f"Failed to execute final_state_query: {query_res_step.get('message')}")
|
|
809
|
+
|
|
810
|
+
# TODO: Re-evaluate how task_achieved should be determined without PoC logic
|
|
811
|
+
# Maybe based on final observation, specific tool calls, or reward function logic itself?
|
|
812
|
+
|
|
813
|
+
# Log evaluation_criteria and its relevant fields before calling reward function
|
|
814
|
+
self.logger.debug(f"Evaluation criteria object: {eval_criteria}")
|
|
815
|
+
if eval_criteria:
|
|
816
|
+
self.logger.debug(
|
|
817
|
+
f"Evaluation criteria ground_truth_function_calls: {getattr(eval_criteria, 'ground_truth_function_calls', 'AttributeError or None')}"
|
|
818
|
+
)
|
|
819
|
+
self.logger.debug(
|
|
820
|
+
f"Evaluation criteria ground_truth_comparable_state: {getattr(eval_criteria, 'ground_truth_comparable_state', 'AttributeError or None')}"
|
|
821
|
+
)
|
|
822
|
+
|
|
823
|
+
# Prepare ground_truth dictionary for the reward function
|
|
824
|
+
ground_truth_for_reward = None
|
|
825
|
+
if eval_criteria:
|
|
826
|
+
ground_truth_for_reward = {
|
|
827
|
+
"function_calls": getattr(eval_criteria, "ground_truth_function_calls", None),
|
|
828
|
+
"comparable_state": getattr(eval_criteria, "ground_truth_comparable_state", None),
|
|
829
|
+
}
|
|
830
|
+
|
|
831
|
+
# Prepare state dictionary for reward function
|
|
832
|
+
state_for_reward = {
|
|
833
|
+
"resource": episode_resource,
|
|
834
|
+
"successful_func_calls": all_user_turns_successful_function_calls,
|
|
835
|
+
# Add other relevant state info if needed
|
|
836
|
+
}
|
|
837
|
+
|
|
838
|
+
# Prepare eval_args dictionary
|
|
839
|
+
eval_args = {
|
|
840
|
+
"messages": conversation_messages, # Pass final conversation history (as dicts)
|
|
841
|
+
"state": state_for_reward,
|
|
842
|
+
"task_achieved": task_achieved, # Still needs proper determination
|
|
843
|
+
"task_definition_name": self.task_definition.name,
|
|
844
|
+
}
|
|
845
|
+
|
|
846
|
+
# Add ground_truth as a single parameter (not unpacked)
|
|
847
|
+
if ground_truth_for_reward:
|
|
848
|
+
eval_args["ground_truth"] = ground_truth_for_reward
|
|
849
|
+
|
|
850
|
+
# Call the reward function
|
|
851
|
+
self.logger.info(f"=== CALLING REWARD FUNCTION DEBUG ===")
|
|
852
|
+
self.logger.info(f"Reward function type: {type(self.reward_function)}")
|
|
853
|
+
self.logger.info(f"Eval args keys: {list(eval_args.keys())}")
|
|
854
|
+
self.logger.info(f"Task achieved: {eval_args.get('task_achieved', 'NOT_SET')}")
|
|
855
|
+
self.logger.info(f"Messages count: {len(eval_args.get('messages', []))}")
|
|
856
|
+
evaluation_result = self.reward_function(**eval_args)
|
|
857
|
+
self.logger.info(f"=== REWARD FUNCTION RESULT ===")
|
|
858
|
+
self.logger.info(f"Reward function result: {evaluation_result}")
|
|
859
|
+
self.logger.info(f"Result type: {type(evaluation_result)}")
|
|
860
|
+
self.logger.info(f"=== END REWARD FUNCTION DEBUG ===")
|
|
861
|
+
|
|
862
|
+
# Return both the evaluation result and the inputs for trajectory capture
|
|
863
|
+
return {
|
|
864
|
+
"evaluation_result": evaluation_result,
|
|
865
|
+
"reward_function_inputs": {
|
|
866
|
+
"messages": conversation_messages,
|
|
867
|
+
"state": state_for_reward,
|
|
868
|
+
"task_achieved": task_achieved,
|
|
869
|
+
"task_definition_name": self.task_definition.name,
|
|
870
|
+
"ground_truth": ground_truth_for_reward,
|
|
871
|
+
},
|
|
872
|
+
}
|
|
873
|
+
|
|
874
|
+
except Exception as e_lifecycle:
|
|
875
|
+
self.logger.error(f"Exception during task lifecycle: {e_lifecycle}", exc_info=True)
|
|
876
|
+
return {
|
|
877
|
+
"evaluation_result": {"error": str(e_lifecycle)},
|
|
878
|
+
"reward_function_inputs": None,
|
|
879
|
+
}
|
|
880
|
+
finally:
|
|
881
|
+
if episode_resource:
|
|
882
|
+
await episode_resource.close()
|
|
883
|
+
self.logger.info("Episode resource closed.")
|
|
884
|
+
if self.base_resource:
|
|
885
|
+
await self.base_resource.close()
|
|
886
|
+
self.base_resource = None
|
|
887
|
+
self.logger.info("Base resource closed.")
|
|
888
|
+
self.logger.info(f"Execution for task '{self.task_definition.name}' finished.")
|
|
889
|
+
# This should not be reached normally since we return earlier, but handle edge case
|
|
890
|
+
return {
|
|
891
|
+
"evaluation_result": {"error": "Unexpected execution path"},
|
|
892
|
+
"reward_function_inputs": None,
|
|
893
|
+
}
|