lybic-guiagents 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lybic-guiagents might be problematic. Click here for more details.
- desktop_env/__init__.py +1 -0
- desktop_env/actions.py +203 -0
- desktop_env/controllers/__init__.py +0 -0
- desktop_env/controllers/python.py +471 -0
- desktop_env/controllers/setup.py +882 -0
- desktop_env/desktop_env.py +509 -0
- desktop_env/evaluators/__init__.py +5 -0
- desktop_env/evaluators/getters/__init__.py +41 -0
- desktop_env/evaluators/getters/calc.py +15 -0
- desktop_env/evaluators/getters/chrome.py +1774 -0
- desktop_env/evaluators/getters/file.py +154 -0
- desktop_env/evaluators/getters/general.py +42 -0
- desktop_env/evaluators/getters/gimp.py +38 -0
- desktop_env/evaluators/getters/impress.py +126 -0
- desktop_env/evaluators/getters/info.py +24 -0
- desktop_env/evaluators/getters/misc.py +406 -0
- desktop_env/evaluators/getters/replay.py +20 -0
- desktop_env/evaluators/getters/vlc.py +86 -0
- desktop_env/evaluators/getters/vscode.py +35 -0
- desktop_env/evaluators/metrics/__init__.py +160 -0
- desktop_env/evaluators/metrics/basic_os.py +68 -0
- desktop_env/evaluators/metrics/chrome.py +493 -0
- desktop_env/evaluators/metrics/docs.py +1011 -0
- desktop_env/evaluators/metrics/general.py +665 -0
- desktop_env/evaluators/metrics/gimp.py +637 -0
- desktop_env/evaluators/metrics/libreoffice.py +28 -0
- desktop_env/evaluators/metrics/others.py +92 -0
- desktop_env/evaluators/metrics/pdf.py +31 -0
- desktop_env/evaluators/metrics/slides.py +957 -0
- desktop_env/evaluators/metrics/table.py +585 -0
- desktop_env/evaluators/metrics/thunderbird.py +176 -0
- desktop_env/evaluators/metrics/utils.py +719 -0
- desktop_env/evaluators/metrics/vlc.py +524 -0
- desktop_env/evaluators/metrics/vscode.py +283 -0
- desktop_env/providers/__init__.py +35 -0
- desktop_env/providers/aws/__init__.py +0 -0
- desktop_env/providers/aws/manager.py +278 -0
- desktop_env/providers/aws/provider.py +186 -0
- desktop_env/providers/aws/provider_with_proxy.py +315 -0
- desktop_env/providers/aws/proxy_pool.py +193 -0
- desktop_env/providers/azure/__init__.py +0 -0
- desktop_env/providers/azure/manager.py +87 -0
- desktop_env/providers/azure/provider.py +207 -0
- desktop_env/providers/base.py +97 -0
- desktop_env/providers/gcp/__init__.py +0 -0
- desktop_env/providers/gcp/manager.py +0 -0
- desktop_env/providers/gcp/provider.py +0 -0
- desktop_env/providers/virtualbox/__init__.py +0 -0
- desktop_env/providers/virtualbox/manager.py +463 -0
- desktop_env/providers/virtualbox/provider.py +124 -0
- desktop_env/providers/vmware/__init__.py +0 -0
- desktop_env/providers/vmware/manager.py +455 -0
- desktop_env/providers/vmware/provider.py +105 -0
- gui_agents/__init__.py +0 -0
- gui_agents/agents/Action.py +209 -0
- gui_agents/agents/__init__.py +0 -0
- gui_agents/agents/agent_s.py +832 -0
- gui_agents/agents/global_state.py +610 -0
- gui_agents/agents/grounding.py +651 -0
- gui_agents/agents/hardware_interface.py +129 -0
- gui_agents/agents/manager.py +568 -0
- gui_agents/agents/translator.py +132 -0
- gui_agents/agents/worker.py +355 -0
- gui_agents/cli_app.py +560 -0
- gui_agents/core/__init__.py +0 -0
- gui_agents/core/engine.py +1496 -0
- gui_agents/core/knowledge.py +449 -0
- gui_agents/core/mllm.py +555 -0
- gui_agents/tools/__init__.py +0 -0
- gui_agents/tools/tools.py +727 -0
- gui_agents/unit_test/__init__.py +0 -0
- gui_agents/unit_test/run_tests.py +65 -0
- gui_agents/unit_test/test_manager.py +330 -0
- gui_agents/unit_test/test_worker.py +269 -0
- gui_agents/utils/__init__.py +0 -0
- gui_agents/utils/analyze_display.py +301 -0
- gui_agents/utils/common_utils.py +263 -0
- gui_agents/utils/display_viewer.py +281 -0
- gui_agents/utils/embedding_manager.py +53 -0
- gui_agents/utils/image_axis_utils.py +27 -0
- lybic_guiagents-0.1.0.dist-info/METADATA +416 -0
- lybic_guiagents-0.1.0.dist-info/RECORD +85 -0
- lybic_guiagents-0.1.0.dist-info/WHEEL +5 -0
- lybic_guiagents-0.1.0.dist-info/licenses/LICENSE +201 -0
- lybic_guiagents-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,1496 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import backoff
|
|
4
|
+
import requests
|
|
5
|
+
from typing import List, Dict, Any, Optional, Union
|
|
6
|
+
import numpy as np
|
|
7
|
+
from anthropic import Anthropic
|
|
8
|
+
from openai import (
|
|
9
|
+
AzureOpenAI,
|
|
10
|
+
APIConnectionError,
|
|
11
|
+
APIError,
|
|
12
|
+
AzureOpenAI,
|
|
13
|
+
OpenAI,
|
|
14
|
+
RateLimitError,
|
|
15
|
+
)
|
|
16
|
+
from google import genai
|
|
17
|
+
from google.genai import types
|
|
18
|
+
from zhipuai import ZhipuAI
|
|
19
|
+
from groq import Groq
|
|
20
|
+
import boto3
|
|
21
|
+
import exa_py
|
|
22
|
+
from typing import List, Dict, Any, Optional, Union, Tuple
|
|
23
|
+
|
|
24
|
+
class ModelPricing:
|
|
25
|
+
def __init__(self, pricing_file: str = "model_pricing.json"):
|
|
26
|
+
self.pricing_file = pricing_file
|
|
27
|
+
self.pricing_data = self._load_pricing()
|
|
28
|
+
|
|
29
|
+
def _load_pricing(self) -> Dict:
|
|
30
|
+
if os.path.exists(self.pricing_file):
|
|
31
|
+
try:
|
|
32
|
+
with open(self.pricing_file, 'r', encoding='utf-8') as f:
|
|
33
|
+
return json.load(f)
|
|
34
|
+
except Exception as e:
|
|
35
|
+
print(f"Warning: Failed to load pricing file {self.pricing_file}: {e}")
|
|
36
|
+
|
|
37
|
+
return {
|
|
38
|
+
"default": {"input": 0, "output": 0}
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
def get_price(self, model: str) -> Dict[str, float]:
|
|
42
|
+
# Handle nested pricing data structure
|
|
43
|
+
if "llm_models" in self.pricing_data:
|
|
44
|
+
# Iterate through all LLM model categories
|
|
45
|
+
for category, models in self.pricing_data["llm_models"].items():
|
|
46
|
+
# Direct model name matching
|
|
47
|
+
if model in models:
|
|
48
|
+
pricing = models[model]
|
|
49
|
+
return self._parse_pricing(pricing)
|
|
50
|
+
|
|
51
|
+
# Fuzzy matching for model names
|
|
52
|
+
for model_name in models:
|
|
53
|
+
if model_name in model or model in model_name:
|
|
54
|
+
pricing = models[model_name]
|
|
55
|
+
return self._parse_pricing(pricing)
|
|
56
|
+
|
|
57
|
+
# Handle embedding models
|
|
58
|
+
if "embedding_models" in self.pricing_data:
|
|
59
|
+
for category, models in self.pricing_data["embedding_models"].items():
|
|
60
|
+
if model in models:
|
|
61
|
+
pricing = models[model]
|
|
62
|
+
return self._parse_pricing(pricing)
|
|
63
|
+
|
|
64
|
+
for model_name in models:
|
|
65
|
+
if model_name in model or model in model_name:
|
|
66
|
+
pricing = models[model_name]
|
|
67
|
+
return self._parse_pricing(pricing)
|
|
68
|
+
|
|
69
|
+
# Default pricing
|
|
70
|
+
return {"input": 0, "output": 0}
|
|
71
|
+
|
|
72
|
+
def _parse_pricing(self, pricing: Dict[str, str]) -> Dict[str, float]:
|
|
73
|
+
"""Parse pricing strings and convert to numeric values"""
|
|
74
|
+
result = {}
|
|
75
|
+
|
|
76
|
+
for key, value in pricing.items():
|
|
77
|
+
if isinstance(value, str):
|
|
78
|
+
# Remove currency symbols and units, convert to float
|
|
79
|
+
clean_value = value.replace('$', '').replace('¥', '').replace(',', '')
|
|
80
|
+
try:
|
|
81
|
+
result[key] = float(clean_value)
|
|
82
|
+
except ValueError:
|
|
83
|
+
result[key] = 0.0
|
|
84
|
+
else:
|
|
85
|
+
result[key] = float(value) if value else 0.0
|
|
86
|
+
|
|
87
|
+
return result
|
|
88
|
+
|
|
89
|
+
def calculate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
|
|
90
|
+
pricing = self.get_price(model)
|
|
91
|
+
input_cost = (input_tokens / 1000000) * pricing["input"]
|
|
92
|
+
output_cost = (output_tokens / 1000000) * pricing["output"]
|
|
93
|
+
return input_cost + output_cost
|
|
94
|
+
|
|
95
|
+
# Initialize pricing manager with correct pricing file path
|
|
96
|
+
pricing_file = os.path.join(os.path.dirname(__file__), 'model_pricing.json')
|
|
97
|
+
pricing_manager = ModelPricing(pricing_file)
|
|
98
|
+
|
|
99
|
+
def extract_token_usage(response, provider: str) -> Tuple[int, int]:
|
|
100
|
+
if "-" in provider:
|
|
101
|
+
api_type, vendor = provider.split("-", 1)
|
|
102
|
+
else:
|
|
103
|
+
api_type, vendor = "llm", provider
|
|
104
|
+
|
|
105
|
+
if api_type == "llm":
|
|
106
|
+
if vendor in ["openai", "qwen", "deepseek", "doubao", "siliconflow", "monica", "vllm", "groq", "zhipu", "gemini", "openrouter", "azureopenai", "huggingface", "exa"]:
|
|
107
|
+
if hasattr(response, 'usage') and response.usage:
|
|
108
|
+
return response.usage.prompt_tokens, response.usage.completion_tokens
|
|
109
|
+
|
|
110
|
+
elif vendor == "anthropic":
|
|
111
|
+
if hasattr(response, 'usage') and response.usage:
|
|
112
|
+
return response.usage.input_tokens, response.usage.output_tokens
|
|
113
|
+
|
|
114
|
+
elif vendor == "bedrock":
|
|
115
|
+
if isinstance(response, dict) and "usage" in response:
|
|
116
|
+
usage = response["usage"]
|
|
117
|
+
return usage.get("input_tokens", 0), usage.get("output_tokens", 0)
|
|
118
|
+
|
|
119
|
+
elif api_type == "embedding":
|
|
120
|
+
if vendor in ["openai", "azureopenai", "qwen", "doubao"]:
|
|
121
|
+
if hasattr(response, 'usage') and response.usage:
|
|
122
|
+
return response.usage.prompt_tokens, 0
|
|
123
|
+
|
|
124
|
+
elif vendor == "jina":
|
|
125
|
+
if isinstance(response, dict) and "usage" in response:
|
|
126
|
+
total_tokens = response["usage"].get("total_tokens", 0)
|
|
127
|
+
return total_tokens, 0
|
|
128
|
+
|
|
129
|
+
elif vendor == "gemini":
|
|
130
|
+
if hasattr(response, 'usage') and response.usage:
|
|
131
|
+
return response.usage.prompt_tokens, 0
|
|
132
|
+
|
|
133
|
+
return 0, 0
|
|
134
|
+
|
|
135
|
+
def calculate_tokens_and_cost(response, provider: str, model: str) -> Tuple[List[int], float]:
|
|
136
|
+
input_tokens, output_tokens = extract_token_usage(response, provider)
|
|
137
|
+
total_tokens = input_tokens + output_tokens
|
|
138
|
+
cost = pricing_manager.calculate_cost(model, input_tokens, output_tokens)
|
|
139
|
+
|
|
140
|
+
return [input_tokens, output_tokens, total_tokens], cost
|
|
141
|
+
|
|
142
|
+
class LMMEngine:
|
|
143
|
+
pass
|
|
144
|
+
|
|
145
|
+
# ==================== LLM ====================
|
|
146
|
+
|
|
147
|
+
class LMMEngineOpenAI(LMMEngine):
|
|
148
|
+
def __init__(
|
|
149
|
+
self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs
|
|
150
|
+
):
|
|
151
|
+
assert model is not None, "model must be provided"
|
|
152
|
+
self.model = model
|
|
153
|
+
self.provider = "llm-openai"
|
|
154
|
+
|
|
155
|
+
api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
156
|
+
if api_key is None:
|
|
157
|
+
raise ValueError(
|
|
158
|
+
"An API Key needs to be provided in either the api_key parameter or as an environment variable named OPENAI_API_KEY"
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
self.base_url = base_url
|
|
162
|
+
|
|
163
|
+
self.api_key = api_key
|
|
164
|
+
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
|
165
|
+
|
|
166
|
+
if not self.base_url:
|
|
167
|
+
self.llm_client = OpenAI(api_key=self.api_key)
|
|
168
|
+
else:
|
|
169
|
+
self.llm_client = OpenAI(base_url=self.base_url, api_key=self.api_key)
|
|
170
|
+
|
|
171
|
+
@backoff.on_exception(
|
|
172
|
+
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
|
173
|
+
)
|
|
174
|
+
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
|
175
|
+
"""Generate the next message based on previous messages"""
|
|
176
|
+
response = self.llm_client.chat.completions.create(
|
|
177
|
+
model=self.model,
|
|
178
|
+
messages=messages,
|
|
179
|
+
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
|
180
|
+
temperature=temperature,
|
|
181
|
+
**kwargs,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
content = response.choices[0].message.content
|
|
185
|
+
total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
|
|
186
|
+
|
|
187
|
+
return content, total_tokens, cost
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class LMMEngineQwen(LMMEngine):
|
|
191
|
+
def __init__(
|
|
192
|
+
self, base_url=None, api_key=None, model=None, rate_limit=-1, enable_thinking=False, **kwargs
|
|
193
|
+
):
|
|
194
|
+
assert model is not None, "model must be provided"
|
|
195
|
+
self.model = model
|
|
196
|
+
self.enable_thinking = enable_thinking
|
|
197
|
+
self.provider = "llm-qwen"
|
|
198
|
+
|
|
199
|
+
api_key = api_key or os.getenv("DASHSCOPE_API_KEY")
|
|
200
|
+
if api_key is None:
|
|
201
|
+
raise ValueError(
|
|
202
|
+
"An API Key needs to be provided in either the api_key parameter or as an environment variable named DASHSCOPE_API_KEY"
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
self.base_url = base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
|
206
|
+
self.api_key = api_key
|
|
207
|
+
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
|
208
|
+
|
|
209
|
+
self.llm_client = OpenAI(base_url=self.base_url, api_key=self.api_key)
|
|
210
|
+
|
|
211
|
+
@backoff.on_exception(
|
|
212
|
+
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
|
213
|
+
)
|
|
214
|
+
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
|
215
|
+
"""Generate the next message based on previous messages"""
|
|
216
|
+
# For Qwen3 models, we need to handle thinking mode
|
|
217
|
+
extra_body = {}
|
|
218
|
+
if self.model.startswith("qwen3") and not self.enable_thinking:
|
|
219
|
+
extra_body["enable_thinking"] = False
|
|
220
|
+
|
|
221
|
+
response = self.llm_client.chat.completions.create(
|
|
222
|
+
model=self.model,
|
|
223
|
+
messages=messages,
|
|
224
|
+
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
|
225
|
+
temperature=temperature,
|
|
226
|
+
**extra_body,
|
|
227
|
+
**kwargs,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
content = response.choices[0].message.content
|
|
231
|
+
total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
|
|
232
|
+
|
|
233
|
+
return content, total_tokens, cost
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class LMMEngineDoubao(LMMEngine):
|
|
237
|
+
def __init__(
|
|
238
|
+
self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs
|
|
239
|
+
):
|
|
240
|
+
assert model is not None, "model must be provided"
|
|
241
|
+
self.model = model
|
|
242
|
+
self.provider = "llm-doubao"
|
|
243
|
+
|
|
244
|
+
api_key = api_key or os.getenv("ARK_API_KEY")
|
|
245
|
+
if api_key is None:
|
|
246
|
+
raise ValueError(
|
|
247
|
+
"An API Key needs to be provided in either the api_key parameter or as an environment variable named ARK_API_KEY"
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
self.base_url = base_url or "https://ark.cn-beijing.volces.com/api/v3"
|
|
251
|
+
self.api_key = api_key
|
|
252
|
+
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
|
253
|
+
|
|
254
|
+
self.llm_client = OpenAI(base_url=self.base_url, api_key=self.api_key)
|
|
255
|
+
|
|
256
|
+
@backoff.on_exception(
|
|
257
|
+
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
|
258
|
+
)
|
|
259
|
+
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
|
260
|
+
"""Generate the next message based on previous messages"""
|
|
261
|
+
response = self.llm_client.chat.completions.create(
|
|
262
|
+
model=self.model,
|
|
263
|
+
messages=messages,
|
|
264
|
+
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
|
265
|
+
temperature=temperature,
|
|
266
|
+
extra_body={
|
|
267
|
+
"thinking": {
|
|
268
|
+
"type": "disabled",
|
|
269
|
+
# "type": "enabled",
|
|
270
|
+
# "type": "auto",
|
|
271
|
+
}
|
|
272
|
+
},
|
|
273
|
+
**kwargs,
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
content = response.choices[0].message.content
|
|
277
|
+
total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
|
|
278
|
+
|
|
279
|
+
return content, total_tokens, cost
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
class LMMEngineAnthropic(LMMEngine):
|
|
283
|
+
def __init__(
|
|
284
|
+
self, base_url=None, api_key=None, model=None, thinking=False, **kwargs
|
|
285
|
+
):
|
|
286
|
+
assert model is not None, "model must be provided"
|
|
287
|
+
self.model = model
|
|
288
|
+
self.thinking = thinking
|
|
289
|
+
self.provider = "llm-anthropic"
|
|
290
|
+
|
|
291
|
+
api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
|
|
292
|
+
if api_key is None:
|
|
293
|
+
raise ValueError(
|
|
294
|
+
"An API Key needs to be provided in either the api_key parameter or as an environment variable named ANTHROPIC_API_KEY"
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
self.api_key = api_key
|
|
298
|
+
|
|
299
|
+
self.llm_client = Anthropic(api_key=self.api_key)
|
|
300
|
+
|
|
301
|
+
@backoff.on_exception(
|
|
302
|
+
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
|
303
|
+
)
|
|
304
|
+
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
|
305
|
+
"""Generate the next message based on previous messages"""
|
|
306
|
+
if self.thinking:
|
|
307
|
+
response = self.llm_client.messages.create(
|
|
308
|
+
system=messages[0]["content"][0]["text"],
|
|
309
|
+
model=self.model,
|
|
310
|
+
messages=messages[1:],
|
|
311
|
+
max_tokens=8192,
|
|
312
|
+
thinking={"type": "enabled", "budget_tokens": 4096},
|
|
313
|
+
**kwargs,
|
|
314
|
+
)
|
|
315
|
+
thoughts = response.content[0].thinking
|
|
316
|
+
print("CLAUDE 3.7 THOUGHTS:", thoughts)
|
|
317
|
+
content = response.content[1].text
|
|
318
|
+
else:
|
|
319
|
+
response = self.llm_client.messages.create(
|
|
320
|
+
system=messages[0]["content"][0]["text"],
|
|
321
|
+
model=self.model,
|
|
322
|
+
messages=messages[1:],
|
|
323
|
+
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
|
324
|
+
temperature=temperature,
|
|
325
|
+
**kwargs,
|
|
326
|
+
)
|
|
327
|
+
content = response.content[0].text
|
|
328
|
+
|
|
329
|
+
total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
|
|
330
|
+
return content, total_tokens, cost
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
class LMMEngineGemini(LMMEngine):
|
|
334
|
+
def __init__(
|
|
335
|
+
self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs
|
|
336
|
+
):
|
|
337
|
+
assert model is not None, "model must be provided"
|
|
338
|
+
self.model = model
|
|
339
|
+
self.provider = "llm-gemini"
|
|
340
|
+
|
|
341
|
+
api_key = api_key or os.getenv("GEMINI_API_KEY")
|
|
342
|
+
if api_key is None:
|
|
343
|
+
raise ValueError(
|
|
344
|
+
"An API Key needs to be provided in either the api_key parameter or as an environment variable named GEMINI_API_KEY"
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
self.base_url = base_url or os.getenv("GEMINI_ENDPOINT_URL")
|
|
348
|
+
if self.base_url is None:
|
|
349
|
+
raise ValueError(
|
|
350
|
+
"An endpoint URL needs to be provided in either the endpoint_url parameter or as an environment variable named GEMINI_ENDPOINT_URL"
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
self.api_key = api_key
|
|
354
|
+
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
|
355
|
+
|
|
356
|
+
self.llm_client = OpenAI(base_url=self.base_url, api_key=self.api_key)
|
|
357
|
+
|
|
358
|
+
@backoff.on_exception(
|
|
359
|
+
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
|
360
|
+
)
|
|
361
|
+
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
|
362
|
+
"""Generate the next message based on previous messages"""
|
|
363
|
+
response = self.llm_client.chat.completions.create(
|
|
364
|
+
model=self.model,
|
|
365
|
+
messages=messages,
|
|
366
|
+
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
|
367
|
+
temperature=temperature,
|
|
368
|
+
# reasoning_effort="low",
|
|
369
|
+
extra_body={
|
|
370
|
+
'extra_body': {
|
|
371
|
+
"google": {
|
|
372
|
+
"thinking_config": {
|
|
373
|
+
"thinking_budget": 128,
|
|
374
|
+
"include_thoughts": True
|
|
375
|
+
}
|
|
376
|
+
}
|
|
377
|
+
}
|
|
378
|
+
},
|
|
379
|
+
**kwargs,
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
content = response.choices[0].message.content
|
|
383
|
+
total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
|
|
384
|
+
|
|
385
|
+
return content, total_tokens, cost
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
class LMMEngineOpenRouter(LMMEngine):
|
|
390
|
+
def __init__(
|
|
391
|
+
self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs
|
|
392
|
+
):
|
|
393
|
+
assert model is not None, "model must be provided"
|
|
394
|
+
self.model = model
|
|
395
|
+
self.provider = "llm-openrouter"
|
|
396
|
+
|
|
397
|
+
api_key = api_key or os.getenv("OPENROUTER_API_KEY")
|
|
398
|
+
if api_key is None:
|
|
399
|
+
raise ValueError(
|
|
400
|
+
"An API Key needs to be provided in either the api_key parameter or as an environment variable named OPENROUTER_API_KEY"
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
self.base_url = base_url or os.getenv("OPEN_ROUTER_ENDPOINT_URL")
|
|
404
|
+
if self.base_url is None:
|
|
405
|
+
raise ValueError(
|
|
406
|
+
"An endpoint URL needs to be provided in either the endpoint_url parameter or as an environment variable named OPEN_ROUTER_ENDPOINT_URL"
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
self.api_key = api_key
|
|
410
|
+
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
|
411
|
+
|
|
412
|
+
self.llm_client = OpenAI(base_url=self.base_url, api_key=self.api_key)
|
|
413
|
+
|
|
414
|
+
@backoff.on_exception(
|
|
415
|
+
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
|
416
|
+
)
|
|
417
|
+
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
|
418
|
+
"""Generate the next message based on previous messages"""
|
|
419
|
+
response = self.llm_client.chat.completions.create(
|
|
420
|
+
model=self.model,
|
|
421
|
+
messages=messages,
|
|
422
|
+
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
|
423
|
+
temperature=temperature,
|
|
424
|
+
**kwargs,
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
content = response.choices[0].message.content
|
|
428
|
+
total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
|
|
429
|
+
|
|
430
|
+
return content, total_tokens, cost
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
class LMMEngineAzureOpenAI(LMMEngine):
|
|
434
|
+
def __init__(
|
|
435
|
+
self,
|
|
436
|
+
base_url=None,
|
|
437
|
+
api_key=None,
|
|
438
|
+
azure_endpoint=None,
|
|
439
|
+
model=None,
|
|
440
|
+
api_version=None,
|
|
441
|
+
rate_limit=-1,
|
|
442
|
+
**kwargs
|
|
443
|
+
):
|
|
444
|
+
assert model is not None, "model must be provided"
|
|
445
|
+
self.model = model
|
|
446
|
+
self.provider = "llm-azureopenai"
|
|
447
|
+
|
|
448
|
+
assert api_version is not None, "api_version must be provided"
|
|
449
|
+
self.api_version = api_version
|
|
450
|
+
|
|
451
|
+
api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
|
|
452
|
+
if api_key is None:
|
|
453
|
+
raise ValueError(
|
|
454
|
+
"An API Key needs to be provided in either the api_key parameter or as an environment variable named AZURE_OPENAI_API_KEY"
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
self.api_key = api_key
|
|
458
|
+
|
|
459
|
+
azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
|
|
460
|
+
if azure_endpoint is None:
|
|
461
|
+
raise ValueError(
|
|
462
|
+
"An Azure API endpoint needs to be provided in either the azure_endpoint parameter or as an environment variable named AZURE_OPENAI_ENDPOINT"
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
self.azure_endpoint = azure_endpoint
|
|
466
|
+
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
|
467
|
+
|
|
468
|
+
self.llm_client = AzureOpenAI(
|
|
469
|
+
azure_endpoint=self.azure_endpoint,
|
|
470
|
+
api_key=self.api_key,
|
|
471
|
+
api_version=self.api_version,
|
|
472
|
+
)
|
|
473
|
+
self.cost = 0.0
|
|
474
|
+
|
|
475
|
+
# @backoff.on_exception(backoff.expo, (APIConnectionError, APIError, RateLimitError), max_tries=10)
|
|
476
|
+
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
|
477
|
+
"""Generate the next message based on previous messages"""
|
|
478
|
+
response = self.llm_client.chat.completions.create(
|
|
479
|
+
model=self.model,
|
|
480
|
+
messages=messages,
|
|
481
|
+
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
|
482
|
+
temperature=temperature,
|
|
483
|
+
**kwargs,
|
|
484
|
+
)
|
|
485
|
+
content = response.choices[0].message.content
|
|
486
|
+
total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
|
|
487
|
+
return content, total_tokens, cost
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
class LMMEnginevLLM(LMMEngine):
|
|
491
|
+
def __init__(
|
|
492
|
+
self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs
|
|
493
|
+
):
|
|
494
|
+
assert model is not None, "model must be provided"
|
|
495
|
+
self.model = model
|
|
496
|
+
self.api_key = api_key
|
|
497
|
+
self.provider = "llm-vllm"
|
|
498
|
+
|
|
499
|
+
self.base_url = base_url or os.getenv("vLLM_ENDPOINT_URL")
|
|
500
|
+
if self.base_url is None:
|
|
501
|
+
raise ValueError(
|
|
502
|
+
"An endpoint URL needs to be provided in either the endpoint_url parameter or as an environment variable named vLLM_ENDPOINT_URL"
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
|
506
|
+
|
|
507
|
+
self.llm_client = OpenAI(base_url=self.base_url, api_key=self.api_key)
|
|
508
|
+
|
|
509
|
+
# @backoff.on_exception(backoff.expo, (APIConnectionError, APIError, RateLimitError), max_tries=10)
|
|
510
|
+
# TODO: Default params chosen for the Qwen model
|
|
511
|
+
def generate(
|
|
512
|
+
self,
|
|
513
|
+
messages,
|
|
514
|
+
temperature=0.0,
|
|
515
|
+
top_p=0.8,
|
|
516
|
+
repetition_penalty=1.05,
|
|
517
|
+
max_new_tokens=512,
|
|
518
|
+
**kwargs
|
|
519
|
+
):
|
|
520
|
+
"""Generate the next message based on previous messages"""
|
|
521
|
+
response = self.llm_client.chat.completions.create(
|
|
522
|
+
model=self.model,
|
|
523
|
+
messages=messages,
|
|
524
|
+
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
|
525
|
+
temperature=temperature,
|
|
526
|
+
top_p=top_p,
|
|
527
|
+
extra_body={"repetition_penalty": repetition_penalty},
|
|
528
|
+
)
|
|
529
|
+
content = response.choices[0].message.content
|
|
530
|
+
total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
|
|
531
|
+
return content, total_tokens, cost
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
class LMMEngineHuggingFace(LMMEngine):
|
|
535
|
+
def __init__(self, base_url=None, api_key=None, rate_limit=-1, **kwargs):
|
|
536
|
+
assert base_url is not None, "HuggingFace endpoint must be provided"
|
|
537
|
+
self.base_url = base_url
|
|
538
|
+
self.model = base_url.split('/')[-1] if base_url else "huggingface-tgi"
|
|
539
|
+
self.provider = "llm-huggingface"
|
|
540
|
+
|
|
541
|
+
api_key = api_key or os.getenv("HF_TOKEN")
|
|
542
|
+
if api_key is None:
|
|
543
|
+
raise ValueError(
|
|
544
|
+
"A HuggingFace token needs to be provided in either the api_key parameter or as an environment variable named HF_TOKEN"
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
self.api_key = api_key
|
|
548
|
+
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
|
549
|
+
|
|
550
|
+
self.llm_client = OpenAI(base_url=self.base_url, api_key=self.api_key)
|
|
551
|
+
|
|
552
|
+
@backoff.on_exception(
|
|
553
|
+
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
|
554
|
+
)
|
|
555
|
+
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
|
556
|
+
"""Generate the next message based on previous messages"""
|
|
557
|
+
response = self.llm_client.chat.completions.create(
|
|
558
|
+
model="tgi",
|
|
559
|
+
messages=messages,
|
|
560
|
+
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
|
561
|
+
temperature=temperature,
|
|
562
|
+
**kwargs,
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
content = response.choices[0].message.content
|
|
566
|
+
total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
|
|
567
|
+
|
|
568
|
+
return content, total_tokens, cost
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
class LMMEngineDeepSeek(LMMEngine):
|
|
572
|
+
def __init__(
|
|
573
|
+
self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs
|
|
574
|
+
):
|
|
575
|
+
assert model is not None, "model must be provided"
|
|
576
|
+
self.model = model
|
|
577
|
+
self.provider = "llm-deepseek"
|
|
578
|
+
|
|
579
|
+
api_key = api_key or os.getenv("DEEPSEEK_API_KEY")
|
|
580
|
+
if api_key is None:
|
|
581
|
+
raise ValueError(
|
|
582
|
+
"An API Key needs to be provided in either the api_key parameter or as an environment variable named DEEPSEEK_API_KEY"
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
self.base_url = base_url or "https://api.deepseek.com"
|
|
586
|
+
self.api_key = api_key
|
|
587
|
+
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
|
588
|
+
|
|
589
|
+
self.llm_client = OpenAI(base_url=self.base_url, api_key=self.api_key)
|
|
590
|
+
|
|
591
|
+
@backoff.on_exception(
|
|
592
|
+
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
|
593
|
+
)
|
|
594
|
+
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
|
595
|
+
"""Generate the next message based on previous messages"""
|
|
596
|
+
response = self.llm_client.chat.completions.create(
|
|
597
|
+
model=self.model,
|
|
598
|
+
messages=messages,
|
|
599
|
+
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
|
600
|
+
temperature=temperature,
|
|
601
|
+
**kwargs,
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
content = response.choices[0].message.content
|
|
605
|
+
total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
|
|
606
|
+
return content, total_tokens, cost
|
|
607
|
+
|
|
608
|
+
|
|
609
|
+
class LMMEngineZhipu(LMMEngine):
|
|
610
|
+
def __init__(
|
|
611
|
+
self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs
|
|
612
|
+
):
|
|
613
|
+
assert model is not None, "model must be provided"
|
|
614
|
+
self.model = model
|
|
615
|
+
self.provider = "llm-zhipu"
|
|
616
|
+
|
|
617
|
+
api_key = api_key or os.getenv("ZHIPU_API_KEY")
|
|
618
|
+
if api_key is None:
|
|
619
|
+
raise ValueError(
|
|
620
|
+
"An API Key needs to be provided in either the api_key parameter or as an environment variable named ZHIPU_API_KEY"
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
self.api_key = api_key
|
|
624
|
+
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
|
625
|
+
|
|
626
|
+
# Use ZhipuAI client directly instead of OpenAI compatibility layer
|
|
627
|
+
self.llm_client = ZhipuAI(api_key=self.api_key)
|
|
628
|
+
|
|
629
|
+
@backoff.on_exception(
|
|
630
|
+
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
|
631
|
+
)
|
|
632
|
+
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
|
633
|
+
"""Generate the next message based on previous messages"""
|
|
634
|
+
response = self.llm_client.chat.completions.create(
|
|
635
|
+
model=self.model,
|
|
636
|
+
messages=messages,
|
|
637
|
+
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
|
638
|
+
temperature=temperature,
|
|
639
|
+
**kwargs,
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
content = response.choices[0].message.content # type: ignore
|
|
643
|
+
total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
|
|
644
|
+
return content, total_tokens, cost
|
|
645
|
+
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
class LMMEngineGroq(LMMEngine):
|
|
649
|
+
def __init__(
|
|
650
|
+
self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs
|
|
651
|
+
):
|
|
652
|
+
assert model is not None, "model must be provided"
|
|
653
|
+
self.model = model
|
|
654
|
+
self.provider = "llm-groq"
|
|
655
|
+
|
|
656
|
+
api_key = api_key or os.getenv("GROQ_API_KEY")
|
|
657
|
+
if api_key is None:
|
|
658
|
+
raise ValueError(
|
|
659
|
+
"An API Key needs to be provided in either the api_key parameter or as an environment variable named GROQ_API_KEY"
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
self.api_key = api_key
|
|
663
|
+
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
|
664
|
+
|
|
665
|
+
# Use Groq client directly
|
|
666
|
+
self.llm_client = Groq(api_key=self.api_key)
|
|
667
|
+
|
|
668
|
+
@backoff.on_exception(
|
|
669
|
+
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
|
670
|
+
)
|
|
671
|
+
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
|
672
|
+
"""Generate the next message based on previous messages"""
|
|
673
|
+
response = self.llm_client.chat.completions.create(
|
|
674
|
+
model=self.model,
|
|
675
|
+
messages=messages,
|
|
676
|
+
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
|
677
|
+
temperature=temperature,
|
|
678
|
+
**kwargs,
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
content = response.choices[0].message.content
|
|
682
|
+
total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
|
|
683
|
+
return content, total_tokens, cost
|
|
684
|
+
|
|
685
|
+
|
|
686
|
+
class LMMEngineSiliconflow(LMMEngine):
|
|
687
|
+
def __init__(
|
|
688
|
+
self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs
|
|
689
|
+
):
|
|
690
|
+
assert model is not None, "model must be provided"
|
|
691
|
+
self.model = model
|
|
692
|
+
self.provider = "llm-siliconflow"
|
|
693
|
+
|
|
694
|
+
api_key = api_key or os.getenv("SILICONFLOW_API_KEY")
|
|
695
|
+
if api_key is None:
|
|
696
|
+
raise ValueError(
|
|
697
|
+
"An API Key needs to be provided in either the api_key parameter or as an environment variable named SILICONFLOW_API_KEY"
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
self.base_url = base_url or "https://api.siliconflow.cn/v1"
|
|
701
|
+
self.api_key = api_key
|
|
702
|
+
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
|
703
|
+
|
|
704
|
+
self.llm_client = OpenAI(base_url=self.base_url, api_key=self.api_key)
|
|
705
|
+
|
|
706
|
+
@backoff.on_exception(
|
|
707
|
+
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
|
708
|
+
)
|
|
709
|
+
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
|
710
|
+
"""Generate the next message based on previous messages"""
|
|
711
|
+
response = self.llm_client.chat.completions.create(
|
|
712
|
+
model=self.model,
|
|
713
|
+
messages=messages,
|
|
714
|
+
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
|
715
|
+
temperature=temperature,
|
|
716
|
+
**kwargs,
|
|
717
|
+
)
|
|
718
|
+
|
|
719
|
+
content = response.choices[0].message.content
|
|
720
|
+
total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
|
|
721
|
+
return content, total_tokens, cost
|
|
722
|
+
|
|
723
|
+
|
|
724
|
+
class LMMEngineMonica(LMMEngine):
|
|
725
|
+
def __init__(
|
|
726
|
+
self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs
|
|
727
|
+
):
|
|
728
|
+
assert model is not None, "model must be provided"
|
|
729
|
+
self.model = model
|
|
730
|
+
self.provider = "llm-monica"
|
|
731
|
+
|
|
732
|
+
api_key = api_key or os.getenv("MONICA_API_KEY")
|
|
733
|
+
if api_key is None:
|
|
734
|
+
raise ValueError(
|
|
735
|
+
"An API Key needs to be provided in either the api_key parameter or as an environment variable named MONICA_API_KEY"
|
|
736
|
+
)
|
|
737
|
+
|
|
738
|
+
self.base_url = base_url or "https://openapi.monica.im/v1"
|
|
739
|
+
self.api_key = api_key
|
|
740
|
+
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
|
741
|
+
|
|
742
|
+
self.llm_client = OpenAI(base_url=self.base_url, api_key=self.api_key)
|
|
743
|
+
|
|
744
|
+
@backoff.on_exception(
|
|
745
|
+
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
|
746
|
+
)
|
|
747
|
+
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
|
748
|
+
"""Generate the next message based on previous messages"""
|
|
749
|
+
response = self.llm_client.chat.completions.create(
|
|
750
|
+
model=self.model,
|
|
751
|
+
messages=messages,
|
|
752
|
+
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
|
753
|
+
temperature=temperature,
|
|
754
|
+
**kwargs,
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
content = response.choices[0].message.content
|
|
758
|
+
total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
|
|
759
|
+
return content, total_tokens, cost
|
|
760
|
+
|
|
761
|
+
|
|
762
|
+
class LMMEngineAWSBedrock(LMMEngine):
|
|
763
|
+
def __init__(
|
|
764
|
+
self,
|
|
765
|
+
aws_access_key=None,
|
|
766
|
+
aws_secret_key=None,
|
|
767
|
+
aws_region=None,
|
|
768
|
+
model=None,
|
|
769
|
+
rate_limit=-1,
|
|
770
|
+
**kwargs
|
|
771
|
+
):
|
|
772
|
+
assert model is not None, "model must be provided"
|
|
773
|
+
self.model = model
|
|
774
|
+
self.provider = "llm-bedrock"
|
|
775
|
+
|
|
776
|
+
# Claude model mapping for AWS Bedrock
|
|
777
|
+
self.claude_model_map = {
|
|
778
|
+
"claude-opus-4": "anthropic.claude-opus-4-20250514-v1:0",
|
|
779
|
+
"claude-sonnet-4": "anthropic.claude-sonnet-4-20250514-v1:0",
|
|
780
|
+
"claude-3-7-sonnet": "anthropic.claude-3-7-sonnet-20250219-v1:0",
|
|
781
|
+
"claude-3-5-sonnet": "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
|
782
|
+
"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
|
783
|
+
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
|
784
|
+
"claude-3-5-haiku": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
|
785
|
+
"claude-3-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
|
|
786
|
+
"claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0",
|
|
787
|
+
"claude-3-opus": "anthropic.claude-3-opus-20240229-v1:0",
|
|
788
|
+
}
|
|
789
|
+
|
|
790
|
+
# Get the actual Bedrock model ID
|
|
791
|
+
self.bedrock_model_id = self.claude_model_map.get(model, model)
|
|
792
|
+
|
|
793
|
+
# AWS credentials
|
|
794
|
+
aws_access_key = aws_access_key or os.getenv("AWS_ACCESS_KEY_ID")
|
|
795
|
+
aws_secret_key = aws_secret_key or os.getenv("AWS_SECRET_ACCESS_KEY")
|
|
796
|
+
aws_region = aws_region or os.getenv("AWS_DEFAULT_REGION") or "us-west-2"
|
|
797
|
+
|
|
798
|
+
if aws_access_key is None:
|
|
799
|
+
raise ValueError(
|
|
800
|
+
"AWS Access Key needs to be provided in either the aws_access_key parameter or as an environment variable named AWS_ACCESS_KEY_ID"
|
|
801
|
+
)
|
|
802
|
+
if aws_secret_key is None:
|
|
803
|
+
raise ValueError(
|
|
804
|
+
"AWS Secret Key needs to be provided in either the aws_secret_key parameter or as an environment variable named AWS_SECRET_ACCESS_KEY"
|
|
805
|
+
)
|
|
806
|
+
|
|
807
|
+
self.aws_region = aws_region
|
|
808
|
+
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
|
809
|
+
|
|
810
|
+
# Initialize Bedrock client
|
|
811
|
+
self.bedrock_client = boto3.client(
|
|
812
|
+
service_name="bedrock-runtime",
|
|
813
|
+
region_name=aws_region,
|
|
814
|
+
aws_access_key_id=aws_access_key,
|
|
815
|
+
aws_secret_access_key=aws_secret_key
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
@backoff.on_exception(
|
|
819
|
+
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
|
820
|
+
)
|
|
821
|
+
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
|
822
|
+
"""Generate the next message based on previous messages"""
|
|
823
|
+
|
|
824
|
+
# Convert messages to Bedrock format
|
|
825
|
+
# Extract system message if present
|
|
826
|
+
system_message = None
|
|
827
|
+
user_messages = []
|
|
828
|
+
|
|
829
|
+
for message in messages:
|
|
830
|
+
if message["role"] == "system":
|
|
831
|
+
if isinstance(message["content"], list):
|
|
832
|
+
system_message = message["content"][0]["text"]
|
|
833
|
+
else:
|
|
834
|
+
system_message = message["content"]
|
|
835
|
+
else:
|
|
836
|
+
# Handle both list and string content formats
|
|
837
|
+
if isinstance(message["content"], list):
|
|
838
|
+
content = message["content"][0]["text"] if message["content"] else ""
|
|
839
|
+
else:
|
|
840
|
+
content = message["content"]
|
|
841
|
+
|
|
842
|
+
user_messages.append({
|
|
843
|
+
"role": message["role"],
|
|
844
|
+
"content": content
|
|
845
|
+
})
|
|
846
|
+
|
|
847
|
+
# Prepare the body for Bedrock
|
|
848
|
+
body = {
|
|
849
|
+
"max_tokens": max_new_tokens if max_new_tokens else 4096,
|
|
850
|
+
"messages": user_messages,
|
|
851
|
+
"anthropic_version": "bedrock-2023-05-31"
|
|
852
|
+
}
|
|
853
|
+
|
|
854
|
+
if temperature > 0:
|
|
855
|
+
body["temperature"] = temperature
|
|
856
|
+
|
|
857
|
+
if system_message:
|
|
858
|
+
body["system"] = system_message
|
|
859
|
+
|
|
860
|
+
try:
|
|
861
|
+
response = self.bedrock_client.invoke_model(
|
|
862
|
+
body=json.dumps(body),
|
|
863
|
+
modelId=self.bedrock_model_id
|
|
864
|
+
)
|
|
865
|
+
|
|
866
|
+
response_body = json.loads(response.get("body").read())
|
|
867
|
+
|
|
868
|
+
if "content" in response_body and response_body["content"]:
|
|
869
|
+
content = response_body["content"][0]["text"]
|
|
870
|
+
else:
|
|
871
|
+
raise ValueError("No content in response")
|
|
872
|
+
|
|
873
|
+
total_tokens, cost = calculate_tokens_and_cost(response_body, self.provider, self.model)
|
|
874
|
+
return content, total_tokens, cost
|
|
875
|
+
|
|
876
|
+
except Exception as e:
|
|
877
|
+
print(f"AWS Bedrock error: {e}")
|
|
878
|
+
raise
|
|
879
|
+
|
|
880
|
+
# ==================== Embedding ====================
|
|
881
|
+
|
|
882
|
+
class OpenAIEmbeddingEngine(LMMEngine):
|
|
883
|
+
def __init__(
|
|
884
|
+
self,
|
|
885
|
+
embedding_model: str = "text-embedding-3-small",
|
|
886
|
+
api_key=None,
|
|
887
|
+
**kwargs
|
|
888
|
+
):
|
|
889
|
+
"""Init an OpenAI Embedding engine
|
|
890
|
+
|
|
891
|
+
Args:
|
|
892
|
+
embedding_model (str, optional): Model name. Defaults to "text-embedding-3-small".
|
|
893
|
+
api_key (_type_, optional): Auth key from OpenAI. Defaults to None.
|
|
894
|
+
"""
|
|
895
|
+
self.model = embedding_model
|
|
896
|
+
self.provider = "embedding-openai"
|
|
897
|
+
|
|
898
|
+
api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
899
|
+
if api_key is None:
|
|
900
|
+
raise ValueError(
|
|
901
|
+
"An API Key needs to be provided in either the api_key parameter or as an environment variable named OPENAI_API_KEY"
|
|
902
|
+
)
|
|
903
|
+
self.api_key = api_key
|
|
904
|
+
|
|
905
|
+
@backoff.on_exception(
|
|
906
|
+
backoff.expo,
|
|
907
|
+
(
|
|
908
|
+
APIError,
|
|
909
|
+
RateLimitError,
|
|
910
|
+
APIConnectionError,
|
|
911
|
+
),
|
|
912
|
+
)
|
|
913
|
+
def get_embeddings(self, text: str) -> Tuple[np.ndarray, List[int], float]:
|
|
914
|
+
client = OpenAI(api_key=self.api_key)
|
|
915
|
+
response = client.embeddings.create(model=self.model, input=text)
|
|
916
|
+
|
|
917
|
+
embeddings = np.array([data.embedding for data in response.data])
|
|
918
|
+
total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
|
|
919
|
+
|
|
920
|
+
return embeddings, total_tokens, cost
|
|
921
|
+
|
|
922
|
+
|
|
923
|
+
|
|
924
|
+
class GeminiEmbeddingEngine(LMMEngine):
|
|
925
|
+
def __init__(
|
|
926
|
+
self,
|
|
927
|
+
embedding_model: str = "text-embedding-004",
|
|
928
|
+
api_key=None,
|
|
929
|
+
**kwargs
|
|
930
|
+
):
|
|
931
|
+
"""Init an Gemini Embedding engine
|
|
932
|
+
|
|
933
|
+
Args:
|
|
934
|
+
embedding_model (str, optional): Model name. Defaults to "text-embedding-004".
|
|
935
|
+
api_key (_type_, optional): Auth key from Gemini. Defaults to None.
|
|
936
|
+
"""
|
|
937
|
+
self.model = embedding_model
|
|
938
|
+
self.provider = "embedding-gemini"
|
|
939
|
+
|
|
940
|
+
api_key = api_key or os.getenv("GEMINI_API_KEY")
|
|
941
|
+
if api_key is None:
|
|
942
|
+
raise ValueError(
|
|
943
|
+
"An API Key needs to be provided in either the api_key parameter or as an environment variable named GEMINI_API_KEY"
|
|
944
|
+
)
|
|
945
|
+
self.api_key = api_key
|
|
946
|
+
|
|
947
|
+
@backoff.on_exception(
|
|
948
|
+
backoff.expo,
|
|
949
|
+
(
|
|
950
|
+
APIError,
|
|
951
|
+
RateLimitError,
|
|
952
|
+
APIConnectionError,
|
|
953
|
+
),
|
|
954
|
+
)
|
|
955
|
+
def get_embeddings(self, text: str) -> Tuple[np.ndarray, List[int], float]:
|
|
956
|
+
client = genai.Client(api_key=self.api_key)
|
|
957
|
+
|
|
958
|
+
result = client.models.embed_content(
|
|
959
|
+
model=self.model,
|
|
960
|
+
contents=text,
|
|
961
|
+
config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"),
|
|
962
|
+
)
|
|
963
|
+
|
|
964
|
+
embeddings = np.array([i.values for i in result.embeddings]) # type: ignore
|
|
965
|
+
total_tokens, cost = calculate_tokens_and_cost(result, self.provider, self.model)
|
|
966
|
+
|
|
967
|
+
return embeddings, total_tokens, cost
|
|
968
|
+
|
|
969
|
+
|
|
970
|
+
|
|
971
|
+
class AzureOpenAIEmbeddingEngine(LMMEngine):
|
|
972
|
+
def __init__(
|
|
973
|
+
self,
|
|
974
|
+
embedding_model: str = "text-embedding-3-small",
|
|
975
|
+
api_key=None,
|
|
976
|
+
api_version=None,
|
|
977
|
+
endpoint_url=None,
|
|
978
|
+
**kwargs
|
|
979
|
+
):
|
|
980
|
+
"""Init an Azure OpenAI Embedding engine
|
|
981
|
+
|
|
982
|
+
Args:
|
|
983
|
+
embedding_model (str, optional): Model name. Defaults to "text-embedding-3-small".
|
|
984
|
+
api_key (_type_, optional): Auth key from Azure OpenAI. Defaults to None.
|
|
985
|
+
api_version (_type_, optional): API version. Defaults to None.
|
|
986
|
+
endpoint_url (_type_, optional): Endpoint URL. Defaults to None.
|
|
987
|
+
"""
|
|
988
|
+
self.model = embedding_model
|
|
989
|
+
self.provider = "embedding-azureopenai"
|
|
990
|
+
|
|
991
|
+
api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
|
|
992
|
+
if api_key is None:
|
|
993
|
+
raise ValueError(
|
|
994
|
+
"An API Key needs to be provided in either the api_key parameter or as an environment variable named AZURE_OPENAI_API_KEY"
|
|
995
|
+
)
|
|
996
|
+
self.api_key = api_key
|
|
997
|
+
|
|
998
|
+
api_version = api_version or os.getenv("OPENAI_API_VERSION")
|
|
999
|
+
if api_version is None:
|
|
1000
|
+
raise ValueError(
|
|
1001
|
+
"An API Version needs to be provided in either the api_version parameter or as an environment variable named OPENAI_API_VERSION"
|
|
1002
|
+
)
|
|
1003
|
+
self.api_version = api_version
|
|
1004
|
+
|
|
1005
|
+
endpoint_url = endpoint_url or os.getenv("AZURE_OPENAI_ENDPOINT")
|
|
1006
|
+
if endpoint_url is None:
|
|
1007
|
+
raise ValueError(
|
|
1008
|
+
"An Endpoint URL needs to be provided in either the endpoint_url parameter or as an environment variable named AZURE_OPENAI_ENDPOINT"
|
|
1009
|
+
)
|
|
1010
|
+
self.endpoint_url = endpoint_url
|
|
1011
|
+
|
|
1012
|
+
@backoff.on_exception(
|
|
1013
|
+
backoff.expo,
|
|
1014
|
+
(
|
|
1015
|
+
APIError,
|
|
1016
|
+
RateLimitError,
|
|
1017
|
+
APIConnectionError,
|
|
1018
|
+
),
|
|
1019
|
+
)
|
|
1020
|
+
def get_embeddings(self, text: str) -> Tuple[np.ndarray, List[int], float]:
|
|
1021
|
+
client = AzureOpenAI(
|
|
1022
|
+
api_key=self.api_key,
|
|
1023
|
+
api_version=self.api_version,
|
|
1024
|
+
azure_endpoint=self.endpoint_url,
|
|
1025
|
+
)
|
|
1026
|
+
response = client.embeddings.create(input=text, model=self.model)
|
|
1027
|
+
|
|
1028
|
+
embeddings = np.array([data.embedding for data in response.data])
|
|
1029
|
+
total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
|
|
1030
|
+
|
|
1031
|
+
return embeddings, total_tokens, cost
|
|
1032
|
+
|
|
1033
|
+
|
|
1034
|
+
class DashScopeEmbeddingEngine(LMMEngine):
|
|
1035
|
+
def __init__(
|
|
1036
|
+
self,
|
|
1037
|
+
embedding_model: str = "text-embedding-v4",
|
|
1038
|
+
api_key=None,
|
|
1039
|
+
dimensions: int = 1024,
|
|
1040
|
+
**kwargs
|
|
1041
|
+
):
|
|
1042
|
+
"""Init a DashScope (阿里云百炼) Embedding engine
|
|
1043
|
+
|
|
1044
|
+
Args:
|
|
1045
|
+
embedding_model (str, optional): Model name. Defaults to "text-embedding-v4".
|
|
1046
|
+
api_key (_type_, optional): Auth key from DashScope. Defaults to None.
|
|
1047
|
+
dimensions (int, optional): Embedding dimensions. Defaults to 1024.
|
|
1048
|
+
"""
|
|
1049
|
+
self.model = embedding_model
|
|
1050
|
+
self.dimensions = dimensions
|
|
1051
|
+
self.provider = "embedding-qwen"
|
|
1052
|
+
|
|
1053
|
+
api_key = api_key or os.getenv("DASHSCOPE_API_KEY")
|
|
1054
|
+
if api_key is None:
|
|
1055
|
+
raise ValueError(
|
|
1056
|
+
"An API Key needs to be provided in either the api_key parameter or as an environment variable named DASHSCOPE_API_KEY"
|
|
1057
|
+
)
|
|
1058
|
+
self.api_key = api_key
|
|
1059
|
+
|
|
1060
|
+
# Initialize OpenAI client with DashScope base URL
|
|
1061
|
+
self.client = OpenAI(
|
|
1062
|
+
api_key=self.api_key,
|
|
1063
|
+
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
|
1064
|
+
)
|
|
1065
|
+
|
|
1066
|
+
@backoff.on_exception(
|
|
1067
|
+
backoff.expo,
|
|
1068
|
+
(
|
|
1069
|
+
APIError,
|
|
1070
|
+
RateLimitError,
|
|
1071
|
+
APIConnectionError,
|
|
1072
|
+
),
|
|
1073
|
+
)
|
|
1074
|
+
def get_embeddings(self, text: str) -> Tuple[np.ndarray, List[int], float]:
|
|
1075
|
+
response = self.client.embeddings.create(
|
|
1076
|
+
model=self.model,
|
|
1077
|
+
input=text,
|
|
1078
|
+
dimensions=self.dimensions,
|
|
1079
|
+
encoding_format="float"
|
|
1080
|
+
)
|
|
1081
|
+
|
|
1082
|
+
embeddings = np.array([data.embedding for data in response.data])
|
|
1083
|
+
total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
|
|
1084
|
+
|
|
1085
|
+
return embeddings, total_tokens, cost
|
|
1086
|
+
|
|
1087
|
+
|
|
1088
|
+
|
|
1089
|
+
class DoubaoEmbeddingEngine(LMMEngine):
|
|
1090
|
+
def __init__(
|
|
1091
|
+
self,
|
|
1092
|
+
embedding_model: str = "doubao-embedding-256",
|
|
1093
|
+
api_key=None,
|
|
1094
|
+
**kwargs
|
|
1095
|
+
):
|
|
1096
|
+
"""Init a Doubao (字节跳动豆包) Embedding engine
|
|
1097
|
+
|
|
1098
|
+
Args:
|
|
1099
|
+
embedding_model (str, optional): Model name. Defaults to "doubao-embedding-256".
|
|
1100
|
+
api_key (_type_, optional): Auth key from Doubao. Defaults to None.
|
|
1101
|
+
"""
|
|
1102
|
+
self.model = embedding_model
|
|
1103
|
+
self.provider = "embedding-doubao"
|
|
1104
|
+
|
|
1105
|
+
api_key = api_key or os.getenv("ARK_API_KEY")
|
|
1106
|
+
if api_key is None:
|
|
1107
|
+
raise ValueError(
|
|
1108
|
+
"An API Key needs to be provided in either the api_key parameter or as an environment variable named ARK_API_KEY"
|
|
1109
|
+
)
|
|
1110
|
+
self.api_key = api_key
|
|
1111
|
+
self.base_url = "https://ark.cn-beijing.volces.com/api/v3"
|
|
1112
|
+
|
|
1113
|
+
# Use OpenAI-compatible client for text embeddings
|
|
1114
|
+
self.client = OpenAI(
|
|
1115
|
+
api_key=self.api_key,
|
|
1116
|
+
base_url=self.base_url
|
|
1117
|
+
)
|
|
1118
|
+
|
|
1119
|
+
@backoff.on_exception(
|
|
1120
|
+
backoff.expo,
|
|
1121
|
+
(
|
|
1122
|
+
APIError,
|
|
1123
|
+
RateLimitError,
|
|
1124
|
+
APIConnectionError,
|
|
1125
|
+
),
|
|
1126
|
+
)
|
|
1127
|
+
def get_embeddings(self, text: str) -> Tuple[np.ndarray, List[int], float]:
|
|
1128
|
+
response = self.client.embeddings.create(
|
|
1129
|
+
model=self.model,
|
|
1130
|
+
input=text,
|
|
1131
|
+
encoding_format="float"
|
|
1132
|
+
)
|
|
1133
|
+
|
|
1134
|
+
embeddings = np.array([data.embedding for data in response.data])
|
|
1135
|
+
total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
|
|
1136
|
+
|
|
1137
|
+
return embeddings, total_tokens, cost
|
|
1138
|
+
|
|
1139
|
+
|
|
1140
|
+
class JinaEmbeddingEngine(LMMEngine):
|
|
1141
|
+
def __init__(
|
|
1142
|
+
self,
|
|
1143
|
+
embedding_model: str = "jina-embeddings-v4",
|
|
1144
|
+
api_key=None,
|
|
1145
|
+
task: str = "retrieval.query",
|
|
1146
|
+
**kwargs
|
|
1147
|
+
):
|
|
1148
|
+
"""Init a Jina AI Embedding engine
|
|
1149
|
+
|
|
1150
|
+
Args:
|
|
1151
|
+
embedding_model (str, optional): Model name. Defaults to "jina-embeddings-v4".
|
|
1152
|
+
api_key (_type_, optional): Auth key from Jina AI. Defaults to None.
|
|
1153
|
+
task (str, optional): Task type. Options: "retrieval.query", "retrieval.passage", "text-matching". Defaults to "retrieval.query".
|
|
1154
|
+
"""
|
|
1155
|
+
self.model = embedding_model
|
|
1156
|
+
self.task = task
|
|
1157
|
+
self.provider = "embedding-jina"
|
|
1158
|
+
|
|
1159
|
+
api_key = api_key or os.getenv("JINA_API_KEY")
|
|
1160
|
+
if api_key is None:
|
|
1161
|
+
raise ValueError(
|
|
1162
|
+
"An API Key needs to be provided in either the api_key parameter or as an environment variable named JINA_API_KEY"
|
|
1163
|
+
)
|
|
1164
|
+
self.api_key = api_key
|
|
1165
|
+
self.base_url = "https://api.jina.ai/v1"
|
|
1166
|
+
|
|
1167
|
+
@backoff.on_exception(
|
|
1168
|
+
backoff.expo,
|
|
1169
|
+
(
|
|
1170
|
+
APIError,
|
|
1171
|
+
RateLimitError,
|
|
1172
|
+
APIConnectionError,
|
|
1173
|
+
),
|
|
1174
|
+
)
|
|
1175
|
+
def get_embeddings(self, text: str) -> Tuple[np.ndarray, List[int], float]:
|
|
1176
|
+
import requests
|
|
1177
|
+
|
|
1178
|
+
headers = {
|
|
1179
|
+
"Content-Type": "application/json",
|
|
1180
|
+
"Authorization": f"Bearer {self.api_key}"
|
|
1181
|
+
}
|
|
1182
|
+
|
|
1183
|
+
data = {
|
|
1184
|
+
"model": self.model,
|
|
1185
|
+
"task": self.task,
|
|
1186
|
+
"input": [
|
|
1187
|
+
{
|
|
1188
|
+
"text": text
|
|
1189
|
+
}
|
|
1190
|
+
]
|
|
1191
|
+
}
|
|
1192
|
+
|
|
1193
|
+
response = requests.post(
|
|
1194
|
+
f"{self.base_url}/embeddings",
|
|
1195
|
+
headers=headers,
|
|
1196
|
+
json=data
|
|
1197
|
+
)
|
|
1198
|
+
|
|
1199
|
+
if response.status_code != 200:
|
|
1200
|
+
raise Exception(f"Jina AI API error: {response.text}")
|
|
1201
|
+
|
|
1202
|
+
result = response.json()
|
|
1203
|
+
embeddings = np.array([data["embedding"] for data in result["data"]])
|
|
1204
|
+
|
|
1205
|
+
total_tokens, cost = calculate_tokens_and_cost(result, self.provider, self.model)
|
|
1206
|
+
|
|
1207
|
+
return embeddings, total_tokens, cost
|
|
1208
|
+
|
|
1209
|
+
|
|
1210
|
+
# ==================== webSearch ====================
|
|
1211
|
+
class SearchEngine:
|
|
1212
|
+
"""Base class for search engines"""
|
|
1213
|
+
pass
|
|
1214
|
+
|
|
1215
|
+
class BochaAISearchEngine(SearchEngine):
|
|
1216
|
+
def __init__(
|
|
1217
|
+
self,
|
|
1218
|
+
api_key: str|None = None,
|
|
1219
|
+
base_url: str = "https://api.bochaai.com/v1",
|
|
1220
|
+
rate_limit: int = -1,
|
|
1221
|
+
**kwargs
|
|
1222
|
+
):
|
|
1223
|
+
"""Init a Bocha AI Search engine
|
|
1224
|
+
|
|
1225
|
+
Args:
|
|
1226
|
+
api_key (str, optional): Auth key from Bocha AI. Defaults to None.
|
|
1227
|
+
base_url (str, optional): Base URL for the API. Defaults to "https://api.bochaai.com/v1".
|
|
1228
|
+
rate_limit (int, optional): Rate limit per minute. Defaults to -1 (no limit).
|
|
1229
|
+
"""
|
|
1230
|
+
api_key = api_key or os.getenv("BOCHA_API_KEY")
|
|
1231
|
+
if api_key is None:
|
|
1232
|
+
raise ValueError(
|
|
1233
|
+
"An API Key needs to be provided in either the api_key parameter or as an environment variable named BOCHA_API_KEY"
|
|
1234
|
+
)
|
|
1235
|
+
|
|
1236
|
+
self.api_key = api_key
|
|
1237
|
+
self.base_url = base_url
|
|
1238
|
+
self.endpoint = f"{base_url}/ai-search"
|
|
1239
|
+
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
|
1240
|
+
|
|
1241
|
+
@backoff.on_exception(
|
|
1242
|
+
backoff.expo,
|
|
1243
|
+
(
|
|
1244
|
+
APIConnectionError,
|
|
1245
|
+
APIError,
|
|
1246
|
+
RateLimitError,
|
|
1247
|
+
requests.exceptions.RequestException,
|
|
1248
|
+
),
|
|
1249
|
+
max_time=60
|
|
1250
|
+
)
|
|
1251
|
+
def search(
|
|
1252
|
+
self,
|
|
1253
|
+
query: str,
|
|
1254
|
+
freshness: str = "noLimit",
|
|
1255
|
+
answer: bool = True,
|
|
1256
|
+
stream: bool = False,
|
|
1257
|
+
**kwargs
|
|
1258
|
+
) -> Union[Dict[str, Any], Any]:
|
|
1259
|
+
"""Search with AI and return intelligent answer
|
|
1260
|
+
|
|
1261
|
+
Args:
|
|
1262
|
+
query (str): Search query
|
|
1263
|
+
freshness (str, optional): Freshness filter. Defaults to "noLimit".
|
|
1264
|
+
answer (bool, optional): Whether to return answer. Defaults to True.
|
|
1265
|
+
stream (bool, optional): Whether to stream response. Defaults to False.
|
|
1266
|
+
|
|
1267
|
+
Returns:
|
|
1268
|
+
Union[Dict[str, Any], Any]: AI search results with sources and answer
|
|
1269
|
+
"""
|
|
1270
|
+
headers = {
|
|
1271
|
+
'Authorization': f'Bearer {self.api_key}',
|
|
1272
|
+
'Content-Type': 'application/json'
|
|
1273
|
+
}
|
|
1274
|
+
|
|
1275
|
+
payload = {
|
|
1276
|
+
"query": query,
|
|
1277
|
+
"freshness": freshness,
|
|
1278
|
+
"answer": answer,
|
|
1279
|
+
"stream": stream,
|
|
1280
|
+
**kwargs
|
|
1281
|
+
}
|
|
1282
|
+
|
|
1283
|
+
if stream:
|
|
1284
|
+
result = self._stream_search(headers, payload)
|
|
1285
|
+
return result, [0, 0, 0], 0.06
|
|
1286
|
+
else:
|
|
1287
|
+
result = self._regular_search(headers, payload)
|
|
1288
|
+
return result, [0, 0, 0], 0.06
|
|
1289
|
+
|
|
1290
|
+
|
|
1291
|
+
def _regular_search(self, headers: Dict[str, str], payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
1292
|
+
"""Regular non-streaming search"""
|
|
1293
|
+
response = requests.post(
|
|
1294
|
+
self.endpoint,
|
|
1295
|
+
headers=headers,
|
|
1296
|
+
json=payload
|
|
1297
|
+
)
|
|
1298
|
+
|
|
1299
|
+
if response.status_code != 200:
|
|
1300
|
+
raise APIError(f"Bocha AI Search API error: {response.text}") # type: ignore
|
|
1301
|
+
|
|
1302
|
+
return response.json()
|
|
1303
|
+
|
|
1304
|
+
def _stream_search(self, headers: Dict[str, str], payload: Dict[str, Any]):
|
|
1305
|
+
"""Streaming search response"""
|
|
1306
|
+
response = requests.post(
|
|
1307
|
+
self.endpoint,
|
|
1308
|
+
headers=headers,
|
|
1309
|
+
json=payload,
|
|
1310
|
+
stream=True
|
|
1311
|
+
)
|
|
1312
|
+
|
|
1313
|
+
if response.status_code != 200:
|
|
1314
|
+
raise APIError(f"Bocha AI Search API error: {response.text}") # type: ignore
|
|
1315
|
+
|
|
1316
|
+
for line in response.iter_lines():
|
|
1317
|
+
if line:
|
|
1318
|
+
line = line.decode('utf-8')
|
|
1319
|
+
if line.startswith('data:'):
|
|
1320
|
+
data = line[5:].strip()
|
|
1321
|
+
if data and data != '{"event":"done"}':
|
|
1322
|
+
try:
|
|
1323
|
+
yield json.loads(data)
|
|
1324
|
+
except json.JSONDecodeError:
|
|
1325
|
+
continue
|
|
1326
|
+
|
|
1327
|
+
def get_answer(self, query: str, **kwargs) -> Tuple[str, int, float]:
|
|
1328
|
+
"""Get AI generated answer only"""
|
|
1329
|
+
result, _, remaining_balance = self.search(query, answer=True, **kwargs)
|
|
1330
|
+
|
|
1331
|
+
# Extract answer from messages
|
|
1332
|
+
messages = result.get("messages", []) # type: ignore
|
|
1333
|
+
answer = ""
|
|
1334
|
+
for message in messages:
|
|
1335
|
+
if message.get("type") == "answer":
|
|
1336
|
+
answer = message.get("content", "")
|
|
1337
|
+
break
|
|
1338
|
+
|
|
1339
|
+
return answer, [0,0,0], remaining_balance # type: ignore
|
|
1340
|
+
|
|
1341
|
+
|
|
1342
|
+
def get_sources(self, query: str, **kwargs) -> List[Dict[str, Any]]:
|
|
1343
|
+
"""Get source materials only"""
|
|
1344
|
+
result, _, remaining_balance = self.search(query, **kwargs)
|
|
1345
|
+
|
|
1346
|
+
# Extract sources from messages
|
|
1347
|
+
sources = []
|
|
1348
|
+
messages = result.get("messages", []) # type: ignore
|
|
1349
|
+
for message in messages:
|
|
1350
|
+
if message.get("type") == "source":
|
|
1351
|
+
content_type = message.get("content_type", "")
|
|
1352
|
+
if content_type in ["webpage", "image", "video", "baike_pro", "medical_common"]:
|
|
1353
|
+
sources.append({
|
|
1354
|
+
"type": content_type,
|
|
1355
|
+
"content": json.loads(message.get("content", "{}"))
|
|
1356
|
+
})
|
|
1357
|
+
|
|
1358
|
+
return sources, 0, remaining_balance # type: ignore
|
|
1359
|
+
|
|
1360
|
+
|
|
1361
|
+
def get_follow_up_questions(self, query: str, **kwargs) -> List[str]:
|
|
1362
|
+
"""Get follow-up questions"""
|
|
1363
|
+
result, _, remaining_balance = self.search(query, **kwargs)
|
|
1364
|
+
|
|
1365
|
+
# Extract follow-up questions from messages
|
|
1366
|
+
follow_ups = []
|
|
1367
|
+
messages = result.get("messages", []) # type: ignore
|
|
1368
|
+
for message in messages:
|
|
1369
|
+
if message.get("type") == "follow_up":
|
|
1370
|
+
follow_ups.append(message.get("content", ""))
|
|
1371
|
+
|
|
1372
|
+
return follow_ups, 0, remaining_balance # type: ignore
|
|
1373
|
+
|
|
1374
|
+
|
|
1375
|
+
class ExaResearchEngine(SearchEngine):
|
|
1376
|
+
def __init__(
|
|
1377
|
+
self,
|
|
1378
|
+
api_key: str|None = None,
|
|
1379
|
+
base_url: str = "https://api.exa.ai",
|
|
1380
|
+
rate_limit: int = -1,
|
|
1381
|
+
**kwargs
|
|
1382
|
+
):
|
|
1383
|
+
"""Init an Exa Research engine
|
|
1384
|
+
|
|
1385
|
+
Args:
|
|
1386
|
+
api_key (str, optional): Auth key from Exa AI. Defaults to None.
|
|
1387
|
+
base_url (str, optional): Base URL for the API. Defaults to "https://api.exa.ai".
|
|
1388
|
+
rate_limit (int, optional): Rate limit per minute. Defaults to -1 (no limit).
|
|
1389
|
+
"""
|
|
1390
|
+
api_key = api_key or os.getenv("EXA_API_KEY")
|
|
1391
|
+
if api_key is None:
|
|
1392
|
+
raise ValueError(
|
|
1393
|
+
"An API Key needs to be provided in either the api_key parameter or as an environment variable named EXA_API_KEY"
|
|
1394
|
+
)
|
|
1395
|
+
|
|
1396
|
+
self.api_key = api_key
|
|
1397
|
+
self.base_url = base_url
|
|
1398
|
+
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
|
1399
|
+
|
|
1400
|
+
# Initialize OpenAI-compatible client for chat completions
|
|
1401
|
+
self.chat_client = OpenAI(
|
|
1402
|
+
base_url=base_url,
|
|
1403
|
+
api_key=api_key
|
|
1404
|
+
)
|
|
1405
|
+
|
|
1406
|
+
# Initialize Exa client for research tasks
|
|
1407
|
+
try:
|
|
1408
|
+
from exa_py import Exa
|
|
1409
|
+
self.exa_client = Exa(api_key=api_key)
|
|
1410
|
+
except ImportError:
|
|
1411
|
+
self.exa_client = None
|
|
1412
|
+
print("Warning: exa_py not installed. Research tasks will not be available.")
|
|
1413
|
+
|
|
1414
|
+
@backoff.on_exception(
|
|
1415
|
+
backoff.expo,
|
|
1416
|
+
(
|
|
1417
|
+
APIConnectionError,
|
|
1418
|
+
APIError,
|
|
1419
|
+
RateLimitError,
|
|
1420
|
+
),
|
|
1421
|
+
max_time=60
|
|
1422
|
+
)
|
|
1423
|
+
def search(self, query: str, **kwargs):
|
|
1424
|
+
"""Standard Exa search with direct cost from API
|
|
1425
|
+
|
|
1426
|
+
Args:
|
|
1427
|
+
query (str): Search query
|
|
1428
|
+
**kwargs: Additional search parameters
|
|
1429
|
+
|
|
1430
|
+
Returns:
|
|
1431
|
+
tuple: (result, tokens, cost) where cost is actual API cost
|
|
1432
|
+
"""
|
|
1433
|
+
headers = {
|
|
1434
|
+
'x-api-key': self.api_key,
|
|
1435
|
+
'Content-Type': 'application/json'
|
|
1436
|
+
}
|
|
1437
|
+
|
|
1438
|
+
payload = {
|
|
1439
|
+
"query": query,
|
|
1440
|
+
**kwargs
|
|
1441
|
+
}
|
|
1442
|
+
|
|
1443
|
+
response = requests.post(
|
|
1444
|
+
f"{self.base_url}/search",
|
|
1445
|
+
headers=headers,
|
|
1446
|
+
json=payload
|
|
1447
|
+
)
|
|
1448
|
+
|
|
1449
|
+
if response.status_code != 200:
|
|
1450
|
+
raise APIError(f"Exa Search API error: {response.text}") # type: ignore
|
|
1451
|
+
|
|
1452
|
+
result = response.json()
|
|
1453
|
+
|
|
1454
|
+
cost = 0.0
|
|
1455
|
+
if "costDollars" in result:
|
|
1456
|
+
cost = result["costDollars"].get("total", 0.0)
|
|
1457
|
+
|
|
1458
|
+
return result, [0, 0, 0], cost
|
|
1459
|
+
|
|
1460
|
+
def chat_research(
|
|
1461
|
+
self,
|
|
1462
|
+
query: str,
|
|
1463
|
+
model: str = "exa",
|
|
1464
|
+
stream: bool = False,
|
|
1465
|
+
**kwargs
|
|
1466
|
+
) -> Union[str, Any]:
|
|
1467
|
+
"""Research using chat completions interface
|
|
1468
|
+
|
|
1469
|
+
Args:
|
|
1470
|
+
query (str): Research query
|
|
1471
|
+
model (str, optional): Model name. Defaults to "exa".
|
|
1472
|
+
stream (bool, optional): Whether to stream response. Defaults to False.
|
|
1473
|
+
|
|
1474
|
+
Returns:
|
|
1475
|
+
Union[str, Any]: Research result or stream
|
|
1476
|
+
"""
|
|
1477
|
+
messages = [
|
|
1478
|
+
{"role": "user", "content": query}
|
|
1479
|
+
]
|
|
1480
|
+
|
|
1481
|
+
if stream:
|
|
1482
|
+
completion = self.chat_client.chat.completions.create(
|
|
1483
|
+
model=model,
|
|
1484
|
+
messages=messages, # type: ignore
|
|
1485
|
+
stream=True,
|
|
1486
|
+
**kwargs
|
|
1487
|
+
)
|
|
1488
|
+
return completion
|
|
1489
|
+
else:
|
|
1490
|
+
completion = self.chat_client.chat.completions.create(
|
|
1491
|
+
model=model,
|
|
1492
|
+
messages=messages, # type: ignore
|
|
1493
|
+
**kwargs
|
|
1494
|
+
)
|
|
1495
|
+
result = completion.choices[0].message.content # type: ignore
|
|
1496
|
+
return result,[0,0,0],0.005
|