cua-agent 0.1.24__py3-none-any.whl → 0.1.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/core/base.py +20 -0
- agent/core/callbacks.py +57 -2
- agent/providers/anthropic/callbacks/manager.py +20 -10
- agent/providers/omni/clients/oaicompat.py +11 -3
- agent/providers/omni/loop.py +24 -4
- agent/providers/openai/loop.py +13 -4
- agent/ui/gradio/app.py +429 -329
- {cua_agent-0.1.24.dist-info → cua_agent-0.1.26.dist-info}/METADATA +37 -23
- {cua_agent-0.1.24.dist-info → cua_agent-0.1.26.dist-info}/RECORD +11 -11
- {cua_agent-0.1.24.dist-info → cua_agent-0.1.26.dist-info}/WHEEL +1 -1
- {cua_agent-0.1.24.dist-info → cua_agent-0.1.26.dist-info}/entry_points.txt +0 -0
agent/core/base.py
CHANGED
|
@@ -5,10 +5,12 @@ import asyncio
|
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
6
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
|
7
7
|
|
|
8
|
+
from agent.providers.omni.parser import ParseResult
|
|
8
9
|
from computer import Computer
|
|
9
10
|
from .messages import StandardMessageManager, ImageRetentionConfig
|
|
10
11
|
from .types import AgentResponse
|
|
11
12
|
from .experiment import ExperimentManager
|
|
13
|
+
from .callbacks import CallbackManager, CallbackHandler
|
|
12
14
|
|
|
13
15
|
logger = logging.getLogger(__name__)
|
|
14
16
|
|
|
@@ -27,6 +29,7 @@ class BaseLoop(ABC):
|
|
|
27
29
|
base_dir: Optional[str] = "trajectories",
|
|
28
30
|
save_trajectory: bool = True,
|
|
29
31
|
only_n_most_recent_images: Optional[int] = 2,
|
|
32
|
+
callback_handlers: Optional[List[CallbackHandler]] = None,
|
|
30
33
|
**kwargs,
|
|
31
34
|
):
|
|
32
35
|
"""Initialize base agent loop.
|
|
@@ -75,6 +78,9 @@ class BaseLoop(ABC):
|
|
|
75
78
|
|
|
76
79
|
# Initialize basic tracking
|
|
77
80
|
self.turn_count = 0
|
|
81
|
+
|
|
82
|
+
# Initialize callback manager
|
|
83
|
+
self.callback_manager = CallbackManager(handlers=callback_handlers or [])
|
|
78
84
|
|
|
79
85
|
async def initialize(self) -> None:
|
|
80
86
|
"""Initialize both the API client and computer interface with retries."""
|
|
@@ -187,3 +193,17 @@ class BaseLoop(ABC):
|
|
|
187
193
|
"""
|
|
188
194
|
if self.experiment_manager:
|
|
189
195
|
self.experiment_manager.save_screenshot(img_base64, action_type)
|
|
196
|
+
|
|
197
|
+
###########################################
|
|
198
|
+
# EVENT HOOKS / CALLBACKS
|
|
199
|
+
###########################################
|
|
200
|
+
|
|
201
|
+
async def handle_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None:
|
|
202
|
+
"""Process a screenshot through callback managers
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
screenshot_base64: Base64 encoded screenshot
|
|
206
|
+
action_type: Type of action that triggered the screenshot
|
|
207
|
+
"""
|
|
208
|
+
if hasattr(self, 'callback_manager'):
|
|
209
|
+
await self.callback_manager.on_screenshot(screenshot_base64, action_type, parsed_screen)
|
agent/core/callbacks.py
CHANGED
|
@@ -6,6 +6,8 @@ from abc import ABC, abstractmethod
|
|
|
6
6
|
from datetime import datetime
|
|
7
7
|
from typing import Any, Dict, List, Optional, Protocol
|
|
8
8
|
|
|
9
|
+
from agent.providers.omni.parser import ParseResult
|
|
10
|
+
|
|
9
11
|
logger = logging.getLogger(__name__)
|
|
10
12
|
|
|
11
13
|
class ContentCallback(Protocol):
|
|
@@ -20,6 +22,10 @@ class APICallback(Protocol):
|
|
|
20
22
|
"""Protocol for API callbacks."""
|
|
21
23
|
def __call__(self, request: Any, response: Any, error: Optional[Exception] = None) -> None: ...
|
|
22
24
|
|
|
25
|
+
class ScreenshotCallback(Protocol):
|
|
26
|
+
"""Protocol for screenshot callbacks."""
|
|
27
|
+
def __call__(self, screenshot_base64: str, action_type: str = "") -> Optional[str]: ...
|
|
28
|
+
|
|
23
29
|
class BaseCallbackManager(ABC):
|
|
24
30
|
"""Base class for callback managers."""
|
|
25
31
|
|
|
@@ -110,7 +116,20 @@ class CallbackManager:
|
|
|
110
116
|
"""
|
|
111
117
|
for handler in self.handlers:
|
|
112
118
|
await handler.on_error(error, **kwargs)
|
|
113
|
-
|
|
119
|
+
|
|
120
|
+
async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None:
|
|
121
|
+
"""Called when a screenshot is taken.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
screenshot_base64: Base64 encoded screenshot
|
|
125
|
+
action_type: Type of action that triggered the screenshot
|
|
126
|
+
parsed_screen: Optional output from parsing the screenshot
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
Modified screenshot or original if no modifications
|
|
130
|
+
"""
|
|
131
|
+
for handler in self.handlers:
|
|
132
|
+
await handler.on_screenshot(screenshot_base64, action_type, parsed_screen)
|
|
114
133
|
|
|
115
134
|
class CallbackHandler(ABC):
|
|
116
135
|
"""Base class for callback handlers."""
|
|
@@ -144,4 +163,40 @@ class CallbackHandler(ABC):
|
|
|
144
163
|
error: Exception that occurred
|
|
145
164
|
**kwargs: Additional data
|
|
146
165
|
"""
|
|
147
|
-
pass
|
|
166
|
+
pass
|
|
167
|
+
|
|
168
|
+
@abstractmethod
|
|
169
|
+
async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None:
|
|
170
|
+
"""Called when a screenshot is taken.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
screenshot_base64: Base64 encoded screenshot
|
|
174
|
+
action_type: Type of action that triggered the screenshot
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
Optional modified screenshot
|
|
178
|
+
"""
|
|
179
|
+
pass
|
|
180
|
+
|
|
181
|
+
class DefaultCallbackHandler(CallbackHandler):
|
|
182
|
+
"""Default implementation of CallbackHandler with no-op methods.
|
|
183
|
+
|
|
184
|
+
This class implements all abstract methods from CallbackHandler,
|
|
185
|
+
allowing subclasses to override only the methods they need.
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
async def on_action_start(self, action: str, **kwargs) -> None:
|
|
189
|
+
"""Default no-op implementation."""
|
|
190
|
+
pass
|
|
191
|
+
|
|
192
|
+
async def on_action_end(self, action: str, success: bool, **kwargs) -> None:
|
|
193
|
+
"""Default no-op implementation."""
|
|
194
|
+
pass
|
|
195
|
+
|
|
196
|
+
async def on_error(self, error: Exception, **kwargs) -> None:
|
|
197
|
+
"""Default no-op implementation."""
|
|
198
|
+
pass
|
|
199
|
+
|
|
200
|
+
async def on_screenshot(self, screenshot_base64: str, action_type: str = "") -> None:
|
|
201
|
+
"""Default no-op implementation."""
|
|
202
|
+
pass
|
|
@@ -3,23 +3,33 @@ import httpx
|
|
|
3
3
|
from anthropic.types.beta import BetaContentBlockParam
|
|
4
4
|
from ..tools import ToolResult
|
|
5
5
|
|
|
6
|
+
|
|
6
7
|
class APICallback(Protocol):
|
|
7
8
|
"""Protocol for API callbacks."""
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
9
|
+
|
|
10
|
+
def __call__(
|
|
11
|
+
self,
|
|
12
|
+
request: httpx.Request | None,
|
|
13
|
+
response: httpx.Response | object | None,
|
|
14
|
+
error: Exception | None,
|
|
15
|
+
) -> None: ...
|
|
16
|
+
|
|
11
17
|
|
|
12
18
|
class ContentCallback(Protocol):
|
|
13
19
|
"""Protocol for content callbacks."""
|
|
20
|
+
|
|
14
21
|
def __call__(self, content: BetaContentBlockParam) -> None: ...
|
|
15
22
|
|
|
23
|
+
|
|
16
24
|
class ToolCallback(Protocol):
|
|
17
25
|
"""Protocol for tool callbacks."""
|
|
26
|
+
|
|
18
27
|
def __call__(self, result: ToolResult, tool_id: str) -> None: ...
|
|
19
28
|
|
|
29
|
+
|
|
20
30
|
class CallbackManager:
|
|
21
31
|
"""Manages various callbacks for the agent system."""
|
|
22
|
-
|
|
32
|
+
|
|
23
33
|
def __init__(
|
|
24
34
|
self,
|
|
25
35
|
content_callback: ContentCallback,
|
|
@@ -27,7 +37,7 @@ class CallbackManager:
|
|
|
27
37
|
api_callback: APICallback,
|
|
28
38
|
):
|
|
29
39
|
"""Initialize the callback manager.
|
|
30
|
-
|
|
40
|
+
|
|
31
41
|
Args:
|
|
32
42
|
content_callback: Callback for content updates
|
|
33
43
|
tool_callback: Callback for tool execution results
|
|
@@ -36,20 +46,20 @@ class CallbackManager:
|
|
|
36
46
|
self.content_callback = content_callback
|
|
37
47
|
self.tool_callback = tool_callback
|
|
38
48
|
self.api_callback = api_callback
|
|
39
|
-
|
|
49
|
+
|
|
40
50
|
def on_content(self, content: BetaContentBlockParam) -> None:
|
|
41
51
|
"""Handle content updates."""
|
|
42
52
|
self.content_callback(content)
|
|
43
|
-
|
|
53
|
+
|
|
44
54
|
def on_tool_result(self, result: ToolResult, tool_id: str) -> None:
|
|
45
55
|
"""Handle tool execution results."""
|
|
46
56
|
self.tool_callback(result, tool_id)
|
|
47
|
-
|
|
57
|
+
|
|
48
58
|
def on_api_interaction(
|
|
49
59
|
self,
|
|
50
60
|
request: httpx.Request | None,
|
|
51
61
|
response: httpx.Response | object | None,
|
|
52
|
-
error: Exception | None
|
|
62
|
+
error: Exception | None,
|
|
53
63
|
) -> None:
|
|
54
64
|
"""Handle API interactions."""
|
|
55
|
-
self.api_callback(request, response, error)
|
|
65
|
+
self.api_callback(request, response, error)
|
|
@@ -45,8 +45,8 @@ class OAICompatClient(BaseOmniClient):
|
|
|
45
45
|
max_tokens: Maximum tokens to generate
|
|
46
46
|
temperature: Generation temperature
|
|
47
47
|
"""
|
|
48
|
-
super().__init__(api_key="EMPTY", model=model)
|
|
49
|
-
self.api_key = "EMPTY"
|
|
48
|
+
super().__init__(api_key=api_key or "EMPTY", model=model)
|
|
49
|
+
self.api_key = api_key or "EMPTY" # Local endpoints typically don't require an API key
|
|
50
50
|
self.model = model
|
|
51
51
|
self.provider_base_url = (
|
|
52
52
|
provider_base_url or "http://localhost:8000/v1"
|
|
@@ -146,10 +146,18 @@ class OAICompatClient(BaseOmniClient):
|
|
|
146
146
|
base_url = self.provider_base_url or "http://localhost:8000/v1"
|
|
147
147
|
|
|
148
148
|
# Check if the base URL already includes the chat/completions endpoint
|
|
149
|
+
|
|
149
150
|
endpoint_url = base_url
|
|
150
151
|
if not endpoint_url.endswith("/chat/completions"):
|
|
152
|
+
# If URL is RunPod format, make it OpenAI compatible
|
|
153
|
+
if endpoint_url.startswith("https://api.runpod.ai/v2/"):
|
|
154
|
+
# Extract RunPod endpoint ID
|
|
155
|
+
parts = endpoint_url.split("/")
|
|
156
|
+
if len(parts) >= 5:
|
|
157
|
+
runpod_id = parts[4]
|
|
158
|
+
endpoint_url = f"https://api.runpod.ai/v2/{runpod_id}/openai/v1/chat/completions"
|
|
151
159
|
# If the URL ends with /v1, append /chat/completions
|
|
152
|
-
|
|
160
|
+
elif endpoint_url.endswith("/v1"):
|
|
153
161
|
endpoint_url = f"{endpoint_url}/chat/completions"
|
|
154
162
|
# If the URL doesn't end with /v1, make sure it has a proper structure
|
|
155
163
|
elif not endpoint_url.endswith("/"):
|
agent/providers/omni/loop.py
CHANGED
|
@@ -147,7 +147,7 @@ class OmniLoop(BaseLoop):
|
|
|
147
147
|
)
|
|
148
148
|
elif self.provider == LLMProvider.OAICOMPAT:
|
|
149
149
|
self.client = OAICompatClient(
|
|
150
|
-
api_key="EMPTY", # Local endpoints typically don't require an API key
|
|
150
|
+
api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
|
|
151
151
|
model=self.model,
|
|
152
152
|
provider_base_url=self.provider_base_url,
|
|
153
153
|
)
|
|
@@ -183,7 +183,7 @@ class OmniLoop(BaseLoop):
|
|
|
183
183
|
)
|
|
184
184
|
elif self.provider == LLMProvider.OAICOMPAT:
|
|
185
185
|
self.client = OAICompatClient(
|
|
186
|
-
api_key="EMPTY", # Local endpoints typically don't require an API key
|
|
186
|
+
api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
|
|
187
187
|
model=self.model,
|
|
188
188
|
provider_base_url=self.provider_base_url,
|
|
189
189
|
)
|
|
@@ -443,6 +443,8 @@ class OmniLoop(BaseLoop):
|
|
|
443
443
|
except (json.JSONDecodeError, IndexError):
|
|
444
444
|
try:
|
|
445
445
|
# Look for JSON object pattern
|
|
446
|
+
import re # Local import to ensure availability
|
|
447
|
+
|
|
446
448
|
json_pattern = r"\{[^}]+\}"
|
|
447
449
|
json_match = re.search(json_pattern, raw_text)
|
|
448
450
|
if json_match:
|
|
@@ -453,8 +455,20 @@ class OmniLoop(BaseLoop):
|
|
|
453
455
|
logger.error(f"No JSON found in content")
|
|
454
456
|
return True, action_screenshot_saved
|
|
455
457
|
except json.JSONDecodeError as e:
|
|
456
|
-
|
|
457
|
-
|
|
458
|
+
# Try to sanitize the JSON string and retry
|
|
459
|
+
try:
|
|
460
|
+
# Remove or replace invalid control characters
|
|
461
|
+
import re # Local import to ensure availability
|
|
462
|
+
|
|
463
|
+
sanitized_text = re.sub(r"[\x00-\x1F\x7F]", "", raw_text)
|
|
464
|
+
# Try parsing again with sanitized text
|
|
465
|
+
parsed_content = json.loads(sanitized_text)
|
|
466
|
+
logger.info(
|
|
467
|
+
"Successfully parsed JSON after sanitizing control characters"
|
|
468
|
+
)
|
|
469
|
+
except json.JSONDecodeError:
|
|
470
|
+
logger.error(f"Failed to parse JSON from text: {str(e)}")
|
|
471
|
+
return True, action_screenshot_saved
|
|
458
472
|
|
|
459
473
|
# Step 4: Process the parsed content if available
|
|
460
474
|
if parsed_content:
|
|
@@ -534,6 +548,10 @@ class OmniLoop(BaseLoop):
|
|
|
534
548
|
img_data = parsed_screen.annotated_image_base64
|
|
535
549
|
if "," in img_data:
|
|
536
550
|
img_data = img_data.split(",")[1]
|
|
551
|
+
|
|
552
|
+
# Process screenshot through hooks and save if needed
|
|
553
|
+
await self.handle_screenshot(img_data, action_type="state", parsed_screen=parsed_screen)
|
|
554
|
+
|
|
537
555
|
# Save with a generic "state" action type to indicate this is the current screen state
|
|
538
556
|
self._save_screenshot(img_data, action_type="state")
|
|
539
557
|
except Exception as e:
|
|
@@ -649,6 +667,8 @@ class OmniLoop(BaseLoop):
|
|
|
649
667
|
response=response,
|
|
650
668
|
messages=self.message_manager.messages,
|
|
651
669
|
model=self.model,
|
|
670
|
+
parsed_screen=parsed_screen,
|
|
671
|
+
parser=self.parser
|
|
652
672
|
)
|
|
653
673
|
|
|
654
674
|
# Yield the response to the caller
|
agent/providers/openai/loop.py
CHANGED
|
@@ -194,8 +194,13 @@ class OpenAILoop(BaseLoop):
|
|
|
194
194
|
# Convert to base64 if needed
|
|
195
195
|
if isinstance(screenshot, bytes):
|
|
196
196
|
screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
|
|
197
|
+
elif isinstance(screenshot, (bytearray, memoryview)):
|
|
198
|
+
screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
|
|
197
199
|
else:
|
|
198
|
-
screenshot_base64 = screenshot
|
|
200
|
+
screenshot_base64 = str(screenshot)
|
|
201
|
+
|
|
202
|
+
# Emit screenshot callbacks
|
|
203
|
+
await self.handle_screenshot(screenshot_base64, action_type="initial_state")
|
|
199
204
|
|
|
200
205
|
# Save screenshot if requested
|
|
201
206
|
if self.save_trajectory:
|
|
@@ -204,8 +209,6 @@ class OpenAILoop(BaseLoop):
|
|
|
204
209
|
logger.warning(
|
|
205
210
|
"Converting non-string screenshot_base64 to string for _save_screenshot"
|
|
206
211
|
)
|
|
207
|
-
if isinstance(screenshot_base64, (bytearray, memoryview)):
|
|
208
|
-
screenshot_base64 = base64.b64encode(screenshot_base64).decode("utf-8")
|
|
209
212
|
self._save_screenshot(screenshot_base64, action_type="state")
|
|
210
213
|
logger.info("Screenshot saved to trajectory")
|
|
211
214
|
|
|
@@ -336,8 +339,14 @@ class OpenAILoop(BaseLoop):
|
|
|
336
339
|
screenshot = await self.computer.interface.screenshot()
|
|
337
340
|
if isinstance(screenshot, bytes):
|
|
338
341
|
screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
|
|
342
|
+
elif isinstance(screenshot, (bytearray, memoryview)):
|
|
343
|
+
screenshot_base64 = base64.b64encode(bytes(screenshot)).decode("utf-8")
|
|
339
344
|
else:
|
|
340
|
-
screenshot_base64 = screenshot
|
|
345
|
+
screenshot_base64 = str(screenshot)
|
|
346
|
+
|
|
347
|
+
# Process screenshot through hooks
|
|
348
|
+
action_type = f"after_{action.get('type', 'action')}"
|
|
349
|
+
await self.handle_screenshot(screenshot_base64, action_type=action_type)
|
|
341
350
|
|
|
342
351
|
# Create computer_call_output
|
|
343
352
|
computer_call_output = {
|