cua-agent 0.1.21__py3-none-any.whl → 0.1.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/__init__.py +1 -1
- agent/core/agent.py +9 -3
- agent/core/factory.py +3 -5
- agent/core/provider_config.py +5 -1
- agent/core/types.py +59 -1
- agent/providers/omni/__init__.py +1 -1
- agent/providers/omni/clients/base.py +8 -17
- agent/providers/omni/clients/oaicompat.py +177 -0
- agent/providers/omni/clients/ollama.py +122 -0
- agent/providers/omni/clients/openai.py +0 -4
- agent/providers/omni/loop.py +43 -1
- agent/providers/omni/tools/manager.py +1 -1
- agent/ui/__init__.py +1 -0
- agent/ui/gradio/__init__.py +21 -0
- agent/ui/gradio/app.py +872 -0
- {cua_agent-0.1.21.dist-info → cua_agent-0.1.23.dist-info}/METADATA +67 -3
- {cua_agent-0.1.21.dist-info → cua_agent-0.1.23.dist-info}/RECORD +19 -16
- agent/core/README.md +0 -101
- agent/providers/omni/types.py +0 -44
- {cua_agent-0.1.21.dist-info → cua_agent-0.1.23.dist-info}/WHEEL +0 -0
- {cua_agent-0.1.21.dist-info → cua_agent-0.1.23.dist-info}/entry_points.txt +0 -0
agent/__init__.py
CHANGED
|
@@ -48,7 +48,7 @@ except Exception as e:
|
|
|
48
48
|
# Other issues with telemetry
|
|
49
49
|
logger.warning(f"Error initializing telemetry: {e}")
|
|
50
50
|
|
|
51
|
-
from .
|
|
51
|
+
from .core.types import LLMProvider, LLM
|
|
52
52
|
from .core.factory import AgentLoop
|
|
53
53
|
from .core.agent import ComputerAgent
|
|
54
54
|
|
agent/core/agent.py
CHANGED
|
@@ -6,8 +6,7 @@ import os
|
|
|
6
6
|
from typing import AsyncGenerator, Optional
|
|
7
7
|
|
|
8
8
|
from computer import Computer
|
|
9
|
-
from
|
|
10
|
-
from .. import AgentLoop
|
|
9
|
+
from .types import LLM, AgentLoop
|
|
11
10
|
from .types import AgentResponse
|
|
12
11
|
from .factory import LoopFactory
|
|
13
12
|
from .provider_config import DEFAULT_MODELS, ENV_VARS
|
|
@@ -75,6 +74,7 @@ class ComputerAgent:
|
|
|
75
74
|
# Use the provided LLM object
|
|
76
75
|
self.provider = model.provider
|
|
77
76
|
actual_model_name = model.name or DEFAULT_MODELS.get(self.provider, "")
|
|
77
|
+
self.provider_base_url = getattr(model, "provider_base_url", None)
|
|
78
78
|
|
|
79
79
|
# Ensure we have a valid model name
|
|
80
80
|
if not actual_model_name:
|
|
@@ -86,7 +86,12 @@ class ComputerAgent:
|
|
|
86
86
|
|
|
87
87
|
# Get API key from environment if not provided
|
|
88
88
|
actual_api_key = api_key or os.environ.get(ENV_VARS[self.provider], "")
|
|
89
|
-
|
|
89
|
+
# Ollama and OpenAI-compatible APIs typically don't require an API key
|
|
90
|
+
if (
|
|
91
|
+
not actual_api_key
|
|
92
|
+
and str(self.provider) not in ["ollama", "oaicompat"]
|
|
93
|
+
and ENV_VARS[self.provider] != "none"
|
|
94
|
+
):
|
|
90
95
|
raise ValueError(f"No API key provided for {self.provider}")
|
|
91
96
|
|
|
92
97
|
# Create the appropriate loop using the factory
|
|
@@ -101,6 +106,7 @@ class ComputerAgent:
|
|
|
101
106
|
save_trajectory=save_trajectory,
|
|
102
107
|
trajectory_dir=trajectory_dir,
|
|
103
108
|
only_n_most_recent_images=only_n_most_recent_images,
|
|
109
|
+
provider_base_url=self.provider_base_url,
|
|
104
110
|
)
|
|
105
111
|
except ValueError as e:
|
|
106
112
|
logger.error(f"Failed to create loop: {str(e)}")
|
agent/core/factory.py
CHANGED
|
@@ -8,10 +8,6 @@ from computer import Computer
|
|
|
8
8
|
from .types import AgentLoop
|
|
9
9
|
from .base import BaseLoop
|
|
10
10
|
|
|
11
|
-
# For type checking only
|
|
12
|
-
if TYPE_CHECKING:
|
|
13
|
-
from ..providers.omni.types import LLMProvider
|
|
14
|
-
|
|
15
11
|
logger = logging.getLogger(__name__)
|
|
16
12
|
|
|
17
13
|
|
|
@@ -33,6 +29,7 @@ class LoopFactory:
|
|
|
33
29
|
trajectory_dir: str = "trajectories",
|
|
34
30
|
only_n_most_recent_images: Optional[int] = None,
|
|
35
31
|
acknowledge_safety_check_callback: Optional[Callable[[str], Awaitable[bool]]] = None,
|
|
32
|
+
provider_base_url: Optional[str] = None,
|
|
36
33
|
) -> BaseLoop:
|
|
37
34
|
"""Create and return an appropriate loop instance based on type."""
|
|
38
35
|
if loop_type == AgentLoop.ANTHROPIC:
|
|
@@ -77,7 +74,7 @@ class LoopFactory:
|
|
|
77
74
|
try:
|
|
78
75
|
from ..providers.omni.loop import OmniLoop
|
|
79
76
|
from ..providers.omni.parser import OmniParser
|
|
80
|
-
from
|
|
77
|
+
from .types import LLMProvider
|
|
81
78
|
except ImportError:
|
|
82
79
|
raise ImportError(
|
|
83
80
|
"The 'omni' provider is not installed. "
|
|
@@ -99,6 +96,7 @@ class LoopFactory:
|
|
|
99
96
|
base_dir=trajectory_dir,
|
|
100
97
|
only_n_most_recent_images=only_n_most_recent_images,
|
|
101
98
|
parser=OmniParser(),
|
|
99
|
+
provider_base_url=provider_base_url,
|
|
102
100
|
)
|
|
103
101
|
else:
|
|
104
102
|
raise ValueError(f"Unsupported loop type: {loop_type}")
|
agent/core/provider_config.py
CHANGED
|
@@ -1,15 +1,19 @@
|
|
|
1
1
|
"""Provider-specific configurations and constants."""
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from .types import LLMProvider
|
|
4
4
|
|
|
5
5
|
# Default models for different providers
|
|
6
6
|
DEFAULT_MODELS = {
|
|
7
7
|
LLMProvider.OPENAI: "gpt-4o",
|
|
8
8
|
LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
|
|
9
|
+
LLMProvider.OLLAMA: "gemma3:4b-it-q4_K_M",
|
|
10
|
+
LLMProvider.OAICOMPAT: "Qwen2.5-VL-7B-Instruct",
|
|
9
11
|
}
|
|
10
12
|
|
|
11
13
|
# Map providers to their environment variable names
|
|
12
14
|
ENV_VARS = {
|
|
13
15
|
LLMProvider.OPENAI: "OPENAI_API_KEY",
|
|
14
16
|
LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
|
|
17
|
+
LLMProvider.OLLAMA: "none",
|
|
18
|
+
LLMProvider.OAICOMPAT: "none", # OpenAI-compatible API typically doesn't require an API key
|
|
15
19
|
}
|
agent/core/types.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""Core type definitions."""
|
|
2
2
|
|
|
3
3
|
from typing import Any, Dict, List, Optional, TypedDict, Union
|
|
4
|
-
from enum import Enum, auto
|
|
4
|
+
from enum import Enum, StrEnum, auto
|
|
5
|
+
from dataclasses import dataclass
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
class AgentLoop(Enum):
|
|
@@ -10,9 +11,66 @@ class AgentLoop(Enum):
|
|
|
10
11
|
ANTHROPIC = auto() # Anthropic implementation
|
|
11
12
|
OMNI = auto() # OmniLoop implementation
|
|
12
13
|
OPENAI = auto() # OpenAI implementation
|
|
14
|
+
OLLAMA = auto() # OLLAMA implementation
|
|
13
15
|
# Add more loop types as needed
|
|
14
16
|
|
|
15
17
|
|
|
18
|
+
class LLMProvider(StrEnum):
|
|
19
|
+
"""Supported LLM providers."""
|
|
20
|
+
|
|
21
|
+
ANTHROPIC = "anthropic"
|
|
22
|
+
OPENAI = "openai"
|
|
23
|
+
OLLAMA = "ollama"
|
|
24
|
+
OAICOMPAT = "oaicompat"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class LLM:
|
|
29
|
+
"""Configuration for LLM model and provider."""
|
|
30
|
+
|
|
31
|
+
provider: LLMProvider
|
|
32
|
+
name: Optional[str] = None
|
|
33
|
+
provider_base_url: Optional[str] = None
|
|
34
|
+
|
|
35
|
+
def __post_init__(self):
|
|
36
|
+
"""Set default model name if not provided."""
|
|
37
|
+
if self.name is None:
|
|
38
|
+
from .provider_config import DEFAULT_MODELS
|
|
39
|
+
|
|
40
|
+
self.name = DEFAULT_MODELS.get(self.provider)
|
|
41
|
+
|
|
42
|
+
# Set default provider URL if none provided
|
|
43
|
+
if self.provider_base_url is None and self.provider == LLMProvider.OAICOMPAT:
|
|
44
|
+
# Default for vLLM
|
|
45
|
+
self.provider_base_url = "http://localhost:8000/v1"
|
|
46
|
+
# Common alternatives:
|
|
47
|
+
# - LM Studio: "http://localhost:1234/v1"
|
|
48
|
+
# - LocalAI: "http://localhost:8080/v1"
|
|
49
|
+
# - Ollama with OpenAI compatible API: "http://localhost:11434/v1"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# For backward compatibility
|
|
53
|
+
LLMModel = LLM
|
|
54
|
+
Model = LLM
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# Default models for each provider
|
|
58
|
+
PROVIDER_TO_DEFAULT_MODEL: Dict[LLMProvider, str] = {
|
|
59
|
+
LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
|
|
60
|
+
LLMProvider.OPENAI: "gpt-4o",
|
|
61
|
+
LLMProvider.OLLAMA: "gemma3:4b-it-q4_K_M",
|
|
62
|
+
LLMProvider.OAICOMPAT: "Qwen2.5-VL-7B-Instruct",
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
# Environment variable names for each provider
|
|
66
|
+
PROVIDER_TO_ENV_VAR: Dict[LLMProvider, str] = {
|
|
67
|
+
LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
|
|
68
|
+
LLMProvider.OPENAI: "OPENAI_API_KEY",
|
|
69
|
+
LLMProvider.OLLAMA: "none",
|
|
70
|
+
LLMProvider.OAICOMPAT: "none",
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
|
|
16
74
|
class AgentResponse(TypedDict, total=False):
|
|
17
75
|
"""Agent response format."""
|
|
18
76
|
|
agent/providers/omni/__init__.py
CHANGED
|
@@ -1,43 +1,34 @@
|
|
|
1
1
|
"""Base client implementation for Omni providers."""
|
|
2
2
|
|
|
3
|
-
import os
|
|
4
3
|
import logging
|
|
5
4
|
from typing import Dict, List, Optional, Any, Tuple
|
|
6
|
-
import aiohttp
|
|
7
|
-
import json
|
|
8
5
|
|
|
9
6
|
logger = logging.getLogger(__name__)
|
|
10
7
|
|
|
8
|
+
|
|
11
9
|
class BaseOmniClient:
|
|
12
10
|
"""Base class for provider-specific clients."""
|
|
13
|
-
|
|
14
|
-
def __init__(
|
|
15
|
-
self,
|
|
16
|
-
api_key: Optional[str] = None,
|
|
17
|
-
model: Optional[str] = None
|
|
18
|
-
):
|
|
11
|
+
|
|
12
|
+
def __init__(self, api_key: Optional[str] = None, model: Optional[str] = None):
|
|
19
13
|
"""Initialize base client.
|
|
20
|
-
|
|
14
|
+
|
|
21
15
|
Args:
|
|
22
16
|
api_key: Optional API key
|
|
23
17
|
model: Optional model name
|
|
24
18
|
"""
|
|
25
19
|
self.api_key = api_key
|
|
26
20
|
self.model = model
|
|
27
|
-
|
|
21
|
+
|
|
28
22
|
async def run_interleaved(
|
|
29
|
-
self,
|
|
30
|
-
messages: List[Dict[str, Any]],
|
|
31
|
-
system: str,
|
|
32
|
-
max_tokens: Optional[int] = None
|
|
23
|
+
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
|
33
24
|
) -> Dict[str, Any]:
|
|
34
25
|
"""Run interleaved chat completion.
|
|
35
|
-
|
|
26
|
+
|
|
36
27
|
Args:
|
|
37
28
|
messages: List of message dicts
|
|
38
29
|
system: System prompt
|
|
39
30
|
max_tokens: Optional max tokens override
|
|
40
|
-
|
|
31
|
+
|
|
41
32
|
Returns:
|
|
42
33
|
Response dict
|
|
43
34
|
"""
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""OpenAI-compatible client implementation."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Dict, List, Optional, Any
|
|
6
|
+
import aiohttp
|
|
7
|
+
import re
|
|
8
|
+
from .base import BaseOmniClient
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# OpenAI-compatible client for the OmniLoop
|
|
14
|
+
class OAICompatClient(BaseOmniClient):
|
|
15
|
+
"""OpenAI-compatible API client implementation.
|
|
16
|
+
|
|
17
|
+
This client can be used with any service that implements the OpenAI API protocol, including:
|
|
18
|
+
- vLLM
|
|
19
|
+
- LM Studio
|
|
20
|
+
- LocalAI
|
|
21
|
+
- Ollama (with OpenAI compatibility)
|
|
22
|
+
- Text Generation WebUI
|
|
23
|
+
- Any other service with OpenAI API compatibility
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
api_key: Optional[str] = None,
|
|
29
|
+
model: str = "Qwen2.5-VL-7B-Instruct",
|
|
30
|
+
provider_base_url: Optional[str] = "http://localhost:8000/v1",
|
|
31
|
+
max_tokens: int = 4096,
|
|
32
|
+
temperature: float = 0.0,
|
|
33
|
+
):
|
|
34
|
+
"""Initialize the OpenAI-compatible client.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
api_key: Not used for local endpoints, usually set to "EMPTY"
|
|
38
|
+
model: Model name to use
|
|
39
|
+
provider_base_url: API base URL. Typically in the format "http://localhost:PORT/v1"
|
|
40
|
+
Examples:
|
|
41
|
+
- vLLM: "http://localhost:8000/v1"
|
|
42
|
+
- LM Studio: "http://localhost:1234/v1"
|
|
43
|
+
- LocalAI: "http://localhost:8080/v1"
|
|
44
|
+
- Ollama: "http://localhost:11434/v1"
|
|
45
|
+
max_tokens: Maximum tokens to generate
|
|
46
|
+
temperature: Generation temperature
|
|
47
|
+
"""
|
|
48
|
+
super().__init__(api_key="EMPTY", model=model)
|
|
49
|
+
self.api_key = "EMPTY" # Local endpoints typically don't require an API key
|
|
50
|
+
self.model = model
|
|
51
|
+
self.provider_base_url = (
|
|
52
|
+
provider_base_url or "http://localhost:8000/v1"
|
|
53
|
+
) # Use default if None
|
|
54
|
+
self.max_tokens = max_tokens
|
|
55
|
+
self.temperature = temperature
|
|
56
|
+
|
|
57
|
+
def _extract_base64_image(self, text: str) -> Optional[str]:
|
|
58
|
+
"""Extract base64 image data from an HTML img tag."""
|
|
59
|
+
pattern = r'data:image/[^;]+;base64,([^"]+)'
|
|
60
|
+
match = re.search(pattern, text)
|
|
61
|
+
return match.group(1) if match else None
|
|
62
|
+
|
|
63
|
+
def _get_loggable_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
64
|
+
"""Create a loggable version of messages with image data truncated."""
|
|
65
|
+
loggable_messages = []
|
|
66
|
+
for msg in messages:
|
|
67
|
+
if isinstance(msg.get("content"), list):
|
|
68
|
+
new_content = []
|
|
69
|
+
for content in msg["content"]:
|
|
70
|
+
if content.get("type") == "image":
|
|
71
|
+
new_content.append(
|
|
72
|
+
{"type": "image", "image_url": {"url": "[BASE64_IMAGE_DATA]"}}
|
|
73
|
+
)
|
|
74
|
+
else:
|
|
75
|
+
new_content.append(content)
|
|
76
|
+
loggable_messages.append({"role": msg["role"], "content": new_content})
|
|
77
|
+
else:
|
|
78
|
+
loggable_messages.append(msg)
|
|
79
|
+
return loggable_messages
|
|
80
|
+
|
|
81
|
+
async def run_interleaved(
|
|
82
|
+
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
|
83
|
+
) -> Dict[str, Any]:
|
|
84
|
+
"""Run interleaved chat completion.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
messages: List of message dicts
|
|
88
|
+
system: System prompt
|
|
89
|
+
max_tokens: Optional max tokens override
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Response dict
|
|
93
|
+
"""
|
|
94
|
+
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
|
95
|
+
|
|
96
|
+
final_messages = [{"role": "system", "content": system}]
|
|
97
|
+
|
|
98
|
+
# Process messages
|
|
99
|
+
for item in messages:
|
|
100
|
+
if isinstance(item, dict):
|
|
101
|
+
if isinstance(item["content"], list):
|
|
102
|
+
# Content is already in the correct format
|
|
103
|
+
final_messages.append(item)
|
|
104
|
+
else:
|
|
105
|
+
# Single string content, check for image
|
|
106
|
+
base64_img = self._extract_base64_image(item["content"])
|
|
107
|
+
if base64_img:
|
|
108
|
+
message = {
|
|
109
|
+
"role": item["role"],
|
|
110
|
+
"content": [
|
|
111
|
+
{
|
|
112
|
+
"type": "image_url",
|
|
113
|
+
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
|
|
114
|
+
}
|
|
115
|
+
],
|
|
116
|
+
}
|
|
117
|
+
else:
|
|
118
|
+
message = {
|
|
119
|
+
"role": item["role"],
|
|
120
|
+
"content": [{"type": "text", "text": item["content"]}],
|
|
121
|
+
}
|
|
122
|
+
final_messages.append(message)
|
|
123
|
+
else:
|
|
124
|
+
# String content, check for image
|
|
125
|
+
base64_img = self._extract_base64_image(item)
|
|
126
|
+
if base64_img:
|
|
127
|
+
message = {
|
|
128
|
+
"role": "user",
|
|
129
|
+
"content": [
|
|
130
|
+
{
|
|
131
|
+
"type": "image_url",
|
|
132
|
+
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
|
|
133
|
+
}
|
|
134
|
+
],
|
|
135
|
+
}
|
|
136
|
+
else:
|
|
137
|
+
message = {"role": "user", "content": [{"type": "text", "text": item}]}
|
|
138
|
+
final_messages.append(message)
|
|
139
|
+
|
|
140
|
+
payload = {"model": self.model, "messages": final_messages, "temperature": self.temperature}
|
|
141
|
+
payload["max_tokens"] = max_tokens or self.max_tokens
|
|
142
|
+
|
|
143
|
+
try:
|
|
144
|
+
async with aiohttp.ClientSession() as session:
|
|
145
|
+
# Use default base URL if none provided
|
|
146
|
+
base_url = self.provider_base_url or "http://localhost:8000/v1"
|
|
147
|
+
|
|
148
|
+
# Check if the base URL already includes the chat/completions endpoint
|
|
149
|
+
endpoint_url = base_url
|
|
150
|
+
if not endpoint_url.endswith("/chat/completions"):
|
|
151
|
+
# If the URL ends with /v1, append /chat/completions
|
|
152
|
+
if endpoint_url.endswith("/v1"):
|
|
153
|
+
endpoint_url = f"{endpoint_url}/chat/completions"
|
|
154
|
+
# If the URL doesn't end with /v1, make sure it has a proper structure
|
|
155
|
+
elif not endpoint_url.endswith("/"):
|
|
156
|
+
endpoint_url = f"{endpoint_url}/chat/completions"
|
|
157
|
+
else:
|
|
158
|
+
endpoint_url = f"{endpoint_url}chat/completions"
|
|
159
|
+
|
|
160
|
+
# Log the endpoint URL for debugging
|
|
161
|
+
logger.debug(f"Using endpoint URL: {endpoint_url}")
|
|
162
|
+
|
|
163
|
+
async with session.post(endpoint_url, headers=headers, json=payload) as response:
|
|
164
|
+
response_json = await response.json()
|
|
165
|
+
|
|
166
|
+
if response.status != 200:
|
|
167
|
+
error_msg = response_json.get("error", {}).get(
|
|
168
|
+
"message", str(response_json)
|
|
169
|
+
)
|
|
170
|
+
logger.error(f"Error in API call: {error_msg}")
|
|
171
|
+
raise Exception(f"API error: {error_msg}")
|
|
172
|
+
|
|
173
|
+
return response_json
|
|
174
|
+
|
|
175
|
+
except Exception as e:
|
|
176
|
+
logger.error(f"Error in API call: {str(e)}")
|
|
177
|
+
raise
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
"""Ollama API client implementation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple, cast
|
|
5
|
+
import asyncio
|
|
6
|
+
from httpx import ConnectError, ReadTimeout
|
|
7
|
+
|
|
8
|
+
from ollama import AsyncClient, Options
|
|
9
|
+
from ollama import Message
|
|
10
|
+
from .base import BaseOmniClient
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class OllamaClient(BaseOmniClient):
|
|
16
|
+
"""Client for making calls to Ollama API."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, api_key: str, model: str, max_retries: int = 3, retry_delay: float = 1.0):
|
|
19
|
+
"""Initialize the Ollama client.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
api_key: Not used
|
|
23
|
+
model: Ollama model name (e.g. "gemma3:4b-it-q4_K_M")
|
|
24
|
+
max_retries: Maximum number of retries for API calls
|
|
25
|
+
retry_delay: Base delay between retries in seconds
|
|
26
|
+
"""
|
|
27
|
+
if not model:
|
|
28
|
+
raise ValueError("Model name must be provided")
|
|
29
|
+
|
|
30
|
+
self.client = AsyncClient(
|
|
31
|
+
host="http://localhost:11434",
|
|
32
|
+
)
|
|
33
|
+
self.model: str = model # Add explicit type annotation
|
|
34
|
+
self.max_retries = max_retries
|
|
35
|
+
self.retry_delay = retry_delay
|
|
36
|
+
|
|
37
|
+
def _convert_message_format(self, system: str, messages: List[Dict[str, Any]]) -> List[Any]:
|
|
38
|
+
"""Convert messages from standard format to Ollama format.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
messages: Messages in standard format
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Messages in Ollama format
|
|
45
|
+
"""
|
|
46
|
+
ollama_messages = []
|
|
47
|
+
|
|
48
|
+
# Add system message
|
|
49
|
+
ollama_messages.append(
|
|
50
|
+
{
|
|
51
|
+
"role": "system",
|
|
52
|
+
"content": system,
|
|
53
|
+
}
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
for message in messages:
|
|
57
|
+
# Skip messages with empty content
|
|
58
|
+
if not message.get("content"):
|
|
59
|
+
continue
|
|
60
|
+
content = message.get("content", [{}])[0]
|
|
61
|
+
isImage = content.get("type", "") == "image_url"
|
|
62
|
+
isText = content.get("type", "") == "text"
|
|
63
|
+
if isText:
|
|
64
|
+
data = content.get("text", "")
|
|
65
|
+
ollama_messages.append({"role": message["role"], "content": data})
|
|
66
|
+
if isImage:
|
|
67
|
+
data = content.get("image_url", {}).get("url", "")
|
|
68
|
+
# remove header
|
|
69
|
+
data = data.removeprefix("data:image/png;base64,")
|
|
70
|
+
ollama_messages.append(
|
|
71
|
+
{"role": message["role"], "content": "Use this image", "images": [data]}
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Cast the list to the correct type expected by Ollama
|
|
75
|
+
return cast(List[Any], ollama_messages)
|
|
76
|
+
|
|
77
|
+
async def run_interleaved(
|
|
78
|
+
self, messages: List[Dict[str, Any]], system: str, max_tokens: int
|
|
79
|
+
) -> Any:
|
|
80
|
+
"""Run model with interleaved conversation format.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
messages: List of messages to process
|
|
84
|
+
system: System prompt
|
|
85
|
+
max_tokens: Not used
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Model response
|
|
89
|
+
"""
|
|
90
|
+
last_error = None
|
|
91
|
+
|
|
92
|
+
for attempt in range(self.max_retries):
|
|
93
|
+
try:
|
|
94
|
+
# Convert messages to Ollama format
|
|
95
|
+
ollama_messages = self._convert_message_format(system, messages)
|
|
96
|
+
|
|
97
|
+
response = await self.client.chat(
|
|
98
|
+
model=self.model,
|
|
99
|
+
options=Options(
|
|
100
|
+
temperature=0,
|
|
101
|
+
),
|
|
102
|
+
messages=ollama_messages,
|
|
103
|
+
format="json",
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
return response
|
|
107
|
+
|
|
108
|
+
except (ConnectError, ReadTimeout) as e:
|
|
109
|
+
last_error = e
|
|
110
|
+
logger.warning(
|
|
111
|
+
f"Connection error on attempt {attempt + 1}/{self.max_retries}: {str(e)}"
|
|
112
|
+
)
|
|
113
|
+
if attempt < self.max_retries - 1:
|
|
114
|
+
await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
|
|
115
|
+
continue
|
|
116
|
+
|
|
117
|
+
except Exception as e:
|
|
118
|
+
logger.error(f"Unexpected error in Ollama API call: {str(e)}")
|
|
119
|
+
raise RuntimeError(f"Ollama API call failed: {str(e)}")
|
|
120
|
+
|
|
121
|
+
# If we get here, all retries failed
|
|
122
|
+
raise RuntimeError(f"Connection error after {self.max_retries} retries: {str(last_error)}")
|
agent/providers/omni/loop.py
CHANGED
|
@@ -16,9 +16,11 @@ from ...core.messages import StandardMessageManager, ImageRetentionConfig
|
|
|
16
16
|
from .utils import to_openai_agent_response_format
|
|
17
17
|
from ...core.types import AgentResponse
|
|
18
18
|
from computer import Computer
|
|
19
|
-
from .types import LLMProvider
|
|
19
|
+
from ...core.types import LLMProvider
|
|
20
20
|
from .clients.openai import OpenAIClient
|
|
21
21
|
from .clients.anthropic import AnthropicClient
|
|
22
|
+
from .clients.ollama import OllamaClient
|
|
23
|
+
from .clients.oaicompat import OAICompatClient
|
|
22
24
|
from .prompts import SYSTEM_PROMPT
|
|
23
25
|
from .api_handler import OmniAPIHandler
|
|
24
26
|
from .tools.manager import ToolManager
|
|
@@ -59,6 +61,7 @@ class OmniLoop(BaseLoop):
|
|
|
59
61
|
max_retries: int = 3,
|
|
60
62
|
retry_delay: float = 1.0,
|
|
61
63
|
save_trajectory: bool = True,
|
|
64
|
+
provider_base_url: Optional[str] = None,
|
|
62
65
|
**kwargs,
|
|
63
66
|
):
|
|
64
67
|
"""Initialize the loop.
|
|
@@ -74,10 +77,12 @@ class OmniLoop(BaseLoop):
|
|
|
74
77
|
max_retries: Maximum number of retries for API calls
|
|
75
78
|
retry_delay: Delay between retries in seconds
|
|
76
79
|
save_trajectory: Whether to save trajectory data
|
|
80
|
+
provider_base_url: Base URL for the API provider (used for OAICOMPAT)
|
|
77
81
|
"""
|
|
78
82
|
# Set parser and provider before initializing base class
|
|
79
83
|
self.parser = parser
|
|
80
84
|
self.provider = provider
|
|
85
|
+
self.provider_base_url = provider_base_url
|
|
81
86
|
|
|
82
87
|
# Initialize message manager with image retention config
|
|
83
88
|
self.message_manager = StandardMessageManager(
|
|
@@ -135,6 +140,17 @@ class OmniLoop(BaseLoop):
|
|
|
135
140
|
api_key=self.api_key,
|
|
136
141
|
model=self.model,
|
|
137
142
|
)
|
|
143
|
+
elif self.provider == LLMProvider.OLLAMA:
|
|
144
|
+
self.client = OllamaClient(
|
|
145
|
+
api_key=self.api_key,
|
|
146
|
+
model=self.model,
|
|
147
|
+
)
|
|
148
|
+
elif self.provider == LLMProvider.OAICOMPAT:
|
|
149
|
+
self.client = OAICompatClient(
|
|
150
|
+
api_key="EMPTY", # Local endpoints typically don't require an API key
|
|
151
|
+
model=self.model,
|
|
152
|
+
provider_base_url=self.provider_base_url,
|
|
153
|
+
)
|
|
138
154
|
else:
|
|
139
155
|
raise ValueError(f"Unsupported provider: {self.provider}")
|
|
140
156
|
|
|
@@ -160,6 +176,17 @@ class OmniLoop(BaseLoop):
|
|
|
160
176
|
max_retries=self.max_retries,
|
|
161
177
|
retry_delay=self.retry_delay,
|
|
162
178
|
)
|
|
179
|
+
elif self.provider == LLMProvider.OLLAMA:
|
|
180
|
+
self.client = OllamaClient(
|
|
181
|
+
api_key=self.api_key,
|
|
182
|
+
model=self.model,
|
|
183
|
+
)
|
|
184
|
+
elif self.provider == LLMProvider.OAICOMPAT:
|
|
185
|
+
self.client = OAICompatClient(
|
|
186
|
+
api_key="EMPTY", # Local endpoints typically don't require an API key
|
|
187
|
+
model=self.model,
|
|
188
|
+
provider_base_url=self.provider_base_url,
|
|
189
|
+
)
|
|
163
190
|
else:
|
|
164
191
|
raise ValueError(f"Unsupported provider: {self.provider}")
|
|
165
192
|
|
|
@@ -370,6 +397,21 @@ class OmniLoop(BaseLoop):
|
|
|
370
397
|
else:
|
|
371
398
|
logger.warning("Invalid Anthropic response format")
|
|
372
399
|
return True, action_screenshot_saved
|
|
400
|
+
elif self.provider == LLMProvider.OLLAMA:
|
|
401
|
+
try:
|
|
402
|
+
raw_text = response["message"]["content"]
|
|
403
|
+
standard_content = [{"type": "text", "text": raw_text}]
|
|
404
|
+
except (KeyError, TypeError, IndexError) as e:
|
|
405
|
+
logger.error(f"Invalid response format: {str(e)}")
|
|
406
|
+
return True, action_screenshot_saved
|
|
407
|
+
elif self.provider == LLMProvider.OAICOMPAT:
|
|
408
|
+
try:
|
|
409
|
+
# OpenAI-compatible response format
|
|
410
|
+
raw_text = response["choices"][0]["message"]["content"]
|
|
411
|
+
standard_content = [{"type": "text", "text": raw_text}]
|
|
412
|
+
except (KeyError, TypeError, IndexError) as e:
|
|
413
|
+
logger.error(f"Invalid response format: {str(e)}")
|
|
414
|
+
return True, action_screenshot_saved
|
|
373
415
|
else:
|
|
374
416
|
# Assume OpenAI or compatible format
|
|
375
417
|
try:
|
|
@@ -7,7 +7,7 @@ from ....core.tools import BaseToolManager, ToolResult
|
|
|
7
7
|
from ....core.tools.collection import ToolCollection
|
|
8
8
|
from .computer import ComputerTool
|
|
9
9
|
from .bash import BashTool
|
|
10
|
-
from
|
|
10
|
+
from ....core.types import LLMProvider
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class ToolManager(BaseToolManager):
|
agent/ui/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""UI modules for the Computer-Use Agent."""
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Gradio UI for Computer-Use Agent."""
|
|
2
|
+
|
|
3
|
+
import gradio as gr
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from .app import create_gradio_ui
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def registry(name: str = "cua:gpt-4o") -> gr.Blocks:
|
|
10
|
+
"""Create and register a Gradio UI for the Computer-Use Agent.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
name: The name to use for the Gradio app, in format 'provider:model'
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
A Gradio Blocks application
|
|
17
|
+
"""
|
|
18
|
+
provider, model = name.split(":", 1) if ":" in name else ("openai", name)
|
|
19
|
+
|
|
20
|
+
# Create and return the Gradio UI
|
|
21
|
+
return create_gradio_ui(provider_name=provider, model_name=model)
|