cua-agent 0.1.22__py3-none-any.whl → 0.1.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/__init__.py +1 -1
- agent/core/agent.py +9 -4
- agent/core/factory.py +3 -5
- agent/core/provider_config.py +4 -2
- agent/core/types.py +41 -1
- agent/providers/omni/__init__.py +1 -1
- agent/providers/omni/clients/oaicompat.py +177 -0
- agent/providers/omni/loop.py +25 -1
- agent/providers/omni/tools/manager.py +1 -1
- agent/ui/__init__.py +1 -0
- agent/ui/gradio/__init__.py +21 -0
- agent/ui/gradio/app.py +872 -0
- {cua_agent-0.1.22.dist-info → cua_agent-0.1.24.dist-info}/METADATA +74 -2
- {cua_agent-0.1.22.dist-info → cua_agent-0.1.24.dist-info}/RECORD +16 -14
- agent/core/README.md +0 -101
- agent/providers/omni/types.py +0 -47
- {cua_agent-0.1.22.dist-info → cua_agent-0.1.24.dist-info}/WHEEL +0 -0
- {cua_agent-0.1.22.dist-info → cua_agent-0.1.24.dist-info}/entry_points.txt +0 -0
agent/__init__.py
CHANGED
|
@@ -48,7 +48,7 @@ except Exception as e:
|
|
|
48
48
|
# Other issues with telemetry
|
|
49
49
|
logger.warning(f"Error initializing telemetry: {e}")
|
|
50
50
|
|
|
51
|
-
from .
|
|
51
|
+
from .core.types import LLMProvider, LLM
|
|
52
52
|
from .core.factory import AgentLoop
|
|
53
53
|
from .core.agent import ComputerAgent
|
|
54
54
|
|
agent/core/agent.py
CHANGED
|
@@ -6,8 +6,7 @@ import os
|
|
|
6
6
|
from typing import AsyncGenerator, Optional
|
|
7
7
|
|
|
8
8
|
from computer import Computer
|
|
9
|
-
from
|
|
10
|
-
from .. import AgentLoop
|
|
9
|
+
from .types import LLM, AgentLoop
|
|
11
10
|
from .types import AgentResponse
|
|
12
11
|
from .factory import LoopFactory
|
|
13
12
|
from .provider_config import DEFAULT_MODELS, ENV_VARS
|
|
@@ -75,6 +74,7 @@ class ComputerAgent:
|
|
|
75
74
|
# Use the provided LLM object
|
|
76
75
|
self.provider = model.provider
|
|
77
76
|
actual_model_name = model.name or DEFAULT_MODELS.get(self.provider, "")
|
|
77
|
+
self.provider_base_url = getattr(model, "provider_base_url", None)
|
|
78
78
|
|
|
79
79
|
# Ensure we have a valid model name
|
|
80
80
|
if not actual_model_name:
|
|
@@ -86,8 +86,12 @@ class ComputerAgent:
|
|
|
86
86
|
|
|
87
87
|
# Get API key from environment if not provided
|
|
88
88
|
actual_api_key = api_key or os.environ.get(ENV_VARS[self.provider], "")
|
|
89
|
-
# Ollama
|
|
90
|
-
if
|
|
89
|
+
# Ollama and OpenAI-compatible APIs typically don't require an API key
|
|
90
|
+
if (
|
|
91
|
+
not actual_api_key
|
|
92
|
+
and str(self.provider) not in ["ollama", "oaicompat"]
|
|
93
|
+
and ENV_VARS[self.provider] != "none"
|
|
94
|
+
):
|
|
91
95
|
raise ValueError(f"No API key provided for {self.provider}")
|
|
92
96
|
|
|
93
97
|
# Create the appropriate loop using the factory
|
|
@@ -102,6 +106,7 @@ class ComputerAgent:
|
|
|
102
106
|
save_trajectory=save_trajectory,
|
|
103
107
|
trajectory_dir=trajectory_dir,
|
|
104
108
|
only_n_most_recent_images=only_n_most_recent_images,
|
|
109
|
+
provider_base_url=self.provider_base_url,
|
|
105
110
|
)
|
|
106
111
|
except ValueError as e:
|
|
107
112
|
logger.error(f"Failed to create loop: {str(e)}")
|
agent/core/factory.py
CHANGED
|
@@ -8,10 +8,6 @@ from computer import Computer
|
|
|
8
8
|
from .types import AgentLoop
|
|
9
9
|
from .base import BaseLoop
|
|
10
10
|
|
|
11
|
-
# For type checking only
|
|
12
|
-
if TYPE_CHECKING:
|
|
13
|
-
from ..providers.omni.types import LLMProvider
|
|
14
|
-
|
|
15
11
|
logger = logging.getLogger(__name__)
|
|
16
12
|
|
|
17
13
|
|
|
@@ -33,6 +29,7 @@ class LoopFactory:
|
|
|
33
29
|
trajectory_dir: str = "trajectories",
|
|
34
30
|
only_n_most_recent_images: Optional[int] = None,
|
|
35
31
|
acknowledge_safety_check_callback: Optional[Callable[[str], Awaitable[bool]]] = None,
|
|
32
|
+
provider_base_url: Optional[str] = None,
|
|
36
33
|
) -> BaseLoop:
|
|
37
34
|
"""Create and return an appropriate loop instance based on type."""
|
|
38
35
|
if loop_type == AgentLoop.ANTHROPIC:
|
|
@@ -77,7 +74,7 @@ class LoopFactory:
|
|
|
77
74
|
try:
|
|
78
75
|
from ..providers.omni.loop import OmniLoop
|
|
79
76
|
from ..providers.omni.parser import OmniParser
|
|
80
|
-
from
|
|
77
|
+
from .types import LLMProvider
|
|
81
78
|
except ImportError:
|
|
82
79
|
raise ImportError(
|
|
83
80
|
"The 'omni' provider is not installed. "
|
|
@@ -99,6 +96,7 @@ class LoopFactory:
|
|
|
99
96
|
base_dir=trajectory_dir,
|
|
100
97
|
only_n_most_recent_images=only_n_most_recent_images,
|
|
101
98
|
parser=OmniParser(),
|
|
99
|
+
provider_base_url=provider_base_url,
|
|
102
100
|
)
|
|
103
101
|
else:
|
|
104
102
|
raise ValueError(f"Unsupported loop type: {loop_type}")
|
agent/core/provider_config.py
CHANGED
|
@@ -1,17 +1,19 @@
|
|
|
1
1
|
"""Provider-specific configurations and constants."""
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from .types import LLMProvider
|
|
4
4
|
|
|
5
5
|
# Default models for different providers
|
|
6
6
|
DEFAULT_MODELS = {
|
|
7
7
|
LLMProvider.OPENAI: "gpt-4o",
|
|
8
8
|
LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
|
|
9
9
|
LLMProvider.OLLAMA: "gemma3:4b-it-q4_K_M",
|
|
10
|
+
LLMProvider.OAICOMPAT: "Qwen2.5-VL-7B-Instruct",
|
|
10
11
|
}
|
|
11
12
|
|
|
12
13
|
# Map providers to their environment variable names
|
|
13
14
|
ENV_VARS = {
|
|
14
15
|
LLMProvider.OPENAI: "OPENAI_API_KEY",
|
|
15
16
|
LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
|
|
16
|
-
LLMProvider.OLLAMA: "
|
|
17
|
+
LLMProvider.OLLAMA: "none",
|
|
18
|
+
LLMProvider.OAICOMPAT: "none", # OpenAI-compatible API typically doesn't require an API key
|
|
17
19
|
}
|
agent/core/types.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""Core type definitions."""
|
|
2
2
|
|
|
3
3
|
from typing import Any, Dict, List, Optional, TypedDict, Union
|
|
4
|
-
from enum import Enum, auto
|
|
4
|
+
from enum import Enum, StrEnum, auto
|
|
5
|
+
from dataclasses import dataclass
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
class AgentLoop(Enum):
|
|
@@ -14,6 +15,45 @@ class AgentLoop(Enum):
|
|
|
14
15
|
# Add more loop types as needed
|
|
15
16
|
|
|
16
17
|
|
|
18
|
+
class LLMProvider(StrEnum):
|
|
19
|
+
"""Supported LLM providers."""
|
|
20
|
+
|
|
21
|
+
ANTHROPIC = "anthropic"
|
|
22
|
+
OPENAI = "openai"
|
|
23
|
+
OLLAMA = "ollama"
|
|
24
|
+
OAICOMPAT = "oaicompat"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class LLM:
|
|
29
|
+
"""Configuration for LLM model and provider."""
|
|
30
|
+
|
|
31
|
+
provider: LLMProvider
|
|
32
|
+
name: Optional[str] = None
|
|
33
|
+
provider_base_url: Optional[str] = None
|
|
34
|
+
|
|
35
|
+
def __post_init__(self):
|
|
36
|
+
"""Set default model name if not provided."""
|
|
37
|
+
if self.name is None:
|
|
38
|
+
from .provider_config import DEFAULT_MODELS
|
|
39
|
+
|
|
40
|
+
self.name = DEFAULT_MODELS.get(self.provider)
|
|
41
|
+
|
|
42
|
+
# Set default provider URL if none provided
|
|
43
|
+
if self.provider_base_url is None and self.provider == LLMProvider.OAICOMPAT:
|
|
44
|
+
# Default for vLLM
|
|
45
|
+
self.provider_base_url = "http://localhost:8000/v1"
|
|
46
|
+
# Common alternatives:
|
|
47
|
+
# - LM Studio: "http://localhost:1234/v1"
|
|
48
|
+
# - LocalAI: "http://localhost:8080/v1"
|
|
49
|
+
# - Ollama with OpenAI compatible API: "http://localhost:11434/v1"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# For backward compatibility
|
|
53
|
+
LLMModel = LLM
|
|
54
|
+
Model = LLM
|
|
55
|
+
|
|
56
|
+
|
|
17
57
|
class AgentResponse(TypedDict, total=False):
|
|
18
58
|
"""Agent response format."""
|
|
19
59
|
|
agent/providers/omni/__init__.py
CHANGED
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""OpenAI-compatible client implementation."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Dict, List, Optional, Any
|
|
6
|
+
import aiohttp
|
|
7
|
+
import re
|
|
8
|
+
from .base import BaseOmniClient
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# OpenAI-compatible client for the OmniLoop
|
|
14
|
+
class OAICompatClient(BaseOmniClient):
|
|
15
|
+
"""OpenAI-compatible API client implementation.
|
|
16
|
+
|
|
17
|
+
This client can be used with any service that implements the OpenAI API protocol, including:
|
|
18
|
+
- vLLM
|
|
19
|
+
- LM Studio
|
|
20
|
+
- LocalAI
|
|
21
|
+
- Ollama (with OpenAI compatibility)
|
|
22
|
+
- Text Generation WebUI
|
|
23
|
+
- Any other service with OpenAI API compatibility
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
api_key: Optional[str] = None,
|
|
29
|
+
model: str = "Qwen2.5-VL-7B-Instruct",
|
|
30
|
+
provider_base_url: Optional[str] = "http://localhost:8000/v1",
|
|
31
|
+
max_tokens: int = 4096,
|
|
32
|
+
temperature: float = 0.0,
|
|
33
|
+
):
|
|
34
|
+
"""Initialize the OpenAI-compatible client.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
api_key: Not used for local endpoints, usually set to "EMPTY"
|
|
38
|
+
model: Model name to use
|
|
39
|
+
provider_base_url: API base URL. Typically in the format "http://localhost:PORT/v1"
|
|
40
|
+
Examples:
|
|
41
|
+
- vLLM: "http://localhost:8000/v1"
|
|
42
|
+
- LM Studio: "http://localhost:1234/v1"
|
|
43
|
+
- LocalAI: "http://localhost:8080/v1"
|
|
44
|
+
- Ollama: "http://localhost:11434/v1"
|
|
45
|
+
max_tokens: Maximum tokens to generate
|
|
46
|
+
temperature: Generation temperature
|
|
47
|
+
"""
|
|
48
|
+
super().__init__(api_key="EMPTY", model=model)
|
|
49
|
+
self.api_key = "EMPTY" # Local endpoints typically don't require an API key
|
|
50
|
+
self.model = model
|
|
51
|
+
self.provider_base_url = (
|
|
52
|
+
provider_base_url or "http://localhost:8000/v1"
|
|
53
|
+
) # Use default if None
|
|
54
|
+
self.max_tokens = max_tokens
|
|
55
|
+
self.temperature = temperature
|
|
56
|
+
|
|
57
|
+
def _extract_base64_image(self, text: str) -> Optional[str]:
|
|
58
|
+
"""Extract base64 image data from an HTML img tag."""
|
|
59
|
+
pattern = r'data:image/[^;]+;base64,([^"]+)'
|
|
60
|
+
match = re.search(pattern, text)
|
|
61
|
+
return match.group(1) if match else None
|
|
62
|
+
|
|
63
|
+
def _get_loggable_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
64
|
+
"""Create a loggable version of messages with image data truncated."""
|
|
65
|
+
loggable_messages = []
|
|
66
|
+
for msg in messages:
|
|
67
|
+
if isinstance(msg.get("content"), list):
|
|
68
|
+
new_content = []
|
|
69
|
+
for content in msg["content"]:
|
|
70
|
+
if content.get("type") == "image":
|
|
71
|
+
new_content.append(
|
|
72
|
+
{"type": "image", "image_url": {"url": "[BASE64_IMAGE_DATA]"}}
|
|
73
|
+
)
|
|
74
|
+
else:
|
|
75
|
+
new_content.append(content)
|
|
76
|
+
loggable_messages.append({"role": msg["role"], "content": new_content})
|
|
77
|
+
else:
|
|
78
|
+
loggable_messages.append(msg)
|
|
79
|
+
return loggable_messages
|
|
80
|
+
|
|
81
|
+
async def run_interleaved(
|
|
82
|
+
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
|
83
|
+
) -> Dict[str, Any]:
|
|
84
|
+
"""Run interleaved chat completion.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
messages: List of message dicts
|
|
88
|
+
system: System prompt
|
|
89
|
+
max_tokens: Optional max tokens override
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Response dict
|
|
93
|
+
"""
|
|
94
|
+
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
|
95
|
+
|
|
96
|
+
final_messages = [{"role": "system", "content": system}]
|
|
97
|
+
|
|
98
|
+
# Process messages
|
|
99
|
+
for item in messages:
|
|
100
|
+
if isinstance(item, dict):
|
|
101
|
+
if isinstance(item["content"], list):
|
|
102
|
+
# Content is already in the correct format
|
|
103
|
+
final_messages.append(item)
|
|
104
|
+
else:
|
|
105
|
+
# Single string content, check for image
|
|
106
|
+
base64_img = self._extract_base64_image(item["content"])
|
|
107
|
+
if base64_img:
|
|
108
|
+
message = {
|
|
109
|
+
"role": item["role"],
|
|
110
|
+
"content": [
|
|
111
|
+
{
|
|
112
|
+
"type": "image_url",
|
|
113
|
+
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
|
|
114
|
+
}
|
|
115
|
+
],
|
|
116
|
+
}
|
|
117
|
+
else:
|
|
118
|
+
message = {
|
|
119
|
+
"role": item["role"],
|
|
120
|
+
"content": [{"type": "text", "text": item["content"]}],
|
|
121
|
+
}
|
|
122
|
+
final_messages.append(message)
|
|
123
|
+
else:
|
|
124
|
+
# String content, check for image
|
|
125
|
+
base64_img = self._extract_base64_image(item)
|
|
126
|
+
if base64_img:
|
|
127
|
+
message = {
|
|
128
|
+
"role": "user",
|
|
129
|
+
"content": [
|
|
130
|
+
{
|
|
131
|
+
"type": "image_url",
|
|
132
|
+
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
|
|
133
|
+
}
|
|
134
|
+
],
|
|
135
|
+
}
|
|
136
|
+
else:
|
|
137
|
+
message = {"role": "user", "content": [{"type": "text", "text": item}]}
|
|
138
|
+
final_messages.append(message)
|
|
139
|
+
|
|
140
|
+
payload = {"model": self.model, "messages": final_messages, "temperature": self.temperature}
|
|
141
|
+
payload["max_tokens"] = max_tokens or self.max_tokens
|
|
142
|
+
|
|
143
|
+
try:
|
|
144
|
+
async with aiohttp.ClientSession() as session:
|
|
145
|
+
# Use default base URL if none provided
|
|
146
|
+
base_url = self.provider_base_url or "http://localhost:8000/v1"
|
|
147
|
+
|
|
148
|
+
# Check if the base URL already includes the chat/completions endpoint
|
|
149
|
+
endpoint_url = base_url
|
|
150
|
+
if not endpoint_url.endswith("/chat/completions"):
|
|
151
|
+
# If the URL ends with /v1, append /chat/completions
|
|
152
|
+
if endpoint_url.endswith("/v1"):
|
|
153
|
+
endpoint_url = f"{endpoint_url}/chat/completions"
|
|
154
|
+
# If the URL doesn't end with /v1, make sure it has a proper structure
|
|
155
|
+
elif not endpoint_url.endswith("/"):
|
|
156
|
+
endpoint_url = f"{endpoint_url}/chat/completions"
|
|
157
|
+
else:
|
|
158
|
+
endpoint_url = f"{endpoint_url}chat/completions"
|
|
159
|
+
|
|
160
|
+
# Log the endpoint URL for debugging
|
|
161
|
+
logger.debug(f"Using endpoint URL: {endpoint_url}")
|
|
162
|
+
|
|
163
|
+
async with session.post(endpoint_url, headers=headers, json=payload) as response:
|
|
164
|
+
response_json = await response.json()
|
|
165
|
+
|
|
166
|
+
if response.status != 200:
|
|
167
|
+
error_msg = response_json.get("error", {}).get(
|
|
168
|
+
"message", str(response_json)
|
|
169
|
+
)
|
|
170
|
+
logger.error(f"Error in API call: {error_msg}")
|
|
171
|
+
raise Exception(f"API error: {error_msg}")
|
|
172
|
+
|
|
173
|
+
return response_json
|
|
174
|
+
|
|
175
|
+
except Exception as e:
|
|
176
|
+
logger.error(f"Error in API call: {str(e)}")
|
|
177
|
+
raise
|
agent/providers/omni/loop.py
CHANGED
|
@@ -16,10 +16,11 @@ from ...core.messages import StandardMessageManager, ImageRetentionConfig
|
|
|
16
16
|
from .utils import to_openai_agent_response_format
|
|
17
17
|
from ...core.types import AgentResponse
|
|
18
18
|
from computer import Computer
|
|
19
|
-
from .types import LLMProvider
|
|
19
|
+
from ...core.types import LLMProvider
|
|
20
20
|
from .clients.openai import OpenAIClient
|
|
21
21
|
from .clients.anthropic import AnthropicClient
|
|
22
22
|
from .clients.ollama import OllamaClient
|
|
23
|
+
from .clients.oaicompat import OAICompatClient
|
|
23
24
|
from .prompts import SYSTEM_PROMPT
|
|
24
25
|
from .api_handler import OmniAPIHandler
|
|
25
26
|
from .tools.manager import ToolManager
|
|
@@ -60,6 +61,7 @@ class OmniLoop(BaseLoop):
|
|
|
60
61
|
max_retries: int = 3,
|
|
61
62
|
retry_delay: float = 1.0,
|
|
62
63
|
save_trajectory: bool = True,
|
|
64
|
+
provider_base_url: Optional[str] = None,
|
|
63
65
|
**kwargs,
|
|
64
66
|
):
|
|
65
67
|
"""Initialize the loop.
|
|
@@ -75,10 +77,12 @@ class OmniLoop(BaseLoop):
|
|
|
75
77
|
max_retries: Maximum number of retries for API calls
|
|
76
78
|
retry_delay: Delay between retries in seconds
|
|
77
79
|
save_trajectory: Whether to save trajectory data
|
|
80
|
+
provider_base_url: Base URL for the API provider (used for OAICOMPAT)
|
|
78
81
|
"""
|
|
79
82
|
# Set parser and provider before initializing base class
|
|
80
83
|
self.parser = parser
|
|
81
84
|
self.provider = provider
|
|
85
|
+
self.provider_base_url = provider_base_url
|
|
82
86
|
|
|
83
87
|
# Initialize message manager with image retention config
|
|
84
88
|
self.message_manager = StandardMessageManager(
|
|
@@ -141,6 +145,12 @@ class OmniLoop(BaseLoop):
|
|
|
141
145
|
api_key=self.api_key,
|
|
142
146
|
model=self.model,
|
|
143
147
|
)
|
|
148
|
+
elif self.provider == LLMProvider.OAICOMPAT:
|
|
149
|
+
self.client = OAICompatClient(
|
|
150
|
+
api_key="EMPTY", # Local endpoints typically don't require an API key
|
|
151
|
+
model=self.model,
|
|
152
|
+
provider_base_url=self.provider_base_url,
|
|
153
|
+
)
|
|
144
154
|
else:
|
|
145
155
|
raise ValueError(f"Unsupported provider: {self.provider}")
|
|
146
156
|
|
|
@@ -171,6 +181,12 @@ class OmniLoop(BaseLoop):
|
|
|
171
181
|
api_key=self.api_key,
|
|
172
182
|
model=self.model,
|
|
173
183
|
)
|
|
184
|
+
elif self.provider == LLMProvider.OAICOMPAT:
|
|
185
|
+
self.client = OAICompatClient(
|
|
186
|
+
api_key="EMPTY", # Local endpoints typically don't require an API key
|
|
187
|
+
model=self.model,
|
|
188
|
+
provider_base_url=self.provider_base_url,
|
|
189
|
+
)
|
|
174
190
|
else:
|
|
175
191
|
raise ValueError(f"Unsupported provider: {self.provider}")
|
|
176
192
|
|
|
@@ -388,6 +404,14 @@ class OmniLoop(BaseLoop):
|
|
|
388
404
|
except (KeyError, TypeError, IndexError) as e:
|
|
389
405
|
logger.error(f"Invalid response format: {str(e)}")
|
|
390
406
|
return True, action_screenshot_saved
|
|
407
|
+
elif self.provider == LLMProvider.OAICOMPAT:
|
|
408
|
+
try:
|
|
409
|
+
# OpenAI-compatible response format
|
|
410
|
+
raw_text = response["choices"][0]["message"]["content"]
|
|
411
|
+
standard_content = [{"type": "text", "text": raw_text}]
|
|
412
|
+
except (KeyError, TypeError, IndexError) as e:
|
|
413
|
+
logger.error(f"Invalid response format: {str(e)}")
|
|
414
|
+
return True, action_screenshot_saved
|
|
391
415
|
else:
|
|
392
416
|
# Assume OpenAI or compatible format
|
|
393
417
|
try:
|
|
@@ -7,7 +7,7 @@ from ....core.tools import BaseToolManager, ToolResult
|
|
|
7
7
|
from ....core.tools.collection import ToolCollection
|
|
8
8
|
from .computer import ComputerTool
|
|
9
9
|
from .bash import BashTool
|
|
10
|
-
from
|
|
10
|
+
from ....core.types import LLMProvider
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class ToolManager(BaseToolManager):
|
agent/ui/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""UI modules for the Computer-Use Agent."""
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Gradio UI for Computer-Use Agent."""
|
|
2
|
+
|
|
3
|
+
import gradio as gr
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from .app import create_gradio_ui
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def registry(name: str = "cua:gpt-4o") -> gr.Blocks:
|
|
10
|
+
"""Create and register a Gradio UI for the Computer-Use Agent.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
name: The name to use for the Gradio app, in format 'provider:model'
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
A Gradio Blocks application
|
|
17
|
+
"""
|
|
18
|
+
provider, model = name.split(":", 1) if ":" in name else ("openai", name)
|
|
19
|
+
|
|
20
|
+
# Create and return the Gradio UI
|
|
21
|
+
return create_gradio_ui(provider_name=provider, model_name=model)
|