cua-agent 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cua-agent might be problematic. Click here for more details.

agent/__init__.py CHANGED
@@ -4,7 +4,7 @@ __version__ = "0.1.0"
4
4
 
5
5
  from .core.factory import AgentFactory
6
6
  from .core.agent import ComputerAgent
7
- from .types.base import Provider, AgenticLoop
8
- from .providers.omni.types import APIProvider
7
+ from .providers.omni.types import LLMProvider, LLM
8
+ from .types.base import Provider, AgentLoop
9
9
 
10
- __all__ = ["AgentFactory", "Provider", "ComputerAgent", "AgenticLoop", "APIProvider"]
10
+ __all__ = ["AgentFactory", "Provider", "ComputerAgent", "AgentLoop", "LLMProvider", "LLM"]
agent/core/README.md CHANGED
@@ -34,7 +34,7 @@ Here's how to use the unified ComputerAgent:
34
34
  ```python
35
35
  from agent.core.agent import ComputerAgent
36
36
  from agent.types.base import AgenticLoop
37
- from agent.providers.omni.types import APIProvider
37
+ from agent.providers.omni.types import LLMProvider
38
38
  from computer import Computer
39
39
 
40
40
  # Create a Computer instance
@@ -44,7 +44,7 @@ computer = Computer()
44
44
  agent = ComputerAgent(
45
45
  computer=computer,
46
46
  loop_type=AgenticLoop.OMNI,
47
- provider=APIProvider.OPENAI,
47
+ provider=LLMProvider.OPENAI,
48
48
  model="gpt-4o",
49
49
  api_key="your_api_key_here", # Can also use OPENAI_API_KEY environment variable
50
50
  save_trajectory=True,
agent/core/agent.py CHANGED
@@ -3,12 +3,12 @@
3
3
  import os
4
4
  import logging
5
5
  import asyncio
6
- from typing import Any, AsyncGenerator, Dict, List, Optional, TYPE_CHECKING
6
+ from typing import Any, AsyncGenerator, Dict, List, Optional, TYPE_CHECKING, Union, cast
7
7
  from datetime import datetime
8
8
 
9
9
  from computer import Computer
10
10
 
11
- from ..types.base import Provider, AgenticLoop
11
+ from ..types.base import Provider, AgentLoop
12
12
  from .base_agent import BaseComputerAgent
13
13
 
14
14
  # Only import types for type checking to avoid circular imports
@@ -17,23 +17,23 @@ if TYPE_CHECKING:
17
17
  from ..providers.omni.loop import OmniLoop
18
18
  from ..providers.omni.parser import OmniParser
19
19
 
20
- # Import the APIProvider enum without importing the whole module
21
- from ..providers.omni.types import APIProvider
20
+ # Import the provider types
21
+ from ..providers.omni.types import LLMProvider, LLM, Model, LLMModel
22
22
 
23
23
  logger = logging.getLogger(__name__)
24
24
 
25
25
  # Default models for different providers
26
26
  DEFAULT_MODELS = {
27
- APIProvider.OPENAI: "gpt-4o",
28
- APIProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
29
- APIProvider.GROQ: "llama3-70b-8192",
27
+ LLMProvider.OPENAI: "gpt-4o",
28
+ LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
29
+ LLMProvider.GROQ: "llama3-70b-8192",
30
30
  }
31
31
 
32
32
  # Map providers to their environment variable names
33
33
  ENV_VARS = {
34
- APIProvider.OPENAI: "OPENAI_API_KEY",
35
- APIProvider.GROQ: "GROQ_API_KEY",
36
- APIProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
34
+ LLMProvider.OPENAI: "OPENAI_API_KEY",
35
+ LLMProvider.GROQ: "GROQ_API_KEY",
36
+ LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
37
37
  }
38
38
 
39
39
 
@@ -47,10 +47,9 @@ class ComputerAgent(BaseComputerAgent):
47
47
  def __init__(
48
48
  self,
49
49
  computer: Computer,
50
- loop_type: AgenticLoop = AgenticLoop.OMNI,
51
- ai_provider: APIProvider = APIProvider.OPENAI,
50
+ loop: AgentLoop = AgentLoop.OMNI,
51
+ model: Optional[Union[LLM, Dict[str, str], str]] = None,
52
52
  api_key: Optional[str] = None,
53
- model: Optional[str] = None,
54
53
  save_trajectory: bool = True,
55
54
  trajectory_dir: Optional[str] = "trajectories",
56
55
  only_n_most_recent_images: Optional[int] = None,
@@ -62,10 +61,13 @@ class ComputerAgent(BaseComputerAgent):
62
61
 
63
62
  Args:
64
63
  computer: Computer instance to control
65
- loop_type: The type of loop to use (Anthropic or Omni)
66
- ai_provider: AI provider to use (required for Cua loop)
64
+ loop: The type of loop to use (Anthropic or Omni)
65
+ model: LLM configuration. Can be:
66
+ - LLM object with provider and name
67
+ - Dict with 'provider' and 'name' keys
68
+ - String with model name (defaults to OpenAI provider)
69
+ - None (defaults based on loop)
67
70
  api_key: Optional API key (will use environment variable if not provided)
68
- model: Optional model name (will use provider default if not specified)
69
71
  save_trajectory: Whether to save screenshots and logs
70
72
  trajectory_dir: Directory to save trajectories (defaults to "trajectories")
71
73
  only_n_most_recent_images: Limit history to N most recent images
@@ -87,8 +89,7 @@ class ComputerAgent(BaseComputerAgent):
87
89
  **kwargs,
88
90
  )
89
91
 
90
- self.loop_type = loop_type
91
- self.provider = ai_provider
92
+ self.loop_type = loop
92
93
  self.save_trajectory = save_trajectory
93
94
  self.trajectory_dir = trajectory_dir
94
95
  self.only_n_most_recent_images = only_n_most_recent_images
@@ -98,14 +99,19 @@ class ComputerAgent(BaseComputerAgent):
98
99
  # Configure logging based on verbosity
99
100
  self._configure_logging(verbosity)
100
101
 
102
+ # Process model configuration
103
+ self.model_config = self._process_model_config(model, loop)
104
+
101
105
  # Get API key from environment if not provided
102
106
  if api_key is None:
103
107
  env_var = (
104
- ENV_VARS.get(ai_provider) if loop_type == AgenticLoop.OMNI else "ANTHROPIC_API_KEY"
108
+ ENV_VARS.get(self.model_config.provider)
109
+ if loop == AgentLoop.OMNI
110
+ else "ANTHROPIC_API_KEY"
105
111
  )
106
112
  if not env_var:
107
113
  raise ValueError(
108
- f"Unsupported provider: {ai_provider}. Please use one of: {list(ENV_VARS.keys())}"
114
+ f"Unsupported provider: {self.model_config.provider}. Please use one of: {list(ENV_VARS.keys())}"
109
115
  )
110
116
 
111
117
  api_key = os.environ.get(env_var)
@@ -119,18 +125,49 @@ class ComputerAgent(BaseComputerAgent):
119
125
  )
120
126
  self.api_key = api_key
121
127
 
122
- # Set model based on provider if not specified
123
- if model is None:
124
- if loop_type == AgenticLoop.OMNI:
125
- self.model = DEFAULT_MODELS[ai_provider]
126
- else: # Anthropic loop
127
- self.model = DEFAULT_MODELS[APIProvider.ANTHROPIC]
128
- else:
129
- self.model = model
130
-
131
128
  # Initialize the appropriate loop based on loop_type
132
129
  self.loop = self._init_loop()
133
130
 
131
+ def _process_model_config(
132
+ self, model_input: Optional[Union[LLM, Dict[str, str], str]], loop: AgentLoop
133
+ ) -> LLM:
134
+ """Process and normalize model configuration.
135
+
136
+ Args:
137
+ model_input: Input model configuration (LLM, dict, string, or None)
138
+ loop: The loop type being used
139
+
140
+ Returns:
141
+ Normalized LLM instance
142
+ """
143
+ # Handle case where model_input is None
144
+ if model_input is None:
145
+ # Use Anthropic for Anthropic loop, OpenAI for Omni loop
146
+ default_provider = (
147
+ LLMProvider.ANTHROPIC if loop == AgentLoop.ANTHROPIC else LLMProvider.OPENAI
148
+ )
149
+ return LLM(provider=default_provider)
150
+
151
+ # Handle case where model_input is already a LLM or one of its aliases
152
+ if isinstance(model_input, (LLM, Model, LLMModel)):
153
+ return model_input
154
+
155
+ # Handle case where model_input is a dict
156
+ if isinstance(model_input, dict):
157
+ provider = model_input.get("provider", LLMProvider.OPENAI)
158
+ if isinstance(provider, str):
159
+ provider = LLMProvider(provider)
160
+ return LLM(provider=provider, name=model_input.get("name"))
161
+
162
+ # Handle case where model_input is a string (model name)
163
+ if isinstance(model_input, str):
164
+ default_provider = (
165
+ LLMProvider.ANTHROPIC if loop == AgentLoop.ANTHROPIC else LLMProvider.OPENAI
166
+ )
167
+ return LLM(provider=default_provider, name=model_input)
168
+
169
+ raise ValueError(f"Unsupported model configuration: {model_input}")
170
+
134
171
  def _configure_logging(self, verbosity: int):
135
172
  """Configure logging based on verbosity level."""
136
173
  # Use the logging level directly without mapping
@@ -159,12 +196,15 @@ class ComputerAgent(BaseComputerAgent):
159
196
  from ..providers.omni.loop import OmniLoop
160
197
  from ..providers.omni.parser import OmniParser
161
198
 
162
- if self.loop_type == AgenticLoop.ANTHROPIC:
199
+ if self.loop_type == AgentLoop.ANTHROPIC:
163
200
  from ..providers.anthropic.loop import AnthropicLoop
164
201
 
202
+ # Ensure we always have a valid model name
203
+ model_name = self.model_config.name or DEFAULT_MODELS[LLMProvider.ANTHROPIC]
204
+
165
205
  return AnthropicLoop(
166
206
  api_key=self.api_key,
167
- model=self.model,
207
+ model=model_name,
168
208
  computer=self.computer,
169
209
  save_trajectory=self.save_trajectory,
170
210
  base_dir=self.trajectory_dir,
@@ -176,10 +216,13 @@ class ComputerAgent(BaseComputerAgent):
176
216
  if "parser" not in self._kwargs:
177
217
  self._kwargs["parser"] = OmniParser()
178
218
 
219
+ # Ensure we always have a valid model name
220
+ model_name = self.model_config.name or DEFAULT_MODELS[self.model_config.provider]
221
+
179
222
  return OmniLoop(
180
- provider=self.provider,
223
+ provider=self.model_config.provider,
181
224
  api_key=self.api_key,
182
- model=self.model,
225
+ model=model_name,
183
226
  computer=self.computer,
184
227
  save_trajectory=self.save_trajectory,
185
228
  base_dir=self.trajectory_dir,
@@ -198,7 +241,7 @@ class ComputerAgent(BaseComputerAgent):
198
241
  """
199
242
  try:
200
243
  # Format the messages based on loop type
201
- if self.loop_type == AgenticLoop.ANTHROPIC:
244
+ if self.loop_type == AgentLoop.ANTHROPIC:
202
245
  # Anthropic format
203
246
  messages = [{"role": "user", "content": [{"type": "text", "text": task}]}]
204
247
  else:
@@ -221,7 +264,7 @@ class ComputerAgent(BaseComputerAgent):
221
264
  continue
222
265
 
223
266
  # Extract content and metadata based on loop type
224
- if self.loop_type == AgenticLoop.ANTHROPIC:
267
+ if self.loop_type == AgentLoop.ANTHROPIC:
225
268
  # Handle Anthropic format
226
269
  if "content" in result:
227
270
  content_text = ""
agent/core/messages.py CHANGED
@@ -37,6 +37,17 @@ class BaseMessageManager:
37
37
  if self.image_retention_config.min_removal_threshold < 1:
38
38
  raise ValueError("min_removal_threshold must be at least 1")
39
39
 
40
+ # Track provider for message formatting
41
+ self.provider = "openai" # Default provider
42
+
43
+ def set_provider(self, provider: str) -> None:
44
+ """Set the current provider to format messages for.
45
+
46
+ Args:
47
+ provider: Provider name (e.g., 'openai', 'anthropic')
48
+ """
49
+ self.provider = provider.lower()
50
+
40
51
  def prepare_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
41
52
  """Prepare messages by applying image retention and caching as configured.
42
53
 
@@ -96,6 +107,10 @@ class BaseMessageManager:
96
107
  Args:
97
108
  messages: Messages to inject caching into
98
109
  """
110
+ # Only apply cache_control for Anthropic API, not OpenAI
111
+ if self.provider != "anthropic":
112
+ return
113
+
99
114
  # Default to caching last 3 turns
100
115
  turns_to_cache = 3
101
116
  for message in reversed(messages):
@@ -1,6 +1,6 @@
1
1
  """Anthropic provider implementation."""
2
2
 
3
3
  from .loop import AnthropicLoop
4
- from .types import APIProvider
4
+ from .types import LLMProvider
5
5
 
6
- __all__ = ["AnthropicLoop", "APIProvider"]
6
+ __all__ = ["AnthropicLoop", "LLMProvider"]
@@ -3,25 +3,28 @@ import httpx
3
3
  import asyncio
4
4
  from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex
5
5
  from anthropic.types.beta import BetaMessage, BetaMessageParam, BetaToolUnionParam
6
- from ..types import APIProvider
6
+ from ..types import LLMProvider
7
7
  from .logging import log_api_interaction
8
8
  import random
9
9
  import logging
10
10
 
11
11
  logger = logging.getLogger(__name__)
12
12
 
13
+
13
14
  class APIConnectionError(Exception):
14
15
  """Error raised when there are connection issues with the API."""
16
+
15
17
  pass
16
18
 
19
+
17
20
  class BaseAnthropicClient:
18
21
  """Base class for Anthropic API clients."""
19
-
22
+
20
23
  MAX_RETRIES = 10
21
24
  INITIAL_RETRY_DELAY = 1.0
22
25
  MAX_RETRY_DELAY = 60.0
23
26
  JITTER_FACTOR = 0.1
24
-
27
+
25
28
  async def create_message(
26
29
  self,
27
30
  *,
@@ -36,79 +39,67 @@ class BaseAnthropicClient:
36
39
 
37
40
  async def _make_api_call_with_retries(self, api_call):
38
41
  """Make an API call with exponential backoff retry logic.
39
-
42
+
40
43
  Args:
41
44
  api_call: Async function that makes the actual API call
42
-
45
+
43
46
  Returns:
44
47
  API response
45
-
48
+
46
49
  Raises:
47
50
  APIConnectionError: If all retries fail
48
51
  """
49
52
  retry_count = 0
50
53
  last_error = None
51
-
54
+
52
55
  while retry_count < self.MAX_RETRIES:
53
56
  try:
54
57
  return await api_call()
55
58
  except Exception as e:
56
59
  last_error = e
57
60
  retry_count += 1
58
-
61
+
59
62
  if retry_count == self.MAX_RETRIES:
60
63
  break
61
-
64
+
62
65
  # Calculate delay with exponential backoff and jitter
63
66
  delay = min(
64
- self.INITIAL_RETRY_DELAY * (2 ** (retry_count - 1)),
65
- self.MAX_RETRY_DELAY
67
+ self.INITIAL_RETRY_DELAY * (2 ** (retry_count - 1)), self.MAX_RETRY_DELAY
66
68
  )
67
69
  # Add jitter to avoid thundering herd
68
70
  jitter = delay * self.JITTER_FACTOR * (2 * random.random() - 1)
69
71
  final_delay = delay + jitter
70
-
72
+
71
73
  logger.info(
72
74
  f"Retrying request (attempt {retry_count}/{self.MAX_RETRIES}) "
73
75
  f"in {final_delay:.2f} seconds after error: {str(e)}"
74
76
  )
75
77
  await asyncio.sleep(final_delay)
76
-
78
+
77
79
  raise APIConnectionError(
78
- f"Failed after {self.MAX_RETRIES} retries. "
79
- f"Last error: {str(last_error)}"
80
+ f"Failed after {self.MAX_RETRIES} retries. " f"Last error: {str(last_error)}"
80
81
  )
81
82
 
83
+
82
84
  class AnthropicDirectClient(BaseAnthropicClient):
83
85
  """Direct Anthropic API client implementation."""
84
-
86
+
85
87
  def __init__(self, api_key: str, model: str):
86
88
  self.model = model
87
- self.client = Anthropic(
88
- api_key=api_key,
89
- http_client=self._create_http_client()
90
- )
91
-
89
+ self.client = Anthropic(api_key=api_key, http_client=self._create_http_client())
90
+
92
91
  def _create_http_client(self) -> httpx.Client:
93
92
  """Create an HTTP client with appropriate settings."""
94
93
  return httpx.Client(
95
94
  verify=True,
96
- timeout=httpx.Timeout(
97
- connect=30.0,
98
- read=300.0,
99
- write=30.0,
100
- pool=30.0
101
- ),
95
+ timeout=httpx.Timeout(connect=30.0, read=300.0, write=30.0, pool=30.0),
102
96
  transport=httpx.HTTPTransport(
103
97
  retries=3,
104
98
  verify=True,
105
- limits=httpx.Limits(
106
- max_keepalive_connections=5,
107
- max_connections=10
108
- )
109
- )
99
+ limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
100
+ ),
110
101
  )
111
-
102
+
112
103
  async def create_message(
113
104
  self,
114
105
  *,
@@ -119,6 +110,7 @@ class AnthropicDirectClient(BaseAnthropicClient):
119
110
  betas: list[str],
120
111
  ) -> BetaMessage:
121
112
  """Create a message using the direct Anthropic API with retry logic."""
113
+
122
114
  async def api_call():
123
115
  response = self.client.beta.messages.with_raw_response.create(
124
116
  max_tokens=max_tokens,
@@ -130,20 +122,21 @@ class AnthropicDirectClient(BaseAnthropicClient):
130
122
  )
131
123
  log_api_interaction(response.http_response.request, response.http_response, None)
132
124
  return response.parse()
133
-
125
+
134
126
  try:
135
127
  return await self._make_api_call_with_retries(api_call)
136
128
  except Exception as e:
137
129
  log_api_interaction(None, None, e)
138
130
  raise
139
131
 
132
+
140
133
  class AnthropicVertexClient(BaseAnthropicClient):
141
134
  """Google Cloud Vertex AI implementation of Anthropic client."""
142
-
135
+
143
136
  def __init__(self, model: str):
144
137
  self.model = model
145
138
  self.client = AnthropicVertex()
146
-
139
+
147
140
  async def create_message(
148
141
  self,
149
142
  *,
@@ -154,6 +147,7 @@ class AnthropicVertexClient(BaseAnthropicClient):
154
147
  betas: list[str],
155
148
  ) -> BetaMessage:
156
149
  """Create a message using Vertex AI with retry logic."""
150
+
157
151
  async def api_call():
158
152
  response = self.client.beta.messages.with_raw_response.create(
159
153
  max_tokens=max_tokens,
@@ -165,20 +159,21 @@ class AnthropicVertexClient(BaseAnthropicClient):
165
159
  )
166
160
  log_api_interaction(response.http_response.request, response.http_response, None)
167
161
  return response.parse()
168
-
162
+
169
163
  try:
170
164
  return await self._make_api_call_with_retries(api_call)
171
165
  except Exception as e:
172
166
  log_api_interaction(None, None, e)
173
167
  raise
174
168
 
169
+
175
170
  class AnthropicBedrockClient(BaseAnthropicClient):
176
171
  """AWS Bedrock implementation of Anthropic client."""
177
-
172
+
178
173
  def __init__(self, model: str):
179
174
  self.model = model
180
175
  self.client = AnthropicBedrock()
181
-
176
+
182
177
  async def create_message(
183
178
  self,
184
179
  *,
@@ -189,6 +184,7 @@ class AnthropicBedrockClient(BaseAnthropicClient):
189
184
  betas: list[str],
190
185
  ) -> BetaMessage:
191
186
  """Create a message using AWS Bedrock with retry logic."""
187
+
192
188
  async def api_call():
193
189
  response = self.client.beta.messages.with_raw_response.create(
194
190
  max_tokens=max_tokens,
@@ -200,23 +196,24 @@ class AnthropicBedrockClient(BaseAnthropicClient):
200
196
  )
201
197
  log_api_interaction(response.http_response.request, response.http_response, None)
202
198
  return response.parse()
203
-
199
+
204
200
  try:
205
201
  return await self._make_api_call_with_retries(api_call)
206
202
  except Exception as e:
207
203
  log_api_interaction(None, None, e)
208
204
  raise
209
205
 
206
+
210
207
  class AnthropicClientFactory:
211
208
  """Factory for creating appropriate Anthropic client implementations."""
212
-
209
+
213
210
  @staticmethod
214
- def create_client(provider: APIProvider, api_key: str, model: str) -> BaseAnthropicClient:
211
+ def create_client(provider: LLMProvider, api_key: str, model: str) -> BaseAnthropicClient:
215
212
  """Create an appropriate client based on the provider."""
216
- if provider == APIProvider.ANTHROPIC:
213
+ if provider == LLMProvider.ANTHROPIC:
217
214
  return AnthropicDirectClient(api_key, model)
218
- elif provider == APIProvider.VERTEX:
215
+ elif provider == LLMProvider.VERTEX:
219
216
  return AnthropicVertexClient(model)
220
- elif provider == APIProvider.BEDROCK:
217
+ elif provider == LLMProvider.BEDROCK:
221
218
  return AnthropicBedrockClient(model)
222
- raise ValueError(f"Unsupported provider: {provider}")
219
+ raise ValueError(f"Unsupported provider: {provider}")
@@ -32,7 +32,7 @@ from .tools.manager import ToolManager
32
32
  from .messages.manager import MessageManager
33
33
  from .callbacks.manager import CallbackManager
34
34
  from .prompts import SYSTEM_PROMPT
35
- from .types import APIProvider
35
+ from .types import LLMProvider
36
36
  from .tools import ToolResult
37
37
 
38
38
  # Constants
@@ -86,7 +86,7 @@ class AnthropicLoop(BaseLoop):
86
86
  self.model = "claude-3-7-sonnet-20250219"
87
87
 
88
88
  # Anthropic-specific attributes
89
- self.provider = APIProvider.ANTHROPIC
89
+ self.provider = LLMProvider.ANTHROPIC
90
90
  self.client = None
91
91
  self.retry_count = 0
92
92
  self.tool_manager = None
@@ -1,7 +1,7 @@
1
1
  from enum import StrEnum
2
2
 
3
3
 
4
- class APIProvider(StrEnum):
4
+ class LLMProvider(StrEnum):
5
5
  """Enum for supported API providers."""
6
6
 
7
7
  ANTHROPIC = "anthropic"
@@ -9,8 +9,8 @@ class APIProvider(StrEnum):
9
9
  VERTEX = "vertex"
10
10
 
11
11
 
12
- PROVIDER_TO_DEFAULT_MODEL_NAME: dict[APIProvider, str] = {
13
- APIProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
14
- APIProvider.BEDROCK: "anthropic.claude-3-7-sonnet-20250219-v2:0",
15
- APIProvider.VERTEX: "claude-3-5-sonnet-v2@20241022",
12
+ PROVIDER_TO_DEFAULT_MODEL_NAME: dict[LLMProvider, str] = {
13
+ LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
14
+ LLMProvider.BEDROCK: "anthropic.claude-3-7-sonnet-20250219-v2:0",
15
+ LLMProvider.VERTEX: "claude-3-5-sonnet-v2@20241022",
16
16
  }
@@ -2,7 +2,7 @@
2
2
 
3
3
  # The OmniComputerAgent has been replaced by the unified ComputerAgent
4
4
  # which can be found in agent.core.agent
5
- from .types import APIProvider
5
+ from .types import LLMProvider
6
6
  from .experiment import ExperimentManager
7
7
  from .visualization import visualize_click, visualize_scroll, calculate_element_center
8
8
  from .image_utils import (
@@ -14,7 +14,7 @@ from .image_utils import (
14
14
  )
15
15
 
16
16
  __all__ = [
17
- "APIProvider",
17
+ "LLMProvider",
18
18
  "ExperimentManager",
19
19
  "visualize_click",
20
20
  "visualize_scroll",
@@ -17,7 +17,7 @@ import copy
17
17
  from .parser import OmniParser, ParseResult, ParserMetadata, UIElement
18
18
  from ...core.loop import BaseLoop
19
19
  from computer import Computer
20
- from .types import APIProvider
20
+ from .types import LLMProvider
21
21
  from .clients.base import BaseOmniClient
22
22
  from .clients.openai import OpenAIClient
23
23
  from .clients.groq import GroqClient
@@ -46,7 +46,7 @@ class OmniLoop(BaseLoop):
46
46
  def __init__(
47
47
  self,
48
48
  parser: OmniParser,
49
- provider: APIProvider,
49
+ provider: LLMProvider,
50
50
  api_key: str,
51
51
  model: str,
52
52
  computer: Computer,
@@ -180,11 +180,11 @@ class OmniLoop(BaseLoop):
180
180
  try:
181
181
  logger.info(f"Initializing {self.provider} client with model {self.model}...")
182
182
 
183
- if self.provider == APIProvider.OPENAI:
183
+ if self.provider == LLMProvider.OPENAI:
184
184
  self.client = OpenAIClient(api_key=self.api_key, model=self.model)
185
- elif self.provider == APIProvider.GROQ:
185
+ elif self.provider == LLMProvider.GROQ:
186
186
  self.client = GroqClient(api_key=self.api_key, model=self.model)
187
- elif self.provider == APIProvider.ANTHROPIC:
187
+ elif self.provider == LLMProvider.ANTHROPIC:
188
188
  self.client = AnthropicClient(
189
189
  api_key=self.api_key,
190
190
  model=self.model,
@@ -219,12 +219,16 @@ class OmniLoop(BaseLoop):
219
219
  if self.client is None:
220
220
  raise RuntimeError("Failed to initialize client")
221
221
 
222
+ # Set the provider in message manager based on current provider
223
+ provider_name = str(self.provider).split(".")[-1].lower() # Extract name from enum
224
+ self.message_manager.set_provider(provider_name)
225
+
222
226
  # Apply image retention and prepare messages
223
227
  # This will limit the number of images based on only_n_most_recent_images
224
- prepared_messages = self.message_manager.prepare_messages(messages.copy())
228
+ prepared_messages = self.message_manager.get_formatted_messages(provider_name)
225
229
 
226
230
  # Filter out system messages for Anthropic
227
- if self.provider == APIProvider.ANTHROPIC:
231
+ if self.provider == LLMProvider.ANTHROPIC:
228
232
  filtered_messages = [
229
233
  msg for msg in prepared_messages if msg["role"] != "system"
230
234
  ]
@@ -234,7 +238,7 @@ class OmniLoop(BaseLoop):
234
238
  # Log request
235
239
  request_data = {"messages": filtered_messages, "max_tokens": self.max_tokens}
236
240
 
237
- if self.provider == APIProvider.ANTHROPIC:
241
+ if self.provider == LLMProvider.ANTHROPIC:
238
242
  request_data["system"] = self._get_system_prompt()
239
243
  else:
240
244
  request_data["system"] = system_prompt
@@ -251,7 +255,7 @@ class OmniLoop(BaseLoop):
251
255
 
252
256
  if is_async:
253
257
  # For async implementations (AnthropicClient)
254
- if self.provider == APIProvider.ANTHROPIC:
258
+ if self.provider == LLMProvider.ANTHROPIC:
255
259
  response = await run_method(
256
260
  messages=filtered_messages,
257
261
  system=self._get_system_prompt(),
@@ -265,7 +269,7 @@ class OmniLoop(BaseLoop):
265
269
  )
266
270
  else:
267
271
  # For non-async implementations (GroqClient, etc.)
268
- if self.provider == APIProvider.ANTHROPIC:
272
+ if self.provider == LLMProvider.ANTHROPIC:
269
273
  response = run_method(
270
274
  messages=filtered_messages,
271
275
  system=self._get_system_prompt(),
@@ -335,7 +339,7 @@ class OmniLoop(BaseLoop):
335
339
  action_screenshot_saved = False
336
340
  try:
337
341
  # Handle Anthropic response format
338
- if self.provider == APIProvider.ANTHROPIC:
342
+ if self.provider == LLMProvider.ANTHROPIC:
339
343
  if hasattr(response, "content") and isinstance(response.content, list):
340
344
  # Extract text from content blocks
341
345
  for block in response.content:
@@ -559,7 +563,7 @@ class OmniLoop(BaseLoop):
559
563
  """Process and add screen info to messages."""
560
564
  try:
561
565
  # Only add message if we have an image and provider supports it
562
- if self.provider in [APIProvider.OPENAI, APIProvider.ANTHROPIC]:
566
+ if self.provider in [LLMProvider.OPENAI, LLMProvider.ANTHROPIC]:
563
567
  image = parsed_screen.annotated_image_base64 or None
564
568
  if image:
565
569
  # Save screen info to current turn directory
@@ -573,7 +577,7 @@ class OmniLoop(BaseLoop):
573
577
  logger.info(f"Saved elements to {elements_path}")
574
578
 
575
579
  # Format the image content based on the provider
576
- if self.provider == APIProvider.ANTHROPIC:
580
+ if self.provider == LLMProvider.ANTHROPIC:
577
581
  # Compress the image before sending to Anthropic (5MB limit)
578
582
  image_size = len(image)
579
583
  logger.info(f"Image base64 is present, length: {image_size}")
@@ -103,6 +103,9 @@ class OmniMessageManager(BaseMessageManager):
103
103
  Returns:
104
104
  List of formatted messages
105
105
  """
106
+ # Set the provider for message formatting
107
+ self.set_provider(provider)
108
+
106
109
  if provider == "anthropic":
107
110
  return self._format_for_anthropic()
108
111
  elif provider == "openai":
@@ -62,17 +62,3 @@ IMPORTANT NOTES:
62
62
  9. Reflect whether the element is clickable or not, for example reflect if it is an hyperlink or a button or a normal text.
63
63
  10. If you are prompted with login information page or captcha page, or you think it need user's permission to do the next action, you should say "Action": "None" in the json field.
64
64
  """
65
-
66
- # SYSTEM_PROMPT1 = """You are an AI assistant helping users interact with their computer.
67
- # Analyze the screen information and respond with JSON containing:
68
- # {
69
- # "Box ID": "Numeric ID of the relevant UI element",
70
- # "Action": "One of: left_click, right_click, double_click, move_cursor, drag_to, type_text, press_key, hotkey, scroll_down, scroll_up, wait",
71
- # "Value": "Text to type, key to press",
72
- # "Explanation": "Why this action was chosen"
73
- # }
74
-
75
- # Notes:
76
- # - For starting applications, use the "hotkey" action with command+space for starting a Spotlight search.
77
- # - Each UI element is highlighted with a colored bounding box, and its Box ID appears nearby in the same color for easy identification.
78
- # """
@@ -1,11 +1,12 @@
1
1
  """Type definitions for the Omni provider."""
2
2
 
3
3
  from enum import StrEnum
4
- from typing import Dict
4
+ from typing import Dict, Optional
5
+ from dataclasses import dataclass
5
6
 
6
7
 
7
- class APIProvider(StrEnum):
8
- """Supported API providers."""
8
+ class LLMProvider(StrEnum):
9
+ """Supported LLM providers."""
9
10
 
10
11
  ANTHROPIC = "anthropic"
11
12
  OPENAI = "openai"
@@ -13,18 +14,39 @@ class APIProvider(StrEnum):
13
14
  QWEN = "qwen"
14
15
 
15
16
 
17
+ LLMProvider
18
+
19
+
20
+ @dataclass
21
+ class LLM:
22
+ """Configuration for LLM model and provider."""
23
+
24
+ provider: LLMProvider
25
+ name: Optional[str] = None
26
+
27
+ def __post_init__(self):
28
+ """Set default model name if not provided."""
29
+ if self.name is None:
30
+ self.name = PROVIDER_TO_DEFAULT_MODEL.get(self.provider)
31
+
32
+
33
+ # For backward compatibility
34
+ LLMModel = LLM
35
+ Model = LLM
36
+
37
+
16
38
  # Default models for each provider
17
- PROVIDER_TO_DEFAULT_MODEL: Dict[APIProvider, str] = {
18
- APIProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
19
- APIProvider.OPENAI: "gpt-4o",
20
- APIProvider.GROQ: "deepseek-r1-distill-llama-70b",
21
- APIProvider.QWEN: "qwen2.5-vl-72b-instruct",
39
+ PROVIDER_TO_DEFAULT_MODEL: Dict[LLMProvider, str] = {
40
+ LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
41
+ LLMProvider.OPENAI: "gpt-4o",
42
+ LLMProvider.GROQ: "deepseek-r1-distill-llama-70b",
43
+ LLMProvider.QWEN: "qwen2.5-vl-72b-instruct",
22
44
  }
23
45
 
24
46
  # Environment variable names for each provider
25
- PROVIDER_TO_ENV_VAR: Dict[APIProvider, str] = {
26
- APIProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
27
- APIProvider.OPENAI: "OPENAI_API_KEY",
28
- APIProvider.GROQ: "GROQ_API_KEY",
29
- APIProvider.QWEN: "QWEN_API_KEY",
47
+ PROVIDER_TO_ENV_VAR: Dict[LLMProvider, str] = {
48
+ LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
49
+ LLMProvider.OPENAI: "OPENAI_API_KEY",
50
+ LLMProvider.GROQ: "GROQ_API_KEY",
51
+ LLMProvider.QWEN: "QWEN_API_KEY",
30
52
  }
agent/types/base.py CHANGED
@@ -44,9 +44,10 @@ class Annotation(BaseModel):
44
44
  vm_url: str
45
45
 
46
46
 
47
- class AgenticLoop(Enum):
47
+ class AgentLoop(Enum):
48
48
  """Enumeration of available loop types."""
49
49
 
50
50
  ANTHROPIC = auto() # Anthropic implementation
51
+ OPENAI = auto() # OpenAI implementation
51
52
  OMNI = auto() # OmniLoop implementation
52
53
  # Add more loop types as needed
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: cua-agent
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: CUA (Computer Use) Agent for AI-driven computer interaction
5
5
  Author-Email: TryCua <gh@trycua.com>
6
6
  Requires-Python: <3.13,>=3.10
@@ -1,15 +1,15 @@
1
1
  agent/README.md,sha256=8EFnLrKejthEcL9bZflQSbvA-KwpiPanBz8TEEwRub8,2153
2
- agent/__init__.py,sha256=16Q828puFb7Ucq_-de49moVCzl1-iDO8Uo5dzFwX0Ag,347
3
- agent/core/README.md,sha256=RY4kKEjm_-_Ul2xgY7ntzsXdPe0Tg1wvtOSZ4xp4DN0,3559
2
+ agent/__init__.py,sha256=BRIunVPG0T5CdAiNJyElKxUZN8Mngg2_TmtLwaupG4I,355
3
+ agent/core/README.md,sha256=VOXNVbR0ugxf9gCXYmZtUU2kngZhfi29haT_oSxK0Lk,3559
4
4
  agent/core/__init__.py,sha256=0htZ-VfsH9ixHB8j_SXu_uv6r3XXsq5TrghFNd-yRNE,709
5
- agent/core/agent.py,sha256=q2x0vFykIavX_FBi4Eq222QCSFmuuekAin4FPrtSGbY,11711
5
+ agent/core/agent.py,sha256=AQ-S2wVD82RFnD_HmR-zjA7Jj09CUKGp7KreWX1j6Fg,13495
6
6
  agent/core/base_agent.py,sha256=MgaMKTwgqNJ1-TgS_mxALoC9COzc7Acg9y7Q8HAFX2c,6266
7
7
  agent/core/callbacks.py,sha256=VbGIf5QkHh3Q0KsLM6wv7hRdIA5WExTVYLm64bckyUA,4306
8
8
  agent/core/computer_agent.py,sha256=JGLMl_PwImUttmQh2amdLlXHS9CUyZ9MW20J1Xid7dM,2417
9
9
  agent/core/experiment.py,sha256=AST1t83eqaGzjoW6KvrhfVIs3ELAR_I70VHq2NsMmNk,7446
10
10
  agent/core/factory.py,sha256=WraOEHWPXBSN4R3DO7M2ctyadodeA8tzHM3dUjdQ_3A,3441
11
11
  agent/core/loop.py,sha256=E-0pz7MaguZQrHs5GP98Oc8C_Iz8ier0vXrD9Ny2HL8,8999
12
- agent/core/messages.py,sha256=Ou0lLEwa2EQCartcTszsvNjCP6sHUxmr2_C9PGzbASg,7163
12
+ agent/core/messages.py,sha256=N8pV8Eh-AJpMuDPRI5OGWUIOU6DRr-pQjK9XU0go9Hk,7637
13
13
  agent/core/tools/__init__.py,sha256=xZen-PqUp2dUaMEHJowXCQm33_5Sxhsx9PSoD0rq6tI,489
14
14
  agent/core/tools/base.py,sha256=CdzRFNuOjNfzgyTUN4ZoCGkUDR5HI0ECQVpvrUdEij8,2295
15
15
  agent/core/tools/bash.py,sha256=jnJKVlHn8np8e0gWd8EO0_qqjMkfQzutSugA_Iol4jE,1585
@@ -18,11 +18,11 @@ agent/core/tools/computer.py,sha256=lT_aW3huoYpcM8kffuokELupSz_WZG_qkaW1gITRC58,
18
18
  agent/core/tools/edit.py,sha256=kv4jTKCM0VXrnoNErf7mT-xlr81-7T8v49_VA9y_L4Y,2005
19
19
  agent/core/tools/manager.py,sha256=IRsCXjGc076nncQuyIjODoafnHTDhrf9sP5B4q5Pcdo,1742
20
20
  agent/providers/__init__.py,sha256=b4tIBAaIB1V7p8V0BWipHVnMhfHH_OuVgP4OWGSHdD8,194
21
- agent/providers/anthropic/__init__.py,sha256=vEqLDkYXZoXg9A64bOtWfv9hoJlJCXbTpQGcmQ9eec8,149
22
- agent/providers/anthropic/api/client.py,sha256=_DeCn6bYgVG0LcQYDO6VCjTPrt6U-PO5vr4GWmhCPH8,7404
21
+ agent/providers/anthropic/__init__.py,sha256=Mj11IZnVshZ2iHkvg4Z5-jrQIaD1WvzDz2Zk_pMwqIA,149
22
+ agent/providers/anthropic/api/client.py,sha256=Y_g4Xg8Ko4tCqjipVm0GBMw-86vw0KQVXS5aWzJinzw,7038
23
23
  agent/providers/anthropic/api/logging.py,sha256=vHpwkIyOZdkSTVIH4ycbBPd4a_rzhP7Osu1I-Ayouwc,5154
24
24
  agent/providers/anthropic/callbacks/manager.py,sha256=dRKN7MuBze2dLal0iHDxCKYqMdh_KShSphuwn7zC-c4,1878
25
- agent/providers/anthropic/loop.py,sha256=GfUU_0erZgaM8oENSbrKEepsYsYTfuOiygcjHK0pefY,17904
25
+ agent/providers/anthropic/loop.py,sha256=-g-OUpdVPSTO5kFJSZ5AmnjoWSEs2niHZFSR6B_KKvU,17904
26
26
  agent/providers/anthropic/messages/manager.py,sha256=atD41v6bjC1STxRB-jLBty9wHlMwacH9cwsL4tBz3uo,4891
27
27
  agent/providers/anthropic/prompts.py,sha256=nHFfgPrfvnWrEdVP7EUBGUHAI85D2X9HeZirk9EwncU,1941
28
28
  agent/providers/anthropic/tools/__init__.py,sha256=JyZwuVtPUnZwRSZBSCdQv9yxbLCsygm3l8Ywjjt9qTQ,661
@@ -33,8 +33,8 @@ agent/providers/anthropic/tools/computer.py,sha256=WnQS2rIIDz1juwoQMun2ODJjOV134
33
33
  agent/providers/anthropic/tools/edit.py,sha256=EGRP61MDA4Oue1D7Q-_vLpd6LdGbdBA1Z4HSZ66DbmI,13465
34
34
  agent/providers/anthropic/tools/manager.py,sha256=zW-biqO_MV3fb1nDEOl3EmCXD1leoglFj6LDRSM3djs,1982
35
35
  agent/providers/anthropic/tools/run.py,sha256=xhXdnBK1di9muaO44CEirL9hpGy3NmKbjfMpyeVmn8Y,1595
36
- agent/providers/anthropic/types.py,sha256=kKc4XvSuKfumv4KLpJOwyY4t5deBsLgZTSAP4raZGvg,421
37
- agent/providers/omni/__init__.py,sha256=wKOVVWHkD-p4QUz0TIEENkMb7Iq2LRSh88KUGBW1XQA,744
36
+ agent/providers/anthropic/types.py,sha256=SF00kOMC1ui8j9Ah56KaeiR2cL394qCHjFIsBpXxt5w,421
37
+ agent/providers/omni/__init__.py,sha256=eTUh4Pmh4zO-RLnP-wAFm8EkJBMImT-G2xnVIYWRti0,744
38
38
  agent/providers/omni/callbacks.py,sha256=ZG9NCgsHWt6y5jKsfcGLaoLxTpmKnIhCArDdeP4q9sA,2369
39
39
  agent/providers/omni/clients/anthropic.py,sha256=X_QRVxqwA_ExdUqgBEwo1aHOfZQxVIBDmDugNHF97OM,3554
40
40
  agent/providers/omni/clients/base.py,sha256=zAAgPi0jl3SWPC730R9l79E8bfYPSo39UtCSE-mrK6I,1076
@@ -43,23 +43,23 @@ agent/providers/omni/clients/openai.py,sha256=E4TAXMUFoYTunJETCWCNx5XAc6xutiN4rB
43
43
  agent/providers/omni/clients/utils.py,sha256=Ani9CVVBm_J2Dl51WG6p1GVuoI6cq8scISrG0pmQ37o,688
44
44
  agent/providers/omni/experiment.py,sha256=JGAdHi7Nf73I48c9k3TY1Xpr_i6D2VG1wurOzw5cNGk,9888
45
45
  agent/providers/omni/image_utils.py,sha256=qIFuNi5cIMVwrqYBXG1T6PxUlbxz7gIngFFP39bZIlU,2782
46
- agent/providers/omni/loop.py,sha256=Xr2QeedAVJ_jHn3KMopRuH3mrm2Qn4ncxKjqj9hWxAw,43577
47
- agent/providers/omni/messages.py,sha256=6LkQfzYDWq2FvIHpqhs5pc0l6AmFx_xKCjj1R5czMPo,6047
46
+ agent/providers/omni/loop.py,sha256=mHCs13in3mrLizF1x8OeCXECp4bL9-CYS_XOJOUZqu8,43827
47
+ agent/providers/omni/messages.py,sha256=zdjQCAMH-hOyrQQesHhTiIsQbw43KqVSmVIzS8JOIFA,6134
48
48
  agent/providers/omni/parser.py,sha256=Iv-cXWG2qzdYjyZJH5pGUzfv6nOaiHQ2OXdQSe00Ydw,9151
49
- agent/providers/omni/prompts.py,sha256=29qy8ppbLOjLil3aiqryjaiBf8CQx-xXHN44O-85Q00,4503
49
+ agent/providers/omni/prompts.py,sha256=Mupjy0bUwBjcAeLXpE1r1jisYPSlhwsp-IXJKEKrEtw,3779
50
50
  agent/providers/omni/tool_manager.py,sha256=O6DxyEI-Vg6jt99phh011o4q4me_vNhH2YffIxkO4GM,2585
51
51
  agent/providers/omni/tools/__init__.py,sha256=l636hx9Q5z9eaFdPanPwPENUE-w-Xm8kAZhPUq0ZQF4,309
52
52
  agent/providers/omni/tools/bash.py,sha256=y_ibfP9iRcbiU_E0faAoa4DCP_BlkMlKOOURdBBIGZE,2030
53
53
  agent/providers/omni/tools/computer.py,sha256=xkMmAR0e_kbf0Zs2mggCDyWrQOJZyXOKPFjkutaQb94,9108
54
54
  agent/providers/omni/tools/manager.py,sha256=V_tav2yU92PyQnFlxNXG1wvNEaJoEYudtKx5sRjj06Q,2619
55
- agent/providers/omni/types.py,sha256=cEH6M5fcRN8ZIv_jfcYkTYboGBM4EzglLZo1_Xk7Ip8,800
55
+ agent/providers/omni/types.py,sha256=G7Zqm-nWMa3K2klj-D3KUVWc2r8NJB7sYZCwwl0m9Ic,1233
56
56
  agent/providers/omni/utils.py,sha256=JqSye1bEp4wxhUgmaMyZi172fTlgXtygJ7XlnvKdUtE,6337
57
57
  agent/providers/omni/visualization.py,sha256=N3qVQLxYmia3iSVC5oCt5YRlMPuVfylCOyB99R33u8U,3924
58
58
  agent/types/__init__.py,sha256=61UFJT-w0CT4YRn0LiTx4A7fsMdVQjlXO9vnmbI1A7Y,604
59
- agent/types/base.py,sha256=rVb4mPWp1SOHfrzOCDqx0pfCV5bgIsdrIzgM_kX_xVs,1090
59
+ agent/types/base.py,sha256=Iy_Q2DIBMLtwWdLyfvHw_6E2ltYu3bIv8GUNy3LYkGs,1133
60
60
  agent/types/messages.py,sha256=4-hwtxeAhto90_EZpHFducddtsHUsHauvXzYrpKG4RE,953
61
61
  agent/types/tools.py,sha256=Jes2CFCFqC727WWHbO-sG7V03rBHnQe5X7Oi9ZkuScI,877
62
- cua_agent-0.1.0.dist-info/METADATA,sha256=Q4nPzYL_UQwx82vuaRLBUFmA_Sgd37TVoGA9FNYDRmU,1890
63
- cua_agent-0.1.0.dist-info/WHEEL,sha256=thaaA2w1JzcGC48WYufAs8nrYZjJm8LqNfnXFOFyCC4,90
64
- cua_agent-0.1.0.dist-info/entry_points.txt,sha256=6OYgBcLyFCUgeqLgnvMyOJxPCWzgy7se4rLPKtNonMs,34
65
- cua_agent-0.1.0.dist-info/RECORD,,
62
+ cua_agent-0.1.2.dist-info/METADATA,sha256=bXSToJpS_e5KRzyRELUzCuOkozsDUD29pBMj3DKzF7U,1890
63
+ cua_agent-0.1.2.dist-info/WHEEL,sha256=thaaA2w1JzcGC48WYufAs8nrYZjJm8LqNfnXFOFyCC4,90
64
+ cua_agent-0.1.2.dist-info/entry_points.txt,sha256=6OYgBcLyFCUgeqLgnvMyOJxPCWzgy7se4rLPKtNonMs,34
65
+ cua_agent-0.1.2.dist-info/RECORD,,