cua-agent 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cua-agent might be problematic. Click here for more details.

Files changed (112) hide show
  1. agent/__init__.py +21 -12
  2. agent/__main__.py +21 -0
  3. agent/adapters/__init__.py +9 -0
  4. agent/adapters/huggingfacelocal_adapter.py +229 -0
  5. agent/agent.py +594 -0
  6. agent/callbacks/__init__.py +19 -0
  7. agent/callbacks/base.py +153 -0
  8. agent/callbacks/budget_manager.py +44 -0
  9. agent/callbacks/image_retention.py +139 -0
  10. agent/callbacks/logging.py +247 -0
  11. agent/callbacks/pii_anonymization.py +259 -0
  12. agent/callbacks/telemetry.py +210 -0
  13. agent/callbacks/trajectory_saver.py +305 -0
  14. agent/cli.py +297 -0
  15. agent/computer_handler.py +107 -0
  16. agent/decorators.py +90 -0
  17. agent/loops/__init__.py +11 -0
  18. agent/loops/anthropic.py +728 -0
  19. agent/loops/omniparser.py +339 -0
  20. agent/loops/openai.py +95 -0
  21. agent/loops/uitars.py +688 -0
  22. agent/responses.py +207 -0
  23. agent/telemetry.py +135 -14
  24. agent/types.py +79 -0
  25. agent/ui/__init__.py +7 -1
  26. agent/ui/__main__.py +2 -13
  27. agent/ui/gradio/__init__.py +6 -19
  28. agent/ui/gradio/app.py +94 -1313
  29. agent/ui/gradio/ui_components.py +721 -0
  30. cua_agent-0.4.0.dist-info/METADATA +424 -0
  31. cua_agent-0.4.0.dist-info/RECORD +33 -0
  32. {cua_agent-0.3.1.dist-info → cua_agent-0.4.0.dist-info}/WHEEL +1 -1
  33. agent/core/__init__.py +0 -27
  34. agent/core/agent.py +0 -210
  35. agent/core/base.py +0 -217
  36. agent/core/callbacks.py +0 -200
  37. agent/core/experiment.py +0 -249
  38. agent/core/factory.py +0 -122
  39. agent/core/messages.py +0 -332
  40. agent/core/provider_config.py +0 -21
  41. agent/core/telemetry.py +0 -142
  42. agent/core/tools/__init__.py +0 -21
  43. agent/core/tools/base.py +0 -74
  44. agent/core/tools/bash.py +0 -52
  45. agent/core/tools/collection.py +0 -46
  46. agent/core/tools/computer.py +0 -113
  47. agent/core/tools/edit.py +0 -67
  48. agent/core/tools/manager.py +0 -56
  49. agent/core/tools.py +0 -32
  50. agent/core/types.py +0 -88
  51. agent/core/visualization.py +0 -197
  52. agent/providers/__init__.py +0 -4
  53. agent/providers/anthropic/__init__.py +0 -6
  54. agent/providers/anthropic/api/client.py +0 -360
  55. agent/providers/anthropic/api/logging.py +0 -150
  56. agent/providers/anthropic/api_handler.py +0 -140
  57. agent/providers/anthropic/callbacks/__init__.py +0 -5
  58. agent/providers/anthropic/callbacks/manager.py +0 -65
  59. agent/providers/anthropic/loop.py +0 -568
  60. agent/providers/anthropic/prompts.py +0 -23
  61. agent/providers/anthropic/response_handler.py +0 -226
  62. agent/providers/anthropic/tools/__init__.py +0 -33
  63. agent/providers/anthropic/tools/base.py +0 -88
  64. agent/providers/anthropic/tools/bash.py +0 -66
  65. agent/providers/anthropic/tools/collection.py +0 -34
  66. agent/providers/anthropic/tools/computer.py +0 -396
  67. agent/providers/anthropic/tools/edit.py +0 -326
  68. agent/providers/anthropic/tools/manager.py +0 -54
  69. agent/providers/anthropic/tools/run.py +0 -42
  70. agent/providers/anthropic/types.py +0 -16
  71. agent/providers/anthropic/utils.py +0 -367
  72. agent/providers/omni/__init__.py +0 -8
  73. agent/providers/omni/api_handler.py +0 -42
  74. agent/providers/omni/clients/anthropic.py +0 -103
  75. agent/providers/omni/clients/base.py +0 -35
  76. agent/providers/omni/clients/oaicompat.py +0 -195
  77. agent/providers/omni/clients/ollama.py +0 -122
  78. agent/providers/omni/clients/openai.py +0 -155
  79. agent/providers/omni/clients/utils.py +0 -25
  80. agent/providers/omni/image_utils.py +0 -34
  81. agent/providers/omni/loop.py +0 -990
  82. agent/providers/omni/parser.py +0 -307
  83. agent/providers/omni/prompts.py +0 -64
  84. agent/providers/omni/tools/__init__.py +0 -30
  85. agent/providers/omni/tools/base.py +0 -29
  86. agent/providers/omni/tools/bash.py +0 -74
  87. agent/providers/omni/tools/computer.py +0 -179
  88. agent/providers/omni/tools/manager.py +0 -61
  89. agent/providers/omni/utils.py +0 -236
  90. agent/providers/openai/__init__.py +0 -6
  91. agent/providers/openai/api_handler.py +0 -456
  92. agent/providers/openai/loop.py +0 -472
  93. agent/providers/openai/response_handler.py +0 -205
  94. agent/providers/openai/tools/__init__.py +0 -15
  95. agent/providers/openai/tools/base.py +0 -79
  96. agent/providers/openai/tools/computer.py +0 -326
  97. agent/providers/openai/tools/manager.py +0 -106
  98. agent/providers/openai/types.py +0 -36
  99. agent/providers/openai/utils.py +0 -98
  100. agent/providers/uitars/__init__.py +0 -1
  101. agent/providers/uitars/clients/base.py +0 -35
  102. agent/providers/uitars/clients/mlxvlm.py +0 -263
  103. agent/providers/uitars/clients/oaicompat.py +0 -214
  104. agent/providers/uitars/loop.py +0 -660
  105. agent/providers/uitars/prompts.py +0 -63
  106. agent/providers/uitars/tools/__init__.py +0 -1
  107. agent/providers/uitars/tools/computer.py +0 -283
  108. agent/providers/uitars/tools/manager.py +0 -60
  109. agent/providers/uitars/utils.py +0 -264
  110. cua_agent-0.3.1.dist-info/METADATA +0 -295
  111. cua_agent-0.3.1.dist-info/RECORD +0 -87
  112. {cua_agent-0.3.1.dist-info → cua_agent-0.4.0.dist-info}/entry_points.txt +0 -0
@@ -1,990 +0,0 @@
1
- """Omni-specific agent loop implementation."""
2
-
3
- import logging
4
- from typing import Any, Dict, List, Optional, Tuple, AsyncGenerator
5
- import json
6
- import re
7
- import os
8
- import asyncio
9
- from httpx import ConnectError, ReadTimeout
10
- from typing import cast
11
-
12
- from .parser import OmniParser, ParseResult
13
- from ...core.base import BaseLoop
14
- from ...core.visualization import VisualizationHelper
15
- from ...core.messages import StandardMessageManager, ImageRetentionConfig
16
- from .utils import to_openai_agent_response_format
17
- from ...core.types import AgentResponse
18
- from computer import Computer
19
- from ...core.types import LLMProvider
20
- from .clients.openai import OpenAIClient
21
- from .clients.anthropic import AnthropicClient
22
- from .clients.ollama import OllamaClient
23
- from .clients.oaicompat import OAICompatClient
24
- from .prompts import SYSTEM_PROMPT
25
- from .api_handler import OmniAPIHandler
26
- from .tools.manager import ToolManager
27
- from .tools import ToolResult
28
-
29
- logger = logging.getLogger(__name__)
30
-
31
- def extract_data(input_string: str, data_type: str) -> str:
32
- """Extract content from code blocks."""
33
- pattern = f"```{data_type}" + r"(.*?)(```|$)"
34
- matches = re.findall(pattern, input_string, re.DOTALL)
35
- return matches[0][0].strip() if matches else input_string
36
-
37
-
38
- class OmniLoop(BaseLoop):
39
- """Omni-specific implementation of the agent loop.
40
-
41
- This class extends BaseLoop to provide support for multimodal models
42
- from various providers (OpenAI, Anthropic, etc.) with UI parsing
43
- and desktop automation capabilities.
44
- """
45
-
46
- ###########################################
47
- # INITIALIZATION AND CONFIGURATION
48
- ###########################################
49
-
50
- def __init__(
51
- self,
52
- parser: OmniParser,
53
- provider: LLMProvider,
54
- api_key: str,
55
- model: str,
56
- computer: Computer,
57
- only_n_most_recent_images: Optional[int] = 2,
58
- base_dir: Optional[str] = "trajectories",
59
- max_retries: int = 3,
60
- retry_delay: float = 1.0,
61
- save_trajectory: bool = True,
62
- provider_base_url: Optional[str] = None,
63
- **kwargs,
64
- ):
65
- """Initialize the loop.
66
-
67
- Args:
68
- parser: Parser instance
69
- provider: API provider
70
- api_key: API key
71
- model: Model name
72
- computer: Computer instance
73
- only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
74
- base_dir: Base directory for saving experiment data
75
- max_retries: Maximum number of retries for API calls
76
- retry_delay: Delay between retries in seconds
77
- save_trajectory: Whether to save trajectory data
78
- provider_base_url: Base URL for the API provider (used for OAICOMPAT)
79
- """
80
- # Set parser and provider before initializing base class
81
- self.parser = parser
82
- self.provider = provider
83
- self.provider_base_url = provider_base_url
84
-
85
- # Initialize message manager with image retention config
86
- self.message_manager = StandardMessageManager(
87
- config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
88
- )
89
-
90
- # Initialize base class (which will set up experiment manager)
91
- super().__init__(
92
- computer=computer,
93
- model=model,
94
- api_key=api_key,
95
- max_retries=max_retries,
96
- retry_delay=retry_delay,
97
- base_dir=base_dir,
98
- save_trajectory=save_trajectory,
99
- only_n_most_recent_images=only_n_most_recent_images,
100
- **kwargs,
101
- )
102
-
103
- # Set API client attributes
104
- self.client = None
105
- self.retry_count = 0
106
- self.loop_task = None # Store the loop task for cancellation
107
-
108
- # Initialize handlers
109
- self.api_handler = OmniAPIHandler(loop=self)
110
- self.viz_helper = VisualizationHelper(agent=self)
111
-
112
- # Initialize tool manager
113
- self.tool_manager = ToolManager(computer=computer, provider=provider)
114
-
115
- logger.info("OmniLoop initialized with StandardMessageManager")
116
-
117
- async def initialize(self) -> None:
118
- """Initialize the loop by setting up tools and clients."""
119
- # Initialize base class
120
- await super().initialize()
121
-
122
- # Initialize tool manager with error handling
123
- try:
124
- logger.info("Initializing tool manager...")
125
- await self.tool_manager.initialize()
126
- logger.info("Tool manager initialized successfully.")
127
- except Exception as e:
128
- logger.error(f"Error initializing tool manager: {str(e)}")
129
- logger.warning("Will attempt to initialize tools on first use.")
130
-
131
- # Initialize API clients based on provider
132
- if self.provider == LLMProvider.ANTHROPIC:
133
- self.client = AnthropicClient(
134
- api_key=self.api_key,
135
- model=self.model,
136
- )
137
- elif self.provider == LLMProvider.OPENAI:
138
- self.client = OpenAIClient(
139
- api_key=self.api_key,
140
- model=self.model,
141
- )
142
- elif self.provider == LLMProvider.OLLAMA:
143
- self.client = OllamaClient(
144
- api_key=self.api_key,
145
- model=self.model,
146
- )
147
- elif self.provider == LLMProvider.OAICOMPAT:
148
- self.client = OAICompatClient(
149
- api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
150
- model=self.model,
151
- provider_base_url=self.provider_base_url,
152
- )
153
- else:
154
- raise ValueError(f"Unsupported provider: {self.provider}")
155
-
156
- ###########################################
157
- # CLIENT INITIALIZATION - IMPLEMENTING ABSTRACT METHOD
158
- ###########################################
159
-
160
- async def initialize_client(self) -> None:
161
- """Initialize the appropriate client based on provider.
162
-
163
- Implements abstract method from BaseLoop to set up the specific
164
- provider client (OpenAI, Anthropic, etc.).
165
- """
166
- try:
167
- logger.info(f"Initializing {self.provider} client with model {self.model}...")
168
-
169
- if self.provider == LLMProvider.OPENAI:
170
- self.client = OpenAIClient(api_key=self.api_key, model=self.model)
171
- elif self.provider == LLMProvider.ANTHROPIC:
172
- self.client = AnthropicClient(
173
- api_key=self.api_key,
174
- model=self.model,
175
- max_retries=self.max_retries,
176
- retry_delay=self.retry_delay,
177
- )
178
- elif self.provider == LLMProvider.OLLAMA:
179
- self.client = OllamaClient(
180
- api_key=self.api_key,
181
- model=self.model,
182
- )
183
- elif self.provider == LLMProvider.OAICOMPAT:
184
- self.client = OAICompatClient(
185
- api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
186
- model=self.model,
187
- provider_base_url=self.provider_base_url,
188
- )
189
- else:
190
- raise ValueError(f"Unsupported provider: {self.provider}")
191
-
192
- logger.info(f"Initialized {self.provider} client with model {self.model}")
193
- except Exception as e:
194
- logger.error(f"Error initializing client: {str(e)}")
195
- self.client = None
196
- raise RuntimeError(f"Failed to initialize client: {str(e)}")
197
-
198
- ###########################################
199
- # API CALL HANDLING
200
- ###########################################
201
-
202
- async def _make_api_call(self, messages: List[Dict[str, Any]], system_prompt: str) -> Any:
203
- """Make API call to provider with retry logic."""
204
- # Create new turn directory for this API call
205
- self._create_turn_dir()
206
-
207
- request_data = None
208
- last_error = None
209
-
210
- for attempt in range(self.max_retries):
211
- try:
212
- # Ensure client is initialized
213
- if self.client is None:
214
- logger.info(
215
- f"Client not initialized in _make_api_call (attempt {attempt+1}), initializing now..."
216
- )
217
- await self.initialize_client()
218
- if self.client is None:
219
- raise RuntimeError("Failed to initialize client")
220
-
221
- # Get messages in standard format from the message manager
222
- self.message_manager.messages = messages.copy()
223
- prepared_messages = self.message_manager.get_messages()
224
-
225
- # Special handling for Anthropic
226
- if self.provider == LLMProvider.ANTHROPIC:
227
- # Convert to Anthropic format
228
- anthropic_messages, anthropic_system = self.message_manager.to_anthropic_format(
229
- prepared_messages
230
- )
231
-
232
- # Filter out any empty/invalid messages
233
- filtered_messages = [
234
- msg
235
- for msg in anthropic_messages
236
- if msg.get("role") in ["user", "assistant"]
237
- ]
238
-
239
- # Ensure there's at least one message for Anthropic
240
- if not filtered_messages:
241
- logger.warning(
242
- "No valid messages found for Anthropic API call. Adding a default user message."
243
- )
244
- filtered_messages = [
245
- {
246
- "role": "user",
247
- "content": [
248
- {"type": "text", "text": "Please help with this task."}
249
- ],
250
- }
251
- ]
252
-
253
- # Combine system prompts if needed
254
- final_system_prompt = anthropic_system or system_prompt
255
-
256
- # Log request
257
- request_data = {
258
- "messages": filtered_messages,
259
- "max_tokens": self.max_tokens,
260
- "system": final_system_prompt,
261
- }
262
-
263
- self._log_api_call("request", request_data)
264
-
265
- # Make API call
266
- response = await self.client.run_interleaved(
267
- messages=filtered_messages,
268
- system=final_system_prompt,
269
- max_tokens=self.max_tokens,
270
- )
271
- else:
272
- # For OpenAI and others, use standard format directly
273
- # Log request
274
- request_data = {
275
- "messages": prepared_messages,
276
- "max_tokens": self.max_tokens,
277
- "system": system_prompt,
278
- }
279
-
280
- self._log_api_call("request", request_data)
281
-
282
- # Make API call
283
- response = await self.client.run_interleaved(
284
- messages=prepared_messages,
285
- system=system_prompt,
286
- max_tokens=self.max_tokens,
287
- )
288
-
289
- # Log success response
290
- self._log_api_call("response", request_data, response)
291
-
292
- return response
293
-
294
- except (ConnectError, ReadTimeout) as e:
295
- last_error = e
296
- logger.warning(
297
- f"Connection error on attempt {attempt + 1}/{self.max_retries}: {str(e)}"
298
- )
299
- if attempt < self.max_retries - 1:
300
- await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
301
- # Reset client on connection errors to force re-initialization
302
- self.client = None
303
- continue
304
-
305
- except RuntimeError as e:
306
- # Handle client initialization errors specifically
307
- last_error = e
308
- self._log_api_call("error", request_data, error=e)
309
- logger.error(
310
- f"Client initialization error (attempt {attempt + 1}/{self.max_retries}): {str(e)}"
311
- )
312
- if attempt < self.max_retries - 1:
313
- # Reset client to force re-initialization
314
- self.client = None
315
- await asyncio.sleep(self.retry_delay)
316
- continue
317
-
318
- except Exception as e:
319
- # Log unexpected error
320
- last_error = e
321
- self._log_api_call("error", request_data, error=e)
322
- logger.error(f"Unexpected error in API call: {str(e)}")
323
- if attempt < self.max_retries - 1:
324
- await asyncio.sleep(self.retry_delay)
325
- continue
326
-
327
- # If we get here, all retries failed
328
- error_message = f"API call failed after {self.max_retries} attempts"
329
- if last_error:
330
- error_message += f": {str(last_error)}"
331
-
332
- logger.error(error_message)
333
- raise RuntimeError(error_message)
334
-
335
- ###########################################
336
- # RESPONSE AND ACTION HANDLING
337
- ###########################################
338
-
339
- async def _handle_response(
340
- self, response: Any, messages: List[Dict[str, Any]], parsed_screen: ParseResult
341
- ) -> Tuple[bool, bool]:
342
- """Handle API response.
343
-
344
- Args:
345
- response: API response
346
- messages: List of messages to update
347
- parsed_screen: Current parsed screen information
348
-
349
- Returns:
350
- Tuple of (should_continue, action_screenshot_saved)
351
- """
352
- action_screenshot_saved = False
353
-
354
- # Helper function to safely add assistant messages using the message manager
355
- def add_assistant_message(content):
356
- if isinstance(content, str):
357
- # Convert string to proper format
358
- formatted_content = [{"type": "text", "text": content}]
359
- self.message_manager.add_assistant_message(formatted_content)
360
- logger.info("Added formatted text assistant message")
361
- elif isinstance(content, list):
362
- # Already in proper format
363
- self.message_manager.add_assistant_message(content)
364
- logger.info("Added structured assistant message")
365
- else:
366
- # Default case - convert to string
367
- formatted_content = [{"type": "text", "text": str(content)}]
368
- self.message_manager.add_assistant_message(formatted_content)
369
- logger.info("Added converted assistant message")
370
-
371
- try:
372
- # Step 1: Normalize response to standard format based on provider
373
- standard_content = []
374
- raw_text = None
375
-
376
- # Convert response to standardized content based on provider
377
- if self.provider == LLMProvider.ANTHROPIC:
378
- if hasattr(response, "content") and isinstance(response.content, list):
379
- # Convert Anthropic response to standard format
380
- for block in response.content:
381
- if hasattr(block, "type"):
382
- if block.type == "text":
383
- standard_content.append({"type": "text", "text": block.text})
384
- # Store raw text for JSON parsing
385
- if raw_text is None:
386
- raw_text = block.text
387
- else:
388
- raw_text += "\n" + block.text
389
- else:
390
- # Add other block types
391
- block_dict = {}
392
- for key, value in vars(block).items():
393
- if not key.startswith("_"):
394
- block_dict[key] = value
395
- standard_content.append(block_dict)
396
- else:
397
- logger.warning("Invalid Anthropic response format")
398
- return True, action_screenshot_saved
399
- elif self.provider == LLMProvider.OLLAMA:
400
- try:
401
- raw_text = response["message"]["content"]
402
- standard_content = [{"type": "text", "text": raw_text}]
403
- except (KeyError, TypeError, IndexError) as e:
404
- logger.error(f"Invalid response format: {str(e)}")
405
- return True, action_screenshot_saved
406
- elif self.provider == LLMProvider.OAICOMPAT:
407
- try:
408
- # OpenAI-compatible response format
409
- raw_text = response["choices"][0]["message"]["content"]
410
- standard_content = [{"type": "text", "text": raw_text}]
411
- except (KeyError, TypeError, IndexError) as e:
412
- logger.error(f"Invalid response format: {str(e)}")
413
- return True, action_screenshot_saved
414
- else:
415
- # Assume OpenAI or compatible format
416
- try:
417
- raw_text = response["choices"][0]["message"]["content"]
418
- standard_content = [{"type": "text", "text": raw_text}]
419
- except (KeyError, TypeError, IndexError) as e:
420
- logger.error(f"Invalid response format: {str(e)}")
421
- return True, action_screenshot_saved
422
-
423
- # Step 2: Add the normalized response to message history
424
- add_assistant_message(standard_content)
425
-
426
- # Step 3: Extract JSON from the content for action execution
427
- parsed_content = None
428
-
429
- # If we have raw text, try to extract JSON from it
430
- if raw_text:
431
- # Try different approaches to extract JSON
432
- try:
433
- # First try to parse the whole content as JSON
434
- parsed_content = json.loads(raw_text)
435
- logger.info("Successfully parsed whole content as JSON")
436
- except json.JSONDecodeError:
437
- try:
438
- # Try to find JSON block
439
- json_content = extract_data(raw_text, "json")
440
- parsed_content = json.loads(json_content)
441
- logger.info("Successfully parsed JSON from code block")
442
- except (json.JSONDecodeError, IndexError):
443
- try:
444
- # Look for JSON object pattern
445
- import re # Local import to ensure availability
446
-
447
- json_pattern = r"\{[^}]+\}"
448
- json_match = re.search(json_pattern, raw_text)
449
- if json_match:
450
- json_str = json_match.group(0)
451
- parsed_content = json.loads(json_str)
452
- logger.info("Successfully parsed JSON from text")
453
- else:
454
- logger.error(f"No JSON found in content")
455
- return True, action_screenshot_saved
456
- except json.JSONDecodeError as e:
457
- # Try to sanitize the JSON string and retry
458
- try:
459
- # Remove or replace invalid control characters
460
- import re # Local import to ensure availability
461
-
462
- sanitized_text = re.sub(r"[\x00-\x1F\x7F]", "", raw_text)
463
- # Try parsing again with sanitized text
464
- parsed_content = json.loads(sanitized_text)
465
- logger.info(
466
- "Successfully parsed JSON after sanitizing control characters"
467
- )
468
- except json.JSONDecodeError:
469
- logger.error(f"Failed to parse JSON from text: {str(e)}")
470
- return True, action_screenshot_saved
471
-
472
- # Step 4: Process the parsed content if available
473
- if parsed_content:
474
- # Clean up Box ID format
475
- if "Box ID" in parsed_content and isinstance(parsed_content["Box ID"], str):
476
- parsed_content["Box ID"] = parsed_content["Box ID"].replace("Box #", "")
477
-
478
- # Add any explanatory text as reasoning if not present
479
- if "Explanation" not in parsed_content and raw_text:
480
- # Extract any text before the JSON as reasoning
481
- text_before_json = raw_text.split("{")[0].strip()
482
- if text_before_json:
483
- parsed_content["Explanation"] = text_before_json
484
-
485
- # Log the parsed content for debugging
486
- logger.info(f"Parsed content: {json.dumps(parsed_content, indent=2)}")
487
-
488
- # Step 5: Execute the action
489
- try:
490
- # Execute action using the common helper method
491
- should_continue, action_screenshot_saved = (
492
- await self._execute_action_with_tools(
493
- parsed_content, cast(ParseResult, parsed_screen)
494
- )
495
- )
496
-
497
- # Check if task is complete
498
- if parsed_content.get("Action") == "None":
499
- return False, action_screenshot_saved
500
- return should_continue, action_screenshot_saved
501
- except Exception as e:
502
- logger.error(f"Error executing action: {str(e)}")
503
- # Update the last assistant message with error
504
- error_message = [{"type": "text", "text": f"Error executing action: {str(e)}"}]
505
- # Replace the last assistant message with the error
506
- self.message_manager.add_assistant_message(error_message)
507
- return False, action_screenshot_saved
508
-
509
- return True, action_screenshot_saved
510
-
511
- except Exception as e:
512
- logger.error(f"Error handling response: {str(e)}")
513
- # Add error message using the message manager
514
- error_message = [{"type": "text", "text": f"Error: {str(e)}"}]
515
- self.message_manager.add_assistant_message(error_message)
516
- raise
517
-
518
- ###########################################
519
- # SCREEN PARSING - IMPLEMENTING ABSTRACT METHOD
520
- ###########################################
521
-
522
- async def _get_parsed_screen_som(self, save_screenshot: bool = True) -> ParseResult:
523
- """Get parsed screen information with Screen Object Model.
524
-
525
- Extends the base class method to use the OmniParser to parse the screen
526
- and extract UI elements.
527
-
528
- Args:
529
- save_screenshot: Whether to save the screenshot (set to False when screenshots will be saved elsewhere)
530
-
531
- Returns:
532
- ParseResult containing screen information and elements
533
- """
534
- try:
535
- # Use the parser's parse_screen method which handles the screenshot internally
536
- parsed_screen = await self.parser.parse_screen(computer=self.computer)
537
-
538
- # Log information about the parsed results
539
- logger.info(
540
- f"Parsed screen with {len(parsed_screen.elements) if parsed_screen.elements else 0} elements"
541
- )
542
-
543
- # Save screenshot if requested and if we have image data
544
- if save_screenshot and self.save_trajectory and parsed_screen.annotated_image_base64:
545
- try:
546
- # Extract just the image data (remove data:image/png;base64, prefix)
547
- img_data = parsed_screen.annotated_image_base64
548
- if "," in img_data:
549
- img_data = img_data.split(",")[1]
550
-
551
- # Process screenshot through hooks and save if needed
552
- await self.handle_screenshot(img_data, action_type="state", parsed_screen=parsed_screen)
553
-
554
- # Save with a generic "state" action type to indicate this is the current screen state
555
- self._save_screenshot(img_data, action_type="state")
556
- except Exception as e:
557
- logger.error(f"Error saving screenshot: {str(e)}")
558
-
559
- return parsed_screen
560
-
561
- except Exception as e:
562
- logger.error(f"Error getting parsed screen: {str(e)}")
563
- raise
564
-
565
- def _get_system_prompt(self) -> str:
566
- """Get the system prompt for the model."""
567
- return SYSTEM_PROMPT
568
-
569
- ###########################################
570
- # MAIN LOOP - IMPLEMENTING ABSTRACT METHOD
571
- ###########################################
572
-
573
- async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]:
574
- """Run the agent loop with provided messages.
575
-
576
- Args:
577
- messages: List of messages in standard OpenAI format
578
-
579
- Yields:
580
- Agent response format
581
- """
582
- try:
583
- logger.info(f"Starting OmniLoop run with {len(messages)} messages")
584
-
585
- # Initialize the message manager with the provided messages
586
- self.message_manager.messages = messages.copy()
587
-
588
- # Create queue for response streaming
589
- queue = asyncio.Queue()
590
-
591
- # Start loop in background task
592
- self.loop_task = asyncio.create_task(self._run_loop(queue, messages))
593
-
594
- # Process and yield messages as they arrive
595
- while True:
596
- try:
597
- item = await queue.get()
598
- if item is None: # Stop signal
599
- break
600
- yield item
601
- queue.task_done()
602
- except Exception as e:
603
- logger.error(f"Error processing queue item: {str(e)}")
604
- continue
605
-
606
- # Wait for loop to complete
607
- await self.loop_task
608
-
609
- # Send completion message
610
- yield {
611
- "role": "assistant",
612
- "content": "Task completed successfully.",
613
- "metadata": {"title": "✅ Complete"},
614
- }
615
-
616
- except Exception as e:
617
- logger.error(f"Error in run method: {str(e)}")
618
- yield {
619
- "role": "assistant",
620
- "content": f"Error: {str(e)}",
621
- "metadata": {"title": "❌ Error"},
622
- }
623
-
624
- async def _run_loop(self, queue: asyncio.Queue, messages: List[Dict[str, Any]]) -> None:
625
- """Internal method to run the agent loop with provided messages.
626
-
627
- Args:
628
- queue: Queue to put responses into
629
- messages: List of messages in standard OpenAI format
630
- """
631
- # Continue running until explicitly told to stop
632
- running = True
633
- turn_created = False
634
- # Track if an action-specific screenshot has been saved this turn
635
- action_screenshot_saved = False
636
-
637
- attempt = 0
638
- max_attempts = 3
639
-
640
- while running and attempt < max_attempts:
641
- try:
642
- # Create a new turn directory if it's not already created
643
- if not turn_created:
644
- self._create_turn_dir()
645
- turn_created = True
646
-
647
- # Ensure client is initialized
648
- if self.client is None:
649
- logger.info("Initializing client...")
650
- await self.initialize_client()
651
- if self.client is None:
652
- raise RuntimeError("Failed to initialize client")
653
- logger.info("Client initialized successfully")
654
-
655
- # Get up-to-date screen information
656
- parsed_screen = await self._get_parsed_screen_som()
657
-
658
- # Process screen info and update messages in standard format
659
- try:
660
- # Get image from parsed screen
661
- image = parsed_screen.annotated_image_base64 or None
662
- if image:
663
- # Save elements as JSON if we have a turn directory
664
- if self.current_turn_dir and hasattr(parsed_screen, "elements"):
665
- elements_path = os.path.join(self.current_turn_dir, "elements.json")
666
- with open(elements_path, "w") as f:
667
- # Convert elements to dicts for JSON serialization
668
- elements_json = [
669
- elem.model_dump() for elem in parsed_screen.elements
670
- ]
671
- json.dump(elements_json, f, indent=2)
672
- logger.info(f"Saved elements to {elements_path}")
673
-
674
- # Remove data URL prefix if present
675
- if "," in image:
676
- image = image.split(",")[1]
677
-
678
- # Add screenshot to message history using message manager
679
- self.message_manager.add_user_message(
680
- [
681
- {
682
- "type": "image_url",
683
- "image_url": {"url": f"data:image/png;base64,{image}"},
684
- }
685
- ]
686
- )
687
- logger.info("Added screenshot to message history")
688
- except Exception as e:
689
- logger.error(f"Error processing screen info: {str(e)}")
690
- raise
691
-
692
- # Get system prompt
693
- system_prompt = self._get_system_prompt()
694
-
695
- # Make API call with retries using the APIHandler
696
- response = await self.api_handler.make_api_call(
697
- self.message_manager.messages, system_prompt
698
- )
699
-
700
- # Handle the response (may execute actions)
701
- # Returns: (should_continue, action_screenshot_saved)
702
- should_continue, new_screenshot_saved = await self._handle_response(
703
- response, self.message_manager.messages, parsed_screen
704
- )
705
-
706
- # Update whether an action screenshot was saved this turn
707
- action_screenshot_saved = action_screenshot_saved or new_screenshot_saved
708
-
709
- # Create OpenAI-compatible response format using utility function
710
- openai_compatible_response = await to_openai_agent_response_format(
711
- response=response,
712
- messages=self.message_manager.messages,
713
- model=self.model,
714
- parsed_screen=parsed_screen,
715
- parser=self.parser
716
- )
717
- # Log standardized response for ease of parsing
718
- self._log_api_call("agent_response", request=None, response=openai_compatible_response)
719
-
720
- # Put the response in the queue
721
- await queue.put(openai_compatible_response)
722
-
723
- # Check if we should continue this conversation
724
- running = should_continue
725
-
726
- # Create a new turn directory if we're continuing
727
- if running:
728
- turn_created = False
729
-
730
- # Reset attempt counter on success
731
- attempt = 0
732
-
733
- except Exception as e:
734
- attempt += 1
735
- error_msg = f"Error in _run_loop method (attempt {attempt}/{max_attempts}): {str(e)}"
736
- logger.error(error_msg)
737
-
738
- # If this is our last attempt, provide more info about the error
739
- if attempt >= max_attempts:
740
- logger.error(f"Maximum retry attempts reached. Last error was: {str(e)}")
741
-
742
- await queue.put({
743
- "role": "assistant",
744
- "content": f"Error: {str(e)}",
745
- "metadata": {"title": "❌ Error"},
746
- })
747
-
748
- # Create a brief delay before retrying
749
- await asyncio.sleep(1)
750
- finally:
751
- # Signal that we're done
752
- await queue.put(None)
753
-
754
- async def cancel(self) -> None:
755
- """Cancel the currently running agent loop task.
756
-
757
- This method stops the ongoing processing in the agent loop
758
- by cancelling the loop_task if it exists and is running.
759
- """
760
- if self.loop_task and not self.loop_task.done():
761
- logger.info("Cancelling Omni loop task")
762
- self.loop_task.cancel()
763
- try:
764
- # Wait for the task to be cancelled with a timeout
765
- await asyncio.wait_for(self.loop_task, timeout=2.0)
766
- except asyncio.TimeoutError:
767
- logger.warning("Timeout while waiting for loop task to cancel")
768
- except asyncio.CancelledError:
769
- logger.info("Loop task cancelled successfully")
770
- except Exception as e:
771
- logger.error(f"Error while cancelling loop task: {str(e)}")
772
- finally:
773
- logger.info("Omni loop task cancelled")
774
- else:
775
- logger.info("No active Omni loop task to cancel")
776
-
777
- async def process_model_response(self, response_text: str) -> Optional[Dict[str, Any]]:
778
- """Process model response to extract tool calls.
779
-
780
- Args:
781
- response_text: Model response text
782
-
783
- Returns:
784
- Extracted tool information, or None if no tool call was found
785
- """
786
- try:
787
- # Ensure tools are initialized before use
788
- await self._ensure_tools_initialized()
789
-
790
- # Look for tool use in the response
791
- if "function_call" in response_text or "tool_use" in response_text:
792
- # The extract_tool_call method should be implemented in the OmniAPIHandler
793
- # For now, we'll just use a simple approach
794
- # This will be replaced with the proper implementation
795
- tool_info = None
796
- if "function_call" in response_text:
797
- # Extract function call params
798
- try:
799
- # Simple extraction - in real code this would be more robust
800
- import json
801
- import re
802
-
803
- match = re.search(r'"function_call"\s*:\s*{([^}]+)}', response_text)
804
- if match:
805
- function_text = "{" + match.group(1) + "}"
806
- tool_info = json.loads(function_text)
807
- except Exception as e:
808
- logger.error(f"Error extracting function call: {str(e)}")
809
-
810
- if tool_info:
811
- try:
812
- # Execute the tool
813
- result = await self.tool_manager.execute_tool(
814
- name=tool_info.get("name"), tool_input=tool_info.get("arguments", {})
815
- )
816
- # Handle the result
817
- return {"tool_result": result}
818
- except Exception as e:
819
- error_msg = (
820
- f"Error executing tool '{tool_info.get('name', 'unknown')}': {str(e)}"
821
- )
822
- logger.error(error_msg)
823
- return {"tool_result": ToolResult(error=error_msg)}
824
- except Exception as e:
825
- logger.error(f"Error processing tool call: {str(e)}")
826
-
827
- return None
828
-
829
- async def process_response_with_tools(
830
- self, response_text: str, parsed_screen: Optional[ParseResult] = None
831
- ) -> Tuple[bool, str]:
832
- """Process model response and execute tools.
833
-
834
- Args:
835
- response_text: Model response text
836
- parsed_screen: Current parsed screen information (optional)
837
-
838
- Returns:
839
- Tuple of (action_taken, observation)
840
- """
841
- logger.info("Processing response with tools")
842
-
843
- # Process the response to extract tool calls
844
- tool_result = await self.process_model_response(response_text)
845
-
846
- if tool_result and "tool_result" in tool_result:
847
- # A tool was executed
848
- result = tool_result["tool_result"]
849
- if result.error:
850
- return False, f"ERROR: {result.error}"
851
- else:
852
- return True, result.output or "Tool executed successfully"
853
-
854
- # No action or tool call found
855
- return False, "No action taken - no tool call detected in response"
856
-
857
- ###########################################
858
- # UTILITY METHODS
859
- ###########################################
860
-
861
- async def _ensure_tools_initialized(self) -> None:
862
- """Ensure the tool manager and tools are initialized before use."""
863
- if not hasattr(self.tool_manager, "tools") or self.tool_manager.tools is None:
864
- logger.info("Tools not initialized. Initializing now...")
865
- await self.tool_manager.initialize()
866
- logger.info("Tools initialized successfully.")
867
-
868
- async def _execute_action_with_tools(
869
- self, action_data: Dict[str, Any], parsed_screen: ParseResult
870
- ) -> Tuple[bool, bool]:
871
- """Execute an action using the tools-based approach.
872
-
873
- Args:
874
- action_data: Dictionary containing action details
875
- parsed_screen: Current parsed screen information
876
-
877
- Returns:
878
- Tuple of (should_continue, action_screenshot_saved)
879
- """
880
- action_screenshot_saved = False
881
- action_type = None # Initialize for possible use in post-action screenshot
882
-
883
- try:
884
- # Extract the action
885
- parsed_action = action_data.get("Action", "").lower()
886
-
887
- # Only process if we have a valid action
888
- if not parsed_action or parsed_action == "none":
889
- return False, action_screenshot_saved
890
-
891
- # Convert the parsed content to a format suitable for the tools system
892
- tool_name = "computer" # Default to computer tool
893
- tool_args = {"action": parsed_action}
894
-
895
- # Add specific arguments based on action type
896
- if parsed_action in ["left_click", "right_click", "double_click", "move_cursor"]:
897
- # Calculate coordinates from Box ID using parser
898
- try:
899
- box_id = int(action_data["Box ID"])
900
- x, y = await self.parser.calculate_click_coordinates(
901
- box_id, cast(ParseResult, parsed_screen)
902
- )
903
- tool_args["x"] = x
904
- tool_args["y"] = y
905
-
906
- # Visualize action if screenshot is available
907
- if parsed_screen and parsed_screen.annotated_image_base64:
908
- img_data = parsed_screen.annotated_image_base64
909
- # Remove data URL prefix if present
910
- if img_data.startswith("data:image"):
911
- img_data = img_data.split(",")[1]
912
- # Save visualization for coordinate-based actions
913
- self.viz_helper.visualize_action(x, y, img_data)
914
- action_screenshot_saved = True
915
-
916
- except (ValueError, KeyError) as e:
917
- logger.error(f"Error processing Box ID: {str(e)}")
918
- return False, action_screenshot_saved
919
-
920
- elif parsed_action == "type_text":
921
- tool_args["text"] = action_data.get("Value", "")
922
- # For type_text, store the value in the action type for screenshot naming
923
- action_type = f"type_{tool_args['text'][:20]}" # Truncate if too long
924
-
925
- elif parsed_action == "press_key":
926
- tool_args["key"] = action_data.get("Value", "")
927
- action_type = f"press_{tool_args['key']}"
928
-
929
- elif parsed_action == "hotkey":
930
- value = action_data.get("Value", "")
931
- if isinstance(value, list):
932
- tool_args["keys"] = value
933
- action_type = f"hotkey_{'_'.join(value)}"
934
- else:
935
- # Split string format like "command+space" into a list
936
- keys = [k.strip() for k in value.lower().split("+")]
937
- tool_args["keys"] = keys
938
- action_type = f"hotkey_{value.replace('+', '_')}"
939
-
940
- elif parsed_action in ["scroll_down", "scroll_up"]:
941
- clicks = int(action_data.get("amount", 1))
942
- tool_args["amount"] = clicks
943
- action_type = f"scroll_{parsed_action.split('_')[1]}_{clicks}"
944
-
945
- # Visualize scrolling if screenshot is available
946
- if parsed_screen and parsed_screen.annotated_image_base64:
947
- img_data = parsed_screen.annotated_image_base64
948
- # Remove data URL prefix if present
949
- if img_data.startswith("data:image"):
950
- img_data = img_data.split(",")[1]
951
- direction = "down" if parsed_action == "scroll_down" else "up"
952
- # For scrolling, we save the visualization
953
- self.viz_helper.visualize_scroll(direction, clicks, img_data)
954
- action_screenshot_saved = True
955
-
956
- # Ensure tools are initialized before use
957
- await self._ensure_tools_initialized()
958
-
959
- # Execute tool with prepared arguments
960
- result = await self.tool_manager.execute_tool(name=tool_name, tool_input=tool_args)
961
-
962
- # Take a new screenshot after the action if we haven't already saved one
963
- if not action_screenshot_saved:
964
- try:
965
- # Get a new screenshot after the action
966
- new_parsed_screen = await self._get_parsed_screen_som(save_screenshot=False)
967
- if new_parsed_screen and new_parsed_screen.annotated_image_base64:
968
- img_data = new_parsed_screen.annotated_image_base64
969
- # Remove data URL prefix if present
970
- if img_data.startswith("data:image"):
971
- img_data = img_data.split(",")[1]
972
- # Save with action type if defined, otherwise use the action name
973
- if action_type:
974
- self._save_screenshot(img_data, action_type=action_type)
975
- else:
976
- self._save_screenshot(img_data, action_type=parsed_action)
977
- action_screenshot_saved = True
978
- except Exception as screenshot_error:
979
- logger.error(f"Error taking post-action screenshot: {str(screenshot_error)}")
980
-
981
- # Continue the loop if the action is not "None"
982
- return True, action_screenshot_saved
983
-
984
- except Exception as e:
985
- logger.error(f"Error executing action: {str(e)}")
986
- # Update the last assistant message with error
987
- error_message = [{"type": "text", "text": f"Error executing action: {str(e)}"}]
988
- # Replace the last assistant message with the error
989
- self.message_manager.add_assistant_message(error_message)
990
- return False, action_screenshot_saved