cua-agent 0.4.17__tar.gz → 0.4.19__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cua-agent might be problematic. Click here for more details.

Files changed (54) hide show
  1. {cua_agent-0.4.17 → cua_agent-0.4.19}/PKG-INFO +3 -3
  2. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/adapters/__init__.py +2 -0
  3. cua_agent-0.4.19/agent/adapters/mlxvlm_adapter.py +359 -0
  4. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/agent.py +14 -3
  5. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/callbacks/__init__.py +2 -0
  6. cua_agent-0.4.19/agent/callbacks/operator_validator.py +138 -0
  7. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/callbacks/trajectory_saver.py +87 -5
  8. cua_agent-0.4.19/agent/integrations/hud/__init__.py +228 -0
  9. cua_agent-0.4.19/agent/integrations/hud/proxy.py +183 -0
  10. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/loops/anthropic.py +12 -1
  11. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/loops/composed_grounded.py +26 -14
  12. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/loops/openai.py +15 -7
  13. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/loops/uitars.py +17 -8
  14. cua_agent-0.4.19/agent/proxy/examples.py +192 -0
  15. cua_agent-0.4.19/agent/proxy/handlers.py +248 -0
  16. {cua_agent-0.4.17 → cua_agent-0.4.19}/pyproject.toml +3 -3
  17. cua_agent-0.4.17/agent/integrations/hud/__init__.py +0 -77
  18. cua_agent-0.4.17/agent/integrations/hud/adapter.py +0 -121
  19. cua_agent-0.4.17/agent/integrations/hud/agent.py +0 -373
  20. cua_agent-0.4.17/agent/integrations/hud/computer_handler.py +0 -187
  21. {cua_agent-0.4.17 → cua_agent-0.4.19}/README.md +0 -0
  22. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/__init__.py +0 -0
  23. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/__main__.py +0 -0
  24. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/adapters/huggingfacelocal_adapter.py +0 -0
  25. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/adapters/human_adapter.py +0 -0
  26. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/callbacks/base.py +0 -0
  27. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/callbacks/budget_manager.py +0 -0
  28. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/callbacks/image_retention.py +0 -0
  29. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/callbacks/logging.py +0 -0
  30. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/callbacks/pii_anonymization.py +0 -0
  31. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/callbacks/telemetry.py +0 -0
  32. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/cli.py +0 -0
  33. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/computers/__init__.py +0 -0
  34. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/computers/base.py +0 -0
  35. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/computers/cua.py +0 -0
  36. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/computers/custom.py +0 -0
  37. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/decorators.py +0 -0
  38. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/human_tool/__init__.py +0 -0
  39. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/human_tool/__main__.py +0 -0
  40. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/human_tool/server.py +0 -0
  41. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/human_tool/ui.py +0 -0
  42. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/loops/__init__.py +0 -0
  43. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/loops/base.py +0 -0
  44. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/loops/glm45v.py +0 -0
  45. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/loops/gta1.py +0 -0
  46. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/loops/model_types.csv +0 -0
  47. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/loops/omniparser.py +0 -0
  48. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/responses.py +0 -0
  49. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/types.py +0 -0
  50. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/ui/__init__.py +0 -0
  51. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/ui/__main__.py +0 -0
  52. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/ui/gradio/__init__.py +0 -0
  53. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/ui/gradio/app.py +0 -0
  54. {cua_agent-0.4.17 → cua_agent-0.4.19}/agent/ui/gradio/ui_components.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: cua-agent
3
- Version: 0.4.17
3
+ Version: 0.4.19
4
4
  Summary: CUA (Computer Use) Agent for AI-driven computer interaction
5
5
  Author-Email: TryCua <gh@trycua.com>
6
6
  Requires-Python: >=3.12
@@ -38,7 +38,7 @@ Requires-Dist: python-dotenv>=1.0.1; extra == "ui"
38
38
  Provides-Extra: cli
39
39
  Requires-Dist: yaspin>=3.1.0; extra == "cli"
40
40
  Provides-Extra: hud
41
- Requires-Dist: hud-python==0.2.10; extra == "hud"
41
+ Requires-Dist: hud-python<0.5.0,>=0.4.12; extra == "hud"
42
42
  Provides-Extra: all
43
43
  Requires-Dist: ultralytics>=8.0.0; extra == "all"
44
44
  Requires-Dist: cua-som<0.2.0,>=0.1.0; extra == "all"
@@ -49,7 +49,7 @@ Requires-Dist: transformers>=4.54.0; extra == "all"
49
49
  Requires-Dist: gradio>=5.23.3; extra == "all"
50
50
  Requires-Dist: python-dotenv>=1.0.1; extra == "all"
51
51
  Requires-Dist: yaspin>=3.1.0; extra == "all"
52
- Requires-Dist: hud-python==0.2.10; extra == "all"
52
+ Requires-Dist: hud-python<0.5.0,>=0.4.12; extra == "all"
53
53
  Description-Content-Type: text/markdown
54
54
 
55
55
  <div align="center">
@@ -4,8 +4,10 @@ Adapters package for agent - Custom LLM adapters for LiteLLM
4
4
 
5
5
  from .huggingfacelocal_adapter import HuggingFaceLocalAdapter
6
6
  from .human_adapter import HumanAdapter
7
+ from .mlxvlm_adapter import MLXVLMAdapter
7
8
 
8
9
  __all__ = [
9
10
  "HuggingFaceLocalAdapter",
10
11
  "HumanAdapter",
12
+ "MLXVLMAdapter",
11
13
  ]
@@ -0,0 +1,359 @@
1
+ import asyncio
2
+ import functools
3
+ import warnings
4
+ import io
5
+ import base64
6
+ import math
7
+ import re
8
+ from concurrent.futures import ThreadPoolExecutor
9
+ from typing import Iterator, AsyncIterator, Dict, List, Any, Optional, Tuple, cast
10
+ from PIL import Image
11
+ from litellm.types.utils import GenericStreamingChunk, ModelResponse
12
+ from litellm.llms.custom_llm import CustomLLM
13
+ from litellm import completion, acompletion
14
+
15
+ # Try to import MLX dependencies
16
+ try:
17
+ import mlx.core as mx
18
+ from mlx_vlm import load, generate
19
+ from mlx_vlm.prompt_utils import apply_chat_template
20
+ from mlx_vlm.utils import load_config
21
+ from transformers.tokenization_utils import PreTrainedTokenizer
22
+ MLX_AVAILABLE = True
23
+ except ImportError:
24
+ MLX_AVAILABLE = False
25
+
26
+ # Constants for smart_resize
27
+ IMAGE_FACTOR = 28
28
+ MIN_PIXELS = 100 * 28 * 28
29
+ MAX_PIXELS = 16384 * 28 * 28
30
+ MAX_RATIO = 200
31
+
32
+ def round_by_factor(number: float, factor: int) -> int:
33
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
34
+ return round(number / factor) * factor
35
+
36
+ def ceil_by_factor(number: float, factor: int) -> int:
37
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
38
+ return math.ceil(number / factor) * factor
39
+
40
+ def floor_by_factor(number: float, factor: int) -> int:
41
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
42
+ return math.floor(number / factor) * factor
43
+
44
+ def smart_resize(
45
+ height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
46
+ ) -> tuple[int, int]:
47
+ """
48
+ Rescales the image so that the following conditions are met:
49
+
50
+ 1. Both dimensions (height and width) are divisible by 'factor'.
51
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
52
+ 3. The aspect ratio of the image is maintained as closely as possible.
53
+ """
54
+ if max(height, width) / min(height, width) > MAX_RATIO:
55
+ raise ValueError(
56
+ f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
57
+ )
58
+ h_bar = max(factor, round_by_factor(height, factor))
59
+ w_bar = max(factor, round_by_factor(width, factor))
60
+ if h_bar * w_bar > max_pixels:
61
+ beta = math.sqrt((height * width) / max_pixels)
62
+ h_bar = floor_by_factor(height / beta, factor)
63
+ w_bar = floor_by_factor(width / beta, factor)
64
+ elif h_bar * w_bar < min_pixels:
65
+ beta = math.sqrt(min_pixels / (height * width))
66
+ h_bar = ceil_by_factor(height * beta, factor)
67
+ w_bar = ceil_by_factor(width * beta, factor)
68
+ return h_bar, w_bar
69
+
70
+
71
+ class MLXVLMAdapter(CustomLLM):
72
+ """MLX VLM Adapter for running vision-language models locally using MLX."""
73
+
74
+ def __init__(self, **kwargs):
75
+ """Initialize the adapter.
76
+
77
+ Args:
78
+ **kwargs: Additional arguments
79
+ """
80
+ super().__init__()
81
+
82
+ self.models = {} # Cache for loaded models
83
+ self.processors = {} # Cache for loaded processors
84
+ self.configs = {} # Cache for loaded configs
85
+ self._executor = ThreadPoolExecutor(max_workers=1) # Single thread pool
86
+
87
+ def _load_model_and_processor(self, model_name: str):
88
+ """Load model and processor if not already cached.
89
+
90
+ Args:
91
+ model_name: Name of the model to load
92
+
93
+ Returns:
94
+ Tuple of (model, processor, config)
95
+ """
96
+ if not MLX_AVAILABLE:
97
+ raise ImportError("MLX VLM dependencies not available. Please install mlx-vlm.")
98
+
99
+ if model_name not in self.models:
100
+ # Load model and processor
101
+ model_obj, processor = load(
102
+ model_name,
103
+ processor_kwargs={"min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS}
104
+ )
105
+ config = load_config(model_name)
106
+
107
+ # Cache them
108
+ self.models[model_name] = model_obj
109
+ self.processors[model_name] = processor
110
+ self.configs[model_name] = config
111
+
112
+ return self.models[model_name], self.processors[model_name], self.configs[model_name]
113
+
114
+ def _process_coordinates(self, text: str, original_size: Tuple[int, int], model_size: Tuple[int, int]) -> str:
115
+ """Process coordinates in box tokens based on image resizing using smart_resize approach.
116
+
117
+ Args:
118
+ text: Text containing box tokens
119
+ original_size: Original image size (width, height)
120
+ model_size: Model processed image size (width, height)
121
+
122
+ Returns:
123
+ Text with processed coordinates
124
+ """
125
+ # Find all box tokens
126
+ box_pattern = r"<\|box_start\|>\((\d+),\s*(\d+)\)<\|box_end\|>"
127
+
128
+ def process_coords(match):
129
+ model_x, model_y = int(match.group(1)), int(match.group(2))
130
+ # Scale coordinates from model space to original image space
131
+ # Both original_size and model_size are in (width, height) format
132
+ new_x = int(model_x * original_size[0] / model_size[0]) # Width
133
+ new_y = int(model_y * original_size[1] / model_size[1]) # Height
134
+ return f"<|box_start|>({new_x},{new_y})<|box_end|>"
135
+
136
+ return re.sub(box_pattern, process_coords, text)
137
+
138
+ def _convert_messages(self, messages: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[Image.Image], Dict[int, Tuple[int, int]], Dict[int, Tuple[int, int]]]:
139
+ """Convert OpenAI format messages to MLX VLM format and extract images.
140
+
141
+ Args:
142
+ messages: Messages in OpenAI format
143
+
144
+ Returns:
145
+ Tuple of (processed_messages, images, original_sizes, model_sizes)
146
+ """
147
+ processed_messages = []
148
+ images = []
149
+ original_sizes = {} # Track original sizes of images for coordinate mapping
150
+ model_sizes = {} # Track model processed sizes
151
+ image_index = 0
152
+
153
+ for message in messages:
154
+ processed_message = {
155
+ "role": message["role"],
156
+ "content": []
157
+ }
158
+
159
+ content = message.get("content", [])
160
+ if isinstance(content, str):
161
+ # Simple text content
162
+ processed_message["content"] = content
163
+ elif isinstance(content, list):
164
+ # Multi-modal content
165
+ processed_content = []
166
+ for item in content:
167
+ if item.get("type") == "text":
168
+ processed_content.append({
169
+ "type": "text",
170
+ "text": item.get("text", "")
171
+ })
172
+ elif item.get("type") == "image_url":
173
+ image_url = item.get("image_url", {}).get("url", "")
174
+ pil_image = None
175
+
176
+ if image_url.startswith("data:image/"):
177
+ # Extract base64 data
178
+ base64_data = image_url.split(',')[1]
179
+ # Convert base64 to PIL Image
180
+ image_data = base64.b64decode(base64_data)
181
+ pil_image = Image.open(io.BytesIO(image_data))
182
+ else:
183
+ # Handle file path or URL
184
+ pil_image = Image.open(image_url)
185
+
186
+ # Store original image size for coordinate mapping
187
+ original_size = pil_image.size
188
+ original_sizes[image_index] = original_size
189
+
190
+ # Use smart_resize to determine model size
191
+ # Note: smart_resize expects (height, width) but PIL gives (width, height)
192
+ height, width = original_size[1], original_size[0]
193
+ new_height, new_width = smart_resize(height, width)
194
+ # Store model size in (width, height) format for consistent coordinate processing
195
+ model_sizes[image_index] = (new_width, new_height)
196
+
197
+ # Resize the image using the calculated dimensions from smart_resize
198
+ resized_image = pil_image.resize((new_width, new_height))
199
+ images.append(resized_image)
200
+
201
+ # Add image placeholder to content
202
+ processed_content.append({
203
+ "type": "image"
204
+ })
205
+
206
+ image_index += 1
207
+
208
+ processed_message["content"] = processed_content
209
+
210
+ processed_messages.append(processed_message)
211
+
212
+ return processed_messages, images, original_sizes, model_sizes
213
+
214
+ def _generate(self, **kwargs) -> str:
215
+ """Generate response using the local MLX VLM model.
216
+
217
+ Args:
218
+ **kwargs: Keyword arguments containing messages and model info
219
+
220
+ Returns:
221
+ Generated text response
222
+ """
223
+ messages = kwargs.get('messages', [])
224
+ model_name = kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')
225
+ max_tokens = kwargs.get('max_tokens', 128)
226
+
227
+ # Warn about ignored kwargs
228
+ ignored_kwargs = set(kwargs.keys()) - {'messages', 'model', 'max_tokens'}
229
+ if ignored_kwargs:
230
+ warnings.warn(f"Ignoring unsupported kwargs: {ignored_kwargs}")
231
+
232
+ # Load model and processor
233
+ model, processor, config = self._load_model_and_processor(model_name)
234
+
235
+ # Convert messages and extract images
236
+ processed_messages, images, original_sizes, model_sizes = self._convert_messages(messages)
237
+
238
+ # Process user text input with box coordinates after image processing
239
+ # Swap original_size and model_size arguments for inverse transformation
240
+ for msg_idx, msg in enumerate(processed_messages):
241
+ if msg.get("role") == "user" and isinstance(msg.get("content"), str):
242
+ content = msg.get("content", "")
243
+ if "<|box_start|>" in content and original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
244
+ orig_size = original_sizes[0]
245
+ model_size = model_sizes[0]
246
+ # Swap arguments to perform inverse transformation for user input
247
+ processed_messages[msg_idx]["content"] = self._process_coordinates(content, model_size, orig_size)
248
+
249
+ try:
250
+ # Format prompt according to model requirements using the processor directly
251
+ prompt = processor.apply_chat_template(
252
+ processed_messages,
253
+ tokenize=False,
254
+ add_generation_prompt=True,
255
+ return_tensors='pt'
256
+ )
257
+ tokenizer = cast(PreTrainedTokenizer, processor)
258
+
259
+ # Generate response
260
+ text_content, usage = generate(
261
+ model,
262
+ tokenizer,
263
+ str(prompt),
264
+ images, # type: ignore
265
+ verbose=False,
266
+ max_tokens=max_tokens
267
+ )
268
+
269
+ except Exception as e:
270
+ raise RuntimeError(f"Error generating response: {str(e)}") from e
271
+
272
+ # Process coordinates in the response back to original image space
273
+ if original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
274
+ # Get original image size and model size (using the first image)
275
+ orig_size = original_sizes[0]
276
+ model_size = model_sizes[0]
277
+
278
+ # Check if output contains box tokens that need processing
279
+ if "<|box_start|>" in text_content:
280
+ # Process coordinates from model space back to original image space
281
+ text_content = self._process_coordinates(text_content, orig_size, model_size)
282
+
283
+ return text_content
284
+
285
+ def completion(self, *args, **kwargs) -> ModelResponse:
286
+ """Synchronous completion method.
287
+
288
+ Returns:
289
+ ModelResponse with generated text
290
+ """
291
+ generated_text = self._generate(**kwargs)
292
+
293
+ result = completion(
294
+ model=f"mlx/{kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')}",
295
+ mock_response=generated_text,
296
+ )
297
+ return cast(ModelResponse, result)
298
+
299
+ async def acompletion(self, *args, **kwargs) -> ModelResponse:
300
+ """Asynchronous completion method.
301
+
302
+ Returns:
303
+ ModelResponse with generated text
304
+ """
305
+ # Run _generate in thread pool to avoid blocking
306
+ loop = asyncio.get_event_loop()
307
+ generated_text = await loop.run_in_executor(
308
+ self._executor,
309
+ functools.partial(self._generate, **kwargs)
310
+ )
311
+
312
+ result = await acompletion(
313
+ model=f"mlx/{kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')}",
314
+ mock_response=generated_text,
315
+ )
316
+ return cast(ModelResponse, result)
317
+
318
+ def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
319
+ """Synchronous streaming method.
320
+
321
+ Returns:
322
+ Iterator of GenericStreamingChunk
323
+ """
324
+ generated_text = self._generate(**kwargs)
325
+
326
+ generic_streaming_chunk: GenericStreamingChunk = {
327
+ "finish_reason": "stop",
328
+ "index": 0,
329
+ "is_finished": True,
330
+ "text": generated_text,
331
+ "tool_use": None,
332
+ "usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
333
+ }
334
+
335
+ yield generic_streaming_chunk
336
+
337
+ async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
338
+ """Asynchronous streaming method.
339
+
340
+ Returns:
341
+ AsyncIterator of GenericStreamingChunk
342
+ """
343
+ # Run _generate in thread pool to avoid blocking
344
+ loop = asyncio.get_event_loop()
345
+ generated_text = await loop.run_in_executor(
346
+ self._executor,
347
+ functools.partial(self._generate, **kwargs)
348
+ )
349
+
350
+ generic_streaming_chunk: GenericStreamingChunk = {
351
+ "finish_reason": "stop",
352
+ "index": 0,
353
+ "is_finished": True,
354
+ "text": generated_text,
355
+ "tool_use": None,
356
+ "usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
357
+ }
358
+
359
+ yield generic_streaming_chunk
@@ -3,6 +3,7 @@ ComputerAgent - Main agent class that selects and runs agent loops
3
3
  """
4
4
 
5
5
  import asyncio
6
+ from pathlib import Path
6
7
  from typing import Dict, List, Any, Optional, AsyncGenerator, Union, cast, Callable, Set, Tuple
7
8
 
8
9
  from litellm.responses.utils import Usage
@@ -22,6 +23,7 @@ import inspect
22
23
  from .adapters import (
23
24
  HuggingFaceLocalAdapter,
24
25
  HumanAdapter,
26
+ MLXVLMAdapter,
25
27
  )
26
28
  from .callbacks import (
27
29
  ImageRetentionCallback,
@@ -29,6 +31,7 @@ from .callbacks import (
29
31
  TrajectorySaverCallback,
30
32
  BudgetManagerCallback,
31
33
  TelemetryCallback,
34
+ OperatorNormalizerCallback
32
35
  )
33
36
  from .computers import (
34
37
  AsyncComputerHandler,
@@ -160,7 +163,7 @@ class ComputerAgent:
160
163
  only_n_most_recent_images: Optional[int] = None,
161
164
  callbacks: Optional[List[Any]] = None,
162
165
  verbosity: Optional[int] = None,
163
- trajectory_dir: Optional[str] = None,
166
+ trajectory_dir: Optional[str | Path | dict] = None,
164
167
  max_retries: Optional[int] = 3,
165
168
  screenshot_delay: Optional[float | int] = 0.5,
166
169
  use_prompt_caching: Optional[bool] = False,
@@ -201,6 +204,9 @@ class ComputerAgent:
201
204
 
202
205
  # == Add built-in callbacks ==
203
206
 
207
+ # Prepend operator normalizer callback
208
+ self.callbacks.insert(0, OperatorNormalizerCallback())
209
+
204
210
  # Add telemetry callback if telemetry_enabled is set
205
211
  if self.telemetry_enabled:
206
212
  if isinstance(self.telemetry_enabled, bool):
@@ -218,7 +224,10 @@ class ComputerAgent:
218
224
 
219
225
  # Add trajectory saver callback if trajectory_dir is set
220
226
  if self.trajectory_dir:
221
- self.callbacks.append(TrajectorySaverCallback(self.trajectory_dir))
227
+ if isinstance(self.trajectory_dir, dict):
228
+ self.callbacks.append(TrajectorySaverCallback(**self.trajectory_dir))
229
+ elif isinstance(self.trajectory_dir, (str, Path)):
230
+ self.callbacks.append(TrajectorySaverCallback(str(self.trajectory_dir)))
222
231
 
223
232
  # Add budget manager if max_trajectory_budget is set
224
233
  if max_trajectory_budget:
@@ -234,9 +243,11 @@ class ComputerAgent:
234
243
  device="auto"
235
244
  )
236
245
  human_adapter = HumanAdapter()
246
+ mlx_adapter = MLXVLMAdapter()
237
247
  litellm.custom_provider_map = [
238
248
  {"provider": "huggingface-local", "custom_handler": hf_adapter},
239
- {"provider": "human", "custom_handler": human_adapter}
249
+ {"provider": "human", "custom_handler": human_adapter},
250
+ {"provider": "mlx", "custom_handler": mlx_adapter}
240
251
  ]
241
252
  litellm.suppress_debug_info = True
242
253
 
@@ -8,6 +8,7 @@ from .logging import LoggingCallback
8
8
  from .trajectory_saver import TrajectorySaverCallback
9
9
  from .budget_manager import BudgetManagerCallback
10
10
  from .telemetry import TelemetryCallback
11
+ from .operator_validator import OperatorNormalizerCallback
11
12
 
12
13
  __all__ = [
13
14
  "AsyncCallbackHandler",
@@ -16,4 +17,5 @@ __all__ = [
16
17
  "TrajectorySaverCallback",
17
18
  "BudgetManagerCallback",
18
19
  "TelemetryCallback",
20
+ "OperatorNormalizerCallback",
19
21
  ]
@@ -0,0 +1,138 @@
1
+ """
2
+ OperatorValidatorCallback
3
+
4
+ Ensures agent output actions conform to expected schemas by fixing common issues:
5
+ - click: add default button='left' if missing
6
+ - keypress: wrap keys string into a list
7
+ - etc.
8
+
9
+ This runs in on_llm_end, which receives the output array (AgentMessage[] as dicts).
10
+ The purpose is to avoid spending another LLM call to fix broken computer call syntax when possible.
11
+ """
12
+ from __future__ import annotations
13
+
14
+ from typing import Any, Dict, List
15
+
16
+ from .base import AsyncCallbackHandler
17
+
18
+
19
+ class OperatorNormalizerCallback(AsyncCallbackHandler):
20
+ """Normalizes common computer call hallucinations / errors in computer call syntax."""
21
+
22
+ async def on_llm_end(self, output: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
23
+ # Mutate in-place as requested, but still return the list for chaining
24
+ for item in output or []:
25
+ if item.get("type") != "computer_call":
26
+ continue
27
+ action = item.get("action")
28
+ if not isinstance(action, dict):
29
+ continue
30
+
31
+ # rename mouse click actions to "click"
32
+ for mouse_btn in ["left", "right", "wheel", "back", "forward"]:
33
+ if action.get("type", "") == f"{mouse_btn}_click":
34
+ action["type"] = "click"
35
+ action["button"] = mouse_btn
36
+ # rename hotkey actions to "keypress"
37
+ for alias in ["hotkey", "key", "press", "key_press"]:
38
+ if action.get("type", "") == alias:
39
+ action["type"] = "keypress"
40
+ # assume click actions
41
+ if "button" in action and "type" not in action:
42
+ action["type"] = "click"
43
+ if "click" in action and "type" not in action:
44
+ action["type"] = "click"
45
+ if ("scroll_x" in action or "scroll_y" in action) and "type" not in action:
46
+ action["type"] = "scroll"
47
+ if "text" in action and "type" not in action:
48
+ action["type"] = "type"
49
+
50
+ action_type = action.get("type")
51
+ def _keep_keys(action: Dict[str, Any], keys_to_keep: List[str]):
52
+ """Keep only the provided keys on action; delete everything else.
53
+ Always ensures required 'type' is present if listed in keys_to_keep.
54
+ """
55
+ for key in list(action.keys()):
56
+ if key not in keys_to_keep:
57
+ del action[key]
58
+ # rename "coordinate" to "x", "y"
59
+ if "coordinate" in action:
60
+ action["x"] = action["coordinate"][0]
61
+ action["y"] = action["coordinate"][1]
62
+ del action["coordinate"]
63
+ if action_type == "click":
64
+ # convert "click" to "button"
65
+ if "button" not in action and "click" in action:
66
+ action["button"] = action["click"]
67
+ del action["click"]
68
+ # default button to "left"
69
+ action["button"] = action.get("button", "left")
70
+ # add default scroll x, y if missing
71
+ if action_type == "scroll":
72
+ action["scroll_x"] = action.get("scroll_x", 0)
73
+ action["scroll_y"] = action.get("scroll_y", 0)
74
+ # ensure keys arg is a list (normalize aliases first)
75
+ if action_type == "keypress":
76
+ keys = action.get("keys")
77
+ for keys_alias in ["keypress", "key", "press", "key_press", "text"]:
78
+ if keys_alias in action:
79
+ action["keys"] = action[keys_alias]
80
+ del action[keys_alias]
81
+ keys = action.get("keys")
82
+ if isinstance(keys, str):
83
+ action["keys"] = keys.replace("-", "+").split("+") if len(keys) > 1 else [keys]
84
+ required_keys_by_type = {
85
+ # OpenAI actions
86
+ "click": ["type", "button", "x", "y"],
87
+ "double_click": ["type", "x", "y"],
88
+ "drag": ["type", "path"],
89
+ "keypress": ["type", "keys"],
90
+ "move": ["type", "x", "y"],
91
+ "screenshot": ["type"],
92
+ "scroll": ["type", "scroll_x", "scroll_y", "x", "y"],
93
+ "type": ["type", "text"],
94
+ "wait": ["type"],
95
+ # Anthropic actions
96
+ "left_mouse_down": ["type", "x", "y"],
97
+ "left_mouse_up": ["type", "x", "y"],
98
+ "triple_click": ["type", "button", "x", "y"],
99
+ }
100
+ keep = required_keys_by_type.get(action_type or "")
101
+ if keep:
102
+ _keep_keys(action, keep)
103
+
104
+
105
+ # Second pass: if an assistant message is immediately followed by a computer_call,
106
+ # replace the assistant message itself with a reasoning message with summary text.
107
+ if isinstance(output, list):
108
+ for i, item in enumerate(output):
109
+ # AssistantMessage shape: { type: 'message', role: 'assistant', content: OutputContent[] }
110
+ if item.get("type") == "message" and item.get("role") == "assistant":
111
+ next_idx = i + 1
112
+ if next_idx >= len(output):
113
+ continue
114
+ next_item = output[next_idx]
115
+ if not isinstance(next_item, dict):
116
+ continue
117
+ if next_item.get("type") != "computer_call":
118
+ continue
119
+ contents = item.get("content") or []
120
+ # Extract text from OutputContent[]
121
+ text_parts: List[str] = []
122
+ if isinstance(contents, list):
123
+ for c in contents:
124
+ if isinstance(c, dict) and c.get("type") == "output_text" and isinstance(c.get("text"), str):
125
+ text_parts.append(c["text"])
126
+ text_content = "\n".join(text_parts).strip()
127
+ # Replace assistant message with reasoning message
128
+ output[i] = {
129
+ "type": "reasoning",
130
+ "summary": [
131
+ {
132
+ "type": "summary_text",
133
+ "text": text_content,
134
+ }
135
+ ],
136
+ }
137
+
138
+ return output