cua-agent 0.1.29__py3-none-any.whl → 0.1.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cua-agent might be problematic. Click here for more details.

@@ -0,0 +1,595 @@
1
+ """UI-TARS-specific agent loop implementation."""
2
+
3
+ import logging
4
+ import asyncio
5
+ import re
6
+ import os
7
+ import json
8
+ import base64
9
+ import copy
10
+ from typing import Any, Dict, List, Optional, Tuple, AsyncGenerator, cast
11
+
12
+ from httpx import ConnectError, ReadTimeout
13
+
14
+ from ...core.base import BaseLoop
15
+ from ...core.messages import StandardMessageManager, ImageRetentionConfig
16
+ from ...core.types import AgentResponse, LLMProvider
17
+ from ...core.visualization import VisualizationHelper
18
+ from computer import Computer
19
+
20
+ from .utils import add_box_token, parse_actions, parse_action_parameters
21
+ from .tools.manager import ToolManager
22
+ from .tools.computer import ToolResult
23
+ from .prompts import COMPUTER_USE, SYSTEM_PROMPT
24
+
25
+ from .clients.oaicompat import OAICompatClient
26
+
27
+ logging.basicConfig(level=logging.INFO)
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class UITARSLoop(BaseLoop):
32
+ """UI-TARS-specific implementation of the agent loop.
33
+
34
+ This class extends BaseLoop to provide support for the UI-TARS model
35
+ with computer control capabilities.
36
+ """
37
+
38
+ ###########################################
39
+ # INITIALIZATION AND CONFIGURATION
40
+ ###########################################
41
+
42
+ def __init__(
43
+ self,
44
+ computer: Computer,
45
+ api_key: str,
46
+ model: str,
47
+ provider_base_url: Optional[str] = "http://localhost:8000/v1",
48
+ only_n_most_recent_images: Optional[int] = 2,
49
+ base_dir: Optional[str] = "trajectories",
50
+ max_retries: int = 3,
51
+ retry_delay: float = 1.0,
52
+ save_trajectory: bool = True,
53
+ **kwargs,
54
+ ):
55
+ """Initialize the loop.
56
+
57
+ Args:
58
+ computer: Computer instance
59
+ api_key: API key (may not be needed for local endpoints)
60
+ model: Model name (e.g., "ui-tars")
61
+ provider_base_url: Base URL for the API provider
62
+ only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
63
+ base_dir: Base directory for saving experiment data
64
+ max_retries: Maximum number of retries for API calls
65
+ retry_delay: Delay between retries in seconds
66
+ save_trajectory: Whether to save trajectory data
67
+ """
68
+ # Set provider before initializing base class
69
+ self.provider = LLMProvider.OAICOMPAT
70
+ self.provider_base_url = provider_base_url
71
+
72
+ # Initialize message manager with image retention config
73
+ self.message_manager = StandardMessageManager(
74
+ config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
75
+ )
76
+
77
+ # Initialize base class (which will set up experiment manager)
78
+ super().__init__(
79
+ computer=computer,
80
+ model=model,
81
+ api_key=api_key,
82
+ max_retries=max_retries,
83
+ retry_delay=retry_delay,
84
+ base_dir=base_dir,
85
+ save_trajectory=save_trajectory,
86
+ only_n_most_recent_images=only_n_most_recent_images,
87
+ **kwargs,
88
+ )
89
+
90
+ # Set API client attributes
91
+ self.client = None
92
+ self.retry_count = 0
93
+
94
+ # Initialize visualization helper
95
+ self.viz_helper = VisualizationHelper(agent=self)
96
+
97
+ # Initialize tool manager
98
+ self.tool_manager = ToolManager(computer=computer)
99
+
100
+ logger.info("UITARSLoop initialized with StandardMessageManager")
101
+
102
+ async def initialize(self) -> None:
103
+ """Initialize the loop by setting up tools and clients."""
104
+ # Initialize base class
105
+ await super().initialize()
106
+
107
+ # Initialize tool manager with error handling
108
+ try:
109
+ logger.info("Initializing tool manager...")
110
+ await self.tool_manager.initialize()
111
+ logger.info("Tool manager initialized successfully.")
112
+ except Exception as e:
113
+ logger.error(f"Error initializing tool manager: {str(e)}")
114
+ logger.warning("Will attempt to initialize tools on first use.")
115
+
116
+ # Initialize client for the OAICompat provider
117
+ try:
118
+ await self.initialize_client()
119
+ except Exception as e:
120
+ logger.error(f"Error initializing client: {str(e)}")
121
+ raise RuntimeError(f"Failed to initialize client: {str(e)}")
122
+
123
+ ###########################################
124
+ # CLIENT INITIALIZATION - IMPLEMENTING ABSTRACT METHOD
125
+ ###########################################
126
+
127
+ async def initialize_client(self) -> None:
128
+ """Initialize the appropriate client.
129
+
130
+ Implements abstract method from BaseLoop to set up the specific
131
+ provider client (OAICompat for UI-TARS).
132
+ """
133
+ try:
134
+ logger.info(f"Initializing OAICompat client for UI-TARS with model {self.model}...")
135
+
136
+ self.client = OAICompatClient(
137
+ api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
138
+ model=self.model,
139
+ provider_base_url=self.provider_base_url,
140
+ )
141
+
142
+ logger.info(f"Initialized OAICompat client with model {self.model}")
143
+ except Exception as e:
144
+ logger.error(f"Error initializing client: {str(e)}")
145
+ self.client = None
146
+ raise RuntimeError(f"Failed to initialize client: {str(e)}")
147
+
148
+ ###########################################
149
+ # MESSAGE FORMATTING
150
+ ###########################################
151
+
152
+ def to_uitars_format(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
153
+ """Convert messages to UI-TARS compatible format.
154
+
155
+ Args:
156
+ messages: List of messages in standard format
157
+
158
+ Returns:
159
+ List of messages formatted for UI-TARS
160
+ """
161
+ # Create a copy of the messages to avoid modifying the original
162
+ uitars_messages = copy.deepcopy(messages)
163
+
164
+ # Find the first user message to modify
165
+ first_user_idx = None
166
+ instruction = ""
167
+
168
+ for idx, msg in enumerate(uitars_messages):
169
+ if msg.get("role") == "user":
170
+ first_user_idx = idx
171
+ content = msg.get("content", "")
172
+ if isinstance(content, str):
173
+ instruction = content
174
+ break
175
+ elif isinstance(content, list):
176
+ for item in content:
177
+ if item.get("type") == "text":
178
+ instruction = item.get("text", "")
179
+ break
180
+ if instruction:
181
+ break
182
+
183
+ # Only modify the first user message if found
184
+ if first_user_idx is not None and instruction:
185
+ # Create the computer use prompt
186
+ user_prompt = COMPUTER_USE.format(
187
+ instruction=instruction,
188
+ language="English"
189
+ )
190
+
191
+ # Replace the content of the first user message
192
+ if isinstance(uitars_messages[first_user_idx].get("content", ""), str):
193
+ uitars_messages[first_user_idx]["content"] = [{"type": "text", "text": user_prompt}]
194
+ elif isinstance(uitars_messages[first_user_idx].get("content", ""), list):
195
+ # Find and replace only the text part, keeping images
196
+ for i, item in enumerate(uitars_messages[first_user_idx]["content"]):
197
+ if item.get("type") == "text":
198
+ uitars_messages[first_user_idx]["content"][i]["text"] = user_prompt
199
+ break
200
+
201
+ # Add box tokens to assistant responses
202
+ for idx, msg in enumerate(uitars_messages):
203
+ if msg.get("role") == "assistant":
204
+ content = msg.get("content", "")
205
+ if content and isinstance(content, list):
206
+ for i, part in enumerate(content):
207
+ if part.get('type') == 'text':
208
+ uitars_messages[idx]["content"][i]["text"] = add_box_token(part['text'])
209
+
210
+ return uitars_messages
211
+
212
+ ###########################################
213
+ # API CALL HANDLING
214
+ ###########################################
215
+
216
+ async def _make_api_call(self, messages: List[Dict[str, Any]], system_prompt: str) -> Any:
217
+ """Make API call to provider with retry logic."""
218
+ # Create new turn directory for this API call
219
+ self._create_turn_dir()
220
+
221
+ request_data = None
222
+ last_error = None
223
+
224
+ for attempt in range(self.max_retries):
225
+ try:
226
+ # Ensure client is initialized
227
+ if self.client is None:
228
+ logger.info(
229
+ f"Client not initialized in _make_api_call (attempt {attempt+1}), initializing now..."
230
+ )
231
+ await self.initialize_client()
232
+ if self.client is None:
233
+ raise RuntimeError("Failed to initialize client")
234
+
235
+ # Convert messages to UI-TARS format
236
+ prepared_messages = self.message_manager.get_messages()
237
+ uitars_messages = self.to_uitars_format(prepared_messages)
238
+
239
+ # Log request
240
+ request_data = {
241
+ "messages": uitars_messages,
242
+ "max_tokens": self.max_tokens,
243
+ "system": system_prompt,
244
+ }
245
+
246
+ self._log_api_call("request", request_data)
247
+
248
+ # Make API call
249
+ response = await self.client.run_interleaved(
250
+ messages=uitars_messages,
251
+ system=system_prompt,
252
+ max_tokens=self.max_tokens,
253
+ )
254
+
255
+ # Log success response
256
+ self._log_api_call("response", request_data, response)
257
+
258
+ return response
259
+
260
+ except (ConnectError, ReadTimeout) as e:
261
+ last_error = e
262
+ logger.warning(
263
+ f"Connection error on attempt {attempt + 1}/{self.max_retries}: {str(e)}"
264
+ )
265
+ if attempt < self.max_retries - 1:
266
+ await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
267
+ # Reset client on connection errors to force re-initialization
268
+ self.client = None
269
+ continue
270
+
271
+ except RuntimeError as e:
272
+ # Handle client initialization errors specifically
273
+ last_error = e
274
+ self._log_api_call("error", request_data, error=e)
275
+ logger.error(
276
+ f"Client initialization error (attempt {attempt + 1}/{self.max_retries}): {str(e)}"
277
+ )
278
+ if attempt < self.max_retries - 1:
279
+ # Reset client to force re-initialization
280
+ self.client = None
281
+ await asyncio.sleep(self.retry_delay)
282
+ continue
283
+
284
+ except Exception as e:
285
+ # Log unexpected error
286
+ last_error = e
287
+ self._log_api_call("error", request_data, error=e)
288
+ logger.error(f"Unexpected error in API call: {str(e)}")
289
+ if attempt < self.max_retries - 1:
290
+ await asyncio.sleep(self.retry_delay)
291
+ continue
292
+
293
+ # If we get here, all retries failed
294
+ error_message = f"API call failed after {self.max_retries} attempts"
295
+ if last_error:
296
+ error_message += f": {str(last_error)}"
297
+
298
+ logger.error(error_message)
299
+ raise RuntimeError(error_message)
300
+
301
+ ###########################################
302
+ # RESPONSE AND ACTION HANDLING
303
+ ###########################################
304
+
305
+ async def _handle_response(
306
+ self, response: Any, messages: List[Dict[str, Any]]
307
+ ) -> Tuple[bool, bool]:
308
+ """Handle API response.
309
+
310
+ Args:
311
+ response: API response
312
+ messages: List of messages to update
313
+
314
+ Returns:
315
+ Tuple of (should_continue, action_screenshot_saved)
316
+ """
317
+ action_screenshot_saved = False
318
+
319
+ try:
320
+ # Step 1: Extract the raw response text
321
+ raw_text = None
322
+
323
+ try:
324
+ # OpenAI-compatible response format
325
+ raw_text = response["choices"][0]["message"]["content"]
326
+ except (KeyError, TypeError, IndexError) as e:
327
+ logger.error(f"Invalid response format: {str(e)}")
328
+ return True, action_screenshot_saved
329
+
330
+ # Step 2: Add the response to message history
331
+ self.message_manager.add_assistant_message([{"type": "text", "text": raw_text}])
332
+
333
+ # Step 3: Parse actions from the response
334
+ parsed_actions = parse_actions(raw_text)
335
+
336
+ if not parsed_actions:
337
+ logger.warning("No action found in the response")
338
+ return True, action_screenshot_saved
339
+
340
+ # Step 4: Execute each action
341
+ for action in parsed_actions:
342
+ action_type = None
343
+
344
+ # Handle "finished" action
345
+ if action.startswith("finished"):
346
+ logger.info("Agent completed the task")
347
+ return False, action_screenshot_saved
348
+
349
+ # Process other action types (click, type, etc.)
350
+ try:
351
+ # Parse action parameters using the utility function
352
+ action_name, tool_args = parse_action_parameters(action)
353
+
354
+ if not action_name:
355
+ logger.warning(f"Could not parse action: {action}")
356
+ continue
357
+
358
+ # Mark actions that would create screenshots
359
+ if action_name in ["click", "left_double", "right_single", "drag", "scroll"]:
360
+ action_screenshot_saved = True
361
+
362
+ # Execute the tool with prepared arguments
363
+ await self._ensure_tools_initialized()
364
+
365
+ # Let's log what we're about to execute for debugging
366
+ logger.info(f"Executing computer tool with arguments: {tool_args}")
367
+
368
+ result = await self.tool_manager.execute_tool(name="computer", tool_input=tool_args)
369
+
370
+ # Handle the result
371
+ if hasattr(result, "error") and result.error:
372
+ logger.error(f"Error executing tool: {result.error}")
373
+ else:
374
+ # Action was successful
375
+ logger.info(f"Successfully executed {action_name}")
376
+
377
+ # Save screenshot if one was returned and we haven't already saved one
378
+ if hasattr(result, "base64_image") and result.base64_image:
379
+ self._save_screenshot(result.base64_image, action_type=action_name)
380
+ action_screenshot_saved = True
381
+
382
+ except Exception as e:
383
+ logger.error(f"Error executing action {action}: {str(e)}")
384
+
385
+ # Continue the loop if there are actions to process
386
+ return True, action_screenshot_saved
387
+
388
+ except Exception as e:
389
+ logger.error(f"Error handling response: {str(e)}")
390
+ # Add error message using the message manager
391
+ error_message = [{"type": "text", "text": f"Error: {str(e)}"}]
392
+ self.message_manager.add_assistant_message(error_message)
393
+ raise
394
+
395
+ ###########################################
396
+ # SCREEN HANDLING
397
+ ###########################################
398
+
399
+ async def _get_current_screen(self, save_screenshot: bool = True) -> str:
400
+ """Get the current screen as a base64 encoded image.
401
+
402
+ Args:
403
+ save_screenshot: Whether to save the screenshot
404
+
405
+ Returns:
406
+ Base64 encoded screenshot
407
+ """
408
+ try:
409
+ # Take a screenshot
410
+ screenshot = await self.computer.interface.screenshot()
411
+
412
+ # Convert to base64
413
+ img_base64 = base64.b64encode(screenshot).decode("utf-8")
414
+
415
+ # Process screenshot through hooks and save if needed
416
+ await self.handle_screenshot(img_base64, action_type="state")
417
+
418
+ # Save screenshot if requested
419
+ if save_screenshot and self.save_trajectory:
420
+ self._save_screenshot(img_base64, action_type="state")
421
+
422
+ return img_base64
423
+
424
+ except Exception as e:
425
+ logger.error(f"Error getting current screen: {str(e)}")
426
+ raise
427
+
428
+ ###########################################
429
+ # SYSTEM PROMPT
430
+ ###########################################
431
+
432
+ def _get_system_prompt(self) -> str:
433
+ """Get the system prompt for the model."""
434
+ return SYSTEM_PROMPT
435
+
436
+ ###########################################
437
+ # MAIN LOOP - IMPLEMENTING ABSTRACT METHOD
438
+ ###########################################
439
+
440
+ async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]:
441
+ """Run the agent loop with provided messages.
442
+
443
+ Args:
444
+ messages: List of messages in standard OpenAI format
445
+
446
+ Yields:
447
+ Agent response format
448
+ """
449
+ # Initialize the message manager with the provided messages
450
+ self.message_manager.messages = messages.copy()
451
+ logger.info(f"Starting UITARSLoop run with {len(self.message_manager.messages)} messages")
452
+
453
+ # Continue running until explicitly told to stop
454
+ running = True
455
+ turn_created = False
456
+ # Track if an action-specific screenshot has been saved this turn
457
+ action_screenshot_saved = False
458
+
459
+ attempt = 0
460
+ max_attempts = 3
461
+
462
+ while running and attempt < max_attempts:
463
+ try:
464
+ # Create a new turn directory if it's not already created
465
+ if not turn_created:
466
+ self._create_turn_dir()
467
+ turn_created = True
468
+
469
+ # Ensure client is initialized
470
+ if self.client is None:
471
+ logger.info("Initializing client...")
472
+ await self.initialize_client()
473
+ if self.client is None:
474
+ raise RuntimeError("Failed to initialize client")
475
+ logger.info("Client initialized successfully")
476
+
477
+ # Get current screen
478
+ base64_screenshot = await self._get_current_screen()
479
+
480
+ # Add screenshot to message history
481
+ self.message_manager.add_user_message(
482
+ [
483
+ {
484
+ "type": "image_url",
485
+ "image_url": {"url": f"data:image/png;base64,{base64_screenshot}"},
486
+ }
487
+ ]
488
+ )
489
+ logger.info("Added screenshot to message history")
490
+
491
+ # Get system prompt
492
+ system_prompt = self._get_system_prompt()
493
+
494
+ # Make API call with retries
495
+ response = await self._make_api_call(
496
+ self.message_manager.messages, system_prompt
497
+ )
498
+
499
+ # Handle the response (may execute actions)
500
+ # Returns: (should_continue, action_screenshot_saved)
501
+ should_continue, new_screenshot_saved = await self._handle_response(
502
+ response, self.message_manager.messages
503
+ )
504
+
505
+ # Update whether an action screenshot was saved this turn
506
+ action_screenshot_saved = action_screenshot_saved or new_screenshot_saved
507
+
508
+ # Parse actions from the raw response
509
+ raw_response = response["choices"][0]["message"]["content"]
510
+ parsed_actions = parse_actions(raw_response)
511
+
512
+ # Extract thought content if available
513
+ thought = ""
514
+ if "Thought:" in raw_response:
515
+ thought_match = re.search(r"Thought: (.*?)(?=\s*Action:|$)", raw_response, re.DOTALL)
516
+ if thought_match:
517
+ thought = thought_match.group(1).strip()
518
+
519
+ # Create standardized thought response format
520
+ thought_response = {
521
+ "role": "assistant",
522
+ "content": thought or raw_response,
523
+ "metadata": {
524
+ "title": "🧠 UI-TARS Thoughts"
525
+ }
526
+ }
527
+
528
+ # Create action response format
529
+ action_response = {
530
+ "role": "assistant",
531
+ "content": str(parsed_actions),
532
+ "metadata": {
533
+ "title": "🖱️ UI-TARS Actions",
534
+ }
535
+ }
536
+
537
+ # Yield both responses to the caller (thoughts first, then actions)
538
+ yield thought_response
539
+ if parsed_actions:
540
+ yield action_response
541
+
542
+ # Check if we should continue this conversation
543
+ running = should_continue
544
+
545
+ # Create a new turn directory if we're continuing
546
+ if running:
547
+ turn_created = False
548
+
549
+ # Reset attempt counter on success
550
+ attempt = 0
551
+
552
+ except Exception as e:
553
+ attempt += 1
554
+ error_msg = f"Error in run method (attempt {attempt}/{max_attempts}): {str(e)}"
555
+ logger.error(error_msg)
556
+
557
+ # If this is our last attempt, provide more info about the error
558
+ if attempt >= max_attempts:
559
+ logger.error(f"Maximum retry attempts reached. Last error was: {str(e)}")
560
+
561
+ yield {
562
+ "error": str(e),
563
+ "metadata": {"title": "❌ Error"},
564
+ }
565
+
566
+ # Create a brief delay before retrying
567
+ await asyncio.sleep(1)
568
+
569
+ ###########################################
570
+ # UTILITY METHODS
571
+ ###########################################
572
+
573
+ async def _ensure_tools_initialized(self) -> None:
574
+ """Ensure the tool manager and tools are initialized before use."""
575
+ if not hasattr(self.tool_manager, "tools") or self.tool_manager.tools is None:
576
+ logger.info("Tools not initialized. Initializing now...")
577
+ await self.tool_manager.initialize()
578
+ logger.info("Tools initialized successfully.")
579
+
580
+ async def process_model_response(self, response_text: str) -> Optional[Dict[str, Any]]:
581
+ """Process model response to extract tool calls.
582
+
583
+ Args:
584
+ response_text: Model response text
585
+
586
+ Returns:
587
+ Extracted tool information, or None if no tool call was found
588
+ """
589
+ # UI-TARS doesn't use the standard tool call format, so we parse its actions differently
590
+ parsed_actions = parse_actions(response_text)
591
+
592
+ if parsed_actions:
593
+ return {"actions": parsed_actions}
594
+
595
+ return None
@@ -0,0 +1,59 @@
1
+ """Prompts for UI-TARS agent."""
2
+
3
+ SYSTEM_PROMPT = "You are a helpful assistant."
4
+
5
+ COMPUTER_USE = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
6
+
7
+ ## Output Format
8
+ ```
9
+ Thought: ...
10
+ Action: ...
11
+ ```
12
+
13
+ ## Action Space
14
+
15
+ click(start_box='<|box_start|>(x1,y1)<|box_end|>')
16
+ left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
17
+ right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
18
+ drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
19
+ hotkey(key='')
20
+ type(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format. If you want to submit your input, use \\n at the end of content.
21
+ scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
22
+ wait() #Sleep for 5s and take a screenshot to check for any changes.
23
+ finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
24
+
25
+
26
+ ## Note
27
+ - Use {language} in `Thought` part.
28
+ - Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
29
+
30
+ ## User Instruction
31
+ {instruction}
32
+ """
33
+
34
+ MOBILE_USE = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
35
+ ## Output Format
36
+ ```
37
+ Thought: ...
38
+ Action: ...
39
+ ```
40
+ ## Action Space
41
+
42
+ click(start_box='<|box_start|>(x1,y1)<|box_end|>')
43
+ long_press(start_box='<|box_start|>(x1,y1)<|box_end|>')
44
+ type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
45
+ scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
46
+ open_app(app_name=\'\')
47
+ drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
48
+ press_home()
49
+ press_back()
50
+ finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
51
+
52
+
53
+ ## Note
54
+ - Use {language} in `Thought` part.
55
+ - Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
56
+
57
+ ## User Instruction
58
+ {instruction}
59
+ """
@@ -0,0 +1 @@
1
+ """UI-TARS tools package."""