cua-agent 0.1.25__py3-none-any.whl → 0.1.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cua-agent might be problematic. Click here for more details.

agent/core/base.py CHANGED
@@ -5,10 +5,12 @@ import asyncio
5
5
  from abc import ABC, abstractmethod
6
6
  from typing import Any, AsyncGenerator, Dict, List, Optional
7
7
 
8
+ from agent.providers.omni.parser import ParseResult
8
9
  from computer import Computer
9
10
  from .messages import StandardMessageManager, ImageRetentionConfig
10
11
  from .types import AgentResponse
11
12
  from .experiment import ExperimentManager
13
+ from .callbacks import CallbackManager, CallbackHandler
12
14
 
13
15
  logger = logging.getLogger(__name__)
14
16
 
@@ -27,6 +29,7 @@ class BaseLoop(ABC):
27
29
  base_dir: Optional[str] = "trajectories",
28
30
  save_trajectory: bool = True,
29
31
  only_n_most_recent_images: Optional[int] = 2,
32
+ callback_handlers: Optional[List[CallbackHandler]] = None,
30
33
  **kwargs,
31
34
  ):
32
35
  """Initialize base agent loop.
@@ -75,6 +78,9 @@ class BaseLoop(ABC):
75
78
 
76
79
  # Initialize basic tracking
77
80
  self.turn_count = 0
81
+
82
+ # Initialize callback manager
83
+ self.callback_manager = CallbackManager(handlers=callback_handlers or [])
78
84
 
79
85
  async def initialize(self) -> None:
80
86
  """Initialize both the API client and computer interface with retries."""
@@ -187,3 +193,17 @@ class BaseLoop(ABC):
187
193
  """
188
194
  if self.experiment_manager:
189
195
  self.experiment_manager.save_screenshot(img_base64, action_type)
196
+
197
+ ###########################################
198
+ # EVENT HOOKS / CALLBACKS
199
+ ###########################################
200
+
201
+ async def handle_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None:
202
+ """Process a screenshot through callback managers
203
+
204
+ Args:
205
+ screenshot_base64: Base64 encoded screenshot
206
+ action_type: Type of action that triggered the screenshot
207
+ """
208
+ if hasattr(self, 'callback_manager'):
209
+ await self.callback_manager.on_screenshot(screenshot_base64, action_type, parsed_screen)
agent/core/callbacks.py CHANGED
@@ -6,6 +6,8 @@ from abc import ABC, abstractmethod
6
6
  from datetime import datetime
7
7
  from typing import Any, Dict, List, Optional, Protocol
8
8
 
9
+ from agent.providers.omni.parser import ParseResult
10
+
9
11
  logger = logging.getLogger(__name__)
10
12
 
11
13
  class ContentCallback(Protocol):
@@ -20,6 +22,10 @@ class APICallback(Protocol):
20
22
  """Protocol for API callbacks."""
21
23
  def __call__(self, request: Any, response: Any, error: Optional[Exception] = None) -> None: ...
22
24
 
25
+ class ScreenshotCallback(Protocol):
26
+ """Protocol for screenshot callbacks."""
27
+ def __call__(self, screenshot_base64: str, action_type: str = "") -> Optional[str]: ...
28
+
23
29
  class BaseCallbackManager(ABC):
24
30
  """Base class for callback managers."""
25
31
 
@@ -110,7 +116,20 @@ class CallbackManager:
110
116
  """
111
117
  for handler in self.handlers:
112
118
  await handler.on_error(error, **kwargs)
113
-
119
+
120
+ async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None:
121
+ """Called when a screenshot is taken.
122
+
123
+ Args:
124
+ screenshot_base64: Base64 encoded screenshot
125
+ action_type: Type of action that triggered the screenshot
126
+ parsed_screen: Optional output from parsing the screenshot
127
+
128
+ Returns:
129
+ Modified screenshot or original if no modifications
130
+ """
131
+ for handler in self.handlers:
132
+ await handler.on_screenshot(screenshot_base64, action_type, parsed_screen)
114
133
 
115
134
  class CallbackHandler(ABC):
116
135
  """Base class for callback handlers."""
@@ -144,4 +163,40 @@ class CallbackHandler(ABC):
144
163
  error: Exception that occurred
145
164
  **kwargs: Additional data
146
165
  """
147
- pass
166
+ pass
167
+
168
+ @abstractmethod
169
+ async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None:
170
+ """Called when a screenshot is taken.
171
+
172
+ Args:
173
+ screenshot_base64: Base64 encoded screenshot
174
+ action_type: Type of action that triggered the screenshot
175
+
176
+ Returns:
177
+ Optional modified screenshot
178
+ """
179
+ pass
180
+
181
+ class DefaultCallbackHandler(CallbackHandler):
182
+ """Default implementation of CallbackHandler with no-op methods.
183
+
184
+ This class implements all abstract methods from CallbackHandler,
185
+ allowing subclasses to override only the methods they need.
186
+ """
187
+
188
+ async def on_action_start(self, action: str, **kwargs) -> None:
189
+ """Default no-op implementation."""
190
+ pass
191
+
192
+ async def on_action_end(self, action: str, success: bool, **kwargs) -> None:
193
+ """Default no-op implementation."""
194
+ pass
195
+
196
+ async def on_error(self, error: Exception, **kwargs) -> None:
197
+ """Default no-op implementation."""
198
+ pass
199
+
200
+ async def on_screenshot(self, screenshot_base64: str, action_type: str = "") -> None:
201
+ """Default no-op implementation."""
202
+ pass
@@ -3,23 +3,33 @@ import httpx
3
3
  from anthropic.types.beta import BetaContentBlockParam
4
4
  from ..tools import ToolResult
5
5
 
6
+
6
7
  class APICallback(Protocol):
7
8
  """Protocol for API callbacks."""
8
- def __call__(self, request: httpx.Request | None,
9
- response: httpx.Response | object | None,
10
- error: Exception | None) -> None: ...
9
+
10
+ def __call__(
11
+ self,
12
+ request: httpx.Request | None,
13
+ response: httpx.Response | object | None,
14
+ error: Exception | None,
15
+ ) -> None: ...
16
+
11
17
 
12
18
  class ContentCallback(Protocol):
13
19
  """Protocol for content callbacks."""
20
+
14
21
  def __call__(self, content: BetaContentBlockParam) -> None: ...
15
22
 
23
+
16
24
  class ToolCallback(Protocol):
17
25
  """Protocol for tool callbacks."""
26
+
18
27
  def __call__(self, result: ToolResult, tool_id: str) -> None: ...
19
28
 
29
+
20
30
  class CallbackManager:
21
31
  """Manages various callbacks for the agent system."""
22
-
32
+
23
33
  def __init__(
24
34
  self,
25
35
  content_callback: ContentCallback,
@@ -27,7 +37,7 @@ class CallbackManager:
27
37
  api_callback: APICallback,
28
38
  ):
29
39
  """Initialize the callback manager.
30
-
40
+
31
41
  Args:
32
42
  content_callback: Callback for content updates
33
43
  tool_callback: Callback for tool execution results
@@ -36,20 +46,20 @@ class CallbackManager:
36
46
  self.content_callback = content_callback
37
47
  self.tool_callback = tool_callback
38
48
  self.api_callback = api_callback
39
-
49
+
40
50
  def on_content(self, content: BetaContentBlockParam) -> None:
41
51
  """Handle content updates."""
42
52
  self.content_callback(content)
43
-
53
+
44
54
  def on_tool_result(self, result: ToolResult, tool_id: str) -> None:
45
55
  """Handle tool execution results."""
46
56
  self.tool_callback(result, tool_id)
47
-
57
+
48
58
  def on_api_interaction(
49
59
  self,
50
60
  request: httpx.Request | None,
51
61
  response: httpx.Response | object | None,
52
- error: Exception | None
62
+ error: Exception | None,
53
63
  ) -> None:
54
64
  """Handle API interactions."""
55
- self.api_callback(request, response, error)
65
+ self.api_callback(request, response, error)
@@ -45,8 +45,8 @@ class OAICompatClient(BaseOmniClient):
45
45
  max_tokens: Maximum tokens to generate
46
46
  temperature: Generation temperature
47
47
  """
48
- super().__init__(api_key="EMPTY", model=model)
49
- self.api_key = "EMPTY" # Local endpoints typically don't require an API key
48
+ super().__init__(api_key=api_key or "EMPTY", model=model)
49
+ self.api_key = api_key or "EMPTY" # Local endpoints typically don't require an API key
50
50
  self.model = model
51
51
  self.provider_base_url = (
52
52
  provider_base_url or "http://localhost:8000/v1"
@@ -146,10 +146,18 @@ class OAICompatClient(BaseOmniClient):
146
146
  base_url = self.provider_base_url or "http://localhost:8000/v1"
147
147
 
148
148
  # Check if the base URL already includes the chat/completions endpoint
149
+
149
150
  endpoint_url = base_url
150
151
  if not endpoint_url.endswith("/chat/completions"):
152
+ # If URL is RunPod format, make it OpenAI compatible
153
+ if endpoint_url.startswith("https://api.runpod.ai/v2/"):
154
+ # Extract RunPod endpoint ID
155
+ parts = endpoint_url.split("/")
156
+ if len(parts) >= 5:
157
+ runpod_id = parts[4]
158
+ endpoint_url = f"https://api.runpod.ai/v2/{runpod_id}/openai/v1/chat/completions"
151
159
  # If the URL ends with /v1, append /chat/completions
152
- if endpoint_url.endswith("/v1"):
160
+ elif endpoint_url.endswith("/v1"):
153
161
  endpoint_url = f"{endpoint_url}/chat/completions"
154
162
  # If the URL doesn't end with /v1, make sure it has a proper structure
155
163
  elif not endpoint_url.endswith("/"):
@@ -147,7 +147,7 @@ class OmniLoop(BaseLoop):
147
147
  )
148
148
  elif self.provider == LLMProvider.OAICOMPAT:
149
149
  self.client = OAICompatClient(
150
- api_key="EMPTY", # Local endpoints typically don't require an API key
150
+ api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
151
151
  model=self.model,
152
152
  provider_base_url=self.provider_base_url,
153
153
  )
@@ -183,7 +183,7 @@ class OmniLoop(BaseLoop):
183
183
  )
184
184
  elif self.provider == LLMProvider.OAICOMPAT:
185
185
  self.client = OAICompatClient(
186
- api_key="EMPTY", # Local endpoints typically don't require an API key
186
+ api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
187
187
  model=self.model,
188
188
  provider_base_url=self.provider_base_url,
189
189
  )
@@ -548,6 +548,10 @@ class OmniLoop(BaseLoop):
548
548
  img_data = parsed_screen.annotated_image_base64
549
549
  if "," in img_data:
550
550
  img_data = img_data.split(",")[1]
551
+
552
+ # Process screenshot through hooks and save if needed
553
+ await self.handle_screenshot(img_data, action_type="state", parsed_screen=parsed_screen)
554
+
551
555
  # Save with a generic "state" action type to indicate this is the current screen state
552
556
  self._save_screenshot(img_data, action_type="state")
553
557
  except Exception as e:
@@ -663,6 +667,8 @@ class OmniLoop(BaseLoop):
663
667
  response=response,
664
668
  messages=self.message_manager.messages,
665
669
  model=self.model,
670
+ parsed_screen=parsed_screen,
671
+ parser=self.parser
666
672
  )
667
673
 
668
674
  # Yield the response to the caller
@@ -194,8 +194,13 @@ class OpenAILoop(BaseLoop):
194
194
  # Convert to base64 if needed
195
195
  if isinstance(screenshot, bytes):
196
196
  screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
197
+ elif isinstance(screenshot, (bytearray, memoryview)):
198
+ screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
197
199
  else:
198
- screenshot_base64 = screenshot
200
+ screenshot_base64 = str(screenshot)
201
+
202
+ # Emit screenshot callbacks
203
+ await self.handle_screenshot(screenshot_base64, action_type="initial_state")
199
204
 
200
205
  # Save screenshot if requested
201
206
  if self.save_trajectory:
@@ -204,8 +209,6 @@ class OpenAILoop(BaseLoop):
204
209
  logger.warning(
205
210
  "Converting non-string screenshot_base64 to string for _save_screenshot"
206
211
  )
207
- if isinstance(screenshot_base64, (bytearray, memoryview)):
208
- screenshot_base64 = base64.b64encode(screenshot_base64).decode("utf-8")
209
212
  self._save_screenshot(screenshot_base64, action_type="state")
210
213
  logger.info("Screenshot saved to trajectory")
211
214
 
@@ -336,8 +339,14 @@ class OpenAILoop(BaseLoop):
336
339
  screenshot = await self.computer.interface.screenshot()
337
340
  if isinstance(screenshot, bytes):
338
341
  screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
342
+ elif isinstance(screenshot, (bytearray, memoryview)):
343
+ screenshot_base64 = base64.b64encode(bytes(screenshot)).decode("utf-8")
339
344
  else:
340
- screenshot_base64 = screenshot
345
+ screenshot_base64 = str(screenshot)
346
+
347
+ # Process screenshot through hooks
348
+ action_type = f"after_{action.get('type', 'action')}"
349
+ await self.handle_screenshot(screenshot_base64, action_type=action_type)
341
350
 
342
351
  # Create computer_call_output
343
352
  computer_call_output = {