cua-agent 0.1.6__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/__init__.py +3 -2
- agent/core/__init__.py +1 -6
- agent/core/{computer_agent.py → agent.py} +31 -76
- agent/core/{loop.py → base.py} +68 -127
- agent/core/factory.py +104 -0
- agent/core/messages.py +279 -125
- agent/core/provider_config.py +15 -0
- agent/core/types.py +45 -0
- agent/core/visualization.py +197 -0
- agent/providers/anthropic/api/client.py +142 -1
- agent/providers/anthropic/api_handler.py +140 -0
- agent/providers/anthropic/callbacks/__init__.py +5 -0
- agent/providers/anthropic/loop.py +207 -221
- agent/providers/anthropic/response_handler.py +226 -0
- agent/providers/anthropic/tools/bash.py +0 -97
- agent/providers/anthropic/utils.py +368 -0
- agent/providers/omni/__init__.py +1 -20
- agent/providers/omni/api_handler.py +42 -0
- agent/providers/omni/clients/anthropic.py +4 -0
- agent/providers/omni/image_utils.py +0 -72
- agent/providers/omni/loop.py +491 -607
- agent/providers/omni/parser.py +58 -4
- agent/providers/omni/tools/__init__.py +25 -7
- agent/providers/omni/tools/base.py +29 -0
- agent/providers/omni/tools/bash.py +43 -38
- agent/providers/omni/tools/computer.py +144 -182
- agent/providers/omni/tools/manager.py +25 -45
- agent/providers/omni/types.py +1 -3
- agent/providers/omni/utils.py +224 -145
- agent/providers/openai/__init__.py +6 -0
- agent/providers/openai/api_handler.py +453 -0
- agent/providers/openai/loop.py +440 -0
- agent/providers/openai/response_handler.py +205 -0
- agent/providers/openai/tools/__init__.py +15 -0
- agent/providers/openai/tools/base.py +79 -0
- agent/providers/openai/tools/computer.py +319 -0
- agent/providers/openai/tools/manager.py +106 -0
- agent/providers/openai/types.py +36 -0
- agent/providers/openai/utils.py +98 -0
- cua_agent-0.1.18.dist-info/METADATA +165 -0
- cua_agent-0.1.18.dist-info/RECORD +73 -0
- agent/README.md +0 -63
- agent/providers/anthropic/messages/manager.py +0 -112
- agent/providers/omni/callbacks.py +0 -78
- agent/providers/omni/clients/groq.py +0 -101
- agent/providers/omni/experiment.py +0 -276
- agent/providers/omni/messages.py +0 -171
- agent/providers/omni/tool_manager.py +0 -91
- agent/providers/omni/visualization.py +0 -130
- agent/types/__init__.py +0 -23
- agent/types/base.py +0 -41
- agent/types/messages.py +0 -36
- cua_agent-0.1.6.dist-info/METADATA +0 -120
- cua_agent-0.1.6.dist-info/RECORD +0 -64
- /agent/{types → core}/tools.py +0 -0
- {cua_agent-0.1.6.dist-info → cua_agent-0.1.18.dist-info}/WHEEL +0 -0
- {cua_agent-0.1.6.dist-info → cua_agent-0.1.18.dist-info}/entry_points.txt +0 -0
agent/core/messages.py
CHANGED
|
@@ -1,12 +1,11 @@
|
|
|
1
1
|
"""Message handling utilities for agent."""
|
|
2
2
|
|
|
3
|
-
import base64
|
|
4
|
-
from datetime import datetime
|
|
5
|
-
from io import BytesIO
|
|
6
3
|
import logging
|
|
7
|
-
|
|
8
|
-
from
|
|
4
|
+
import json
|
|
5
|
+
from typing import Any, Dict, List, Optional, Union, Tuple
|
|
9
6
|
from dataclasses import dataclass
|
|
7
|
+
import re
|
|
8
|
+
from ..providers.omni.parser import ParseResult
|
|
10
9
|
|
|
11
10
|
logger = logging.getLogger(__name__)
|
|
12
11
|
|
|
@@ -123,123 +122,278 @@ class BaseMessageManager:
|
|
|
123
122
|
break
|
|
124
123
|
|
|
125
124
|
|
|
126
|
-
|
|
127
|
-
"""
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
"
|
|
152
|
-
"content":
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
"
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
125
|
+
class StandardMessageManager:
|
|
126
|
+
"""Manages messages in a standardized OpenAI format across different providers."""
|
|
127
|
+
|
|
128
|
+
def __init__(self, config: Optional[ImageRetentionConfig] = None):
|
|
129
|
+
"""Initialize message manager.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
config: Configuration for image retention
|
|
133
|
+
"""
|
|
134
|
+
self.messages: List[Dict[str, Any]] = []
|
|
135
|
+
self.config = config or ImageRetentionConfig()
|
|
136
|
+
|
|
137
|
+
def add_user_message(self, content: Union[str, List[Dict[str, Any]]]) -> None:
|
|
138
|
+
"""Add a user message.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
content: Message content (text or multimodal content)
|
|
142
|
+
"""
|
|
143
|
+
self.messages.append({"role": "user", "content": content})
|
|
144
|
+
|
|
145
|
+
def add_assistant_message(self, content: Union[str, List[Dict[str, Any]]]) -> None:
|
|
146
|
+
"""Add an assistant message.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
content: Message content (text or multimodal content)
|
|
150
|
+
"""
|
|
151
|
+
self.messages.append({"role": "assistant", "content": content})
|
|
152
|
+
|
|
153
|
+
def add_system_message(self, content: str) -> None:
|
|
154
|
+
"""Add a system message.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
content: System message content
|
|
158
|
+
"""
|
|
159
|
+
self.messages.append({"role": "system", "content": content})
|
|
160
|
+
|
|
161
|
+
def get_messages(self) -> List[Dict[str, Any]]:
|
|
162
|
+
"""Get all messages in standard format.
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
List of messages
|
|
166
|
+
"""
|
|
167
|
+
# If image retention is configured, apply it
|
|
168
|
+
if self.config.num_images_to_keep is not None:
|
|
169
|
+
return self._apply_image_retention(self.messages)
|
|
170
|
+
return self.messages
|
|
171
|
+
|
|
172
|
+
def _apply_image_retention(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
173
|
+
"""Apply image retention policy to messages.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
messages: List of messages
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
List of messages with image retention applied
|
|
180
|
+
"""
|
|
181
|
+
if not self.config.num_images_to_keep:
|
|
182
|
+
return messages
|
|
183
|
+
|
|
184
|
+
# Find user messages with images
|
|
185
|
+
image_messages = []
|
|
186
|
+
for msg in messages:
|
|
187
|
+
if msg["role"] == "user" and isinstance(msg["content"], list):
|
|
188
|
+
has_image = any(
|
|
189
|
+
item.get("type") == "image_url" or item.get("type") == "image"
|
|
190
|
+
for item in msg["content"]
|
|
191
|
+
)
|
|
192
|
+
if has_image:
|
|
193
|
+
image_messages.append(msg)
|
|
194
|
+
|
|
195
|
+
# If we don't have more images than the limit, return all messages
|
|
196
|
+
if len(image_messages) <= self.config.num_images_to_keep:
|
|
197
|
+
return messages
|
|
198
|
+
|
|
199
|
+
# Get the most recent N images to keep
|
|
200
|
+
images_to_keep = image_messages[-self.config.num_images_to_keep :]
|
|
201
|
+
images_to_remove = image_messages[: -self.config.num_images_to_keep]
|
|
202
|
+
|
|
203
|
+
# Create a new message list without the older images
|
|
204
|
+
result = []
|
|
205
|
+
for msg in messages:
|
|
206
|
+
if msg in images_to_remove:
|
|
207
|
+
# Skip this message
|
|
208
|
+
continue
|
|
209
|
+
result.append(msg)
|
|
210
|
+
|
|
211
|
+
return result
|
|
212
|
+
|
|
213
|
+
def to_anthropic_format(
|
|
214
|
+
self, messages: List[Dict[str, Any]]
|
|
215
|
+
) -> Tuple[List[Dict[str, Any]], str]:
|
|
216
|
+
"""Convert standard OpenAI format messages to Anthropic format.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
messages: List of messages in OpenAI format
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
Tuple containing (anthropic_messages, system_content)
|
|
223
|
+
"""
|
|
224
|
+
result = []
|
|
225
|
+
system_content = ""
|
|
226
|
+
|
|
227
|
+
# Process messages in order to maintain conversation flow
|
|
228
|
+
previous_assistant_tool_use_ids = (
|
|
229
|
+
set()
|
|
230
|
+
) # Track tool_use_ids in the previous assistant message
|
|
231
|
+
|
|
232
|
+
for i, msg in enumerate(messages):
|
|
233
|
+
role = msg.get("role", "")
|
|
234
|
+
content = msg.get("content", "")
|
|
235
|
+
|
|
236
|
+
if role == "system":
|
|
237
|
+
# Collect system messages for later use
|
|
238
|
+
system_content += content + "\n"
|
|
239
|
+
continue
|
|
240
|
+
|
|
241
|
+
if role == "assistant":
|
|
242
|
+
# Track tool_use_ids in this assistant message for the next user message
|
|
243
|
+
previous_assistant_tool_use_ids = set()
|
|
244
|
+
if isinstance(content, list):
|
|
245
|
+
for item in content:
|
|
246
|
+
if (
|
|
247
|
+
isinstance(item, dict)
|
|
248
|
+
and item.get("type") == "tool_use"
|
|
249
|
+
and "id" in item
|
|
250
|
+
):
|
|
251
|
+
previous_assistant_tool_use_ids.add(item["id"])
|
|
252
|
+
|
|
253
|
+
logger.info(
|
|
254
|
+
f"Tool use IDs in assistant message #{i}: {previous_assistant_tool_use_ids}"
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
if role in ["user", "assistant"]:
|
|
258
|
+
anthropic_msg = {"role": role}
|
|
259
|
+
|
|
260
|
+
# Convert content based on type
|
|
261
|
+
if isinstance(content, str):
|
|
262
|
+
# Simple text content
|
|
263
|
+
anthropic_msg["content"] = [{"type": "text", "text": content}]
|
|
264
|
+
elif isinstance(content, list):
|
|
265
|
+
# Convert complex content
|
|
266
|
+
anthropic_content = []
|
|
267
|
+
for item in content:
|
|
268
|
+
item_type = item.get("type", "")
|
|
269
|
+
|
|
270
|
+
if item_type == "text":
|
|
271
|
+
anthropic_content.append({"type": "text", "text": item.get("text", "")})
|
|
272
|
+
elif item_type == "image_url":
|
|
273
|
+
# Convert OpenAI image format to Anthropic
|
|
274
|
+
image_url = item.get("image_url", {}).get("url", "")
|
|
275
|
+
if image_url.startswith("data:"):
|
|
276
|
+
# Extract base64 data and media type
|
|
277
|
+
match = re.match(r"data:(.+);base64,(.+)", image_url)
|
|
278
|
+
if match:
|
|
279
|
+
media_type, data = match.groups()
|
|
280
|
+
anthropic_content.append(
|
|
281
|
+
{
|
|
282
|
+
"type": "image",
|
|
283
|
+
"source": {
|
|
284
|
+
"type": "base64",
|
|
285
|
+
"media_type": media_type,
|
|
286
|
+
"data": data,
|
|
287
|
+
},
|
|
288
|
+
}
|
|
289
|
+
)
|
|
290
|
+
else:
|
|
291
|
+
# Regular URL
|
|
292
|
+
anthropic_content.append(
|
|
293
|
+
{
|
|
294
|
+
"type": "image",
|
|
295
|
+
"source": {
|
|
296
|
+
"type": "url",
|
|
297
|
+
"url": image_url,
|
|
298
|
+
},
|
|
299
|
+
}
|
|
300
|
+
)
|
|
301
|
+
elif item_type == "tool_use":
|
|
302
|
+
# Always include tool_use blocks
|
|
303
|
+
anthropic_content.append(item)
|
|
304
|
+
elif item_type == "tool_result":
|
|
305
|
+
# Check if this is a user message AND if the tool_use_id exists in the previous assistant message
|
|
306
|
+
tool_use_id = item.get("tool_use_id")
|
|
307
|
+
|
|
308
|
+
# Only include tool_result if it references a tool_use from the immediately preceding assistant message
|
|
309
|
+
if (
|
|
310
|
+
role == "user"
|
|
311
|
+
and tool_use_id
|
|
312
|
+
and tool_use_id in previous_assistant_tool_use_ids
|
|
313
|
+
):
|
|
314
|
+
anthropic_content.append(item)
|
|
315
|
+
logger.info(
|
|
316
|
+
f"Including tool_result with tool_use_id: {tool_use_id}"
|
|
317
|
+
)
|
|
318
|
+
else:
|
|
319
|
+
# Convert to text to preserve information
|
|
320
|
+
logger.warning(
|
|
321
|
+
f"Converting tool_result to text. Tool use ID {tool_use_id} not found in previous assistant message"
|
|
322
|
+
)
|
|
323
|
+
content_text = "Tool Result: "
|
|
324
|
+
if "content" in item:
|
|
325
|
+
if isinstance(item["content"], list):
|
|
326
|
+
for content_item in item["content"]:
|
|
327
|
+
if (
|
|
328
|
+
isinstance(content_item, dict)
|
|
329
|
+
and content_item.get("type") == "text"
|
|
330
|
+
):
|
|
331
|
+
content_text += content_item.get("text", "")
|
|
332
|
+
elif isinstance(item["content"], str):
|
|
333
|
+
content_text += item["content"]
|
|
334
|
+
anthropic_content.append({"type": "text", "text": content_text})
|
|
335
|
+
|
|
336
|
+
anthropic_msg["content"] = anthropic_content
|
|
337
|
+
|
|
338
|
+
result.append(anthropic_msg)
|
|
339
|
+
|
|
340
|
+
return result, system_content
|
|
341
|
+
|
|
342
|
+
def from_anthropic_format(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
343
|
+
"""Convert Anthropic format messages to standard OpenAI format.
|
|
344
|
+
|
|
345
|
+
Args:
|
|
346
|
+
messages: List of messages in Anthropic format
|
|
347
|
+
|
|
348
|
+
Returns:
|
|
349
|
+
List of messages in OpenAI format
|
|
350
|
+
"""
|
|
351
|
+
result = []
|
|
352
|
+
|
|
353
|
+
for msg in messages:
|
|
354
|
+
role = msg.get("role", "")
|
|
355
|
+
content = msg.get("content", [])
|
|
356
|
+
|
|
357
|
+
if role in ["user", "assistant"]:
|
|
358
|
+
openai_msg = {"role": role}
|
|
359
|
+
|
|
360
|
+
# Simple case: single text block
|
|
361
|
+
if len(content) == 1 and content[0].get("type") == "text":
|
|
362
|
+
openai_msg["content"] = content[0].get("text", "")
|
|
363
|
+
else:
|
|
364
|
+
# Complex case: multiple blocks or non-text
|
|
365
|
+
openai_content = []
|
|
366
|
+
for item in content:
|
|
367
|
+
item_type = item.get("type", "")
|
|
368
|
+
|
|
369
|
+
if item_type == "text":
|
|
370
|
+
openai_content.append({"type": "text", "text": item.get("text", "")})
|
|
371
|
+
elif item_type == "image":
|
|
372
|
+
# Convert Anthropic image to OpenAI format
|
|
373
|
+
source = item.get("source", {})
|
|
374
|
+
if source.get("type") == "base64":
|
|
375
|
+
media_type = source.get("media_type", "image/png")
|
|
376
|
+
data = source.get("data", "")
|
|
377
|
+
openai_content.append(
|
|
378
|
+
{
|
|
379
|
+
"type": "image_url",
|
|
380
|
+
"image_url": {"url": f"data:{media_type};base64,{data}"},
|
|
381
|
+
}
|
|
382
|
+
)
|
|
383
|
+
else:
|
|
384
|
+
# URL
|
|
385
|
+
openai_content.append(
|
|
386
|
+
{
|
|
387
|
+
"type": "image_url",
|
|
388
|
+
"image_url": {"url": source.get("url", "")},
|
|
389
|
+
}
|
|
390
|
+
)
|
|
391
|
+
elif item_type in ["tool_use", "tool_result"]:
|
|
392
|
+
# Pass through tool-related content
|
|
393
|
+
openai_content.append(item)
|
|
394
|
+
|
|
395
|
+
openai_msg["content"] = openai_content
|
|
396
|
+
|
|
397
|
+
result.append(openai_msg)
|
|
398
|
+
|
|
399
|
+
return result
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Provider-specific configurations and constants."""
|
|
2
|
+
|
|
3
|
+
from ..providers.omni.types import LLMProvider
|
|
4
|
+
|
|
5
|
+
# Default models for different providers
|
|
6
|
+
DEFAULT_MODELS = {
|
|
7
|
+
LLMProvider.OPENAI: "gpt-4o",
|
|
8
|
+
LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
# Map providers to their environment variable names
|
|
12
|
+
ENV_VARS = {
|
|
13
|
+
LLMProvider.OPENAI: "OPENAI_API_KEY",
|
|
14
|
+
LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
|
|
15
|
+
}
|
agent/core/types.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""Core type definitions."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, List, Optional, TypedDict, Union
|
|
4
|
+
from enum import Enum, auto
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class AgentLoop(Enum):
|
|
8
|
+
"""Enumeration of available loop types."""
|
|
9
|
+
|
|
10
|
+
ANTHROPIC = auto() # Anthropic implementation
|
|
11
|
+
OMNI = auto() # OmniLoop implementation
|
|
12
|
+
OPENAI = auto() # OpenAI implementation
|
|
13
|
+
# Add more loop types as needed
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AgentResponse(TypedDict, total=False):
|
|
17
|
+
"""Agent response format."""
|
|
18
|
+
|
|
19
|
+
id: str
|
|
20
|
+
object: str
|
|
21
|
+
created_at: int
|
|
22
|
+
status: str
|
|
23
|
+
error: Optional[str]
|
|
24
|
+
incomplete_details: Optional[Any]
|
|
25
|
+
instructions: Optional[Any]
|
|
26
|
+
max_output_tokens: Optional[int]
|
|
27
|
+
model: str
|
|
28
|
+
output: List[Dict[str, Any]]
|
|
29
|
+
parallel_tool_calls: bool
|
|
30
|
+
previous_response_id: Optional[str]
|
|
31
|
+
reasoning: Dict[str, str]
|
|
32
|
+
store: bool
|
|
33
|
+
temperature: float
|
|
34
|
+
text: Dict[str, Dict[str, str]]
|
|
35
|
+
tool_choice: str
|
|
36
|
+
tools: List[Dict[str, Union[str, int]]]
|
|
37
|
+
top_p: float
|
|
38
|
+
truncation: str
|
|
39
|
+
usage: Dict[str, Any]
|
|
40
|
+
user: Optional[str]
|
|
41
|
+
metadata: Dict[str, Any]
|
|
42
|
+
response: Dict[str, List[Dict[str, Any]]]
|
|
43
|
+
# Additional fields for error responses
|
|
44
|
+
role: str
|
|
45
|
+
content: Union[str, List[Dict[str, Any]]]
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
"""Core visualization utilities for agents."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import base64
|
|
5
|
+
from typing import Dict, Tuple
|
|
6
|
+
from PIL import Image, ImageDraw
|
|
7
|
+
from io import BytesIO
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def visualize_click(x: int, y: int, img_base64: str) -> Image.Image:
|
|
13
|
+
"""Visualize a click action by drawing a circle on the screenshot.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
x: X coordinate of the click
|
|
17
|
+
y: Y coordinate of the click
|
|
18
|
+
img_base64: Base64-encoded screenshot
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
PIL Image with visualization
|
|
22
|
+
"""
|
|
23
|
+
try:
|
|
24
|
+
# Decode the base64 image
|
|
25
|
+
image_data = base64.b64decode(img_base64)
|
|
26
|
+
img = Image.open(BytesIO(image_data))
|
|
27
|
+
|
|
28
|
+
# Create a copy to draw on
|
|
29
|
+
draw_img = img.copy()
|
|
30
|
+
draw = ImageDraw.Draw(draw_img)
|
|
31
|
+
|
|
32
|
+
# Draw a circle at the click location
|
|
33
|
+
radius = 15
|
|
34
|
+
draw.ellipse([(x - radius, y - radius), (x + radius, y + radius)], outline="red", width=3)
|
|
35
|
+
|
|
36
|
+
# Draw crosshairs
|
|
37
|
+
line_length = 20
|
|
38
|
+
draw.line([(x - line_length, y), (x + line_length, y)], fill="red", width=3)
|
|
39
|
+
draw.line([(x, y - line_length), (x, y + line_length)], fill="red", width=3)
|
|
40
|
+
|
|
41
|
+
return draw_img
|
|
42
|
+
except Exception as e:
|
|
43
|
+
logger.error(f"Error visualizing click: {str(e)}")
|
|
44
|
+
# Return a blank image as fallback
|
|
45
|
+
return Image.new("RGB", (800, 600), "white")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def visualize_scroll(direction: str, clicks: int, img_base64: str) -> Image.Image:
|
|
49
|
+
"""Visualize a scroll action by drawing arrows on the screenshot.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
direction: Direction of scroll ('up' or 'down')
|
|
53
|
+
clicks: Number of scroll clicks
|
|
54
|
+
img_base64: Base64-encoded screenshot
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
PIL Image with visualization
|
|
58
|
+
"""
|
|
59
|
+
try:
|
|
60
|
+
# Decode the base64 image
|
|
61
|
+
image_data = base64.b64decode(img_base64)
|
|
62
|
+
img = Image.open(BytesIO(image_data))
|
|
63
|
+
|
|
64
|
+
# Create a copy to draw on
|
|
65
|
+
draw_img = img.copy()
|
|
66
|
+
draw = ImageDraw.Draw(draw_img)
|
|
67
|
+
|
|
68
|
+
# Calculate parameters for visualization
|
|
69
|
+
width, height = img.size
|
|
70
|
+
center_x = width // 2
|
|
71
|
+
|
|
72
|
+
# Draw arrows to indicate scrolling
|
|
73
|
+
arrow_length = min(100, height // 4)
|
|
74
|
+
arrow_width = 30
|
|
75
|
+
num_arrows = min(clicks, 3) # Don't draw too many arrows
|
|
76
|
+
|
|
77
|
+
# Calculate starting position
|
|
78
|
+
if direction == "down":
|
|
79
|
+
start_y = height // 3
|
|
80
|
+
arrow_dir = 1 # Down
|
|
81
|
+
else:
|
|
82
|
+
start_y = height * 2 // 3
|
|
83
|
+
arrow_dir = -1 # Up
|
|
84
|
+
|
|
85
|
+
# Draw the arrows
|
|
86
|
+
for i in range(num_arrows):
|
|
87
|
+
y_pos = start_y + (i * arrow_length * arrow_dir * 0.7)
|
|
88
|
+
arrow_top = (center_x, y_pos)
|
|
89
|
+
arrow_bottom = (center_x, y_pos + arrow_length * arrow_dir)
|
|
90
|
+
|
|
91
|
+
# Draw the main line
|
|
92
|
+
draw.line([arrow_top, arrow_bottom], fill="red", width=5)
|
|
93
|
+
|
|
94
|
+
# Draw the arrowhead
|
|
95
|
+
arrowhead_size = 20
|
|
96
|
+
if direction == "down":
|
|
97
|
+
draw.line(
|
|
98
|
+
[
|
|
99
|
+
(center_x - arrow_width // 2, arrow_bottom[1] - arrowhead_size),
|
|
100
|
+
arrow_bottom,
|
|
101
|
+
(center_x + arrow_width // 2, arrow_bottom[1] - arrowhead_size),
|
|
102
|
+
],
|
|
103
|
+
fill="red",
|
|
104
|
+
width=5,
|
|
105
|
+
)
|
|
106
|
+
else:
|
|
107
|
+
draw.line(
|
|
108
|
+
[
|
|
109
|
+
(center_x - arrow_width // 2, arrow_bottom[1] + arrowhead_size),
|
|
110
|
+
arrow_bottom,
|
|
111
|
+
(center_x + arrow_width // 2, arrow_bottom[1] + arrowhead_size),
|
|
112
|
+
],
|
|
113
|
+
fill="red",
|
|
114
|
+
width=5,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
return draw_img
|
|
118
|
+
except Exception as e:
|
|
119
|
+
logger.error(f"Error visualizing scroll: {str(e)}")
|
|
120
|
+
# Return a blank image as fallback
|
|
121
|
+
return Image.new("RGB", (800, 600), "white")
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def calculate_element_center(bbox: Dict[str, float], width: int, height: int) -> Tuple[int, int]:
|
|
125
|
+
"""Calculate the center point of a UI element.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
bbox: Bounding box dictionary with x1, y1, x2, y2 coordinates (0-1 normalized)
|
|
129
|
+
width: Screen width in pixels
|
|
130
|
+
height: Screen height in pixels
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
(x, y) tuple with pixel coordinates
|
|
134
|
+
"""
|
|
135
|
+
center_x = int((bbox["x1"] + bbox["x2"]) / 2 * width)
|
|
136
|
+
center_y = int((bbox["y1"] + bbox["y2"]) / 2 * height)
|
|
137
|
+
return center_x, center_y
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class VisualizationHelper:
|
|
141
|
+
"""Helper class for visualizing agent actions."""
|
|
142
|
+
|
|
143
|
+
def __init__(self, agent):
|
|
144
|
+
"""Initialize visualization helper.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
agent: Reference to the agent that will use this helper
|
|
148
|
+
"""
|
|
149
|
+
self.agent = agent
|
|
150
|
+
|
|
151
|
+
def visualize_action(self, x: int, y: int, img_base64: str) -> None:
|
|
152
|
+
"""Visualize a click action by drawing on the screenshot."""
|
|
153
|
+
if (
|
|
154
|
+
not self.agent.save_trajectory
|
|
155
|
+
or not hasattr(self.agent, "experiment_manager")
|
|
156
|
+
or not self.agent.experiment_manager
|
|
157
|
+
):
|
|
158
|
+
return
|
|
159
|
+
|
|
160
|
+
try:
|
|
161
|
+
# Use the visualization utility
|
|
162
|
+
img = visualize_click(x, y, img_base64)
|
|
163
|
+
|
|
164
|
+
# Save the visualization
|
|
165
|
+
self.agent.experiment_manager.save_action_visualization(img, "click", f"x{x}_y{y}")
|
|
166
|
+
except Exception as e:
|
|
167
|
+
logger.error(f"Error visualizing action: {str(e)}")
|
|
168
|
+
|
|
169
|
+
def visualize_scroll(self, direction: str, clicks: int, img_base64: str) -> None:
|
|
170
|
+
"""Visualize a scroll action by drawing arrows on the screenshot."""
|
|
171
|
+
if (
|
|
172
|
+
not self.agent.save_trajectory
|
|
173
|
+
or not hasattr(self.agent, "experiment_manager")
|
|
174
|
+
or not self.agent.experiment_manager
|
|
175
|
+
):
|
|
176
|
+
return
|
|
177
|
+
|
|
178
|
+
try:
|
|
179
|
+
# Use the visualization utility
|
|
180
|
+
img = visualize_scroll(direction, clicks, img_base64)
|
|
181
|
+
|
|
182
|
+
# Save the visualization
|
|
183
|
+
self.agent.experiment_manager.save_action_visualization(
|
|
184
|
+
img, "scroll", f"{direction}_{clicks}"
|
|
185
|
+
)
|
|
186
|
+
except Exception as e:
|
|
187
|
+
logger.error(f"Error visualizing scroll: {str(e)}")
|
|
188
|
+
|
|
189
|
+
def save_action_visualization(
|
|
190
|
+
self, img: Image.Image, action_name: str, details: str = ""
|
|
191
|
+
) -> str:
|
|
192
|
+
"""Save a visualization of an action."""
|
|
193
|
+
if hasattr(self.agent, "experiment_manager") and self.agent.experiment_manager:
|
|
194
|
+
return self.agent.experiment_manager.save_action_visualization(
|
|
195
|
+
img, action_name, details
|
|
196
|
+
)
|
|
197
|
+
return ""
|