hud-python 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hud-python might be problematic. Click here for more details.

Files changed (54) hide show
  1. hud/__init__.py +7 -4
  2. hud/adapters/common/adapter.py +14 -3
  3. hud/adapters/common/tests/test_adapter.py +16 -4
  4. hud/datasets.py +188 -0
  5. hud/env/docker_client.py +14 -2
  6. hud/env/local_docker_client.py +28 -6
  7. hud/gym.py +0 -9
  8. hud/{mcp_agent → mcp}/__init__.py +2 -0
  9. hud/mcp/base.py +631 -0
  10. hud/{mcp_agent → mcp}/claude.py +52 -47
  11. hud/mcp/client.py +312 -0
  12. hud/{mcp_agent → mcp}/langchain.py +52 -33
  13. hud/{mcp_agent → mcp}/openai.py +56 -40
  14. hud/{mcp_agent → mcp}/tests/test_base.py +129 -54
  15. hud/mcp/tests/test_claude.py +294 -0
  16. hud/mcp/tests/test_client.py +324 -0
  17. hud/mcp/tests/test_openai.py +238 -0
  18. hud/settings.py +6 -0
  19. hud/task.py +2 -88
  20. hud/taskset.py +2 -23
  21. hud/telemetry/__init__.py +5 -0
  22. hud/telemetry/_trace.py +180 -17
  23. hud/telemetry/context.py +79 -0
  24. hud/telemetry/exporter.py +165 -6
  25. hud/telemetry/job.py +141 -0
  26. hud/telemetry/tests/test_trace.py +36 -25
  27. hud/tools/__init__.py +14 -1
  28. hud/tools/computer/hud.py +13 -0
  29. hud/tools/executors/__init__.py +19 -2
  30. hud/tools/executors/pyautogui.py +84 -50
  31. hud/tools/executors/tests/test_pyautogui_executor.py +4 -1
  32. hud/tools/playwright_tool.py +73 -67
  33. hud/tools/tests/test_edit.py +8 -1
  34. hud/tools/tests/test_tools.py +3 -0
  35. hud/trajectory.py +5 -1
  36. hud/utils/tests/test_version.py +1 -1
  37. hud/version.py +1 -1
  38. {hud_python-0.3.0.dist-info → hud_python-0.3.2.dist-info}/METADATA +20 -14
  39. {hud_python-0.3.0.dist-info → hud_python-0.3.2.dist-info}/RECORD +42 -47
  40. hud/evaluators/__init__.py +0 -9
  41. hud/evaluators/base.py +0 -32
  42. hud/evaluators/inspect.py +0 -24
  43. hud/evaluators/judge.py +0 -189
  44. hud/evaluators/match.py +0 -156
  45. hud/evaluators/remote.py +0 -65
  46. hud/evaluators/tests/__init__.py +0 -0
  47. hud/evaluators/tests/test_inspect.py +0 -12
  48. hud/evaluators/tests/test_judge.py +0 -231
  49. hud/evaluators/tests/test_match.py +0 -115
  50. hud/evaluators/tests/test_remote.py +0 -98
  51. hud/mcp_agent/base.py +0 -723
  52. /hud/{mcp_agent → mcp}/tests/__init__.py +0 -0
  53. {hud_python-0.3.0.dist-info → hud_python-0.3.2.dist-info}/WHEEL +0 -0
  54. {hud_python-0.3.0.dist-info → hud_python-0.3.2.dist-info}/licenses/LICENSE +0 -0
hud/mcp_agent/base.py DELETED
@@ -1,723 +0,0 @@
1
- """Base MCP Agent implementation."""
2
-
3
- from __future__ import annotations
4
-
5
- import asyncio
6
- import logging
7
- from abc import ABC, abstractmethod
8
- from typing import TYPE_CHECKING, Any
9
-
10
- import mcp.types as types
11
- from mcp_use import MCPClient
12
-
13
- if TYPE_CHECKING:
14
- from hud.task import Task
15
-
16
- logger = logging.getLogger(__name__)
17
-
18
-
19
- class BaseMCPAgent(ABC):
20
- """
21
- Base class for MCP-enabled agents.
22
-
23
- This class provides the foundation for agents that interact with MCP servers,
24
- handling tool discovery and filtering while leaving provider-specific
25
- implementation details to subclasses.
26
- """
27
-
28
- def __init__(
29
- self,
30
- client: MCPClient | None = None,
31
- allowed_tools: list[str] | None = None,
32
- disallowed_tools: list[str] | None = None,
33
- initial_screenshot: bool = False,
34
- max_screenshot_history: int = 3,
35
- append_tool_system_prompt: bool = True,
36
- custom_system_prompt: str | None = None,
37
- lifecycle_tools: dict[str, str] | None = None,
38
- ) -> None:
39
- """
40
- Initialize the base MCP agent.
41
-
42
- Args:
43
- client: MCPClient instance for server connections
44
- allowed_tools: List of tool names to allow (None = all tools)
45
- disallowed_tools: List of tool names to disallow
46
- initial_screenshot: Whether to capture screenshot before first prompt
47
- max_screenshot_history: Maximum number of screenshots to keep in context
48
- append_tool_system_prompt: Whether to append available tools to system prompt
49
- custom_system_prompt: Custom system prompt to use
50
- lifecycle_tools: Dict mapping lifecycle phases to tool names. Default:
51
- {
52
- "setup": "setup", # Setup phase tool
53
- "evaluate": "evaluate" # Evaluation phase tool
54
- }
55
- """
56
- self.client = client
57
- self.allowed_tools = allowed_tools
58
- self.disallowed_tools = disallowed_tools or []
59
- self.initial_screenshot = initial_screenshot
60
- self.max_screenshot_history = max_screenshot_history
61
- self.append_tool_system_prompt = append_tool_system_prompt
62
- self.custom_system_prompt = custom_system_prompt
63
-
64
- # Default lifecycle tool mapping
65
- default_lifecycle = {"setup": "setup", "evaluate": "evaluate"}
66
- self.lifecycle_tools = {**default_lifecycle, **(lifecycle_tools or {})}
67
-
68
- self._available_tools: list[types.Tool] = []
69
- self._tool_map: dict[str, tuple[str, types.Tool]] = {}
70
- self._sessions: dict[str, Any] = {}
71
-
72
- if client is None:
73
- self.client = MCPClient()
74
-
75
- async def initialize(self) -> None:
76
- """Initialize the agent and discover available tools."""
77
- # Get existing sessions or create new ones
78
- if self.client is None:
79
- raise ValueError("Client is not initialized")
80
-
81
- sessions = self.client.get_all_active_sessions()
82
-
83
- if not sessions:
84
- logger.info("No active sessions found, creating new ones...")
85
- sessions = await self.client.create_all_sessions()
86
-
87
- self._sessions = sessions
88
-
89
- # Discover tools from all servers
90
- self._available_tools = []
91
- self._tool_map = {}
92
-
93
- for server_name, session in sessions.items():
94
- try:
95
- # Ensure session is initialized
96
- if not hasattr(session, "connector") or not hasattr(
97
- session.connector, "client_session"
98
- ):
99
- await session.initialize()
100
-
101
- if session.connector.client_session is None:
102
- raise ValueError("Client session is not initialized")
103
-
104
- tools_result = await session.connector.client_session.list_tools()
105
-
106
- # Log all tools before filtering
107
- logger.info(
108
- "Tools from '%s' (pre-filter): %s",
109
- server_name,
110
- [tool.name for tool in tools_result.tools],
111
- )
112
-
113
- for tool in tools_result.tools:
114
- # Always include lifecycle tools for framework use
115
- is_lifecycle_tool = tool.name in self.lifecycle_tools.values()
116
-
117
- # Apply filtering (but always allow lifecycle tools)
118
- if not is_lifecycle_tool:
119
- if self.allowed_tools and tool.name not in self.allowed_tools:
120
- continue
121
- if tool.name in self.disallowed_tools:
122
- continue
123
-
124
- self._available_tools.append(tool)
125
- # Store tool with server reference for execution
126
- self._tool_map[tool.name] = (server_name, tool)
127
-
128
- except Exception as e:
129
- logger.error("Failed to list tools from server %s: %s", server_name, e)
130
-
131
- # Separate lifecycle tools from regular tools for clearer logging
132
- lifecycle_tool_names = list(self.lifecycle_tools.values())
133
- regular_tools = [
134
- t.name for t in self._available_tools if t.name not in lifecycle_tool_names
135
- ]
136
- lifecycle_tools_found = [
137
- t.name for t in self._available_tools if t.name in lifecycle_tool_names
138
- ]
139
-
140
- logger.info(
141
- "Agent initialized with %s tools (%s regular, %s lifecycle)",
142
- len(self._available_tools),
143
- len(regular_tools),
144
- len(lifecycle_tools_found),
145
- )
146
- if regular_tools:
147
- logger.info("Regular tools: %s", regular_tools)
148
- if lifecycle_tools_found:
149
- logger.info("Lifecycle tools: %s", lifecycle_tools_found)
150
-
151
- def get_available_tools(self) -> list[types.Tool]:
152
- """Get list of available MCP tools for LLM use (excludes lifecycle tools)."""
153
- lifecycle_tool_names = list(self.lifecycle_tools.values())
154
- return [tool for tool in self._available_tools if tool.name not in lifecycle_tool_names]
155
-
156
- def get_tool_map(self) -> dict[str, tuple[str, types.Tool]]:
157
- """Get mapping of tool names to (server_name, tool) tuples."""
158
- return self._tool_map
159
-
160
- def get_sessions(self) -> dict[str, Any]:
161
- """Get active MCP sessions."""
162
- return self._sessions
163
-
164
- def get_tools_by_server(self) -> dict[str, list[types.Tool]]:
165
- """Get tools grouped by server name."""
166
- tools_by_server = {}
167
- for server_name, tool in self._tool_map.values():
168
- if server_name not in tools_by_server:
169
- tools_by_server[server_name] = []
170
- tools_by_server[server_name].append(tool)
171
- return tools_by_server
172
-
173
- def get_tools_by_connector(self) -> dict[Any, list[types.Tool]]:
174
- """Get tools grouped by connector instance."""
175
- tools_by_connector = {}
176
- for server_name, tool in self._tool_map.values():
177
- session = self._sessions[server_name]
178
- connector = session.connector
179
-
180
- if connector not in tools_by_connector:
181
- tools_by_connector[connector] = []
182
- tools_by_connector[connector].append(tool)
183
- return tools_by_connector
184
-
185
- def get_system_prompt(self) -> str:
186
- """Generate system prompt with optional tool information."""
187
- base_prompt = self.custom_system_prompt or "You are a helpful assistant."
188
-
189
- if self.append_tool_system_prompt and self._available_tools:
190
- tool_descriptions = []
191
- for tool in self._available_tools:
192
- desc = f"- {tool.name}: {tool.description}"
193
- if tool.inputSchema:
194
- desc += f" (parameters: {tool.inputSchema})"
195
- tool_descriptions.append(desc)
196
-
197
- tools_prompt = "\n\nYou have access to the following tools:\n" + "\n".join(
198
- tool_descriptions
199
- )
200
- return base_prompt + tools_prompt
201
-
202
- return base_prompt
203
-
204
- async def call_tool(self, tool_call: dict[str, Any]) -> types.CallToolResult:
205
- """
206
- Call a tool through the MCP client.
207
-
208
- Args:
209
- tool_call: Dict with 'name' and optional 'arguments' keys
210
-
211
- Returns:
212
- The raw MCP CallToolResult
213
- """
214
- tool_name = tool_call.get("name")
215
- if not tool_name:
216
- raise ValueError("Tool call must have a 'name' field")
217
-
218
- tool_args = tool_call.get("arguments", {})
219
-
220
- if tool_name not in self._tool_map:
221
- raise ValueError(f"Tool '{tool_name}' not found or not allowed")
222
-
223
- if self.client is None:
224
- raise ValueError("Client is not initialized")
225
-
226
- server_name, tool = self._tool_map[tool_name]
227
- session = self.client.get_session(server_name)
228
-
229
- logger.info(
230
- "Calling tool '%s' on server '%s' with args: %s",
231
- tool_name,
232
- server_name,
233
- tool_args,
234
- )
235
- if session.connector.client_session is None:
236
- raise ValueError("Client session is not initialized")
237
-
238
- result = await session.connector.client_session.call_tool(tool_name, tool_args)
239
-
240
- # Log result for debugging
241
- if result.isError:
242
- logger.error("Tool '%s' returned error: %s", tool_name, result.content)
243
- else:
244
- logger.debug("Tool '%s' completed successfully", tool_name)
245
-
246
- return result
247
-
248
- def has_computer_tools(self) -> bool:
249
- """Check if any computer control tools are available."""
250
- computer_tools = {"computer", "computer_anthropic", "computer_openai", "screenshot"}
251
- return any(tool.name in computer_tools for tool in self._available_tools)
252
-
253
- def get_tool_schemas(self) -> list[dict]:
254
- """Get tool schemas in a format suitable for the model."""
255
- schemas = []
256
- for tool in self._available_tools:
257
- # Filter out lifecycle tools from LLM conversation
258
- if tool.name in self.lifecycle_tools.values():
259
- continue
260
-
261
- schema = {
262
- "name": tool.name,
263
- "description": tool.description,
264
- }
265
- if tool.inputSchema:
266
- schema["parameters"] = tool.inputSchema
267
- schemas.append(schema)
268
- return schemas
269
-
270
- async def capture_screenshot(self) -> str | None:
271
- """Capture a screenshot using available tools."""
272
- if not self.has_computer_tools():
273
- return None
274
-
275
- # Try different screenshot tools
276
- for tool_name in [
277
- "computer",
278
- "screenshot",
279
- "computer_anthropic",
280
- "computer_openai",
281
- "anthropic_computer",
282
- "openai_computer",
283
- ]:
284
- if tool_name in self._tool_map:
285
- try:
286
- # Different tools have different APIs
287
- if tool_name == "computer_openai":
288
- tool_call = {"name": tool_name, "arguments": {"type": "screenshot"}}
289
- else:
290
- tool_call = {"name": tool_name, "arguments": {"action": "screenshot"}}
291
-
292
- result = await self.call_tool(tool_call)
293
-
294
- # Extract screenshot from result
295
- for content in result.content:
296
- if isinstance(content, types.ImageContent):
297
- logger.info("Captured screenshot")
298
- return content.data
299
-
300
- except Exception as e:
301
- logger.warning("Failed to capture screenshot with %s: %s", tool_name, e)
302
-
303
- return None
304
-
305
- def process_tool_results(self, tool_results: list[dict[str, Any]]) -> dict[str, Any]:
306
- """
307
- Process tool results into a standardized format.
308
-
309
- Returns a dict with:
310
- - text: Combined text output from all tools
311
- - screenshot: Latest screenshot if any tool returned one
312
- - errors: List of any errors encountered
313
- - results: List of (tool_name, content_blocks) tuples for provider-specific formatting
314
- """
315
- text_parts = []
316
- latest_screenshot = None
317
- errors = []
318
- results = []
319
-
320
- for tool_result in tool_results:
321
- tool_name = tool_result["tool_name"]
322
- content_blocks = []
323
-
324
- if tool_result.get("error"):
325
- error_msg = f"{tool_name}: {tool_result.get('error_message', 'Unknown error')}"
326
- errors.append(error_msg)
327
- text_parts.append(f"Error - {error_msg}")
328
- content_blocks.append(
329
- {
330
- "type": "error",
331
- "text": tool_result.get("error_message", "Unknown error"),
332
- }
333
- )
334
- else:
335
- result = tool_result["result"]
336
- if result.isError:
337
- # Extract error from content
338
- error_text = "Tool execution failed"
339
- for content in result.content:
340
- if isinstance(content, types.TextContent):
341
- error_text = content.text
342
- break
343
- error_msg = f"{tool_name}: {error_text}"
344
- errors.append(error_msg)
345
- text_parts.append(f"Error - {error_msg}")
346
- content_blocks.append(
347
- {
348
- "type": "error",
349
- "text": error_text,
350
- }
351
- )
352
- else:
353
- # Process success content
354
- tool_output = []
355
- for content in result.content:
356
- if isinstance(content, types.TextContent):
357
- tool_output.append(content.text)
358
- content_blocks.append(
359
- {
360
- "type": "text",
361
- "text": content.text,
362
- }
363
- )
364
- elif isinstance(content, types.ImageContent):
365
- # Keep the latest screenshot
366
- latest_screenshot = content.data
367
- content_blocks.append(
368
- {
369
- "type": "image",
370
- "data": content.data,
371
- }
372
- )
373
-
374
- if tool_output:
375
- text_parts.append(f"{tool_name}: " + " ".join(tool_output))
376
-
377
- results.append((tool_name, content_blocks))
378
-
379
- return {
380
- "text": "\n".join(text_parts) if text_parts else "No output from tools",
381
- "screenshot": latest_screenshot,
382
- "errors": errors,
383
- "results": results, # List of (tool_name, content_blocks) for provider-specific use
384
- }
385
-
386
- async def run(
387
- self, prompt_or_task: str | Task, max_steps: int = 10, conversation_mode: bool = False
388
- ) -> dict[str, Any]:
389
- """
390
- Run the agent with the given prompt or task.
391
-
392
- Args:
393
- prompt_or_task: Either a string prompt for simple execution or a Task object
394
- max_steps: Maximum number of steps
395
- conversation_mode: If True, continue even when model returns text without tool calls
396
-
397
- Returns:
398
- For string prompts: The final response string
399
- For Task objects: Evaluation result dict with 'reward', 'done', 'info' keys
400
- """
401
- # Import here to avoid circular imports
402
- from hud.task import Task
403
-
404
- if not self._available_tools:
405
- await self.initialize()
406
-
407
- # Handle Task objects with full lifecycle
408
- if isinstance(prompt_or_task, Task):
409
- return await self._run_task(prompt_or_task, max_steps)
410
-
411
- # Handle simple string prompts (existing behavior)
412
- elif isinstance(prompt_or_task, str):
413
- return await self._run_prompt(prompt_or_task, max_steps, conversation_mode)
414
-
415
- else:
416
- raise TypeError(f"prompt_or_task must be str or Task, got {type(prompt_or_task)}")
417
-
418
- async def _run_task(self, task: Task, max_steps: int = 10) -> dict[str, Any]:
419
- """
420
- Execute a task with setup and evaluate phases.
421
-
422
- Args:
423
- task: Task object with prompt, setup, and evaluate configs
424
- max_steps: Maximum steps for task execution
425
-
426
- Returns:
427
- Evaluation result dict with 'reward', 'done', 'info' keys
428
- """
429
- try:
430
- # Setup phase
431
- if task.setup is not None:
432
- setup_tool = self.lifecycle_tools.get("setup", "setup")
433
- await self._call_tool_safe(setup_tool, task.setup)
434
-
435
- # Execute the task prompt
436
- await self._run_prompt(task.prompt, max_steps, conversation_mode=False)
437
-
438
- # Evaluate phase
439
- if task.evaluate is not None:
440
- evaluate_tool = self.lifecycle_tools.get("evaluate", "evaluate")
441
- eval_result = await self._call_tool_safe(evaluate_tool, task.evaluate)
442
-
443
- # Return evaluation result if it's properly formatted
444
- if (
445
- isinstance(eval_result, dict)
446
- and "reward" in eval_result
447
- and "done" in eval_result
448
- ):
449
- return eval_result
450
- elif isinstance(eval_result, dict) and "grade" in eval_result:
451
- return {
452
- "reward": eval_result.get("grade", 0.0),
453
- "done": True,
454
- "info": {
455
- "error": eval_result.get("error"),
456
- "logs": eval_result.get("logs", ""),
457
- "original_result": eval_result,
458
- },
459
- }
460
- else:
461
- # Fallback for invalid evaluation format
462
- return {
463
- "reward": 0.0,
464
- "done": True,
465
- "info": {"error": "Invalid evaluation result", "eval_result": eval_result},
466
- }
467
- else:
468
- # No evaluation - assume success
469
- return {
470
- "reward": 0.0,
471
- "done": True,
472
- "info": {"message": "Task completed (no evaluation specified)"},
473
- }
474
-
475
- except Exception as e:
476
- return {"reward": 0.0, "done": True, "info": {"error": str(e)}}
477
-
478
- async def _call_tool_safe(self, tool_name: str, arguments: Any) -> Any:
479
- """
480
- Safely call a tool and return its result.
481
-
482
- Args:
483
- tool_name: Name of the tool to call
484
- arguments: Arguments to pass to the tool (config from task)
485
-
486
- Returns:
487
- Tool result or None if tool not available/failed
488
- """
489
- try:
490
- if tool_name in self._tool_map:
491
- tool_call = {"name": tool_name, "arguments": arguments}
492
- result = await self.call_tool(tool_call)
493
-
494
- if result.isError:
495
- logger.error("Tool %s returned error: %s", tool_name, result.content)
496
- return {"error": result.content}
497
- else:
498
- # Extract content from MCP result
499
- if hasattr(result, "content") and result.content:
500
- if len(result.content) == 1:
501
- content_item = result.content[0]
502
- # Check if content_item is a text type
503
- if hasattr(content_item, "text") and hasattr(content_item, "type"):
504
- if getattr(content_item, "type", None) == "text":
505
- # Try to parse as JSON if it looks like structured data
506
- text = content_item.text # type: ignore[reportAttributeAccessIssue]
507
- if text.strip().startswith("{") and text.strip().endswith("}"):
508
- try:
509
- import json
510
-
511
- return json.loads(text)
512
- except json.JSONDecodeError:
513
- return text
514
- return text
515
- else:
516
- return content_item
517
- else:
518
- return result.content
519
- return result
520
- else:
521
- logger.warning("Tool %s not available", tool_name)
522
- return None
523
- except Exception as e:
524
- logger.error("Failed to call tool %s: %s", tool_name, e)
525
- return {"error": str(e)}
526
-
527
- async def _run_prompt(
528
- self,
529
- prompt: str,
530
- max_steps: int = 10,
531
- conversation_mode: bool = False,
532
- ) -> dict[str, Any]:
533
- """
534
- Run the agent with the given prompt.
535
-
536
- Args:
537
- prompt: The task to complete
538
- max_steps: Maximum number of steps
539
- conversation_mode: If True, continue even when model returns text without tool calls
540
-
541
- Returns:
542
- The final response or result
543
- """
544
- try:
545
- latest_screenshot = None
546
- if self.initial_screenshot:
547
- latest_screenshot = await self.capture_screenshot()
548
-
549
- messages = await self.create_initial_messages(prompt, latest_screenshot)
550
-
551
- step = 0
552
- while step < max_steps:
553
- step += 1
554
- logger.info("step %s/%s", step, max_steps)
555
-
556
- try:
557
- response = await self.get_model_response(messages, step)
558
-
559
- # Log the model's response
560
- logger.info("Model response - Content: %s", response.get("content", ""))
561
- logger.info(
562
- "Model response - Tool calls: %s",
563
- [tc.get("name") for tc in response.get("tool_calls", [])],
564
- )
565
- logger.info("Model response - Done: %s", response.get("done", False))
566
-
567
- # Check if we should stop
568
- if response.get("done", False) and not conversation_mode:
569
- return response.get("content", "Task completed")
570
-
571
- tool_calls = response.get("tool_calls", [])
572
- if not tool_calls:
573
- if conversation_mode:
574
- # In conversation mode, if model responds without tools,
575
- # show the response and get user input
576
- model_response = response.get("content", "")
577
- if model_response:
578
- print(f"\n🤖 Agent: {model_response}") # noqa: T201
579
- user_input = input("\n👤 You: ").strip()
580
- if user_input.lower() in ["exit", "quit", "bye"]:
581
- return {
582
- "done": True,
583
- "reward": 0.0,
584
- "info": {"message": "Conversation ended by user."},
585
- }
586
- # Add user's response to the conversation
587
- # This needs to be handled by subclass-specific format
588
- user_message = await self.create_user_message(user_input)
589
- messages.append(user_message)
590
- continue
591
- else:
592
- # No content and no tools - something went wrong
593
- return {
594
- "done": False,
595
- "reward": 0.0,
596
- "info": {"message": "No response generated"},
597
- }
598
- else:
599
- # In task mode, no tool calls means we're done
600
- logger.info("In task mode with no tool calls - stopping execution")
601
- logger.info(
602
- "Final message: %s",
603
- response.get("content", "No response generated"),
604
- )
605
- return {
606
- "done": True,
607
- "reward": 0.0,
608
- "info": {
609
- "message": response.get("content", "No response generated"),
610
- },
611
- }
612
-
613
- # Execute tool calls
614
- tool_results = []
615
- for tool_call in tool_calls:
616
- if not tool_call.get("name"):
617
- continue
618
- try:
619
- result = await self.call_tool(tool_call)
620
- tool_results.append(
621
- {
622
- "tool_name": tool_call["name"],
623
- "result": result,
624
- "error": False,
625
- }
626
- )
627
- except Exception as e:
628
- logger.error("Tool execution failed: %s", e)
629
- tool_results.append(
630
- {
631
- "tool_name": tool_call["name"],
632
- "error": True,
633
- "error_message": str(e),
634
- }
635
- )
636
-
637
- # Process results
638
- processed_results = self.process_tool_results(tool_results)
639
-
640
- # Update screenshot if we got a new one
641
- if processed_results["screenshot"]:
642
- latest_screenshot = processed_results["screenshot"]
643
-
644
- # Format tool results for the model
645
- tool_messages = await self.format_tool_results(
646
- processed_results,
647
- response.get("tool_calls", []),
648
- )
649
- messages.extend(tool_messages)
650
-
651
- except Exception as e:
652
- logger.error("Model call failed: %s", e)
653
- return {"done": False, "reward": 0.0, "info": {"message": f"Error: {e}"}}
654
-
655
- return {"done": True, "reward": 0.0, "info": {"message": "Task completed"}}
656
-
657
- except KeyboardInterrupt:
658
- logger.info("Agent execution interrupted by user")
659
- return {
660
- "done": False,
661
- "reward": 0.0,
662
- "info": {"message": "Execution interrupted by user (Ctrl+C)"},
663
- }
664
- except asyncio.CancelledError:
665
- logger.info("Agent execution cancelled")
666
- return {"done": False, "reward": 0.0, "info": {"message": "Execution cancelled"}}
667
-
668
- @abstractmethod
669
- async def create_initial_messages(self, prompt: str, screenshot: str | None) -> list[Any]:
670
- """
671
- Create initial messages for the conversation.
672
-
673
- Args:
674
- prompt: The user's prompt
675
- screenshot: Optional initial screenshot
676
-
677
- Returns:
678
- List of messages in provider-specific format
679
- """
680
-
681
- @abstractmethod
682
- async def get_model_response(self, messages: list[Any], step: int) -> dict[str, Any]:
683
- """
684
- Get response from the model including any tool calls.
685
-
686
- Args:
687
- messages: List of messages in provider-specific format
688
- step: Current step number
689
-
690
- Returns:
691
- Dict with 'content', 'tool_calls', and 'done' keys
692
- """
693
-
694
- @abstractmethod
695
- async def format_tool_results(
696
- self, processed_results: dict[str, Any], tool_calls: list[dict[str, Any]]
697
- ) -> list[Any]:
698
- """
699
- Format tool results into messages for the model.
700
-
701
- Args:
702
- processed_results: Processed tool results from process_tool_results
703
- tool_calls: Original tool calls from the model
704
-
705
- Returns:
706
- List of formatted messages to append to conversation
707
- """
708
- raise NotImplementedError
709
-
710
- async def create_user_message(self, text: str) -> Any:
711
- """
712
- Create a user message in the format expected by the model.
713
-
714
- Default implementation for text-only messages.
715
- Subclasses can override for specific formats.
716
-
717
- Args:
718
- text: User's text input
719
-
720
- Returns:
721
- Formatted user message
722
- """
723
- return {"role": "user", "content": text}