cua-agent 0.4.34__py3-none-any.whl → 0.4.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/__init__.py +4 -10
- agent/__main__.py +2 -1
- agent/adapters/huggingfacelocal_adapter.py +54 -61
- agent/adapters/human_adapter.py +116 -114
- agent/adapters/mlxvlm_adapter.py +110 -99
- agent/adapters/models/__init__.py +14 -6
- agent/adapters/models/generic.py +7 -4
- agent/adapters/models/internvl.py +66 -30
- agent/adapters/models/opencua.py +23 -8
- agent/adapters/models/qwen2_5_vl.py +7 -4
- agent/agent.py +184 -158
- agent/callbacks/__init__.py +4 -4
- agent/callbacks/base.py +45 -31
- agent/callbacks/budget_manager.py +22 -10
- agent/callbacks/image_retention.py +18 -13
- agent/callbacks/logging.py +55 -42
- agent/callbacks/operator_validator.py +3 -1
- agent/callbacks/pii_anonymization.py +19 -16
- agent/callbacks/telemetry.py +67 -61
- agent/callbacks/trajectory_saver.py +90 -70
- agent/cli.py +115 -110
- agent/computers/__init__.py +13 -8
- agent/computers/base.py +32 -19
- agent/computers/cua.py +33 -25
- agent/computers/custom.py +78 -71
- agent/decorators.py +23 -14
- agent/human_tool/__init__.py +2 -7
- agent/human_tool/__main__.py +6 -2
- agent/human_tool/server.py +48 -37
- agent/human_tool/ui.py +235 -185
- agent/integrations/hud/__init__.py +15 -21
- agent/integrations/hud/agent.py +101 -83
- agent/integrations/hud/proxy.py +90 -57
- agent/loops/__init__.py +25 -21
- agent/loops/anthropic.py +537 -483
- agent/loops/base.py +13 -14
- agent/loops/composed_grounded.py +135 -149
- agent/loops/gemini.py +31 -12
- agent/loops/glm45v.py +135 -133
- agent/loops/gta1.py +47 -50
- agent/loops/holo.py +4 -2
- agent/loops/internvl.py +6 -11
- agent/loops/moondream3.py +36 -12
- agent/loops/omniparser.py +215 -210
- agent/loops/openai.py +49 -50
- agent/loops/opencua.py +29 -41
- agent/loops/qwen.py +510 -0
- agent/loops/uitars.py +237 -202
- agent/proxy/examples.py +54 -50
- agent/proxy/handlers.py +27 -34
- agent/responses.py +330 -330
- agent/types.py +11 -5
- agent/ui/__init__.py +1 -1
- agent/ui/__main__.py +1 -1
- agent/ui/gradio/app.py +23 -18
- agent/ui/gradio/ui_components.py +310 -161
- {cua_agent-0.4.34.dist-info → cua_agent-0.4.36.dist-info}/METADATA +18 -10
- cua_agent-0.4.36.dist-info/RECORD +64 -0
- cua_agent-0.4.34.dist-info/RECORD +0 -63
- {cua_agent-0.4.34.dist-info → cua_agent-0.4.36.dist-info}/WHEEL +0 -0
- {cua_agent-0.4.34.dist-info → cua_agent-0.4.36.dist-info}/entry_points.txt +0 -0
agent/adapters/mlxvlm_adapter.py
CHANGED
|
@@ -1,24 +1,26 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
import base64
|
|
2
3
|
import functools
|
|
3
|
-
import warnings
|
|
4
4
|
import io
|
|
5
|
-
import base64
|
|
6
5
|
import math
|
|
7
6
|
import re
|
|
7
|
+
import warnings
|
|
8
8
|
from concurrent.futures import ThreadPoolExecutor
|
|
9
|
-
from typing import
|
|
10
|
-
|
|
11
|
-
from litellm
|
|
9
|
+
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple, cast
|
|
10
|
+
|
|
11
|
+
from litellm import acompletion, completion
|
|
12
12
|
from litellm.llms.custom_llm import CustomLLM
|
|
13
|
-
from litellm import
|
|
13
|
+
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
|
14
|
+
from PIL import Image
|
|
14
15
|
|
|
15
16
|
# Try to import MLX dependencies
|
|
16
17
|
try:
|
|
17
18
|
import mlx.core as mx
|
|
18
|
-
from mlx_vlm import
|
|
19
|
+
from mlx_vlm import generate, load
|
|
19
20
|
from mlx_vlm.prompt_utils import apply_chat_template
|
|
20
21
|
from mlx_vlm.utils import load_config
|
|
21
22
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
23
|
+
|
|
22
24
|
MLX_AVAILABLE = True
|
|
23
25
|
except ImportError:
|
|
24
26
|
MLX_AVAILABLE = False
|
|
@@ -29,20 +31,28 @@ MIN_PIXELS = 100 * 28 * 28
|
|
|
29
31
|
MAX_PIXELS = 16384 * 28 * 28
|
|
30
32
|
MAX_RATIO = 200
|
|
31
33
|
|
|
34
|
+
|
|
32
35
|
def round_by_factor(number: float, factor: int) -> int:
|
|
33
36
|
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
|
34
37
|
return round(number / factor) * factor
|
|
35
38
|
|
|
39
|
+
|
|
36
40
|
def ceil_by_factor(number: float, factor: int) -> int:
|
|
37
41
|
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
|
38
42
|
return math.ceil(number / factor) * factor
|
|
39
43
|
|
|
44
|
+
|
|
40
45
|
def floor_by_factor(number: float, factor: int) -> int:
|
|
41
46
|
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
|
42
47
|
return math.floor(number / factor) * factor
|
|
43
48
|
|
|
49
|
+
|
|
44
50
|
def smart_resize(
|
|
45
|
-
height: int,
|
|
51
|
+
height: int,
|
|
52
|
+
width: int,
|
|
53
|
+
factor: int = IMAGE_FACTOR,
|
|
54
|
+
min_pixels: int = MIN_PIXELS,
|
|
55
|
+
max_pixels: int = MAX_PIXELS,
|
|
46
56
|
) -> tuple[int, int]:
|
|
47
57
|
"""
|
|
48
58
|
Rescales the image so that the following conditions are met:
|
|
@@ -70,61 +80,62 @@ def smart_resize(
|
|
|
70
80
|
|
|
71
81
|
class MLXVLMAdapter(CustomLLM):
|
|
72
82
|
"""MLX VLM Adapter for running vision-language models locally using MLX."""
|
|
73
|
-
|
|
83
|
+
|
|
74
84
|
def __init__(self, **kwargs):
|
|
75
85
|
"""Initialize the adapter.
|
|
76
|
-
|
|
86
|
+
|
|
77
87
|
Args:
|
|
78
88
|
**kwargs: Additional arguments
|
|
79
89
|
"""
|
|
80
90
|
super().__init__()
|
|
81
|
-
|
|
91
|
+
|
|
82
92
|
self.models = {} # Cache for loaded models
|
|
83
93
|
self.processors = {} # Cache for loaded processors
|
|
84
94
|
self.configs = {} # Cache for loaded configs
|
|
85
95
|
self._executor = ThreadPoolExecutor(max_workers=1) # Single thread pool
|
|
86
|
-
|
|
96
|
+
|
|
87
97
|
def _load_model_and_processor(self, model_name: str):
|
|
88
98
|
"""Load model and processor if not already cached.
|
|
89
|
-
|
|
99
|
+
|
|
90
100
|
Args:
|
|
91
101
|
model_name: Name of the model to load
|
|
92
|
-
|
|
102
|
+
|
|
93
103
|
Returns:
|
|
94
104
|
Tuple of (model, processor, config)
|
|
95
105
|
"""
|
|
96
106
|
if not MLX_AVAILABLE:
|
|
97
107
|
raise ImportError("MLX VLM dependencies not available. Please install mlx-vlm.")
|
|
98
|
-
|
|
108
|
+
|
|
99
109
|
if model_name not in self.models:
|
|
100
110
|
# Load model and processor
|
|
101
111
|
model_obj, processor = load(
|
|
102
|
-
model_name,
|
|
103
|
-
processor_kwargs={"min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS}
|
|
112
|
+
model_name, processor_kwargs={"min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS}
|
|
104
113
|
)
|
|
105
114
|
config = load_config(model_name)
|
|
106
|
-
|
|
115
|
+
|
|
107
116
|
# Cache them
|
|
108
117
|
self.models[model_name] = model_obj
|
|
109
118
|
self.processors[model_name] = processor
|
|
110
119
|
self.configs[model_name] = config
|
|
111
|
-
|
|
120
|
+
|
|
112
121
|
return self.models[model_name], self.processors[model_name], self.configs[model_name]
|
|
113
|
-
|
|
114
|
-
def _process_coordinates(
|
|
122
|
+
|
|
123
|
+
def _process_coordinates(
|
|
124
|
+
self, text: str, original_size: Tuple[int, int], model_size: Tuple[int, int]
|
|
125
|
+
) -> str:
|
|
115
126
|
"""Process coordinates in box tokens based on image resizing using smart_resize approach.
|
|
116
|
-
|
|
127
|
+
|
|
117
128
|
Args:
|
|
118
129
|
text: Text containing box tokens
|
|
119
130
|
original_size: Original image size (width, height)
|
|
120
131
|
model_size: Model processed image size (width, height)
|
|
121
|
-
|
|
132
|
+
|
|
122
133
|
Returns:
|
|
123
134
|
Text with processed coordinates
|
|
124
135
|
"""
|
|
125
136
|
# Find all box tokens
|
|
126
137
|
box_pattern = r"<\|box_start\|>\((\d+),\s*(\d+)\)<\|box_end\|>"
|
|
127
|
-
|
|
138
|
+
|
|
128
139
|
def process_coords(match):
|
|
129
140
|
model_x, model_y = int(match.group(1)), int(match.group(2))
|
|
130
141
|
# Scale coordinates from model space to original image space
|
|
@@ -132,15 +143,20 @@ class MLXVLMAdapter(CustomLLM):
|
|
|
132
143
|
new_x = int(model_x * original_size[0] / model_size[0]) # Width
|
|
133
144
|
new_y = int(model_y * original_size[1] / model_size[1]) # Height
|
|
134
145
|
return f"<|box_start|>({new_x},{new_y})<|box_end|>"
|
|
135
|
-
|
|
146
|
+
|
|
136
147
|
return re.sub(box_pattern, process_coords, text)
|
|
137
|
-
|
|
138
|
-
def _convert_messages(self, messages: List[Dict[str, Any]]) -> Tuple[
|
|
148
|
+
|
|
149
|
+
def _convert_messages(self, messages: List[Dict[str, Any]]) -> Tuple[
|
|
150
|
+
List[Dict[str, Any]],
|
|
151
|
+
List[Image.Image],
|
|
152
|
+
Dict[int, Tuple[int, int]],
|
|
153
|
+
Dict[int, Tuple[int, int]],
|
|
154
|
+
]:
|
|
139
155
|
"""Convert OpenAI format messages to MLX VLM format and extract images.
|
|
140
|
-
|
|
156
|
+
|
|
141
157
|
Args:
|
|
142
158
|
messages: Messages in OpenAI format
|
|
143
|
-
|
|
159
|
+
|
|
144
160
|
Returns:
|
|
145
161
|
Tuple of (processed_messages, images, original_sizes, model_sizes)
|
|
146
162
|
"""
|
|
@@ -149,13 +165,10 @@ class MLXVLMAdapter(CustomLLM):
|
|
|
149
165
|
original_sizes = {} # Track original sizes of images for coordinate mapping
|
|
150
166
|
model_sizes = {} # Track model processed sizes
|
|
151
167
|
image_index = 0
|
|
152
|
-
|
|
168
|
+
|
|
153
169
|
for message in messages:
|
|
154
|
-
processed_message = {
|
|
155
|
-
|
|
156
|
-
"content": []
|
|
157
|
-
}
|
|
158
|
-
|
|
170
|
+
processed_message = {"role": message["role"], "content": []}
|
|
171
|
+
|
|
159
172
|
content = message.get("content", [])
|
|
160
173
|
if isinstance(content, str):
|
|
161
174
|
# Simple text content
|
|
@@ -165,164 +178,163 @@ class MLXVLMAdapter(CustomLLM):
|
|
|
165
178
|
processed_content = []
|
|
166
179
|
for item in content:
|
|
167
180
|
if item.get("type") == "text":
|
|
168
|
-
processed_content.append({
|
|
169
|
-
"type": "text",
|
|
170
|
-
"text": item.get("text", "")
|
|
171
|
-
})
|
|
181
|
+
processed_content.append({"type": "text", "text": item.get("text", "")})
|
|
172
182
|
elif item.get("type") == "image_url":
|
|
173
183
|
image_url = item.get("image_url", {}).get("url", "")
|
|
174
184
|
pil_image = None
|
|
175
|
-
|
|
185
|
+
|
|
176
186
|
if image_url.startswith("data:image/"):
|
|
177
187
|
# Extract base64 data
|
|
178
|
-
base64_data = image_url.split(
|
|
188
|
+
base64_data = image_url.split(",")[1]
|
|
179
189
|
# Convert base64 to PIL Image
|
|
180
190
|
image_data = base64.b64decode(base64_data)
|
|
181
191
|
pil_image = Image.open(io.BytesIO(image_data))
|
|
182
192
|
else:
|
|
183
193
|
# Handle file path or URL
|
|
184
194
|
pil_image = Image.open(image_url)
|
|
185
|
-
|
|
195
|
+
|
|
186
196
|
# Store original image size for coordinate mapping
|
|
187
197
|
original_size = pil_image.size
|
|
188
198
|
original_sizes[image_index] = original_size
|
|
189
|
-
|
|
199
|
+
|
|
190
200
|
# Use smart_resize to determine model size
|
|
191
201
|
# Note: smart_resize expects (height, width) but PIL gives (width, height)
|
|
192
202
|
height, width = original_size[1], original_size[0]
|
|
193
203
|
new_height, new_width = smart_resize(height, width)
|
|
194
204
|
# Store model size in (width, height) format for consistent coordinate processing
|
|
195
205
|
model_sizes[image_index] = (new_width, new_height)
|
|
196
|
-
|
|
206
|
+
|
|
197
207
|
# Resize the image using the calculated dimensions from smart_resize
|
|
198
208
|
resized_image = pil_image.resize((new_width, new_height))
|
|
199
209
|
images.append(resized_image)
|
|
200
|
-
|
|
210
|
+
|
|
201
211
|
# Add image placeholder to content
|
|
202
|
-
processed_content.append({
|
|
203
|
-
|
|
204
|
-
})
|
|
205
|
-
|
|
212
|
+
processed_content.append({"type": "image"})
|
|
213
|
+
|
|
206
214
|
image_index += 1
|
|
207
|
-
|
|
215
|
+
|
|
208
216
|
processed_message["content"] = processed_content
|
|
209
|
-
|
|
217
|
+
|
|
210
218
|
processed_messages.append(processed_message)
|
|
211
|
-
|
|
219
|
+
|
|
212
220
|
return processed_messages, images, original_sizes, model_sizes
|
|
213
|
-
|
|
221
|
+
|
|
214
222
|
def _generate(self, **kwargs) -> str:
|
|
215
223
|
"""Generate response using the local MLX VLM model.
|
|
216
|
-
|
|
224
|
+
|
|
217
225
|
Args:
|
|
218
226
|
**kwargs: Keyword arguments containing messages and model info
|
|
219
|
-
|
|
227
|
+
|
|
220
228
|
Returns:
|
|
221
229
|
Generated text response
|
|
222
230
|
"""
|
|
223
|
-
messages = kwargs.get(
|
|
224
|
-
model_name = kwargs.get(
|
|
225
|
-
max_tokens = kwargs.get(
|
|
226
|
-
|
|
231
|
+
messages = kwargs.get("messages", [])
|
|
232
|
+
model_name = kwargs.get("model", "mlx-community/UI-TARS-1.5-7B-4bit")
|
|
233
|
+
max_tokens = kwargs.get("max_tokens", 128)
|
|
234
|
+
|
|
227
235
|
# Warn about ignored kwargs
|
|
228
|
-
ignored_kwargs = set(kwargs.keys()) - {
|
|
236
|
+
ignored_kwargs = set(kwargs.keys()) - {"messages", "model", "max_tokens"}
|
|
229
237
|
if ignored_kwargs:
|
|
230
238
|
warnings.warn(f"Ignoring unsupported kwargs: {ignored_kwargs}")
|
|
231
|
-
|
|
239
|
+
|
|
232
240
|
# Load model and processor
|
|
233
241
|
model, processor, config = self._load_model_and_processor(model_name)
|
|
234
|
-
|
|
242
|
+
|
|
235
243
|
# Convert messages and extract images
|
|
236
244
|
processed_messages, images, original_sizes, model_sizes = self._convert_messages(messages)
|
|
237
|
-
|
|
245
|
+
|
|
238
246
|
# Process user text input with box coordinates after image processing
|
|
239
247
|
# Swap original_size and model_size arguments for inverse transformation
|
|
240
248
|
for msg_idx, msg in enumerate(processed_messages):
|
|
241
249
|
if msg.get("role") == "user" and isinstance(msg.get("content"), str):
|
|
242
250
|
content = msg.get("content", "")
|
|
243
|
-
if
|
|
251
|
+
if (
|
|
252
|
+
"<|box_start|>" in content
|
|
253
|
+
and original_sizes
|
|
254
|
+
and model_sizes
|
|
255
|
+
and 0 in original_sizes
|
|
256
|
+
and 0 in model_sizes
|
|
257
|
+
):
|
|
244
258
|
orig_size = original_sizes[0]
|
|
245
259
|
model_size = model_sizes[0]
|
|
246
260
|
# Swap arguments to perform inverse transformation for user input
|
|
247
|
-
processed_messages[msg_idx]["content"] = self._process_coordinates(
|
|
248
|
-
|
|
261
|
+
processed_messages[msg_idx]["content"] = self._process_coordinates(
|
|
262
|
+
content, model_size, orig_size
|
|
263
|
+
)
|
|
264
|
+
|
|
249
265
|
try:
|
|
250
266
|
# Format prompt according to model requirements using the processor directly
|
|
251
267
|
prompt = processor.apply_chat_template(
|
|
252
|
-
processed_messages,
|
|
253
|
-
tokenize=False,
|
|
254
|
-
add_generation_prompt=True,
|
|
255
|
-
return_tensors='pt'
|
|
268
|
+
processed_messages, tokenize=False, add_generation_prompt=True, return_tensors="pt"
|
|
256
269
|
)
|
|
257
270
|
tokenizer = cast(PreTrainedTokenizer, processor)
|
|
258
|
-
|
|
271
|
+
|
|
259
272
|
# Generate response
|
|
260
273
|
text_content, usage = generate(
|
|
261
|
-
model,
|
|
262
|
-
tokenizer,
|
|
263
|
-
str(prompt),
|
|
264
|
-
images,
|
|
274
|
+
model,
|
|
275
|
+
tokenizer,
|
|
276
|
+
str(prompt),
|
|
277
|
+
images, # type: ignore
|
|
265
278
|
verbose=False,
|
|
266
|
-
max_tokens=max_tokens
|
|
279
|
+
max_tokens=max_tokens,
|
|
267
280
|
)
|
|
268
|
-
|
|
281
|
+
|
|
269
282
|
except Exception as e:
|
|
270
283
|
raise RuntimeError(f"Error generating response: {str(e)}") from e
|
|
271
|
-
|
|
284
|
+
|
|
272
285
|
# Process coordinates in the response back to original image space
|
|
273
286
|
if original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
|
|
274
287
|
# Get original image size and model size (using the first image)
|
|
275
288
|
orig_size = original_sizes[0]
|
|
276
289
|
model_size = model_sizes[0]
|
|
277
|
-
|
|
290
|
+
|
|
278
291
|
# Check if output contains box tokens that need processing
|
|
279
292
|
if "<|box_start|>" in text_content:
|
|
280
293
|
# Process coordinates from model space back to original image space
|
|
281
294
|
text_content = self._process_coordinates(text_content, orig_size, model_size)
|
|
282
|
-
|
|
295
|
+
|
|
283
296
|
return text_content
|
|
284
|
-
|
|
297
|
+
|
|
285
298
|
def completion(self, *args, **kwargs) -> ModelResponse:
|
|
286
299
|
"""Synchronous completion method.
|
|
287
|
-
|
|
300
|
+
|
|
288
301
|
Returns:
|
|
289
302
|
ModelResponse with generated text
|
|
290
303
|
"""
|
|
291
304
|
generated_text = self._generate(**kwargs)
|
|
292
|
-
|
|
305
|
+
|
|
293
306
|
result = completion(
|
|
294
307
|
model=f"mlx/{kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')}",
|
|
295
308
|
mock_response=generated_text,
|
|
296
309
|
)
|
|
297
310
|
return cast(ModelResponse, result)
|
|
298
|
-
|
|
311
|
+
|
|
299
312
|
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
|
300
313
|
"""Asynchronous completion method.
|
|
301
|
-
|
|
314
|
+
|
|
302
315
|
Returns:
|
|
303
316
|
ModelResponse with generated text
|
|
304
317
|
"""
|
|
305
318
|
# Run _generate in thread pool to avoid blocking
|
|
306
319
|
loop = asyncio.get_event_loop()
|
|
307
320
|
generated_text = await loop.run_in_executor(
|
|
308
|
-
self._executor,
|
|
309
|
-
functools.partial(self._generate, **kwargs)
|
|
321
|
+
self._executor, functools.partial(self._generate, **kwargs)
|
|
310
322
|
)
|
|
311
|
-
|
|
323
|
+
|
|
312
324
|
result = await acompletion(
|
|
313
325
|
model=f"mlx/{kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')}",
|
|
314
326
|
mock_response=generated_text,
|
|
315
327
|
)
|
|
316
328
|
return cast(ModelResponse, result)
|
|
317
|
-
|
|
329
|
+
|
|
318
330
|
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
|
319
331
|
"""Synchronous streaming method.
|
|
320
|
-
|
|
332
|
+
|
|
321
333
|
Returns:
|
|
322
334
|
Iterator of GenericStreamingChunk
|
|
323
335
|
"""
|
|
324
336
|
generated_text = self._generate(**kwargs)
|
|
325
|
-
|
|
337
|
+
|
|
326
338
|
generic_streaming_chunk: GenericStreamingChunk = {
|
|
327
339
|
"finish_reason": "stop",
|
|
328
340
|
"index": 0,
|
|
@@ -331,22 +343,21 @@ class MLXVLMAdapter(CustomLLM):
|
|
|
331
343
|
"tool_use": None,
|
|
332
344
|
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
|
|
333
345
|
}
|
|
334
|
-
|
|
346
|
+
|
|
335
347
|
yield generic_streaming_chunk
|
|
336
|
-
|
|
348
|
+
|
|
337
349
|
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
|
338
350
|
"""Asynchronous streaming method.
|
|
339
|
-
|
|
351
|
+
|
|
340
352
|
Returns:
|
|
341
353
|
AsyncIterator of GenericStreamingChunk
|
|
342
354
|
"""
|
|
343
355
|
# Run _generate in thread pool to avoid blocking
|
|
344
356
|
loop = asyncio.get_event_loop()
|
|
345
357
|
generated_text = await loop.run_in_executor(
|
|
346
|
-
self._executor,
|
|
347
|
-
functools.partial(self._generate, **kwargs)
|
|
358
|
+
self._executor, functools.partial(self._generate, **kwargs)
|
|
348
359
|
)
|
|
349
|
-
|
|
360
|
+
|
|
350
361
|
generic_streaming_chunk: GenericStreamingChunk = {
|
|
351
362
|
"finish_reason": "stop",
|
|
352
363
|
"index": 0,
|
|
@@ -355,5 +366,5 @@ class MLXVLMAdapter(CustomLLM):
|
|
|
355
366
|
"tool_use": None,
|
|
356
367
|
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
|
|
357
368
|
}
|
|
358
|
-
|
|
359
|
-
yield generic_streaming_chunk
|
|
369
|
+
|
|
370
|
+
yield generic_streaming_chunk
|
|
@@ -2,32 +2,40 @@ from typing import Optional
|
|
|
2
2
|
|
|
3
3
|
try:
|
|
4
4
|
from transformers import AutoConfig
|
|
5
|
+
|
|
5
6
|
HF_AVAILABLE = True
|
|
6
7
|
except ImportError:
|
|
7
8
|
HF_AVAILABLE = False
|
|
8
9
|
|
|
9
10
|
from .generic import GenericHFModel
|
|
11
|
+
from .internvl import InternVLModel
|
|
10
12
|
from .opencua import OpenCUAModel
|
|
11
13
|
from .qwen2_5_vl import Qwen2_5_VLModel
|
|
12
|
-
|
|
14
|
+
|
|
13
15
|
|
|
14
16
|
def load_model(model_name: str, device: str = "auto", trust_remote_code: bool = False):
|
|
15
17
|
"""Factory function to load and return the right model handler instance.
|
|
16
|
-
|
|
18
|
+
|
|
17
19
|
- If the underlying transformers config class matches OpenCUA, return OpenCUAModel
|
|
18
20
|
- Otherwise, return GenericHFModel
|
|
19
21
|
"""
|
|
20
22
|
if not HF_AVAILABLE:
|
|
21
23
|
raise ImportError(
|
|
22
|
-
|
|
24
|
+
'HuggingFace transformers dependencies not found. Install with: pip install "cua-agent[uitars-hf]"'
|
|
23
25
|
)
|
|
24
26
|
cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
|
|
25
27
|
cls = cfg.__class__.__name__
|
|
26
28
|
print(f"cls: {cls}")
|
|
27
29
|
if "OpenCUA" in cls:
|
|
28
|
-
return OpenCUAModel(
|
|
30
|
+
return OpenCUAModel(
|
|
31
|
+
model_name=model_name, device=device, trust_remote_code=trust_remote_code
|
|
32
|
+
)
|
|
29
33
|
elif "Qwen2_5_VL" in cls:
|
|
30
|
-
return Qwen2_5_VLModel(
|
|
34
|
+
return Qwen2_5_VLModel(
|
|
35
|
+
model_name=model_name, device=device, trust_remote_code=trust_remote_code
|
|
36
|
+
)
|
|
31
37
|
elif "InternVL" in cls:
|
|
32
|
-
return InternVLModel(
|
|
38
|
+
return InternVLModel(
|
|
39
|
+
model_name=model_name, device=device, trust_remote_code=trust_remote_code
|
|
40
|
+
)
|
|
33
41
|
return GenericHFModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
|
agent/adapters/models/generic.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Any, Dict, List, Optional
|
|
2
2
|
|
|
3
3
|
# Hugging Face imports are local to avoid hard dependency at module import
|
|
4
4
|
try:
|
|
5
5
|
import torch # type: ignore
|
|
6
6
|
from transformers import AutoModel, AutoProcessor # type: ignore
|
|
7
|
+
|
|
7
8
|
HF_AVAILABLE = True
|
|
8
9
|
except Exception:
|
|
9
10
|
HF_AVAILABLE = False
|
|
@@ -14,10 +15,12 @@ class GenericHFModel:
|
|
|
14
15
|
Loads an AutoModelForImageTextToText and AutoProcessor and generates text.
|
|
15
16
|
"""
|
|
16
17
|
|
|
17
|
-
def __init__(
|
|
18
|
+
def __init__(
|
|
19
|
+
self, model_name: str, device: str = "auto", trust_remote_code: bool = False
|
|
20
|
+
) -> None:
|
|
18
21
|
if not HF_AVAILABLE:
|
|
19
22
|
raise ImportError(
|
|
20
|
-
|
|
23
|
+
'HuggingFace transformers dependencies not found. Install with: pip install "cua-agent[uitars-hf]"'
|
|
21
24
|
)
|
|
22
25
|
self.model_name = model_name
|
|
23
26
|
self.device = device
|
|
@@ -64,7 +67,7 @@ class GenericHFModel:
|
|
|
64
67
|
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
|
|
65
68
|
# Trim prompt tokens from output
|
|
66
69
|
generated_ids_trimmed = [
|
|
67
|
-
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
70
|
+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
68
71
|
]
|
|
69
72
|
# Decode
|
|
70
73
|
output_text = self.processor.batch_decode(
|