cua-agent 0.1.29__tar.gz → 0.1.30__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- {cua_agent-0.1.29 → cua_agent-0.1.30}/PKG-INFO +12 -18
- {cua_agent-0.1.29 → cua_agent-0.1.30}/README.md +9 -17
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/core/factory.py +19 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/core/types.py +1 -0
- cua_agent-0.1.30/agent/providers/uitars/__init__.py +1 -0
- cua_agent-0.1.30/agent/providers/uitars/clients/base.py +35 -0
- cua_agent-0.1.30/agent/providers/uitars/clients/oaicompat.py +204 -0
- cua_agent-0.1.30/agent/providers/uitars/loop.py +595 -0
- cua_agent-0.1.30/agent/providers/uitars/prompts.py +59 -0
- cua_agent-0.1.30/agent/providers/uitars/tools/__init__.py +1 -0
- cua_agent-0.1.30/agent/providers/uitars/tools/computer.py +279 -0
- cua_agent-0.1.30/agent/providers/uitars/tools/manager.py +60 -0
- cua_agent-0.1.30/agent/providers/uitars/utils.py +153 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/ui/gradio/app.py +12 -2
- {cua_agent-0.1.29 → cua_agent-0.1.30}/pyproject.toml +6 -3
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/__init__.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/core/__init__.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/core/agent.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/core/base.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/core/callbacks.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/core/experiment.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/core/messages.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/core/provider_config.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/core/telemetry.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/core/tools/__init__.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/core/tools/base.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/core/tools/bash.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/core/tools/collection.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/core/tools/computer.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/core/tools/edit.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/core/tools/manager.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/core/tools.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/core/visualization.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/__init__.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/anthropic/__init__.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/anthropic/api/client.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/anthropic/api/logging.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/anthropic/api_handler.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/anthropic/callbacks/__init__.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/anthropic/callbacks/manager.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/anthropic/loop.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/anthropic/prompts.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/anthropic/response_handler.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/anthropic/tools/__init__.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/anthropic/tools/base.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/anthropic/tools/bash.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/anthropic/tools/collection.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/anthropic/tools/computer.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/anthropic/tools/edit.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/anthropic/tools/manager.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/anthropic/tools/run.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/anthropic/types.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/anthropic/utils.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/omni/__init__.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/omni/api_handler.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/omni/clients/anthropic.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/omni/clients/base.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/omni/clients/oaicompat.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/omni/clients/ollama.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/omni/clients/openai.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/omni/clients/utils.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/omni/image_utils.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/omni/loop.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/omni/parser.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/omni/prompts.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/omni/tools/__init__.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/omni/tools/base.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/omni/tools/bash.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/omni/tools/computer.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/omni/tools/manager.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/omni/utils.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/openai/__init__.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/openai/api_handler.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/openai/loop.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/openai/response_handler.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/openai/tools/__init__.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/openai/tools/base.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/openai/tools/computer.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/openai/tools/manager.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/openai/types.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/providers/openai/utils.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/telemetry.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/ui/__init__.py +0 -0
- {cua_agent-0.1.29 → cua_agent-0.1.30}/agent/ui/gradio/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: cua-agent
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.30
|
|
4
4
|
Summary: CUA (Computer Use) Agent for AI-driven computer interaction
|
|
5
5
|
Author-Email: TryCua <gh@trycua.com>
|
|
6
6
|
Requires-Python: >=3.10
|
|
@@ -21,6 +21,8 @@ Requires-Dist: boto3<2.0.0,>=1.35.81; extra == "anthropic"
|
|
|
21
21
|
Provides-Extra: openai
|
|
22
22
|
Requires-Dist: openai<2.0.0,>=1.14.0; extra == "openai"
|
|
23
23
|
Requires-Dist: httpx<0.29.0,>=0.27.0; extra == "openai"
|
|
24
|
+
Provides-Extra: uitars
|
|
25
|
+
Requires-Dist: httpx<0.29.0,>=0.27.0; extra == "uitars"
|
|
24
26
|
Provides-Extra: ui
|
|
25
27
|
Requires-Dist: gradio<6.0.0,>=5.23.3; extra == "ui"
|
|
26
28
|
Requires-Dist: python-dotenv<2.0.0,>=1.0.1; extra == "ui"
|
|
@@ -118,6 +120,9 @@ async with Computer() as macos_computer:
|
|
|
118
120
|
# or
|
|
119
121
|
# loop=AgentLoop.OMNI,
|
|
120
122
|
# model=LLM(provider=LLMProvider.OLLAMA, model="gemma3")
|
|
123
|
+
# or
|
|
124
|
+
# loop=AgentLoop.UITARS,
|
|
125
|
+
# model=LLM(provider=LLMProvider.OAICOMPAT, model="tgi", provider_base_url="https://**************.us-east-1.aws.endpoints.huggingface.cloud/v1")
|
|
121
126
|
)
|
|
122
127
|
|
|
123
128
|
tasks = [
|
|
@@ -192,6 +197,10 @@ The Gradio UI provides:
|
|
|
192
197
|
- Configuration of agent parameters
|
|
193
198
|
- Chat interface for interacting with the agent
|
|
194
199
|
|
|
200
|
+
### Using UI-TARS
|
|
201
|
+
|
|
202
|
+
You can use UI-TARS by first following the [deployment guide](https://github.com/bytedance/UI-TARS/blob/main/README_deploy.md). This will give you a provider URL like this: `https://**************.us-east-1.aws.endpoints.huggingface.cloud/v1` which you can use in the gradio UI.
|
|
203
|
+
|
|
195
204
|
## Agent Loops
|
|
196
205
|
|
|
197
206
|
The `cua-agent` package provides three agent loops variations, based on different CUA models providers and techniques:
|
|
@@ -200,6 +209,7 @@ The `cua-agent` package provides three agent loops variations, based on differen
|
|
|
200
209
|
|:-----------|:-----------------|:------------|:-------------|
|
|
201
210
|
| `AgentLoop.OPENAI` | • `computer_use_preview` | Use OpenAI Operator CUA model | Not Required |
|
|
202
211
|
| `AgentLoop.ANTHROPIC` | • `claude-3-5-sonnet-20240620`<br>• `claude-3-7-sonnet-20250219` | Use Anthropic Computer-Use | Not Required |
|
|
212
|
+
| `AgentLoop.UITARS` | • `ByteDance-Seed/UI-TARS-1.5-7B` | Uses ByteDance's UI-TARS 1.5 model | Not Required |
|
|
203
213
|
| `AgentLoop.OMNI` | • `claude-3-5-sonnet-20240620`<br>• `claude-3-7-sonnet-20250219`<br>• `gpt-4.5-preview`<br>• `gpt-4o`<br>• `gpt-4`<br>• `phi4`<br>• `phi4-mini`<br>• `gemma3`<br>• `...`<br>• `Any Ollama or OpenAI-compatible model` | Use OmniParser for element pixel-detection (SoM) and any VLMs for UI Grounding and Reasoning | OmniParser |
|
|
204
214
|
|
|
205
215
|
## AgentResponse
|
|
@@ -241,25 +251,9 @@ async for result in agent.run(task):
|
|
|
241
251
|
print(output)
|
|
242
252
|
```
|
|
243
253
|
|
|
244
|
-
### Gradio UI
|
|
245
|
-
|
|
246
|
-
You can also interact with the agent using a Gradio interface.
|
|
247
|
-
|
|
248
|
-
```python
|
|
249
|
-
# Ensure environment variables (e.g., API keys) are loaded
|
|
250
|
-
# You might need a helper function like load_dotenv_files() if using .env
|
|
251
|
-
# from utils import load_dotenv_files
|
|
252
|
-
# load_dotenv_files()
|
|
253
|
-
|
|
254
|
-
from agent.ui.gradio.app import create_gradio_ui
|
|
255
|
-
|
|
256
|
-
app = create_gradio_ui()
|
|
257
|
-
app.launch(share=False)
|
|
258
|
-
```
|
|
259
|
-
|
|
260
254
|
**Note on Settings Persistence:**
|
|
261
255
|
|
|
262
256
|
* The Gradio UI automatically saves your configuration (Agent Loop, Model Choice, Custom Base URL, Save Trajectory state, Recent Images count) to a file named `.gradio_settings.json` in the project's root directory when you successfully run a task.
|
|
263
257
|
* This allows your preferences to persist between sessions.
|
|
264
258
|
* API keys entered into the custom provider field are **not** saved in this file for security reasons. Manage API keys using environment variables (e.g., `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`) or a `.env` file.
|
|
265
|
-
* It's recommended to add `.gradio_settings.json` to your `.gitignore` file.
|
|
259
|
+
* It's recommended to add `.gradio_settings.json` to your `.gitignore` file.
|
|
@@ -50,6 +50,9 @@ async with Computer() as macos_computer:
|
|
|
50
50
|
# or
|
|
51
51
|
# loop=AgentLoop.OMNI,
|
|
52
52
|
# model=LLM(provider=LLMProvider.OLLAMA, model="gemma3")
|
|
53
|
+
# or
|
|
54
|
+
# loop=AgentLoop.UITARS,
|
|
55
|
+
# model=LLM(provider=LLMProvider.OAICOMPAT, model="tgi", provider_base_url="https://**************.us-east-1.aws.endpoints.huggingface.cloud/v1")
|
|
53
56
|
)
|
|
54
57
|
|
|
55
58
|
tasks = [
|
|
@@ -124,6 +127,10 @@ The Gradio UI provides:
|
|
|
124
127
|
- Configuration of agent parameters
|
|
125
128
|
- Chat interface for interacting with the agent
|
|
126
129
|
|
|
130
|
+
### Using UI-TARS
|
|
131
|
+
|
|
132
|
+
You can use UI-TARS by first following the [deployment guide](https://github.com/bytedance/UI-TARS/blob/main/README_deploy.md). This will give you a provider URL like this: `https://**************.us-east-1.aws.endpoints.huggingface.cloud/v1` which you can use in the gradio UI.
|
|
133
|
+
|
|
127
134
|
## Agent Loops
|
|
128
135
|
|
|
129
136
|
The `cua-agent` package provides three agent loops variations, based on different CUA models providers and techniques:
|
|
@@ -132,6 +139,7 @@ The `cua-agent` package provides three agent loops variations, based on differen
|
|
|
132
139
|
|:-----------|:-----------------|:------------|:-------------|
|
|
133
140
|
| `AgentLoop.OPENAI` | • `computer_use_preview` | Use OpenAI Operator CUA model | Not Required |
|
|
134
141
|
| `AgentLoop.ANTHROPIC` | • `claude-3-5-sonnet-20240620`<br>• `claude-3-7-sonnet-20250219` | Use Anthropic Computer-Use | Not Required |
|
|
142
|
+
| `AgentLoop.UITARS` | • `ByteDance-Seed/UI-TARS-1.5-7B` | Uses ByteDance's UI-TARS 1.5 model | Not Required |
|
|
135
143
|
| `AgentLoop.OMNI` | • `claude-3-5-sonnet-20240620`<br>• `claude-3-7-sonnet-20250219`<br>• `gpt-4.5-preview`<br>• `gpt-4o`<br>• `gpt-4`<br>• `phi4`<br>• `phi4-mini`<br>• `gemma3`<br>• `...`<br>• `Any Ollama or OpenAI-compatible model` | Use OmniParser for element pixel-detection (SoM) and any VLMs for UI Grounding and Reasoning | OmniParser |
|
|
136
144
|
|
|
137
145
|
## AgentResponse
|
|
@@ -173,25 +181,9 @@ async for result in agent.run(task):
|
|
|
173
181
|
print(output)
|
|
174
182
|
```
|
|
175
183
|
|
|
176
|
-
### Gradio UI
|
|
177
|
-
|
|
178
|
-
You can also interact with the agent using a Gradio interface.
|
|
179
|
-
|
|
180
|
-
```python
|
|
181
|
-
# Ensure environment variables (e.g., API keys) are loaded
|
|
182
|
-
# You might need a helper function like load_dotenv_files() if using .env
|
|
183
|
-
# from utils import load_dotenv_files
|
|
184
|
-
# load_dotenv_files()
|
|
185
|
-
|
|
186
|
-
from agent.ui.gradio.app import create_gradio_ui
|
|
187
|
-
|
|
188
|
-
app = create_gradio_ui()
|
|
189
|
-
app.launch(share=False)
|
|
190
|
-
```
|
|
191
|
-
|
|
192
184
|
**Note on Settings Persistence:**
|
|
193
185
|
|
|
194
186
|
* The Gradio UI automatically saves your configuration (Agent Loop, Model Choice, Custom Base URL, Save Trajectory state, Recent Images count) to a file named `.gradio_settings.json` in the project's root directory when you successfully run a task.
|
|
195
187
|
* This allows your preferences to persist between sessions.
|
|
196
188
|
* API keys entered into the custom provider field are **not** saved in this file for security reasons. Manage API keys using environment variables (e.g., `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`) or a `.env` file.
|
|
197
|
-
* It's recommended to add `.gradio_settings.json` to your `.gitignore` file.
|
|
189
|
+
* It's recommended to add `.gradio_settings.json` to your `.gitignore` file.
|
|
@@ -98,5 +98,24 @@ class LoopFactory:
|
|
|
98
98
|
parser=OmniParser(),
|
|
99
99
|
provider_base_url=provider_base_url,
|
|
100
100
|
)
|
|
101
|
+
elif loop_type == AgentLoop.UITARS:
|
|
102
|
+
# Lazy import UITARSLoop only when needed
|
|
103
|
+
try:
|
|
104
|
+
from ..providers.uitars.loop import UITARSLoop
|
|
105
|
+
except ImportError:
|
|
106
|
+
raise ImportError(
|
|
107
|
+
"The 'uitars' provider is not installed. "
|
|
108
|
+
"Install it with 'pip install cua-agent[all]'"
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
return UITARSLoop(
|
|
112
|
+
api_key=api_key,
|
|
113
|
+
model=model_name,
|
|
114
|
+
computer=computer,
|
|
115
|
+
save_trajectory=save_trajectory,
|
|
116
|
+
base_dir=trajectory_dir,
|
|
117
|
+
only_n_most_recent_images=only_n_most_recent_images,
|
|
118
|
+
provider_base_url=provider_base_url,
|
|
119
|
+
)
|
|
101
120
|
else:
|
|
102
121
|
raise ValueError(f"Unsupported loop type: {loop_type}")
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""UI-TARS Agent provider package."""
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""Base client implementation for Omni providers."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Dict, List, Optional, Any, Tuple
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BaseUITarsClient:
|
|
10
|
+
"""Base class for provider-specific clients."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, api_key: Optional[str] = None, model: Optional[str] = None):
|
|
13
|
+
"""Initialize base client.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
api_key: Optional API key
|
|
17
|
+
model: Optional model name
|
|
18
|
+
"""
|
|
19
|
+
self.api_key = api_key
|
|
20
|
+
self.model = model
|
|
21
|
+
|
|
22
|
+
async def run_interleaved(
|
|
23
|
+
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
|
24
|
+
) -> Dict[str, Any]:
|
|
25
|
+
"""Run interleaved chat completion.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
messages: List of message dicts
|
|
29
|
+
system: System prompt
|
|
30
|
+
max_tokens: Optional max tokens override
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Response dict
|
|
34
|
+
"""
|
|
35
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
"""OpenAI-compatible client implementation."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Dict, List, Optional, Any
|
|
6
|
+
import aiohttp
|
|
7
|
+
import re
|
|
8
|
+
from .base import BaseUITarsClient
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# OpenAI-compatible client for the UI_Tars
|
|
14
|
+
class OAICompatClient(BaseUITarsClient):
|
|
15
|
+
"""OpenAI-compatible API client implementation.
|
|
16
|
+
|
|
17
|
+
This client can be used with any service that implements the OpenAI API protocol, including:
|
|
18
|
+
- Huggingface Text Generation Interface endpoints
|
|
19
|
+
- vLLM
|
|
20
|
+
- LM Studio
|
|
21
|
+
- LocalAI
|
|
22
|
+
- Ollama (with OpenAI compatibility)
|
|
23
|
+
- Text Generation WebUI
|
|
24
|
+
- Any other service with OpenAI API compatibility
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
api_key: Optional[str] = None,
|
|
30
|
+
model: str = "Qwen2.5-VL-7B-Instruct",
|
|
31
|
+
provider_base_url: Optional[str] = "http://localhost:8000/v1",
|
|
32
|
+
max_tokens: int = 4096,
|
|
33
|
+
temperature: float = 0.0,
|
|
34
|
+
):
|
|
35
|
+
"""Initialize the OpenAI-compatible client.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
api_key: Not used for local endpoints, usually set to "EMPTY"
|
|
39
|
+
model: Model name to use
|
|
40
|
+
provider_base_url: API base URL. Typically in the format "http://localhost:PORT/v1"
|
|
41
|
+
Examples:
|
|
42
|
+
- vLLM: "http://localhost:8000/v1"
|
|
43
|
+
- LM Studio: "http://localhost:1234/v1"
|
|
44
|
+
- LocalAI: "http://localhost:8080/v1"
|
|
45
|
+
- Ollama: "http://localhost:11434/v1"
|
|
46
|
+
max_tokens: Maximum tokens to generate
|
|
47
|
+
temperature: Generation temperature
|
|
48
|
+
"""
|
|
49
|
+
super().__init__(api_key=api_key or "EMPTY", model=model)
|
|
50
|
+
self.api_key = api_key or "EMPTY" # Local endpoints typically don't require an API key
|
|
51
|
+
self.model = model
|
|
52
|
+
self.provider_base_url = (
|
|
53
|
+
provider_base_url or "http://localhost:8000/v1"
|
|
54
|
+
) # Use default if None
|
|
55
|
+
self.max_tokens = max_tokens
|
|
56
|
+
self.temperature = temperature
|
|
57
|
+
|
|
58
|
+
def _extract_base64_image(self, text: str) -> Optional[str]:
|
|
59
|
+
"""Extract base64 image data from an HTML img tag."""
|
|
60
|
+
pattern = r'data:image/[^;]+;base64,([^"]+)'
|
|
61
|
+
match = re.search(pattern, text)
|
|
62
|
+
return match.group(1) if match else None
|
|
63
|
+
|
|
64
|
+
def _get_loggable_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
65
|
+
"""Create a loggable version of messages with image data truncated."""
|
|
66
|
+
loggable_messages = []
|
|
67
|
+
for msg in messages:
|
|
68
|
+
if isinstance(msg.get("content"), list):
|
|
69
|
+
new_content = []
|
|
70
|
+
for content in msg["content"]:
|
|
71
|
+
if content.get("type") == "image":
|
|
72
|
+
new_content.append(
|
|
73
|
+
{"type": "image", "image_url": {"url": "[BASE64_IMAGE_DATA]"}}
|
|
74
|
+
)
|
|
75
|
+
else:
|
|
76
|
+
new_content.append(content)
|
|
77
|
+
loggable_messages.append({"role": msg["role"], "content": new_content})
|
|
78
|
+
else:
|
|
79
|
+
loggable_messages.append(msg)
|
|
80
|
+
return loggable_messages
|
|
81
|
+
|
|
82
|
+
async def run_interleaved(
|
|
83
|
+
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
|
84
|
+
) -> Dict[str, Any]:
|
|
85
|
+
"""Run interleaved chat completion.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
messages: List of message dicts
|
|
89
|
+
system: System prompt
|
|
90
|
+
max_tokens: Optional max tokens override
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Response dict
|
|
94
|
+
"""
|
|
95
|
+
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
|
96
|
+
|
|
97
|
+
final_messages = [{"role": "system", "content": system}]
|
|
98
|
+
|
|
99
|
+
# Process messages
|
|
100
|
+
for item in messages:
|
|
101
|
+
if isinstance(item, dict):
|
|
102
|
+
if isinstance(item["content"], list):
|
|
103
|
+
# Content is already in the correct format
|
|
104
|
+
final_messages.append(item)
|
|
105
|
+
else:
|
|
106
|
+
# Single string content, check for image
|
|
107
|
+
base64_img = self._extract_base64_image(item["content"])
|
|
108
|
+
if base64_img:
|
|
109
|
+
message = {
|
|
110
|
+
"role": item["role"],
|
|
111
|
+
"content": [
|
|
112
|
+
{
|
|
113
|
+
"type": "image_url",
|
|
114
|
+
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
|
|
115
|
+
}
|
|
116
|
+
],
|
|
117
|
+
}
|
|
118
|
+
else:
|
|
119
|
+
message = {
|
|
120
|
+
"role": item["role"],
|
|
121
|
+
"content": [{"type": "text", "text": item["content"]}],
|
|
122
|
+
}
|
|
123
|
+
final_messages.append(message)
|
|
124
|
+
else:
|
|
125
|
+
# String content, check for image
|
|
126
|
+
base64_img = self._extract_base64_image(item)
|
|
127
|
+
if base64_img:
|
|
128
|
+
message = {
|
|
129
|
+
"role": "user",
|
|
130
|
+
"content": [
|
|
131
|
+
{
|
|
132
|
+
"type": "image_url",
|
|
133
|
+
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
|
|
134
|
+
}
|
|
135
|
+
],
|
|
136
|
+
}
|
|
137
|
+
else:
|
|
138
|
+
message = {"role": "user", "content": [{"type": "text", "text": item}]}
|
|
139
|
+
final_messages.append(message)
|
|
140
|
+
|
|
141
|
+
payload = {"model": self.model, "messages": final_messages, "temperature": self.temperature}
|
|
142
|
+
payload["max_tokens"] = max_tokens or self.max_tokens
|
|
143
|
+
|
|
144
|
+
try:
|
|
145
|
+
async with aiohttp.ClientSession() as session:
|
|
146
|
+
# Use default base URL if none provided
|
|
147
|
+
base_url = self.provider_base_url or "http://localhost:8000/v1"
|
|
148
|
+
|
|
149
|
+
# Check if the base URL already includes the chat/completions endpoint
|
|
150
|
+
|
|
151
|
+
endpoint_url = base_url
|
|
152
|
+
if not endpoint_url.endswith("/chat/completions"):
|
|
153
|
+
# If URL is RunPod format, make it OpenAI compatible
|
|
154
|
+
if endpoint_url.startswith("https://api.runpod.ai/v2/"):
|
|
155
|
+
# Extract RunPod endpoint ID
|
|
156
|
+
parts = endpoint_url.split("/")
|
|
157
|
+
if len(parts) >= 5:
|
|
158
|
+
runpod_id = parts[4]
|
|
159
|
+
endpoint_url = f"https://api.runpod.ai/v2/{runpod_id}/openai/v1/chat/completions"
|
|
160
|
+
# If the URL ends with /v1, append /chat/completions
|
|
161
|
+
elif endpoint_url.endswith("/v1"):
|
|
162
|
+
endpoint_url = f"{endpoint_url}/chat/completions"
|
|
163
|
+
# If the URL doesn't end with /v1, make sure it has a proper structure
|
|
164
|
+
elif not endpoint_url.endswith("/"):
|
|
165
|
+
endpoint_url = f"{endpoint_url}/chat/completions"
|
|
166
|
+
else:
|
|
167
|
+
endpoint_url = f"{endpoint_url}chat/completions"
|
|
168
|
+
|
|
169
|
+
# Log the endpoint URL for debugging
|
|
170
|
+
logger.debug(f"Using endpoint URL: {endpoint_url}")
|
|
171
|
+
|
|
172
|
+
async with session.post(endpoint_url, headers=headers, json=payload) as response:
|
|
173
|
+
# Log the status and content type
|
|
174
|
+
logger.debug(f"Status: {response.status}")
|
|
175
|
+
logger.debug(f"Content-Type: {response.headers.get('Content-Type')}")
|
|
176
|
+
|
|
177
|
+
# Get the raw text of the response
|
|
178
|
+
response_text = await response.text()
|
|
179
|
+
logger.debug(f"Response content: {response_text}")
|
|
180
|
+
|
|
181
|
+
# Try to parse as JSON if the content type is appropriate
|
|
182
|
+
if "application/json" in response.headers.get('Content-Type', ''):
|
|
183
|
+
response_json = await response.json()
|
|
184
|
+
else:
|
|
185
|
+
raise Exception(f"Response is not JSON format")
|
|
186
|
+
# # Optionally try to parse it anyway
|
|
187
|
+
# try:
|
|
188
|
+
# import json
|
|
189
|
+
# response_json = json.loads(response_text)
|
|
190
|
+
# except json.JSONDecodeError as e:
|
|
191
|
+
# print(f"Failed to parse response as JSON: {e}")
|
|
192
|
+
|
|
193
|
+
if response.status != 200:
|
|
194
|
+
error_msg = response_json.get("error", {}).get(
|
|
195
|
+
"message", str(response_json)
|
|
196
|
+
)
|
|
197
|
+
logger.error(f"Error in API call: {error_msg}")
|
|
198
|
+
raise Exception(f"API error: {error_msg}")
|
|
199
|
+
|
|
200
|
+
return response_json
|
|
201
|
+
|
|
202
|
+
except Exception as e:
|
|
203
|
+
logger.error(f"Error in API call: {str(e)}")
|
|
204
|
+
raise
|