cua-agent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/README.md +63 -0
- agent/__init__.py +10 -0
- agent/core/README.md +101 -0
- agent/core/__init__.py +34 -0
- agent/core/agent.py +284 -0
- agent/core/base_agent.py +164 -0
- agent/core/callbacks.py +147 -0
- agent/core/computer_agent.py +69 -0
- agent/core/experiment.py +222 -0
- agent/core/factory.py +102 -0
- agent/core/loop.py +244 -0
- agent/core/messages.py +230 -0
- agent/core/tools/__init__.py +21 -0
- agent/core/tools/base.py +74 -0
- agent/core/tools/bash.py +52 -0
- agent/core/tools/collection.py +46 -0
- agent/core/tools/computer.py +113 -0
- agent/core/tools/edit.py +67 -0
- agent/core/tools/manager.py +56 -0
- agent/providers/__init__.py +4 -0
- agent/providers/anthropic/__init__.py +6 -0
- agent/providers/anthropic/api/client.py +222 -0
- agent/providers/anthropic/api/logging.py +150 -0
- agent/providers/anthropic/callbacks/manager.py +55 -0
- agent/providers/anthropic/loop.py +521 -0
- agent/providers/anthropic/messages/manager.py +110 -0
- agent/providers/anthropic/prompts.py +20 -0
- agent/providers/anthropic/tools/__init__.py +33 -0
- agent/providers/anthropic/tools/base.py +88 -0
- agent/providers/anthropic/tools/bash.py +163 -0
- agent/providers/anthropic/tools/collection.py +34 -0
- agent/providers/anthropic/tools/computer.py +550 -0
- agent/providers/anthropic/tools/edit.py +326 -0
- agent/providers/anthropic/tools/manager.py +54 -0
- agent/providers/anthropic/tools/run.py +42 -0
- agent/providers/anthropic/types.py +16 -0
- agent/providers/omni/__init__.py +27 -0
- agent/providers/omni/callbacks.py +78 -0
- agent/providers/omni/clients/anthropic.py +99 -0
- agent/providers/omni/clients/base.py +44 -0
- agent/providers/omni/clients/groq.py +101 -0
- agent/providers/omni/clients/openai.py +159 -0
- agent/providers/omni/clients/utils.py +25 -0
- agent/providers/omni/experiment.py +273 -0
- agent/providers/omni/image_utils.py +106 -0
- agent/providers/omni/loop.py +961 -0
- agent/providers/omni/messages.py +168 -0
- agent/providers/omni/parser.py +252 -0
- agent/providers/omni/prompts.py +78 -0
- agent/providers/omni/tool_manager.py +91 -0
- agent/providers/omni/tools/__init__.py +13 -0
- agent/providers/omni/tools/bash.py +69 -0
- agent/providers/omni/tools/computer.py +216 -0
- agent/providers/omni/tools/manager.py +83 -0
- agent/providers/omni/types.py +30 -0
- agent/providers/omni/utils.py +155 -0
- agent/providers/omni/visualization.py +130 -0
- agent/types/__init__.py +26 -0
- agent/types/base.py +52 -0
- agent/types/messages.py +36 -0
- agent/types/tools.py +32 -0
- cua_agent-0.1.0.dist-info/METADATA +44 -0
- cua_agent-0.1.0.dist-info/RECORD +65 -0
- cua_agent-0.1.0.dist-info/WHEEL +4 -0
- cua_agent-0.1.0.dist-info/entry_points.txt +4 -0
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""Provider-agnostic implementation of the ComputerTool."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import base64
|
|
5
|
+
import io
|
|
6
|
+
from typing import Any, Dict
|
|
7
|
+
|
|
8
|
+
from PIL import Image
|
|
9
|
+
from computer.computer import Computer
|
|
10
|
+
|
|
11
|
+
from ....core.tools.computer import BaseComputerTool
|
|
12
|
+
from ....core.tools import ToolResult, ToolError
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class OmniComputerTool(BaseComputerTool):
|
|
16
|
+
"""A provider-agnostic implementation of the computer tool."""
|
|
17
|
+
|
|
18
|
+
name = "computer"
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
def __init__(self, computer: Computer):
|
|
22
|
+
"""Initialize the ComputerTool.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
computer: Computer instance for screen interactions
|
|
26
|
+
"""
|
|
27
|
+
super().__init__(computer)
|
|
28
|
+
# Initialize dimensions to None, will be set in initialize_dimensions
|
|
29
|
+
self.width = None
|
|
30
|
+
self.height = None
|
|
31
|
+
self.display_num = None
|
|
32
|
+
|
|
33
|
+
def to_params(self) -> Dict[str, Any]:
|
|
34
|
+
"""Convert tool to provider-agnostic parameters.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
Dictionary with tool parameters
|
|
38
|
+
"""
|
|
39
|
+
return {
|
|
40
|
+
"name": self.name,
|
|
41
|
+
"description": "A tool that allows the agent to interact with the screen, keyboard, and mouse",
|
|
42
|
+
"parameters": {
|
|
43
|
+
"action": {
|
|
44
|
+
"type": "string",
|
|
45
|
+
"enum": [
|
|
46
|
+
"key",
|
|
47
|
+
"type",
|
|
48
|
+
"mouse_move",
|
|
49
|
+
"left_click",
|
|
50
|
+
"left_click_drag",
|
|
51
|
+
"right_click",
|
|
52
|
+
"middle_click",
|
|
53
|
+
"double_click",
|
|
54
|
+
"screenshot",
|
|
55
|
+
"cursor_position",
|
|
56
|
+
"scroll",
|
|
57
|
+
],
|
|
58
|
+
"description": "The action to perform on the computer",
|
|
59
|
+
},
|
|
60
|
+
"text": {
|
|
61
|
+
"type": "string",
|
|
62
|
+
"description": "Text to type or key to press, required for 'key' and 'type' actions",
|
|
63
|
+
},
|
|
64
|
+
"coordinate": {
|
|
65
|
+
"type": "array",
|
|
66
|
+
"items": {"type": "integer"},
|
|
67
|
+
"description": "X,Y coordinates for mouse actions like click and move",
|
|
68
|
+
},
|
|
69
|
+
"direction": {
|
|
70
|
+
"type": "string",
|
|
71
|
+
"enum": ["up", "down"],
|
|
72
|
+
"description": "Direction to scroll, used with the 'scroll' action",
|
|
73
|
+
},
|
|
74
|
+
"amount": {
|
|
75
|
+
"type": "integer",
|
|
76
|
+
"description": "Amount to scroll, used with the 'scroll' action",
|
|
77
|
+
},
|
|
78
|
+
},
|
|
79
|
+
**self.options,
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
async def __call__(self, **kwargs) -> ToolResult:
|
|
83
|
+
"""Execute the computer tool with the provided arguments.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
action: The action to perform
|
|
87
|
+
text: Text to type or key to press (for key/type actions)
|
|
88
|
+
coordinate: X,Y coordinates (for mouse actions)
|
|
89
|
+
direction: Direction to scroll (for scroll action)
|
|
90
|
+
amount: Amount to scroll (for scroll action)
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
ToolResult with the action output and optional screenshot
|
|
94
|
+
"""
|
|
95
|
+
# Ensure dimensions are initialized
|
|
96
|
+
if self.width is None or self.height is None:
|
|
97
|
+
await self.initialize_dimensions()
|
|
98
|
+
|
|
99
|
+
action = kwargs.get("action")
|
|
100
|
+
text = kwargs.get("text")
|
|
101
|
+
coordinate = kwargs.get("coordinate")
|
|
102
|
+
direction = kwargs.get("direction", "down")
|
|
103
|
+
amount = kwargs.get("amount", 10)
|
|
104
|
+
|
|
105
|
+
self.logger.info(f"Executing computer action: {action}")
|
|
106
|
+
|
|
107
|
+
try:
|
|
108
|
+
if action == "screenshot":
|
|
109
|
+
return await self.screenshot()
|
|
110
|
+
elif action == "left_click" and coordinate:
|
|
111
|
+
x, y = coordinate
|
|
112
|
+
self.logger.info(f"Clicking at ({x}, {y})")
|
|
113
|
+
await self.computer.interface.move_cursor(x, y)
|
|
114
|
+
await self.computer.interface.left_click()
|
|
115
|
+
|
|
116
|
+
# Take screenshot after action
|
|
117
|
+
screenshot = await self.computer.interface.screenshot()
|
|
118
|
+
screenshot = await self.resize_screenshot_if_needed(screenshot)
|
|
119
|
+
return ToolResult(
|
|
120
|
+
output=f"Performed left click at ({x}, {y})",
|
|
121
|
+
base64_image=base64.b64encode(screenshot).decode(),
|
|
122
|
+
)
|
|
123
|
+
elif action == "right_click" and coordinate:
|
|
124
|
+
x, y = coordinate
|
|
125
|
+
self.logger.info(f"Right clicking at ({x}, {y})")
|
|
126
|
+
await self.computer.interface.move_cursor(x, y)
|
|
127
|
+
await self.computer.interface.right_click()
|
|
128
|
+
|
|
129
|
+
# Take screenshot after action
|
|
130
|
+
screenshot = await self.computer.interface.screenshot()
|
|
131
|
+
screenshot = await self.resize_screenshot_if_needed(screenshot)
|
|
132
|
+
return ToolResult(
|
|
133
|
+
output=f"Performed right click at ({x}, {y})",
|
|
134
|
+
base64_image=base64.b64encode(screenshot).decode(),
|
|
135
|
+
)
|
|
136
|
+
elif action == "double_click" and coordinate:
|
|
137
|
+
x, y = coordinate
|
|
138
|
+
self.logger.info(f"Double clicking at ({x}, {y})")
|
|
139
|
+
await self.computer.interface.move_cursor(x, y)
|
|
140
|
+
await self.computer.interface.double_click()
|
|
141
|
+
|
|
142
|
+
# Take screenshot after action
|
|
143
|
+
screenshot = await self.computer.interface.screenshot()
|
|
144
|
+
screenshot = await self.resize_screenshot_if_needed(screenshot)
|
|
145
|
+
return ToolResult(
|
|
146
|
+
output=f"Performed double click at ({x}, {y})",
|
|
147
|
+
base64_image=base64.b64encode(screenshot).decode(),
|
|
148
|
+
)
|
|
149
|
+
elif action == "mouse_move" and coordinate:
|
|
150
|
+
x, y = coordinate
|
|
151
|
+
self.logger.info(f"Moving cursor to ({x}, {y})")
|
|
152
|
+
await self.computer.interface.move_cursor(x, y)
|
|
153
|
+
|
|
154
|
+
# Take screenshot after action
|
|
155
|
+
screenshot = await self.computer.interface.screenshot()
|
|
156
|
+
screenshot = await self.resize_screenshot_if_needed(screenshot)
|
|
157
|
+
return ToolResult(
|
|
158
|
+
output=f"Moved cursor to ({x}, {y})",
|
|
159
|
+
base64_image=base64.b64encode(screenshot).decode(),
|
|
160
|
+
)
|
|
161
|
+
elif action == "type" and text:
|
|
162
|
+
self.logger.info(f"Typing text: {text}")
|
|
163
|
+
await self.computer.interface.type_text(text)
|
|
164
|
+
|
|
165
|
+
# Take screenshot after action
|
|
166
|
+
screenshot = await self.computer.interface.screenshot()
|
|
167
|
+
screenshot = await self.resize_screenshot_if_needed(screenshot)
|
|
168
|
+
return ToolResult(
|
|
169
|
+
output=f"Typed text: {text}",
|
|
170
|
+
base64_image=base64.b64encode(screenshot).decode(),
|
|
171
|
+
)
|
|
172
|
+
elif action == "key" and text:
|
|
173
|
+
self.logger.info(f"Pressing key: {text}")
|
|
174
|
+
|
|
175
|
+
# Handle special key combinations
|
|
176
|
+
if "+" in text:
|
|
177
|
+
keys = text.split("+")
|
|
178
|
+
await self.computer.interface.hotkey(*keys)
|
|
179
|
+
else:
|
|
180
|
+
await self.computer.interface.press(text)
|
|
181
|
+
|
|
182
|
+
# Take screenshot after action
|
|
183
|
+
screenshot = await self.computer.interface.screenshot()
|
|
184
|
+
screenshot = await self.resize_screenshot_if_needed(screenshot)
|
|
185
|
+
return ToolResult(
|
|
186
|
+
output=f"Pressed key: {text}",
|
|
187
|
+
base64_image=base64.b64encode(screenshot).decode(),
|
|
188
|
+
)
|
|
189
|
+
elif action == "cursor_position":
|
|
190
|
+
pos = await self.computer.interface.get_cursor_position()
|
|
191
|
+
return ToolResult(output=f"X={int(pos[0])},Y={int(pos[1])}")
|
|
192
|
+
elif action == "scroll":
|
|
193
|
+
if direction == "down":
|
|
194
|
+
self.logger.info(f"Scrolling down, amount: {amount}")
|
|
195
|
+
for _ in range(amount):
|
|
196
|
+
await self.computer.interface.hotkey("fn", "down")
|
|
197
|
+
else:
|
|
198
|
+
self.logger.info(f"Scrolling up, amount: {amount}")
|
|
199
|
+
for _ in range(amount):
|
|
200
|
+
await self.computer.interface.hotkey("fn", "up")
|
|
201
|
+
|
|
202
|
+
# Take screenshot after action
|
|
203
|
+
screenshot = await self.computer.interface.screenshot()
|
|
204
|
+
screenshot = await self.resize_screenshot_if_needed(screenshot)
|
|
205
|
+
return ToolResult(
|
|
206
|
+
output=f"Scrolled {direction} by {amount} steps",
|
|
207
|
+
base64_image=base64.b64encode(screenshot).decode(),
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# Default to screenshot for unimplemented actions
|
|
211
|
+
self.logger.warning(f"Action {action} not fully implemented, taking screenshot")
|
|
212
|
+
return await self.screenshot()
|
|
213
|
+
|
|
214
|
+
except Exception as e:
|
|
215
|
+
self.logger.error(f"Error during computer action: {str(e)}")
|
|
216
|
+
return ToolResult(error=f"Failed to perform {action}: {str(e)}")
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""Omni tool manager implementation."""
|
|
2
|
+
|
|
3
|
+
from typing import Dict, List, Any
|
|
4
|
+
from enum import Enum
|
|
5
|
+
|
|
6
|
+
from computer.computer import Computer
|
|
7
|
+
|
|
8
|
+
from ....core.tools import BaseToolManager
|
|
9
|
+
from ....core.tools.collection import ToolCollection
|
|
10
|
+
|
|
11
|
+
from .bash import OmniBashTool
|
|
12
|
+
from .computer import OmniComputerTool
|
|
13
|
+
from .edit import OmniEditTool
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ProviderType(Enum):
|
|
17
|
+
"""Supported provider types."""
|
|
18
|
+
|
|
19
|
+
ANTHROPIC = "anthropic"
|
|
20
|
+
OPENAI = "openai"
|
|
21
|
+
CLAUDE = "claude" # Alias for Anthropic
|
|
22
|
+
GPT = "gpt" # Alias for OpenAI
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class OmniToolManager(BaseToolManager):
|
|
26
|
+
"""Tool manager for multi-provider support."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, computer: Computer):
|
|
29
|
+
"""Initialize Omni tool manager.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
computer: Computer instance for tools
|
|
33
|
+
"""
|
|
34
|
+
super().__init__(computer)
|
|
35
|
+
# Initialize tools
|
|
36
|
+
self.computer_tool = OmniComputerTool(self.computer)
|
|
37
|
+
self.bash_tool = OmniBashTool(self.computer)
|
|
38
|
+
self.edit_tool = OmniEditTool(self.computer)
|
|
39
|
+
|
|
40
|
+
def _initialize_tools(self) -> ToolCollection:
|
|
41
|
+
"""Initialize all available tools."""
|
|
42
|
+
return ToolCollection(self.computer_tool, self.bash_tool, self.edit_tool)
|
|
43
|
+
|
|
44
|
+
async def _initialize_tools_specific(self) -> None:
|
|
45
|
+
"""Initialize provider-specific tool requirements."""
|
|
46
|
+
await self.computer_tool.initialize_dimensions()
|
|
47
|
+
|
|
48
|
+
def get_tool_params(self) -> List[Dict[str, Any]]:
|
|
49
|
+
"""Get tool parameters for API calls.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
List of tool parameters in default format
|
|
53
|
+
"""
|
|
54
|
+
if self.tools is None:
|
|
55
|
+
raise RuntimeError("Tools not initialized. Call initialize() first.")
|
|
56
|
+
return self.tools.to_params()
|
|
57
|
+
|
|
58
|
+
def get_provider_tools(self, provider: ProviderType) -> List[Dict[str, Any]]:
|
|
59
|
+
"""Get tools formatted for a specific provider.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
provider: Provider type to format tools for
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
List of tool parameters in provider-specific format
|
|
66
|
+
"""
|
|
67
|
+
if self.tools is None:
|
|
68
|
+
raise RuntimeError("Tools not initialized. Call initialize() first.")
|
|
69
|
+
|
|
70
|
+
# Default is the base implementation
|
|
71
|
+
tools = self.tools.to_params()
|
|
72
|
+
|
|
73
|
+
# Customize for each provider if needed
|
|
74
|
+
if provider in [ProviderType.ANTHROPIC, ProviderType.CLAUDE]:
|
|
75
|
+
# Format for Anthropic API
|
|
76
|
+
# Additional adjustments can be made here
|
|
77
|
+
pass
|
|
78
|
+
elif provider in [ProviderType.OPENAI, ProviderType.GPT]:
|
|
79
|
+
# Format for OpenAI API
|
|
80
|
+
# Future implementation
|
|
81
|
+
pass
|
|
82
|
+
|
|
83
|
+
return tools
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""Type definitions for the Omni provider."""
|
|
2
|
+
|
|
3
|
+
from enum import StrEnum
|
|
4
|
+
from typing import Dict
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class APIProvider(StrEnum):
|
|
8
|
+
"""Supported API providers."""
|
|
9
|
+
|
|
10
|
+
ANTHROPIC = "anthropic"
|
|
11
|
+
OPENAI = "openai"
|
|
12
|
+
GROQ = "groq"
|
|
13
|
+
QWEN = "qwen"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# Default models for each provider
|
|
17
|
+
PROVIDER_TO_DEFAULT_MODEL: Dict[APIProvider, str] = {
|
|
18
|
+
APIProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
|
|
19
|
+
APIProvider.OPENAI: "gpt-4o",
|
|
20
|
+
APIProvider.GROQ: "deepseek-r1-distill-llama-70b",
|
|
21
|
+
APIProvider.QWEN: "qwen2.5-vl-72b-instruct",
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
# Environment variable names for each provider
|
|
25
|
+
PROVIDER_TO_ENV_VAR: Dict[APIProvider, str] = {
|
|
26
|
+
APIProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
|
|
27
|
+
APIProvider.OPENAI: "OPENAI_API_KEY",
|
|
28
|
+
APIProvider.GROQ: "GROQ_API_KEY",
|
|
29
|
+
APIProvider.QWEN: "QWEN_API_KEY",
|
|
30
|
+
}
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
"""Utility functions for Omni provider."""
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import io
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Tuple
|
|
7
|
+
from PIL import Image
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def compress_image_base64(
|
|
13
|
+
base64_str: str, max_size_bytes: int = 5 * 1024 * 1024, quality: int = 90
|
|
14
|
+
) -> tuple[str, str]:
|
|
15
|
+
"""Compress a base64 encoded image to ensure it's below a certain size.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
base64_str: Base64 encoded image string (with or without data URL prefix)
|
|
19
|
+
max_size_bytes: Maximum size in bytes (default: 5MB)
|
|
20
|
+
quality: Initial JPEG quality (0-100)
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
tuple[str, str]: (Compressed base64 encoded image, media_type)
|
|
24
|
+
"""
|
|
25
|
+
# Handle data URL prefix if present (e.g., "data:image/png;base64,...")
|
|
26
|
+
original_prefix = ""
|
|
27
|
+
media_type = "image/png" # Default media type
|
|
28
|
+
|
|
29
|
+
if base64_str.startswith("data:"):
|
|
30
|
+
parts = base64_str.split(",", 1)
|
|
31
|
+
if len(parts) == 2:
|
|
32
|
+
original_prefix = parts[0] + ","
|
|
33
|
+
base64_str = parts[1]
|
|
34
|
+
# Try to extract media type from the prefix
|
|
35
|
+
if "image/jpeg" in original_prefix.lower():
|
|
36
|
+
media_type = "image/jpeg"
|
|
37
|
+
elif "image/png" in original_prefix.lower():
|
|
38
|
+
media_type = "image/png"
|
|
39
|
+
|
|
40
|
+
# Check if the base64 string is small enough already
|
|
41
|
+
if len(base64_str) <= max_size_bytes:
|
|
42
|
+
logger.info(f"Image already within size limit: {len(base64_str)} bytes")
|
|
43
|
+
return original_prefix + base64_str, media_type
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
# Decode base64
|
|
47
|
+
img_data = base64.b64decode(base64_str)
|
|
48
|
+
img_size = len(img_data)
|
|
49
|
+
logger.info(f"Original image size: {img_size} bytes")
|
|
50
|
+
|
|
51
|
+
# Open image
|
|
52
|
+
img = Image.open(io.BytesIO(img_data))
|
|
53
|
+
|
|
54
|
+
# First, try to compress as PNG (maintains transparency if present)
|
|
55
|
+
buffer = io.BytesIO()
|
|
56
|
+
img.save(buffer, format="PNG", optimize=True)
|
|
57
|
+
buffer.seek(0)
|
|
58
|
+
compressed_data = buffer.getvalue()
|
|
59
|
+
compressed_b64 = base64.b64encode(compressed_data).decode("utf-8")
|
|
60
|
+
|
|
61
|
+
if len(compressed_b64) <= max_size_bytes:
|
|
62
|
+
logger.info(f"Compressed to {len(compressed_data)} bytes as PNG")
|
|
63
|
+
return compressed_b64, "image/png"
|
|
64
|
+
|
|
65
|
+
# Strategy 1: Try reducing quality with JPEG format
|
|
66
|
+
current_quality = quality
|
|
67
|
+
while current_quality > 20:
|
|
68
|
+
buffer = io.BytesIO()
|
|
69
|
+
# Convert to RGB if image has alpha channel (JPEG doesn't support transparency)
|
|
70
|
+
if img.mode in ("RGBA", "LA") or (img.mode == "P" and "transparency" in img.info):
|
|
71
|
+
logger.info("Converting transparent image to RGB for JPEG compression")
|
|
72
|
+
rgb_img = Image.new("RGB", img.size, (255, 255, 255))
|
|
73
|
+
rgb_img.paste(img, mask=img.split()[3] if img.mode == "RGBA" else None)
|
|
74
|
+
rgb_img.save(buffer, format="JPEG", quality=current_quality, optimize=True)
|
|
75
|
+
else:
|
|
76
|
+
img.save(buffer, format="JPEG", quality=current_quality, optimize=True)
|
|
77
|
+
|
|
78
|
+
buffer.seek(0)
|
|
79
|
+
compressed_data = buffer.getvalue()
|
|
80
|
+
compressed_b64 = base64.b64encode(compressed_data).decode("utf-8")
|
|
81
|
+
|
|
82
|
+
if len(compressed_b64) <= max_size_bytes:
|
|
83
|
+
logger.info(
|
|
84
|
+
f"Compressed to {len(compressed_data)} bytes with JPEG quality {current_quality}"
|
|
85
|
+
)
|
|
86
|
+
return compressed_b64, "image/jpeg"
|
|
87
|
+
|
|
88
|
+
# Reduce quality and try again
|
|
89
|
+
current_quality -= 10
|
|
90
|
+
|
|
91
|
+
# Strategy 2: If quality reduction isn't enough, reduce dimensions
|
|
92
|
+
scale_factor = 0.8
|
|
93
|
+
current_img = img
|
|
94
|
+
|
|
95
|
+
while scale_factor > 0.3:
|
|
96
|
+
# Resize image
|
|
97
|
+
new_width = int(img.width * scale_factor)
|
|
98
|
+
new_height = int(img.height * scale_factor)
|
|
99
|
+
current_img = img.resize((new_width, new_height), Image.LANCZOS)
|
|
100
|
+
|
|
101
|
+
# Try with reduced size and quality
|
|
102
|
+
buffer = io.BytesIO()
|
|
103
|
+
# Convert to RGB if necessary for JPEG
|
|
104
|
+
if current_img.mode in ("RGBA", "LA") or (
|
|
105
|
+
current_img.mode == "P" and "transparency" in current_img.info
|
|
106
|
+
):
|
|
107
|
+
rgb_img = Image.new("RGB", current_img.size, (255, 255, 255))
|
|
108
|
+
rgb_img.paste(
|
|
109
|
+
current_img, mask=current_img.split()[3] if current_img.mode == "RGBA" else None
|
|
110
|
+
)
|
|
111
|
+
rgb_img.save(buffer, format="JPEG", quality=70, optimize=True)
|
|
112
|
+
else:
|
|
113
|
+
current_img.save(buffer, format="JPEG", quality=70, optimize=True)
|
|
114
|
+
|
|
115
|
+
buffer.seek(0)
|
|
116
|
+
compressed_data = buffer.getvalue()
|
|
117
|
+
compressed_b64 = base64.b64encode(compressed_data).decode("utf-8")
|
|
118
|
+
|
|
119
|
+
if len(compressed_b64) <= max_size_bytes:
|
|
120
|
+
logger.info(
|
|
121
|
+
f"Compressed to {len(compressed_data)} bytes with scale {scale_factor} and JPEG quality 70"
|
|
122
|
+
)
|
|
123
|
+
return compressed_b64, "image/jpeg"
|
|
124
|
+
|
|
125
|
+
# Reduce scale factor and try again
|
|
126
|
+
scale_factor -= 0.1
|
|
127
|
+
|
|
128
|
+
# If we get here, we couldn't compress enough
|
|
129
|
+
logger.warning("Could not compress image below required size with quality preservation")
|
|
130
|
+
|
|
131
|
+
# Last resort: Use minimum quality and size
|
|
132
|
+
buffer = io.BytesIO()
|
|
133
|
+
smallest_img = img.resize((int(img.width * 0.5), int(img.height * 0.5)), Image.LANCZOS)
|
|
134
|
+
# Convert to RGB if necessary
|
|
135
|
+
if smallest_img.mode in ("RGBA", "LA") or (
|
|
136
|
+
smallest_img.mode == "P" and "transparency" in smallest_img.info
|
|
137
|
+
):
|
|
138
|
+
rgb_img = Image.new("RGB", smallest_img.size, (255, 255, 255))
|
|
139
|
+
rgb_img.paste(
|
|
140
|
+
smallest_img, mask=smallest_img.split()[3] if smallest_img.mode == "RGBA" else None
|
|
141
|
+
)
|
|
142
|
+
rgb_img.save(buffer, format="JPEG", quality=20, optimize=True)
|
|
143
|
+
else:
|
|
144
|
+
smallest_img.save(buffer, format="JPEG", quality=20, optimize=True)
|
|
145
|
+
|
|
146
|
+
buffer.seek(0)
|
|
147
|
+
final_data = buffer.getvalue()
|
|
148
|
+
final_b64 = base64.b64encode(final_data).decode("utf-8")
|
|
149
|
+
|
|
150
|
+
logger.warning(f"Final compressed size: {len(final_b64)} bytes (may still exceed limit)")
|
|
151
|
+
return final_b64, "image/jpeg"
|
|
152
|
+
|
|
153
|
+
except Exception as e:
|
|
154
|
+
logger.error(f"Error compressing image: {str(e)}")
|
|
155
|
+
raise
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
"""Visualization utilities for the Cua provider."""
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import logging
|
|
5
|
+
from io import BytesIO
|
|
6
|
+
from typing import Tuple
|
|
7
|
+
from PIL import Image, ImageDraw
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def visualize_click(x: int, y: int, img_base64: str) -> Image.Image:
|
|
13
|
+
"""Visualize a click action by drawing on the screenshot.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
x: X coordinate of the click
|
|
17
|
+
y: Y coordinate of the click
|
|
18
|
+
img_base64: Base64 encoded image to draw on
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
PIL Image with visualization
|
|
22
|
+
"""
|
|
23
|
+
try:
|
|
24
|
+
# Decode the base64 image
|
|
25
|
+
img_data = base64.b64decode(img_base64)
|
|
26
|
+
img = Image.open(BytesIO(img_data))
|
|
27
|
+
|
|
28
|
+
# Create a drawing context
|
|
29
|
+
draw = ImageDraw.Draw(img)
|
|
30
|
+
|
|
31
|
+
# Draw concentric circles at the click position
|
|
32
|
+
small_radius = 10
|
|
33
|
+
large_radius = 30
|
|
34
|
+
|
|
35
|
+
# Draw filled inner circle
|
|
36
|
+
draw.ellipse(
|
|
37
|
+
[(x - small_radius, y - small_radius), (x + small_radius, y + small_radius)],
|
|
38
|
+
fill="red",
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# Draw outlined outer circle
|
|
42
|
+
draw.ellipse(
|
|
43
|
+
[(x - large_radius, y - large_radius), (x + large_radius, y + large_radius)],
|
|
44
|
+
outline="red",
|
|
45
|
+
width=3,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
return img
|
|
49
|
+
|
|
50
|
+
except Exception as e:
|
|
51
|
+
logger.error(f"Error visualizing click: {str(e)}")
|
|
52
|
+
# Return a blank image in case of error
|
|
53
|
+
return Image.new("RGB", (800, 600), color="white")
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def visualize_scroll(direction: str, clicks: int, img_base64: str) -> Image.Image:
|
|
57
|
+
"""Visualize a scroll action by drawing arrows on the screenshot.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
direction: 'up' or 'down'
|
|
61
|
+
clicks: Number of scroll clicks
|
|
62
|
+
img_base64: Base64 encoded image to draw on
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
PIL Image with visualization
|
|
66
|
+
"""
|
|
67
|
+
try:
|
|
68
|
+
# Decode the base64 image
|
|
69
|
+
img_data = base64.b64decode(img_base64)
|
|
70
|
+
img = Image.open(BytesIO(img_data))
|
|
71
|
+
|
|
72
|
+
# Get image dimensions
|
|
73
|
+
width, height = img.size
|
|
74
|
+
|
|
75
|
+
# Create a drawing context
|
|
76
|
+
draw = ImageDraw.Draw(img)
|
|
77
|
+
|
|
78
|
+
# Determine arrow direction and positions
|
|
79
|
+
center_x = width // 2
|
|
80
|
+
arrow_width = 100
|
|
81
|
+
|
|
82
|
+
if direction.lower() == "up":
|
|
83
|
+
# Draw up arrow in the middle of the screen
|
|
84
|
+
arrow_y = height // 2
|
|
85
|
+
# Arrow points
|
|
86
|
+
points = [
|
|
87
|
+
(center_x, arrow_y - 50), # Top point
|
|
88
|
+
(center_x - arrow_width // 2, arrow_y + 50), # Bottom left
|
|
89
|
+
(center_x + arrow_width // 2, arrow_y + 50), # Bottom right
|
|
90
|
+
]
|
|
91
|
+
color = "blue"
|
|
92
|
+
else: # down
|
|
93
|
+
# Draw down arrow in the middle of the screen
|
|
94
|
+
arrow_y = height // 2
|
|
95
|
+
# Arrow points
|
|
96
|
+
points = [
|
|
97
|
+
(center_x, arrow_y + 50), # Bottom point
|
|
98
|
+
(center_x - arrow_width // 2, arrow_y - 50), # Top left
|
|
99
|
+
(center_x + arrow_width // 2, arrow_y - 50), # Top right
|
|
100
|
+
]
|
|
101
|
+
color = "green"
|
|
102
|
+
|
|
103
|
+
# Draw filled arrow
|
|
104
|
+
draw.polygon(points, fill=color)
|
|
105
|
+
|
|
106
|
+
# Add text showing number of clicks
|
|
107
|
+
text_y = arrow_y + 70 if direction.lower() == "down" else arrow_y - 70
|
|
108
|
+
draw.text((center_x - 40, text_y), f"{clicks} clicks", fill="black")
|
|
109
|
+
|
|
110
|
+
return img
|
|
111
|
+
|
|
112
|
+
except Exception as e:
|
|
113
|
+
logger.error(f"Error visualizing scroll: {str(e)}")
|
|
114
|
+
# Return a blank image in case of error
|
|
115
|
+
return Image.new("RGB", (800, 600), color="white")
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def calculate_element_center(box: Tuple[int, int, int, int]) -> Tuple[int, int]:
|
|
119
|
+
"""Calculate the center coordinates of a bounding box.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
box: Tuple of (left, top, right, bottom) coordinates
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
Tuple of (center_x, center_y) coordinates
|
|
126
|
+
"""
|
|
127
|
+
left, top, right, bottom = box
|
|
128
|
+
center_x = (left + right) // 2
|
|
129
|
+
center_y = (top + bottom) // 2
|
|
130
|
+
return center_x, center_y
|
agent/types/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Type definitions for the agent package."""
|
|
2
|
+
|
|
3
|
+
from .base import Provider, HostConfig, TaskResult, Annotation
|
|
4
|
+
from .messages import Message, Request, Response, StepMessage, DisengageMessage
|
|
5
|
+
from .tools import ToolInvocation, ToolInvocationState, ClientAttachment, ToolResult
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
# Base types
|
|
9
|
+
"Provider",
|
|
10
|
+
"HostConfig",
|
|
11
|
+
"TaskResult",
|
|
12
|
+
"Annotation",
|
|
13
|
+
|
|
14
|
+
# Message types
|
|
15
|
+
"Message",
|
|
16
|
+
"Request",
|
|
17
|
+
"Response",
|
|
18
|
+
"StepMessage",
|
|
19
|
+
"DisengageMessage",
|
|
20
|
+
|
|
21
|
+
# Tool types
|
|
22
|
+
"ToolInvocation",
|
|
23
|
+
"ToolInvocationState",
|
|
24
|
+
"ClientAttachment",
|
|
25
|
+
"ToolResult",
|
|
26
|
+
]
|
agent/types/base.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"""Base type definitions."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum, auto
|
|
4
|
+
from typing import Dict, Any
|
|
5
|
+
from pydantic import BaseModel, ConfigDict
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Provider(str, Enum):
|
|
9
|
+
"""Available AI providers."""
|
|
10
|
+
|
|
11
|
+
UNKNOWN = "unknown" # Default provider for base class
|
|
12
|
+
ANTHROPIC = "anthropic"
|
|
13
|
+
OPENAI = "openai"
|
|
14
|
+
OLLAMA = "ollama"
|
|
15
|
+
OMNI = "omni"
|
|
16
|
+
GROQ = "groq"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class HostConfig(BaseModel):
|
|
20
|
+
"""Host configuration."""
|
|
21
|
+
|
|
22
|
+
model_config = ConfigDict(extra="forbid")
|
|
23
|
+
hostname: str
|
|
24
|
+
port: int
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def address(self) -> str:
|
|
28
|
+
return f"{self.hostname}:{self.port}"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class TaskResult(BaseModel):
|
|
32
|
+
"""Result of a task execution."""
|
|
33
|
+
|
|
34
|
+
model_config = ConfigDict(extra="forbid")
|
|
35
|
+
result: str
|
|
36
|
+
vnc_password: str
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class Annotation(BaseModel):
|
|
40
|
+
"""Annotation metadata."""
|
|
41
|
+
|
|
42
|
+
model_config = ConfigDict(extra="forbid")
|
|
43
|
+
id: str
|
|
44
|
+
vm_url: str
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class AgenticLoop(Enum):
|
|
48
|
+
"""Enumeration of available loop types."""
|
|
49
|
+
|
|
50
|
+
ANTHROPIC = auto() # Anthropic implementation
|
|
51
|
+
OMNI = auto() # OmniLoop implementation
|
|
52
|
+
# Add more loop types as needed
|