cua-agent 0.4.14__py3-none-any.whl → 0.7.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/__init__.py +4 -19
- agent/__main__.py +2 -1
- agent/adapters/__init__.py +6 -0
- agent/adapters/azure_ml_adapter.py +283 -0
- agent/adapters/cua_adapter.py +161 -0
- agent/adapters/huggingfacelocal_adapter.py +67 -125
- agent/adapters/human_adapter.py +116 -114
- agent/adapters/mlxvlm_adapter.py +370 -0
- agent/adapters/models/__init__.py +41 -0
- agent/adapters/models/generic.py +78 -0
- agent/adapters/models/internvl.py +290 -0
- agent/adapters/models/opencua.py +115 -0
- agent/adapters/models/qwen2_5_vl.py +78 -0
- agent/agent.py +431 -241
- agent/callbacks/__init__.py +10 -3
- agent/callbacks/base.py +45 -31
- agent/callbacks/budget_manager.py +22 -10
- agent/callbacks/image_retention.py +54 -98
- agent/callbacks/logging.py +55 -42
- agent/callbacks/operator_validator.py +140 -0
- agent/callbacks/otel.py +291 -0
- agent/callbacks/pii_anonymization.py +19 -16
- agent/callbacks/prompt_instructions.py +47 -0
- agent/callbacks/telemetry.py +106 -69
- agent/callbacks/trajectory_saver.py +178 -70
- agent/cli.py +269 -119
- agent/computers/__init__.py +14 -9
- agent/computers/base.py +32 -19
- agent/computers/cua.py +52 -25
- agent/computers/custom.py +78 -71
- agent/decorators.py +23 -14
- agent/human_tool/__init__.py +2 -7
- agent/human_tool/__main__.py +6 -2
- agent/human_tool/server.py +48 -37
- agent/human_tool/ui.py +359 -235
- agent/integrations/hud/__init__.py +164 -74
- agent/integrations/hud/agent.py +338 -342
- agent/integrations/hud/proxy.py +297 -0
- agent/loops/__init__.py +44 -14
- agent/loops/anthropic.py +590 -492
- agent/loops/base.py +19 -15
- agent/loops/composed_grounded.py +142 -144
- agent/loops/fara/__init__.py +8 -0
- agent/loops/fara/config.py +506 -0
- agent/loops/fara/helpers.py +357 -0
- agent/loops/fara/schema.py +143 -0
- agent/loops/gelato.py +183 -0
- agent/loops/gemini.py +935 -0
- agent/loops/generic_vlm.py +601 -0
- agent/loops/glm45v.py +140 -135
- agent/loops/gta1.py +48 -51
- agent/loops/holo.py +218 -0
- agent/loops/internvl.py +180 -0
- agent/loops/moondream3.py +493 -0
- agent/loops/omniparser.py +326 -226
- agent/loops/openai.py +63 -56
- agent/loops/opencua.py +134 -0
- agent/loops/uiins.py +175 -0
- agent/loops/uitars.py +262 -212
- agent/loops/uitars2.py +951 -0
- agent/playground/__init__.py +5 -0
- agent/playground/server.py +301 -0
- agent/proxy/examples.py +196 -0
- agent/proxy/handlers.py +255 -0
- agent/responses.py +486 -339
- agent/tools/__init__.py +24 -0
- agent/tools/base.py +253 -0
- agent/tools/browser_tool.py +423 -0
- agent/types.py +20 -5
- agent/ui/__init__.py +1 -1
- agent/ui/__main__.py +1 -1
- agent/ui/gradio/app.py +25 -22
- agent/ui/gradio/ui_components.py +314 -167
- cua_agent-0.7.16.dist-info/METADATA +85 -0
- cua_agent-0.7.16.dist-info/RECORD +79 -0
- {cua_agent-0.4.14.dist-info → cua_agent-0.7.16.dist-info}/WHEEL +1 -1
- agent/integrations/hud/adapter.py +0 -121
- agent/integrations/hud/computer_handler.py +0 -187
- agent/telemetry.py +0 -142
- cua_agent-0.4.14.dist-info/METADATA +0 -436
- cua_agent-0.4.14.dist-info/RECORD +0 -50
- {cua_agent-0.4.14.dist-info → cua_agent-0.7.16.dist-info}/entry_points.txt +0 -0
agent/__init__.py
CHANGED
|
@@ -5,19 +5,13 @@ agent - Decorator-based Computer Use Agent with liteLLM integration
|
|
|
5
5
|
import logging
|
|
6
6
|
import sys
|
|
7
7
|
|
|
8
|
-
from .decorators import register_agent
|
|
9
|
-
from .agent import ComputerAgent
|
|
10
|
-
from .types import Messages, AgentResponse
|
|
11
|
-
|
|
12
8
|
# Import loops to register them
|
|
13
9
|
from . import loops
|
|
10
|
+
from .agent import ComputerAgent
|
|
11
|
+
from .decorators import register_agent
|
|
12
|
+
from .types import AgentResponse, Messages
|
|
14
13
|
|
|
15
|
-
__all__ = [
|
|
16
|
-
"register_agent",
|
|
17
|
-
"ComputerAgent",
|
|
18
|
-
"Messages",
|
|
19
|
-
"AgentResponse"
|
|
20
|
-
]
|
|
14
|
+
__all__ = ["register_agent", "ComputerAgent", "Messages", "AgentResponse"]
|
|
21
15
|
|
|
22
16
|
__version__ = "0.4.0"
|
|
23
17
|
|
|
@@ -28,13 +22,9 @@ try:
|
|
|
28
22
|
# Import from core telemetry for basic functions
|
|
29
23
|
from core.telemetry import (
|
|
30
24
|
is_telemetry_enabled,
|
|
31
|
-
flush,
|
|
32
25
|
record_event,
|
|
33
26
|
)
|
|
34
27
|
|
|
35
|
-
# Import set_dimension from our own telemetry module
|
|
36
|
-
from .telemetry import set_dimension
|
|
37
|
-
|
|
38
28
|
# Check if telemetry is enabled
|
|
39
29
|
if is_telemetry_enabled():
|
|
40
30
|
logger.info("Telemetry is enabled")
|
|
@@ -49,11 +39,6 @@ try:
|
|
|
49
39
|
},
|
|
50
40
|
)
|
|
51
41
|
|
|
52
|
-
# Set the package version as a dimension
|
|
53
|
-
set_dimension("agent_version", __version__)
|
|
54
|
-
|
|
55
|
-
# Flush events to ensure they're sent
|
|
56
|
-
flush()
|
|
57
42
|
else:
|
|
58
43
|
logger.info("Telemetry is disabled")
|
|
59
44
|
except ImportError as e:
|
agent/__main__.py
CHANGED
agent/adapters/__init__.py
CHANGED
|
@@ -2,10 +2,16 @@
|
|
|
2
2
|
Adapters package for agent - Custom LLM adapters for LiteLLM
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from .azure_ml_adapter import AzureMLAdapter
|
|
6
|
+
from .cua_adapter import CUAAdapter
|
|
5
7
|
from .huggingfacelocal_adapter import HuggingFaceLocalAdapter
|
|
6
8
|
from .human_adapter import HumanAdapter
|
|
9
|
+
from .mlxvlm_adapter import MLXVLMAdapter
|
|
7
10
|
|
|
8
11
|
__all__ = [
|
|
12
|
+
"AzureMLAdapter",
|
|
9
13
|
"HuggingFaceLocalAdapter",
|
|
10
14
|
"HumanAdapter",
|
|
15
|
+
"MLXVLMAdapter",
|
|
16
|
+
"CUAAdapter",
|
|
11
17
|
]
|
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Azure ML Custom Provider Adapter for LiteLLM.
|
|
3
|
+
|
|
4
|
+
This adapter provides direct OpenAI-compatible API access to Azure ML endpoints
|
|
5
|
+
without message transformation, specifically for models like Fara-7B that require
|
|
6
|
+
exact OpenAI message formatting.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
|
11
|
+
|
|
12
|
+
import httpx
|
|
13
|
+
from litellm import acompletion, completion
|
|
14
|
+
from litellm.llms.custom_llm import CustomLLM
|
|
15
|
+
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AzureMLAdapter(CustomLLM):
|
|
19
|
+
"""
|
|
20
|
+
Azure ML Adapter for OpenAI-compatible endpoints.
|
|
21
|
+
|
|
22
|
+
Makes direct HTTP calls to Azure ML foundry inference endpoints
|
|
23
|
+
using the OpenAI-compatible API format without transforming messages.
|
|
24
|
+
|
|
25
|
+
Usage:
|
|
26
|
+
model = "azure_ml/Fara-7B"
|
|
27
|
+
api_base = "https://foundry-inference-xxx.centralus.inference.ml.azure.com"
|
|
28
|
+
api_key = "your-api-key"
|
|
29
|
+
|
|
30
|
+
response = litellm.completion(
|
|
31
|
+
model=model,
|
|
32
|
+
messages=[...],
|
|
33
|
+
api_base=api_base,
|
|
34
|
+
api_key=api_key
|
|
35
|
+
)
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, **kwargs):
|
|
39
|
+
"""Initialize the adapter."""
|
|
40
|
+
super().__init__()
|
|
41
|
+
self._client: Optional[httpx.Client] = None
|
|
42
|
+
self._async_client: Optional[httpx.AsyncClient] = None
|
|
43
|
+
|
|
44
|
+
def _get_client(self) -> httpx.Client:
|
|
45
|
+
"""Get or create sync HTTP client."""
|
|
46
|
+
if self._client is None:
|
|
47
|
+
self._client = httpx.Client(timeout=600.0)
|
|
48
|
+
return self._client
|
|
49
|
+
|
|
50
|
+
def _get_async_client(self) -> httpx.AsyncClient:
|
|
51
|
+
"""Get or create async HTTP client."""
|
|
52
|
+
if self._async_client is None:
|
|
53
|
+
self._async_client = httpx.AsyncClient(timeout=600.0)
|
|
54
|
+
return self._async_client
|
|
55
|
+
|
|
56
|
+
def _prepare_request(self, **kwargs) -> tuple[str, dict, dict]:
|
|
57
|
+
"""
|
|
58
|
+
Prepare the HTTP request without transforming messages.
|
|
59
|
+
|
|
60
|
+
Applies Azure ML workaround: double-encodes function arguments to work around
|
|
61
|
+
Azure ML's bug where it parses arguments before validation.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
Tuple of (url, headers, json_data)
|
|
65
|
+
"""
|
|
66
|
+
# Extract required params
|
|
67
|
+
api_base = kwargs.get("api_base")
|
|
68
|
+
api_key = kwargs.get("api_key")
|
|
69
|
+
model = kwargs.get("model", "").replace("azure_ml/", "")
|
|
70
|
+
messages = kwargs.get("messages", [])
|
|
71
|
+
|
|
72
|
+
if not api_base:
|
|
73
|
+
raise ValueError("api_base is required for azure_ml provider")
|
|
74
|
+
if not api_key:
|
|
75
|
+
raise ValueError("api_key is required for azure_ml provider")
|
|
76
|
+
|
|
77
|
+
# Build OpenAI-compatible endpoint URL
|
|
78
|
+
base_url = api_base.rstrip("/")
|
|
79
|
+
url = f"{base_url}/chat/completions"
|
|
80
|
+
|
|
81
|
+
# Prepare headers
|
|
82
|
+
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
|
83
|
+
|
|
84
|
+
# WORKAROUND for Azure ML bug:
|
|
85
|
+
# Azure ML incorrectly parses the arguments field before validation,
|
|
86
|
+
# causing it to reject valid JSON strings. We double-encode arguments
|
|
87
|
+
# so that after Azure ML's parse, they remain as strings.
|
|
88
|
+
messages_copy = []
|
|
89
|
+
for message in messages:
|
|
90
|
+
msg_copy = message.copy()
|
|
91
|
+
|
|
92
|
+
# Check if message has tool_calls that need double-encoding
|
|
93
|
+
if "tool_calls" in msg_copy:
|
|
94
|
+
tool_calls_copy = []
|
|
95
|
+
for tool_call in msg_copy["tool_calls"]:
|
|
96
|
+
tc_copy = tool_call.copy()
|
|
97
|
+
|
|
98
|
+
if "function" in tc_copy and "arguments" in tc_copy["function"]:
|
|
99
|
+
func_copy = tc_copy["function"].copy()
|
|
100
|
+
arguments = func_copy["arguments"]
|
|
101
|
+
|
|
102
|
+
# If arguments is already a string, double-encode it
|
|
103
|
+
if isinstance(arguments, str):
|
|
104
|
+
func_copy["arguments"] = json.dumps(arguments)
|
|
105
|
+
|
|
106
|
+
tc_copy["function"] = func_copy
|
|
107
|
+
|
|
108
|
+
tool_calls_copy.append(tc_copy)
|
|
109
|
+
|
|
110
|
+
msg_copy["tool_calls"] = tool_calls_copy
|
|
111
|
+
|
|
112
|
+
messages_copy.append(msg_copy)
|
|
113
|
+
|
|
114
|
+
# Prepare request body with double-encoded messages
|
|
115
|
+
json_data = {"model": model, "messages": messages_copy}
|
|
116
|
+
|
|
117
|
+
# Add optional parameters if provided
|
|
118
|
+
optional_params = [
|
|
119
|
+
"temperature",
|
|
120
|
+
"top_p",
|
|
121
|
+
"n",
|
|
122
|
+
"stream",
|
|
123
|
+
"stop",
|
|
124
|
+
"max_tokens",
|
|
125
|
+
"presence_penalty",
|
|
126
|
+
"frequency_penalty",
|
|
127
|
+
"logit_bias",
|
|
128
|
+
"user",
|
|
129
|
+
"response_format",
|
|
130
|
+
"seed",
|
|
131
|
+
"tools",
|
|
132
|
+
"tool_choice",
|
|
133
|
+
]
|
|
134
|
+
|
|
135
|
+
for param in optional_params:
|
|
136
|
+
if param in kwargs and kwargs[param] is not None:
|
|
137
|
+
json_data[param] = kwargs[param]
|
|
138
|
+
|
|
139
|
+
return url, headers, json_data
|
|
140
|
+
|
|
141
|
+
def completion(self, *args, **kwargs) -> ModelResponse:
|
|
142
|
+
"""
|
|
143
|
+
Synchronous completion method.
|
|
144
|
+
|
|
145
|
+
Makes a direct HTTP POST to Azure ML's OpenAI-compatible endpoint.
|
|
146
|
+
"""
|
|
147
|
+
url, headers, json_data = self._prepare_request(**kwargs)
|
|
148
|
+
|
|
149
|
+
client = self._get_client()
|
|
150
|
+
response = client.post(url, headers=headers, json=json_data)
|
|
151
|
+
response.raise_for_status()
|
|
152
|
+
|
|
153
|
+
# Parse response
|
|
154
|
+
response_json = response.json()
|
|
155
|
+
|
|
156
|
+
# Return using litellm's completion with the actual response
|
|
157
|
+
return completion(
|
|
158
|
+
model=f"azure_ml/{kwargs.get('model', '')}",
|
|
159
|
+
mock_response=response_json["choices"][0]["message"]["content"],
|
|
160
|
+
messages=kwargs.get("messages", []),
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
|
164
|
+
"""
|
|
165
|
+
Asynchronous completion method.
|
|
166
|
+
|
|
167
|
+
Makes a direct async HTTP POST to Azure ML's OpenAI-compatible endpoint.
|
|
168
|
+
"""
|
|
169
|
+
url, headers, json_data = self._prepare_request(**kwargs)
|
|
170
|
+
|
|
171
|
+
client = self._get_async_client()
|
|
172
|
+
response = await client.post(url, headers=headers, json=json_data)
|
|
173
|
+
response.raise_for_status()
|
|
174
|
+
|
|
175
|
+
# Parse response
|
|
176
|
+
response_json = response.json()
|
|
177
|
+
|
|
178
|
+
# Return using litellm's acompletion with the actual response
|
|
179
|
+
return await acompletion(
|
|
180
|
+
model=f"azure_ml/{kwargs.get('model', '')}",
|
|
181
|
+
mock_response=response_json["choices"][0]["message"]["content"],
|
|
182
|
+
messages=kwargs.get("messages", []),
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
|
186
|
+
"""
|
|
187
|
+
Synchronous streaming method.
|
|
188
|
+
|
|
189
|
+
Makes a streaming HTTP POST to Azure ML's OpenAI-compatible endpoint.
|
|
190
|
+
"""
|
|
191
|
+
url, headers, json_data = self._prepare_request(**kwargs)
|
|
192
|
+
json_data["stream"] = True
|
|
193
|
+
|
|
194
|
+
client = self._get_client()
|
|
195
|
+
|
|
196
|
+
with client.stream("POST", url, headers=headers, json=json_data) as response:
|
|
197
|
+
response.raise_for_status()
|
|
198
|
+
|
|
199
|
+
for line in response.iter_lines():
|
|
200
|
+
if line.startswith("data: "):
|
|
201
|
+
data = line[6:] # Remove "data: " prefix
|
|
202
|
+
if data == "[DONE]":
|
|
203
|
+
break
|
|
204
|
+
|
|
205
|
+
try:
|
|
206
|
+
chunk_json = json.loads(data)
|
|
207
|
+
delta = chunk_json["choices"][0].get("delta", {})
|
|
208
|
+
content = delta.get("content", "")
|
|
209
|
+
finish_reason = chunk_json["choices"][0].get("finish_reason")
|
|
210
|
+
|
|
211
|
+
generic_streaming_chunk: GenericStreamingChunk = {
|
|
212
|
+
"finish_reason": finish_reason,
|
|
213
|
+
"index": 0,
|
|
214
|
+
"is_finished": finish_reason is not None,
|
|
215
|
+
"text": content,
|
|
216
|
+
"tool_use": None,
|
|
217
|
+
"usage": chunk_json.get(
|
|
218
|
+
"usage",
|
|
219
|
+
{"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
|
|
220
|
+
),
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
yield generic_streaming_chunk
|
|
224
|
+
except json.JSONDecodeError:
|
|
225
|
+
continue
|
|
226
|
+
|
|
227
|
+
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
|
228
|
+
"""
|
|
229
|
+
Asynchronous streaming method.
|
|
230
|
+
|
|
231
|
+
Makes an async streaming HTTP POST to Azure ML's OpenAI-compatible endpoint.
|
|
232
|
+
"""
|
|
233
|
+
url, headers, json_data = self._prepare_request(**kwargs)
|
|
234
|
+
json_data["stream"] = True
|
|
235
|
+
|
|
236
|
+
client = self._get_async_client()
|
|
237
|
+
|
|
238
|
+
async with client.stream("POST", url, headers=headers, json=json_data) as response:
|
|
239
|
+
response.raise_for_status()
|
|
240
|
+
|
|
241
|
+
async for line in response.aiter_lines():
|
|
242
|
+
if line.startswith("data: "):
|
|
243
|
+
data = line[6:] # Remove "data: " prefix
|
|
244
|
+
if data == "[DONE]":
|
|
245
|
+
break
|
|
246
|
+
|
|
247
|
+
try:
|
|
248
|
+
chunk_json = json.loads(data)
|
|
249
|
+
delta = chunk_json["choices"][0].get("delta", {})
|
|
250
|
+
content = delta.get("content", "")
|
|
251
|
+
finish_reason = chunk_json["choices"][0].get("finish_reason")
|
|
252
|
+
|
|
253
|
+
generic_streaming_chunk: GenericStreamingChunk = {
|
|
254
|
+
"finish_reason": finish_reason,
|
|
255
|
+
"index": 0,
|
|
256
|
+
"is_finished": finish_reason is not None,
|
|
257
|
+
"text": content,
|
|
258
|
+
"tool_use": None,
|
|
259
|
+
"usage": chunk_json.get(
|
|
260
|
+
"usage",
|
|
261
|
+
{"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
|
|
262
|
+
),
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
yield generic_streaming_chunk
|
|
266
|
+
except json.JSONDecodeError:
|
|
267
|
+
continue
|
|
268
|
+
|
|
269
|
+
def __del__(self):
|
|
270
|
+
"""Cleanup HTTP clients."""
|
|
271
|
+
if self._client is not None:
|
|
272
|
+
self._client.close()
|
|
273
|
+
if self._async_client is not None:
|
|
274
|
+
import asyncio
|
|
275
|
+
|
|
276
|
+
try:
|
|
277
|
+
loop = asyncio.get_event_loop()
|
|
278
|
+
if loop.is_running():
|
|
279
|
+
loop.create_task(self._async_client.aclose())
|
|
280
|
+
else:
|
|
281
|
+
loop.run_until_complete(self._async_client.aclose())
|
|
282
|
+
except Exception:
|
|
283
|
+
pass
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Any, AsyncIterator, Iterator
|
|
3
|
+
|
|
4
|
+
from litellm import acompletion, completion
|
|
5
|
+
from litellm.llms.custom_llm import CustomLLM
|
|
6
|
+
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CUAAdapter(CustomLLM):
|
|
10
|
+
def __init__(self, base_url: str | None = None, api_key: str | None = None, **_: Any):
|
|
11
|
+
super().__init__()
|
|
12
|
+
self.base_url = base_url or os.environ.get("CUA_BASE_URL") or "https://inference.cua.ai/v1"
|
|
13
|
+
self.api_key = (
|
|
14
|
+
api_key or os.environ.get("CUA_INFERENCE_API_KEY") or os.environ.get("CUA_API_KEY")
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
def _normalize_model(self, model: str) -> str:
|
|
18
|
+
# Accept either "cua/<model>" or raw "<model>"
|
|
19
|
+
return model.split("/", 1)[1] if model and model.startswith("cua/") else model
|
|
20
|
+
|
|
21
|
+
def completion(self, *args, **kwargs) -> ModelResponse:
|
|
22
|
+
model = kwargs.get("model", "")
|
|
23
|
+
api_base = kwargs.get("api_base") or self.base_url
|
|
24
|
+
if "anthropic/" in model:
|
|
25
|
+
model = f"anthropic/{self._normalize_model(model)}"
|
|
26
|
+
api_base = api_base.removesuffix("/v1")
|
|
27
|
+
elif "gemini/" in model or "google/" in model:
|
|
28
|
+
# Route to Gemini pass-through endpoint
|
|
29
|
+
model = f"gemini/{self._normalize_model(model)}"
|
|
30
|
+
api_base = api_base + "/gemini"
|
|
31
|
+
else:
|
|
32
|
+
model = f"openai/{self._normalize_model(model)}"
|
|
33
|
+
|
|
34
|
+
params = {
|
|
35
|
+
"model": model,
|
|
36
|
+
"messages": kwargs.get("messages", []),
|
|
37
|
+
"api_base": api_base,
|
|
38
|
+
"api_key": kwargs.get("api_key") or self.api_key,
|
|
39
|
+
"stream": False,
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
# Forward tools if provided
|
|
43
|
+
if "tools" in kwargs:
|
|
44
|
+
params["tools"] = kwargs["tools"]
|
|
45
|
+
|
|
46
|
+
if "optional_params" in kwargs:
|
|
47
|
+
params.update(kwargs["optional_params"])
|
|
48
|
+
del kwargs["optional_params"]
|
|
49
|
+
|
|
50
|
+
if "headers" in kwargs:
|
|
51
|
+
params["headers"] = kwargs["headers"]
|
|
52
|
+
del kwargs["headers"]
|
|
53
|
+
|
|
54
|
+
# Print dropped parameters
|
|
55
|
+
original_keys = set(kwargs.keys())
|
|
56
|
+
used_keys = set(params.keys()) # Only these are extracted from kwargs
|
|
57
|
+
ignored_keys = {
|
|
58
|
+
"litellm_params",
|
|
59
|
+
"client",
|
|
60
|
+
"print_verbose",
|
|
61
|
+
"acompletion",
|
|
62
|
+
"timeout",
|
|
63
|
+
"logging_obj",
|
|
64
|
+
"encoding",
|
|
65
|
+
"custom_prompt_dict",
|
|
66
|
+
"model_response",
|
|
67
|
+
"logger_fn",
|
|
68
|
+
}
|
|
69
|
+
dropped_keys = original_keys - used_keys - ignored_keys
|
|
70
|
+
if dropped_keys:
|
|
71
|
+
dropped_keyvals = {k: kwargs[k] for k in dropped_keys}
|
|
72
|
+
# print(f"CUAAdapter.completion: Dropped parameters: {dropped_keyvals}")
|
|
73
|
+
|
|
74
|
+
return completion(**params) # type: ignore
|
|
75
|
+
|
|
76
|
+
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
|
77
|
+
model = kwargs.get("model", "")
|
|
78
|
+
api_base = kwargs.get("api_base") or self.base_url
|
|
79
|
+
if "anthropic/" in model:
|
|
80
|
+
model = f"anthropic/{self._normalize_model(model)}"
|
|
81
|
+
api_base = api_base.removesuffix("/v1")
|
|
82
|
+
elif "gemini/" in model or "google/" in model:
|
|
83
|
+
# Route to Gemini pass-through endpoint
|
|
84
|
+
model = f"gemini/{self._normalize_model(model)}"
|
|
85
|
+
api_base = api_base + "/gemini"
|
|
86
|
+
else:
|
|
87
|
+
model = f"openai/{self._normalize_model(model)}"
|
|
88
|
+
|
|
89
|
+
params = {
|
|
90
|
+
"model": model,
|
|
91
|
+
"messages": kwargs.get("messages", []),
|
|
92
|
+
"api_base": api_base,
|
|
93
|
+
"api_key": kwargs.get("api_key") or self.api_key,
|
|
94
|
+
"stream": False,
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
# Forward tools if provided
|
|
98
|
+
if "tools" in kwargs:
|
|
99
|
+
params["tools"] = kwargs["tools"]
|
|
100
|
+
|
|
101
|
+
if "optional_params" in kwargs:
|
|
102
|
+
params.update(kwargs["optional_params"])
|
|
103
|
+
del kwargs["optional_params"]
|
|
104
|
+
|
|
105
|
+
if "headers" in kwargs:
|
|
106
|
+
params["headers"] = kwargs["headers"]
|
|
107
|
+
del kwargs["headers"]
|
|
108
|
+
|
|
109
|
+
# Print dropped parameters
|
|
110
|
+
original_keys = set(kwargs.keys())
|
|
111
|
+
used_keys = set(params.keys()) # Only these are extracted from kwargs
|
|
112
|
+
ignored_keys = {
|
|
113
|
+
"litellm_params",
|
|
114
|
+
"client",
|
|
115
|
+
"print_verbose",
|
|
116
|
+
"acompletion",
|
|
117
|
+
"timeout",
|
|
118
|
+
"logging_obj",
|
|
119
|
+
"encoding",
|
|
120
|
+
"custom_prompt_dict",
|
|
121
|
+
"model_response",
|
|
122
|
+
"logger_fn",
|
|
123
|
+
}
|
|
124
|
+
dropped_keys = original_keys - used_keys - ignored_keys
|
|
125
|
+
if dropped_keys:
|
|
126
|
+
dropped_keyvals = {k: kwargs[k] for k in dropped_keys}
|
|
127
|
+
# print(f"CUAAdapter.acompletion: Dropped parameters: {dropped_keyvals}")
|
|
128
|
+
|
|
129
|
+
response = await acompletion(**params) # type: ignore
|
|
130
|
+
|
|
131
|
+
return response
|
|
132
|
+
|
|
133
|
+
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
|
134
|
+
params = dict(kwargs)
|
|
135
|
+
inner_model = self._normalize_model(params.get("model", ""))
|
|
136
|
+
params.update(
|
|
137
|
+
{
|
|
138
|
+
"model": f"openai/{inner_model}",
|
|
139
|
+
"api_base": self.base_url,
|
|
140
|
+
"api_key": self.api_key,
|
|
141
|
+
"stream": True,
|
|
142
|
+
}
|
|
143
|
+
)
|
|
144
|
+
# Yield chunks directly from LiteLLM's streaming generator
|
|
145
|
+
for chunk in completion(**params): # type: ignore
|
|
146
|
+
yield chunk # type: ignore
|
|
147
|
+
|
|
148
|
+
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
|
149
|
+
params = dict(kwargs)
|
|
150
|
+
inner_model = self._normalize_model(params.get("model", ""))
|
|
151
|
+
params.update(
|
|
152
|
+
{
|
|
153
|
+
"model": f"openai/{inner_model}",
|
|
154
|
+
"api_base": self.base_url,
|
|
155
|
+
"api_key": self.api_key,
|
|
156
|
+
"stream": True,
|
|
157
|
+
}
|
|
158
|
+
)
|
|
159
|
+
stream = await acompletion(**params) # type: ignore
|
|
160
|
+
async for chunk in stream: # type: ignore
|
|
161
|
+
yield chunk # type: ignore
|