cua-agent 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cua-agent might be problematic. Click here for more details.

Files changed (65) hide show
  1. agent/README.md +63 -0
  2. agent/__init__.py +10 -0
  3. agent/core/README.md +101 -0
  4. agent/core/__init__.py +34 -0
  5. agent/core/agent.py +284 -0
  6. agent/core/base_agent.py +164 -0
  7. agent/core/callbacks.py +147 -0
  8. agent/core/computer_agent.py +69 -0
  9. agent/core/experiment.py +222 -0
  10. agent/core/factory.py +102 -0
  11. agent/core/loop.py +244 -0
  12. agent/core/messages.py +230 -0
  13. agent/core/tools/__init__.py +21 -0
  14. agent/core/tools/base.py +74 -0
  15. agent/core/tools/bash.py +52 -0
  16. agent/core/tools/collection.py +46 -0
  17. agent/core/tools/computer.py +113 -0
  18. agent/core/tools/edit.py +67 -0
  19. agent/core/tools/manager.py +56 -0
  20. agent/providers/__init__.py +4 -0
  21. agent/providers/anthropic/__init__.py +6 -0
  22. agent/providers/anthropic/api/client.py +222 -0
  23. agent/providers/anthropic/api/logging.py +150 -0
  24. agent/providers/anthropic/callbacks/manager.py +55 -0
  25. agent/providers/anthropic/loop.py +521 -0
  26. agent/providers/anthropic/messages/manager.py +110 -0
  27. agent/providers/anthropic/prompts.py +20 -0
  28. agent/providers/anthropic/tools/__init__.py +33 -0
  29. agent/providers/anthropic/tools/base.py +88 -0
  30. agent/providers/anthropic/tools/bash.py +163 -0
  31. agent/providers/anthropic/tools/collection.py +34 -0
  32. agent/providers/anthropic/tools/computer.py +550 -0
  33. agent/providers/anthropic/tools/edit.py +326 -0
  34. agent/providers/anthropic/tools/manager.py +54 -0
  35. agent/providers/anthropic/tools/run.py +42 -0
  36. agent/providers/anthropic/types.py +16 -0
  37. agent/providers/omni/__init__.py +27 -0
  38. agent/providers/omni/callbacks.py +78 -0
  39. agent/providers/omni/clients/anthropic.py +99 -0
  40. agent/providers/omni/clients/base.py +44 -0
  41. agent/providers/omni/clients/groq.py +101 -0
  42. agent/providers/omni/clients/openai.py +159 -0
  43. agent/providers/omni/clients/utils.py +25 -0
  44. agent/providers/omni/experiment.py +273 -0
  45. agent/providers/omni/image_utils.py +106 -0
  46. agent/providers/omni/loop.py +961 -0
  47. agent/providers/omni/messages.py +168 -0
  48. agent/providers/omni/parser.py +252 -0
  49. agent/providers/omni/prompts.py +78 -0
  50. agent/providers/omni/tool_manager.py +91 -0
  51. agent/providers/omni/tools/__init__.py +13 -0
  52. agent/providers/omni/tools/bash.py +69 -0
  53. agent/providers/omni/tools/computer.py +216 -0
  54. agent/providers/omni/tools/manager.py +83 -0
  55. agent/providers/omni/types.py +30 -0
  56. agent/providers/omni/utils.py +155 -0
  57. agent/providers/omni/visualization.py +130 -0
  58. agent/types/__init__.py +26 -0
  59. agent/types/base.py +52 -0
  60. agent/types/messages.py +36 -0
  61. agent/types/tools.py +32 -0
  62. cua_agent-0.1.0.dist-info/METADATA +44 -0
  63. cua_agent-0.1.0.dist-info/RECORD +65 -0
  64. cua_agent-0.1.0.dist-info/WHEEL +4 -0
  65. cua_agent-0.1.0.dist-info/entry_points.txt +4 -0
@@ -0,0 +1,961 @@
1
+ """Omni-specific agent loop implementation."""
2
+
3
+ import logging
4
+ from typing import Any, Dict, List, Optional, Tuple, AsyncGenerator, Union
5
+ import base64
6
+ from PIL import Image
7
+ from io import BytesIO
8
+ import json
9
+ import re
10
+ import os
11
+ from datetime import datetime
12
+ import asyncio
13
+ from httpx import ConnectError, ReadTimeout
14
+ import shutil
15
+ import copy
16
+
17
+ from .parser import OmniParser, ParseResult, ParserMetadata, UIElement
18
+ from ...core.loop import BaseLoop
19
+ from computer import Computer
20
+ from .types import APIProvider
21
+ from .clients.base import BaseOmniClient
22
+ from .clients.openai import OpenAIClient
23
+ from .clients.groq import GroqClient
24
+ from .clients.anthropic import AnthropicClient
25
+ from .prompts import SYSTEM_PROMPT
26
+ from .utils import compress_image_base64
27
+ from .visualization import visualize_click, visualize_scroll, calculate_element_center
28
+ from .image_utils import decode_base64_image, clean_base64_data
29
+ from ...core.messages import ImageRetentionConfig
30
+ from .messages import OmniMessageManager
31
+
32
+ logging.basicConfig(level=logging.INFO)
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ def extract_data(input_string: str, data_type: str) -> str:
37
+ """Extract content from code blocks."""
38
+ pattern = f"```{data_type}" + r"(.*?)(```|$)"
39
+ matches = re.findall(pattern, input_string, re.DOTALL)
40
+ return matches[0][0].strip() if matches else input_string
41
+
42
+
43
+ class OmniLoop(BaseLoop):
44
+ """Omni-specific implementation of the agent loop."""
45
+
46
+ def __init__(
47
+ self,
48
+ parser: OmniParser,
49
+ provider: APIProvider,
50
+ api_key: str,
51
+ model: str,
52
+ computer: Computer,
53
+ only_n_most_recent_images: Optional[int] = 2,
54
+ base_dir: Optional[str] = "trajectories",
55
+ max_retries: int = 3,
56
+ retry_delay: float = 1.0,
57
+ save_trajectory: bool = True,
58
+ **kwargs,
59
+ ):
60
+ """Initialize the loop.
61
+
62
+ Args:
63
+ parser: Parser instance
64
+ provider: API provider
65
+ api_key: API key
66
+ model: Model name
67
+ computer: Computer instance
68
+ only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
69
+ base_dir: Base directory for saving experiment data
70
+ max_retries: Maximum number of retries for API calls
71
+ retry_delay: Delay between retries in seconds
72
+ save_trajectory: Whether to save trajectory data
73
+ """
74
+ # Set parser and provider before initializing base class
75
+ self.parser = parser
76
+ self.provider = provider
77
+
78
+ # Initialize message manager with image retention config
79
+ image_retention_config = ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
80
+ self.message_manager = OmniMessageManager(config=image_retention_config)
81
+
82
+ # Initialize base class (which will set up experiment manager)
83
+ super().__init__(
84
+ computer=computer,
85
+ model=model,
86
+ api_key=api_key,
87
+ max_retries=max_retries,
88
+ retry_delay=retry_delay,
89
+ base_dir=base_dir,
90
+ save_trajectory=save_trajectory,
91
+ only_n_most_recent_images=only_n_most_recent_images,
92
+ **kwargs,
93
+ )
94
+
95
+ # Set API client attributes
96
+ self.client = None
97
+ self.retry_count = 0
98
+
99
+ def _should_save_debug_image(self) -> bool:
100
+ """Check if debug images should be saved.
101
+
102
+ Returns:
103
+ bool: Always returns False as debug image saving has been disabled.
104
+ """
105
+ # Debug image saving functionality has been removed
106
+ return False
107
+
108
+ def _extract_and_save_images(self, data: Any, prefix: str) -> None:
109
+ """Extract and save images from API data.
110
+
111
+ This method is now a no-op as image extraction functionality has been removed.
112
+
113
+ Args:
114
+ data: Data to extract images from
115
+ prefix: Prefix for the extracted image filenames
116
+ """
117
+ # Image extraction functionality has been removed
118
+ return
119
+
120
+ def _save_debug_image(self, image_data: str, filename: str) -> None:
121
+ """Save a debug image to the current turn directory.
122
+
123
+ This method is now a no-op as debug image saving functionality has been removed.
124
+
125
+ Args:
126
+ image_data: Base64 encoded image data
127
+ filename: Name to use for the saved image
128
+ """
129
+ # Debug image saving functionality has been removed
130
+ return
131
+
132
+ def _visualize_action(self, x: int, y: int, img_base64: str) -> None:
133
+ """Visualize an action by drawing on the screenshot."""
134
+ if (
135
+ not self.save_trajectory
136
+ or not hasattr(self, "experiment_manager")
137
+ or not self.experiment_manager
138
+ ):
139
+ return
140
+
141
+ try:
142
+ # Use the visualization utility
143
+ img = visualize_click(x, y, img_base64)
144
+
145
+ # Save the visualization
146
+ self.experiment_manager.save_action_visualization(img, "click", f"x{x}_y{y}")
147
+ except Exception as e:
148
+ logger.error(f"Error visualizing action: {str(e)}")
149
+
150
+ def _visualize_scroll(self, direction: str, clicks: int, img_base64: str) -> None:
151
+ """Visualize a scroll action by drawing arrows on the screenshot."""
152
+ if (
153
+ not self.save_trajectory
154
+ or not hasattr(self, "experiment_manager")
155
+ or not self.experiment_manager
156
+ ):
157
+ return
158
+
159
+ try:
160
+ # Use the visualization utility
161
+ img = visualize_scroll(direction, clicks, img_base64)
162
+
163
+ # Save the visualization
164
+ self.experiment_manager.save_action_visualization(
165
+ img, "scroll", f"{direction}_{clicks}"
166
+ )
167
+ except Exception as e:
168
+ logger.error(f"Error visualizing scroll: {str(e)}")
169
+
170
+ def _save_action_visualization(
171
+ self, img: Image.Image, action_name: str, details: str = ""
172
+ ) -> str:
173
+ """Save a visualization of an action."""
174
+ if hasattr(self, "experiment_manager") and self.experiment_manager:
175
+ return self.experiment_manager.save_action_visualization(img, action_name, details)
176
+ return ""
177
+
178
+ async def initialize_client(self) -> None:
179
+ """Initialize the appropriate client based on provider."""
180
+ try:
181
+ logger.info(f"Initializing {self.provider} client with model {self.model}...")
182
+
183
+ if self.provider == APIProvider.OPENAI:
184
+ self.client = OpenAIClient(api_key=self.api_key, model=self.model)
185
+ elif self.provider == APIProvider.GROQ:
186
+ self.client = GroqClient(api_key=self.api_key, model=self.model)
187
+ elif self.provider == APIProvider.ANTHROPIC:
188
+ self.client = AnthropicClient(
189
+ api_key=self.api_key,
190
+ model=self.model,
191
+ max_retries=self.max_retries,
192
+ retry_delay=self.retry_delay,
193
+ )
194
+ else:
195
+ raise ValueError(f"Unsupported provider: {self.provider}")
196
+
197
+ logger.info(f"Initialized {self.provider} client with model {self.model}")
198
+ except Exception as e:
199
+ logger.error(f"Error initializing client: {str(e)}")
200
+ self.client = None
201
+ raise RuntimeError(f"Failed to initialize client: {str(e)}")
202
+
203
+ async def _make_api_call(self, messages: List[Dict[str, Any]], system_prompt: str) -> Any:
204
+ """Make API call to provider with retry logic."""
205
+ # Create new turn directory for this API call
206
+ self._create_turn_dir()
207
+
208
+ request_data = None
209
+ last_error = None
210
+
211
+ for attempt in range(self.max_retries):
212
+ try:
213
+ # Ensure client is initialized
214
+ if self.client is None:
215
+ logger.info(
216
+ f"Client not initialized in _make_api_call (attempt {attempt+1}), initializing now..."
217
+ )
218
+ await self.initialize_client()
219
+ if self.client is None:
220
+ raise RuntimeError("Failed to initialize client")
221
+
222
+ # Apply image retention and prepare messages
223
+ # This will limit the number of images based on only_n_most_recent_images
224
+ prepared_messages = self.message_manager.prepare_messages(messages.copy())
225
+
226
+ # Filter out system messages for Anthropic
227
+ if self.provider == APIProvider.ANTHROPIC:
228
+ filtered_messages = [
229
+ msg for msg in prepared_messages if msg["role"] != "system"
230
+ ]
231
+ else:
232
+ filtered_messages = prepared_messages
233
+
234
+ # Log request
235
+ request_data = {"messages": filtered_messages, "max_tokens": self.max_tokens}
236
+
237
+ if self.provider == APIProvider.ANTHROPIC:
238
+ request_data["system"] = self._get_system_prompt()
239
+ else:
240
+ request_data["system"] = system_prompt
241
+
242
+ self._log_api_call("request", request_data)
243
+
244
+ # Make API call with appropriate parameters
245
+ if self.client is None:
246
+ raise RuntimeError("Client not initialized. Call initialize_client() first.")
247
+
248
+ # Check if the method is async by inspecting the client implementation
249
+ run_method = self.client.run_interleaved
250
+ is_async = asyncio.iscoroutinefunction(run_method)
251
+
252
+ if is_async:
253
+ # For async implementations (AnthropicClient)
254
+ if self.provider == APIProvider.ANTHROPIC:
255
+ response = await run_method(
256
+ messages=filtered_messages,
257
+ system=self._get_system_prompt(),
258
+ max_tokens=self.max_tokens,
259
+ )
260
+ else:
261
+ response = await run_method(
262
+ messages=messages,
263
+ system=system_prompt,
264
+ max_tokens=self.max_tokens,
265
+ )
266
+ else:
267
+ # For non-async implementations (GroqClient, etc.)
268
+ if self.provider == APIProvider.ANTHROPIC:
269
+ response = run_method(
270
+ messages=filtered_messages,
271
+ system=self._get_system_prompt(),
272
+ max_tokens=self.max_tokens,
273
+ )
274
+ else:
275
+ response = run_method(
276
+ messages=messages,
277
+ system=system_prompt,
278
+ max_tokens=self.max_tokens,
279
+ )
280
+
281
+ # Log success response
282
+ self._log_api_call("response", request_data, response)
283
+
284
+ return response
285
+
286
+ except (ConnectError, ReadTimeout) as e:
287
+ last_error = e
288
+ logger.warning(
289
+ f"Connection error on attempt {attempt + 1}/{self.max_retries}: {str(e)}"
290
+ )
291
+ if attempt < self.max_retries - 1:
292
+ await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
293
+ # Reset client on connection errors to force re-initialization
294
+ self.client = None
295
+ continue
296
+
297
+ except RuntimeError as e:
298
+ # Handle client initialization errors specifically
299
+ last_error = e
300
+ self._log_api_call("error", request_data, error=e)
301
+ logger.error(
302
+ f"Client initialization error (attempt {attempt + 1}/{self.max_retries}): {str(e)}"
303
+ )
304
+ if attempt < self.max_retries - 1:
305
+ # Reset client to force re-initialization
306
+ self.client = None
307
+ await asyncio.sleep(self.retry_delay)
308
+ continue
309
+
310
+ except Exception as e:
311
+ # Log unexpected error
312
+ last_error = e
313
+ self._log_api_call("error", request_data, error=e)
314
+ logger.error(f"Unexpected error in API call: {str(e)}")
315
+ if attempt < self.max_retries - 1:
316
+ await asyncio.sleep(self.retry_delay)
317
+ continue
318
+
319
+ # If we get here, all retries failed
320
+ error_message = f"API call failed after {self.max_retries} attempts"
321
+ if last_error:
322
+ error_message += f": {str(last_error)}"
323
+
324
+ logger.error(error_message)
325
+ raise RuntimeError(error_message)
326
+
327
+ async def _handle_response(
328
+ self, response: Any, messages: List[Dict[str, Any]], parsed_screen: Dict[str, Any]
329
+ ) -> Tuple[bool, bool]:
330
+ """Handle API response.
331
+
332
+ Returns:
333
+ Tuple of (should_continue, action_screenshot_saved)
334
+ """
335
+ action_screenshot_saved = False
336
+ try:
337
+ # Handle Anthropic response format
338
+ if self.provider == APIProvider.ANTHROPIC:
339
+ if hasattr(response, "content") and isinstance(response.content, list):
340
+ # Extract text from content blocks
341
+ for block in response.content:
342
+ if hasattr(block, "type") and block.type == "text":
343
+ content = block.text
344
+
345
+ # Try to find JSON in the content
346
+ try:
347
+ # First look for JSON block
348
+ json_content = extract_data(content, "json")
349
+ parsed_content = json.loads(json_content)
350
+ logger.info("Successfully parsed JSON from code block")
351
+ except (json.JSONDecodeError, IndexError):
352
+ # If no JSON block, try to find JSON object in the text
353
+ try:
354
+ # Look for JSON object pattern
355
+ json_pattern = r"\{[^}]+\}"
356
+ json_match = re.search(json_pattern, content)
357
+ if json_match:
358
+ json_str = json_match.group(0)
359
+ parsed_content = json.loads(json_str)
360
+ logger.info("Successfully parsed JSON from text")
361
+ else:
362
+ logger.error(f"No JSON found in content: {content}")
363
+ continue
364
+ except json.JSONDecodeError as e:
365
+ logger.error(f"Failed to parse JSON from text: {str(e)}")
366
+ continue
367
+
368
+ # Clean up Box ID format
369
+ if "Box ID" in parsed_content and isinstance(
370
+ parsed_content["Box ID"], str
371
+ ):
372
+ parsed_content["Box ID"] = parsed_content["Box ID"].replace(
373
+ "Box #", ""
374
+ )
375
+
376
+ # Add any explanatory text as reasoning if not present
377
+ if "Explanation" not in parsed_content:
378
+ # Extract any text before the JSON as reasoning
379
+ text_before_json = content.split("{")[0].strip()
380
+ if text_before_json:
381
+ parsed_content["Explanation"] = text_before_json
382
+
383
+ # Log the parsed content for debugging
384
+ logger.info(f"Parsed content: {json.dumps(parsed_content, indent=2)}")
385
+
386
+ # Add response to messages
387
+ messages.append(
388
+ {"role": "assistant", "content": json.dumps(parsed_content)}
389
+ )
390
+
391
+ try:
392
+ # Execute action with current parsed screen info
393
+ await self._execute_action(parsed_content, parsed_screen)
394
+ action_screenshot_saved = True
395
+ except Exception as e:
396
+ logger.error(f"Error executing action: {str(e)}")
397
+ # Add error message to conversation
398
+ messages.append(
399
+ {
400
+ "role": "assistant",
401
+ "content": f"Error executing action: {str(e)}",
402
+ "metadata": {"title": "❌ Error"},
403
+ }
404
+ )
405
+ return False, action_screenshot_saved
406
+
407
+ # Check if task is complete
408
+ if parsed_content.get("Action") == "None":
409
+ return False, action_screenshot_saved
410
+ return True, action_screenshot_saved
411
+
412
+ logger.warning("No text block found in Anthropic response")
413
+ return True, action_screenshot_saved
414
+
415
+ # Handle other providers' response formats
416
+ if isinstance(response, dict) and "choices" in response:
417
+ content = response["choices"][0]["message"]["content"]
418
+ else:
419
+ content = response
420
+
421
+ # Parse JSON content
422
+ if isinstance(content, str):
423
+ try:
424
+ # First try to parse the whole content as JSON
425
+ parsed_content = json.loads(content)
426
+ except json.JSONDecodeError:
427
+ try:
428
+ # Try to find JSON block
429
+ json_content = extract_data(content, "json")
430
+ parsed_content = json.loads(json_content)
431
+ except (json.JSONDecodeError, IndexError):
432
+ try:
433
+ # Look for JSON object pattern
434
+ json_pattern = r"\{[^}]+\}"
435
+ json_match = re.search(json_pattern, content)
436
+ if json_match:
437
+ json_str = json_match.group(0)
438
+ parsed_content = json.loads(json_str)
439
+ else:
440
+ logger.error(f"No JSON found in content: {content}")
441
+ return True, action_screenshot_saved
442
+ except json.JSONDecodeError as e:
443
+ logger.error(f"Failed to parse JSON from text: {str(e)}")
444
+ return True, action_screenshot_saved
445
+
446
+ # Clean up Box ID format
447
+ if "Box ID" in parsed_content and isinstance(parsed_content["Box ID"], str):
448
+ parsed_content["Box ID"] = parsed_content["Box ID"].replace("Box #", "")
449
+
450
+ # Add any explanatory text as reasoning if not present
451
+ if "Explanation" not in parsed_content:
452
+ # Extract any text before the JSON as reasoning
453
+ text_before_json = content.split("{")[0].strip()
454
+ if text_before_json:
455
+ parsed_content["Explanation"] = text_before_json
456
+
457
+ # Add response to messages with stringified content
458
+ messages.append({"role": "assistant", "content": json.dumps(parsed_content)})
459
+
460
+ try:
461
+ # Execute action with current parsed screen info
462
+ await self._execute_action(parsed_content, parsed_screen)
463
+ action_screenshot_saved = True
464
+ except Exception as e:
465
+ logger.error(f"Error executing action: {str(e)}")
466
+ # Add error message to conversation
467
+ messages.append(
468
+ {
469
+ "role": "assistant",
470
+ "content": f"Error executing action: {str(e)}",
471
+ "metadata": {"title": "❌ Error"},
472
+ }
473
+ )
474
+ return False, action_screenshot_saved
475
+
476
+ # Check if task is complete
477
+ if parsed_content.get("Action") == "None":
478
+ return False, action_screenshot_saved
479
+
480
+ return True, action_screenshot_saved
481
+ elif isinstance(content, dict):
482
+ # Handle case where content is already a dictionary
483
+ messages.append({"role": "assistant", "content": json.dumps(content)})
484
+
485
+ try:
486
+ # Execute action with current parsed screen info
487
+ await self._execute_action(content, parsed_screen)
488
+ action_screenshot_saved = True
489
+ except Exception as e:
490
+ logger.error(f"Error executing action: {str(e)}")
491
+ # Add error message to conversation
492
+ messages.append(
493
+ {
494
+ "role": "assistant",
495
+ "content": f"Error executing action: {str(e)}",
496
+ "metadata": {"title": "❌ Error"},
497
+ }
498
+ )
499
+ return False, action_screenshot_saved
500
+
501
+ # Check if task is complete
502
+ if content.get("Action") == "None":
503
+ return False, action_screenshot_saved
504
+
505
+ return True, action_screenshot_saved
506
+
507
+ return True, action_screenshot_saved
508
+
509
+ except Exception as e:
510
+ logger.error(f"Error handling response: {str(e)}")
511
+ messages.append(
512
+ {
513
+ "role": "assistant",
514
+ "content": f"Error: {str(e)}",
515
+ "metadata": {"title": "❌ Error"},
516
+ }
517
+ )
518
+ raise
519
+
520
+ async def _get_parsed_screen_som(self, save_screenshot: bool = True) -> ParseResult:
521
+ """Get parsed screen information with SOM.
522
+
523
+ Args:
524
+ save_screenshot: Whether to save the screenshot (set to False when screenshots will be saved elsewhere)
525
+
526
+ Returns:
527
+ ParseResult containing screen information and elements
528
+ """
529
+ try:
530
+ # Use the parser's parse_screen method which handles the screenshot internally
531
+ parsed_screen = await self.parser.parse_screen(computer=self.computer)
532
+
533
+ # Log information about the parsed results
534
+ logger.info(
535
+ f"Parsed screen with {len(parsed_screen.elements) if parsed_screen.elements else 0} elements"
536
+ )
537
+
538
+ # Save screenshot if requested and if we have image data
539
+ if save_screenshot and self.save_trajectory and parsed_screen.annotated_image_base64:
540
+ try:
541
+ # Extract just the image data (remove data:image/png;base64, prefix)
542
+ img_data = parsed_screen.annotated_image_base64
543
+ if "," in img_data:
544
+ img_data = img_data.split(",")[1]
545
+ # Save with a generic "state" action type to indicate this is the current screen state
546
+ self._save_screenshot(img_data, action_type="state")
547
+ except Exception as e:
548
+ logger.error(f"Error saving screenshot: {str(e)}")
549
+
550
+ return parsed_screen
551
+
552
+ except Exception as e:
553
+ logger.error(f"Error getting parsed screen: {str(e)}")
554
+ raise
555
+
556
+ async def _process_screen(
557
+ self, parsed_screen: ParseResult, messages: List[Dict[str, Any]]
558
+ ) -> None:
559
+ """Process and add screen info to messages."""
560
+ try:
561
+ # Only add message if we have an image and provider supports it
562
+ if self.provider in [APIProvider.OPENAI, APIProvider.ANTHROPIC]:
563
+ image = parsed_screen.annotated_image_base64 or None
564
+ if image:
565
+ # Save screen info to current turn directory
566
+ if self.current_turn_dir:
567
+ # Save elements as JSON
568
+ elements_path = os.path.join(self.current_turn_dir, "elements.json")
569
+ with open(elements_path, "w") as f:
570
+ # Convert elements to dicts for JSON serialization
571
+ elements_json = [elem.model_dump() for elem in parsed_screen.elements]
572
+ json.dump(elements_json, f, indent=2)
573
+ logger.info(f"Saved elements to {elements_path}")
574
+
575
+ # Format the image content based on the provider
576
+ if self.provider == APIProvider.ANTHROPIC:
577
+ # Compress the image before sending to Anthropic (5MB limit)
578
+ image_size = len(image)
579
+ logger.info(f"Image base64 is present, length: {image_size}")
580
+
581
+ # Anthropic has a 5MB limit - check against base64 string length
582
+ # which is what matters for the API call payload
583
+ # Use slightly smaller limit (4.9MB) to account for request overhead
584
+ max_size = int(4.9 * 1024 * 1024) # 4.9MB
585
+
586
+ # Default media type (will be overridden if compression is needed)
587
+ media_type = "image/png"
588
+
589
+ # Check if the image already has a media type prefix
590
+ if image.startswith("data:"):
591
+ parts = image.split(",", 1)
592
+ if len(parts) == 2 and "image/jpeg" in parts[0].lower():
593
+ media_type = "image/jpeg"
594
+ elif len(parts) == 2 and "image/png" in parts[0].lower():
595
+ media_type = "image/png"
596
+
597
+ if image_size > max_size:
598
+ logger.info(
599
+ f"Image size ({image_size} bytes) exceeds Anthropic limit ({max_size} bytes), compressing..."
600
+ )
601
+ image, media_type = compress_image_base64(image, max_size)
602
+ logger.info(
603
+ f"Image compressed to {len(image)} bytes with media_type {media_type}"
604
+ )
605
+
606
+ # Anthropic uses "type": "image"
607
+ screen_info_msg = {
608
+ "role": "user",
609
+ "content": [
610
+ {
611
+ "type": "image",
612
+ "source": {
613
+ "type": "base64",
614
+ "media_type": media_type,
615
+ "data": image,
616
+ },
617
+ }
618
+ ],
619
+ }
620
+ else:
621
+ # OpenAI and others use "type": "image_url"
622
+ screen_info_msg = {
623
+ "role": "user",
624
+ "content": [
625
+ {
626
+ "type": "image_url",
627
+ "image_url": {"url": f"data:image/png;base64,{image}"},
628
+ }
629
+ ],
630
+ }
631
+ messages.append(screen_info_msg)
632
+
633
+ except Exception as e:
634
+ logger.error(f"Error processing screen info: {str(e)}")
635
+ raise
636
+
637
+ def _get_system_prompt(self) -> str:
638
+ """Get the system prompt for the model."""
639
+ return SYSTEM_PROMPT
640
+
641
+ async def _execute_action(self, content: Dict[str, Any], parsed_screen: ParseResult) -> None:
642
+ """Execute the action specified in the content using the tool manager.
643
+
644
+ Args:
645
+ content: Dictionary containing the action details
646
+ parsed_screen: Current parsed screen information
647
+ """
648
+ try:
649
+ action = content.get("Action", "").lower()
650
+ if not action:
651
+ return
652
+
653
+ # Track if we saved an action-specific screenshot
654
+ action_screenshot_saved = False
655
+
656
+ try:
657
+ # Prepare kwargs based on action type
658
+ kwargs = {}
659
+
660
+ if action in ["left_click", "right_click", "double_click", "move_cursor"]:
661
+ try:
662
+ box_id = int(content["Box ID"])
663
+ logger.info(f"Processing Box ID: {box_id}")
664
+
665
+ # Calculate click coordinates
666
+ x, y = await self._calculate_click_coordinates(box_id, parsed_screen)
667
+ logger.info(f"Calculated coordinates: x={x}, y={y}")
668
+
669
+ kwargs["x"] = x
670
+ kwargs["y"] = y
671
+
672
+ # Visualize action if screenshot is available
673
+ if parsed_screen.annotated_image_base64:
674
+ img_data = parsed_screen.annotated_image_base64
675
+ # Remove data URL prefix if present
676
+ if img_data.startswith("data:image"):
677
+ img_data = img_data.split(",")[1]
678
+ # Only save visualization for coordinate-based actions
679
+ self._visualize_action(x, y, img_data)
680
+ action_screenshot_saved = True
681
+
682
+ except ValueError as e:
683
+ logger.error(f"Error processing Box ID: {str(e)}")
684
+ return
685
+
686
+ elif action == "drag_to":
687
+ try:
688
+ box_id = int(content["Box ID"])
689
+ x, y = await self._calculate_click_coordinates(box_id, parsed_screen)
690
+ kwargs.update(
691
+ {
692
+ "x": x,
693
+ "y": y,
694
+ "button": content.get("button", "left"),
695
+ "duration": float(content.get("duration", 0.5)),
696
+ }
697
+ )
698
+
699
+ # Visualize drag destination if screenshot is available
700
+ if parsed_screen.annotated_image_base64:
701
+ img_data = parsed_screen.annotated_image_base64
702
+ # Remove data URL prefix if present
703
+ if img_data.startswith("data:image"):
704
+ img_data = img_data.split(",")[1]
705
+ # Only save visualization for coordinate-based actions
706
+ self._visualize_action(x, y, img_data)
707
+ action_screenshot_saved = True
708
+
709
+ except ValueError as e:
710
+ logger.error(f"Error processing drag coordinates: {str(e)}")
711
+ return
712
+
713
+ elif action == "type_text":
714
+ kwargs["text"] = content["Value"]
715
+ # For type_text, store the value in the action type
716
+ action_type = f"type_{content['Value'][:20]}" # Truncate if too long
717
+ elif action == "press_key":
718
+ kwargs["key"] = content["Value"]
719
+ action_type = f"press_{content['Value']}"
720
+ elif action == "hotkey":
721
+ if isinstance(content.get("Value"), list):
722
+ keys = content["Value"]
723
+ action_type = f"hotkey_{'_'.join(keys)}"
724
+ else:
725
+ # Simply split string format like "command+space" into a list
726
+ keys = [k.strip() for k in content["Value"].lower().split("+")]
727
+ action_type = f"hotkey_{content['Value'].replace('+', '_')}"
728
+ logger.info(f"Preparing hotkey with keys: {keys}")
729
+ # Get the method but call it with *args instead of **kwargs
730
+ method = getattr(self.computer, action)
731
+ await method(*keys) # Unpack the keys list as positional arguments
732
+ logger.info(f"Tool execution completed successfully: {action}")
733
+
734
+ # For hotkeys, take a screenshot after the action
735
+ try:
736
+ # Get a new screenshot after the action and save it with the action type
737
+ new_parsed_screen = await self._get_parsed_screen_som(save_screenshot=False)
738
+ if new_parsed_screen and new_parsed_screen.annotated_image_base64:
739
+ img_data = new_parsed_screen.annotated_image_base64
740
+ # Remove data URL prefix if present
741
+ if img_data.startswith("data:image"):
742
+ img_data = img_data.split(",")[1]
743
+ # Save with action type to indicate this is a post-action screenshot
744
+ self._save_screenshot(img_data, action_type=action_type)
745
+ action_screenshot_saved = True
746
+ except Exception as screenshot_error:
747
+ logger.error(
748
+ f"Error taking post-hotkey screenshot: {str(screenshot_error)}"
749
+ )
750
+
751
+ return
752
+
753
+ elif action in ["scroll_down", "scroll_up"]:
754
+ clicks = int(content.get("amount", 1))
755
+ kwargs["clicks"] = clicks
756
+ action_type = f"scroll_{action.split('_')[1]}_{clicks}"
757
+
758
+ # Visualize scrolling if screenshot is available
759
+ if parsed_screen.annotated_image_base64:
760
+ img_data = parsed_screen.annotated_image_base64
761
+ # Remove data URL prefix if present
762
+ if img_data.startswith("data:image"):
763
+ img_data = img_data.split(",")[1]
764
+ direction = "down" if action == "scroll_down" else "up"
765
+ # For scrolling, we only save the visualization to avoid duplicate images
766
+ self._visualize_scroll(direction, clicks, img_data)
767
+ action_screenshot_saved = True
768
+
769
+ else:
770
+ logger.warning(f"Unknown action: {action}")
771
+ return
772
+
773
+ # Execute tool and handle result
774
+ try:
775
+ method = getattr(self.computer, action)
776
+ logger.info(f"Found method for action '{action}': {method}")
777
+ await method(**kwargs)
778
+ logger.info(f"Tool execution completed successfully: {action}")
779
+
780
+ # For non-coordinate based actions that don't already have visualizations,
781
+ # take a new screenshot after the action
782
+ if not action_screenshot_saved:
783
+ # Take a new screenshot
784
+ try:
785
+ # Get a new screenshot after the action and save it with the action type
786
+ new_parsed_screen = await self._get_parsed_screen_som(
787
+ save_screenshot=False
788
+ )
789
+ if new_parsed_screen and new_parsed_screen.annotated_image_base64:
790
+ img_data = new_parsed_screen.annotated_image_base64
791
+ # Remove data URL prefix if present
792
+ if img_data.startswith("data:image"):
793
+ img_data = img_data.split(",")[1]
794
+ # Save with action type to indicate this is a post-action screenshot
795
+ if "action_type" in locals():
796
+ self._save_screenshot(img_data, action_type=action_type)
797
+ else:
798
+ self._save_screenshot(img_data, action_type=action)
799
+ # Update the action screenshot flag for this turn
800
+ action_screenshot_saved = True
801
+ except Exception as screenshot_error:
802
+ logger.error(
803
+ f"Error taking post-action screenshot: {str(screenshot_error)}"
804
+ )
805
+
806
+ except AttributeError as e:
807
+ logger.error(f"Method not found for action '{action}': {str(e)}")
808
+ return
809
+ except Exception as tool_error:
810
+ logger.error(f"Tool execution failed: {str(tool_error)}")
811
+ return
812
+
813
+ except Exception as e:
814
+ logger.error(f"Error executing action {action}: {str(e)}")
815
+ return
816
+
817
+ except Exception as e:
818
+ logger.error(f"Error in _execute_action: {str(e)}")
819
+ return
820
+
821
+ async def _calculate_click_coordinates(
822
+ self, box_id: int, parsed_screen: ParseResult
823
+ ) -> Tuple[int, int]:
824
+ """Calculate click coordinates based on box ID.
825
+
826
+ Args:
827
+ box_id: The ID of the box to click
828
+ parsed_screen: The parsed screen information
829
+
830
+ Returns:
831
+ Tuple of (x, y) coordinates
832
+
833
+ Raises:
834
+ ValueError: If box_id is invalid or missing from parsed screen
835
+ """
836
+ # First try to use structured elements data
837
+ logger.info(f"Elements count: {len(parsed_screen.elements)}")
838
+
839
+ # Try to find element with matching ID
840
+ for element in parsed_screen.elements:
841
+ if element.id == box_id:
842
+ logger.info(f"Found element with ID {box_id}: {element}")
843
+ bbox = element.bbox
844
+
845
+ # Get screen dimensions from the metadata if available, or fallback
846
+ width = parsed_screen.metadata.width if parsed_screen.metadata else 1920
847
+ height = parsed_screen.metadata.height if parsed_screen.metadata else 1080
848
+ logger.info(f"Screen dimensions: width={width}, height={height}")
849
+
850
+ # Calculate center of the box in pixels
851
+ center_x = int((bbox.x1 + bbox.x2) / 2 * width)
852
+ center_y = int((bbox.y1 + bbox.y2) / 2 * height)
853
+ logger.info(f"Calculated center: ({center_x}, {center_y})")
854
+
855
+ # Validate coordinates - if they're (0,0) or unreasonably small,
856
+ # use a default position in the center of the screen
857
+ if center_x == 0 and center_y == 0:
858
+ logger.warning("Got (0,0) coordinates, using fallback position")
859
+ center_x = width // 2
860
+ center_y = height // 2
861
+ logger.info(f"Using fallback center: ({center_x}, {center_y})")
862
+
863
+ return center_x, center_y
864
+
865
+ # If we couldn't find the box, use center of screen
866
+ logger.error(
867
+ f"Box ID {box_id} not found in structured elements (count={len(parsed_screen.elements)})"
868
+ )
869
+
870
+ # Use center of screen as fallback
871
+ width = parsed_screen.metadata.width if parsed_screen.metadata else 1920
872
+ height = parsed_screen.metadata.height if parsed_screen.metadata else 1080
873
+ logger.warning(f"Using fallback position in center of screen ({width//2}, {height//2})")
874
+ return width // 2, height // 2
875
+
876
+ async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]:
877
+ """Run the agent loop with provided messages.
878
+
879
+ Args:
880
+ messages: List of message objects
881
+
882
+ Yields:
883
+ Dict containing response data
884
+ """
885
+ # Keep track of conversation history
886
+ conversation_history = messages.copy()
887
+
888
+ # Continue running until explicitly told to stop
889
+ running = True
890
+ turn_created = False
891
+ # Track if an action-specific screenshot has been saved this turn
892
+ action_screenshot_saved = False
893
+
894
+ attempt = 0
895
+ max_attempts = 3
896
+
897
+ while running and attempt < max_attempts:
898
+ try:
899
+ # Create a new turn directory if it's not already created
900
+ if not turn_created:
901
+ self._create_turn_dir()
902
+ turn_created = True
903
+
904
+ # Ensure client is initialized
905
+ if self.client is None:
906
+ logger.info("Initializing client...")
907
+ await self.initialize_client()
908
+ if self.client is None:
909
+ raise RuntimeError("Failed to initialize client")
910
+ logger.info("Client initialized successfully")
911
+
912
+ # Get up-to-date screen information
913
+ parsed_screen = await self._get_parsed_screen_som()
914
+
915
+ # Process screen info and update messages
916
+ await self._process_screen(parsed_screen, conversation_history)
917
+
918
+ # Get system prompt
919
+ system_prompt = self._get_system_prompt()
920
+
921
+ # Make API call with retries
922
+ response = await self._make_api_call(conversation_history, system_prompt)
923
+
924
+ # Handle the response (may execute actions)
925
+ # Returns: (should_continue, action_screenshot_saved)
926
+ should_continue, new_screenshot_saved = await self._handle_response(
927
+ response, conversation_history, parsed_screen
928
+ )
929
+
930
+ # Update whether an action screenshot was saved this turn
931
+ action_screenshot_saved = action_screenshot_saved or new_screenshot_saved
932
+
933
+ # Yield the response to the caller
934
+ yield {"response": response}
935
+
936
+ # Check if we should continue this conversation
937
+ running = should_continue
938
+
939
+ # Create a new turn directory if we're continuing
940
+ if running:
941
+ turn_created = False
942
+
943
+ # Reset attempt counter on success
944
+ attempt = 0
945
+
946
+ except Exception as e:
947
+ attempt += 1
948
+ error_msg = f"Error in run method (attempt {attempt}/{max_attempts}): {str(e)}"
949
+ logger.error(error_msg)
950
+
951
+ # If this is our last attempt, provide more info about the error
952
+ if attempt >= max_attempts:
953
+ logger.error(f"Maximum retry attempts reached. Last error was: {str(e)}")
954
+
955
+ yield {
956
+ "error": str(e),
957
+ "metadata": {"title": "❌ Error"},
958
+ }
959
+
960
+ # Create a brief delay before retrying
961
+ await asyncio.sleep(1)