cua-agent 0.4.16__tar.gz → 0.4.18__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- {cua_agent-0.4.16 → cua_agent-0.4.18}/PKG-INFO +3 -3
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/adapters/__init__.py +2 -0
- cua_agent-0.4.18/agent/adapters/mlxvlm_adapter.py +358 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/agent.py +10 -4
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/callbacks/__init__.py +2 -0
- cua_agent-0.4.18/agent/callbacks/operator_validator.py +138 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/callbacks/trajectory_saver.py +5 -1
- cua_agent-0.4.18/agent/integrations/hud/__init__.py +228 -0
- cua_agent-0.4.18/agent/integrations/hud/proxy.py +183 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/loops/anthropic.py +12 -1
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/loops/composed_grounded.py +26 -14
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/loops/openai.py +15 -7
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/loops/uitars.py +17 -8
- cua_agent-0.4.18/agent/proxy/examples.py +192 -0
- cua_agent-0.4.18/agent/proxy/handlers.py +248 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/pyproject.toml +3 -3
- cua_agent-0.4.16/agent/integrations/hud/__init__.py +0 -77
- cua_agent-0.4.16/agent/integrations/hud/adapter.py +0 -121
- cua_agent-0.4.16/agent/integrations/hud/agent.py +0 -373
- cua_agent-0.4.16/agent/integrations/hud/computer_handler.py +0 -187
- {cua_agent-0.4.16 → cua_agent-0.4.18}/README.md +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/__init__.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/__main__.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/adapters/huggingfacelocal_adapter.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/adapters/human_adapter.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/callbacks/base.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/callbacks/budget_manager.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/callbacks/image_retention.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/callbacks/logging.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/callbacks/pii_anonymization.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/callbacks/telemetry.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/cli.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/computers/__init__.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/computers/base.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/computers/cua.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/computers/custom.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/decorators.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/human_tool/__init__.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/human_tool/__main__.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/human_tool/server.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/human_tool/ui.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/loops/__init__.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/loops/base.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/loops/glm45v.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/loops/gta1.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/loops/model_types.csv +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/loops/omniparser.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/responses.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/types.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/ui/__init__.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/ui/__main__.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/ui/gradio/__init__.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/ui/gradio/app.py +0 -0
- {cua_agent-0.4.16 → cua_agent-0.4.18}/agent/ui/gradio/ui_components.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: cua-agent
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.18
|
|
4
4
|
Summary: CUA (Computer Use) Agent for AI-driven computer interaction
|
|
5
5
|
Author-Email: TryCua <gh@trycua.com>
|
|
6
6
|
Requires-Python: >=3.12
|
|
@@ -38,7 +38,7 @@ Requires-Dist: python-dotenv>=1.0.1; extra == "ui"
|
|
|
38
38
|
Provides-Extra: cli
|
|
39
39
|
Requires-Dist: yaspin>=3.1.0; extra == "cli"
|
|
40
40
|
Provides-Extra: hud
|
|
41
|
-
Requires-Dist: hud-python
|
|
41
|
+
Requires-Dist: hud-python<0.5.0,>=0.4.12; extra == "hud"
|
|
42
42
|
Provides-Extra: all
|
|
43
43
|
Requires-Dist: ultralytics>=8.0.0; extra == "all"
|
|
44
44
|
Requires-Dist: cua-som<0.2.0,>=0.1.0; extra == "all"
|
|
@@ -49,7 +49,7 @@ Requires-Dist: transformers>=4.54.0; extra == "all"
|
|
|
49
49
|
Requires-Dist: gradio>=5.23.3; extra == "all"
|
|
50
50
|
Requires-Dist: python-dotenv>=1.0.1; extra == "all"
|
|
51
51
|
Requires-Dist: yaspin>=3.1.0; extra == "all"
|
|
52
|
-
Requires-Dist: hud-python
|
|
52
|
+
Requires-Dist: hud-python<0.5.0,>=0.4.12; extra == "all"
|
|
53
53
|
Description-Content-Type: text/markdown
|
|
54
54
|
|
|
55
55
|
<div align="center">
|
|
@@ -4,8 +4,10 @@ Adapters package for agent - Custom LLM adapters for LiteLLM
|
|
|
4
4
|
|
|
5
5
|
from .huggingfacelocal_adapter import HuggingFaceLocalAdapter
|
|
6
6
|
from .human_adapter import HumanAdapter
|
|
7
|
+
from .mlxvlm_adapter import MLXVLMAdapter
|
|
7
8
|
|
|
8
9
|
__all__ = [
|
|
9
10
|
"HuggingFaceLocalAdapter",
|
|
10
11
|
"HumanAdapter",
|
|
12
|
+
"MLXVLMAdapter",
|
|
11
13
|
]
|
|
@@ -0,0 +1,358 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import functools
|
|
3
|
+
import warnings
|
|
4
|
+
import io
|
|
5
|
+
import base64
|
|
6
|
+
import math
|
|
7
|
+
import re
|
|
8
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
9
|
+
from typing import Iterator, AsyncIterator, Dict, List, Any, Optional, Tuple, cast
|
|
10
|
+
from PIL import Image
|
|
11
|
+
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
|
12
|
+
from litellm.llms.custom_llm import CustomLLM
|
|
13
|
+
from litellm import completion, acompletion
|
|
14
|
+
|
|
15
|
+
# Try to import MLX dependencies
|
|
16
|
+
try:
|
|
17
|
+
import mlx.core as mx
|
|
18
|
+
from mlx_vlm import load, generate
|
|
19
|
+
from mlx_vlm.prompt_utils import apply_chat_template
|
|
20
|
+
from mlx_vlm.utils import load_config
|
|
21
|
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
22
|
+
MLX_AVAILABLE = True
|
|
23
|
+
except ImportError:
|
|
24
|
+
MLX_AVAILABLE = False
|
|
25
|
+
|
|
26
|
+
# Constants for smart_resize
|
|
27
|
+
IMAGE_FACTOR = 28
|
|
28
|
+
MIN_PIXELS = 100 * 28 * 28
|
|
29
|
+
MAX_PIXELS = 16384 * 28 * 28
|
|
30
|
+
MAX_RATIO = 200
|
|
31
|
+
|
|
32
|
+
def round_by_factor(number: float, factor: int) -> int:
|
|
33
|
+
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
|
34
|
+
return round(number / factor) * factor
|
|
35
|
+
|
|
36
|
+
def ceil_by_factor(number: float, factor: int) -> int:
|
|
37
|
+
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
|
38
|
+
return math.ceil(number / factor) * factor
|
|
39
|
+
|
|
40
|
+
def floor_by_factor(number: float, factor: int) -> int:
|
|
41
|
+
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
|
42
|
+
return math.floor(number / factor) * factor
|
|
43
|
+
|
|
44
|
+
def smart_resize(
|
|
45
|
+
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
|
46
|
+
) -> tuple[int, int]:
|
|
47
|
+
"""
|
|
48
|
+
Rescales the image so that the following conditions are met:
|
|
49
|
+
|
|
50
|
+
1. Both dimensions (height and width) are divisible by 'factor'.
|
|
51
|
+
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
|
52
|
+
3. The aspect ratio of the image is maintained as closely as possible.
|
|
53
|
+
"""
|
|
54
|
+
if max(height, width) / min(height, width) > MAX_RATIO:
|
|
55
|
+
raise ValueError(
|
|
56
|
+
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
|
57
|
+
)
|
|
58
|
+
h_bar = max(factor, round_by_factor(height, factor))
|
|
59
|
+
w_bar = max(factor, round_by_factor(width, factor))
|
|
60
|
+
if h_bar * w_bar > max_pixels:
|
|
61
|
+
beta = math.sqrt((height * width) / max_pixels)
|
|
62
|
+
h_bar = floor_by_factor(height / beta, factor)
|
|
63
|
+
w_bar = floor_by_factor(width / beta, factor)
|
|
64
|
+
elif h_bar * w_bar < min_pixels:
|
|
65
|
+
beta = math.sqrt(min_pixels / (height * width))
|
|
66
|
+
h_bar = ceil_by_factor(height * beta, factor)
|
|
67
|
+
w_bar = ceil_by_factor(width * beta, factor)
|
|
68
|
+
return h_bar, w_bar
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class MLXVLMAdapter(CustomLLM):
|
|
72
|
+
"""MLX VLM Adapter for running vision-language models locally using MLX."""
|
|
73
|
+
|
|
74
|
+
def __init__(self, **kwargs):
|
|
75
|
+
"""Initialize the adapter.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
**kwargs: Additional arguments
|
|
79
|
+
"""
|
|
80
|
+
super().__init__()
|
|
81
|
+
if not MLX_AVAILABLE:
|
|
82
|
+
raise ImportError("MLX VLM dependencies not available. Please install mlx-vlm.")
|
|
83
|
+
|
|
84
|
+
self.models = {} # Cache for loaded models
|
|
85
|
+
self.processors = {} # Cache for loaded processors
|
|
86
|
+
self.configs = {} # Cache for loaded configs
|
|
87
|
+
self._executor = ThreadPoolExecutor(max_workers=1) # Single thread pool
|
|
88
|
+
|
|
89
|
+
def _load_model_and_processor(self, model_name: str):
|
|
90
|
+
"""Load model and processor if not already cached.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
model_name: Name of the model to load
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Tuple of (model, processor, config)
|
|
97
|
+
"""
|
|
98
|
+
if model_name not in self.models:
|
|
99
|
+
# Load model and processor
|
|
100
|
+
model_obj, processor = load(
|
|
101
|
+
model_name,
|
|
102
|
+
processor_kwargs={"min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS}
|
|
103
|
+
)
|
|
104
|
+
config = load_config(model_name)
|
|
105
|
+
|
|
106
|
+
# Cache them
|
|
107
|
+
self.models[model_name] = model_obj
|
|
108
|
+
self.processors[model_name] = processor
|
|
109
|
+
self.configs[model_name] = config
|
|
110
|
+
|
|
111
|
+
return self.models[model_name], self.processors[model_name], self.configs[model_name]
|
|
112
|
+
|
|
113
|
+
def _process_coordinates(self, text: str, original_size: Tuple[int, int], model_size: Tuple[int, int]) -> str:
|
|
114
|
+
"""Process coordinates in box tokens based on image resizing using smart_resize approach.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
text: Text containing box tokens
|
|
118
|
+
original_size: Original image size (width, height)
|
|
119
|
+
model_size: Model processed image size (width, height)
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
Text with processed coordinates
|
|
123
|
+
"""
|
|
124
|
+
# Find all box tokens
|
|
125
|
+
box_pattern = r"<\|box_start\|>\((\d+),\s*(\d+)\)<\|box_end\|>"
|
|
126
|
+
|
|
127
|
+
def process_coords(match):
|
|
128
|
+
model_x, model_y = int(match.group(1)), int(match.group(2))
|
|
129
|
+
# Scale coordinates from model space to original image space
|
|
130
|
+
# Both original_size and model_size are in (width, height) format
|
|
131
|
+
new_x = int(model_x * original_size[0] / model_size[0]) # Width
|
|
132
|
+
new_y = int(model_y * original_size[1] / model_size[1]) # Height
|
|
133
|
+
return f"<|box_start|>({new_x},{new_y})<|box_end|>"
|
|
134
|
+
|
|
135
|
+
return re.sub(box_pattern, process_coords, text)
|
|
136
|
+
|
|
137
|
+
def _convert_messages(self, messages: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[Image.Image], Dict[int, Tuple[int, int]], Dict[int, Tuple[int, int]]]:
|
|
138
|
+
"""Convert OpenAI format messages to MLX VLM format and extract images.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
messages: Messages in OpenAI format
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Tuple of (processed_messages, images, original_sizes, model_sizes)
|
|
145
|
+
"""
|
|
146
|
+
processed_messages = []
|
|
147
|
+
images = []
|
|
148
|
+
original_sizes = {} # Track original sizes of images for coordinate mapping
|
|
149
|
+
model_sizes = {} # Track model processed sizes
|
|
150
|
+
image_index = 0
|
|
151
|
+
|
|
152
|
+
for message in messages:
|
|
153
|
+
processed_message = {
|
|
154
|
+
"role": message["role"],
|
|
155
|
+
"content": []
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
content = message.get("content", [])
|
|
159
|
+
if isinstance(content, str):
|
|
160
|
+
# Simple text content
|
|
161
|
+
processed_message["content"] = content
|
|
162
|
+
elif isinstance(content, list):
|
|
163
|
+
# Multi-modal content
|
|
164
|
+
processed_content = []
|
|
165
|
+
for item in content:
|
|
166
|
+
if item.get("type") == "text":
|
|
167
|
+
processed_content.append({
|
|
168
|
+
"type": "text",
|
|
169
|
+
"text": item.get("text", "")
|
|
170
|
+
})
|
|
171
|
+
elif item.get("type") == "image_url":
|
|
172
|
+
image_url = item.get("image_url", {}).get("url", "")
|
|
173
|
+
pil_image = None
|
|
174
|
+
|
|
175
|
+
if image_url.startswith("data:image/"):
|
|
176
|
+
# Extract base64 data
|
|
177
|
+
base64_data = image_url.split(',')[1]
|
|
178
|
+
# Convert base64 to PIL Image
|
|
179
|
+
image_data = base64.b64decode(base64_data)
|
|
180
|
+
pil_image = Image.open(io.BytesIO(image_data))
|
|
181
|
+
else:
|
|
182
|
+
# Handle file path or URL
|
|
183
|
+
pil_image = Image.open(image_url)
|
|
184
|
+
|
|
185
|
+
# Store original image size for coordinate mapping
|
|
186
|
+
original_size = pil_image.size
|
|
187
|
+
original_sizes[image_index] = original_size
|
|
188
|
+
|
|
189
|
+
# Use smart_resize to determine model size
|
|
190
|
+
# Note: smart_resize expects (height, width) but PIL gives (width, height)
|
|
191
|
+
height, width = original_size[1], original_size[0]
|
|
192
|
+
new_height, new_width = smart_resize(height, width)
|
|
193
|
+
# Store model size in (width, height) format for consistent coordinate processing
|
|
194
|
+
model_sizes[image_index] = (new_width, new_height)
|
|
195
|
+
|
|
196
|
+
# Resize the image using the calculated dimensions from smart_resize
|
|
197
|
+
resized_image = pil_image.resize((new_width, new_height))
|
|
198
|
+
images.append(resized_image)
|
|
199
|
+
|
|
200
|
+
# Add image placeholder to content
|
|
201
|
+
processed_content.append({
|
|
202
|
+
"type": "image"
|
|
203
|
+
})
|
|
204
|
+
|
|
205
|
+
image_index += 1
|
|
206
|
+
|
|
207
|
+
processed_message["content"] = processed_content
|
|
208
|
+
|
|
209
|
+
processed_messages.append(processed_message)
|
|
210
|
+
|
|
211
|
+
return processed_messages, images, original_sizes, model_sizes
|
|
212
|
+
|
|
213
|
+
def _generate(self, **kwargs) -> str:
|
|
214
|
+
"""Generate response using the local MLX VLM model.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
**kwargs: Keyword arguments containing messages and model info
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
Generated text response
|
|
221
|
+
"""
|
|
222
|
+
messages = kwargs.get('messages', [])
|
|
223
|
+
model_name = kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')
|
|
224
|
+
max_tokens = kwargs.get('max_tokens', 128)
|
|
225
|
+
|
|
226
|
+
# Warn about ignored kwargs
|
|
227
|
+
ignored_kwargs = set(kwargs.keys()) - {'messages', 'model', 'max_tokens'}
|
|
228
|
+
if ignored_kwargs:
|
|
229
|
+
warnings.warn(f"Ignoring unsupported kwargs: {ignored_kwargs}")
|
|
230
|
+
|
|
231
|
+
# Load model and processor
|
|
232
|
+
model, processor, config = self._load_model_and_processor(model_name)
|
|
233
|
+
|
|
234
|
+
# Convert messages and extract images
|
|
235
|
+
processed_messages, images, original_sizes, model_sizes = self._convert_messages(messages)
|
|
236
|
+
|
|
237
|
+
# Process user text input with box coordinates after image processing
|
|
238
|
+
# Swap original_size and model_size arguments for inverse transformation
|
|
239
|
+
for msg_idx, msg in enumerate(processed_messages):
|
|
240
|
+
if msg.get("role") == "user" and isinstance(msg.get("content"), str):
|
|
241
|
+
content = msg.get("content", "")
|
|
242
|
+
if "<|box_start|>" in content and original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
|
|
243
|
+
orig_size = original_sizes[0]
|
|
244
|
+
model_size = model_sizes[0]
|
|
245
|
+
# Swap arguments to perform inverse transformation for user input
|
|
246
|
+
processed_messages[msg_idx]["content"] = self._process_coordinates(content, model_size, orig_size)
|
|
247
|
+
|
|
248
|
+
try:
|
|
249
|
+
# Format prompt according to model requirements using the processor directly
|
|
250
|
+
prompt = processor.apply_chat_template(
|
|
251
|
+
processed_messages,
|
|
252
|
+
tokenize=False,
|
|
253
|
+
add_generation_prompt=True,
|
|
254
|
+
return_tensors='pt'
|
|
255
|
+
)
|
|
256
|
+
tokenizer = cast(PreTrainedTokenizer, processor)
|
|
257
|
+
|
|
258
|
+
# Generate response
|
|
259
|
+
text_content, usage = generate(
|
|
260
|
+
model,
|
|
261
|
+
tokenizer,
|
|
262
|
+
str(prompt),
|
|
263
|
+
images, # type: ignore
|
|
264
|
+
verbose=False,
|
|
265
|
+
max_tokens=max_tokens
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
except Exception as e:
|
|
269
|
+
raise RuntimeError(f"Error generating response: {str(e)}") from e
|
|
270
|
+
|
|
271
|
+
# Process coordinates in the response back to original image space
|
|
272
|
+
if original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
|
|
273
|
+
# Get original image size and model size (using the first image)
|
|
274
|
+
orig_size = original_sizes[0]
|
|
275
|
+
model_size = model_sizes[0]
|
|
276
|
+
|
|
277
|
+
# Check if output contains box tokens that need processing
|
|
278
|
+
if "<|box_start|>" in text_content:
|
|
279
|
+
# Process coordinates from model space back to original image space
|
|
280
|
+
text_content = self._process_coordinates(text_content, orig_size, model_size)
|
|
281
|
+
|
|
282
|
+
return text_content
|
|
283
|
+
|
|
284
|
+
def completion(self, *args, **kwargs) -> ModelResponse:
|
|
285
|
+
"""Synchronous completion method.
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
ModelResponse with generated text
|
|
289
|
+
"""
|
|
290
|
+
generated_text = self._generate(**kwargs)
|
|
291
|
+
|
|
292
|
+
result = completion(
|
|
293
|
+
model=f"mlx/{kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')}",
|
|
294
|
+
mock_response=generated_text,
|
|
295
|
+
)
|
|
296
|
+
return cast(ModelResponse, result)
|
|
297
|
+
|
|
298
|
+
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
|
299
|
+
"""Asynchronous completion method.
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
ModelResponse with generated text
|
|
303
|
+
"""
|
|
304
|
+
# Run _generate in thread pool to avoid blocking
|
|
305
|
+
loop = asyncio.get_event_loop()
|
|
306
|
+
generated_text = await loop.run_in_executor(
|
|
307
|
+
self._executor,
|
|
308
|
+
functools.partial(self._generate, **kwargs)
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
result = await acompletion(
|
|
312
|
+
model=f"mlx/{kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')}",
|
|
313
|
+
mock_response=generated_text,
|
|
314
|
+
)
|
|
315
|
+
return cast(ModelResponse, result)
|
|
316
|
+
|
|
317
|
+
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
|
318
|
+
"""Synchronous streaming method.
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
Iterator of GenericStreamingChunk
|
|
322
|
+
"""
|
|
323
|
+
generated_text = self._generate(**kwargs)
|
|
324
|
+
|
|
325
|
+
generic_streaming_chunk: GenericStreamingChunk = {
|
|
326
|
+
"finish_reason": "stop",
|
|
327
|
+
"index": 0,
|
|
328
|
+
"is_finished": True,
|
|
329
|
+
"text": generated_text,
|
|
330
|
+
"tool_use": None,
|
|
331
|
+
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
yield generic_streaming_chunk
|
|
335
|
+
|
|
336
|
+
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
|
337
|
+
"""Asynchronous streaming method.
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
AsyncIterator of GenericStreamingChunk
|
|
341
|
+
"""
|
|
342
|
+
# Run _generate in thread pool to avoid blocking
|
|
343
|
+
loop = asyncio.get_event_loop()
|
|
344
|
+
generated_text = await loop.run_in_executor(
|
|
345
|
+
self._executor,
|
|
346
|
+
functools.partial(self._generate, **kwargs)
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
generic_streaming_chunk: GenericStreamingChunk = {
|
|
350
|
+
"finish_reason": "stop",
|
|
351
|
+
"index": 0,
|
|
352
|
+
"is_finished": True,
|
|
353
|
+
"text": generated_text,
|
|
354
|
+
"tool_use": None,
|
|
355
|
+
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
yield generic_streaming_chunk
|
|
@@ -22,6 +22,7 @@ import inspect
|
|
|
22
22
|
from .adapters import (
|
|
23
23
|
HuggingFaceLocalAdapter,
|
|
24
24
|
HumanAdapter,
|
|
25
|
+
MLXVLMAdapter,
|
|
25
26
|
)
|
|
26
27
|
from .callbacks import (
|
|
27
28
|
ImageRetentionCallback,
|
|
@@ -29,6 +30,7 @@ from .callbacks import (
|
|
|
29
30
|
TrajectorySaverCallback,
|
|
30
31
|
BudgetManagerCallback,
|
|
31
32
|
TelemetryCallback,
|
|
33
|
+
OperatorNormalizerCallback
|
|
32
34
|
)
|
|
33
35
|
from .computers import (
|
|
34
36
|
AsyncComputerHandler,
|
|
@@ -201,6 +203,9 @@ class ComputerAgent:
|
|
|
201
203
|
|
|
202
204
|
# == Add built-in callbacks ==
|
|
203
205
|
|
|
206
|
+
# Prepend operator normalizer callback
|
|
207
|
+
self.callbacks.insert(0, OperatorNormalizerCallback())
|
|
208
|
+
|
|
204
209
|
# Add telemetry callback if telemetry_enabled is set
|
|
205
210
|
if self.telemetry_enabled:
|
|
206
211
|
if isinstance(self.telemetry_enabled, bool):
|
|
@@ -234,9 +239,11 @@ class ComputerAgent:
|
|
|
234
239
|
device="auto"
|
|
235
240
|
)
|
|
236
241
|
human_adapter = HumanAdapter()
|
|
242
|
+
mlx_adapter = MLXVLMAdapter()
|
|
237
243
|
litellm.custom_provider_map = [
|
|
238
244
|
{"provider": "huggingface-local", "custom_handler": hf_adapter},
|
|
239
|
-
{"provider": "human", "custom_handler": human_adapter}
|
|
245
|
+
{"provider": "human", "custom_handler": human_adapter},
|
|
246
|
+
{"provider": "mlx", "custom_handler": mlx_adapter}
|
|
240
247
|
]
|
|
241
248
|
litellm.suppress_debug_info = True
|
|
242
249
|
|
|
@@ -459,8 +466,7 @@ class ComputerAgent:
|
|
|
459
466
|
assert_callable_with(computer_method, **action_args)
|
|
460
467
|
await computer_method(**action_args)
|
|
461
468
|
else:
|
|
462
|
-
|
|
463
|
-
return []
|
|
469
|
+
raise ToolError(f"Unknown computer action: {action_type}")
|
|
464
470
|
|
|
465
471
|
# Take screenshot after action
|
|
466
472
|
if self.screenshot_delay and self.screenshot_delay > 0:
|
|
@@ -507,7 +513,7 @@ class ComputerAgent:
|
|
|
507
513
|
# Perform function call
|
|
508
514
|
function = self._get_tool(item.get("name"))
|
|
509
515
|
if not function:
|
|
510
|
-
raise
|
|
516
|
+
raise ToolError(f"Function {item.get("name")} not found")
|
|
511
517
|
|
|
512
518
|
args = json.loads(item.get("arguments"))
|
|
513
519
|
|
|
@@ -8,6 +8,7 @@ from .logging import LoggingCallback
|
|
|
8
8
|
from .trajectory_saver import TrajectorySaverCallback
|
|
9
9
|
from .budget_manager import BudgetManagerCallback
|
|
10
10
|
from .telemetry import TelemetryCallback
|
|
11
|
+
from .operator_validator import OperatorNormalizerCallback
|
|
11
12
|
|
|
12
13
|
__all__ = [
|
|
13
14
|
"AsyncCallbackHandler",
|
|
@@ -16,4 +17,5 @@ __all__ = [
|
|
|
16
17
|
"TrajectorySaverCallback",
|
|
17
18
|
"BudgetManagerCallback",
|
|
18
19
|
"TelemetryCallback",
|
|
20
|
+
"OperatorNormalizerCallback",
|
|
19
21
|
]
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""
|
|
2
|
+
OperatorValidatorCallback
|
|
3
|
+
|
|
4
|
+
Ensures agent output actions conform to expected schemas by fixing common issues:
|
|
5
|
+
- click: add default button='left' if missing
|
|
6
|
+
- keypress: wrap keys string into a list
|
|
7
|
+
- etc.
|
|
8
|
+
|
|
9
|
+
This runs in on_llm_end, which receives the output array (AgentMessage[] as dicts).
|
|
10
|
+
The purpose is to avoid spending another LLM call to fix broken computer call syntax when possible.
|
|
11
|
+
"""
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from typing import Any, Dict, List
|
|
15
|
+
|
|
16
|
+
from .base import AsyncCallbackHandler
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class OperatorNormalizerCallback(AsyncCallbackHandler):
|
|
20
|
+
"""Normalizes common computer call hallucinations / errors in computer call syntax."""
|
|
21
|
+
|
|
22
|
+
async def on_llm_end(self, output: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
23
|
+
# Mutate in-place as requested, but still return the list for chaining
|
|
24
|
+
for item in output or []:
|
|
25
|
+
if item.get("type") != "computer_call":
|
|
26
|
+
continue
|
|
27
|
+
action = item.get("action")
|
|
28
|
+
if not isinstance(action, dict):
|
|
29
|
+
continue
|
|
30
|
+
|
|
31
|
+
# rename mouse click actions to "click"
|
|
32
|
+
for mouse_btn in ["left", "right", "wheel", "back", "forward"]:
|
|
33
|
+
if action.get("type", "") == f"{mouse_btn}_click":
|
|
34
|
+
action["type"] = "click"
|
|
35
|
+
action["button"] = mouse_btn
|
|
36
|
+
# rename hotkey actions to "keypress"
|
|
37
|
+
for alias in ["hotkey", "key", "press", "key_press"]:
|
|
38
|
+
if action.get("type", "") == alias:
|
|
39
|
+
action["type"] = "keypress"
|
|
40
|
+
# assume click actions
|
|
41
|
+
if "button" in action and "type" not in action:
|
|
42
|
+
action["type"] = "click"
|
|
43
|
+
if "click" in action and "type" not in action:
|
|
44
|
+
action["type"] = "click"
|
|
45
|
+
if ("scroll_x" in action or "scroll_y" in action) and "type" not in action:
|
|
46
|
+
action["type"] = "scroll"
|
|
47
|
+
if "text" in action and "type" not in action:
|
|
48
|
+
action["type"] = "type"
|
|
49
|
+
|
|
50
|
+
action_type = action.get("type")
|
|
51
|
+
def _keep_keys(action: Dict[str, Any], keys_to_keep: List[str]):
|
|
52
|
+
"""Keep only the provided keys on action; delete everything else.
|
|
53
|
+
Always ensures required 'type' is present if listed in keys_to_keep.
|
|
54
|
+
"""
|
|
55
|
+
for key in list(action.keys()):
|
|
56
|
+
if key not in keys_to_keep:
|
|
57
|
+
del action[key]
|
|
58
|
+
# rename "coordinate" to "x", "y"
|
|
59
|
+
if "coordinate" in action:
|
|
60
|
+
action["x"] = action["coordinate"][0]
|
|
61
|
+
action["y"] = action["coordinate"][1]
|
|
62
|
+
del action["coordinate"]
|
|
63
|
+
if action_type == "click":
|
|
64
|
+
# convert "click" to "button"
|
|
65
|
+
if "button" not in action and "click" in action:
|
|
66
|
+
action["button"] = action["click"]
|
|
67
|
+
del action["click"]
|
|
68
|
+
# default button to "left"
|
|
69
|
+
action["button"] = action.get("button", "left")
|
|
70
|
+
# add default scroll x, y if missing
|
|
71
|
+
if action_type == "scroll":
|
|
72
|
+
action["scroll_x"] = action.get("scroll_x", 0)
|
|
73
|
+
action["scroll_y"] = action.get("scroll_y", 0)
|
|
74
|
+
# ensure keys arg is a list (normalize aliases first)
|
|
75
|
+
if action_type == "keypress":
|
|
76
|
+
keys = action.get("keys")
|
|
77
|
+
for keys_alias in ["keypress", "key", "press", "key_press", "text"]:
|
|
78
|
+
if keys_alias in action:
|
|
79
|
+
action["keys"] = action[keys_alias]
|
|
80
|
+
del action[keys_alias]
|
|
81
|
+
keys = action.get("keys")
|
|
82
|
+
if isinstance(keys, str):
|
|
83
|
+
action["keys"] = keys.replace("-", "+").split("+") if len(keys) > 1 else [keys]
|
|
84
|
+
required_keys_by_type = {
|
|
85
|
+
# OpenAI actions
|
|
86
|
+
"click": ["type", "button", "x", "y"],
|
|
87
|
+
"double_click": ["type", "x", "y"],
|
|
88
|
+
"drag": ["type", "path"],
|
|
89
|
+
"keypress": ["type", "keys"],
|
|
90
|
+
"move": ["type", "x", "y"],
|
|
91
|
+
"screenshot": ["type"],
|
|
92
|
+
"scroll": ["type", "scroll_x", "scroll_y", "x", "y"],
|
|
93
|
+
"type": ["type", "text"],
|
|
94
|
+
"wait": ["type"],
|
|
95
|
+
# Anthropic actions
|
|
96
|
+
"left_mouse_down": ["type", "x", "y"],
|
|
97
|
+
"left_mouse_up": ["type", "x", "y"],
|
|
98
|
+
"triple_click": ["type", "button", "x", "y"],
|
|
99
|
+
}
|
|
100
|
+
keep = required_keys_by_type.get(action_type or "")
|
|
101
|
+
if keep:
|
|
102
|
+
_keep_keys(action, keep)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
# Second pass: if an assistant message is immediately followed by a computer_call,
|
|
106
|
+
# replace the assistant message itself with a reasoning message with summary text.
|
|
107
|
+
if isinstance(output, list):
|
|
108
|
+
for i, item in enumerate(output):
|
|
109
|
+
# AssistantMessage shape: { type: 'message', role: 'assistant', content: OutputContent[] }
|
|
110
|
+
if item.get("type") == "message" and item.get("role") == "assistant":
|
|
111
|
+
next_idx = i + 1
|
|
112
|
+
if next_idx >= len(output):
|
|
113
|
+
continue
|
|
114
|
+
next_item = output[next_idx]
|
|
115
|
+
if not isinstance(next_item, dict):
|
|
116
|
+
continue
|
|
117
|
+
if next_item.get("type") != "computer_call":
|
|
118
|
+
continue
|
|
119
|
+
contents = item.get("content") or []
|
|
120
|
+
# Extract text from OutputContent[]
|
|
121
|
+
text_parts: List[str] = []
|
|
122
|
+
if isinstance(contents, list):
|
|
123
|
+
for c in contents:
|
|
124
|
+
if isinstance(c, dict) and c.get("type") == "output_text" and isinstance(c.get("text"), str):
|
|
125
|
+
text_parts.append(c["text"])
|
|
126
|
+
text_content = "\n".join(text_parts).strip()
|
|
127
|
+
# Replace assistant message with reasoning message
|
|
128
|
+
output[i] = {
|
|
129
|
+
"type": "reasoning",
|
|
130
|
+
"summary": [
|
|
131
|
+
{
|
|
132
|
+
"type": "summary_text",
|
|
133
|
+
"text": text_content,
|
|
134
|
+
}
|
|
135
|
+
],
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
return output
|
|
@@ -94,6 +94,10 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
|
|
94
94
|
# format: turn_000/0000_name.json
|
|
95
95
|
artifact_filename = f"{self.current_artifact:04d}_{name}"
|
|
96
96
|
artifact_path = turn_dir / f"{artifact_filename}.json"
|
|
97
|
+
# add created_at
|
|
98
|
+
if isinstance(artifact, dict):
|
|
99
|
+
artifact = artifact.copy()
|
|
100
|
+
artifact["created_at"] = str(uuid.uuid1().time)
|
|
97
101
|
with open(artifact_path, "w") as f:
|
|
98
102
|
json.dump(sanitize_image_urls(artifact), f, indent=2)
|
|
99
103
|
self.current_artifact += 1
|
|
@@ -171,7 +175,7 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
|
|
171
175
|
"status": "completed",
|
|
172
176
|
"completed_at": str(uuid.uuid1().time),
|
|
173
177
|
"total_usage": self.total_usage,
|
|
174
|
-
"new_items":
|
|
178
|
+
"new_items": new_items,
|
|
175
179
|
"total_turns": self.current_turn
|
|
176
180
|
})
|
|
177
181
|
|