eval-protocol 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. development/__init__.py +1 -0
  2. development/normalize_sandbox_fusion.py +628 -0
  3. development/utils/__init__.py +1 -0
  4. development/utils/generate_api_key.py +31 -0
  5. development/utils/subprocess_manager.py +481 -0
  6. eval_protocol/__init__.py +86 -0
  7. eval_protocol/__main__.py +10 -0
  8. eval_protocol/_version.py +21 -0
  9. eval_protocol/adapters/__init__.py +1 -0
  10. eval_protocol/adapters/braintrust.py +8 -0
  11. eval_protocol/adapters/trl.py +8 -0
  12. eval_protocol/agent/__init__.py +29 -0
  13. eval_protocol/agent/models.py +69 -0
  14. eval_protocol/agent/orchestrator.py +893 -0
  15. eval_protocol/agent/resource_abc.py +89 -0
  16. eval_protocol/agent/resource_pool.py +184 -0
  17. eval_protocol/agent/resources/__init__.py +44 -0
  18. eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
  19. eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
  20. eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
  21. eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
  22. eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
  23. eval_protocol/agent/resources/docker_resource.py +479 -0
  24. eval_protocol/agent/resources/filesystem_resource.py +371 -0
  25. eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
  26. eval_protocol/agent/resources/http_rollout_resource.py +325 -0
  27. eval_protocol/agent/resources/python_state_resource.py +170 -0
  28. eval_protocol/agent/resources/sql_resource.py +271 -0
  29. eval_protocol/agent/task_manager.py +1064 -0
  30. eval_protocol/agent/tool_registry.py +111 -0
  31. eval_protocol/auth.py +156 -0
  32. eval_protocol/cli.py +425 -0
  33. eval_protocol/cli_commands/__init__.py +1 -0
  34. eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
  35. eval_protocol/cli_commands/common.py +242 -0
  36. eval_protocol/cli_commands/deploy.py +486 -0
  37. eval_protocol/cli_commands/deploy_mcp.py +287 -0
  38. eval_protocol/cli_commands/preview.py +186 -0
  39. eval_protocol/cli_commands/run_eval_cmd.py +202 -0
  40. eval_protocol/common_utils.py +36 -0
  41. eval_protocol/config.py +180 -0
  42. eval_protocol/datasets/__init__.py +1 -0
  43. eval_protocol/datasets/loader.py +521 -0
  44. eval_protocol/evaluation.py +1045 -0
  45. eval_protocol/execution/__init__.py +1 -0
  46. eval_protocol/execution/pipeline.py +920 -0
  47. eval_protocol/gcp_tools.py +484 -0
  48. eval_protocol/generation/cache.py +141 -0
  49. eval_protocol/generation/clients/base.py +67 -0
  50. eval_protocol/generation/clients.py +248 -0
  51. eval_protocol/generic_server.py +165 -0
  52. eval_protocol/integrations/__init__.py +12 -0
  53. eval_protocol/integrations/braintrust.py +51 -0
  54. eval_protocol/integrations/deepeval.py +106 -0
  55. eval_protocol/integrations/openeval.py +40 -0
  56. eval_protocol/integrations/trl.py +187 -0
  57. eval_protocol/mcp/__init__.py +48 -0
  58. eval_protocol/mcp/adapter.py +131 -0
  59. eval_protocol/mcp/client/__init__.py +12 -0
  60. eval_protocol/mcp/client/connection.py +499 -0
  61. eval_protocol/mcp/clients.py +195 -0
  62. eval_protocol/mcp/execution/__init__.py +23 -0
  63. eval_protocol/mcp/execution/base_policy.py +227 -0
  64. eval_protocol/mcp/execution/fireworks_policy.py +209 -0
  65. eval_protocol/mcp/execution/manager.py +506 -0
  66. eval_protocol/mcp/execution/policy.py +421 -0
  67. eval_protocol/mcp/grid_renderer.py +54 -0
  68. eval_protocol/mcp/mcpgym.py +637 -0
  69. eval_protocol/mcp/process_manager.py +177 -0
  70. eval_protocol/mcp/session/__init__.py +11 -0
  71. eval_protocol/mcp/session/manager.py +228 -0
  72. eval_protocol/mcp/simple_process_manager.py +291 -0
  73. eval_protocol/mcp/simulation_server.py +458 -0
  74. eval_protocol/mcp/types.py +80 -0
  75. eval_protocol/mcp_agent/__init__.py +1 -0
  76. eval_protocol/mcp_agent/config.py +147 -0
  77. eval_protocol/mcp_agent/intermediary_server.py +542 -0
  78. eval_protocol/mcp_agent/main.py +210 -0
  79. eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
  80. eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
  81. eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
  82. eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
  83. eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
  84. eval_protocol/mcp_agent/session.py +79 -0
  85. eval_protocol/mcp_env.py +304 -0
  86. eval_protocol/models.py +366 -0
  87. eval_protocol/packaging.py +219 -0
  88. eval_protocol/platform_api.py +360 -0
  89. eval_protocol/playback_policy.py +396 -0
  90. eval_protocol/resources.py +128 -0
  91. eval_protocol/reward_function.py +410 -0
  92. eval_protocol/rewards/__init__.py +94 -0
  93. eval_protocol/rewards/accuracy.py +454 -0
  94. eval_protocol/rewards/accuracy_length.py +173 -0
  95. eval_protocol/rewards/apps_coding_reward.py +331 -0
  96. eval_protocol/rewards/apps_execution_utils.py +149 -0
  97. eval_protocol/rewards/apps_testing_util.py +559 -0
  98. eval_protocol/rewards/bfcl_reward.py +313 -0
  99. eval_protocol/rewards/code_execution.py +1620 -0
  100. eval_protocol/rewards/code_execution_utils.py +72 -0
  101. eval_protocol/rewards/cpp_code.py +861 -0
  102. eval_protocol/rewards/deepcoder_reward.py +161 -0
  103. eval_protocol/rewards/format.py +129 -0
  104. eval_protocol/rewards/function_calling.py +541 -0
  105. eval_protocol/rewards/json_schema.py +422 -0
  106. eval_protocol/rewards/language_consistency.py +700 -0
  107. eval_protocol/rewards/lean_prover.py +479 -0
  108. eval_protocol/rewards/length.py +375 -0
  109. eval_protocol/rewards/list_comparison_math_reward.py +221 -0
  110. eval_protocol/rewards/math.py +762 -0
  111. eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
  112. eval_protocol/rewards/reasoning_steps.py +249 -0
  113. eval_protocol/rewards/repetition.py +342 -0
  114. eval_protocol/rewards/tag_count.py +162 -0
  115. eval_protocol/rl_processing.py +82 -0
  116. eval_protocol/server.py +271 -0
  117. eval_protocol/typed_interface.py +260 -0
  118. eval_protocol/utils/__init__.py +8 -0
  119. eval_protocol/utils/batch_evaluation.py +217 -0
  120. eval_protocol/utils/batch_transformation.py +205 -0
  121. eval_protocol/utils/dataset_helpers.py +112 -0
  122. eval_protocol/utils/module_loader.py +56 -0
  123. eval_protocol/utils/packaging_utils.py +108 -0
  124. eval_protocol/utils/static_policy.py +305 -0
  125. eval_protocol-0.0.3.dist-info/METADATA +635 -0
  126. eval_protocol-0.0.3.dist-info/RECORD +130 -0
  127. eval_protocol-0.0.3.dist-info/WHEEL +5 -0
  128. eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
  129. eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
  130. eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
@@ -0,0 +1,893 @@
1
+ # mypy: ignore-errors
2
+ """
3
+ Orchestrator for the Agent Evaluation Framework V2.
4
+ Manages the lifecycle of a task using ForkableResources.
5
+ """
6
+
7
+ import asyncio
8
+ import importlib
9
+ import inspect
10
+ import json
11
+ import logging
12
+ import os
13
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, cast
14
+
15
+ # Attempt to import OpenAI client
16
+ try:
17
+ from openai import AsyncOpenAI, OpenAI
18
+ from openai.types.chat import ChatCompletionMessage, ChatCompletionToolParam
19
+ from openai.types.chat.chat_completion_message_tool_call import (
20
+ ChatCompletionMessageToolCall,
21
+ )
22
+
23
+ OPENAI_AVAILABLE = True
24
+ except ImportError:
25
+ OPENAI_AVAILABLE = False
26
+ # Define dummy types if openai is not installed, to avoid runtime errors on load
27
+ from typing import Any, Dict, List, Optional, Union
28
+
29
+ # Use simple class definitions for runtime and type checking
30
+ class OpenAI:
31
+ def __init__(self, **kwargs: Any) -> None:
32
+ pass
33
+
34
+ class AsyncOpenAI:
35
+ def __init__(self, **kwargs: Any) -> None:
36
+ pass
37
+
38
+ class ChatCompletionMessage:
39
+ content: str = ""
40
+ role: str = "assistant"
41
+
42
+ class ChatCompletionToolParam:
43
+ pass
44
+
45
+ class ChatCompletionMessageToolCall:
46
+ pass
47
+
48
+
49
+ # Max steps for the inner loop within a single user turn
50
+ MAX_STEPS_PER_USER_TURN = 10
51
+
52
+ from ..models import Message, TaskDefinitionModel
53
+ from .resource_abc import ForkableResource
54
+
55
+ # Import specific resource types for type checking if needed, or handle dynamically
56
+ from .resources import (
57
+ BFCLSimAPIResource,
58
+ DockerResource,
59
+ FileSystemResource,
60
+ HttpRolloutResource,
61
+ PythonStateResource,
62
+ SQLResource,
63
+ )
64
+
65
+
66
+ class Orchestrator:
67
+ def __init__(self, task_definition: TaskDefinitionModel):
68
+ self.task_definition = task_definition
69
+ self.base_resource: Optional[ForkableResource] = None
70
+ self.tools_module: Optional[Any] = None
71
+ self.reward_function: Optional[Callable[..., Any]] = None
72
+ self.logger = logging.getLogger(f"Orchestrator.{self.task_definition.name}")
73
+ self.logger.setLevel(logging.DEBUG) # Ensure debug logs are processed
74
+ self.logger.info(f"Orchestrator initialized for task: {self.task_definition.name}")
75
+ self._openai_client: Optional[AsyncOpenAI] = None
76
+
77
+ def _initialize_openai_client(self):
78
+ """Initializes the AsyncOpenAI client if available and not already initialized."""
79
+ if not OPENAI_AVAILABLE:
80
+ self.logger.warning("OpenAI library not available. Cannot use OpenAI models.")
81
+ return
82
+ if self._openai_client is None:
83
+ # Consider adding error handling for missing API key
84
+ try:
85
+ self._openai_client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
86
+ self.logger.info("AsyncOpenAI client initialized.")
87
+ except Exception as e:
88
+ self.logger.error(f"Failed to initialize AsyncOpenAI client: {e}")
89
+ self._openai_client = None # Ensure it's None if init fails
90
+
91
+ def _initialize_fireworks_client(self):
92
+ """Initializes the Fireworks client using OpenAI-compatible interface."""
93
+ if not OPENAI_AVAILABLE:
94
+ self.logger.warning("OpenAI library not available. Cannot use Fireworks models.")
95
+ return
96
+ if self._openai_client is None:
97
+ try:
98
+ self._openai_client = AsyncOpenAI(
99
+ api_key=os.environ.get("FIREWORKS_API_KEY"),
100
+ base_url="https://api.fireworks.ai/inference/v1",
101
+ )
102
+ self.logger.info("Fireworks client initialized.")
103
+ except Exception as e:
104
+ self.logger.error(f"Failed to initialize Fireworks client: {e}")
105
+ self._openai_client = None
106
+
107
+ def _validate_conversation_messages(self, conversation_messages: List[Dict[str, Any]]) -> None:
108
+ """
109
+ Validate and fix conversation messages to ensure OpenAI API compliance.
110
+
111
+ OpenAI requires that tool messages must be preceded by an assistant message with tool_calls.
112
+ This method detects and fixes cases where tool messages are orphaned.
113
+ """
114
+ if not conversation_messages:
115
+ return
116
+
117
+ for i, msg in enumerate(conversation_messages):
118
+ if msg.get("role") == "tool":
119
+ # Check if previous message is assistant with tool_calls
120
+ if i == 0:
121
+ # Tool message at start - this is always invalid
122
+ self.logger.error(f"Found orphaned tool message at start of conversation: {msg}")
123
+ raise ValueError("Tool message cannot be the first message in conversation")
124
+
125
+ prev_msg = conversation_messages[i - 1]
126
+ if prev_msg.get("role") != "assistant" or not prev_msg.get("tool_calls"):
127
+ # Found orphaned tool message - log and remove it
128
+ self.logger.warning(
129
+ f"Found orphaned tool message without preceding assistant tool_calls at index {i}: {msg}"
130
+ )
131
+ self.logger.warning(
132
+ "This suggests a bug in conversation history management - removing invalid tool message"
133
+ )
134
+ conversation_messages.pop(i)
135
+ # Recursively validate again since we modified the list
136
+ return self._validate_conversation_messages(conversation_messages)
137
+
138
+ def _load_module_and_function(self, full_path: str) -> Optional[Callable[..., Any]]:
139
+ try:
140
+ module_path, function_name = full_path.rsplit(".", 1)
141
+ module = importlib.import_module(module_path)
142
+ func = getattr(module, function_name)
143
+
144
+ # Check if the attribute exists but might not be directly callable due to decoration
145
+ # For example, bfcl_reward is defined in the module but wrapped with @reward_function
146
+ if hasattr(module, function_name):
147
+ # For attributes that are or contain callable objects
148
+ attr = getattr(module, function_name)
149
+ if callable(attr):
150
+ self.logger.info(f"Successfully loaded function '{function_name}' from module '{module_path}'.")
151
+ return attr
152
+ # For module-level objects that might wrap callable functions
153
+ elif hasattr(attr, "__call__"):
154
+ self.logger.info(
155
+ f"Successfully loaded callable object '{function_name}' from module '{module_path}'."
156
+ )
157
+ return attr.__call__
158
+ else:
159
+ self.logger.error(f"Loaded attribute '{function_name}' from '{module_path}' is not callable.")
160
+ else:
161
+ self.logger.error(f"Attribute '{function_name}' not found in module '{module_path}'.")
162
+ return None
163
+ except (ImportError, AttributeError, ValueError) as e:
164
+ self.logger.error(f"Failed to load function from '{full_path}': {e}")
165
+ return None
166
+
167
+ async def _load_task_components(self) -> bool:
168
+ if self.task_definition.tools_module_path:
169
+ try:
170
+ self.tools_module = importlib.import_module(self.task_definition.tools_module_path)
171
+ self.logger.info(f"Successfully loaded tools module: {self.task_definition.tools_module_path}")
172
+ except ImportError as e:
173
+ self.logger.error(f"Failed to import tools module '{self.task_definition.tools_module_path}': {e}")
174
+ return False
175
+ else:
176
+ self.logger.info("No 'tools_module_path' specified. Tools may only come from resource.get_tools_spec().")
177
+
178
+ # Load reward function
179
+ if self.task_definition.reward_function_path:
180
+ try:
181
+ # First try direct import
182
+ self.reward_function = self._load_module_and_function(self.task_definition.reward_function_path)
183
+
184
+ if not self.reward_function:
185
+ # If that failed, check if we need to import from eval_protocol.rewards
186
+ if "." not in self.task_definition.reward_function_path:
187
+ # Try importing from rewards directly as a fallback
188
+ fallback_path = f"eval_protocol.rewards.{self.task_definition.reward_function_path}"
189
+ self.logger.info(f"Attempting fallback import from: {fallback_path}")
190
+ self.reward_function = self._load_module_and_function(fallback_path)
191
+
192
+ # If still no function, try importing from __init__ exports
193
+ if (
194
+ not self.reward_function
195
+ and "eval_protocol.rewards" in self.task_definition.reward_function_path
196
+ ):
197
+ # Extract the function name from the path
198
+ func_name = self.task_definition.reward_function_path.split(".")[-1]
199
+ self.logger.debug(f"Attempting to get function by name: {func_name}")
200
+ try:
201
+ import eval_protocol.rewards
202
+
203
+ self.logger.debug(f"Available in rewards module: {dir(eval_protocol.rewards)}")
204
+ if hasattr(eval_protocol.rewards, func_name):
205
+ self.reward_function = getattr(eval_protocol.rewards, func_name)
206
+ self.logger.info(f"Found reward function {func_name} in eval_protocol.rewards")
207
+ self.logger.debug(f"Loaded function type: {type(self.reward_function)}")
208
+ self.logger.debug(f"Is callable: {callable(self.reward_function)}")
209
+ else:
210
+ self.logger.error(f"Function {func_name} not found in eval_protocol.rewards")
211
+ except (ImportError, AttributeError) as e:
212
+ self.logger.error(f"Error importing from rewards module: {e}")
213
+
214
+ if self.reward_function:
215
+ self.logger.info(
216
+ f"Successfully loaded reward function: {self.task_definition.reward_function_path}"
217
+ )
218
+ return True
219
+ else:
220
+ self.logger.error(
221
+ f"Failed to load reward function from '{self.task_definition.reward_function_path}'"
222
+ )
223
+ return False
224
+ except Exception as e:
225
+ self.logger.error(f"Error loading reward function: {e}", exc_info=True)
226
+ return False
227
+ else:
228
+ self.logger.error("Reward function path is mandatory but missing.")
229
+ return False
230
+ return True
231
+
232
+ def _get_resource_class(self, resource_type_name: str) -> Type[ForkableResource]:
233
+ # This method will now need to look into eval_protocol.agent_v2.resources
234
+ # For example: from .resources import SQLResource, PythonStateResource etc.
235
+ # And then map resource_type_name string to the class.
236
+ # For now, a placeholder that would need specific imports or a registry.
237
+
238
+ # Option 1: Direct mapping (requires importing all known resource types here)
239
+ # from .resources import PythonStateResource, SQLResource, FileSystemResource, DockerResource # noqa
240
+
241
+ mapping = {
242
+ "PythonStateResource": PythonStateResource,
243
+ "SQLResource": SQLResource,
244
+ "FileSystemResource": FileSystemResource,
245
+ "DockerResource": DockerResource,
246
+ "BFCLSimAPIResource": BFCLSimAPIResource, # Add BFCLSimAPIResource to mapping
247
+ "HttpRolloutResource": HttpRolloutResource, # Add HttpRolloutResource to mapping
248
+ "http_rollout": HttpRolloutResource, # Allow lowercase alias for convenience
249
+ }
250
+ resource_class = mapping.get(resource_type_name)
251
+
252
+ if resource_class is None:
253
+ raise ValueError(
254
+ f"Resource class '{resource_type_name}' not found or not mapped in Orchestrator._get_resource_class."
255
+ )
256
+ # No need to check issubclass here if mapping is correct and types are imported.
257
+ return cast(Type[ForkableResource], resource_class)
258
+
259
+ async def setup_base_resource(self) -> None:
260
+ resource_type = self.task_definition.resource_type
261
+ base_config = self.task_definition.base_resource_config
262
+
263
+ self.logger.info(f"Attempting to set up base resource of type '{resource_type}'...")
264
+ try:
265
+ ResourceClass = self._get_resource_class(resource_type)
266
+ self.base_resource = ResourceClass()
267
+ await self.base_resource.setup(base_config)
268
+ self.logger.info(f"Base resource '{resource_type}' setup complete.")
269
+ except ValueError as e_val:
270
+ self.logger.error(f"Could not get resource class '{resource_type}'. {e_val}")
271
+ self.base_resource = None
272
+ except Exception as e_setup:
273
+ self.logger.error(
274
+ f"Failed to setup base resource '{resource_type}'. {e_setup}",
275
+ exc_info=True,
276
+ )
277
+ self.base_resource = None
278
+
279
+ async def _get_available_tools(self, episode_resource: ForkableResource) -> Dict[str, Callable[..., Any]]:
280
+ available_tools: Dict[str, Callable[..., Any]] = {}
281
+ if episode_resource:
282
+ resource_tool_specs = await episode_resource.get_tools_spec()
283
+ self.logger.debug(f"Raw tool specs from resource.get_tools_spec(): {resource_tool_specs}")
284
+ for tool_spec in resource_tool_specs:
285
+ # Corrected logic based on BFCLSimAPIResource._infer_schema_from_method output
286
+ tool_name = tool_spec.get("name")
287
+ if tool_name:
288
+ # Create an async adapter function that calls episode_resource.step
289
+ async def resource_tool_adapter(
290
+ params: Dict[str, Any],
291
+ bound_tool_name=tool_name,
292
+ bound_resource=episode_resource,
293
+ ):
294
+ # Ensure params are passed correctly to step
295
+ return await bound_resource.step(action_name=bound_tool_name, action_params=params)
296
+
297
+ available_tools[tool_name] = resource_tool_adapter
298
+ self.logger.debug(f"Added tool '{tool_name}' from resource spec.")
299
+ else:
300
+ self.logger.warning(f"Skipping resource tool spec due to missing 'name': {tool_spec}")
301
+
302
+ # Check for tools defined using ToolRegistry (more common pattern)
303
+ if self.tools_module:
304
+ self.logger.debug(f"Inspecting tools_module: {self.tools_module} (type: {type(self.tools_module)})")
305
+
306
+ # First, try to find a ToolRegistry instance
307
+ registry_instances = []
308
+ for name, member in inspect.getmembers(self.tools_module):
309
+ # Skip if it starts with underscore or is not a ToolRegistry
310
+ if name.startswith("_"):
311
+ continue
312
+
313
+ if hasattr(member, "get_tools") and callable(member.get_tools):
314
+ registry_instances.append((name, member))
315
+ self.logger.debug(f"Found ToolRegistry instance: {name}")
316
+
317
+ if registry_instances:
318
+ # Use the first registry instance found
319
+ registry_name, registry = registry_instances[0]
320
+ self.logger.info(f"Using ToolRegistry '{registry_name}' from module")
321
+
322
+ # Get all tools from the registry
323
+ registry_tools = registry.get_tools()
324
+ for tool_name, tool_func in registry_tools.items():
325
+ # Create an adapter that will pass the resource to the tool
326
+ def create_tool_adapter(tool_func):
327
+ async def adapter(params: Dict[str, Any], bound_resource=episode_resource):
328
+ # Handle both sync and async functions
329
+ if asyncio.iscoroutinefunction(tool_func):
330
+ result = await tool_func(resource=bound_resource, **params)
331
+ else:
332
+ result = tool_func(resource=bound_resource, **params)
333
+ return result
334
+
335
+ return adapter
336
+
337
+ available_tools[tool_name] = create_tool_adapter(tool_func)
338
+ self.logger.debug(f"Added tool '{tool_name}' from registry {registry_name}")
339
+
340
+ # If we found and used a registry, we're done
341
+ if available_tools:
342
+ self.logger.info(f"Found {len(available_tools)} tools from ToolRegistry")
343
+ self.logger.debug(f"Tool names: {list(available_tools.keys())}")
344
+
345
+ # If no registry tools were found, fall back to module inspection
346
+ if not available_tools:
347
+ self.logger.debug("No ToolRegistry found or no tools in registry. Falling back to module inspection.")
348
+
349
+ members_to_inspect = []
350
+ if inspect.ismodule(self.tools_module):
351
+ self.logger.debug("tools_module is a module. Using inspect.getmembers.")
352
+ members_to_inspect = inspect.getmembers(self.tools_module)
353
+ elif hasattr(self.tools_module, "__dict__"):
354
+ self.logger.debug("tools_module is an object with __dict__. Iterating __dict__.items().")
355
+ members_to_inspect = self.tools_module.__dict__.items()
356
+ else:
357
+ self.logger.debug("Falling back to inspect.getmembers.")
358
+ members_to_inspect = inspect.getmembers(self.tools_module)
359
+
360
+ for name, member in members_to_inspect:
361
+ self.logger.debug(
362
+ f"Found member in tools_module: '{name}', type: {type(member)}, callable: {callable(member)}"
363
+ )
364
+ if name.startswith("_") or not callable(member):
365
+ self.logger.debug(f"Skipping member '{name}' (startswith_ or not callable).")
366
+ continue
367
+
368
+ # Check if it's a sync or async function
369
+ is_async = asyncio.iscoroutinefunction(member)
370
+ self.logger.debug(f"Member '{name}' is {'async' if is_async else 'sync'} function.")
371
+
372
+ try:
373
+ sig = inspect.signature(member)
374
+ resource_param_name = next(
375
+ (pname for pname in ["resource", "db_resource"] if pname in sig.parameters),
376
+ None,
377
+ )
378
+
379
+ if resource_param_name:
380
+
381
+ async def module_tool_adapter(
382
+ params: Dict[str, Any],
383
+ bound_tool_func=member,
384
+ bound_resource=episode_resource,
385
+ res_param_name=resource_param_name,
386
+ is_async=is_async,
387
+ ):
388
+ tool_kwargs = {res_param_name: bound_resource, **params}
389
+ if is_async:
390
+ return await bound_tool_func(**tool_kwargs)
391
+ else:
392
+ return bound_tool_func(**tool_kwargs)
393
+
394
+ available_tools[name] = module_tool_adapter
395
+ self.logger.debug(f"Added tool '{name}' from tools_module directly.")
396
+ else:
397
+ self.logger.debug(
398
+ f"Skipping module tool '{name}': no 'resource' or 'db_resource' parameter in signature '{sig}'."
399
+ )
400
+ except ValueError as e_sig:
401
+ self.logger.debug(f"Skipping module tool '{name}': could not get signature. Error: {e_sig}")
402
+ self.logger.info(f"Combined available tools: {list(available_tools.keys())}")
403
+ return available_tools
404
+
405
+ async def execute_task_poc(self, sample_data: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]:
406
+ if not await self._load_task_components():
407
+ self.logger.error("Failed to load task components.")
408
+ return None
409
+ if not self.base_resource:
410
+ await self.setup_base_resource()
411
+ if not self.base_resource:
412
+ self.logger.error("Base resource setup failed or not performed.")
413
+ return None
414
+ if not self.reward_function:
415
+ self.logger.error("Reward function not loaded.")
416
+ return None # Should be caught by _load_task_components
417
+
418
+ self.logger.info(f"Starting execution for task '{self.task_definition.name}'...")
419
+ episode_resource: Optional[ForkableResource] = None
420
+ evaluation_result: Optional[Dict[str, Any]] = None
421
+
422
+ all_user_turns_successful_function_calls: List[List[Dict[str, Any]]] = (
423
+ []
424
+ ) # Track successful calls for reward fn, list of lists (per user turn)
425
+ conversation_messages: List[Dict[str, Any]] = [] # Use dicts for API compatibility
426
+
427
+ # --- Agent Model Setup ---
428
+ agent_model_name = os.environ.get("MODEL_AGENT")
429
+ if not agent_model_name:
430
+ self.logger.error("MODEL_AGENT environment variable not set.")
431
+ return None
432
+ if agent_model_name.startswith("openai/"):
433
+ self._initialize_openai_client()
434
+ if not self._openai_client:
435
+ self.logger.error("OpenAI client failed to initialize. Cannot proceed.")
436
+ return None
437
+ agent_model_name = agent_model_name.split("openai/", 1)[1] # Get actual model name
438
+ self.logger.info(f"Using OpenAI model: {agent_model_name}")
439
+ elif agent_model_name.startswith("fireworks/") or agent_model_name.startswith("accounts/fireworks"):
440
+ self._initialize_fireworks_client()
441
+ if not self._openai_client:
442
+ self.logger.error("Fireworks client failed to initialize. Cannot proceed.")
443
+ return None
444
+ # Remove prefix if it exists
445
+ if agent_model_name.startswith("fireworks/"):
446
+ agent_model_name = agent_model_name.split("fireworks/", 1)[1]
447
+ # If it starts with accounts/fireworks, keep the full model name
448
+ self.logger.info(f"Using Fireworks model: {agent_model_name}")
449
+ else:
450
+ # Placeholder for other model providers if needed in the future
451
+ self.logger.error(f"Unsupported model provider for MODEL_AGENT: {agent_model_name}")
452
+ return None
453
+
454
+ try:
455
+ # --- Task Setup ---
456
+ if not await self._load_task_components():
457
+ self.logger.error("Failed to load task components.")
458
+ return None
459
+ if not self.base_resource:
460
+ await self.setup_base_resource()
461
+ if not self.base_resource:
462
+ self.logger.error("Base resource setup failed or not performed.")
463
+ return None
464
+ if not self.reward_function:
465
+ self.logger.error("Reward function not loaded.")
466
+ return None
467
+
468
+ self.logger.info("Forking base resource for episode...")
469
+ episode_resource = await self.base_resource.fork()
470
+ self.logger.info(f"Episode resource forked: {type(episode_resource).__name__}")
471
+
472
+ # Initialize the episode resource with sample data if provided
473
+ if sample_data:
474
+ self.logger.info(f"Initializing episode resource with sample data: {sample_data}")
475
+ if hasattr(episode_resource, "initialize"):
476
+ await episode_resource.initialize(**sample_data)
477
+ else:
478
+ self.logger.warning(
479
+ f"Episode resource {type(episode_resource).__name__} does not have initialize method"
480
+ )
481
+
482
+ # Get initial state for injection into first prompt (for HTTP rollout)
483
+ initial_state_description = None
484
+ if hasattr(episode_resource, "get_initial_state_description"):
485
+ try:
486
+ initial_state_description = await episode_resource.get_initial_state_description()
487
+ self.logger.info("Retrieved initial state description for first prompt")
488
+ except Exception as e:
489
+ self.logger.warning(f"Failed to get initial state description: {e}")
490
+
491
+ # --- Initial Conversation State ---
492
+ # The conversation_messages list will be built turn by turn.
493
+ # We need a copy of the user turns from the task definition.
494
+ user_turns_from_task: List[Dict[str, Any]] = []
495
+ if self.task_definition.messages:
496
+ for msg_data in self.task_definition.messages:
497
+ if isinstance(msg_data, dict) and msg_data.get("role") == "user":
498
+ # Ensure it's a dict and has a role, content can be complex
499
+ user_turns_from_task.append(msg_data)
500
+ elif isinstance(msg_data, Message) and msg_data.role == "user":
501
+ user_turns_from_task.append(msg_data.model_dump(exclude_none=True))
502
+ else:
503
+ self.logger.warning(
504
+ f"Skipping non-user message or invalid message type in task definition's messages: {msg_data}"
505
+ )
506
+
507
+ if not user_turns_from_task:
508
+ self.logger.error("No user turns found in task definition's messages. Cannot proceed.")
509
+ return None
510
+
511
+ # --- Interaction Loop ---
512
+ # Loop through the user turns defined in the task or up to poc_max_turns
513
+ num_defined_user_turns = len(user_turns_from_task)
514
+ max_interaction_turns = min(self.task_definition.poc_max_turns, num_defined_user_turns)
515
+
516
+ current_user_turn_index = 0
517
+
518
+ for turn_num in range(1, max_interaction_turns + 1): # Outer loop for user turns
519
+ self.logger.info(
520
+ f"--- User Turn {turn_num}/{max_interaction_turns} (Overall Index {current_user_turn_index + 1}/{num_defined_user_turns}) ---"
521
+ )
522
+
523
+ current_user_turn_accumulated_successful_calls: List[Dict[str, Any]] = []
524
+
525
+ # Add the current user turn's message(s) to the conversation history
526
+ if current_user_turn_index < num_defined_user_turns:
527
+ current_user_turn_message = user_turns_from_task[
528
+ current_user_turn_index
529
+ ].copy() # Make a copy to avoid modifying the original
530
+
531
+ # Inject initial state into first user message
532
+ if current_user_turn_index == 0 and initial_state_description:
533
+ original_content = current_user_turn_message.get("content", "")
534
+ enhanced_content = f"{original_content}\n\n{initial_state_description}"
535
+ current_user_turn_message["content"] = enhanced_content
536
+ self.logger.info("Injected initial state into first user prompt")
537
+
538
+ # The user message content might be a string or a list of content blocks (e.g. for multi-modal)
539
+ # For BFCL, it's a string that might represent a JSON list of user messages for that turn.
540
+ # We need to parse it if it's a JSON string representing a list of messages.
541
+ try:
542
+ # Attempt to parse content if it's a string that looks like a JSON list
543
+ if isinstance(current_user_turn_message.get("content"), str):
544
+ parsed_content = json.loads(current_user_turn_message["content"])
545
+ if isinstance(parsed_content, list):
546
+ for sub_msg_dict in parsed_content:
547
+ if (
548
+ isinstance(sub_msg_dict, dict)
549
+ and "role" in sub_msg_dict
550
+ and "content" in sub_msg_dict
551
+ ):
552
+ conversation_messages.append(sub_msg_dict)
553
+ else:
554
+ self.logger.warning(
555
+ f"Skipping sub-message in user turn due to invalid format: {sub_msg_dict}"
556
+ )
557
+ conversation_messages.append(
558
+ current_user_turn_message
559
+ ) # Fallback to original if parsing fails partially
560
+ break # Stop processing sub-messages for this turn
561
+ else: # If loop completed without break
562
+ pass # Successfully processed all sub-messages
563
+ else: # Content is a JSON string but not a list
564
+ conversation_messages.append(current_user_turn_message)
565
+ else: # Content is not a string or already a complex object
566
+ conversation_messages.append(current_user_turn_message)
567
+ except json.JSONDecodeError: # Content is a string but not valid JSON
568
+ conversation_messages.append(current_user_turn_message)
569
+
570
+ current_user_turn_index += 1
571
+ else:
572
+ self.logger.info("No more user turns defined by task. Ending interaction.")
573
+ break # Break outer loop if no more user messages from task def
574
+
575
+ # 1. Get available tools for this user turn (can be dynamic based on resource state)
576
+ # For BFCL, tools are generally static for the episode, but good practice to refresh.
577
+ resource_tool_specs = await episode_resource.get_tools_spec()
578
+ available_tools_adapters = await self._get_available_tools(
579
+ episode_resource
580
+ ) # Get adapters for execution
581
+
582
+ # Format tools for OpenAI API (should be done once per user turn, or if tools change)
583
+ openai_tools: List[ChatCompletionToolParam] = []
584
+ if OPENAI_AVAILABLE:
585
+ # First add tools from the resource
586
+ for spec in resource_tool_specs:
587
+ # Ensure spec has the structure with name and parameters
588
+ if "name" in spec and "parameters" in spec:
589
+ openai_tools.append(
590
+ ChatCompletionToolParam(
591
+ type="function",
592
+ function={
593
+ "name": spec["name"],
594
+ "description": spec.get("description", ""),
595
+ "parameters": spec["parameters"], # Assuming this matches OpenAI schema
596
+ },
597
+ )
598
+ )
599
+ else:
600
+ self.logger.warning(f"Skipping tool spec due to missing name/parameters: {spec}")
601
+
602
+ # Now add tools from the registry
603
+ if (
604
+ self.tools_module
605
+ and hasattr(self.tools_module, "R")
606
+ and hasattr(self.tools_module.R, "get_openai_tools")
607
+ ):
608
+ registry_tools = self.tools_module.R.get_openai_tools()
609
+ for tool_spec in registry_tools:
610
+ openai_tools.append(
611
+ ChatCompletionToolParam(
612
+ type="function",
613
+ function={
614
+ "name": tool_spec["name"],
615
+ "description": tool_spec.get("description", ""),
616
+ "parameters": tool_spec["parameters"],
617
+ },
618
+ )
619
+ )
620
+ else:
621
+ self.logger.warning("OpenAI not available, cannot format tools for API.")
622
+
623
+ if not available_tools_adapters and not openai_tools: # If no tools can be formed or executed
624
+ self.logger.info(
625
+ "No tools available from resource or module for this turn. Agent cannot make tool calls."
626
+ )
627
+ # Agent might still respond textually. Let the loop proceed for one LLM call.
628
+
629
+ # Inner loop for multi-step tool use within this single user turn
630
+ current_inner_step = 0
631
+ while current_inner_step < MAX_STEPS_PER_USER_TURN:
632
+ current_inner_step += 1
633
+ self.logger.info(
634
+ f"--- User Turn {turn_num}, Inner Step {current_inner_step}/{MAX_STEPS_PER_USER_TURN} ---"
635
+ )
636
+
637
+ # 2. Call the LLM (OpenAI)
638
+ try:
639
+ # Validate conversation messages for OpenAI API compliance
640
+ self._validate_conversation_messages(conversation_messages)
641
+
642
+ self.logger.debug(
643
+ f"Calling OpenAI: model={agent_model_name}, messages_FULL_HISTORY={json.dumps(conversation_messages, indent=2)}, tools={openai_tools}"
644
+ ) # Log full message history
645
+ if not self._openai_client:
646
+ raise Exception("OpenAI client not initialized")
647
+
648
+ response = await self._openai_client.chat.completions.create(
649
+ model=agent_model_name,
650
+ messages=conversation_messages, # type: ignore
651
+ tools=openai_tools if openai_tools else None,
652
+ tool_choice="auto" if openai_tools else None,
653
+ max_tokens=4096,
654
+ temperature=0.0,
655
+ )
656
+ response_message = response.choices[0].message
657
+ self.logger.debug(f"OpenAI response message: {response_message}")
658
+
659
+ except Exception as e_openai:
660
+ self.logger.error(f"Error calling OpenAI API: {e_openai}", exc_info=True)
661
+ # Break inner loop on API error, then outer loop will decide to continue or break.
662
+ # For now, let's break the outer loop as well to prevent cascading errors.
663
+ # TODO: Consider more nuanced error handling for outer loop.
664
+ evaluation_result = {"error": f"OpenAI API error: {e_openai}"}
665
+ # Clean up and return
666
+ if episode_resource:
667
+ await episode_resource.close()
668
+ if self.base_resource:
669
+ await self.base_resource.close()
670
+ self.base_resource = None
671
+ return evaluation_result
672
+
673
+ # 3. Process LLM Response
674
+ # Append assistant's response (content and tool calls) to history
675
+ conversation_messages.append(response_message.model_dump(exclude_none=True))
676
+
677
+ tool_calls = response_message.tool_calls
678
+ if tool_calls:
679
+ self.logger.info(f"Assistant requested {len(tool_calls)} tool calls in this step.")
680
+ current_llm_response_successful_calls: List[Dict[str, Any]] = []
681
+ for tool_call in tool_calls:
682
+ function_name = tool_call.function.name
683
+ function_args_str = tool_call.function.arguments
684
+ self.logger.info(f"Attempting tool call: {function_name}({function_args_str})")
685
+
686
+ tool_adapter = available_tools_adapters.get(function_name)
687
+ if tool_adapter:
688
+ try:
689
+ function_args = json.loads(function_args_str)
690
+ print("show function args: ", function_args)
691
+ function_response = await tool_adapter(function_args)
692
+ self.logger.info(
693
+ f"Tool '{function_name}' result: {str(function_response)[:200]}..."
694
+ )
695
+ conversation_messages.append(
696
+ {
697
+ "tool_call_id": tool_call.id,
698
+ "role": "tool",
699
+ "name": function_name,
700
+ "content": json.dumps(function_response),
701
+ }
702
+ )
703
+ current_llm_response_successful_calls.append(
704
+ {
705
+ "name": function_name,
706
+ "args": function_args,
707
+ }
708
+ )
709
+ except json.JSONDecodeError:
710
+ self.logger.error(
711
+ f"Failed to parse arguments for tool '{function_name}': {function_args_str}"
712
+ )
713
+ conversation_messages.append(
714
+ {
715
+ "tool_call_id": tool_call.id,
716
+ "role": "tool",
717
+ "name": function_name,
718
+ "content": json.dumps({"error": "Invalid JSON arguments"}),
719
+ }
720
+ )
721
+ except Exception as e_tool_exec:
722
+ self.logger.error(
723
+ f"Error executing tool '{function_name}': {e_tool_exec}",
724
+ exc_info=True,
725
+ )
726
+ conversation_messages.append(
727
+ {
728
+ "tool_call_id": tool_call.id,
729
+ "role": "tool",
730
+ "name": function_name,
731
+ "content": json.dumps({"error": f"Execution failed: {e_tool_exec}"}),
732
+ }
733
+ )
734
+ else:
735
+ self.logger.error(
736
+ f"Tool '{function_name}' requested by model but not found in available tools."
737
+ )
738
+ conversation_messages.append(
739
+ {
740
+ "tool_call_id": tool_call.id,
741
+ "role": "tool",
742
+ "name": function_name,
743
+ "content": json.dumps({"error": "Tool not found"}),
744
+ }
745
+ )
746
+
747
+ if current_llm_response_successful_calls:
748
+ current_user_turn_accumulated_successful_calls.extend(
749
+ current_llm_response_successful_calls
750
+ )
751
+
752
+ # If tool calls were made, continue the inner loop for the LLM to react to tool results.
753
+ if not openai_tools and not available_tools_adapters: # No tools were ever available
754
+ self.logger.info(
755
+ "No tools were available, but LLM hallucinated tool calls. Breaking inner loop."
756
+ )
757
+ break # Break inner loop
758
+ else:
759
+ # No tool calls from LLM in this step, means assistant provided a final textual response for this user turn.
760
+ self.logger.info(
761
+ "Assistant did not request tool calls in this step. Ending inner loop for this user turn."
762
+ )
763
+ break # Break the inner while loop
764
+ else: # Inner while loop finished due to max_steps_per_user_turn
765
+ self.logger.warning(
766
+ f"Reached max steps ({MAX_STEPS_PER_USER_TURN}) for user turn {turn_num}. Ending inner loop."
767
+ )
768
+ # End of inner while loop for multi-step tool use
769
+
770
+ if current_user_turn_accumulated_successful_calls:
771
+ all_user_turns_successful_function_calls.append(current_user_turn_accumulated_successful_calls)
772
+ # End of outer for loop for user turns
773
+
774
+ # --- Evaluation ---
775
+ self.logger.info("Evaluating task outcome...")
776
+ task_achieved = False # Reset task_achieved, as PoC logic is gone
777
+ eval_criteria = self.task_definition.evaluation_criteria
778
+
779
+ # Log evaluation_criteria and its relevant fields before calling reward function
780
+ self.logger.debug(f"Evaluation criteria object: {eval_criteria}")
781
+ if eval_criteria:
782
+ self.logger.debug(
783
+ f"Evaluation criteria ground_truth_function_calls: {getattr(eval_criteria, 'ground_truth_function_calls', 'AttributeError or None')}"
784
+ )
785
+ self.logger.debug(
786
+ f"Evaluation criteria ground_truth_comparable_state: {getattr(eval_criteria, 'ground_truth_comparable_state', 'AttributeError or None')}"
787
+ )
788
+
789
+ # Check if episode_resource is SQLResource for final_state_query
790
+ # from .resources import SQLResource # Would be needed here for isinstance
791
+ if eval_criteria and eval_criteria.final_state_query: # and isinstance(episode_resource, SQLResource):
792
+ if hasattr(episode_resource, "step"): # Generic check
793
+ query_res_step = await episode_resource.step(
794
+ "fetch_val_sql", {"query": eval_criteria.final_state_query}
795
+ )
796
+ if query_res_step.get("status") == "success":
797
+ outcome = query_res_step.get("result")
798
+ if eval_criteria.expected_query_result_transform:
799
+ try:
800
+ transform_func = eval(eval_criteria.expected_query_result_transform)
801
+ task_achieved = bool(transform_func(outcome))
802
+ except Exception as e_tf:
803
+ self.logger.error(f"Error applying transform: {e_tf}")
804
+ else:
805
+ task_achieved = bool(outcome)
806
+ self.logger.info(f"Final state query outcome: {outcome}, Task achieved: {task_achieved}")
807
+ else:
808
+ self.logger.error(f"Failed to execute final_state_query: {query_res_step.get('message')}")
809
+
810
+ # TODO: Re-evaluate how task_achieved should be determined without PoC logic
811
+ # Maybe based on final observation, specific tool calls, or reward function logic itself?
812
+
813
+ # Log evaluation_criteria and its relevant fields before calling reward function
814
+ self.logger.debug(f"Evaluation criteria object: {eval_criteria}")
815
+ if eval_criteria:
816
+ self.logger.debug(
817
+ f"Evaluation criteria ground_truth_function_calls: {getattr(eval_criteria, 'ground_truth_function_calls', 'AttributeError or None')}"
818
+ )
819
+ self.logger.debug(
820
+ f"Evaluation criteria ground_truth_comparable_state: {getattr(eval_criteria, 'ground_truth_comparable_state', 'AttributeError or None')}"
821
+ )
822
+
823
+ # Prepare ground_truth dictionary for the reward function
824
+ ground_truth_for_reward = None
825
+ if eval_criteria:
826
+ ground_truth_for_reward = {
827
+ "function_calls": getattr(eval_criteria, "ground_truth_function_calls", None),
828
+ "comparable_state": getattr(eval_criteria, "ground_truth_comparable_state", None),
829
+ }
830
+
831
+ # Prepare state dictionary for reward function
832
+ state_for_reward = {
833
+ "resource": episode_resource,
834
+ "successful_func_calls": all_user_turns_successful_function_calls,
835
+ # Add other relevant state info if needed
836
+ }
837
+
838
+ # Prepare eval_args dictionary
839
+ eval_args = {
840
+ "messages": conversation_messages, # Pass final conversation history (as dicts)
841
+ "state": state_for_reward,
842
+ "task_achieved": task_achieved, # Still needs proper determination
843
+ "task_definition_name": self.task_definition.name,
844
+ }
845
+
846
+ # Add ground_truth as a single parameter (not unpacked)
847
+ if ground_truth_for_reward:
848
+ eval_args["ground_truth"] = ground_truth_for_reward
849
+
850
+ # Call the reward function
851
+ self.logger.info(f"=== CALLING REWARD FUNCTION DEBUG ===")
852
+ self.logger.info(f"Reward function type: {type(self.reward_function)}")
853
+ self.logger.info(f"Eval args keys: {list(eval_args.keys())}")
854
+ self.logger.info(f"Task achieved: {eval_args.get('task_achieved', 'NOT_SET')}")
855
+ self.logger.info(f"Messages count: {len(eval_args.get('messages', []))}")
856
+ evaluation_result = self.reward_function(**eval_args)
857
+ self.logger.info(f"=== REWARD FUNCTION RESULT ===")
858
+ self.logger.info(f"Reward function result: {evaluation_result}")
859
+ self.logger.info(f"Result type: {type(evaluation_result)}")
860
+ self.logger.info(f"=== END REWARD FUNCTION DEBUG ===")
861
+
862
+ # Return both the evaluation result and the inputs for trajectory capture
863
+ return {
864
+ "evaluation_result": evaluation_result,
865
+ "reward_function_inputs": {
866
+ "messages": conversation_messages,
867
+ "state": state_for_reward,
868
+ "task_achieved": task_achieved,
869
+ "task_definition_name": self.task_definition.name,
870
+ "ground_truth": ground_truth_for_reward,
871
+ },
872
+ }
873
+
874
+ except Exception as e_lifecycle:
875
+ self.logger.error(f"Exception during task lifecycle: {e_lifecycle}", exc_info=True)
876
+ return {
877
+ "evaluation_result": {"error": str(e_lifecycle)},
878
+ "reward_function_inputs": None,
879
+ }
880
+ finally:
881
+ if episode_resource:
882
+ await episode_resource.close()
883
+ self.logger.info("Episode resource closed.")
884
+ if self.base_resource:
885
+ await self.base_resource.close()
886
+ self.base_resource = None
887
+ self.logger.info("Base resource closed.")
888
+ self.logger.info(f"Execution for task '{self.task_definition.name}' finished.")
889
+ # This should not be reached normally since we return earlier, but handle edge case
890
+ return {
891
+ "evaluation_result": {"error": "Unexpected execution path"},
892
+ "reward_function_inputs": None,
893
+ }