cua-agent 0.1.6__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cua-agent might be problematic. Click here for more details.

Files changed (57) hide show
  1. agent/__init__.py +3 -2
  2. agent/core/__init__.py +1 -6
  3. agent/core/{computer_agent.py → agent.py} +31 -76
  4. agent/core/{loop.py → base.py} +68 -127
  5. agent/core/factory.py +104 -0
  6. agent/core/messages.py +279 -125
  7. agent/core/provider_config.py +15 -0
  8. agent/core/types.py +45 -0
  9. agent/core/visualization.py +197 -0
  10. agent/providers/anthropic/api/client.py +142 -1
  11. agent/providers/anthropic/api_handler.py +140 -0
  12. agent/providers/anthropic/callbacks/__init__.py +5 -0
  13. agent/providers/anthropic/loop.py +207 -221
  14. agent/providers/anthropic/response_handler.py +226 -0
  15. agent/providers/anthropic/tools/bash.py +0 -97
  16. agent/providers/anthropic/utils.py +368 -0
  17. agent/providers/omni/__init__.py +1 -20
  18. agent/providers/omni/api_handler.py +42 -0
  19. agent/providers/omni/clients/anthropic.py +4 -0
  20. agent/providers/omni/image_utils.py +0 -72
  21. agent/providers/omni/loop.py +491 -607
  22. agent/providers/omni/parser.py +58 -4
  23. agent/providers/omni/tools/__init__.py +25 -7
  24. agent/providers/omni/tools/base.py +29 -0
  25. agent/providers/omni/tools/bash.py +43 -38
  26. agent/providers/omni/tools/computer.py +144 -182
  27. agent/providers/omni/tools/manager.py +25 -45
  28. agent/providers/omni/types.py +1 -3
  29. agent/providers/omni/utils.py +224 -145
  30. agent/providers/openai/__init__.py +6 -0
  31. agent/providers/openai/api_handler.py +453 -0
  32. agent/providers/openai/loop.py +440 -0
  33. agent/providers/openai/response_handler.py +205 -0
  34. agent/providers/openai/tools/__init__.py +15 -0
  35. agent/providers/openai/tools/base.py +79 -0
  36. agent/providers/openai/tools/computer.py +319 -0
  37. agent/providers/openai/tools/manager.py +106 -0
  38. agent/providers/openai/types.py +36 -0
  39. agent/providers/openai/utils.py +98 -0
  40. cua_agent-0.1.18.dist-info/METADATA +165 -0
  41. cua_agent-0.1.18.dist-info/RECORD +73 -0
  42. agent/README.md +0 -63
  43. agent/providers/anthropic/messages/manager.py +0 -112
  44. agent/providers/omni/callbacks.py +0 -78
  45. agent/providers/omni/clients/groq.py +0 -101
  46. agent/providers/omni/experiment.py +0 -276
  47. agent/providers/omni/messages.py +0 -171
  48. agent/providers/omni/tool_manager.py +0 -91
  49. agent/providers/omni/visualization.py +0 -130
  50. agent/types/__init__.py +0 -23
  51. agent/types/base.py +0 -41
  52. agent/types/messages.py +0 -36
  53. cua_agent-0.1.6.dist-info/METADATA +0 -120
  54. cua_agent-0.1.6.dist-info/RECORD +0 -64
  55. /agent/{types → core}/tools.py +0 -0
  56. {cua_agent-0.1.6.dist-info → cua_agent-0.1.18.dist-info}/WHEEL +0 -0
  57. {cua_agent-0.1.6.dist-info → cua_agent-0.1.18.dist-info}/entry_points.txt +0 -0
agent/core/messages.py CHANGED
@@ -1,12 +1,11 @@
1
1
  """Message handling utilities for agent."""
2
2
 
3
- import base64
4
- from datetime import datetime
5
- from io import BytesIO
6
3
  import logging
7
- from typing import Any, Dict, List, Optional, Union
8
- from PIL import Image
4
+ import json
5
+ from typing import Any, Dict, List, Optional, Union, Tuple
9
6
  from dataclasses import dataclass
7
+ import re
8
+ from ..providers.omni.parser import ParseResult
10
9
 
11
10
  logger = logging.getLogger(__name__)
12
11
 
@@ -123,123 +122,278 @@ class BaseMessageManager:
123
122
  break
124
123
 
125
124
 
126
- def create_user_message(text: str) -> Dict[str, str]:
127
- """Create a user message.
128
-
129
- Args:
130
- text: The message text
131
-
132
- Returns:
133
- Message dictionary
134
- """
135
- return {
136
- "role": "user",
137
- "content": text,
138
- }
139
-
140
-
141
- def create_assistant_message(text: str) -> Dict[str, str]:
142
- """Create an assistant message.
143
-
144
- Args:
145
- text: The message text
146
-
147
- Returns:
148
- Message dictionary
149
- """
150
- return {
151
- "role": "assistant",
152
- "content": text,
153
- }
154
-
155
-
156
- def create_system_message(text: str) -> Dict[str, str]:
157
- """Create a system message.
158
-
159
- Args:
160
- text: The message text
161
-
162
- Returns:
163
- Message dictionary
164
- """
165
- return {
166
- "role": "system",
167
- "content": text,
168
- }
169
-
170
-
171
- def create_image_message(
172
- image_base64: Optional[str] = None,
173
- image_path: Optional[str] = None,
174
- image_obj: Optional[Image.Image] = None,
175
- ) -> Dict[str, Union[str, List[Dict[str, Any]]]]:
176
- """Create a message with an image.
177
-
178
- Args:
179
- image_base64: Base64 encoded image
180
- image_path: Path to image file
181
- image_obj: PIL Image object
182
-
183
- Returns:
184
- Message dictionary with content list
185
-
186
- Raises:
187
- ValueError: If no image source is provided
188
- """
189
- if not any([image_base64, image_path, image_obj]):
190
- raise ValueError("Must provide one of image_base64, image_path, or image_obj")
191
-
192
- # Convert to base64 if needed
193
- if image_path and not image_base64:
194
- with open(image_path, "rb") as f:
195
- image_bytes = f.read()
196
- image_base64 = base64.b64encode(image_bytes).decode("utf-8")
197
- elif image_obj and not image_base64:
198
- buffer = BytesIO()
199
- image_obj.save(buffer, format="PNG")
200
- image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
201
-
202
- return {
203
- "role": "user",
204
- "content": [
205
- {"type": "image", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
206
- ],
207
- }
208
-
209
-
210
- def create_screen_message(
211
- parsed_screen: Dict[str, Any],
212
- include_raw: bool = False,
213
- ) -> Dict[str, Union[str, List[Dict[str, Any]]]]:
214
- """Create a message with screen information.
215
-
216
- Args:
217
- parsed_screen: Dictionary containing parsed screen info
218
- include_raw: Whether to include raw screenshot base64
219
-
220
- Returns:
221
- Message dictionary with content
222
- """
223
- if include_raw and "screenshot_base64" in parsed_screen:
224
- # Create content list with both image and text
225
- return {
226
- "role": "user",
227
- "content": [
228
- {
229
- "type": "image",
230
- "image_url": {
231
- "url": f"data:image/png;base64,{parsed_screen['screenshot_base64']}"
232
- },
233
- },
234
- {
235
- "type": "text",
236
- "text": f"Screen dimensions: {parsed_screen['width']}x{parsed_screen['height']}",
237
- },
238
- ],
239
- }
240
- else:
241
- # Create text-only message with screen info
242
- return {
243
- "role": "user",
244
- "content": f"Screen dimensions: {parsed_screen['width']}x{parsed_screen['height']}",
245
- }
125
+ class StandardMessageManager:
126
+ """Manages messages in a standardized OpenAI format across different providers."""
127
+
128
+ def __init__(self, config: Optional[ImageRetentionConfig] = None):
129
+ """Initialize message manager.
130
+
131
+ Args:
132
+ config: Configuration for image retention
133
+ """
134
+ self.messages: List[Dict[str, Any]] = []
135
+ self.config = config or ImageRetentionConfig()
136
+
137
+ def add_user_message(self, content: Union[str, List[Dict[str, Any]]]) -> None:
138
+ """Add a user message.
139
+
140
+ Args:
141
+ content: Message content (text or multimodal content)
142
+ """
143
+ self.messages.append({"role": "user", "content": content})
144
+
145
+ def add_assistant_message(self, content: Union[str, List[Dict[str, Any]]]) -> None:
146
+ """Add an assistant message.
147
+
148
+ Args:
149
+ content: Message content (text or multimodal content)
150
+ """
151
+ self.messages.append({"role": "assistant", "content": content})
152
+
153
+ def add_system_message(self, content: str) -> None:
154
+ """Add a system message.
155
+
156
+ Args:
157
+ content: System message content
158
+ """
159
+ self.messages.append({"role": "system", "content": content})
160
+
161
+ def get_messages(self) -> List[Dict[str, Any]]:
162
+ """Get all messages in standard format.
163
+
164
+ Returns:
165
+ List of messages
166
+ """
167
+ # If image retention is configured, apply it
168
+ if self.config.num_images_to_keep is not None:
169
+ return self._apply_image_retention(self.messages)
170
+ return self.messages
171
+
172
+ def _apply_image_retention(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
173
+ """Apply image retention policy to messages.
174
+
175
+ Args:
176
+ messages: List of messages
177
+
178
+ Returns:
179
+ List of messages with image retention applied
180
+ """
181
+ if not self.config.num_images_to_keep:
182
+ return messages
183
+
184
+ # Find user messages with images
185
+ image_messages = []
186
+ for msg in messages:
187
+ if msg["role"] == "user" and isinstance(msg["content"], list):
188
+ has_image = any(
189
+ item.get("type") == "image_url" or item.get("type") == "image"
190
+ for item in msg["content"]
191
+ )
192
+ if has_image:
193
+ image_messages.append(msg)
194
+
195
+ # If we don't have more images than the limit, return all messages
196
+ if len(image_messages) <= self.config.num_images_to_keep:
197
+ return messages
198
+
199
+ # Get the most recent N images to keep
200
+ images_to_keep = image_messages[-self.config.num_images_to_keep :]
201
+ images_to_remove = image_messages[: -self.config.num_images_to_keep]
202
+
203
+ # Create a new message list without the older images
204
+ result = []
205
+ for msg in messages:
206
+ if msg in images_to_remove:
207
+ # Skip this message
208
+ continue
209
+ result.append(msg)
210
+
211
+ return result
212
+
213
+ def to_anthropic_format(
214
+ self, messages: List[Dict[str, Any]]
215
+ ) -> Tuple[List[Dict[str, Any]], str]:
216
+ """Convert standard OpenAI format messages to Anthropic format.
217
+
218
+ Args:
219
+ messages: List of messages in OpenAI format
220
+
221
+ Returns:
222
+ Tuple containing (anthropic_messages, system_content)
223
+ """
224
+ result = []
225
+ system_content = ""
226
+
227
+ # Process messages in order to maintain conversation flow
228
+ previous_assistant_tool_use_ids = (
229
+ set()
230
+ ) # Track tool_use_ids in the previous assistant message
231
+
232
+ for i, msg in enumerate(messages):
233
+ role = msg.get("role", "")
234
+ content = msg.get("content", "")
235
+
236
+ if role == "system":
237
+ # Collect system messages for later use
238
+ system_content += content + "\n"
239
+ continue
240
+
241
+ if role == "assistant":
242
+ # Track tool_use_ids in this assistant message for the next user message
243
+ previous_assistant_tool_use_ids = set()
244
+ if isinstance(content, list):
245
+ for item in content:
246
+ if (
247
+ isinstance(item, dict)
248
+ and item.get("type") == "tool_use"
249
+ and "id" in item
250
+ ):
251
+ previous_assistant_tool_use_ids.add(item["id"])
252
+
253
+ logger.info(
254
+ f"Tool use IDs in assistant message #{i}: {previous_assistant_tool_use_ids}"
255
+ )
256
+
257
+ if role in ["user", "assistant"]:
258
+ anthropic_msg = {"role": role}
259
+
260
+ # Convert content based on type
261
+ if isinstance(content, str):
262
+ # Simple text content
263
+ anthropic_msg["content"] = [{"type": "text", "text": content}]
264
+ elif isinstance(content, list):
265
+ # Convert complex content
266
+ anthropic_content = []
267
+ for item in content:
268
+ item_type = item.get("type", "")
269
+
270
+ if item_type == "text":
271
+ anthropic_content.append({"type": "text", "text": item.get("text", "")})
272
+ elif item_type == "image_url":
273
+ # Convert OpenAI image format to Anthropic
274
+ image_url = item.get("image_url", {}).get("url", "")
275
+ if image_url.startswith("data:"):
276
+ # Extract base64 data and media type
277
+ match = re.match(r"data:(.+);base64,(.+)", image_url)
278
+ if match:
279
+ media_type, data = match.groups()
280
+ anthropic_content.append(
281
+ {
282
+ "type": "image",
283
+ "source": {
284
+ "type": "base64",
285
+ "media_type": media_type,
286
+ "data": data,
287
+ },
288
+ }
289
+ )
290
+ else:
291
+ # Regular URL
292
+ anthropic_content.append(
293
+ {
294
+ "type": "image",
295
+ "source": {
296
+ "type": "url",
297
+ "url": image_url,
298
+ },
299
+ }
300
+ )
301
+ elif item_type == "tool_use":
302
+ # Always include tool_use blocks
303
+ anthropic_content.append(item)
304
+ elif item_type == "tool_result":
305
+ # Check if this is a user message AND if the tool_use_id exists in the previous assistant message
306
+ tool_use_id = item.get("tool_use_id")
307
+
308
+ # Only include tool_result if it references a tool_use from the immediately preceding assistant message
309
+ if (
310
+ role == "user"
311
+ and tool_use_id
312
+ and tool_use_id in previous_assistant_tool_use_ids
313
+ ):
314
+ anthropic_content.append(item)
315
+ logger.info(
316
+ f"Including tool_result with tool_use_id: {tool_use_id}"
317
+ )
318
+ else:
319
+ # Convert to text to preserve information
320
+ logger.warning(
321
+ f"Converting tool_result to text. Tool use ID {tool_use_id} not found in previous assistant message"
322
+ )
323
+ content_text = "Tool Result: "
324
+ if "content" in item:
325
+ if isinstance(item["content"], list):
326
+ for content_item in item["content"]:
327
+ if (
328
+ isinstance(content_item, dict)
329
+ and content_item.get("type") == "text"
330
+ ):
331
+ content_text += content_item.get("text", "")
332
+ elif isinstance(item["content"], str):
333
+ content_text += item["content"]
334
+ anthropic_content.append({"type": "text", "text": content_text})
335
+
336
+ anthropic_msg["content"] = anthropic_content
337
+
338
+ result.append(anthropic_msg)
339
+
340
+ return result, system_content
341
+
342
+ def from_anthropic_format(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
343
+ """Convert Anthropic format messages to standard OpenAI format.
344
+
345
+ Args:
346
+ messages: List of messages in Anthropic format
347
+
348
+ Returns:
349
+ List of messages in OpenAI format
350
+ """
351
+ result = []
352
+
353
+ for msg in messages:
354
+ role = msg.get("role", "")
355
+ content = msg.get("content", [])
356
+
357
+ if role in ["user", "assistant"]:
358
+ openai_msg = {"role": role}
359
+
360
+ # Simple case: single text block
361
+ if len(content) == 1 and content[0].get("type") == "text":
362
+ openai_msg["content"] = content[0].get("text", "")
363
+ else:
364
+ # Complex case: multiple blocks or non-text
365
+ openai_content = []
366
+ for item in content:
367
+ item_type = item.get("type", "")
368
+
369
+ if item_type == "text":
370
+ openai_content.append({"type": "text", "text": item.get("text", "")})
371
+ elif item_type == "image":
372
+ # Convert Anthropic image to OpenAI format
373
+ source = item.get("source", {})
374
+ if source.get("type") == "base64":
375
+ media_type = source.get("media_type", "image/png")
376
+ data = source.get("data", "")
377
+ openai_content.append(
378
+ {
379
+ "type": "image_url",
380
+ "image_url": {"url": f"data:{media_type};base64,{data}"},
381
+ }
382
+ )
383
+ else:
384
+ # URL
385
+ openai_content.append(
386
+ {
387
+ "type": "image_url",
388
+ "image_url": {"url": source.get("url", "")},
389
+ }
390
+ )
391
+ elif item_type in ["tool_use", "tool_result"]:
392
+ # Pass through tool-related content
393
+ openai_content.append(item)
394
+
395
+ openai_msg["content"] = openai_content
396
+
397
+ result.append(openai_msg)
398
+
399
+ return result
@@ -0,0 +1,15 @@
1
+ """Provider-specific configurations and constants."""
2
+
3
+ from ..providers.omni.types import LLMProvider
4
+
5
+ # Default models for different providers
6
+ DEFAULT_MODELS = {
7
+ LLMProvider.OPENAI: "gpt-4o",
8
+ LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
9
+ }
10
+
11
+ # Map providers to their environment variable names
12
+ ENV_VARS = {
13
+ LLMProvider.OPENAI: "OPENAI_API_KEY",
14
+ LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
15
+ }
agent/core/types.py ADDED
@@ -0,0 +1,45 @@
1
+ """Core type definitions."""
2
+
3
+ from typing import Any, Dict, List, Optional, TypedDict, Union
4
+ from enum import Enum, auto
5
+
6
+
7
+ class AgentLoop(Enum):
8
+ """Enumeration of available loop types."""
9
+
10
+ ANTHROPIC = auto() # Anthropic implementation
11
+ OMNI = auto() # OmniLoop implementation
12
+ OPENAI = auto() # OpenAI implementation
13
+ # Add more loop types as needed
14
+
15
+
16
+ class AgentResponse(TypedDict, total=False):
17
+ """Agent response format."""
18
+
19
+ id: str
20
+ object: str
21
+ created_at: int
22
+ status: str
23
+ error: Optional[str]
24
+ incomplete_details: Optional[Any]
25
+ instructions: Optional[Any]
26
+ max_output_tokens: Optional[int]
27
+ model: str
28
+ output: List[Dict[str, Any]]
29
+ parallel_tool_calls: bool
30
+ previous_response_id: Optional[str]
31
+ reasoning: Dict[str, str]
32
+ store: bool
33
+ temperature: float
34
+ text: Dict[str, Dict[str, str]]
35
+ tool_choice: str
36
+ tools: List[Dict[str, Union[str, int]]]
37
+ top_p: float
38
+ truncation: str
39
+ usage: Dict[str, Any]
40
+ user: Optional[str]
41
+ metadata: Dict[str, Any]
42
+ response: Dict[str, List[Dict[str, Any]]]
43
+ # Additional fields for error responses
44
+ role: str
45
+ content: Union[str, List[Dict[str, Any]]]
@@ -0,0 +1,197 @@
1
+ """Core visualization utilities for agents."""
2
+
3
+ import logging
4
+ import base64
5
+ from typing import Dict, Tuple
6
+ from PIL import Image, ImageDraw
7
+ from io import BytesIO
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def visualize_click(x: int, y: int, img_base64: str) -> Image.Image:
13
+ """Visualize a click action by drawing a circle on the screenshot.
14
+
15
+ Args:
16
+ x: X coordinate of the click
17
+ y: Y coordinate of the click
18
+ img_base64: Base64-encoded screenshot
19
+
20
+ Returns:
21
+ PIL Image with visualization
22
+ """
23
+ try:
24
+ # Decode the base64 image
25
+ image_data = base64.b64decode(img_base64)
26
+ img = Image.open(BytesIO(image_data))
27
+
28
+ # Create a copy to draw on
29
+ draw_img = img.copy()
30
+ draw = ImageDraw.Draw(draw_img)
31
+
32
+ # Draw a circle at the click location
33
+ radius = 15
34
+ draw.ellipse([(x - radius, y - radius), (x + radius, y + radius)], outline="red", width=3)
35
+
36
+ # Draw crosshairs
37
+ line_length = 20
38
+ draw.line([(x - line_length, y), (x + line_length, y)], fill="red", width=3)
39
+ draw.line([(x, y - line_length), (x, y + line_length)], fill="red", width=3)
40
+
41
+ return draw_img
42
+ except Exception as e:
43
+ logger.error(f"Error visualizing click: {str(e)}")
44
+ # Return a blank image as fallback
45
+ return Image.new("RGB", (800, 600), "white")
46
+
47
+
48
+ def visualize_scroll(direction: str, clicks: int, img_base64: str) -> Image.Image:
49
+ """Visualize a scroll action by drawing arrows on the screenshot.
50
+
51
+ Args:
52
+ direction: Direction of scroll ('up' or 'down')
53
+ clicks: Number of scroll clicks
54
+ img_base64: Base64-encoded screenshot
55
+
56
+ Returns:
57
+ PIL Image with visualization
58
+ """
59
+ try:
60
+ # Decode the base64 image
61
+ image_data = base64.b64decode(img_base64)
62
+ img = Image.open(BytesIO(image_data))
63
+
64
+ # Create a copy to draw on
65
+ draw_img = img.copy()
66
+ draw = ImageDraw.Draw(draw_img)
67
+
68
+ # Calculate parameters for visualization
69
+ width, height = img.size
70
+ center_x = width // 2
71
+
72
+ # Draw arrows to indicate scrolling
73
+ arrow_length = min(100, height // 4)
74
+ arrow_width = 30
75
+ num_arrows = min(clicks, 3) # Don't draw too many arrows
76
+
77
+ # Calculate starting position
78
+ if direction == "down":
79
+ start_y = height // 3
80
+ arrow_dir = 1 # Down
81
+ else:
82
+ start_y = height * 2 // 3
83
+ arrow_dir = -1 # Up
84
+
85
+ # Draw the arrows
86
+ for i in range(num_arrows):
87
+ y_pos = start_y + (i * arrow_length * arrow_dir * 0.7)
88
+ arrow_top = (center_x, y_pos)
89
+ arrow_bottom = (center_x, y_pos + arrow_length * arrow_dir)
90
+
91
+ # Draw the main line
92
+ draw.line([arrow_top, arrow_bottom], fill="red", width=5)
93
+
94
+ # Draw the arrowhead
95
+ arrowhead_size = 20
96
+ if direction == "down":
97
+ draw.line(
98
+ [
99
+ (center_x - arrow_width // 2, arrow_bottom[1] - arrowhead_size),
100
+ arrow_bottom,
101
+ (center_x + arrow_width // 2, arrow_bottom[1] - arrowhead_size),
102
+ ],
103
+ fill="red",
104
+ width=5,
105
+ )
106
+ else:
107
+ draw.line(
108
+ [
109
+ (center_x - arrow_width // 2, arrow_bottom[1] + arrowhead_size),
110
+ arrow_bottom,
111
+ (center_x + arrow_width // 2, arrow_bottom[1] + arrowhead_size),
112
+ ],
113
+ fill="red",
114
+ width=5,
115
+ )
116
+
117
+ return draw_img
118
+ except Exception as e:
119
+ logger.error(f"Error visualizing scroll: {str(e)}")
120
+ # Return a blank image as fallback
121
+ return Image.new("RGB", (800, 600), "white")
122
+
123
+
124
+ def calculate_element_center(bbox: Dict[str, float], width: int, height: int) -> Tuple[int, int]:
125
+ """Calculate the center point of a UI element.
126
+
127
+ Args:
128
+ bbox: Bounding box dictionary with x1, y1, x2, y2 coordinates (0-1 normalized)
129
+ width: Screen width in pixels
130
+ height: Screen height in pixels
131
+
132
+ Returns:
133
+ (x, y) tuple with pixel coordinates
134
+ """
135
+ center_x = int((bbox["x1"] + bbox["x2"]) / 2 * width)
136
+ center_y = int((bbox["y1"] + bbox["y2"]) / 2 * height)
137
+ return center_x, center_y
138
+
139
+
140
+ class VisualizationHelper:
141
+ """Helper class for visualizing agent actions."""
142
+
143
+ def __init__(self, agent):
144
+ """Initialize visualization helper.
145
+
146
+ Args:
147
+ agent: Reference to the agent that will use this helper
148
+ """
149
+ self.agent = agent
150
+
151
+ def visualize_action(self, x: int, y: int, img_base64: str) -> None:
152
+ """Visualize a click action by drawing on the screenshot."""
153
+ if (
154
+ not self.agent.save_trajectory
155
+ or not hasattr(self.agent, "experiment_manager")
156
+ or not self.agent.experiment_manager
157
+ ):
158
+ return
159
+
160
+ try:
161
+ # Use the visualization utility
162
+ img = visualize_click(x, y, img_base64)
163
+
164
+ # Save the visualization
165
+ self.agent.experiment_manager.save_action_visualization(img, "click", f"x{x}_y{y}")
166
+ except Exception as e:
167
+ logger.error(f"Error visualizing action: {str(e)}")
168
+
169
+ def visualize_scroll(self, direction: str, clicks: int, img_base64: str) -> None:
170
+ """Visualize a scroll action by drawing arrows on the screenshot."""
171
+ if (
172
+ not self.agent.save_trajectory
173
+ or not hasattr(self.agent, "experiment_manager")
174
+ or not self.agent.experiment_manager
175
+ ):
176
+ return
177
+
178
+ try:
179
+ # Use the visualization utility
180
+ img = visualize_scroll(direction, clicks, img_base64)
181
+
182
+ # Save the visualization
183
+ self.agent.experiment_manager.save_action_visualization(
184
+ img, "scroll", f"{direction}_{clicks}"
185
+ )
186
+ except Exception as e:
187
+ logger.error(f"Error visualizing scroll: {str(e)}")
188
+
189
+ def save_action_visualization(
190
+ self, img: Image.Image, action_name: str, details: str = ""
191
+ ) -> str:
192
+ """Save a visualization of an action."""
193
+ if hasattr(self.agent, "experiment_manager") and self.agent.experiment_manager:
194
+ return self.agent.experiment_manager.save_action_visualization(
195
+ img, action_name, details
196
+ )
197
+ return ""