cua-agent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/README.md +63 -0
- agent/__init__.py +10 -0
- agent/core/README.md +101 -0
- agent/core/__init__.py +34 -0
- agent/core/agent.py +284 -0
- agent/core/base_agent.py +164 -0
- agent/core/callbacks.py +147 -0
- agent/core/computer_agent.py +69 -0
- agent/core/experiment.py +222 -0
- agent/core/factory.py +102 -0
- agent/core/loop.py +244 -0
- agent/core/messages.py +230 -0
- agent/core/tools/__init__.py +21 -0
- agent/core/tools/base.py +74 -0
- agent/core/tools/bash.py +52 -0
- agent/core/tools/collection.py +46 -0
- agent/core/tools/computer.py +113 -0
- agent/core/tools/edit.py +67 -0
- agent/core/tools/manager.py +56 -0
- agent/providers/__init__.py +4 -0
- agent/providers/anthropic/__init__.py +6 -0
- agent/providers/anthropic/api/client.py +222 -0
- agent/providers/anthropic/api/logging.py +150 -0
- agent/providers/anthropic/callbacks/manager.py +55 -0
- agent/providers/anthropic/loop.py +521 -0
- agent/providers/anthropic/messages/manager.py +110 -0
- agent/providers/anthropic/prompts.py +20 -0
- agent/providers/anthropic/tools/__init__.py +33 -0
- agent/providers/anthropic/tools/base.py +88 -0
- agent/providers/anthropic/tools/bash.py +163 -0
- agent/providers/anthropic/tools/collection.py +34 -0
- agent/providers/anthropic/tools/computer.py +550 -0
- agent/providers/anthropic/tools/edit.py +326 -0
- agent/providers/anthropic/tools/manager.py +54 -0
- agent/providers/anthropic/tools/run.py +42 -0
- agent/providers/anthropic/types.py +16 -0
- agent/providers/omni/__init__.py +27 -0
- agent/providers/omni/callbacks.py +78 -0
- agent/providers/omni/clients/anthropic.py +99 -0
- agent/providers/omni/clients/base.py +44 -0
- agent/providers/omni/clients/groq.py +101 -0
- agent/providers/omni/clients/openai.py +159 -0
- agent/providers/omni/clients/utils.py +25 -0
- agent/providers/omni/experiment.py +273 -0
- agent/providers/omni/image_utils.py +106 -0
- agent/providers/omni/loop.py +961 -0
- agent/providers/omni/messages.py +168 -0
- agent/providers/omni/parser.py +252 -0
- agent/providers/omni/prompts.py +78 -0
- agent/providers/omni/tool_manager.py +91 -0
- agent/providers/omni/tools/__init__.py +13 -0
- agent/providers/omni/tools/bash.py +69 -0
- agent/providers/omni/tools/computer.py +216 -0
- agent/providers/omni/tools/manager.py +83 -0
- agent/providers/omni/types.py +30 -0
- agent/providers/omni/utils.py +155 -0
- agent/providers/omni/visualization.py +130 -0
- agent/types/__init__.py +26 -0
- agent/types/base.py +52 -0
- agent/types/messages.py +36 -0
- agent/types/tools.py +32 -0
- cua_agent-0.1.0.dist-info/METADATA +44 -0
- cua_agent-0.1.0.dist-info/RECORD +65 -0
- cua_agent-0.1.0.dist-info/WHEEL +4 -0
- cua_agent-0.1.0.dist-info/entry_points.txt +4 -0
|
@@ -0,0 +1,961 @@
|
|
|
1
|
+
"""Omni-specific agent loop implementation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple, AsyncGenerator, Union
|
|
5
|
+
import base64
|
|
6
|
+
from PIL import Image
|
|
7
|
+
from io import BytesIO
|
|
8
|
+
import json
|
|
9
|
+
import re
|
|
10
|
+
import os
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
import asyncio
|
|
13
|
+
from httpx import ConnectError, ReadTimeout
|
|
14
|
+
import shutil
|
|
15
|
+
import copy
|
|
16
|
+
|
|
17
|
+
from .parser import OmniParser, ParseResult, ParserMetadata, UIElement
|
|
18
|
+
from ...core.loop import BaseLoop
|
|
19
|
+
from computer import Computer
|
|
20
|
+
from .types import APIProvider
|
|
21
|
+
from .clients.base import BaseOmniClient
|
|
22
|
+
from .clients.openai import OpenAIClient
|
|
23
|
+
from .clients.groq import GroqClient
|
|
24
|
+
from .clients.anthropic import AnthropicClient
|
|
25
|
+
from .prompts import SYSTEM_PROMPT
|
|
26
|
+
from .utils import compress_image_base64
|
|
27
|
+
from .visualization import visualize_click, visualize_scroll, calculate_element_center
|
|
28
|
+
from .image_utils import decode_base64_image, clean_base64_data
|
|
29
|
+
from ...core.messages import ImageRetentionConfig
|
|
30
|
+
from .messages import OmniMessageManager
|
|
31
|
+
|
|
32
|
+
logging.basicConfig(level=logging.INFO)
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def extract_data(input_string: str, data_type: str) -> str:
|
|
37
|
+
"""Extract content from code blocks."""
|
|
38
|
+
pattern = f"```{data_type}" + r"(.*?)(```|$)"
|
|
39
|
+
matches = re.findall(pattern, input_string, re.DOTALL)
|
|
40
|
+
return matches[0][0].strip() if matches else input_string
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class OmniLoop(BaseLoop):
|
|
44
|
+
"""Omni-specific implementation of the agent loop."""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
parser: OmniParser,
|
|
49
|
+
provider: APIProvider,
|
|
50
|
+
api_key: str,
|
|
51
|
+
model: str,
|
|
52
|
+
computer: Computer,
|
|
53
|
+
only_n_most_recent_images: Optional[int] = 2,
|
|
54
|
+
base_dir: Optional[str] = "trajectories",
|
|
55
|
+
max_retries: int = 3,
|
|
56
|
+
retry_delay: float = 1.0,
|
|
57
|
+
save_trajectory: bool = True,
|
|
58
|
+
**kwargs,
|
|
59
|
+
):
|
|
60
|
+
"""Initialize the loop.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
parser: Parser instance
|
|
64
|
+
provider: API provider
|
|
65
|
+
api_key: API key
|
|
66
|
+
model: Model name
|
|
67
|
+
computer: Computer instance
|
|
68
|
+
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
|
|
69
|
+
base_dir: Base directory for saving experiment data
|
|
70
|
+
max_retries: Maximum number of retries for API calls
|
|
71
|
+
retry_delay: Delay between retries in seconds
|
|
72
|
+
save_trajectory: Whether to save trajectory data
|
|
73
|
+
"""
|
|
74
|
+
# Set parser and provider before initializing base class
|
|
75
|
+
self.parser = parser
|
|
76
|
+
self.provider = provider
|
|
77
|
+
|
|
78
|
+
# Initialize message manager with image retention config
|
|
79
|
+
image_retention_config = ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
|
|
80
|
+
self.message_manager = OmniMessageManager(config=image_retention_config)
|
|
81
|
+
|
|
82
|
+
# Initialize base class (which will set up experiment manager)
|
|
83
|
+
super().__init__(
|
|
84
|
+
computer=computer,
|
|
85
|
+
model=model,
|
|
86
|
+
api_key=api_key,
|
|
87
|
+
max_retries=max_retries,
|
|
88
|
+
retry_delay=retry_delay,
|
|
89
|
+
base_dir=base_dir,
|
|
90
|
+
save_trajectory=save_trajectory,
|
|
91
|
+
only_n_most_recent_images=only_n_most_recent_images,
|
|
92
|
+
**kwargs,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Set API client attributes
|
|
96
|
+
self.client = None
|
|
97
|
+
self.retry_count = 0
|
|
98
|
+
|
|
99
|
+
def _should_save_debug_image(self) -> bool:
|
|
100
|
+
"""Check if debug images should be saved.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
bool: Always returns False as debug image saving has been disabled.
|
|
104
|
+
"""
|
|
105
|
+
# Debug image saving functionality has been removed
|
|
106
|
+
return False
|
|
107
|
+
|
|
108
|
+
def _extract_and_save_images(self, data: Any, prefix: str) -> None:
|
|
109
|
+
"""Extract and save images from API data.
|
|
110
|
+
|
|
111
|
+
This method is now a no-op as image extraction functionality has been removed.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
data: Data to extract images from
|
|
115
|
+
prefix: Prefix for the extracted image filenames
|
|
116
|
+
"""
|
|
117
|
+
# Image extraction functionality has been removed
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
def _save_debug_image(self, image_data: str, filename: str) -> None:
|
|
121
|
+
"""Save a debug image to the current turn directory.
|
|
122
|
+
|
|
123
|
+
This method is now a no-op as debug image saving functionality has been removed.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
image_data: Base64 encoded image data
|
|
127
|
+
filename: Name to use for the saved image
|
|
128
|
+
"""
|
|
129
|
+
# Debug image saving functionality has been removed
|
|
130
|
+
return
|
|
131
|
+
|
|
132
|
+
def _visualize_action(self, x: int, y: int, img_base64: str) -> None:
|
|
133
|
+
"""Visualize an action by drawing on the screenshot."""
|
|
134
|
+
if (
|
|
135
|
+
not self.save_trajectory
|
|
136
|
+
or not hasattr(self, "experiment_manager")
|
|
137
|
+
or not self.experiment_manager
|
|
138
|
+
):
|
|
139
|
+
return
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
# Use the visualization utility
|
|
143
|
+
img = visualize_click(x, y, img_base64)
|
|
144
|
+
|
|
145
|
+
# Save the visualization
|
|
146
|
+
self.experiment_manager.save_action_visualization(img, "click", f"x{x}_y{y}")
|
|
147
|
+
except Exception as e:
|
|
148
|
+
logger.error(f"Error visualizing action: {str(e)}")
|
|
149
|
+
|
|
150
|
+
def _visualize_scroll(self, direction: str, clicks: int, img_base64: str) -> None:
|
|
151
|
+
"""Visualize a scroll action by drawing arrows on the screenshot."""
|
|
152
|
+
if (
|
|
153
|
+
not self.save_trajectory
|
|
154
|
+
or not hasattr(self, "experiment_manager")
|
|
155
|
+
or not self.experiment_manager
|
|
156
|
+
):
|
|
157
|
+
return
|
|
158
|
+
|
|
159
|
+
try:
|
|
160
|
+
# Use the visualization utility
|
|
161
|
+
img = visualize_scroll(direction, clicks, img_base64)
|
|
162
|
+
|
|
163
|
+
# Save the visualization
|
|
164
|
+
self.experiment_manager.save_action_visualization(
|
|
165
|
+
img, "scroll", f"{direction}_{clicks}"
|
|
166
|
+
)
|
|
167
|
+
except Exception as e:
|
|
168
|
+
logger.error(f"Error visualizing scroll: {str(e)}")
|
|
169
|
+
|
|
170
|
+
def _save_action_visualization(
|
|
171
|
+
self, img: Image.Image, action_name: str, details: str = ""
|
|
172
|
+
) -> str:
|
|
173
|
+
"""Save a visualization of an action."""
|
|
174
|
+
if hasattr(self, "experiment_manager") and self.experiment_manager:
|
|
175
|
+
return self.experiment_manager.save_action_visualization(img, action_name, details)
|
|
176
|
+
return ""
|
|
177
|
+
|
|
178
|
+
async def initialize_client(self) -> None:
|
|
179
|
+
"""Initialize the appropriate client based on provider."""
|
|
180
|
+
try:
|
|
181
|
+
logger.info(f"Initializing {self.provider} client with model {self.model}...")
|
|
182
|
+
|
|
183
|
+
if self.provider == APIProvider.OPENAI:
|
|
184
|
+
self.client = OpenAIClient(api_key=self.api_key, model=self.model)
|
|
185
|
+
elif self.provider == APIProvider.GROQ:
|
|
186
|
+
self.client = GroqClient(api_key=self.api_key, model=self.model)
|
|
187
|
+
elif self.provider == APIProvider.ANTHROPIC:
|
|
188
|
+
self.client = AnthropicClient(
|
|
189
|
+
api_key=self.api_key,
|
|
190
|
+
model=self.model,
|
|
191
|
+
max_retries=self.max_retries,
|
|
192
|
+
retry_delay=self.retry_delay,
|
|
193
|
+
)
|
|
194
|
+
else:
|
|
195
|
+
raise ValueError(f"Unsupported provider: {self.provider}")
|
|
196
|
+
|
|
197
|
+
logger.info(f"Initialized {self.provider} client with model {self.model}")
|
|
198
|
+
except Exception as e:
|
|
199
|
+
logger.error(f"Error initializing client: {str(e)}")
|
|
200
|
+
self.client = None
|
|
201
|
+
raise RuntimeError(f"Failed to initialize client: {str(e)}")
|
|
202
|
+
|
|
203
|
+
async def _make_api_call(self, messages: List[Dict[str, Any]], system_prompt: str) -> Any:
|
|
204
|
+
"""Make API call to provider with retry logic."""
|
|
205
|
+
# Create new turn directory for this API call
|
|
206
|
+
self._create_turn_dir()
|
|
207
|
+
|
|
208
|
+
request_data = None
|
|
209
|
+
last_error = None
|
|
210
|
+
|
|
211
|
+
for attempt in range(self.max_retries):
|
|
212
|
+
try:
|
|
213
|
+
# Ensure client is initialized
|
|
214
|
+
if self.client is None:
|
|
215
|
+
logger.info(
|
|
216
|
+
f"Client not initialized in _make_api_call (attempt {attempt+1}), initializing now..."
|
|
217
|
+
)
|
|
218
|
+
await self.initialize_client()
|
|
219
|
+
if self.client is None:
|
|
220
|
+
raise RuntimeError("Failed to initialize client")
|
|
221
|
+
|
|
222
|
+
# Apply image retention and prepare messages
|
|
223
|
+
# This will limit the number of images based on only_n_most_recent_images
|
|
224
|
+
prepared_messages = self.message_manager.prepare_messages(messages.copy())
|
|
225
|
+
|
|
226
|
+
# Filter out system messages for Anthropic
|
|
227
|
+
if self.provider == APIProvider.ANTHROPIC:
|
|
228
|
+
filtered_messages = [
|
|
229
|
+
msg for msg in prepared_messages if msg["role"] != "system"
|
|
230
|
+
]
|
|
231
|
+
else:
|
|
232
|
+
filtered_messages = prepared_messages
|
|
233
|
+
|
|
234
|
+
# Log request
|
|
235
|
+
request_data = {"messages": filtered_messages, "max_tokens": self.max_tokens}
|
|
236
|
+
|
|
237
|
+
if self.provider == APIProvider.ANTHROPIC:
|
|
238
|
+
request_data["system"] = self._get_system_prompt()
|
|
239
|
+
else:
|
|
240
|
+
request_data["system"] = system_prompt
|
|
241
|
+
|
|
242
|
+
self._log_api_call("request", request_data)
|
|
243
|
+
|
|
244
|
+
# Make API call with appropriate parameters
|
|
245
|
+
if self.client is None:
|
|
246
|
+
raise RuntimeError("Client not initialized. Call initialize_client() first.")
|
|
247
|
+
|
|
248
|
+
# Check if the method is async by inspecting the client implementation
|
|
249
|
+
run_method = self.client.run_interleaved
|
|
250
|
+
is_async = asyncio.iscoroutinefunction(run_method)
|
|
251
|
+
|
|
252
|
+
if is_async:
|
|
253
|
+
# For async implementations (AnthropicClient)
|
|
254
|
+
if self.provider == APIProvider.ANTHROPIC:
|
|
255
|
+
response = await run_method(
|
|
256
|
+
messages=filtered_messages,
|
|
257
|
+
system=self._get_system_prompt(),
|
|
258
|
+
max_tokens=self.max_tokens,
|
|
259
|
+
)
|
|
260
|
+
else:
|
|
261
|
+
response = await run_method(
|
|
262
|
+
messages=messages,
|
|
263
|
+
system=system_prompt,
|
|
264
|
+
max_tokens=self.max_tokens,
|
|
265
|
+
)
|
|
266
|
+
else:
|
|
267
|
+
# For non-async implementations (GroqClient, etc.)
|
|
268
|
+
if self.provider == APIProvider.ANTHROPIC:
|
|
269
|
+
response = run_method(
|
|
270
|
+
messages=filtered_messages,
|
|
271
|
+
system=self._get_system_prompt(),
|
|
272
|
+
max_tokens=self.max_tokens,
|
|
273
|
+
)
|
|
274
|
+
else:
|
|
275
|
+
response = run_method(
|
|
276
|
+
messages=messages,
|
|
277
|
+
system=system_prompt,
|
|
278
|
+
max_tokens=self.max_tokens,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# Log success response
|
|
282
|
+
self._log_api_call("response", request_data, response)
|
|
283
|
+
|
|
284
|
+
return response
|
|
285
|
+
|
|
286
|
+
except (ConnectError, ReadTimeout) as e:
|
|
287
|
+
last_error = e
|
|
288
|
+
logger.warning(
|
|
289
|
+
f"Connection error on attempt {attempt + 1}/{self.max_retries}: {str(e)}"
|
|
290
|
+
)
|
|
291
|
+
if attempt < self.max_retries - 1:
|
|
292
|
+
await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
|
|
293
|
+
# Reset client on connection errors to force re-initialization
|
|
294
|
+
self.client = None
|
|
295
|
+
continue
|
|
296
|
+
|
|
297
|
+
except RuntimeError as e:
|
|
298
|
+
# Handle client initialization errors specifically
|
|
299
|
+
last_error = e
|
|
300
|
+
self._log_api_call("error", request_data, error=e)
|
|
301
|
+
logger.error(
|
|
302
|
+
f"Client initialization error (attempt {attempt + 1}/{self.max_retries}): {str(e)}"
|
|
303
|
+
)
|
|
304
|
+
if attempt < self.max_retries - 1:
|
|
305
|
+
# Reset client to force re-initialization
|
|
306
|
+
self.client = None
|
|
307
|
+
await asyncio.sleep(self.retry_delay)
|
|
308
|
+
continue
|
|
309
|
+
|
|
310
|
+
except Exception as e:
|
|
311
|
+
# Log unexpected error
|
|
312
|
+
last_error = e
|
|
313
|
+
self._log_api_call("error", request_data, error=e)
|
|
314
|
+
logger.error(f"Unexpected error in API call: {str(e)}")
|
|
315
|
+
if attempt < self.max_retries - 1:
|
|
316
|
+
await asyncio.sleep(self.retry_delay)
|
|
317
|
+
continue
|
|
318
|
+
|
|
319
|
+
# If we get here, all retries failed
|
|
320
|
+
error_message = f"API call failed after {self.max_retries} attempts"
|
|
321
|
+
if last_error:
|
|
322
|
+
error_message += f": {str(last_error)}"
|
|
323
|
+
|
|
324
|
+
logger.error(error_message)
|
|
325
|
+
raise RuntimeError(error_message)
|
|
326
|
+
|
|
327
|
+
async def _handle_response(
|
|
328
|
+
self, response: Any, messages: List[Dict[str, Any]], parsed_screen: Dict[str, Any]
|
|
329
|
+
) -> Tuple[bool, bool]:
|
|
330
|
+
"""Handle API response.
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
Tuple of (should_continue, action_screenshot_saved)
|
|
334
|
+
"""
|
|
335
|
+
action_screenshot_saved = False
|
|
336
|
+
try:
|
|
337
|
+
# Handle Anthropic response format
|
|
338
|
+
if self.provider == APIProvider.ANTHROPIC:
|
|
339
|
+
if hasattr(response, "content") and isinstance(response.content, list):
|
|
340
|
+
# Extract text from content blocks
|
|
341
|
+
for block in response.content:
|
|
342
|
+
if hasattr(block, "type") and block.type == "text":
|
|
343
|
+
content = block.text
|
|
344
|
+
|
|
345
|
+
# Try to find JSON in the content
|
|
346
|
+
try:
|
|
347
|
+
# First look for JSON block
|
|
348
|
+
json_content = extract_data(content, "json")
|
|
349
|
+
parsed_content = json.loads(json_content)
|
|
350
|
+
logger.info("Successfully parsed JSON from code block")
|
|
351
|
+
except (json.JSONDecodeError, IndexError):
|
|
352
|
+
# If no JSON block, try to find JSON object in the text
|
|
353
|
+
try:
|
|
354
|
+
# Look for JSON object pattern
|
|
355
|
+
json_pattern = r"\{[^}]+\}"
|
|
356
|
+
json_match = re.search(json_pattern, content)
|
|
357
|
+
if json_match:
|
|
358
|
+
json_str = json_match.group(0)
|
|
359
|
+
parsed_content = json.loads(json_str)
|
|
360
|
+
logger.info("Successfully parsed JSON from text")
|
|
361
|
+
else:
|
|
362
|
+
logger.error(f"No JSON found in content: {content}")
|
|
363
|
+
continue
|
|
364
|
+
except json.JSONDecodeError as e:
|
|
365
|
+
logger.error(f"Failed to parse JSON from text: {str(e)}")
|
|
366
|
+
continue
|
|
367
|
+
|
|
368
|
+
# Clean up Box ID format
|
|
369
|
+
if "Box ID" in parsed_content and isinstance(
|
|
370
|
+
parsed_content["Box ID"], str
|
|
371
|
+
):
|
|
372
|
+
parsed_content["Box ID"] = parsed_content["Box ID"].replace(
|
|
373
|
+
"Box #", ""
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
# Add any explanatory text as reasoning if not present
|
|
377
|
+
if "Explanation" not in parsed_content:
|
|
378
|
+
# Extract any text before the JSON as reasoning
|
|
379
|
+
text_before_json = content.split("{")[0].strip()
|
|
380
|
+
if text_before_json:
|
|
381
|
+
parsed_content["Explanation"] = text_before_json
|
|
382
|
+
|
|
383
|
+
# Log the parsed content for debugging
|
|
384
|
+
logger.info(f"Parsed content: {json.dumps(parsed_content, indent=2)}")
|
|
385
|
+
|
|
386
|
+
# Add response to messages
|
|
387
|
+
messages.append(
|
|
388
|
+
{"role": "assistant", "content": json.dumps(parsed_content)}
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
try:
|
|
392
|
+
# Execute action with current parsed screen info
|
|
393
|
+
await self._execute_action(parsed_content, parsed_screen)
|
|
394
|
+
action_screenshot_saved = True
|
|
395
|
+
except Exception as e:
|
|
396
|
+
logger.error(f"Error executing action: {str(e)}")
|
|
397
|
+
# Add error message to conversation
|
|
398
|
+
messages.append(
|
|
399
|
+
{
|
|
400
|
+
"role": "assistant",
|
|
401
|
+
"content": f"Error executing action: {str(e)}",
|
|
402
|
+
"metadata": {"title": "❌ Error"},
|
|
403
|
+
}
|
|
404
|
+
)
|
|
405
|
+
return False, action_screenshot_saved
|
|
406
|
+
|
|
407
|
+
# Check if task is complete
|
|
408
|
+
if parsed_content.get("Action") == "None":
|
|
409
|
+
return False, action_screenshot_saved
|
|
410
|
+
return True, action_screenshot_saved
|
|
411
|
+
|
|
412
|
+
logger.warning("No text block found in Anthropic response")
|
|
413
|
+
return True, action_screenshot_saved
|
|
414
|
+
|
|
415
|
+
# Handle other providers' response formats
|
|
416
|
+
if isinstance(response, dict) and "choices" in response:
|
|
417
|
+
content = response["choices"][0]["message"]["content"]
|
|
418
|
+
else:
|
|
419
|
+
content = response
|
|
420
|
+
|
|
421
|
+
# Parse JSON content
|
|
422
|
+
if isinstance(content, str):
|
|
423
|
+
try:
|
|
424
|
+
# First try to parse the whole content as JSON
|
|
425
|
+
parsed_content = json.loads(content)
|
|
426
|
+
except json.JSONDecodeError:
|
|
427
|
+
try:
|
|
428
|
+
# Try to find JSON block
|
|
429
|
+
json_content = extract_data(content, "json")
|
|
430
|
+
parsed_content = json.loads(json_content)
|
|
431
|
+
except (json.JSONDecodeError, IndexError):
|
|
432
|
+
try:
|
|
433
|
+
# Look for JSON object pattern
|
|
434
|
+
json_pattern = r"\{[^}]+\}"
|
|
435
|
+
json_match = re.search(json_pattern, content)
|
|
436
|
+
if json_match:
|
|
437
|
+
json_str = json_match.group(0)
|
|
438
|
+
parsed_content = json.loads(json_str)
|
|
439
|
+
else:
|
|
440
|
+
logger.error(f"No JSON found in content: {content}")
|
|
441
|
+
return True, action_screenshot_saved
|
|
442
|
+
except json.JSONDecodeError as e:
|
|
443
|
+
logger.error(f"Failed to parse JSON from text: {str(e)}")
|
|
444
|
+
return True, action_screenshot_saved
|
|
445
|
+
|
|
446
|
+
# Clean up Box ID format
|
|
447
|
+
if "Box ID" in parsed_content and isinstance(parsed_content["Box ID"], str):
|
|
448
|
+
parsed_content["Box ID"] = parsed_content["Box ID"].replace("Box #", "")
|
|
449
|
+
|
|
450
|
+
# Add any explanatory text as reasoning if not present
|
|
451
|
+
if "Explanation" not in parsed_content:
|
|
452
|
+
# Extract any text before the JSON as reasoning
|
|
453
|
+
text_before_json = content.split("{")[0].strip()
|
|
454
|
+
if text_before_json:
|
|
455
|
+
parsed_content["Explanation"] = text_before_json
|
|
456
|
+
|
|
457
|
+
# Add response to messages with stringified content
|
|
458
|
+
messages.append({"role": "assistant", "content": json.dumps(parsed_content)})
|
|
459
|
+
|
|
460
|
+
try:
|
|
461
|
+
# Execute action with current parsed screen info
|
|
462
|
+
await self._execute_action(parsed_content, parsed_screen)
|
|
463
|
+
action_screenshot_saved = True
|
|
464
|
+
except Exception as e:
|
|
465
|
+
logger.error(f"Error executing action: {str(e)}")
|
|
466
|
+
# Add error message to conversation
|
|
467
|
+
messages.append(
|
|
468
|
+
{
|
|
469
|
+
"role": "assistant",
|
|
470
|
+
"content": f"Error executing action: {str(e)}",
|
|
471
|
+
"metadata": {"title": "❌ Error"},
|
|
472
|
+
}
|
|
473
|
+
)
|
|
474
|
+
return False, action_screenshot_saved
|
|
475
|
+
|
|
476
|
+
# Check if task is complete
|
|
477
|
+
if parsed_content.get("Action") == "None":
|
|
478
|
+
return False, action_screenshot_saved
|
|
479
|
+
|
|
480
|
+
return True, action_screenshot_saved
|
|
481
|
+
elif isinstance(content, dict):
|
|
482
|
+
# Handle case where content is already a dictionary
|
|
483
|
+
messages.append({"role": "assistant", "content": json.dumps(content)})
|
|
484
|
+
|
|
485
|
+
try:
|
|
486
|
+
# Execute action with current parsed screen info
|
|
487
|
+
await self._execute_action(content, parsed_screen)
|
|
488
|
+
action_screenshot_saved = True
|
|
489
|
+
except Exception as e:
|
|
490
|
+
logger.error(f"Error executing action: {str(e)}")
|
|
491
|
+
# Add error message to conversation
|
|
492
|
+
messages.append(
|
|
493
|
+
{
|
|
494
|
+
"role": "assistant",
|
|
495
|
+
"content": f"Error executing action: {str(e)}",
|
|
496
|
+
"metadata": {"title": "❌ Error"},
|
|
497
|
+
}
|
|
498
|
+
)
|
|
499
|
+
return False, action_screenshot_saved
|
|
500
|
+
|
|
501
|
+
# Check if task is complete
|
|
502
|
+
if content.get("Action") == "None":
|
|
503
|
+
return False, action_screenshot_saved
|
|
504
|
+
|
|
505
|
+
return True, action_screenshot_saved
|
|
506
|
+
|
|
507
|
+
return True, action_screenshot_saved
|
|
508
|
+
|
|
509
|
+
except Exception as e:
|
|
510
|
+
logger.error(f"Error handling response: {str(e)}")
|
|
511
|
+
messages.append(
|
|
512
|
+
{
|
|
513
|
+
"role": "assistant",
|
|
514
|
+
"content": f"Error: {str(e)}",
|
|
515
|
+
"metadata": {"title": "❌ Error"},
|
|
516
|
+
}
|
|
517
|
+
)
|
|
518
|
+
raise
|
|
519
|
+
|
|
520
|
+
async def _get_parsed_screen_som(self, save_screenshot: bool = True) -> ParseResult:
|
|
521
|
+
"""Get parsed screen information with SOM.
|
|
522
|
+
|
|
523
|
+
Args:
|
|
524
|
+
save_screenshot: Whether to save the screenshot (set to False when screenshots will be saved elsewhere)
|
|
525
|
+
|
|
526
|
+
Returns:
|
|
527
|
+
ParseResult containing screen information and elements
|
|
528
|
+
"""
|
|
529
|
+
try:
|
|
530
|
+
# Use the parser's parse_screen method which handles the screenshot internally
|
|
531
|
+
parsed_screen = await self.parser.parse_screen(computer=self.computer)
|
|
532
|
+
|
|
533
|
+
# Log information about the parsed results
|
|
534
|
+
logger.info(
|
|
535
|
+
f"Parsed screen with {len(parsed_screen.elements) if parsed_screen.elements else 0} elements"
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
# Save screenshot if requested and if we have image data
|
|
539
|
+
if save_screenshot and self.save_trajectory and parsed_screen.annotated_image_base64:
|
|
540
|
+
try:
|
|
541
|
+
# Extract just the image data (remove data:image/png;base64, prefix)
|
|
542
|
+
img_data = parsed_screen.annotated_image_base64
|
|
543
|
+
if "," in img_data:
|
|
544
|
+
img_data = img_data.split(",")[1]
|
|
545
|
+
# Save with a generic "state" action type to indicate this is the current screen state
|
|
546
|
+
self._save_screenshot(img_data, action_type="state")
|
|
547
|
+
except Exception as e:
|
|
548
|
+
logger.error(f"Error saving screenshot: {str(e)}")
|
|
549
|
+
|
|
550
|
+
return parsed_screen
|
|
551
|
+
|
|
552
|
+
except Exception as e:
|
|
553
|
+
logger.error(f"Error getting parsed screen: {str(e)}")
|
|
554
|
+
raise
|
|
555
|
+
|
|
556
|
+
async def _process_screen(
|
|
557
|
+
self, parsed_screen: ParseResult, messages: List[Dict[str, Any]]
|
|
558
|
+
) -> None:
|
|
559
|
+
"""Process and add screen info to messages."""
|
|
560
|
+
try:
|
|
561
|
+
# Only add message if we have an image and provider supports it
|
|
562
|
+
if self.provider in [APIProvider.OPENAI, APIProvider.ANTHROPIC]:
|
|
563
|
+
image = parsed_screen.annotated_image_base64 or None
|
|
564
|
+
if image:
|
|
565
|
+
# Save screen info to current turn directory
|
|
566
|
+
if self.current_turn_dir:
|
|
567
|
+
# Save elements as JSON
|
|
568
|
+
elements_path = os.path.join(self.current_turn_dir, "elements.json")
|
|
569
|
+
with open(elements_path, "w") as f:
|
|
570
|
+
# Convert elements to dicts for JSON serialization
|
|
571
|
+
elements_json = [elem.model_dump() for elem in parsed_screen.elements]
|
|
572
|
+
json.dump(elements_json, f, indent=2)
|
|
573
|
+
logger.info(f"Saved elements to {elements_path}")
|
|
574
|
+
|
|
575
|
+
# Format the image content based on the provider
|
|
576
|
+
if self.provider == APIProvider.ANTHROPIC:
|
|
577
|
+
# Compress the image before sending to Anthropic (5MB limit)
|
|
578
|
+
image_size = len(image)
|
|
579
|
+
logger.info(f"Image base64 is present, length: {image_size}")
|
|
580
|
+
|
|
581
|
+
# Anthropic has a 5MB limit - check against base64 string length
|
|
582
|
+
# which is what matters for the API call payload
|
|
583
|
+
# Use slightly smaller limit (4.9MB) to account for request overhead
|
|
584
|
+
max_size = int(4.9 * 1024 * 1024) # 4.9MB
|
|
585
|
+
|
|
586
|
+
# Default media type (will be overridden if compression is needed)
|
|
587
|
+
media_type = "image/png"
|
|
588
|
+
|
|
589
|
+
# Check if the image already has a media type prefix
|
|
590
|
+
if image.startswith("data:"):
|
|
591
|
+
parts = image.split(",", 1)
|
|
592
|
+
if len(parts) == 2 and "image/jpeg" in parts[0].lower():
|
|
593
|
+
media_type = "image/jpeg"
|
|
594
|
+
elif len(parts) == 2 and "image/png" in parts[0].lower():
|
|
595
|
+
media_type = "image/png"
|
|
596
|
+
|
|
597
|
+
if image_size > max_size:
|
|
598
|
+
logger.info(
|
|
599
|
+
f"Image size ({image_size} bytes) exceeds Anthropic limit ({max_size} bytes), compressing..."
|
|
600
|
+
)
|
|
601
|
+
image, media_type = compress_image_base64(image, max_size)
|
|
602
|
+
logger.info(
|
|
603
|
+
f"Image compressed to {len(image)} bytes with media_type {media_type}"
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
# Anthropic uses "type": "image"
|
|
607
|
+
screen_info_msg = {
|
|
608
|
+
"role": "user",
|
|
609
|
+
"content": [
|
|
610
|
+
{
|
|
611
|
+
"type": "image",
|
|
612
|
+
"source": {
|
|
613
|
+
"type": "base64",
|
|
614
|
+
"media_type": media_type,
|
|
615
|
+
"data": image,
|
|
616
|
+
},
|
|
617
|
+
}
|
|
618
|
+
],
|
|
619
|
+
}
|
|
620
|
+
else:
|
|
621
|
+
# OpenAI and others use "type": "image_url"
|
|
622
|
+
screen_info_msg = {
|
|
623
|
+
"role": "user",
|
|
624
|
+
"content": [
|
|
625
|
+
{
|
|
626
|
+
"type": "image_url",
|
|
627
|
+
"image_url": {"url": f"data:image/png;base64,{image}"},
|
|
628
|
+
}
|
|
629
|
+
],
|
|
630
|
+
}
|
|
631
|
+
messages.append(screen_info_msg)
|
|
632
|
+
|
|
633
|
+
except Exception as e:
|
|
634
|
+
logger.error(f"Error processing screen info: {str(e)}")
|
|
635
|
+
raise
|
|
636
|
+
|
|
637
|
+
def _get_system_prompt(self) -> str:
|
|
638
|
+
"""Get the system prompt for the model."""
|
|
639
|
+
return SYSTEM_PROMPT
|
|
640
|
+
|
|
641
|
+
async def _execute_action(self, content: Dict[str, Any], parsed_screen: ParseResult) -> None:
|
|
642
|
+
"""Execute the action specified in the content using the tool manager.
|
|
643
|
+
|
|
644
|
+
Args:
|
|
645
|
+
content: Dictionary containing the action details
|
|
646
|
+
parsed_screen: Current parsed screen information
|
|
647
|
+
"""
|
|
648
|
+
try:
|
|
649
|
+
action = content.get("Action", "").lower()
|
|
650
|
+
if not action:
|
|
651
|
+
return
|
|
652
|
+
|
|
653
|
+
# Track if we saved an action-specific screenshot
|
|
654
|
+
action_screenshot_saved = False
|
|
655
|
+
|
|
656
|
+
try:
|
|
657
|
+
# Prepare kwargs based on action type
|
|
658
|
+
kwargs = {}
|
|
659
|
+
|
|
660
|
+
if action in ["left_click", "right_click", "double_click", "move_cursor"]:
|
|
661
|
+
try:
|
|
662
|
+
box_id = int(content["Box ID"])
|
|
663
|
+
logger.info(f"Processing Box ID: {box_id}")
|
|
664
|
+
|
|
665
|
+
# Calculate click coordinates
|
|
666
|
+
x, y = await self._calculate_click_coordinates(box_id, parsed_screen)
|
|
667
|
+
logger.info(f"Calculated coordinates: x={x}, y={y}")
|
|
668
|
+
|
|
669
|
+
kwargs["x"] = x
|
|
670
|
+
kwargs["y"] = y
|
|
671
|
+
|
|
672
|
+
# Visualize action if screenshot is available
|
|
673
|
+
if parsed_screen.annotated_image_base64:
|
|
674
|
+
img_data = parsed_screen.annotated_image_base64
|
|
675
|
+
# Remove data URL prefix if present
|
|
676
|
+
if img_data.startswith("data:image"):
|
|
677
|
+
img_data = img_data.split(",")[1]
|
|
678
|
+
# Only save visualization for coordinate-based actions
|
|
679
|
+
self._visualize_action(x, y, img_data)
|
|
680
|
+
action_screenshot_saved = True
|
|
681
|
+
|
|
682
|
+
except ValueError as e:
|
|
683
|
+
logger.error(f"Error processing Box ID: {str(e)}")
|
|
684
|
+
return
|
|
685
|
+
|
|
686
|
+
elif action == "drag_to":
|
|
687
|
+
try:
|
|
688
|
+
box_id = int(content["Box ID"])
|
|
689
|
+
x, y = await self._calculate_click_coordinates(box_id, parsed_screen)
|
|
690
|
+
kwargs.update(
|
|
691
|
+
{
|
|
692
|
+
"x": x,
|
|
693
|
+
"y": y,
|
|
694
|
+
"button": content.get("button", "left"),
|
|
695
|
+
"duration": float(content.get("duration", 0.5)),
|
|
696
|
+
}
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
# Visualize drag destination if screenshot is available
|
|
700
|
+
if parsed_screen.annotated_image_base64:
|
|
701
|
+
img_data = parsed_screen.annotated_image_base64
|
|
702
|
+
# Remove data URL prefix if present
|
|
703
|
+
if img_data.startswith("data:image"):
|
|
704
|
+
img_data = img_data.split(",")[1]
|
|
705
|
+
# Only save visualization for coordinate-based actions
|
|
706
|
+
self._visualize_action(x, y, img_data)
|
|
707
|
+
action_screenshot_saved = True
|
|
708
|
+
|
|
709
|
+
except ValueError as e:
|
|
710
|
+
logger.error(f"Error processing drag coordinates: {str(e)}")
|
|
711
|
+
return
|
|
712
|
+
|
|
713
|
+
elif action == "type_text":
|
|
714
|
+
kwargs["text"] = content["Value"]
|
|
715
|
+
# For type_text, store the value in the action type
|
|
716
|
+
action_type = f"type_{content['Value'][:20]}" # Truncate if too long
|
|
717
|
+
elif action == "press_key":
|
|
718
|
+
kwargs["key"] = content["Value"]
|
|
719
|
+
action_type = f"press_{content['Value']}"
|
|
720
|
+
elif action == "hotkey":
|
|
721
|
+
if isinstance(content.get("Value"), list):
|
|
722
|
+
keys = content["Value"]
|
|
723
|
+
action_type = f"hotkey_{'_'.join(keys)}"
|
|
724
|
+
else:
|
|
725
|
+
# Simply split string format like "command+space" into a list
|
|
726
|
+
keys = [k.strip() for k in content["Value"].lower().split("+")]
|
|
727
|
+
action_type = f"hotkey_{content['Value'].replace('+', '_')}"
|
|
728
|
+
logger.info(f"Preparing hotkey with keys: {keys}")
|
|
729
|
+
# Get the method but call it with *args instead of **kwargs
|
|
730
|
+
method = getattr(self.computer, action)
|
|
731
|
+
await method(*keys) # Unpack the keys list as positional arguments
|
|
732
|
+
logger.info(f"Tool execution completed successfully: {action}")
|
|
733
|
+
|
|
734
|
+
# For hotkeys, take a screenshot after the action
|
|
735
|
+
try:
|
|
736
|
+
# Get a new screenshot after the action and save it with the action type
|
|
737
|
+
new_parsed_screen = await self._get_parsed_screen_som(save_screenshot=False)
|
|
738
|
+
if new_parsed_screen and new_parsed_screen.annotated_image_base64:
|
|
739
|
+
img_data = new_parsed_screen.annotated_image_base64
|
|
740
|
+
# Remove data URL prefix if present
|
|
741
|
+
if img_data.startswith("data:image"):
|
|
742
|
+
img_data = img_data.split(",")[1]
|
|
743
|
+
# Save with action type to indicate this is a post-action screenshot
|
|
744
|
+
self._save_screenshot(img_data, action_type=action_type)
|
|
745
|
+
action_screenshot_saved = True
|
|
746
|
+
except Exception as screenshot_error:
|
|
747
|
+
logger.error(
|
|
748
|
+
f"Error taking post-hotkey screenshot: {str(screenshot_error)}"
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
return
|
|
752
|
+
|
|
753
|
+
elif action in ["scroll_down", "scroll_up"]:
|
|
754
|
+
clicks = int(content.get("amount", 1))
|
|
755
|
+
kwargs["clicks"] = clicks
|
|
756
|
+
action_type = f"scroll_{action.split('_')[1]}_{clicks}"
|
|
757
|
+
|
|
758
|
+
# Visualize scrolling if screenshot is available
|
|
759
|
+
if parsed_screen.annotated_image_base64:
|
|
760
|
+
img_data = parsed_screen.annotated_image_base64
|
|
761
|
+
# Remove data URL prefix if present
|
|
762
|
+
if img_data.startswith("data:image"):
|
|
763
|
+
img_data = img_data.split(",")[1]
|
|
764
|
+
direction = "down" if action == "scroll_down" else "up"
|
|
765
|
+
# For scrolling, we only save the visualization to avoid duplicate images
|
|
766
|
+
self._visualize_scroll(direction, clicks, img_data)
|
|
767
|
+
action_screenshot_saved = True
|
|
768
|
+
|
|
769
|
+
else:
|
|
770
|
+
logger.warning(f"Unknown action: {action}")
|
|
771
|
+
return
|
|
772
|
+
|
|
773
|
+
# Execute tool and handle result
|
|
774
|
+
try:
|
|
775
|
+
method = getattr(self.computer, action)
|
|
776
|
+
logger.info(f"Found method for action '{action}': {method}")
|
|
777
|
+
await method(**kwargs)
|
|
778
|
+
logger.info(f"Tool execution completed successfully: {action}")
|
|
779
|
+
|
|
780
|
+
# For non-coordinate based actions that don't already have visualizations,
|
|
781
|
+
# take a new screenshot after the action
|
|
782
|
+
if not action_screenshot_saved:
|
|
783
|
+
# Take a new screenshot
|
|
784
|
+
try:
|
|
785
|
+
# Get a new screenshot after the action and save it with the action type
|
|
786
|
+
new_parsed_screen = await self._get_parsed_screen_som(
|
|
787
|
+
save_screenshot=False
|
|
788
|
+
)
|
|
789
|
+
if new_parsed_screen and new_parsed_screen.annotated_image_base64:
|
|
790
|
+
img_data = new_parsed_screen.annotated_image_base64
|
|
791
|
+
# Remove data URL prefix if present
|
|
792
|
+
if img_data.startswith("data:image"):
|
|
793
|
+
img_data = img_data.split(",")[1]
|
|
794
|
+
# Save with action type to indicate this is a post-action screenshot
|
|
795
|
+
if "action_type" in locals():
|
|
796
|
+
self._save_screenshot(img_data, action_type=action_type)
|
|
797
|
+
else:
|
|
798
|
+
self._save_screenshot(img_data, action_type=action)
|
|
799
|
+
# Update the action screenshot flag for this turn
|
|
800
|
+
action_screenshot_saved = True
|
|
801
|
+
except Exception as screenshot_error:
|
|
802
|
+
logger.error(
|
|
803
|
+
f"Error taking post-action screenshot: {str(screenshot_error)}"
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
except AttributeError as e:
|
|
807
|
+
logger.error(f"Method not found for action '{action}': {str(e)}")
|
|
808
|
+
return
|
|
809
|
+
except Exception as tool_error:
|
|
810
|
+
logger.error(f"Tool execution failed: {str(tool_error)}")
|
|
811
|
+
return
|
|
812
|
+
|
|
813
|
+
except Exception as e:
|
|
814
|
+
logger.error(f"Error executing action {action}: {str(e)}")
|
|
815
|
+
return
|
|
816
|
+
|
|
817
|
+
except Exception as e:
|
|
818
|
+
logger.error(f"Error in _execute_action: {str(e)}")
|
|
819
|
+
return
|
|
820
|
+
|
|
821
|
+
async def _calculate_click_coordinates(
|
|
822
|
+
self, box_id: int, parsed_screen: ParseResult
|
|
823
|
+
) -> Tuple[int, int]:
|
|
824
|
+
"""Calculate click coordinates based on box ID.
|
|
825
|
+
|
|
826
|
+
Args:
|
|
827
|
+
box_id: The ID of the box to click
|
|
828
|
+
parsed_screen: The parsed screen information
|
|
829
|
+
|
|
830
|
+
Returns:
|
|
831
|
+
Tuple of (x, y) coordinates
|
|
832
|
+
|
|
833
|
+
Raises:
|
|
834
|
+
ValueError: If box_id is invalid or missing from parsed screen
|
|
835
|
+
"""
|
|
836
|
+
# First try to use structured elements data
|
|
837
|
+
logger.info(f"Elements count: {len(parsed_screen.elements)}")
|
|
838
|
+
|
|
839
|
+
# Try to find element with matching ID
|
|
840
|
+
for element in parsed_screen.elements:
|
|
841
|
+
if element.id == box_id:
|
|
842
|
+
logger.info(f"Found element with ID {box_id}: {element}")
|
|
843
|
+
bbox = element.bbox
|
|
844
|
+
|
|
845
|
+
# Get screen dimensions from the metadata if available, or fallback
|
|
846
|
+
width = parsed_screen.metadata.width if parsed_screen.metadata else 1920
|
|
847
|
+
height = parsed_screen.metadata.height if parsed_screen.metadata else 1080
|
|
848
|
+
logger.info(f"Screen dimensions: width={width}, height={height}")
|
|
849
|
+
|
|
850
|
+
# Calculate center of the box in pixels
|
|
851
|
+
center_x = int((bbox.x1 + bbox.x2) / 2 * width)
|
|
852
|
+
center_y = int((bbox.y1 + bbox.y2) / 2 * height)
|
|
853
|
+
logger.info(f"Calculated center: ({center_x}, {center_y})")
|
|
854
|
+
|
|
855
|
+
# Validate coordinates - if they're (0,0) or unreasonably small,
|
|
856
|
+
# use a default position in the center of the screen
|
|
857
|
+
if center_x == 0 and center_y == 0:
|
|
858
|
+
logger.warning("Got (0,0) coordinates, using fallback position")
|
|
859
|
+
center_x = width // 2
|
|
860
|
+
center_y = height // 2
|
|
861
|
+
logger.info(f"Using fallback center: ({center_x}, {center_y})")
|
|
862
|
+
|
|
863
|
+
return center_x, center_y
|
|
864
|
+
|
|
865
|
+
# If we couldn't find the box, use center of screen
|
|
866
|
+
logger.error(
|
|
867
|
+
f"Box ID {box_id} not found in structured elements (count={len(parsed_screen.elements)})"
|
|
868
|
+
)
|
|
869
|
+
|
|
870
|
+
# Use center of screen as fallback
|
|
871
|
+
width = parsed_screen.metadata.width if parsed_screen.metadata else 1920
|
|
872
|
+
height = parsed_screen.metadata.height if parsed_screen.metadata else 1080
|
|
873
|
+
logger.warning(f"Using fallback position in center of screen ({width//2}, {height//2})")
|
|
874
|
+
return width // 2, height // 2
|
|
875
|
+
|
|
876
|
+
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]:
|
|
877
|
+
"""Run the agent loop with provided messages.
|
|
878
|
+
|
|
879
|
+
Args:
|
|
880
|
+
messages: List of message objects
|
|
881
|
+
|
|
882
|
+
Yields:
|
|
883
|
+
Dict containing response data
|
|
884
|
+
"""
|
|
885
|
+
# Keep track of conversation history
|
|
886
|
+
conversation_history = messages.copy()
|
|
887
|
+
|
|
888
|
+
# Continue running until explicitly told to stop
|
|
889
|
+
running = True
|
|
890
|
+
turn_created = False
|
|
891
|
+
# Track if an action-specific screenshot has been saved this turn
|
|
892
|
+
action_screenshot_saved = False
|
|
893
|
+
|
|
894
|
+
attempt = 0
|
|
895
|
+
max_attempts = 3
|
|
896
|
+
|
|
897
|
+
while running and attempt < max_attempts:
|
|
898
|
+
try:
|
|
899
|
+
# Create a new turn directory if it's not already created
|
|
900
|
+
if not turn_created:
|
|
901
|
+
self._create_turn_dir()
|
|
902
|
+
turn_created = True
|
|
903
|
+
|
|
904
|
+
# Ensure client is initialized
|
|
905
|
+
if self.client is None:
|
|
906
|
+
logger.info("Initializing client...")
|
|
907
|
+
await self.initialize_client()
|
|
908
|
+
if self.client is None:
|
|
909
|
+
raise RuntimeError("Failed to initialize client")
|
|
910
|
+
logger.info("Client initialized successfully")
|
|
911
|
+
|
|
912
|
+
# Get up-to-date screen information
|
|
913
|
+
parsed_screen = await self._get_parsed_screen_som()
|
|
914
|
+
|
|
915
|
+
# Process screen info and update messages
|
|
916
|
+
await self._process_screen(parsed_screen, conversation_history)
|
|
917
|
+
|
|
918
|
+
# Get system prompt
|
|
919
|
+
system_prompt = self._get_system_prompt()
|
|
920
|
+
|
|
921
|
+
# Make API call with retries
|
|
922
|
+
response = await self._make_api_call(conversation_history, system_prompt)
|
|
923
|
+
|
|
924
|
+
# Handle the response (may execute actions)
|
|
925
|
+
# Returns: (should_continue, action_screenshot_saved)
|
|
926
|
+
should_continue, new_screenshot_saved = await self._handle_response(
|
|
927
|
+
response, conversation_history, parsed_screen
|
|
928
|
+
)
|
|
929
|
+
|
|
930
|
+
# Update whether an action screenshot was saved this turn
|
|
931
|
+
action_screenshot_saved = action_screenshot_saved or new_screenshot_saved
|
|
932
|
+
|
|
933
|
+
# Yield the response to the caller
|
|
934
|
+
yield {"response": response}
|
|
935
|
+
|
|
936
|
+
# Check if we should continue this conversation
|
|
937
|
+
running = should_continue
|
|
938
|
+
|
|
939
|
+
# Create a new turn directory if we're continuing
|
|
940
|
+
if running:
|
|
941
|
+
turn_created = False
|
|
942
|
+
|
|
943
|
+
# Reset attempt counter on success
|
|
944
|
+
attempt = 0
|
|
945
|
+
|
|
946
|
+
except Exception as e:
|
|
947
|
+
attempt += 1
|
|
948
|
+
error_msg = f"Error in run method (attempt {attempt}/{max_attempts}): {str(e)}"
|
|
949
|
+
logger.error(error_msg)
|
|
950
|
+
|
|
951
|
+
# If this is our last attempt, provide more info about the error
|
|
952
|
+
if attempt >= max_attempts:
|
|
953
|
+
logger.error(f"Maximum retry attempts reached. Last error was: {str(e)}")
|
|
954
|
+
|
|
955
|
+
yield {
|
|
956
|
+
"error": str(e),
|
|
957
|
+
"metadata": {"title": "❌ Error"},
|
|
958
|
+
}
|
|
959
|
+
|
|
960
|
+
# Create a brief delay before retrying
|
|
961
|
+
await asyncio.sleep(1)
|