cua-agent 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cua-agent might be problematic. Click here for more details.

agent/__init__.py CHANGED
@@ -5,6 +5,6 @@ __version__ = "0.1.0"
5
5
  from .core.factory import AgentFactory
6
6
  from .core.agent import ComputerAgent
7
7
  from .types.base import Provider, AgenticLoop
8
- from .providers.omni.types import APIProvider
8
+ from .providers.omni.types import LLMProvider, LLM, Model, LLMModel, APIProvider
9
9
 
10
- __all__ = ["AgentFactory", "Provider", "ComputerAgent", "AgenticLoop", "APIProvider"]
10
+ __all__ = ["AgentFactory", "Provider", "ComputerAgent", "AgenticLoop", "LLMProvider", "LLM", "Model", "LLMModel", "APIProvider"]
agent/core/agent.py CHANGED
@@ -3,7 +3,7 @@
3
3
  import os
4
4
  import logging
5
5
  import asyncio
6
- from typing import Any, AsyncGenerator, Dict, List, Optional, TYPE_CHECKING
6
+ from typing import Any, AsyncGenerator, Dict, List, Optional, TYPE_CHECKING, Union, cast
7
7
  from datetime import datetime
8
8
 
9
9
  from computer import Computer
@@ -17,23 +17,23 @@ if TYPE_CHECKING:
17
17
  from ..providers.omni.loop import OmniLoop
18
18
  from ..providers.omni.parser import OmniParser
19
19
 
20
- # Import the APIProvider enum without importing the whole module
21
- from ..providers.omni.types import APIProvider
20
+ # Import the provider types
21
+ from ..providers.omni.types import LLMProvider, LLM, Model, LLMModel, APIProvider
22
22
 
23
23
  logger = logging.getLogger(__name__)
24
24
 
25
25
  # Default models for different providers
26
26
  DEFAULT_MODELS = {
27
- APIProvider.OPENAI: "gpt-4o",
28
- APIProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
29
- APIProvider.GROQ: "llama3-70b-8192",
27
+ LLMProvider.OPENAI: "gpt-4o",
28
+ LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
29
+ LLMProvider.GROQ: "llama3-70b-8192",
30
30
  }
31
31
 
32
32
  # Map providers to their environment variable names
33
33
  ENV_VARS = {
34
- APIProvider.OPENAI: "OPENAI_API_KEY",
35
- APIProvider.GROQ: "GROQ_API_KEY",
36
- APIProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
34
+ LLMProvider.OPENAI: "OPENAI_API_KEY",
35
+ LLMProvider.GROQ: "GROQ_API_KEY",
36
+ LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
37
37
  }
38
38
 
39
39
 
@@ -48,9 +48,8 @@ class ComputerAgent(BaseComputerAgent):
48
48
  self,
49
49
  computer: Computer,
50
50
  loop_type: AgenticLoop = AgenticLoop.OMNI,
51
- ai_provider: APIProvider = APIProvider.OPENAI,
51
+ model: Optional[Union[LLM, Dict[str, str], str]] = None,
52
52
  api_key: Optional[str] = None,
53
- model: Optional[str] = None,
54
53
  save_trajectory: bool = True,
55
54
  trajectory_dir: Optional[str] = "trajectories",
56
55
  only_n_most_recent_images: Optional[int] = None,
@@ -63,9 +62,12 @@ class ComputerAgent(BaseComputerAgent):
63
62
  Args:
64
63
  computer: Computer instance to control
65
64
  loop_type: The type of loop to use (Anthropic or Omni)
66
- ai_provider: AI provider to use (required for Cua loop)
65
+ model: LLM configuration. Can be:
66
+ - LLM object with provider and name
67
+ - Dict with 'provider' and 'name' keys
68
+ - String with model name (defaults to OpenAI provider)
69
+ - None (defaults based on loop_type)
67
70
  api_key: Optional API key (will use environment variable if not provided)
68
- model: Optional model name (will use provider default if not specified)
69
71
  save_trajectory: Whether to save screenshots and logs
70
72
  trajectory_dir: Directory to save trajectories (defaults to "trajectories")
71
73
  only_n_most_recent_images: Limit history to N most recent images
@@ -88,7 +90,6 @@ class ComputerAgent(BaseComputerAgent):
88
90
  )
89
91
 
90
92
  self.loop_type = loop_type
91
- self.provider = ai_provider
92
93
  self.save_trajectory = save_trajectory
93
94
  self.trajectory_dir = trajectory_dir
94
95
  self.only_n_most_recent_images = only_n_most_recent_images
@@ -98,14 +99,19 @@ class ComputerAgent(BaseComputerAgent):
98
99
  # Configure logging based on verbosity
99
100
  self._configure_logging(verbosity)
100
101
 
102
+ # Process model configuration
103
+ self.model_config = self._process_model_config(model, loop_type)
104
+
101
105
  # Get API key from environment if not provided
102
106
  if api_key is None:
103
107
  env_var = (
104
- ENV_VARS.get(ai_provider) if loop_type == AgenticLoop.OMNI else "ANTHROPIC_API_KEY"
108
+ ENV_VARS.get(self.model_config.provider)
109
+ if loop_type == AgenticLoop.OMNI
110
+ else "ANTHROPIC_API_KEY"
105
111
  )
106
112
  if not env_var:
107
113
  raise ValueError(
108
- f"Unsupported provider: {ai_provider}. Please use one of: {list(ENV_VARS.keys())}"
114
+ f"Unsupported provider: {self.model_config.provider}. Please use one of: {list(ENV_VARS.keys())}"
109
115
  )
110
116
 
111
117
  api_key = os.environ.get(env_var)
@@ -119,17 +125,51 @@ class ComputerAgent(BaseComputerAgent):
119
125
  )
120
126
  self.api_key = api_key
121
127
 
122
- # Set model based on provider if not specified
123
- if model is None:
124
- if loop_type == AgenticLoop.OMNI:
125
- self.model = DEFAULT_MODELS[ai_provider]
126
- else: # Anthropic loop
127
- self.model = DEFAULT_MODELS[APIProvider.ANTHROPIC]
128
- else:
129
- self.model = model
130
-
131
128
  # Initialize the appropriate loop based on loop_type
132
129
  self.loop = self._init_loop()
130
+
131
+ def _process_model_config(
132
+ self, model_input: Optional[Union[LLM, Dict[str, str], str]], loop_type: AgenticLoop
133
+ ) -> LLM:
134
+ """Process and normalize model configuration.
135
+
136
+ Args:
137
+ model_input: Input model configuration (LLM, dict, string, or None)
138
+ loop_type: The loop type being used
139
+
140
+ Returns:
141
+ Normalized LLM instance
142
+ """
143
+ # Handle case where model_input is None
144
+ if model_input is None:
145
+ # Use Anthropic for Anthropic loop, OpenAI for Omni loop
146
+ default_provider = (
147
+ LLMProvider.ANTHROPIC if loop_type == AgenticLoop.ANTHROPIC else LLMProvider.OPENAI
148
+ )
149
+ return LLM(provider=default_provider)
150
+
151
+ # Handle case where model_input is already a LLM or one of its aliases
152
+ if isinstance(model_input, (LLM, Model, LLMModel)):
153
+ return model_input
154
+
155
+ # Handle case where model_input is a dict
156
+ if isinstance(model_input, dict):
157
+ provider = model_input.get("provider", LLMProvider.OPENAI)
158
+ if isinstance(provider, str):
159
+ provider = LLMProvider(provider)
160
+ return LLM(
161
+ provider=provider,
162
+ name=model_input.get("name")
163
+ )
164
+
165
+ # Handle case where model_input is a string (model name)
166
+ if isinstance(model_input, str):
167
+ default_provider = (
168
+ LLMProvider.ANTHROPIC if loop_type == AgenticLoop.ANTHROPIC else LLMProvider.OPENAI
169
+ )
170
+ return LLM(provider=default_provider, name=model_input)
171
+
172
+ raise ValueError(f"Unsupported model configuration: {model_input}")
133
173
 
134
174
  def _configure_logging(self, verbosity: int):
135
175
  """Configure logging based on verbosity level."""
@@ -162,9 +202,12 @@ class ComputerAgent(BaseComputerAgent):
162
202
  if self.loop_type == AgenticLoop.ANTHROPIC:
163
203
  from ..providers.anthropic.loop import AnthropicLoop
164
204
 
205
+ # Ensure we always have a valid model name
206
+ model_name = self.model_config.name or DEFAULT_MODELS[LLMProvider.ANTHROPIC]
207
+
165
208
  return AnthropicLoop(
166
209
  api_key=self.api_key,
167
- model=self.model,
210
+ model=model_name,
168
211
  computer=self.computer,
169
212
  save_trajectory=self.save_trajectory,
170
213
  base_dir=self.trajectory_dir,
@@ -176,10 +219,13 @@ class ComputerAgent(BaseComputerAgent):
176
219
  if "parser" not in self._kwargs:
177
220
  self._kwargs["parser"] = OmniParser()
178
221
 
222
+ # Ensure we always have a valid model name
223
+ model_name = self.model_config.name or DEFAULT_MODELS[self.model_config.provider]
224
+
179
225
  return OmniLoop(
180
- provider=self.provider,
226
+ provider=self.model_config.provider,
181
227
  api_key=self.api_key,
182
- model=self.model,
228
+ model=model_name,
183
229
  computer=self.computer,
184
230
  save_trajectory=self.save_trajectory,
185
231
  base_dir=self.trajectory_dir,
agent/core/messages.py CHANGED
@@ -37,6 +37,17 @@ class BaseMessageManager:
37
37
  if self.image_retention_config.min_removal_threshold < 1:
38
38
  raise ValueError("min_removal_threshold must be at least 1")
39
39
 
40
+ # Track provider for message formatting
41
+ self.provider = "openai" # Default provider
42
+
43
+ def set_provider(self, provider: str) -> None:
44
+ """Set the current provider to format messages for.
45
+
46
+ Args:
47
+ provider: Provider name (e.g., 'openai', 'anthropic')
48
+ """
49
+ self.provider = provider.lower()
50
+
40
51
  def prepare_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
41
52
  """Prepare messages by applying image retention and caching as configured.
42
53
 
@@ -96,6 +107,10 @@ class BaseMessageManager:
96
107
  Args:
97
108
  messages: Messages to inject caching into
98
109
  """
110
+ # Only apply cache_control for Anthropic API, not OpenAI
111
+ if self.provider != "anthropic":
112
+ return
113
+
99
114
  # Default to caching last 3 turns
100
115
  turns_to_cache = 3
101
116
  for message in reversed(messages):
@@ -219,9 +219,13 @@ class OmniLoop(BaseLoop):
219
219
  if self.client is None:
220
220
  raise RuntimeError("Failed to initialize client")
221
221
 
222
+ # Set the provider in message manager based on current provider
223
+ provider_name = str(self.provider).split(".")[-1].lower() # Extract name from enum
224
+ self.message_manager.set_provider(provider_name)
225
+
222
226
  # Apply image retention and prepare messages
223
227
  # This will limit the number of images based on only_n_most_recent_images
224
- prepared_messages = self.message_manager.prepare_messages(messages.copy())
228
+ prepared_messages = self.message_manager.get_formatted_messages(provider_name)
225
229
 
226
230
  # Filter out system messages for Anthropic
227
231
  if self.provider == APIProvider.ANTHROPIC:
@@ -103,6 +103,9 @@ class OmniMessageManager(BaseMessageManager):
103
103
  Returns:
104
104
  List of formatted messages
105
105
  """
106
+ # Set the provider for message formatting
107
+ self.set_provider(provider)
108
+
106
109
  if provider == "anthropic":
107
110
  return self._format_for_anthropic()
108
111
  elif provider == "openai":
@@ -1,11 +1,12 @@
1
1
  """Type definitions for the Omni provider."""
2
2
 
3
3
  from enum import StrEnum
4
- from typing import Dict
4
+ from typing import Dict, Optional
5
+ from dataclasses import dataclass
5
6
 
6
7
 
7
- class APIProvider(StrEnum):
8
- """Supported API providers."""
8
+ class LLMProvider(StrEnum):
9
+ """Supported LLM providers."""
9
10
 
10
11
  ANTHROPIC = "anthropic"
11
12
  OPENAI = "openai"
@@ -13,18 +14,40 @@ class APIProvider(StrEnum):
13
14
  QWEN = "qwen"
14
15
 
15
16
 
17
+ # For backward compatibility
18
+ APIProvider = LLMProvider
19
+
20
+
21
+ @dataclass
22
+ class LLM:
23
+ """Configuration for LLM model and provider."""
24
+
25
+ provider: LLMProvider
26
+ name: Optional[str] = None
27
+
28
+ def __post_init__(self):
29
+ """Set default model name if not provided."""
30
+ if self.name is None:
31
+ self.name = PROVIDER_TO_DEFAULT_MODEL.get(self.provider)
32
+
33
+
34
+ # For backward compatibility
35
+ LLMModel = LLM
36
+ Model = LLM
37
+
38
+
16
39
  # Default models for each provider
17
- PROVIDER_TO_DEFAULT_MODEL: Dict[APIProvider, str] = {
18
- APIProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
19
- APIProvider.OPENAI: "gpt-4o",
20
- APIProvider.GROQ: "deepseek-r1-distill-llama-70b",
21
- APIProvider.QWEN: "qwen2.5-vl-72b-instruct",
40
+ PROVIDER_TO_DEFAULT_MODEL: Dict[LLMProvider, str] = {
41
+ LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
42
+ LLMProvider.OPENAI: "gpt-4o",
43
+ LLMProvider.GROQ: "deepseek-r1-distill-llama-70b",
44
+ LLMProvider.QWEN: "qwen2.5-vl-72b-instruct",
22
45
  }
23
46
 
24
47
  # Environment variable names for each provider
25
- PROVIDER_TO_ENV_VAR: Dict[APIProvider, str] = {
26
- APIProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
27
- APIProvider.OPENAI: "OPENAI_API_KEY",
28
- APIProvider.GROQ: "GROQ_API_KEY",
29
- APIProvider.QWEN: "QWEN_API_KEY",
48
+ PROVIDER_TO_ENV_VAR: Dict[LLMProvider, str] = {
49
+ LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
50
+ LLMProvider.OPENAI: "OPENAI_API_KEY",
51
+ LLMProvider.GROQ: "GROQ_API_KEY",
52
+ LLMProvider.QWEN: "QWEN_API_KEY",
30
53
  }
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: cua-agent
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Summary: CUA (Computer Use) Agent for AI-driven computer interaction
5
5
  Author-Email: TryCua <gh@trycua.com>
6
6
  Requires-Python: <3.13,>=3.10
@@ -1,15 +1,15 @@
1
1
  agent/README.md,sha256=8EFnLrKejthEcL9bZflQSbvA-KwpiPanBz8TEEwRub8,2153
2
- agent/__init__.py,sha256=16Q828puFb7Ucq_-de49moVCzl1-iDO8Uo5dzFwX0Ag,347
2
+ agent/__init__.py,sha256=5IxjivBoXkpBQyPP3uwCrCoMx7gNZbM1rdVaIn2jxZ4,425
3
3
  agent/core/README.md,sha256=RY4kKEjm_-_Ul2xgY7ntzsXdPe0Tg1wvtOSZ4xp4DN0,3559
4
4
  agent/core/__init__.py,sha256=0htZ-VfsH9ixHB8j_SXu_uv6r3XXsq5TrghFNd-yRNE,709
5
- agent/core/agent.py,sha256=q2x0vFykIavX_FBi4Eq222QCSFmuuekAin4FPrtSGbY,11711
5
+ agent/core/agent.py,sha256=HSVNTiEhlDHMGJiA4CKi33TtwEdsMIvkgZ9nHfk2M8E,13730
6
6
  agent/core/base_agent.py,sha256=MgaMKTwgqNJ1-TgS_mxALoC9COzc7Acg9y7Q8HAFX2c,6266
7
7
  agent/core/callbacks.py,sha256=VbGIf5QkHh3Q0KsLM6wv7hRdIA5WExTVYLm64bckyUA,4306
8
8
  agent/core/computer_agent.py,sha256=JGLMl_PwImUttmQh2amdLlXHS9CUyZ9MW20J1Xid7dM,2417
9
9
  agent/core/experiment.py,sha256=AST1t83eqaGzjoW6KvrhfVIs3ELAR_I70VHq2NsMmNk,7446
10
10
  agent/core/factory.py,sha256=WraOEHWPXBSN4R3DO7M2ctyadodeA8tzHM3dUjdQ_3A,3441
11
11
  agent/core/loop.py,sha256=E-0pz7MaguZQrHs5GP98Oc8C_Iz8ier0vXrD9Ny2HL8,8999
12
- agent/core/messages.py,sha256=Ou0lLEwa2EQCartcTszsvNjCP6sHUxmr2_C9PGzbASg,7163
12
+ agent/core/messages.py,sha256=N8pV8Eh-AJpMuDPRI5OGWUIOU6DRr-pQjK9XU0go9Hk,7637
13
13
  agent/core/tools/__init__.py,sha256=xZen-PqUp2dUaMEHJowXCQm33_5Sxhsx9PSoD0rq6tI,489
14
14
  agent/core/tools/base.py,sha256=CdzRFNuOjNfzgyTUN4ZoCGkUDR5HI0ECQVpvrUdEij8,2295
15
15
  agent/core/tools/bash.py,sha256=jnJKVlHn8np8e0gWd8EO0_qqjMkfQzutSugA_Iol4jE,1585
@@ -43,8 +43,8 @@ agent/providers/omni/clients/openai.py,sha256=E4TAXMUFoYTunJETCWCNx5XAc6xutiN4rB
43
43
  agent/providers/omni/clients/utils.py,sha256=Ani9CVVBm_J2Dl51WG6p1GVuoI6cq8scISrG0pmQ37o,688
44
44
  agent/providers/omni/experiment.py,sha256=JGAdHi7Nf73I48c9k3TY1Xpr_i6D2VG1wurOzw5cNGk,9888
45
45
  agent/providers/omni/image_utils.py,sha256=qIFuNi5cIMVwrqYBXG1T6PxUlbxz7gIngFFP39bZIlU,2782
46
- agent/providers/omni/loop.py,sha256=Xr2QeedAVJ_jHn3KMopRuH3mrm2Qn4ncxKjqj9hWxAw,43577
47
- agent/providers/omni/messages.py,sha256=6LkQfzYDWq2FvIHpqhs5pc0l6AmFx_xKCjj1R5czMPo,6047
46
+ agent/providers/omni/loop.py,sha256=U1R_ayfN4T25hvbLMp97qeqSrqVtSL-U03G8Sqf4AaM,43827
47
+ agent/providers/omni/messages.py,sha256=zdjQCAMH-hOyrQQesHhTiIsQbw43KqVSmVIzS8JOIFA,6134
48
48
  agent/providers/omni/parser.py,sha256=Iv-cXWG2qzdYjyZJH5pGUzfv6nOaiHQ2OXdQSe00Ydw,9151
49
49
  agent/providers/omni/prompts.py,sha256=29qy8ppbLOjLil3aiqryjaiBf8CQx-xXHN44O-85Q00,4503
50
50
  agent/providers/omni/tool_manager.py,sha256=O6DxyEI-Vg6jt99phh011o4q4me_vNhH2YffIxkO4GM,2585
@@ -52,14 +52,14 @@ agent/providers/omni/tools/__init__.py,sha256=l636hx9Q5z9eaFdPanPwPENUE-w-Xm8kAZ
52
52
  agent/providers/omni/tools/bash.py,sha256=y_ibfP9iRcbiU_E0faAoa4DCP_BlkMlKOOURdBBIGZE,2030
53
53
  agent/providers/omni/tools/computer.py,sha256=xkMmAR0e_kbf0Zs2mggCDyWrQOJZyXOKPFjkutaQb94,9108
54
54
  agent/providers/omni/tools/manager.py,sha256=V_tav2yU92PyQnFlxNXG1wvNEaJoEYudtKx5sRjj06Q,2619
55
- agent/providers/omni/types.py,sha256=cEH6M5fcRN8ZIv_jfcYkTYboGBM4EzglLZo1_Xk7Ip8,800
55
+ agent/providers/omni/types.py,sha256=6x-n3MLvvKOFAdvzYDf6Zzw-i118kvWHXE37qxa_L4o,1284
56
56
  agent/providers/omni/utils.py,sha256=JqSye1bEp4wxhUgmaMyZi172fTlgXtygJ7XlnvKdUtE,6337
57
57
  agent/providers/omni/visualization.py,sha256=N3qVQLxYmia3iSVC5oCt5YRlMPuVfylCOyB99R33u8U,3924
58
58
  agent/types/__init__.py,sha256=61UFJT-w0CT4YRn0LiTx4A7fsMdVQjlXO9vnmbI1A7Y,604
59
59
  agent/types/base.py,sha256=rVb4mPWp1SOHfrzOCDqx0pfCV5bgIsdrIzgM_kX_xVs,1090
60
60
  agent/types/messages.py,sha256=4-hwtxeAhto90_EZpHFducddtsHUsHauvXzYrpKG4RE,953
61
61
  agent/types/tools.py,sha256=Jes2CFCFqC727WWHbO-sG7V03rBHnQe5X7Oi9ZkuScI,877
62
- cua_agent-0.1.0.dist-info/METADATA,sha256=Q4nPzYL_UQwx82vuaRLBUFmA_Sgd37TVoGA9FNYDRmU,1890
63
- cua_agent-0.1.0.dist-info/WHEEL,sha256=thaaA2w1JzcGC48WYufAs8nrYZjJm8LqNfnXFOFyCC4,90
64
- cua_agent-0.1.0.dist-info/entry_points.txt,sha256=6OYgBcLyFCUgeqLgnvMyOJxPCWzgy7se4rLPKtNonMs,34
65
- cua_agent-0.1.0.dist-info/RECORD,,
62
+ cua_agent-0.1.1.dist-info/METADATA,sha256=JdbHHQ7uBAlcnLZpZ1eWCiPCDODEOaB50j6XwtIs0Ss,1890
63
+ cua_agent-0.1.1.dist-info/WHEEL,sha256=thaaA2w1JzcGC48WYufAs8nrYZjJm8LqNfnXFOFyCC4,90
64
+ cua_agent-0.1.1.dist-info/entry_points.txt,sha256=6OYgBcLyFCUgeqLgnvMyOJxPCWzgy7se4rLPKtNonMs,34
65
+ cua_agent-0.1.1.dist-info/RECORD,,