cua-agent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/README.md +63 -0
- agent/__init__.py +10 -0
- agent/core/README.md +101 -0
- agent/core/__init__.py +34 -0
- agent/core/agent.py +284 -0
- agent/core/base_agent.py +164 -0
- agent/core/callbacks.py +147 -0
- agent/core/computer_agent.py +69 -0
- agent/core/experiment.py +222 -0
- agent/core/factory.py +102 -0
- agent/core/loop.py +244 -0
- agent/core/messages.py +230 -0
- agent/core/tools/__init__.py +21 -0
- agent/core/tools/base.py +74 -0
- agent/core/tools/bash.py +52 -0
- agent/core/tools/collection.py +46 -0
- agent/core/tools/computer.py +113 -0
- agent/core/tools/edit.py +67 -0
- agent/core/tools/manager.py +56 -0
- agent/providers/__init__.py +4 -0
- agent/providers/anthropic/__init__.py +6 -0
- agent/providers/anthropic/api/client.py +222 -0
- agent/providers/anthropic/api/logging.py +150 -0
- agent/providers/anthropic/callbacks/manager.py +55 -0
- agent/providers/anthropic/loop.py +521 -0
- agent/providers/anthropic/messages/manager.py +110 -0
- agent/providers/anthropic/prompts.py +20 -0
- agent/providers/anthropic/tools/__init__.py +33 -0
- agent/providers/anthropic/tools/base.py +88 -0
- agent/providers/anthropic/tools/bash.py +163 -0
- agent/providers/anthropic/tools/collection.py +34 -0
- agent/providers/anthropic/tools/computer.py +550 -0
- agent/providers/anthropic/tools/edit.py +326 -0
- agent/providers/anthropic/tools/manager.py +54 -0
- agent/providers/anthropic/tools/run.py +42 -0
- agent/providers/anthropic/types.py +16 -0
- agent/providers/omni/__init__.py +27 -0
- agent/providers/omni/callbacks.py +78 -0
- agent/providers/omni/clients/anthropic.py +99 -0
- agent/providers/omni/clients/base.py +44 -0
- agent/providers/omni/clients/groq.py +101 -0
- agent/providers/omni/clients/openai.py +159 -0
- agent/providers/omni/clients/utils.py +25 -0
- agent/providers/omni/experiment.py +273 -0
- agent/providers/omni/image_utils.py +106 -0
- agent/providers/omni/loop.py +961 -0
- agent/providers/omni/messages.py +168 -0
- agent/providers/omni/parser.py +252 -0
- agent/providers/omni/prompts.py +78 -0
- agent/providers/omni/tool_manager.py +91 -0
- agent/providers/omni/tools/__init__.py +13 -0
- agent/providers/omni/tools/bash.py +69 -0
- agent/providers/omni/tools/computer.py +216 -0
- agent/providers/omni/tools/manager.py +83 -0
- agent/providers/omni/types.py +30 -0
- agent/providers/omni/utils.py +155 -0
- agent/providers/omni/visualization.py +130 -0
- agent/types/__init__.py +26 -0
- agent/types/base.py +52 -0
- agent/types/messages.py +36 -0
- agent/types/tools.py +32 -0
- cua_agent-0.1.0.dist-info/METADATA +44 -0
- cua_agent-0.1.0.dist-info/RECORD +65 -0
- cua_agent-0.1.0.dist-info/WHEEL +4 -0
- cua_agent-0.1.0.dist-info/entry_points.txt +4 -0
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""Groq client implementation."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Dict, List, Optional, Any, Tuple
|
|
6
|
+
|
|
7
|
+
from groq import Groq
|
|
8
|
+
import re
|
|
9
|
+
from .utils import is_image_path
|
|
10
|
+
from .base import BaseOmniClient
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class GroqClient(BaseOmniClient):
|
|
16
|
+
"""Client for making Groq API calls."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
api_key: Optional[str] = None,
|
|
21
|
+
model: str = "deepseek-r1-distill-llama-70b",
|
|
22
|
+
max_tokens: int = 4096,
|
|
23
|
+
temperature: float = 0.6,
|
|
24
|
+
):
|
|
25
|
+
"""Initialize Groq client.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
api_key: Groq API key (if not provided, will try to get from env)
|
|
29
|
+
model: Model name to use
|
|
30
|
+
max_tokens: Maximum tokens to generate
|
|
31
|
+
temperature: Temperature for sampling
|
|
32
|
+
"""
|
|
33
|
+
super().__init__(api_key=api_key, model=model)
|
|
34
|
+
self.api_key = api_key or os.getenv("GROQ_API_KEY")
|
|
35
|
+
if not self.api_key:
|
|
36
|
+
raise ValueError("No Groq API key provided")
|
|
37
|
+
|
|
38
|
+
self.max_tokens = max_tokens
|
|
39
|
+
self.temperature = temperature
|
|
40
|
+
self.client = Groq(api_key=self.api_key)
|
|
41
|
+
self.model: str = model # Add explicit type annotation
|
|
42
|
+
|
|
43
|
+
def run_interleaved(
|
|
44
|
+
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
|
45
|
+
) -> tuple[str, int]:
|
|
46
|
+
"""Run interleaved chat completion.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
messages: List of message dicts
|
|
50
|
+
system: System prompt
|
|
51
|
+
max_tokens: Optional max tokens override
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Tuple of (response text, token usage)
|
|
55
|
+
"""
|
|
56
|
+
# Avoid using system messages for R1
|
|
57
|
+
final_messages = [{"role": "user", "content": system}]
|
|
58
|
+
|
|
59
|
+
# Process messages
|
|
60
|
+
if isinstance(messages, list):
|
|
61
|
+
for item in messages:
|
|
62
|
+
if isinstance(item, dict):
|
|
63
|
+
# For dict items, concatenate all text content, ignoring images
|
|
64
|
+
text_contents = []
|
|
65
|
+
for cnt in item["content"]:
|
|
66
|
+
if isinstance(cnt, str):
|
|
67
|
+
if not is_image_path(cnt): # Skip image paths
|
|
68
|
+
text_contents.append(cnt)
|
|
69
|
+
else:
|
|
70
|
+
text_contents.append(str(cnt))
|
|
71
|
+
|
|
72
|
+
if text_contents: # Only add if there's text content
|
|
73
|
+
message = {"role": "user", "content": " ".join(text_contents)}
|
|
74
|
+
final_messages.append(message)
|
|
75
|
+
else: # str
|
|
76
|
+
message = {"role": "user", "content": item}
|
|
77
|
+
final_messages.append(message)
|
|
78
|
+
|
|
79
|
+
elif isinstance(messages, str):
|
|
80
|
+
final_messages.append({"role": "user", "content": messages})
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
completion = self.client.chat.completions.create( # type: ignore
|
|
84
|
+
model=self.model,
|
|
85
|
+
messages=final_messages, # type: ignore
|
|
86
|
+
temperature=self.temperature,
|
|
87
|
+
max_tokens=max_tokens or self.max_tokens,
|
|
88
|
+
top_p=0.95,
|
|
89
|
+
stream=False,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
response = completion.choices[0].message.content
|
|
93
|
+
final_answer = response.split("</think>\n")[-1] if "</think>" in response else response
|
|
94
|
+
final_answer = final_answer.replace("<output>", "").replace("</output>", "")
|
|
95
|
+
token_usage = completion.usage.total_tokens
|
|
96
|
+
|
|
97
|
+
return final_answer, token_usage
|
|
98
|
+
|
|
99
|
+
except Exception as e:
|
|
100
|
+
logger.error(f"Error in Groq API call: {e}")
|
|
101
|
+
raise
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
"""OpenAI client implementation."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Dict, List, Optional, Any
|
|
6
|
+
import aiohttp
|
|
7
|
+
import base64
|
|
8
|
+
import re
|
|
9
|
+
import json
|
|
10
|
+
import ssl
|
|
11
|
+
import certifi
|
|
12
|
+
from datetime import datetime
|
|
13
|
+
from .base import BaseOmniClient
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
# OpenAI specific client for the OmniLoop
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class OpenAIClient(BaseOmniClient):
|
|
21
|
+
"""OpenAI vision API client implementation."""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
api_key: Optional[str] = None,
|
|
26
|
+
model: str = "gpt-4o",
|
|
27
|
+
provider_base_url: str = "https://api.openai.com/v1",
|
|
28
|
+
max_tokens: int = 4096,
|
|
29
|
+
temperature: float = 0.0,
|
|
30
|
+
):
|
|
31
|
+
"""Initialize the OpenAI client.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
api_key: OpenAI API key
|
|
35
|
+
model: Model to use
|
|
36
|
+
provider_base_url: API endpoint
|
|
37
|
+
max_tokens: Maximum tokens to generate
|
|
38
|
+
temperature: Generation temperature
|
|
39
|
+
"""
|
|
40
|
+
super().__init__(api_key=api_key, model=model)
|
|
41
|
+
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
42
|
+
if not self.api_key:
|
|
43
|
+
raise ValueError("No OpenAI API key provided")
|
|
44
|
+
|
|
45
|
+
self.model = model
|
|
46
|
+
self.provider_base_url = provider_base_url
|
|
47
|
+
self.max_tokens = max_tokens
|
|
48
|
+
self.temperature = temperature
|
|
49
|
+
|
|
50
|
+
def _extract_base64_image(self, text: str) -> Optional[str]:
|
|
51
|
+
"""Extract base64 image data from an HTML img tag."""
|
|
52
|
+
pattern = r'data:image/[^;]+;base64,([^"]+)'
|
|
53
|
+
match = re.search(pattern, text)
|
|
54
|
+
return match.group(1) if match else None
|
|
55
|
+
|
|
56
|
+
def _get_loggable_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
57
|
+
"""Create a loggable version of messages with image data truncated."""
|
|
58
|
+
loggable_messages = []
|
|
59
|
+
for msg in messages:
|
|
60
|
+
if isinstance(msg.get("content"), list):
|
|
61
|
+
new_content = []
|
|
62
|
+
for content in msg["content"]:
|
|
63
|
+
if content.get("type") == "image":
|
|
64
|
+
new_content.append(
|
|
65
|
+
{"type": "image", "image_url": {"url": "[BASE64_IMAGE_DATA]"}}
|
|
66
|
+
)
|
|
67
|
+
else:
|
|
68
|
+
new_content.append(content)
|
|
69
|
+
loggable_messages.append({"role": msg["role"], "content": new_content})
|
|
70
|
+
else:
|
|
71
|
+
loggable_messages.append(msg)
|
|
72
|
+
return loggable_messages
|
|
73
|
+
|
|
74
|
+
async def run_interleaved(
|
|
75
|
+
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
|
76
|
+
) -> Dict[str, Any]:
|
|
77
|
+
"""Run interleaved chat completion.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
messages: List of message dicts
|
|
81
|
+
system: System prompt
|
|
82
|
+
max_tokens: Optional max tokens override
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Response dict
|
|
86
|
+
"""
|
|
87
|
+
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
|
88
|
+
|
|
89
|
+
final_messages = [{"role": "system", "content": system}]
|
|
90
|
+
|
|
91
|
+
# Process messages
|
|
92
|
+
for item in messages:
|
|
93
|
+
if isinstance(item, dict):
|
|
94
|
+
if isinstance(item["content"], list):
|
|
95
|
+
# Content is already in the correct format
|
|
96
|
+
final_messages.append(item)
|
|
97
|
+
else:
|
|
98
|
+
# Single string content, check for image
|
|
99
|
+
base64_img = self._extract_base64_image(item["content"])
|
|
100
|
+
if base64_img:
|
|
101
|
+
message = {
|
|
102
|
+
"role": item["role"],
|
|
103
|
+
"content": [
|
|
104
|
+
{
|
|
105
|
+
"type": "image_url",
|
|
106
|
+
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
|
|
107
|
+
}
|
|
108
|
+
],
|
|
109
|
+
}
|
|
110
|
+
else:
|
|
111
|
+
message = {
|
|
112
|
+
"role": item["role"],
|
|
113
|
+
"content": [{"type": "text", "text": item["content"]}],
|
|
114
|
+
}
|
|
115
|
+
final_messages.append(message)
|
|
116
|
+
else:
|
|
117
|
+
# String content, check for image
|
|
118
|
+
base64_img = self._extract_base64_image(item)
|
|
119
|
+
if base64_img:
|
|
120
|
+
message = {
|
|
121
|
+
"role": "user",
|
|
122
|
+
"content": [
|
|
123
|
+
{
|
|
124
|
+
"type": "image_url",
|
|
125
|
+
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
|
|
126
|
+
}
|
|
127
|
+
],
|
|
128
|
+
}
|
|
129
|
+
else:
|
|
130
|
+
message = {"role": "user", "content": [{"type": "text", "text": item}]}
|
|
131
|
+
final_messages.append(message)
|
|
132
|
+
|
|
133
|
+
payload = {"model": self.model, "messages": final_messages, "temperature": self.temperature}
|
|
134
|
+
|
|
135
|
+
if "o1" in self.model or "o3-mini" in self.model:
|
|
136
|
+
payload["reasoning_effort"] = "low"
|
|
137
|
+
payload["max_completion_tokens"] = max_tokens or self.max_tokens
|
|
138
|
+
else:
|
|
139
|
+
payload["max_tokens"] = max_tokens or self.max_tokens
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
async with aiohttp.ClientSession() as session:
|
|
143
|
+
async with session.post(
|
|
144
|
+
f"{self.provider_base_url}/chat/completions", headers=headers, json=payload
|
|
145
|
+
) as response:
|
|
146
|
+
response_json = await response.json()
|
|
147
|
+
|
|
148
|
+
if response.status != 200:
|
|
149
|
+
error_msg = response_json.get("error", {}).get(
|
|
150
|
+
"message", str(response_json)
|
|
151
|
+
)
|
|
152
|
+
logger.error(f"Error in OpenAI API call: {error_msg}")
|
|
153
|
+
raise Exception(f"OpenAI API error: {error_msg}")
|
|
154
|
+
|
|
155
|
+
return response_json
|
|
156
|
+
|
|
157
|
+
except Exception as e:
|
|
158
|
+
logger.error(f"Error in OpenAI API call: {str(e)}")
|
|
159
|
+
raise
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
|
|
3
|
+
def is_image_path(text: str) -> bool:
|
|
4
|
+
"""Check if a text string is an image file path.
|
|
5
|
+
|
|
6
|
+
Args:
|
|
7
|
+
text: Text string to check
|
|
8
|
+
|
|
9
|
+
Returns:
|
|
10
|
+
True if text ends with image extension, False otherwise
|
|
11
|
+
"""
|
|
12
|
+
image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".tif")
|
|
13
|
+
return text.endswith(image_extensions)
|
|
14
|
+
|
|
15
|
+
def encode_image(image_path: str) -> str:
|
|
16
|
+
"""Encode image file to base64.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
image_path: Path to image file
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
Base64 encoded image string
|
|
23
|
+
"""
|
|
24
|
+
with open(image_path, "rb") as image_file:
|
|
25
|
+
return base64.b64encode(image_file.read()).decode("utf-8")
|
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
"""Experiment management for the Cua provider."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import logging
|
|
5
|
+
import copy
|
|
6
|
+
import base64
|
|
7
|
+
from io import BytesIO
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from typing import Any, Dict, List, Optional
|
|
10
|
+
from PIL import Image
|
|
11
|
+
import json
|
|
12
|
+
import time
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ExperimentManager:
|
|
18
|
+
"""Manages experiment directories and logging for the agent."""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
base_dir: Optional[str] = None,
|
|
23
|
+
only_n_most_recent_images: Optional[int] = None,
|
|
24
|
+
):
|
|
25
|
+
"""Initialize the experiment manager.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
base_dir: Base directory for saving experiment data
|
|
29
|
+
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
|
|
30
|
+
"""
|
|
31
|
+
self.base_dir = base_dir
|
|
32
|
+
self.only_n_most_recent_images = only_n_most_recent_images
|
|
33
|
+
self.run_dir = None
|
|
34
|
+
self.current_turn_dir = None
|
|
35
|
+
self.turn_count = 0
|
|
36
|
+
self.screenshot_count = 0
|
|
37
|
+
# Track all screenshots for potential API request inclusion
|
|
38
|
+
self.screenshot_paths = []
|
|
39
|
+
|
|
40
|
+
# Set up experiment directories if base_dir is provided
|
|
41
|
+
if self.base_dir:
|
|
42
|
+
self.setup_experiment_dirs()
|
|
43
|
+
|
|
44
|
+
def setup_experiment_dirs(self) -> None:
|
|
45
|
+
"""Setup the experiment directory structure."""
|
|
46
|
+
if not self.base_dir:
|
|
47
|
+
return
|
|
48
|
+
|
|
49
|
+
# Create base experiments directory if it doesn't exist
|
|
50
|
+
os.makedirs(self.base_dir, exist_ok=True)
|
|
51
|
+
|
|
52
|
+
# Use the base_dir directly as the run_dir
|
|
53
|
+
self.run_dir = self.base_dir
|
|
54
|
+
logger.info(f"Using directory for experiment: {self.run_dir}")
|
|
55
|
+
|
|
56
|
+
# Create first turn directory
|
|
57
|
+
self.create_turn_dir()
|
|
58
|
+
|
|
59
|
+
def create_turn_dir(self) -> None:
|
|
60
|
+
"""Create a new directory for the current turn."""
|
|
61
|
+
if not self.run_dir:
|
|
62
|
+
return
|
|
63
|
+
|
|
64
|
+
self.turn_count += 1
|
|
65
|
+
self.current_turn_dir = os.path.join(self.run_dir, f"turn_{self.turn_count:03d}")
|
|
66
|
+
os.makedirs(self.current_turn_dir, exist_ok=True)
|
|
67
|
+
logger.info(f"Created turn directory: {self.current_turn_dir}")
|
|
68
|
+
|
|
69
|
+
def sanitize_log_data(self, data: Any) -> Any:
|
|
70
|
+
"""Sanitize data for logging by removing large base64 strings.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
data: Data to sanitize (dict, list, or primitive)
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Sanitized copy of the data
|
|
77
|
+
"""
|
|
78
|
+
if isinstance(data, dict):
|
|
79
|
+
result = copy.deepcopy(data)
|
|
80
|
+
|
|
81
|
+
# Handle nested dictionaries and lists
|
|
82
|
+
for key, value in result.items():
|
|
83
|
+
# Process content arrays that contain image data
|
|
84
|
+
if key == "content" and isinstance(value, list):
|
|
85
|
+
for i, item in enumerate(value):
|
|
86
|
+
if isinstance(item, dict):
|
|
87
|
+
# Handle Anthropic format
|
|
88
|
+
if item.get("type") == "image" and isinstance(item.get("source"), dict):
|
|
89
|
+
source = item["source"]
|
|
90
|
+
if "data" in source and isinstance(source["data"], str):
|
|
91
|
+
# Replace base64 data with a placeholder and length info
|
|
92
|
+
data_len = len(source["data"])
|
|
93
|
+
source["data"] = f"[BASE64_IMAGE_DATA_LENGTH_{data_len}]"
|
|
94
|
+
|
|
95
|
+
# Handle OpenAI format
|
|
96
|
+
elif item.get("type") == "image_url" and isinstance(
|
|
97
|
+
item.get("image_url"), dict
|
|
98
|
+
):
|
|
99
|
+
url_dict = item["image_url"]
|
|
100
|
+
if "url" in url_dict and isinstance(url_dict["url"], str):
|
|
101
|
+
url = url_dict["url"]
|
|
102
|
+
if url.startswith("data:"):
|
|
103
|
+
# Replace base64 data with placeholder
|
|
104
|
+
data_len = len(url)
|
|
105
|
+
url_dict["url"] = f"[BASE64_IMAGE_URL_LENGTH_{data_len}]"
|
|
106
|
+
|
|
107
|
+
# Handle other nested structures recursively
|
|
108
|
+
if isinstance(value, dict):
|
|
109
|
+
result[key] = self.sanitize_log_data(value)
|
|
110
|
+
elif isinstance(value, list):
|
|
111
|
+
result[key] = [self.sanitize_log_data(item) for item in value]
|
|
112
|
+
|
|
113
|
+
return result
|
|
114
|
+
elif isinstance(data, list):
|
|
115
|
+
return [self.sanitize_log_data(item) for item in data]
|
|
116
|
+
else:
|
|
117
|
+
return data
|
|
118
|
+
|
|
119
|
+
def save_debug_image(self, image_data: str, filename: str) -> None:
|
|
120
|
+
"""Save a debug image to the experiment directory.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
image_data: Base64 encoded image data
|
|
124
|
+
filename: Filename to save the image as
|
|
125
|
+
"""
|
|
126
|
+
# Since we no longer want to use the images/ folder, we'll skip this functionality
|
|
127
|
+
return
|
|
128
|
+
|
|
129
|
+
def save_screenshot(self, img_base64: str, action_type: str = "") -> None:
|
|
130
|
+
"""Save a screenshot to the experiment directory.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
img_base64: Base64 encoded screenshot
|
|
134
|
+
action_type: Type of action that triggered the screenshot
|
|
135
|
+
"""
|
|
136
|
+
if not self.current_turn_dir:
|
|
137
|
+
return
|
|
138
|
+
|
|
139
|
+
try:
|
|
140
|
+
# Increment screenshot counter
|
|
141
|
+
self.screenshot_count += 1
|
|
142
|
+
|
|
143
|
+
# Create a descriptive filename
|
|
144
|
+
timestamp = int(time.time() * 1000)
|
|
145
|
+
action_suffix = f"_{action_type}" if action_type else ""
|
|
146
|
+
filename = f"screenshot_{self.screenshot_count:03d}{action_suffix}_{timestamp}.png"
|
|
147
|
+
|
|
148
|
+
# Save directly to the turn directory (no screenshots subdirectory)
|
|
149
|
+
filepath = os.path.join(self.current_turn_dir, filename)
|
|
150
|
+
|
|
151
|
+
# Save the screenshot
|
|
152
|
+
img_data = base64.b64decode(img_base64)
|
|
153
|
+
with open(filepath, "wb") as f:
|
|
154
|
+
f.write(img_data)
|
|
155
|
+
|
|
156
|
+
# Keep track of the file path for reference
|
|
157
|
+
self.screenshot_paths.append(filepath)
|
|
158
|
+
|
|
159
|
+
return filepath
|
|
160
|
+
except Exception as e:
|
|
161
|
+
logger.error(f"Error saving screenshot: {str(e)}")
|
|
162
|
+
return None
|
|
163
|
+
|
|
164
|
+
def should_save_debug_image(self) -> bool:
|
|
165
|
+
"""Determine if debug images should be saved.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
Boolean indicating if debug images should be saved
|
|
169
|
+
"""
|
|
170
|
+
# We no longer need to save debug images, so always return False
|
|
171
|
+
return False
|
|
172
|
+
|
|
173
|
+
def save_action_visualization(
|
|
174
|
+
self, img: Image.Image, action_name: str, details: str = ""
|
|
175
|
+
) -> str:
|
|
176
|
+
"""Save a visualization of an action.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
img: Image to save
|
|
180
|
+
action_name: Name of the action
|
|
181
|
+
details: Additional details about the action
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
Path to the saved image
|
|
185
|
+
"""
|
|
186
|
+
if not self.current_turn_dir:
|
|
187
|
+
return ""
|
|
188
|
+
|
|
189
|
+
try:
|
|
190
|
+
# Create a descriptive filename
|
|
191
|
+
timestamp = int(time.time() * 1000)
|
|
192
|
+
details_suffix = f"_{details}" if details else ""
|
|
193
|
+
filename = f"vis_{action_name}{details_suffix}_{timestamp}.png"
|
|
194
|
+
|
|
195
|
+
# Save directly to the turn directory (no visualizations subdirectory)
|
|
196
|
+
filepath = os.path.join(self.current_turn_dir, filename)
|
|
197
|
+
|
|
198
|
+
# Save the image
|
|
199
|
+
img.save(filepath)
|
|
200
|
+
|
|
201
|
+
# Keep track of the file path for cleanup
|
|
202
|
+
self.screenshot_paths.append(filepath)
|
|
203
|
+
|
|
204
|
+
return filepath
|
|
205
|
+
except Exception as e:
|
|
206
|
+
logger.error(f"Error saving action visualization: {str(e)}")
|
|
207
|
+
return ""
|
|
208
|
+
|
|
209
|
+
def extract_and_save_images(self, data: Any, prefix: str) -> None:
|
|
210
|
+
"""Extract and save images from response data.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
data: Response data to extract images from
|
|
214
|
+
prefix: Prefix for saved image filenames
|
|
215
|
+
"""
|
|
216
|
+
# Since we no longer want to save extracted images separately,
|
|
217
|
+
# we'll skip this functionality entirely
|
|
218
|
+
return
|
|
219
|
+
|
|
220
|
+
def log_api_call(
|
|
221
|
+
self,
|
|
222
|
+
call_type: str,
|
|
223
|
+
request: Any,
|
|
224
|
+
provider: str,
|
|
225
|
+
model: str,
|
|
226
|
+
response: Any = None,
|
|
227
|
+
error: Optional[Exception] = None,
|
|
228
|
+
) -> None:
|
|
229
|
+
"""Log API call details to file.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
call_type: Type of API call (e.g., 'request', 'response', 'error')
|
|
233
|
+
request: The API request data
|
|
234
|
+
provider: The AI provider used
|
|
235
|
+
model: The AI model used
|
|
236
|
+
response: Optional API response data
|
|
237
|
+
error: Optional error information
|
|
238
|
+
"""
|
|
239
|
+
if not self.current_turn_dir:
|
|
240
|
+
return
|
|
241
|
+
|
|
242
|
+
try:
|
|
243
|
+
# Create a unique filename with timestamp
|
|
244
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
245
|
+
filename = f"api_call_{timestamp}_{call_type}.json"
|
|
246
|
+
filepath = os.path.join(self.current_turn_dir, filename)
|
|
247
|
+
|
|
248
|
+
# Sanitize data to remove large base64 strings
|
|
249
|
+
sanitized_request = self.sanitize_log_data(request)
|
|
250
|
+
sanitized_response = self.sanitize_log_data(response) if response is not None else None
|
|
251
|
+
|
|
252
|
+
# Prepare log data
|
|
253
|
+
log_data = {
|
|
254
|
+
"timestamp": timestamp,
|
|
255
|
+
"provider": provider,
|
|
256
|
+
"model": model,
|
|
257
|
+
"type": call_type,
|
|
258
|
+
"request": sanitized_request,
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
if sanitized_response is not None:
|
|
262
|
+
log_data["response"] = sanitized_response
|
|
263
|
+
if error is not None:
|
|
264
|
+
log_data["error"] = str(error)
|
|
265
|
+
|
|
266
|
+
# Write to file
|
|
267
|
+
with open(filepath, "w") as f:
|
|
268
|
+
json.dump(log_data, f, indent=2, default=str)
|
|
269
|
+
|
|
270
|
+
logger.info(f"Logged API {call_type} to {filepath}")
|
|
271
|
+
|
|
272
|
+
except Exception as e:
|
|
273
|
+
logger.error(f"Error logging API call: {str(e)}")
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""Image processing utilities for the Cua provider."""
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import logging
|
|
5
|
+
import re
|
|
6
|
+
from io import BytesIO
|
|
7
|
+
from typing import Optional, Tuple
|
|
8
|
+
from PIL import Image
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def decode_base64_image(img_base64: str) -> Optional[Image.Image]:
|
|
14
|
+
"""Decode a base64 encoded image to a PIL Image.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
img_base64: Base64 encoded image, may include data URL prefix
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
PIL Image or None if decoding fails
|
|
21
|
+
"""
|
|
22
|
+
try:
|
|
23
|
+
# Remove data URL prefix if present
|
|
24
|
+
if img_base64.startswith("data:image"):
|
|
25
|
+
img_base64 = img_base64.split(",")[1]
|
|
26
|
+
|
|
27
|
+
# Decode base64 to bytes
|
|
28
|
+
img_data = base64.b64decode(img_base64)
|
|
29
|
+
|
|
30
|
+
# Convert bytes to PIL Image
|
|
31
|
+
return Image.open(BytesIO(img_data))
|
|
32
|
+
except Exception as e:
|
|
33
|
+
logger.error(f"Error decoding base64 image: {str(e)}")
|
|
34
|
+
return None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def encode_image_base64(img: Image.Image, format: str = "PNG") -> str:
|
|
38
|
+
"""Encode a PIL Image to base64.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
img: PIL Image to encode
|
|
42
|
+
format: Image format (PNG, JPEG, etc.)
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Base64 encoded image string
|
|
46
|
+
"""
|
|
47
|
+
try:
|
|
48
|
+
buffered = BytesIO()
|
|
49
|
+
img.save(buffered, format=format)
|
|
50
|
+
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
|
51
|
+
except Exception as e:
|
|
52
|
+
logger.error(f"Error encoding image to base64: {str(e)}")
|
|
53
|
+
return ""
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def clean_base64_data(img_base64: str) -> str:
|
|
57
|
+
"""Clean base64 image data by removing data URL prefix.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
img_base64: Base64 encoded image, may include data URL prefix
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
Clean base64 string without prefix
|
|
64
|
+
"""
|
|
65
|
+
if img_base64.startswith("data:image"):
|
|
66
|
+
return img_base64.split(",")[1]
|
|
67
|
+
return img_base64
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def extract_base64_from_text(text: str) -> Optional[str]:
|
|
71
|
+
"""Extract base64 image data from a text string.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
text: Text potentially containing base64 image data
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Base64 string or None if not found
|
|
78
|
+
"""
|
|
79
|
+
# Look for data URL pattern
|
|
80
|
+
data_url_pattern = r"data:image/[^;]+;base64,([a-zA-Z0-9+/=]+)"
|
|
81
|
+
match = re.search(data_url_pattern, text)
|
|
82
|
+
if match:
|
|
83
|
+
return match.group(1)
|
|
84
|
+
|
|
85
|
+
# Look for plain base64 pattern (basic heuristic)
|
|
86
|
+
base64_pattern = r"([a-zA-Z0-9+/=]{100,})"
|
|
87
|
+
match = re.search(base64_pattern, text)
|
|
88
|
+
if match:
|
|
89
|
+
return match.group(1)
|
|
90
|
+
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def get_image_dimensions(img_base64: str) -> Tuple[int, int]:
|
|
95
|
+
"""Get the dimensions of a base64 encoded image.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
img_base64: Base64 encoded image
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
Tuple of (width, height) or (0, 0) if decoding fails
|
|
102
|
+
"""
|
|
103
|
+
img = decode_base64_image(img_base64)
|
|
104
|
+
if img:
|
|
105
|
+
return img.size
|
|
106
|
+
return (0, 0)
|