cua-agent 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cua-agent might be problematic. Click here for more details.

agent/__init__.py CHANGED
@@ -48,9 +48,7 @@ except Exception as e:
48
48
  # Other issues with telemetry
49
49
  logger.warning(f"Error initializing telemetry: {e}")
50
50
 
51
- from .core.factory import AgentFactory
52
- from .core.agent import ComputerAgent
53
51
  from .providers.omni.types import LLMProvider, LLM
54
- from .types.base import Provider, AgentLoop
52
+ from .types.base import AgentLoop
55
53
 
56
- __all__ = ["AgentFactory", "Provider", "ComputerAgent", "AgentLoop", "LLMProvider", "LLM"]
54
+ __all__ = ["AgentLoop", "LLMProvider", "LLM"]
agent/core/__init__.py CHANGED
@@ -1,6 +1,5 @@
1
1
  """Core agent components."""
2
2
 
3
- from .base_agent import BaseComputerAgent
4
3
  from .loop import BaseLoop
5
4
  from .messages import (
6
5
  create_user_message,
@@ -12,7 +11,7 @@ from .messages import (
12
11
  ImageRetentionConfig,
13
12
  )
14
13
  from .callbacks import (
15
- CallbackManager,
14
+ CallbackManager,
16
15
  CallbackHandler,
17
16
  BaseCallbackManager,
18
17
  ContentCallback,
@@ -21,9 +20,8 @@ from .callbacks import (
21
20
  )
22
21
 
23
22
  __all__ = [
24
- "BaseComputerAgent",
25
- "BaseLoop",
26
- "CallbackManager",
23
+ "BaseLoop",
24
+ "CallbackManager",
27
25
  "CallbackHandler",
28
26
  "BaseMessageManager",
29
27
  "ImageRetentionConfig",
@@ -1,69 +1,251 @@
1
1
  """Main entry point for computer agents."""
2
2
 
3
+ import asyncio
3
4
  import logging
4
- from typing import Any, AsyncGenerator, Dict, Optional
5
+ import os
6
+ from typing import Any, AsyncGenerator, Dict, Optional, cast
7
+ from dataclasses import dataclass
5
8
 
6
9
  from computer import Computer
7
- from ..types.base import Provider
8
- from .factory import AgentFactory
10
+ from ..providers.anthropic.loop import AnthropicLoop
11
+ from ..providers.omni.loop import OmniLoop
12
+ from ..providers.omni.parser import OmniParser
13
+ from ..providers.omni.types import LLMProvider, LLM
14
+ from .. import AgentLoop
9
15
 
10
16
  logging.basicConfig(level=logging.INFO)
11
17
  logger = logging.getLogger(__name__)
12
18
 
19
+ # Default models for different providers
20
+ DEFAULT_MODELS = {
21
+ LLMProvider.OPENAI: "gpt-4o",
22
+ LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
23
+ }
24
+
25
+ # Map providers to their environment variable names
26
+ ENV_VARS = {
27
+ LLMProvider.OPENAI: "OPENAI_API_KEY",
28
+ LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
29
+ }
30
+
13
31
 
14
32
  class ComputerAgent:
15
33
  """A computer agent that can perform automated tasks using natural language instructions."""
16
34
 
17
- def __init__(self, provider: Provider, computer: Optional[Computer] = None, **kwargs):
35
+ def __init__(
36
+ self,
37
+ computer: Computer,
38
+ model: LLM,
39
+ loop: AgentLoop,
40
+ max_retries: int = 3,
41
+ screenshot_dir: Optional[str] = None,
42
+ log_dir: Optional[str] = None,
43
+ api_key: Optional[str] = None,
44
+ save_trajectory: bool = True,
45
+ trajectory_dir: str = "trajectories",
46
+ only_n_most_recent_images: Optional[int] = None,
47
+ parser: Optional[OmniParser] = None,
48
+ verbosity: int = logging.INFO,
49
+ ):
18
50
  """Initialize the ComputerAgent.
19
51
 
20
52
  Args:
21
- provider: The AI provider to use (e.g., Provider.ANTHROPIC)
22
- computer: Optional Computer instance. If not provided, one will be created with default settings.
23
- **kwargs: Additional provider-specific arguments
53
+ computer: Computer instance. If not provided, one will be created with default settings.
54
+ max_retries: Maximum number of retry attempts.
55
+ screenshot_dir: Directory to save screenshots.
56
+ log_dir: Directory to save logs (set to None to disable logging to files).
57
+ model: LLM object containing provider and model name. Takes precedence over provider/model_name.
58
+ provider: The AI provider to use (e.g., LLMProvider.ANTHROPIC). Only used if model is None.
59
+ api_key: The API key for the provider. If not provided, will look for environment variable.
60
+ model_name: The model name to use. Only used if model is None.
61
+ save_trajectory: Whether to save the trajectory.
62
+ trajectory_dir: Directory to save the trajectory.
63
+ only_n_most_recent_images: Maximum number of recent screenshots to include in API requests.
64
+ parser: Parser instance for the OmniLoop. Only used if provider is not ANTHROPIC.
65
+ verbosity: Logging level.
24
66
  """
25
- self.provider = provider
26
- self._computer = computer
27
- self._kwargs = kwargs
28
- self._agent = None
67
+ # Basic agent configuration
68
+ self.max_retries = max_retries
69
+ self.computer = computer or Computer()
70
+ self.queue = asyncio.Queue()
71
+ self.screenshot_dir = screenshot_dir
72
+ self.log_dir = log_dir
73
+ self._retry_count = 0
29
74
  self._initialized = False
30
75
  self._in_context = False
31
76
 
32
- # Create provider-specific agent using factory
33
- self._agent = AgentFactory.create(provider=provider, computer=computer, **kwargs)
77
+ # Set logging level
78
+ logger.setLevel(verbosity)
79
+
80
+ # Setup logging
81
+ if self.log_dir:
82
+ os.makedirs(self.log_dir, exist_ok=True)
83
+ logger.info(f"Created logs directory: {self.log_dir}")
84
+
85
+ # Setup screenshots directory
86
+ if self.screenshot_dir:
87
+ os.makedirs(self.screenshot_dir, exist_ok=True)
88
+ logger.info(f"Created screenshots directory: {self.screenshot_dir}")
89
+
90
+ # Use the provided LLM object
91
+ self.provider = model.provider
92
+ actual_model_name = model.name or DEFAULT_MODELS.get(self.provider, "")
93
+
94
+ # Ensure we have a valid model name
95
+ if not actual_model_name:
96
+ actual_model_name = DEFAULT_MODELS.get(self.provider, "")
97
+ if not actual_model_name:
98
+ raise ValueError(
99
+ f"No model specified for provider {self.provider} and no default found"
100
+ )
101
+
102
+ # Ensure computer is properly cast for typing purposes
103
+ computer_instance = cast(Computer, self.computer)
104
+
105
+ # Get API key from environment if not provided
106
+ actual_api_key = api_key or os.environ.get(ENV_VARS[self.provider], "")
107
+ if not actual_api_key:
108
+ raise ValueError(f"No API key provided for {self.provider}")
109
+
110
+ # Initialize the appropriate loop based on the loop parameter
111
+ if loop == AgentLoop.ANTHROPIC:
112
+ self._loop = AnthropicLoop(
113
+ api_key=actual_api_key,
114
+ model=actual_model_name,
115
+ computer=computer_instance,
116
+ save_trajectory=save_trajectory,
117
+ base_dir=trajectory_dir,
118
+ only_n_most_recent_images=only_n_most_recent_images,
119
+ )
120
+ else:
121
+ # Default to OmniLoop for other loop types
122
+ # Initialize parser if not provided
123
+ actual_parser = parser or OmniParser()
124
+
125
+ self._loop = OmniLoop(
126
+ provider=self.provider,
127
+ api_key=actual_api_key,
128
+ model=actual_model_name,
129
+ computer=computer_instance,
130
+ save_trajectory=save_trajectory,
131
+ base_dir=trajectory_dir,
132
+ only_n_most_recent_images=only_n_most_recent_images,
133
+ parser=actual_parser,
134
+ )
135
+
136
+ logger.info(
137
+ f"ComputerAgent initialized with provider: {self.provider}, model: {actual_model_name}"
138
+ )
34
139
 
35
140
  async def __aenter__(self):
36
- """Enter the async context manager."""
141
+ """Initialize the agent when used as a context manager."""
142
+ logger.info("Entering ComputerAgent context")
37
143
  self._in_context = True
144
+
145
+ # In case the computer wasn't initialized
146
+ try:
147
+ # Initialize the computer only if not already initialized
148
+ logger.info("Checking if computer is already initialized...")
149
+ if not self.computer._initialized:
150
+ logger.info("Initializing computer in __aenter__...")
151
+ # Use the computer's __aenter__ directly instead of calling run()
152
+ await self.computer.__aenter__()
153
+ logger.info("Computer initialized in __aenter__")
154
+ else:
155
+ logger.info("Computer already initialized, skipping initialization")
156
+
157
+ # Take a test screenshot to verify the computer is working
158
+ logger.info("Testing computer with a screenshot...")
159
+ try:
160
+ test_screenshot = await self.computer.interface.screenshot()
161
+ # Determine the screenshot size based on its type
162
+ if isinstance(test_screenshot, (bytes, bytearray, memoryview)):
163
+ size = len(test_screenshot)
164
+ elif hasattr(test_screenshot, "base64_image"):
165
+ size = len(test_screenshot.base64_image)
166
+ else:
167
+ size = "unknown"
168
+ logger.info(f"Screenshot test successful, size: {size}")
169
+ except Exception as e:
170
+ logger.error(f"Screenshot test failed: {str(e)}")
171
+ # Even though screenshot failed, we continue since some tests might not need it
172
+ except Exception as e:
173
+ logger.error(f"Error initializing computer in __aenter__: {str(e)}")
174
+ raise
175
+
38
176
  await self.initialize()
39
177
  return self
40
178
 
41
179
  async def __aexit__(self, exc_type, exc_val, exc_tb):
42
- """Exit the async context manager."""
180
+ """Cleanup agent resources if needed."""
181
+ logger.info("Cleaning up agent resources")
43
182
  self._in_context = False
44
183
 
184
+ # Do any necessary cleanup
185
+ # We're not shutting down the computer here as it might be shared
186
+ # Just log that we're exiting
187
+ if exc_type:
188
+ logger.error(f"Exiting agent context with error: {exc_type.__name__}: {exc_val}")
189
+ else:
190
+ logger.info("Exiting agent context normally")
191
+
192
+ # If we have a queue, make sure to signal it's done
193
+ if hasattr(self, "queue") and self.queue:
194
+ await self.queue.put(None) # Signal that we're done
195
+
45
196
  async def initialize(self) -> None:
46
197
  """Initialize the agent and its components."""
47
198
  if not self._initialized:
48
- if not self._in_context and self._computer:
49
- # If not in context manager but have a computer, initialize it
50
- await self._computer.run()
199
+ # Always initialize the computer if available
200
+ if self.computer and not self.computer._initialized:
201
+ await self.computer.run()
51
202
  self._initialized = True
52
203
 
204
+ async def _init_if_needed(self):
205
+ """Initialize the computer interface if it hasn't been initialized yet."""
206
+ if not self.computer._initialized:
207
+ logger.info("Computer not initialized, initializing now...")
208
+ try:
209
+ # Call run directly
210
+ await self.computer.run()
211
+ logger.info("Computer interface initialized successfully")
212
+ except Exception as e:
213
+ logger.error(f"Error initializing computer interface: {str(e)}")
214
+ raise
215
+
53
216
  async def run(self, task: str) -> AsyncGenerator[Dict[str, Any], None]:
54
- """Run the agent with a given task."""
55
- if not self._initialized:
56
- await self.initialize()
217
+ """Run a task using the computer agent.
218
+
219
+ Args:
220
+ task: Task description
221
+
222
+ Yields:
223
+ Task execution updates
224
+ """
225
+ try:
226
+ logger.info(f"Running task: {task}")
227
+
228
+ # Initialize the computer if needed
229
+ if not self._initialized:
230
+ await self.initialize()
231
+
232
+ # Format task as a message
233
+ messages = [{"role": "user", "content": task}]
57
234
 
58
- if self._agent is None:
59
- logger.error("Agent not initialized properly")
60
- yield {"error": "Agent not initialized properly"}
61
- return
235
+ # Pass properly formatted messages to the loop
236
+ if self._loop is None:
237
+ logger.error("Loop not initialized properly")
238
+ yield {"error": "Loop not initialized properly"}
239
+ return
62
240
 
63
- async for result in self._agent.run(task):
64
- yield result
241
+ # Execute the task and yield results
242
+ async for result in self._loop.run(messages):
243
+ yield result
65
244
 
66
- @property
67
- def computer(self) -> Optional[Computer]:
68
- """Get the underlying computer instance."""
69
- return self._agent.computer if self._agent else None
245
+ except Exception as e:
246
+ logger.error(f"Error in agent run method: {str(e)}")
247
+ yield {
248
+ "role": "assistant",
249
+ "content": f"Error: {str(e)}",
250
+ "metadata": {"title": "❌ Error"},
251
+ }
agent/core/experiment.py CHANGED
@@ -84,7 +84,21 @@ class ExperimentManager:
84
84
  if isinstance(data, dict):
85
85
  result = {}
86
86
  for k, v in data.items():
87
- result[k] = self.sanitize_log_data(v)
87
+ # Special handling for 'data' field in Anthropic message source
88
+ if k == "data" and isinstance(v, str) and len(v) > 1000:
89
+ result[k] = f"[BASE64_DATA_LENGTH_{len(v)}]"
90
+ # Special handling for the 'media_type' key which indicates we're in an image block
91
+ elif k == "media_type" and "image" in str(v):
92
+ result[k] = v
93
+ # If we're in an image block, look for a sibling 'data' field with base64 content
94
+ if (
95
+ "data" in result
96
+ and isinstance(result["data"], str)
97
+ and len(result["data"]) > 1000
98
+ ):
99
+ result["data"] = f"[BASE64_DATA_LENGTH_{len(result['data'])}]"
100
+ else:
101
+ result[k] = self.sanitize_log_data(v)
88
102
  return result
89
103
  elif isinstance(data, list):
90
104
  return [self.sanitize_log_data(item) for item in data]
@@ -93,15 +107,18 @@ class ExperimentManager:
93
107
  else:
94
108
  return data
95
109
 
96
- def save_screenshot(self, img_base64: str, action_type: str = "") -> None:
110
+ def save_screenshot(self, img_base64: str, action_type: str = "") -> Optional[str]:
97
111
  """Save a screenshot to the experiment directory.
98
112
 
99
113
  Args:
100
114
  img_base64: Base64 encoded screenshot
101
115
  action_type: Type of action that triggered the screenshot
116
+
117
+ Returns:
118
+ Path to the saved screenshot or None if there was an error
102
119
  """
103
120
  if not self.current_turn_dir:
104
- return
121
+ return None
105
122
 
106
123
  try:
107
124
  # Increment screenshot counter
agent/core/loop.py CHANGED
@@ -141,9 +141,6 @@ class BaseLoop(ABC):
141
141
  # Initialize API client
142
142
  await self.initialize_client()
143
143
 
144
- # Initialize computer
145
- await self.computer.initialize()
146
-
147
144
  logger.info("Initialization complete.")
148
145
  return
149
146
  except Exception as e:
@@ -173,15 +170,22 @@ class BaseLoop(ABC):
173
170
  base64_image = ""
174
171
 
175
172
  # Handle different types of screenshot returns
176
- if isinstance(screenshot, bytes):
173
+ if isinstance(screenshot, (bytes, bytearray, memoryview)):
177
174
  # Raw bytes screenshot
178
175
  base64_image = base64.b64encode(screenshot).decode("utf-8")
179
176
  elif hasattr(screenshot, "base64_image"):
180
177
  # Object-style screenshot with attributes
181
- base64_image = screenshot.base64_image
182
- if hasattr(screenshot, "width") and hasattr(screenshot, "height"):
183
- width = screenshot.width
184
- height = screenshot.height
178
+ # Type checking can't infer these attributes, but they exist at runtime
179
+ # on certain screenshot return types
180
+ base64_image = getattr(screenshot, "base64_image")
181
+ width = (
182
+ getattr(screenshot, "width", width) if hasattr(screenshot, "width") else width
183
+ )
184
+ height = (
185
+ getattr(screenshot, "height", height)
186
+ if hasattr(screenshot, "height")
187
+ else height
188
+ )
185
189
 
186
190
  # Create parsed screen data
187
191
  parsed_screen = {
agent/core/telemetry.py CHANGED
@@ -4,58 +4,70 @@ import logging
4
4
  import os
5
5
  import platform
6
6
  import sys
7
- from typing import Dict, Any
7
+ from typing import Dict, Any, Callable
8
8
 
9
9
  # Import the core telemetry module
10
10
  TELEMETRY_AVAILABLE = False
11
11
 
12
+
13
+ # Local fallbacks in case core telemetry isn't available
14
+ def _noop(*args: Any, **kwargs: Any) -> None:
15
+ """No-op function for when telemetry is not available."""
16
+ pass
17
+
18
+
19
+ # Define default functions with unique names to avoid shadowing
20
+ _default_record_event = _noop
21
+ _default_increment_counter = _noop
22
+ _default_set_dimension = _noop
23
+ _default_get_telemetry_client = lambda: None
24
+ _default_flush = _noop
25
+ _default_is_telemetry_enabled = lambda: False
26
+ _default_is_telemetry_globally_disabled = lambda: True
27
+
28
+ # Set the actual functions to the defaults initially
29
+ record_event = _default_record_event
30
+ increment_counter = _default_increment_counter
31
+ set_dimension = _default_set_dimension
32
+ get_telemetry_client = _default_get_telemetry_client
33
+ flush = _default_flush
34
+ is_telemetry_enabled = _default_is_telemetry_enabled
35
+ is_telemetry_globally_disabled = _default_is_telemetry_globally_disabled
36
+
37
+ logger = logging.getLogger("cua.agent.telemetry")
38
+
12
39
  try:
40
+ # Import from core telemetry
13
41
  from core.telemetry import (
14
- record_event,
15
- increment,
16
- get_telemetry_client,
17
- flush,
18
- is_telemetry_enabled,
19
- is_telemetry_globally_disabled,
42
+ record_event as core_record_event,
43
+ increment as core_increment,
44
+ get_telemetry_client as core_get_telemetry_client,
45
+ flush as core_flush,
46
+ is_telemetry_enabled as core_is_telemetry_enabled,
47
+ is_telemetry_globally_disabled as core_is_telemetry_globally_disabled,
20
48
  )
21
49
 
50
+ # Override the default functions with actual implementations
51
+ record_event = core_record_event
52
+ get_telemetry_client = core_get_telemetry_client
53
+ flush = core_flush
54
+ is_telemetry_enabled = core_is_telemetry_enabled
55
+ is_telemetry_globally_disabled = core_is_telemetry_globally_disabled
56
+
22
57
  def increment_counter(counter_name: str, value: int = 1) -> None:
23
58
  """Wrapper for increment to maintain backward compatibility."""
24
59
  if is_telemetry_enabled():
25
- increment(counter_name, value)
60
+ core_increment(counter_name, value)
26
61
 
27
62
  def set_dimension(name: str, value: Any) -> None:
28
63
  """Set a dimension that will be attached to all events."""
29
- logger = logging.getLogger("cua.agent.telemetry")
30
64
  logger.debug(f"Setting dimension {name}={value}")
31
65
 
32
66
  TELEMETRY_AVAILABLE = True
33
- logger = logging.getLogger("cua.agent.telemetry")
34
67
  logger.info("Successfully imported telemetry")
35
68
  except ImportError as e:
36
- logger = logging.getLogger("cua.agent.telemetry")
37
69
  logger.warning(f"Could not import telemetry: {e}")
38
- TELEMETRY_AVAILABLE = False
39
-
40
-
41
- # Local fallbacks in case core telemetry isn't available
42
- def _noop(*args: Any, **kwargs: Any) -> None:
43
- """No-op function for when telemetry is not available."""
44
- pass
45
-
46
-
47
- logger = logging.getLogger("cua.agent.telemetry")
48
-
49
- # If telemetry isn't available, use no-op functions
50
- if not TELEMETRY_AVAILABLE:
51
70
  logger.debug("Telemetry not available, using no-op functions")
52
- record_event = _noop # type: ignore
53
- increment_counter = _noop # type: ignore
54
- set_dimension = _noop # type: ignore
55
- get_telemetry_client = lambda: None # type: ignore
56
- flush = _noop # type: ignore
57
- is_telemetry_enabled = lambda: False # type: ignore
58
- is_telemetry_globally_disabled = lambda: True # type: ignore
59
71
 
60
72
  # Get system info once to use in telemetry
61
73
  SYSTEM_INFO = {
@@ -71,7 +83,7 @@ def enable_telemetry() -> bool:
71
83
  Returns:
72
84
  bool: True if telemetry was successfully enabled, False otherwise
73
85
  """
74
- global TELEMETRY_AVAILABLE
86
+ global TELEMETRY_AVAILABLE, record_event, increment_counter, get_telemetry_client, flush, is_telemetry_enabled, is_telemetry_globally_disabled
75
87
 
76
88
  # Check if globally disabled using core function
77
89
  if TELEMETRY_AVAILABLE and is_telemetry_globally_disabled():