cua-agent 0.1.28__py3-none-any.whl → 0.1.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cua-agent might be problematic. Click here for more details.

agent/core/factory.py CHANGED
@@ -98,5 +98,24 @@ class LoopFactory:
98
98
  parser=OmniParser(),
99
99
  provider_base_url=provider_base_url,
100
100
  )
101
+ elif loop_type == AgentLoop.UITARS:
102
+ # Lazy import UITARSLoop only when needed
103
+ try:
104
+ from ..providers.uitars.loop import UITARSLoop
105
+ except ImportError:
106
+ raise ImportError(
107
+ "The 'uitars' provider is not installed. "
108
+ "Install it with 'pip install cua-agent[all]'"
109
+ )
110
+
111
+ return UITARSLoop(
112
+ api_key=api_key,
113
+ model=model_name,
114
+ computer=computer,
115
+ save_trajectory=save_trajectory,
116
+ base_dir=trajectory_dir,
117
+ only_n_most_recent_images=only_n_most_recent_images,
118
+ provider_base_url=provider_base_url,
119
+ )
101
120
  else:
102
121
  raise ValueError(f"Unsupported loop type: {loop_type}")
agent/core/types.py CHANGED
@@ -12,6 +12,7 @@ class AgentLoop(Enum):
12
12
  OMNI = auto() # OmniLoop implementation
13
13
  OPENAI = auto() # OpenAI implementation
14
14
  OLLAMA = auto() # OLLAMA implementation
15
+ UITARS = auto() # UI-TARS implementation
15
16
  # Add more loop types as needed
16
17
 
17
18
 
@@ -162,8 +162,8 @@ class ComputerTool(BaseComputerTool, BaseOpenAITool):
162
162
  y = kwargs.get("y")
163
163
  if x is None or y is None:
164
164
  raise ToolError("x and y coordinates are required for scroll action")
165
- scroll_x = kwargs.get("scroll_x", 0)
166
- scroll_y = kwargs.get("scroll_y", 0)
165
+ scroll_x = kwargs.get("scroll_x", 0) // 20
166
+ scroll_y = kwargs.get("scroll_y", 0) // 20
167
167
  return await self.handle_scroll(x, y, scroll_x, scroll_y)
168
168
  elif type == "screenshot":
169
169
  return await self.screenshot()
@@ -0,0 +1 @@
1
+ """UI-TARS Agent provider package."""
@@ -0,0 +1,35 @@
1
+ """Base client implementation for Omni providers."""
2
+
3
+ import logging
4
+ from typing import Dict, List, Optional, Any, Tuple
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ class BaseUITarsClient:
10
+ """Base class for provider-specific clients."""
11
+
12
+ def __init__(self, api_key: Optional[str] = None, model: Optional[str] = None):
13
+ """Initialize base client.
14
+
15
+ Args:
16
+ api_key: Optional API key
17
+ model: Optional model name
18
+ """
19
+ self.api_key = api_key
20
+ self.model = model
21
+
22
+ async def run_interleaved(
23
+ self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
24
+ ) -> Dict[str, Any]:
25
+ """Run interleaved chat completion.
26
+
27
+ Args:
28
+ messages: List of message dicts
29
+ system: System prompt
30
+ max_tokens: Optional max tokens override
31
+
32
+ Returns:
33
+ Response dict
34
+ """
35
+ raise NotImplementedError
@@ -0,0 +1,204 @@
1
+ """OpenAI-compatible client implementation."""
2
+
3
+ import os
4
+ import logging
5
+ from typing import Dict, List, Optional, Any
6
+ import aiohttp
7
+ import re
8
+ from .base import BaseUITarsClient
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ # OpenAI-compatible client for the UI_Tars
14
+ class OAICompatClient(BaseUITarsClient):
15
+ """OpenAI-compatible API client implementation.
16
+
17
+ This client can be used with any service that implements the OpenAI API protocol, including:
18
+ - Huggingface Text Generation Interface endpoints
19
+ - vLLM
20
+ - LM Studio
21
+ - LocalAI
22
+ - Ollama (with OpenAI compatibility)
23
+ - Text Generation WebUI
24
+ - Any other service with OpenAI API compatibility
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ api_key: Optional[str] = None,
30
+ model: str = "Qwen2.5-VL-7B-Instruct",
31
+ provider_base_url: Optional[str] = "http://localhost:8000/v1",
32
+ max_tokens: int = 4096,
33
+ temperature: float = 0.0,
34
+ ):
35
+ """Initialize the OpenAI-compatible client.
36
+
37
+ Args:
38
+ api_key: Not used for local endpoints, usually set to "EMPTY"
39
+ model: Model name to use
40
+ provider_base_url: API base URL. Typically in the format "http://localhost:PORT/v1"
41
+ Examples:
42
+ - vLLM: "http://localhost:8000/v1"
43
+ - LM Studio: "http://localhost:1234/v1"
44
+ - LocalAI: "http://localhost:8080/v1"
45
+ - Ollama: "http://localhost:11434/v1"
46
+ max_tokens: Maximum tokens to generate
47
+ temperature: Generation temperature
48
+ """
49
+ super().__init__(api_key=api_key or "EMPTY", model=model)
50
+ self.api_key = api_key or "EMPTY" # Local endpoints typically don't require an API key
51
+ self.model = model
52
+ self.provider_base_url = (
53
+ provider_base_url or "http://localhost:8000/v1"
54
+ ) # Use default if None
55
+ self.max_tokens = max_tokens
56
+ self.temperature = temperature
57
+
58
+ def _extract_base64_image(self, text: str) -> Optional[str]:
59
+ """Extract base64 image data from an HTML img tag."""
60
+ pattern = r'data:image/[^;]+;base64,([^"]+)'
61
+ match = re.search(pattern, text)
62
+ return match.group(1) if match else None
63
+
64
+ def _get_loggable_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
65
+ """Create a loggable version of messages with image data truncated."""
66
+ loggable_messages = []
67
+ for msg in messages:
68
+ if isinstance(msg.get("content"), list):
69
+ new_content = []
70
+ for content in msg["content"]:
71
+ if content.get("type") == "image":
72
+ new_content.append(
73
+ {"type": "image", "image_url": {"url": "[BASE64_IMAGE_DATA]"}}
74
+ )
75
+ else:
76
+ new_content.append(content)
77
+ loggable_messages.append({"role": msg["role"], "content": new_content})
78
+ else:
79
+ loggable_messages.append(msg)
80
+ return loggable_messages
81
+
82
+ async def run_interleaved(
83
+ self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
84
+ ) -> Dict[str, Any]:
85
+ """Run interleaved chat completion.
86
+
87
+ Args:
88
+ messages: List of message dicts
89
+ system: System prompt
90
+ max_tokens: Optional max tokens override
91
+
92
+ Returns:
93
+ Response dict
94
+ """
95
+ headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
96
+
97
+ final_messages = [{"role": "system", "content": system}]
98
+
99
+ # Process messages
100
+ for item in messages:
101
+ if isinstance(item, dict):
102
+ if isinstance(item["content"], list):
103
+ # Content is already in the correct format
104
+ final_messages.append(item)
105
+ else:
106
+ # Single string content, check for image
107
+ base64_img = self._extract_base64_image(item["content"])
108
+ if base64_img:
109
+ message = {
110
+ "role": item["role"],
111
+ "content": [
112
+ {
113
+ "type": "image_url",
114
+ "image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
115
+ }
116
+ ],
117
+ }
118
+ else:
119
+ message = {
120
+ "role": item["role"],
121
+ "content": [{"type": "text", "text": item["content"]}],
122
+ }
123
+ final_messages.append(message)
124
+ else:
125
+ # String content, check for image
126
+ base64_img = self._extract_base64_image(item)
127
+ if base64_img:
128
+ message = {
129
+ "role": "user",
130
+ "content": [
131
+ {
132
+ "type": "image_url",
133
+ "image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
134
+ }
135
+ ],
136
+ }
137
+ else:
138
+ message = {"role": "user", "content": [{"type": "text", "text": item}]}
139
+ final_messages.append(message)
140
+
141
+ payload = {"model": self.model, "messages": final_messages, "temperature": self.temperature}
142
+ payload["max_tokens"] = max_tokens or self.max_tokens
143
+
144
+ try:
145
+ async with aiohttp.ClientSession() as session:
146
+ # Use default base URL if none provided
147
+ base_url = self.provider_base_url or "http://localhost:8000/v1"
148
+
149
+ # Check if the base URL already includes the chat/completions endpoint
150
+
151
+ endpoint_url = base_url
152
+ if not endpoint_url.endswith("/chat/completions"):
153
+ # If URL is RunPod format, make it OpenAI compatible
154
+ if endpoint_url.startswith("https://api.runpod.ai/v2/"):
155
+ # Extract RunPod endpoint ID
156
+ parts = endpoint_url.split("/")
157
+ if len(parts) >= 5:
158
+ runpod_id = parts[4]
159
+ endpoint_url = f"https://api.runpod.ai/v2/{runpod_id}/openai/v1/chat/completions"
160
+ # If the URL ends with /v1, append /chat/completions
161
+ elif endpoint_url.endswith("/v1"):
162
+ endpoint_url = f"{endpoint_url}/chat/completions"
163
+ # If the URL doesn't end with /v1, make sure it has a proper structure
164
+ elif not endpoint_url.endswith("/"):
165
+ endpoint_url = f"{endpoint_url}/chat/completions"
166
+ else:
167
+ endpoint_url = f"{endpoint_url}chat/completions"
168
+
169
+ # Log the endpoint URL for debugging
170
+ logger.debug(f"Using endpoint URL: {endpoint_url}")
171
+
172
+ async with session.post(endpoint_url, headers=headers, json=payload) as response:
173
+ # Log the status and content type
174
+ logger.debug(f"Status: {response.status}")
175
+ logger.debug(f"Content-Type: {response.headers.get('Content-Type')}")
176
+
177
+ # Get the raw text of the response
178
+ response_text = await response.text()
179
+ logger.debug(f"Response content: {response_text}")
180
+
181
+ # Try to parse as JSON if the content type is appropriate
182
+ if "application/json" in response.headers.get('Content-Type', ''):
183
+ response_json = await response.json()
184
+ else:
185
+ raise Exception(f"Response is not JSON format")
186
+ # # Optionally try to parse it anyway
187
+ # try:
188
+ # import json
189
+ # response_json = json.loads(response_text)
190
+ # except json.JSONDecodeError as e:
191
+ # print(f"Failed to parse response as JSON: {e}")
192
+
193
+ if response.status != 200:
194
+ error_msg = response_json.get("error", {}).get(
195
+ "message", str(response_json)
196
+ )
197
+ logger.error(f"Error in API call: {error_msg}")
198
+ raise Exception(f"API error: {error_msg}")
199
+
200
+ return response_json
201
+
202
+ except Exception as e:
203
+ logger.error(f"Error in API call: {str(e)}")
204
+ raise