cua-agent 0.1.35__py3-none-any.whl → 0.1.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cua-agent might be problematic. Click here for more details.

agent/core/factory.py CHANGED
@@ -116,6 +116,7 @@ class LoopFactory:
116
116
  base_dir=trajectory_dir,
117
117
  only_n_most_recent_images=only_n_most_recent_images,
118
118
  provider_base_url=provider_base_url,
119
+ provider=provider,
119
120
  )
120
121
  else:
121
122
  raise ValueError(f"Unsupported loop type: {loop_type}")
@@ -8,6 +8,7 @@ DEFAULT_MODELS = {
8
8
  LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
9
9
  LLMProvider.OLLAMA: "gemma3:4b-it-q4_K_M",
10
10
  LLMProvider.OAICOMPAT: "Qwen2.5-VL-7B-Instruct",
11
+ LLMProvider.MLXVLM: "mlx-community/UI-TARS-1.5-7B-4bit",
11
12
  }
12
13
 
13
14
  # Map providers to their environment variable names
@@ -16,4 +17,5 @@ ENV_VARS = {
16
17
  LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
17
18
  LLMProvider.OLLAMA: "none",
18
19
  LLMProvider.OAICOMPAT: "none", # OpenAI-compatible API typically doesn't require an API key
20
+ LLMProvider.MLXVLM: "none", # MLX VLM typically doesn't require an API key
19
21
  }
agent/core/types.py CHANGED
@@ -23,6 +23,7 @@ class LLMProvider(StrEnum):
23
23
  OPENAI = "openai"
24
24
  OLLAMA = "ollama"
25
25
  OAICOMPAT = "oaicompat"
26
+ MLXVLM= "mlxvlm"
26
27
 
27
28
 
28
29
  @dataclass
@@ -0,0 +1,263 @@
1
+ """MLX LVM client implementation."""
2
+
3
+ import io
4
+ import logging
5
+ import base64
6
+ import tempfile
7
+ import os
8
+ import re
9
+ import math
10
+ from typing import Dict, List, Optional, Any, cast, Tuple
11
+ from PIL import Image
12
+
13
+ from .base import BaseUITarsClient
14
+ import mlx.core as mx
15
+ from mlx_vlm import load, generate
16
+ from mlx_vlm.prompt_utils import apply_chat_template
17
+ from mlx_vlm.utils import load_config
18
+ from transformers.tokenization_utils import PreTrainedTokenizer
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Constants for smart_resize
23
+ IMAGE_FACTOR = 28
24
+ MIN_PIXELS = 100 * 28 * 28
25
+ MAX_PIXELS = 16384 * 28 * 28
26
+ MAX_RATIO = 200
27
+
28
+ def round_by_factor(number: float, factor: int) -> int:
29
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
30
+ return round(number / factor) * factor
31
+
32
+ def ceil_by_factor(number: float, factor: int) -> int:
33
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
34
+ return math.ceil(number / factor) * factor
35
+
36
+ def floor_by_factor(number: float, factor: int) -> int:
37
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
38
+ return math.floor(number / factor) * factor
39
+
40
+ def smart_resize(
41
+ height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
42
+ ) -> tuple[int, int]:
43
+ """
44
+ Rescales the image so that the following conditions are met:
45
+
46
+ 1. Both dimensions (height and width) are divisible by 'factor'.
47
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
48
+ 3. The aspect ratio of the image is maintained as closely as possible.
49
+ """
50
+ if max(height, width) / min(height, width) > MAX_RATIO:
51
+ raise ValueError(
52
+ f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
53
+ )
54
+ h_bar = max(factor, round_by_factor(height, factor))
55
+ w_bar = max(factor, round_by_factor(width, factor))
56
+ if h_bar * w_bar > max_pixels:
57
+ beta = math.sqrt((height * width) / max_pixels)
58
+ h_bar = floor_by_factor(height / beta, factor)
59
+ w_bar = floor_by_factor(width / beta, factor)
60
+ elif h_bar * w_bar < min_pixels:
61
+ beta = math.sqrt(min_pixels / (height * width))
62
+ h_bar = ceil_by_factor(height * beta, factor)
63
+ w_bar = ceil_by_factor(width * beta, factor)
64
+ return h_bar, w_bar
65
+
66
+ class MLXVLMUITarsClient(BaseUITarsClient):
67
+ """MLX LVM client implementation class."""
68
+
69
+ def __init__(
70
+ self,
71
+ model: str = "mlx-community/UI-TARS-1.5-7B-4bit"
72
+ ):
73
+ """Initialize MLX LVM client.
74
+
75
+ Args:
76
+ model: Model name or path (defaults to mlx-community/UI-TARS-1.5-7B-4bit)
77
+ """
78
+ # Load model and processor
79
+ model_obj, processor = load(
80
+ model,
81
+ processor_kwargs={"min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS}
82
+ )
83
+ self.config = load_config(model)
84
+ self.model = model_obj
85
+ self.processor = processor
86
+ self.model_name = model
87
+
88
+ def _process_coordinates(self, text: str, original_size: Tuple[int, int], model_size: Tuple[int, int]) -> str:
89
+ """Process coordinates in box tokens based on image resizing using smart_resize approach.
90
+
91
+ Args:
92
+ text: Text containing box tokens
93
+ original_size: Original image size (width, height)
94
+ model_size: Model processed image size (width, height)
95
+
96
+ Returns:
97
+ Text with processed coordinates
98
+ """
99
+ # Find all box tokens
100
+ box_pattern = r"<\|box_start\|>\((\d+),\s*(\d+)\)<\|box_end\|>"
101
+
102
+ def process_coords(match):
103
+ model_x, model_y = int(match.group(1)), int(match.group(2))
104
+ # Scale coordinates from model space to original image space
105
+ # Both original_size and model_size are in (width, height) format
106
+ new_x = int(model_x * original_size[0] / model_size[0]) # Width
107
+ new_y = int(model_y * original_size[1] / model_size[1]) # Height
108
+ return f"<|box_start|>({new_x},{new_y})<|box_end|>"
109
+
110
+ return re.sub(box_pattern, process_coords, text)
111
+
112
+ async def run_interleaved(
113
+ self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
114
+ ) -> Dict[str, Any]:
115
+ """Run interleaved chat completion.
116
+
117
+ Args:
118
+ messages: List of message dicts
119
+ system: System prompt
120
+ max_tokens: Optional max tokens override
121
+
122
+ Returns:
123
+ Response dict
124
+ """
125
+ # Ensure the system message is included
126
+ if not any(msg.get("role") == "system" for msg in messages):
127
+ messages = [{"role": "system", "content": system}] + messages
128
+
129
+ # Create a deep copy of messages to avoid modifying the original
130
+ processed_messages = messages.copy()
131
+
132
+ # Extract images and process messages
133
+ images = []
134
+ original_sizes = {} # Track original sizes of images for coordinate mapping
135
+ model_sizes = {} # Track model processed sizes
136
+ image_index = 0
137
+
138
+ for msg_idx, msg in enumerate(messages):
139
+ content = msg.get("content", [])
140
+ if not isinstance(content, list):
141
+ continue
142
+
143
+ # Create a copy of the content list to modify
144
+ processed_content = []
145
+
146
+ for item_idx, item in enumerate(content):
147
+ if item.get("type") == "image_url":
148
+ image_url = item.get("image_url", {}).get("url", "")
149
+ pil_image = None
150
+
151
+ if image_url.startswith("data:image/"):
152
+ # Extract base64 data
153
+ base64_data = image_url.split(',')[1]
154
+ # Convert base64 to PIL Image
155
+ image_data = base64.b64decode(base64_data)
156
+ pil_image = Image.open(io.BytesIO(image_data))
157
+ else:
158
+ # Handle file path or URL
159
+ pil_image = Image.open(image_url)
160
+
161
+ # Store original image size for coordinate mapping
162
+ original_size = pil_image.size
163
+ original_sizes[image_index] = original_size
164
+
165
+ # Use smart_resize to determine model size
166
+ # Note: smart_resize expects (height, width) but PIL gives (width, height)
167
+ height, width = original_size[1], original_size[0]
168
+ new_height, new_width = smart_resize(height, width)
169
+ # Store model size in (width, height) format for consistent coordinate processing
170
+ model_sizes[image_index] = (new_width, new_height)
171
+
172
+ # Resize the image using the calculated dimensions from smart_resize
173
+ resized_image = pil_image.resize((new_width, new_height))
174
+ images.append(resized_image)
175
+ image_index += 1
176
+
177
+ # Copy items to processed content list
178
+ processed_content.append(item.copy())
179
+
180
+ # Update the processed message content
181
+ processed_messages[msg_idx] = msg.copy()
182
+ processed_messages[msg_idx]["content"] = processed_content
183
+
184
+ logger.info(f"resized {len(images)} from {original_sizes[0]} to {model_sizes[0]}")
185
+
186
+ # Process user text input with box coordinates after image processing
187
+ # Swap original_size and model_size arguments for inverse transformation
188
+ for msg_idx, msg in enumerate(processed_messages):
189
+ if msg.get("role") == "user" and isinstance(msg.get("content"), str):
190
+ if "<|box_start|>" in msg.get("content") and original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
191
+ orig_size = original_sizes[0]
192
+ model_size = model_sizes[0]
193
+ # Swap arguments to perform inverse transformation for user input
194
+ processed_messages[msg_idx]["content"] = self._process_coordinates(msg["content"], model_size, orig_size)
195
+
196
+ try:
197
+ # Format prompt according to model requirements using the processor directly
198
+ prompt = self.processor.apply_chat_template(
199
+ processed_messages,
200
+ tokenize=False,
201
+ add_generation_prompt=True
202
+ )
203
+ tokenizer = cast(PreTrainedTokenizer, self.processor)
204
+
205
+ print("generating response...")
206
+
207
+ # Generate response
208
+ text_content, usage = generate(
209
+ self.model,
210
+ tokenizer,
211
+ str(prompt),
212
+ images,
213
+ verbose=False,
214
+ max_tokens=max_tokens
215
+ )
216
+
217
+ from pprint import pprint
218
+ print("DEBUG - AGENT GENERATION --------")
219
+ pprint(text_content)
220
+ print("DEBUG - AGENT GENERATION --------")
221
+ except Exception as e:
222
+ logger.error(f"Error generating response: {str(e)}")
223
+ return {
224
+ "choices": [
225
+ {
226
+ "message": {
227
+ "role": "assistant",
228
+ "content": f"Error generating response: {str(e)}"
229
+ },
230
+ "finish_reason": "error"
231
+ }
232
+ ],
233
+ "model": self.model_name,
234
+ "error": str(e)
235
+ }
236
+
237
+ # Process coordinates in the response back to original image space
238
+ if original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
239
+ # Get original image size and model size (using the first image)
240
+ orig_size = original_sizes[0]
241
+ model_size = model_sizes[0]
242
+
243
+ # Check if output contains box tokens that need processing
244
+ if "<|box_start|>" in text_content:
245
+ # Process coordinates from model space back to original image space
246
+ text_content = self._process_coordinates(text_content, orig_size, model_size)
247
+
248
+ # Format response to match OpenAI format
249
+ response = {
250
+ "choices": [
251
+ {
252
+ "message": {
253
+ "role": "assistant",
254
+ "content": text_content
255
+ },
256
+ "finish_reason": "stop"
257
+ }
258
+ ],
259
+ "model": self.model_name,
260
+ "usage": usage
261
+ }
262
+
263
+ return response
@@ -23,6 +23,7 @@ from .tools.computer import ToolResult
23
23
  from .prompts import COMPUTER_USE, SYSTEM_PROMPT, MAC_SPECIFIC_NOTES
24
24
 
25
25
  from .clients.oaicompat import OAICompatClient
26
+ from .clients.mlxvlm import MLXVLMUITarsClient
26
27
 
27
28
  logging.basicConfig(level=logging.INFO)
28
29
  logger = logging.getLogger(__name__)
@@ -44,6 +45,7 @@ class UITARSLoop(BaseLoop):
44
45
  computer: Computer,
45
46
  api_key: str,
46
47
  model: str,
48
+ provider: Optional[LLMProvider] = None,
47
49
  provider_base_url: Optional[str] = "http://localhost:8000/v1",
48
50
  only_n_most_recent_images: Optional[int] = 2,
49
51
  base_dir: Optional[str] = "trajectories",
@@ -64,9 +66,10 @@ class UITARSLoop(BaseLoop):
64
66
  max_retries: Maximum number of retries for API calls
65
67
  retry_delay: Delay between retries in seconds
66
68
  save_trajectory: Whether to save trajectory data
69
+ provider: The LLM provider to use (defaults to OAICOMPAT if not specified)
67
70
  """
68
71
  # Set provider before initializing base class
69
- self.provider = LLMProvider.OAICOMPAT
72
+ self.provider = provider or LLMProvider.OAICOMPAT
70
73
  self.provider_base_url = provider_base_url
71
74
 
72
75
  # Initialize message manager with image retention config
@@ -113,7 +116,7 @@ class UITARSLoop(BaseLoop):
113
116
  logger.error(f"Error initializing tool manager: {str(e)}")
114
117
  logger.warning("Will attempt to initialize tools on first use.")
115
118
 
116
- # Initialize client for the OAICompat provider
119
+ # Initialize client for the selected provider
117
120
  try:
118
121
  await self.initialize_client()
119
122
  except Exception as e:
@@ -128,18 +131,28 @@ class UITARSLoop(BaseLoop):
128
131
  """Initialize the appropriate client.
129
132
 
130
133
  Implements abstract method from BaseLoop to set up the specific
131
- provider client (OAICompat for UI-TARS).
134
+ provider client based on the configured provider.
132
135
  """
133
136
  try:
134
- logger.info(f"Initializing OAICompat client for UI-TARS with model {self.model}...")
135
-
136
- self.client = OAICompatClient(
137
- api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
138
- model=self.model,
139
- provider_base_url=self.provider_base_url,
140
- )
141
-
142
- logger.info(f"Initialized OAICompat client with model {self.model}")
137
+ if self.provider == LLMProvider.MLXVLM:
138
+ logger.info(f"Initializing MLX VLM client for UI-TARS with model {self.model}...")
139
+
140
+ self.client = MLXVLMUITarsClient(
141
+ model=self.model,
142
+ )
143
+
144
+ logger.info(f"Initialized MLX VLM client with model {self.model}")
145
+ else:
146
+ # Default to OAICompat client for other providers
147
+ logger.info(f"Initializing OAICompat client for UI-TARS with model {self.model}...")
148
+
149
+ self.client = OAICompatClient(
150
+ api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
151
+ model=self.model,
152
+ provider_base_url=self.provider_base_url,
153
+ )
154
+
155
+ logger.info(f"Initialized OAICompat client with model {self.model}")
143
156
  except Exception as e:
144
157
  logger.error(f"Error initializing client: {str(e)}")
145
158
  self.client = None
@@ -105,7 +105,7 @@ async def to_agent_response_format(
105
105
  }
106
106
  ],
107
107
  truncation="auto",
108
- usage=response["usage"],
108
+ usage=response.get("usage", {}),
109
109
  user=None,
110
110
  metadata={},
111
111
  response=response
agent/ui/gradio/app.py CHANGED
@@ -6,7 +6,7 @@ with an advanced UI for model selection and configuration.
6
6
 
7
7
  Supported Agent Loops and Models:
8
8
  - AgentLoop.OPENAI: Uses OpenAI Operator CUA model
9
- computer_use_preview
9
+ computer-use-preview
10
10
 
11
11
  - AgentLoop.ANTHROPIC: Uses Anthropic Computer-Use models
12
12
  • claude-3-5-sonnet-20240620
@@ -133,12 +133,12 @@ class GradioChatScreenshotHandler(DefaultCallbackHandler):
133
133
  MODEL_MAPPINGS = {
134
134
  "openai": {
135
135
  # Default to operator CUA model
136
- "default": "computer_use_preview",
136
+ "default": "computer-use-preview",
137
137
  # Map standard OpenAI model names to CUA-specific model names
138
- "gpt-4-turbo": "computer_use_preview",
139
- "gpt-4o": "computer_use_preview",
140
- "gpt-4": "computer_use_preview",
141
- "gpt-4.5-preview": "computer_use_preview",
138
+ "gpt-4-turbo": "computer-use-preview",
139
+ "gpt-4o": "computer-use-preview",
140
+ "gpt-4": "computer-use-preview",
141
+ "gpt-4.5-preview": "computer-use-preview",
142
142
  "gpt-4o-mini": "gpt-4o-mini",
143
143
  },
144
144
  "anthropic": {
@@ -164,8 +164,10 @@ MODEL_MAPPINGS = {
164
164
  "claude-3-7-sonnet-20250219": "claude-3-7-sonnet-20250219",
165
165
  },
166
166
  "uitars": {
167
- # UI-TARS models default to custom endpoint
168
- "default": "ByteDance-Seed/UI-TARS-1.5-7B",
167
+ # UI-TARS models using MLXVLM provider
168
+ "default": "mlx-community/UI-TARS-1.5-7B-4bit",
169
+ "mlx-community/UI-TARS-1.5-7B-4bit": "mlx-community/UI-TARS-1.5-7B-4bit",
170
+ "mlx-community/UI-TARS-1.5-7B-6bit": "mlx-community/UI-TARS-1.5-7B-6bit"
169
171
  },
170
172
  "ollama": {
171
173
  # For Ollama models, we keep the original name
@@ -215,7 +217,7 @@ def get_provider_and_model(model_name: str, loop_provider: str) -> tuple:
215
217
  # Determine provider and clean model name based on the full string from UI
216
218
  cleaned_model_name = model_name # Default to using the name as-is (for custom)
217
219
 
218
- if model_name == "Custom model...":
220
+ if model_name == "Custom model (OpenAI compatible API)":
219
221
  # Actual model name comes from custom_model_value via model_to_use.
220
222
  # Assume OAICOMPAT for custom models unless overridden by URL/key later?
221
223
  # get_provider_and_model determines the *initial* provider/model.
@@ -276,8 +278,8 @@ def get_provider_and_model(model_name: str, loop_provider: str) -> tuple:
276
278
  break
277
279
  # Note: No fallback needed here as we explicitly check against omni keys
278
280
 
279
- else: # Handles unexpected formats or the raw custom name if "Custom model..." selected
280
- # Should only happen if user selected "Custom model..."
281
+ else: # Handles unexpected formats or the raw custom name if "Custom model (OpenAI compatible API)" selected
282
+ # Should only happen if user selected "Custom model (OpenAI compatible API)"
281
283
  # Or if a model name format isn't caught above
282
284
  provider = LLMProvider.OAICOMPAT
283
285
  cleaned_model_name = (
@@ -288,8 +290,16 @@ def get_provider_and_model(model_name: str, loop_provider: str) -> tuple:
288
290
  model_name_to_use = cleaned_model_name
289
291
  # agent_loop remains AgentLoop.OMNI
290
292
  elif agent_loop == AgentLoop.UITARS:
291
- provider = LLMProvider.OAICOMPAT
292
- model_name_to_use = MODEL_MAPPINGS["uitars"]["default"] # Default
293
+ # For UITARS, use MLXVLM provider for the MLX models, OAICOMPAT for custom
294
+ if model_name == "Custom model (OpenAI compatible API)":
295
+ provider = LLMProvider.OAICOMPAT
296
+ model_name_to_use = "tgi"
297
+ else:
298
+ provider = LLMProvider.MLXVLM
299
+ # Get the model name from the mappings or use as-is if not found
300
+ model_name_to_use = MODEL_MAPPINGS["uitars"].get(
301
+ model_name, model_name if model_name else MODEL_MAPPINGS["uitars"]["default"]
302
+ )
293
303
  else:
294
304
  # Default to OpenAI if unrecognized loop
295
305
  provider = LLMProvider.OPENAI
@@ -439,8 +449,12 @@ def create_gradio_ui(
439
449
  provider_to_models = {
440
450
  "OPENAI": openai_models,
441
451
  "ANTHROPIC": anthropic_models,
442
- "OMNI": omni_models + ["Custom model..."], # Add custom model option
443
- "UITARS": ["Custom model..."], # UI-TARS options
452
+ "OMNI": omni_models + ["Custom model (OpenAI compatible API)", "Custom model (ollama)"], # Add custom model options
453
+ "UITARS": [
454
+ "mlx-community/UI-TARS-1.5-7B-4bit",
455
+ "mlx-community/UI-TARS-1.5-7B-6bit",
456
+ "Custom model (OpenAI compatible API)"
457
+ ], # UI-TARS options with MLX models
444
458
  }
445
459
 
446
460
  # --- Apply Saved Settings (override defaults if available) ---
@@ -460,9 +474,9 @@ def create_gradio_ui(
460
474
  initial_model = anthropic_models[0] if anthropic_models else "No models available"
461
475
  else: # OMNI
462
476
  initial_model = omni_models[0] if omni_models else "No models available"
463
- if "Custom model..." in available_models_for_loop:
477
+ if "Custom model (OpenAI compatible API)" in available_models_for_loop:
464
478
  initial_model = (
465
- "Custom model..." # Default to custom if available and no other default fits
479
+ "Custom model (OpenAI compatible API)" # Default to custom if available and no other default fits
466
480
  )
467
481
 
468
482
  initial_custom_model = saved_settings.get("custom_model", "Qwen2.5-VL-7B-Instruct")
@@ -485,7 +499,7 @@ def create_gradio_ui(
485
499
 
486
500
  Args:
487
501
  agent_loop_choice: The agent loop type (e.g., UITARS, OPENAI, ANTHROPIC, OMNI)
488
- provider: The provider type (e.g., OPENAI, ANTHROPIC, OLLAMA, OAICOMPAT)
502
+ provider: The provider type (e.g., OPENAI, ANTHROPIC, OLLAMA, OAICOMPAT, MLXVLM)
489
503
  model_name: The model name
490
504
  tasks: List of tasks to execute
491
505
  provider_url: The provider base URL for OAICOMPAT providers
@@ -514,14 +528,58 @@ async def main():
514
528
  only_n_most_recent_images={recent_images},
515
529
  save_trajectory={save_trajectory},'''
516
530
 
517
- # Add the model configuration based on provider
518
- if provider == LLMProvider.OAICOMPAT:
531
+ # Add the model configuration based on provider and agent loop
532
+ if agent_loop_choice == "OPENAI":
533
+ # For OPENAI loop, always use OPENAI provider with computer-use-preview
534
+ code += f'''
535
+ model=LLM(
536
+ provider=LLMProvider.OPENAI,
537
+ name="computer-use-preview"
538
+ )'''
539
+ elif agent_loop_choice == "ANTHROPIC":
540
+ # For ANTHROPIC loop, always use ANTHROPIC provider
519
541
  code += f'''
542
+ model=LLM(
543
+ provider=LLMProvider.ANTHROPIC,
544
+ name="{model_name}"
545
+ )'''
546
+ elif agent_loop_choice == "UITARS":
547
+ # For UITARS, use MLXVLM for mlx-community models, OAICOMPAT for others
548
+ if provider == LLMProvider.MLXVLM:
549
+ code += f'''
550
+ model=LLM(
551
+ provider=LLMProvider.MLXVLM,
552
+ name="{model_name}"
553
+ )'''
554
+ else: # OAICOMPAT
555
+ code += f'''
556
+ model=LLM(
557
+ provider=LLMProvider.OAICOMPAT,
558
+ name="{model_name}",
559
+ provider_base_url="{provider_url}"
560
+ )'''
561
+ elif agent_loop_choice == "OMNI":
562
+ # For OMNI, provider can be OPENAI, ANTHROPIC, OLLAMA, or OAICOMPAT
563
+ if provider == LLMProvider.OAICOMPAT:
564
+ code += f'''
520
565
  model=LLM(
521
566
  provider=LLMProvider.OAICOMPAT,
522
567
  name="{model_name}",
523
568
  provider_base_url="{provider_url}"
524
569
  )'''
570
+ else: # OPENAI, ANTHROPIC, OLLAMA
571
+ code += f'''
572
+ model=LLM(
573
+ provider=LLMProvider.{provider.name},
574
+ name="{model_name}"
575
+ )'''
576
+ else:
577
+ # Default case - just use the provided provider and model
578
+ code += f'''
579
+ model=LLM(
580
+ provider=LLMProvider.{provider.name},
581
+ name="{model_name}"
582
+ )'''
525
583
 
526
584
  code += """
527
585
  )
@@ -547,6 +605,8 @@ async def main():
547
605
  print(f"Executing task: {{task}}")
548
606
  async for result in agent.run(task):
549
607
  print(result)'''
608
+
609
+
550
610
 
551
611
  # Add the main block
552
612
  code += '''
@@ -556,62 +616,6 @@ if __name__ == "__main__":
556
616
 
557
617
  return code
558
618
 
559
- # Function to update model choices based on agent loop selection
560
- def update_model_choices(loop):
561
- models = provider_to_models.get(loop, [])
562
- if loop == "OMNI":
563
- # For OMNI, include the custom model option
564
- if not models:
565
- models = ["Custom model..."]
566
- elif "Custom model..." not in models:
567
- models.append("Custom model...")
568
-
569
- # Show both OpenAI and Anthropic key inputs for OMNI if keys aren't set
570
- return [
571
- gr.update(choices=models, value=models[0] if models else "Custom model...", interactive=True),
572
- gr.update(visible=not has_openai_key),
573
- gr.update(visible=not has_anthropic_key)
574
- ]
575
- elif loop == "OPENAI":
576
- # Show only OpenAI key input for OPENAI loop if key isn't set
577
- if not models:
578
- return [
579
- gr.update(choices=["No models available"], value="No models available", interactive=True),
580
- gr.update(visible=not has_openai_key),
581
- gr.update(visible=False)
582
- ]
583
- return [
584
- gr.update(choices=models, value=models[0] if models else None, interactive=True),
585
- gr.update(visible=not has_openai_key),
586
- gr.update(visible=False)
587
- ]
588
- elif loop == "ANTHROPIC":
589
- # Show only Anthropic key input for ANTHROPIC loop if key isn't set
590
- if not models:
591
- return [
592
- gr.update(choices=["No models available"], value="No models available", interactive=True),
593
- gr.update(visible=False),
594
- gr.update(visible=not has_anthropic_key)
595
- ]
596
- return [
597
- gr.update(choices=models, value=models[0] if models else None, interactive=True),
598
- gr.update(visible=False),
599
- gr.update(visible=not has_anthropic_key)
600
- ]
601
- else:
602
- # For other providers (like UITARS), don't show API key inputs
603
- if not models:
604
- return [
605
- gr.update(choices=["No models available"], value="No models available", interactive=True),
606
- gr.update(visible=False),
607
- gr.update(visible=False)
608
- ]
609
- return [
610
- gr.update(choices=models, value=models[0] if models else None, interactive=True),
611
- gr.update(visible=False),
612
- gr.update(visible=False)
613
- ]
614
-
615
619
  # Create the Gradio interface with advanced UI
616
620
  with gr.Blocks(title="Computer-Use Agent") as demo:
617
621
  with gr.Row():
@@ -670,14 +674,52 @@ if __name__ == "__main__":
670
674
  info="Select the agent loop provider",
671
675
  )
672
676
 
673
- # Create model selection dropdown with custom value support for OMNI
674
- model_choice = gr.Dropdown(
675
- choices=provider_to_models.get(initial_loop, ["No models available"]),
676
- label="LLM Provider and Model",
677
- value=initial_model,
678
- info="Select model or choose 'Custom model...' to enter a custom name",
679
- interactive=True,
680
- )
677
+
678
+ # Create separate model selection dropdowns for each provider type
679
+ # This avoids the Gradio bug with updating choices
680
+ with gr.Group() as model_selection_group:
681
+ # OpenAI models dropdown
682
+ openai_model_choice = gr.Dropdown(
683
+ choices=openai_models,
684
+ label="OpenAI Model",
685
+ value=openai_models[0] if openai_models else "No models available",
686
+ info="Select OpenAI model",
687
+ interactive=True,
688
+ visible=(initial_loop == "OPENAI")
689
+ )
690
+
691
+ # Anthropic models dropdown
692
+ anthropic_model_choice = gr.Dropdown(
693
+ choices=anthropic_models,
694
+ label="Anthropic Model",
695
+ value=anthropic_models[0] if anthropic_models else "No models available",
696
+ info="Select Anthropic model",
697
+ interactive=True,
698
+ visible=(initial_loop == "ANTHROPIC")
699
+ )
700
+
701
+ # OMNI models dropdown
702
+ omni_model_choice = gr.Dropdown(
703
+ choices=omni_models + ["Custom model (OpenAI compatible API)", "Custom model (ollama)"],
704
+ label="OMNI Model",
705
+ value=omni_models[0] if omni_models else "Custom model (OpenAI compatible API)",
706
+ info="Select OMNI model or choose a custom model option",
707
+ interactive=True,
708
+ visible=(initial_loop == "OMNI")
709
+ )
710
+
711
+ # UITARS models dropdown
712
+ uitars_model_choice = gr.Dropdown(
713
+ choices=provider_to_models.get("UITARS", ["No models available"]),
714
+ label="UITARS Model",
715
+ value=provider_to_models.get("UITARS", ["No models available"])[0] if provider_to_models.get("UITARS") else "No models available",
716
+ info="Select UITARS model",
717
+ interactive=True,
718
+ visible=(initial_loop == "UITARS")
719
+ )
720
+
721
+ # Hidden field to store the selected model (for compatibility with existing code)
722
+ model_choice = gr.Textbox(visible=False)
681
723
 
682
724
  # Add API key inputs for OpenAI and Anthropic
683
725
  with gr.Group(visible=not has_openai_key and (initial_loop == "OPENAI" or initial_loop == "OMNI")) as openai_key_group:
@@ -699,34 +741,177 @@ if __name__ == "__main__":
699
741
  type="password",
700
742
  info="Required for Anthropic models"
701
743
  )
744
+
745
+ # Function to set OpenAI API key environment variable
746
+ def set_openai_api_key(key):
747
+ if key and key.strip():
748
+ os.environ["OPENAI_API_KEY"] = key.strip()
749
+ print(f"DEBUG - Set OpenAI API key environment variable")
750
+ return key
751
+
752
+ # Function to set Anthropic API key environment variable
753
+ def set_anthropic_api_key(key):
754
+ if key and key.strip():
755
+ os.environ["ANTHROPIC_API_KEY"] = key.strip()
756
+ print(f"DEBUG - Set Anthropic API key environment variable")
757
+ return key
758
+
759
+ # Add change event handlers for API key inputs
760
+ openai_api_key_input.change(
761
+ fn=set_openai_api_key,
762
+ inputs=[openai_api_key_input],
763
+ outputs=[openai_api_key_input],
764
+ queue=False
765
+ )
766
+
767
+ anthropic_api_key_input.change(
768
+ fn=set_anthropic_api_key,
769
+ inputs=[anthropic_api_key_input],
770
+ outputs=[anthropic_api_key_input],
771
+ queue=False
772
+ )
702
773
 
703
- # Add custom model textbox (only visible when "Custom model..." is selected)
774
+ # Combined function to update UI based on selections
775
+ def update_ui(loop=None, openai_model=None, anthropic_model=None, omni_model=None, uitars_model=None):
776
+ # Default values if not provided
777
+ loop = loop or agent_loop.value
778
+
779
+ # Determine which model value to use for custom model checks
780
+ model_value = None
781
+ if loop == "OPENAI" and openai_model:
782
+ model_value = openai_model
783
+ elif loop == "ANTHROPIC" and anthropic_model:
784
+ model_value = anthropic_model
785
+ elif loop == "OMNI" and omni_model:
786
+ model_value = omni_model
787
+ elif loop == "UITARS" and uitars_model:
788
+ model_value = uitars_model
789
+
790
+ # Show/hide appropriate model dropdown based on loop selection
791
+ openai_visible = (loop == "OPENAI")
792
+ anthropic_visible = (loop == "ANTHROPIC")
793
+ omni_visible = (loop == "OMNI")
794
+ uitars_visible = (loop == "UITARS")
795
+
796
+ # Show/hide API key inputs based on loop selection
797
+ show_openai_key = not has_openai_key and (loop == "OPENAI" or (loop == "OMNI" and model_value and "OpenAI" in model_value and "Custom" not in model_value))
798
+ show_anthropic_key = not has_anthropic_key and (loop == "ANTHROPIC" or (loop == "OMNI" and model_value and "Claude" in model_value and "Custom" not in model_value))
799
+
800
+ # Determine custom model visibility
801
+ is_custom_openai_api = model_value == "Custom model (OpenAI compatible API)"
802
+ is_custom_ollama = model_value == "Custom model (ollama)"
803
+ is_any_custom = is_custom_openai_api or is_custom_ollama
804
+
805
+ # Update the hidden model_choice field based on the visible dropdown
806
+ model_choice_value = model_value if model_value else ""
807
+
808
+ # Return all UI updates
809
+ return [
810
+ # Model dropdowns visibility
811
+ gr.update(visible=openai_visible),
812
+ gr.update(visible=anthropic_visible),
813
+ gr.update(visible=omni_visible),
814
+ gr.update(visible=uitars_visible),
815
+ # API key inputs visibility
816
+ gr.update(visible=show_openai_key),
817
+ gr.update(visible=show_anthropic_key),
818
+ # Custom model fields visibility
819
+ gr.update(visible=is_any_custom), # Custom model name always visible for any custom option
820
+ gr.update(visible=is_custom_openai_api), # Provider base URL only for OpenAI compatible API
821
+ gr.update(visible=is_custom_openai_api), # Provider API key only for OpenAI compatible API
822
+ # Update the hidden model_choice field
823
+ gr.update(value=model_choice_value)
824
+ ]
825
+
826
+ # Add custom model textbox (visible for both custom model options)
704
827
  custom_model = gr.Textbox(
705
828
  label="Custom Model Name",
706
- placeholder="Enter custom model name (e.g., Qwen2.5-VL-7B-Instruct)",
829
+ placeholder="Enter custom model name (e.g., Qwen2.5-VL-7B-Instruct or llama3)",
707
830
  value=initial_custom_model,
708
- visible=(initial_model == "Custom model..."),
831
+ visible=(initial_model == "Custom model (OpenAI compatible API)" or initial_model == "Custom model (ollama)"),
709
832
  interactive=True,
710
833
  )
711
834
 
712
- # Add custom provider base URL textbox (only visible when "Custom model..." is selected)
835
+ # Add custom provider base URL textbox (only visible for OpenAI compatible API)
713
836
  provider_base_url = gr.Textbox(
714
837
  label="Provider Base URL",
715
838
  placeholder="Enter provider base URL (e.g., http://localhost:1234/v1)",
716
839
  value=initial_provider_base_url,
717
- visible=(initial_model == "Custom model..."),
840
+ visible=(initial_model == "Custom model (OpenAI compatible API)"),
718
841
  interactive=True,
719
842
  )
720
843
 
721
- # Add custom API key textbox (only visible when "Custom model..." is selected)
844
+ # Add custom API key textbox (only visible for OpenAI compatible API)
722
845
  provider_api_key = gr.Textbox(
723
846
  label="Provider API Key",
724
847
  placeholder="Enter provider API key (if required)",
725
848
  value="",
726
- visible=(initial_model == "Custom model..."),
849
+ visible=(initial_model == "Custom model (OpenAI compatible API)"),
727
850
  interactive=True,
728
851
  type="password",
729
852
  )
853
+
854
+ # Connect agent_loop changes to update all UI elements
855
+ agent_loop.change(
856
+ fn=update_ui,
857
+ inputs=[agent_loop, openai_model_choice, anthropic_model_choice, omni_model_choice, uitars_model_choice],
858
+ outputs=[
859
+ openai_model_choice, anthropic_model_choice, omni_model_choice, uitars_model_choice,
860
+ openai_key_group, anthropic_key_group,
861
+ custom_model, provider_base_url, provider_api_key,
862
+ model_choice # Add model_choice to outputs
863
+ ],
864
+ queue=False # Process immediately without queueing
865
+ )
866
+
867
+ # Connect each model dropdown to update UI
868
+ omni_model_choice.change(
869
+ fn=update_ui,
870
+ inputs=[agent_loop, openai_model_choice, anthropic_model_choice, omni_model_choice, uitars_model_choice],
871
+ outputs=[
872
+ openai_model_choice, anthropic_model_choice, omni_model_choice, uitars_model_choice,
873
+ openai_key_group, anthropic_key_group,
874
+ custom_model, provider_base_url, provider_api_key,
875
+ model_choice # Add model_choice to outputs
876
+ ],
877
+ queue=False
878
+ )
879
+
880
+ uitars_model_choice.change(
881
+ fn=update_ui,
882
+ inputs=[agent_loop, openai_model_choice, anthropic_model_choice, omni_model_choice, uitars_model_choice],
883
+ outputs=[
884
+ openai_model_choice, anthropic_model_choice, omni_model_choice, uitars_model_choice,
885
+ openai_key_group, anthropic_key_group,
886
+ custom_model, provider_base_url, provider_api_key,
887
+ model_choice # Add model_choice to outputs
888
+ ],
889
+ queue=False
890
+ )
891
+
892
+ openai_model_choice.change(
893
+ fn=update_ui,
894
+ inputs=[agent_loop, openai_model_choice, anthropic_model_choice, omni_model_choice, uitars_model_choice],
895
+ outputs=[
896
+ openai_model_choice, anthropic_model_choice, omni_model_choice, uitars_model_choice,
897
+ openai_key_group, anthropic_key_group,
898
+ custom_model, provider_base_url, provider_api_key,
899
+ model_choice # Add model_choice to outputs
900
+ ],
901
+ queue=False
902
+ )
903
+
904
+ anthropic_model_choice.change(
905
+ fn=update_ui,
906
+ inputs=[agent_loop, openai_model_choice, anthropic_model_choice, omni_model_choice, uitars_model_choice],
907
+ outputs=[
908
+ openai_model_choice, anthropic_model_choice, omni_model_choice, uitars_model_choice,
909
+ openai_key_group, anthropic_key_group,
910
+ custom_model, provider_base_url, provider_api_key,
911
+ model_choice # Add model_choice to outputs
912
+ ],
913
+ queue=False
914
+ )
730
915
 
731
916
  save_trajectory = gr.Checkbox(
732
917
  label="Save Trajectory",
@@ -758,6 +943,9 @@ if __name__ == "__main__":
758
943
  placeholder="Ask me to perform tasks in a virtual macOS environment"
759
944
  )
760
945
  clear = gr.Button("Clear")
946
+
947
+ # Add cancel button
948
+ cancel_button = gr.Button("Cancel", variant="stop")
761
949
 
762
950
  # Add examples
763
951
  example_group = gr.Examples(examples=example_messages, inputs=msg)
@@ -768,10 +956,28 @@ if __name__ == "__main__":
768
956
  history.append(gr.ChatMessage(role="user", content=message))
769
957
  return "", history
770
958
 
959
+ # Function to cancel the running agent
960
+ async def cancel_agent_task(history):
961
+ global global_agent
962
+ if global_agent and hasattr(global_agent, '_loop'):
963
+ print("DEBUG - Cancelling agent task")
964
+ # Cancel the agent loop
965
+ if hasattr(global_agent._loop, 'cancel') and callable(global_agent._loop.cancel):
966
+ await global_agent._loop.cancel()
967
+ history.append(gr.ChatMessage(role="assistant", content="Task cancelled by user", metadata={"title": "❌ Cancelled"}))
968
+ else:
969
+ history.append(gr.ChatMessage(role="assistant", content="Could not cancel task: cancel method not found", metadata={"title": "⚠️ Warning"}))
970
+ else:
971
+ history.append(gr.ChatMessage(role="assistant", content="No active agent task to cancel", metadata={"title": "ℹ️ Info"}))
972
+ return history
973
+
771
974
  # Function to process agent response after user input
772
975
  async def process_response(
773
976
  history,
774
- model_choice_value,
977
+ openai_model_value,
978
+ anthropic_model_value,
979
+ omni_model_value,
980
+ uitars_model_value,
775
981
  custom_model_value,
776
982
  agent_loop_choice,
777
983
  save_traj,
@@ -788,21 +994,47 @@ if __name__ == "__main__":
788
994
  # Get the last user message
789
995
  last_user_message = history[-1]["content"]
790
996
 
997
+ # Get the appropriate model value based on the agent loop
998
+ if agent_loop_choice == "OPENAI":
999
+ model_choice_value = openai_model_value
1000
+ elif agent_loop_choice == "ANTHROPIC":
1001
+ model_choice_value = anthropic_model_value
1002
+ elif agent_loop_choice == "OMNI":
1003
+ model_choice_value = omni_model_value
1004
+ elif agent_loop_choice == "UITARS":
1005
+ model_choice_value = uitars_model_value
1006
+ else:
1007
+ model_choice_value = "No models available"
1008
+
1009
+ # Determine if this is a custom model selection and which type
1010
+ is_custom_openai_api = model_choice_value == "Custom model (OpenAI compatible API)"
1011
+ is_custom_ollama = model_choice_value == "Custom model (ollama)"
1012
+ is_custom_model_selected = is_custom_openai_api or is_custom_ollama
1013
+
791
1014
  # Determine the model name string to analyze: custom or from dropdown
792
- model_string_to_analyze = (
793
- custom_model_value
794
- if model_choice_value == "Custom model..."
795
- else model_choice_value # Use the full UI string initially
796
- )
797
-
798
- # Determine if this is a custom model selection
799
- is_custom_model_selected = model_choice_value == "Custom model..."
1015
+ if is_custom_model_selected:
1016
+ model_string_to_analyze = custom_model_value
1017
+ else:
1018
+ model_string_to_analyze = model_choice_value # Use the full UI string initially
800
1019
 
801
1020
  try:
802
- # Get the provider, *cleaned* model name, and agent loop type
803
- provider, cleaned_model_name_from_func, agent_loop_type = (
804
- get_provider_and_model(model_string_to_analyze, agent_loop_choice)
805
- )
1021
+ # Special case for UITARS - use MLXVLM provider
1022
+ if agent_loop_choice == "UITARS":
1023
+ provider = LLMProvider.MLXVLM
1024
+ cleaned_model_name_from_func = model_string_to_analyze
1025
+ agent_loop_type = AgentLoop.UITARS
1026
+ print(f"Using MLXVLM provider for UITARS model: {model_string_to_analyze}")
1027
+ # Special case for Ollama custom model
1028
+ elif is_custom_ollama and agent_loop_choice == "OMNI":
1029
+ provider = LLMProvider.OLLAMA
1030
+ cleaned_model_name_from_func = custom_model_value
1031
+ agent_loop_type = AgentLoop.OMNI
1032
+ print(f"Using Ollama provider for custom model: {custom_model_value}")
1033
+ else:
1034
+ # Get the provider, *cleaned* model name, and agent loop type
1035
+ provider, cleaned_model_name_from_func, agent_loop_type = (
1036
+ get_provider_and_model(model_string_to_analyze, agent_loop_choice)
1037
+ )
806
1038
 
807
1039
  print(f"provider={provider} cleaned_model_name_from_func={cleaned_model_name_from_func} agent_loop_type={agent_loop_type} agent_loop_choice={agent_loop_choice}")
808
1040
 
@@ -814,26 +1046,34 @@ if __name__ == "__main__":
814
1046
  else cleaned_model_name_from_func
815
1047
  )
816
1048
 
817
- # Determine if OAICOMPAT should be used (only if custom model explicitly selected)
818
- is_oaicompat = is_custom_model_selected
1049
+ # Determine if OAICOMPAT should be used (only for OpenAI compatible API custom model)
1050
+ is_oaicompat = is_custom_openai_api and agent_loop_choice != "UITARS"
819
1051
 
820
1052
  # Get API key based on provider determined by get_provider_and_model
821
1053
  if is_oaicompat and custom_api_key:
822
- # Use custom API key if provided for custom model
1054
+ # Use custom API key if provided for OpenAI compatible API custom model
823
1055
  api_key = custom_api_key
824
1056
  print(
825
- f"DEBUG - Using custom API key for model: {final_model_name_to_send}"
1057
+ f"DEBUG - Using custom API key for OpenAI compatible API model: {final_model_name_to_send}"
826
1058
  )
1059
+ elif provider == LLMProvider.OLLAMA:
1060
+ # No API key needed for Ollama
1061
+ api_key = ""
1062
+ print(f"DEBUG - No API key needed for Ollama model: {final_model_name_to_send}")
827
1063
  elif provider == LLMProvider.OPENAI:
828
1064
  # Use OpenAI key from input if provided, otherwise use environment variable
829
1065
  api_key = openai_key_input if openai_key_input else (openai_api_key or os.environ.get("OPENAI_API_KEY", ""))
830
1066
  if openai_key_input:
831
- print(f"DEBUG - Using provided OpenAI API key from UI")
1067
+ # Set the environment variable for the OpenAI API key
1068
+ os.environ["OPENAI_API_KEY"] = openai_key_input
1069
+ print(f"DEBUG - Using provided OpenAI API key from UI and set as environment variable")
832
1070
  elif provider == LLMProvider.ANTHROPIC:
833
1071
  # Use Anthropic key from input if provided, otherwise use environment variable
834
1072
  api_key = anthropic_key_input if anthropic_key_input else (anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY", ""))
835
1073
  if anthropic_key_input:
836
- print(f"DEBUG - Using provided Anthropic API key from UI")
1074
+ # Set the environment variable for the Anthropic API key
1075
+ os.environ["ANTHROPIC_API_KEY"] = anthropic_key_input
1076
+ print(f"DEBUG - Using provided Anthropic API key from UI and set as environment variable")
837
1077
  else:
838
1078
  # For Ollama or default OAICOMPAT (without custom key), no key needed/expected
839
1079
  api_key = ""
@@ -852,8 +1092,8 @@ if __name__ == "__main__":
852
1092
 
853
1093
  # Create or update the agent
854
1094
  create_agent(
855
- # Provider determined by get_provider_and_model unless custom model selected
856
- provider=LLMProvider.OAICOMPAT if is_oaicompat else provider,
1095
+ # Provider determined by special cases and get_provider_and_model
1096
+ provider=provider,
857
1097
  agent_loop=agent_loop_type,
858
1098
  # Pass the FINAL determined model name (cleaned or custom)
859
1099
  model_name=final_model_name_to_send,
@@ -966,13 +1206,21 @@ if __name__ == "__main__":
966
1206
  # Update with error message
967
1207
  history.append(gr.ChatMessage(role="assistant", content=f"Error: {str(e)}"))
968
1208
  yield history
969
-
970
- # Connect the components
971
- msg.submit(chat_submit, [msg, chatbot_history], [msg, chatbot_history]).then(
972
- process_response,
973
- [
1209
+
1210
+ # Connect the submit button to the process_response function
1211
+ submit_event = msg.submit(
1212
+ fn=chat_submit,
1213
+ inputs=[msg, chatbot_history],
1214
+ outputs=[msg, chatbot_history],
1215
+ queue=False,
1216
+ ).then(
1217
+ fn=process_response,
1218
+ inputs=[
974
1219
  chatbot_history,
975
- model_choice,
1220
+ openai_model_choice,
1221
+ anthropic_model_choice,
1222
+ omni_model_choice,
1223
+ uitars_model_choice,
976
1224
  custom_model,
977
1225
  agent_loop,
978
1226
  save_trajectory,
@@ -982,44 +1230,22 @@ if __name__ == "__main__":
982
1230
  openai_api_key_input,
983
1231
  anthropic_api_key_input,
984
1232
  ],
985
- [chatbot_history],
1233
+ outputs=[chatbot_history],
1234
+ queue=True,
986
1235
  )
987
1236
 
988
1237
  # Clear button functionality
989
1238
  clear.click(lambda: None, None, chatbot_history, queue=False)
990
-
991
- # Connect agent_loop changes to model selection
992
- agent_loop.change(
993
- fn=update_model_choices,
994
- inputs=[agent_loop],
995
- outputs=[model_choice],
996
- queue=False, # Process immediately without queueing
997
- )
998
-
999
- # Show/hide custom model, provider base URL, and API key textboxes based on dropdown selection
1000
- def update_custom_model_visibility(model_value):
1001
- is_custom = model_value == "Custom model..."
1002
- return (
1003
- gr.update(visible=is_custom),
1004
- gr.update(visible=is_custom),
1005
- gr.update(visible=is_custom),
1006
- )
1007
-
1008
- model_choice.change(
1009
- fn=update_custom_model_visibility,
1010
- inputs=[model_choice],
1011
- outputs=[custom_model, provider_base_url, provider_api_key],
1012
- queue=False, # Process immediately without queueing
1013
- )
1014
1239
 
1015
- # Connect agent_loop changes to model selection and API key visibility
1016
- agent_loop.change(
1017
- fn=update_model_choices,
1018
- inputs=[agent_loop],
1019
- outputs=[model_choice, openai_key_group, anthropic_key_group],
1020
- queue=False, # Process immediately without queueing
1240
+ # Connect cancel button to cancel function
1241
+ cancel_button.click(
1242
+ cancel_agent_task,
1243
+ [chatbot_history],
1244
+ [chatbot_history],
1245
+ queue=False # Process immediately without queueing
1021
1246
  )
1022
1247
 
1248
+
1023
1249
  # Function to update the code display based on configuration and chat history
1024
1250
  def update_code_display(agent_loop, model_choice_val, custom_model_val, chat_history, provider_base_url, recent_images_val, save_trajectory_val):
1025
1251
  # Extract messages from chat history
@@ -1029,9 +1255,72 @@ if __name__ == "__main__":
1029
1255
  if msg.get("role") == "user":
1030
1256
  messages.append(msg.get("content", ""))
1031
1257
 
1032
- # Determine provider and model name based on selection
1033
- model_string = custom_model_val if model_choice_val == "Custom model..." else model_choice_val
1034
- provider, model_name, _ = get_provider_and_model(model_string, agent_loop)
1258
+ # Determine if this is a custom model selection and which type
1259
+ is_custom_openai_api = model_choice_val == "Custom model (OpenAI compatible API)"
1260
+ is_custom_ollama = model_choice_val == "Custom model (ollama)"
1261
+ is_custom_model_selected = is_custom_openai_api or is_custom_ollama
1262
+
1263
+ # Determine provider and model name based on agent loop
1264
+ if agent_loop == "OPENAI":
1265
+ # For OPENAI loop, always use OPENAI provider with computer-use-preview
1266
+ provider = LLMProvider.OPENAI
1267
+ model_name = "computer-use-preview"
1268
+ elif agent_loop == "ANTHROPIC":
1269
+ # For ANTHROPIC loop, always use ANTHROPIC provider
1270
+ provider = LLMProvider.ANTHROPIC
1271
+ # Extract model name from the UI string
1272
+ if model_choice_val.startswith("Anthropic: Claude "):
1273
+ # Extract the model name based on the UI string
1274
+ model_parts = model_choice_val.replace("Anthropic: Claude ", "").split(" (")
1275
+ version = model_parts[0] # e.g., "3.7 Sonnet"
1276
+ date = model_parts[1].replace(")", "") if len(model_parts) > 1 else "" # e.g., "20250219"
1277
+
1278
+ # Format as claude-3-7-sonnet-20250219 or claude-3-5-sonnet-20240620
1279
+ version = version.replace(".", "-").replace(" ", "-").lower()
1280
+ model_name = f"claude-{version}-{date}"
1281
+ else:
1282
+ # Use the model_choice_val directly if it doesn't match the expected format
1283
+ model_name = model_choice_val
1284
+ elif agent_loop == "UITARS":
1285
+ # For UITARS, use MLXVLM for mlx-community models, OAICOMPAT for custom
1286
+ if model_choice_val == "Custom model (OpenAI compatible API)":
1287
+ provider = LLMProvider.OAICOMPAT
1288
+ model_name = custom_model_val
1289
+ else:
1290
+ provider = LLMProvider.MLXVLM
1291
+ model_name = model_choice_val
1292
+ elif agent_loop == "OMNI":
1293
+ # For OMNI, provider can be OPENAI, ANTHROPIC, OLLAMA, or OAICOMPAT
1294
+ if is_custom_openai_api:
1295
+ provider = LLMProvider.OAICOMPAT
1296
+ model_name = custom_model_val
1297
+ elif is_custom_ollama:
1298
+ provider = LLMProvider.OLLAMA
1299
+ model_name = custom_model_val
1300
+ elif model_choice_val.startswith("OMNI: OpenAI "):
1301
+ provider = LLMProvider.OPENAI
1302
+ # Extract model name from UI string (e.g., "OMNI: OpenAI GPT-4o" -> "gpt-4o")
1303
+ model_name = model_choice_val.replace("OMNI: OpenAI ", "").lower().replace(" ", "-")
1304
+ elif model_choice_val.startswith("OMNI: Claude "):
1305
+ provider = LLMProvider.ANTHROPIC
1306
+ # Extract model name from UI string (similar to ANTHROPIC loop case)
1307
+ model_parts = model_choice_val.replace("OMNI: Claude ", "").split(" (")
1308
+ version = model_parts[0] # e.g., "3.7 Sonnet"
1309
+ date = model_parts[1].replace(")", "") if len(model_parts) > 1 else "" # e.g., "20250219"
1310
+
1311
+ # Format as claude-3-7-sonnet-20250219 or claude-3-5-sonnet-20240620
1312
+ version = version.replace(".", "-").replace(" ", "-").lower()
1313
+ model_name = f"claude-{version}-{date}"
1314
+ elif model_choice_val.startswith("OMNI: Ollama "):
1315
+ provider = LLMProvider.OLLAMA
1316
+ # Extract model name from UI string (e.g., "OMNI: Ollama llama3" -> "llama3")
1317
+ model_name = model_choice_val.replace("OMNI: Ollama ", "")
1318
+ else:
1319
+ # Fallback to get_provider_and_model for any other cases
1320
+ provider, model_name, _ = get_provider_and_model(model_choice_val, agent_loop)
1321
+ else:
1322
+ # Fallback for any other agent loop
1323
+ provider, model_name, _ = get_provider_and_model(model_choice_val, agent_loop)
1035
1324
 
1036
1325
  # Generate and return the code
1037
1326
  return generate_python_code(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: cua-agent
3
- Version: 0.1.35
3
+ Version: 0.1.38
4
4
  Summary: CUA (Computer Use) Agent for AI-driven computer interaction
5
5
  Author-Email: TryCua <gh@trycua.com>
6
6
  Requires-Python: >=3.10
@@ -23,6 +23,7 @@ Requires-Dist: openai<2.0.0,>=1.14.0; extra == "openai"
23
23
  Requires-Dist: httpx<0.29.0,>=0.27.0; extra == "openai"
24
24
  Provides-Extra: uitars
25
25
  Requires-Dist: httpx<0.29.0,>=0.27.0; extra == "uitars"
26
+ Provides-Extra: uitars-mlx
26
27
  Provides-Extra: ui
27
28
  Requires-Dist: gradio<6.0.0,>=5.23.3; extra == "ui"
28
29
  Requires-Dist: python-dotenv<2.0.0,>=1.0.1; extra == "ui"
@@ -104,6 +105,10 @@ pip install "cua-agent[anthropic]" # Anthropic Cua Loop
104
105
  pip install "cua-agent[uitars]" # UI-Tars support
105
106
  pip install "cua-agent[omni]" # Cua Loop based on OmniParser (includes Ollama for local models)
106
107
  pip install "cua-agent[ui]" # Gradio UI for the agent
108
+
109
+ # For local UI-TARS with MLX support, you need to manually install mlx-vlm:
110
+ pip install "cua-agent[uitars-mlx]"
111
+ pip install git+https://github.com/ddupont808/mlx-vlm.git@stable/fix/qwen2-position-id # PR: https://github.com/Blaizzy/mlx-vlm/pull/349
107
112
  ```
108
113
 
109
114
  ## Run
@@ -206,7 +211,32 @@ The Gradio UI provides:
206
211
 
207
212
  ### Using UI-TARS
208
213
 
209
- You can use UI-TARS by first following the [deployment guide](https://github.com/bytedance/UI-TARS/blob/main/README_deploy.md). This will give you a provider URL like this: `https://**************.us-east-1.aws.endpoints.huggingface.cloud/v1` which you can use in the gradio UI.
214
+ The UI-TARS models are available in two forms:
215
+
216
+ 1. **MLX UI-TARS models** (Default): These models run locally using MLXVLM provider
217
+ - `mlx-community/UI-TARS-1.5-7B-4bit` (default) - 4-bit quantized version
218
+ - `mlx-community/UI-TARS-1.5-7B-6bit` - 6-bit quantized version for higher quality
219
+
220
+ ```python
221
+ agent = ComputerAgent(
222
+ computer=macos_computer,
223
+ loop=AgentLoop.UITARS,
224
+ model=LLM(provider=LLMProvider.MLXVLM, name="mlx-community/UI-TARS-1.5-7B-4bit")
225
+ )
226
+ ```
227
+
228
+ 2. **OpenAI-compatible UI-TARS**: For using the original ByteDance model
229
+ - If you want to use the original ByteDance UI-TARS model via an OpenAI-compatible API, follow the [deployment guide](https://github.com/bytedance/UI-TARS/blob/main/README_deploy.md)
230
+ - This will give you a provider URL like `https://**************.us-east-1.aws.endpoints.huggingface.cloud/v1` which you can use in the code or Gradio UI:
231
+
232
+ ```python
233
+ agent = ComputerAgent(
234
+ computer=macos_computer,
235
+ loop=AgentLoop.UITARS,
236
+ model=LLM(provider=LLMProvider.OAICOMPAT, name="tgi",
237
+ provider_base_url="https://**************.us-east-1.aws.endpoints.huggingface.cloud/v1")
238
+ )
239
+ ```
210
240
 
211
241
  ## Agent Loops
212
242
 
@@ -216,7 +246,7 @@ The `cua-agent` package provides three agent loops variations, based on differen
216
246
  |:-----------|:-----------------|:------------|:-------------|
217
247
  | `AgentLoop.OPENAI` | • `computer_use_preview` | Use OpenAI Operator CUA model | Not Required |
218
248
  | `AgentLoop.ANTHROPIC` | • `claude-3-5-sonnet-20240620`<br>• `claude-3-7-sonnet-20250219` | Use Anthropic Computer-Use | Not Required |
219
- | `AgentLoop.UITARS` | • `ByteDance-Seed/UI-TARS-1.5-7B` | Uses ByteDance's UI-TARS 1.5 model | Not Required |
249
+ | `AgentLoop.UITARS` | • `mlx-community/UI-TARS-1.5-7B-4bit` (default)<br>• `mlx-community/UI-TARS-1.5-7B-6bit`<br>• `ByteDance-Seed/UI-TARS-1.5-7B` (via openAI-compatible endpoint) | Uses UI-TARS models with MLXVLM (default) or OAICOMPAT providers | Not Required |
220
250
  | `AgentLoop.OMNI` | • `claude-3-5-sonnet-20240620`<br>• `claude-3-7-sonnet-20250219`<br>• `gpt-4.5-preview`<br>• `gpt-4o`<br>• `gpt-4`<br>• `phi4`<br>• `phi4-mini`<br>• `gemma3`<br>• `...`<br>• `Any Ollama or OpenAI-compatible model` | Use OmniParser for element pixel-detection (SoM) and any VLMs for UI Grounding and Reasoning | OmniParser |
221
251
 
222
252
  ## AgentResponse
@@ -4,9 +4,9 @@ agent/core/agent.py,sha256=HUfBe7Uam3TObAmf6KH0GDKuNCNunNmmMcuxS7aZg0Q,8332
4
4
  agent/core/base.py,sha256=2sg8B2VqUKImRlkLTNj5lx-Oarlu7_GoMR6MbNzSY9Q,8078
5
5
  agent/core/callbacks.py,sha256=FKAxyajJ-ZJ5SxNXoupNcrm0GYBgjOjJEsStqst0EAk,6453
6
6
  agent/core/experiment.py,sha256=Ywj6q3JZFDKicfPuQsDl0vSN55HS7-Cnk3u3EcUCKe8,8866
7
- agent/core/factory.py,sha256=LSamFOjq2WhGGp5EVyLfDMAUrgHH1C_K5PKpdo24rhU,4573
7
+ agent/core/factory.py,sha256=zzlCdibctqhf8Uta-SrvE-G7h59wAw-7SGhHiGvS9GY,4608
8
8
  agent/core/messages.py,sha256=-OVMDqcxK5MUHPEkHliK29XFJYMRAc1keFvzrUyrOmM,16231
9
- agent/core/provider_config.py,sha256=Hr9kDFSXdPeqC6hbid3OTykNF0-XVi0wzZyd44a7kww,627
9
+ agent/core/provider_config.py,sha256=jB3fLsEsf806HQZ8jtzfSq4bCYGYONBeuCOoog_Nv_Y,768
10
10
  agent/core/telemetry.py,sha256=HElPd32k_w2SJ6t-Cc3j_2-AKdLbFwh2YlM8QViDgRw,4790
11
11
  agent/core/tools.py,sha256=Jes2CFCFqC727WWHbO-sG7V03rBHnQe5X7Oi9ZkuScI,877
12
12
  agent/core/tools/__init__.py,sha256=xZen-PqUp2dUaMEHJowXCQm33_5Sxhsx9PSoD0rq6tI,489
@@ -16,7 +16,7 @@ agent/core/tools/collection.py,sha256=NuwTn6dXSyznxWodfmFDQwUlxxaGb4oBPym4AEJABS
16
16
  agent/core/tools/computer.py,sha256=lT_aW3huoYpcM8kffuokELupSz_WZG_qkaW1gITRC58,3892
17
17
  agent/core/tools/edit.py,sha256=kv4jTKCM0VXrnoNErf7mT-xlr81-7T8v49_VA9y_L4Y,2005
18
18
  agent/core/tools/manager.py,sha256=IRsCXjGc076nncQuyIjODoafnHTDhrf9sP5B4q5Pcdo,1742
19
- agent/core/types.py,sha256=lDMtMFoBRW82X559VJBpbnNAzRo4LL7BbhT5r_QZFmg,2421
19
+ agent/core/types.py,sha256=tkT-PqjgjL0oWVBRFkHAGWVwYx2Byp7PlUWSpvw_-h8,2442
20
20
  agent/core/visualization.py,sha256=1DuFF5sSeSf5BRSevBMDxml9-ajl7BQLFm5KBUwMbI8,6573
21
21
  agent/providers/__init__.py,sha256=b4tIBAaIB1V7p8V0BWipHVnMhfHH_OuVgP4OWGSHdD8,194
22
22
  agent/providers/anthropic/__init__.py,sha256=Mj11IZnVshZ2iHkvg4Z5-jrQIaD1WvzDz2Zk_pMwqIA,149
@@ -68,18 +68,19 @@ agent/providers/openai/types.py,sha256=0mFUxeFy23fJhMwc6lAFVXKngg2fJIXkPS5oV284V
68
68
  agent/providers/openai/utils.py,sha256=YeCZWIqOFSeugWoqAS0rhxOKAfL-9uN9nrYSBGBgPdc,3175
69
69
  agent/providers/uitars/__init__.py,sha256=sq5OMVJP9E_sok9tIiKJreGkjmNWXPMObjPTClYv1es,38
70
70
  agent/providers/uitars/clients/base.py,sha256=5w8Ajmq1JiPyUQJUAq1lSkfpA8_Ts80NQiDxPMTtQrI,948
71
+ agent/providers/uitars/clients/mlxvlm.py,sha256=lMnN6ecMmWHf_l7khJ2iJHHvT7PE4XagUjrWhB0zEhc,10893
71
72
  agent/providers/uitars/clients/oaicompat.py,sha256=uYjwrGCVpFi8wj4kcaJ905ABiY6ksJZXaLlM61B2DUA,8907
72
- agent/providers/uitars/loop.py,sha256=CoZDk4ltz5nsw9yDnFKET5skP1uzibl3QDZOUfJQsKQ,22774
73
+ agent/providers/uitars/loop.py,sha256=4-cgQteixPy03vp7xWezd6jWpuPkBmlLS3tizaOmd0U,23494
73
74
  agent/providers/uitars/prompts.py,sha256=_pQNd438mFpZKZT0aMl6Bd0_GgQxuy9y08kQAMPi9UM,2536
74
75
  agent/providers/uitars/tools/__init__.py,sha256=0hc3W6u5TvcXYztYKIyve_C2G3XMfwt_y7grmH0ZHC0,29
75
76
  agent/providers/uitars/tools/computer.py,sha256=TeIg_aCtMroxWOBJEiYY_YI4krW_C3pYu51tgGsVUYU,11808
76
77
  agent/providers/uitars/tools/manager.py,sha256=2dK9STtz6NuZG3i0nH7ZuHJpb7vKJ2mOVbxGsb0t8lQ,1945
77
- agent/providers/uitars/utils.py,sha256=S6FiZ3P-O4B15P1Gdup2o7SyuIu4nSQbspxcektpwmM,8870
78
+ agent/providers/uitars/utils.py,sha256=493STTEEJcVhVbQgR0e8rNTI1DjkxUx8IgIv3wkJ1SU,8878
78
79
  agent/telemetry.py,sha256=pVGxbj0ewnvq4EGj28CydN4a1iOfvZR_XKL3vIOqhOM,390
79
80
  agent/ui/__init__.py,sha256=ohhxJLBin6k1hl5sKcmBST8mgh23WXgAXz3pN4f470E,45
80
81
  agent/ui/gradio/__init__.py,sha256=ANKZhv1HqsLheWbLVBlyRQ7Q5qGeXuPi5jDs8vu-ZMo,579
81
- agent/ui/gradio/app.py,sha256=q_nS6JJLlu1Y9xu56YHR26l_ypgaK3zR3v6BfpZT4qc,49396
82
- cua_agent-0.1.35.dist-info/METADATA,sha256=C7b0g8sHR6-3eWEJFQwuelXf6MKPeD8_Z8Z5aPwoikQ,11335
83
- cua_agent-0.1.35.dist-info/WHEEL,sha256=tSfRZzRHthuv7vxpI4aehrdN9scLjk-dCJkPLzkHxGg,90
84
- cua_agent-0.1.35.dist-info/entry_points.txt,sha256=6OYgBcLyFCUgeqLgnvMyOJxPCWzgy7se4rLPKtNonMs,34
85
- cua_agent-0.1.35.dist-info/RECORD,,
82
+ agent/ui/gradio/app.py,sha256=M_pqSiN40F1u-8luBdqvTJxQFGzqd0WPsSz8APLbPus,67826
83
+ cua_agent-0.1.38.dist-info/METADATA,sha256=z1UB551Vd-9YISHLzAY_gXSJFZTOgHi7lS-wy0XlgV8,12689
84
+ cua_agent-0.1.38.dist-info/WHEEL,sha256=tSfRZzRHthuv7vxpI4aehrdN9scLjk-dCJkPLzkHxGg,90
85
+ cua_agent-0.1.38.dist-info/entry_points.txt,sha256=6OYgBcLyFCUgeqLgnvMyOJxPCWzgy7se4rLPKtNonMs,34
86
+ cua_agent-0.1.38.dist-info/RECORD,,