lybic-guiagents 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lybic-guiagents might be problematic. Click here for more details.

Files changed (85) hide show
  1. desktop_env/__init__.py +1 -0
  2. desktop_env/actions.py +203 -0
  3. desktop_env/controllers/__init__.py +0 -0
  4. desktop_env/controllers/python.py +471 -0
  5. desktop_env/controllers/setup.py +882 -0
  6. desktop_env/desktop_env.py +509 -0
  7. desktop_env/evaluators/__init__.py +5 -0
  8. desktop_env/evaluators/getters/__init__.py +41 -0
  9. desktop_env/evaluators/getters/calc.py +15 -0
  10. desktop_env/evaluators/getters/chrome.py +1774 -0
  11. desktop_env/evaluators/getters/file.py +154 -0
  12. desktop_env/evaluators/getters/general.py +42 -0
  13. desktop_env/evaluators/getters/gimp.py +38 -0
  14. desktop_env/evaluators/getters/impress.py +126 -0
  15. desktop_env/evaluators/getters/info.py +24 -0
  16. desktop_env/evaluators/getters/misc.py +406 -0
  17. desktop_env/evaluators/getters/replay.py +20 -0
  18. desktop_env/evaluators/getters/vlc.py +86 -0
  19. desktop_env/evaluators/getters/vscode.py +35 -0
  20. desktop_env/evaluators/metrics/__init__.py +160 -0
  21. desktop_env/evaluators/metrics/basic_os.py +68 -0
  22. desktop_env/evaluators/metrics/chrome.py +493 -0
  23. desktop_env/evaluators/metrics/docs.py +1011 -0
  24. desktop_env/evaluators/metrics/general.py +665 -0
  25. desktop_env/evaluators/metrics/gimp.py +637 -0
  26. desktop_env/evaluators/metrics/libreoffice.py +28 -0
  27. desktop_env/evaluators/metrics/others.py +92 -0
  28. desktop_env/evaluators/metrics/pdf.py +31 -0
  29. desktop_env/evaluators/metrics/slides.py +957 -0
  30. desktop_env/evaluators/metrics/table.py +585 -0
  31. desktop_env/evaluators/metrics/thunderbird.py +176 -0
  32. desktop_env/evaluators/metrics/utils.py +719 -0
  33. desktop_env/evaluators/metrics/vlc.py +524 -0
  34. desktop_env/evaluators/metrics/vscode.py +283 -0
  35. desktop_env/providers/__init__.py +35 -0
  36. desktop_env/providers/aws/__init__.py +0 -0
  37. desktop_env/providers/aws/manager.py +278 -0
  38. desktop_env/providers/aws/provider.py +186 -0
  39. desktop_env/providers/aws/provider_with_proxy.py +315 -0
  40. desktop_env/providers/aws/proxy_pool.py +193 -0
  41. desktop_env/providers/azure/__init__.py +0 -0
  42. desktop_env/providers/azure/manager.py +87 -0
  43. desktop_env/providers/azure/provider.py +207 -0
  44. desktop_env/providers/base.py +97 -0
  45. desktop_env/providers/gcp/__init__.py +0 -0
  46. desktop_env/providers/gcp/manager.py +0 -0
  47. desktop_env/providers/gcp/provider.py +0 -0
  48. desktop_env/providers/virtualbox/__init__.py +0 -0
  49. desktop_env/providers/virtualbox/manager.py +463 -0
  50. desktop_env/providers/virtualbox/provider.py +124 -0
  51. desktop_env/providers/vmware/__init__.py +0 -0
  52. desktop_env/providers/vmware/manager.py +455 -0
  53. desktop_env/providers/vmware/provider.py +105 -0
  54. gui_agents/__init__.py +0 -0
  55. gui_agents/agents/Action.py +209 -0
  56. gui_agents/agents/__init__.py +0 -0
  57. gui_agents/agents/agent_s.py +832 -0
  58. gui_agents/agents/global_state.py +610 -0
  59. gui_agents/agents/grounding.py +651 -0
  60. gui_agents/agents/hardware_interface.py +129 -0
  61. gui_agents/agents/manager.py +568 -0
  62. gui_agents/agents/translator.py +132 -0
  63. gui_agents/agents/worker.py +355 -0
  64. gui_agents/cli_app.py +560 -0
  65. gui_agents/core/__init__.py +0 -0
  66. gui_agents/core/engine.py +1496 -0
  67. gui_agents/core/knowledge.py +449 -0
  68. gui_agents/core/mllm.py +555 -0
  69. gui_agents/tools/__init__.py +0 -0
  70. gui_agents/tools/tools.py +727 -0
  71. gui_agents/unit_test/__init__.py +0 -0
  72. gui_agents/unit_test/run_tests.py +65 -0
  73. gui_agents/unit_test/test_manager.py +330 -0
  74. gui_agents/unit_test/test_worker.py +269 -0
  75. gui_agents/utils/__init__.py +0 -0
  76. gui_agents/utils/analyze_display.py +301 -0
  77. gui_agents/utils/common_utils.py +263 -0
  78. gui_agents/utils/display_viewer.py +281 -0
  79. gui_agents/utils/embedding_manager.py +53 -0
  80. gui_agents/utils/image_axis_utils.py +27 -0
  81. lybic_guiagents-0.1.0.dist-info/METADATA +416 -0
  82. lybic_guiagents-0.1.0.dist-info/RECORD +85 -0
  83. lybic_guiagents-0.1.0.dist-info/WHEEL +5 -0
  84. lybic_guiagents-0.1.0.dist-info/licenses/LICENSE +201 -0
  85. lybic_guiagents-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1496 @@
1
+ import os
2
+ import json
3
+ import backoff
4
+ import requests
5
+ from typing import List, Dict, Any, Optional, Union
6
+ import numpy as np
7
+ from anthropic import Anthropic
8
+ from openai import (
9
+ AzureOpenAI,
10
+ APIConnectionError,
11
+ APIError,
12
+ AzureOpenAI,
13
+ OpenAI,
14
+ RateLimitError,
15
+ )
16
+ from google import genai
17
+ from google.genai import types
18
+ from zhipuai import ZhipuAI
19
+ from groq import Groq
20
+ import boto3
21
+ import exa_py
22
+ from typing import List, Dict, Any, Optional, Union, Tuple
23
+
24
+ class ModelPricing:
25
+ def __init__(self, pricing_file: str = "model_pricing.json"):
26
+ self.pricing_file = pricing_file
27
+ self.pricing_data = self._load_pricing()
28
+
29
+ def _load_pricing(self) -> Dict:
30
+ if os.path.exists(self.pricing_file):
31
+ try:
32
+ with open(self.pricing_file, 'r', encoding='utf-8') as f:
33
+ return json.load(f)
34
+ except Exception as e:
35
+ print(f"Warning: Failed to load pricing file {self.pricing_file}: {e}")
36
+
37
+ return {
38
+ "default": {"input": 0, "output": 0}
39
+ }
40
+
41
+ def get_price(self, model: str) -> Dict[str, float]:
42
+ # Handle nested pricing data structure
43
+ if "llm_models" in self.pricing_data:
44
+ # Iterate through all LLM model categories
45
+ for category, models in self.pricing_data["llm_models"].items():
46
+ # Direct model name matching
47
+ if model in models:
48
+ pricing = models[model]
49
+ return self._parse_pricing(pricing)
50
+
51
+ # Fuzzy matching for model names
52
+ for model_name in models:
53
+ if model_name in model or model in model_name:
54
+ pricing = models[model_name]
55
+ return self._parse_pricing(pricing)
56
+
57
+ # Handle embedding models
58
+ if "embedding_models" in self.pricing_data:
59
+ for category, models in self.pricing_data["embedding_models"].items():
60
+ if model in models:
61
+ pricing = models[model]
62
+ return self._parse_pricing(pricing)
63
+
64
+ for model_name in models:
65
+ if model_name in model or model in model_name:
66
+ pricing = models[model_name]
67
+ return self._parse_pricing(pricing)
68
+
69
+ # Default pricing
70
+ return {"input": 0, "output": 0}
71
+
72
+ def _parse_pricing(self, pricing: Dict[str, str]) -> Dict[str, float]:
73
+ """Parse pricing strings and convert to numeric values"""
74
+ result = {}
75
+
76
+ for key, value in pricing.items():
77
+ if isinstance(value, str):
78
+ # Remove currency symbols and units, convert to float
79
+ clean_value = value.replace('$', '').replace('¥', '').replace(',', '')
80
+ try:
81
+ result[key] = float(clean_value)
82
+ except ValueError:
83
+ result[key] = 0.0
84
+ else:
85
+ result[key] = float(value) if value else 0.0
86
+
87
+ return result
88
+
89
+ def calculate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
90
+ pricing = self.get_price(model)
91
+ input_cost = (input_tokens / 1000000) * pricing["input"]
92
+ output_cost = (output_tokens / 1000000) * pricing["output"]
93
+ return input_cost + output_cost
94
+
95
+ # Initialize pricing manager with correct pricing file path
96
+ pricing_file = os.path.join(os.path.dirname(__file__), 'model_pricing.json')
97
+ pricing_manager = ModelPricing(pricing_file)
98
+
99
+ def extract_token_usage(response, provider: str) -> Tuple[int, int]:
100
+ if "-" in provider:
101
+ api_type, vendor = provider.split("-", 1)
102
+ else:
103
+ api_type, vendor = "llm", provider
104
+
105
+ if api_type == "llm":
106
+ if vendor in ["openai", "qwen", "deepseek", "doubao", "siliconflow", "monica", "vllm", "groq", "zhipu", "gemini", "openrouter", "azureopenai", "huggingface", "exa"]:
107
+ if hasattr(response, 'usage') and response.usage:
108
+ return response.usage.prompt_tokens, response.usage.completion_tokens
109
+
110
+ elif vendor == "anthropic":
111
+ if hasattr(response, 'usage') and response.usage:
112
+ return response.usage.input_tokens, response.usage.output_tokens
113
+
114
+ elif vendor == "bedrock":
115
+ if isinstance(response, dict) and "usage" in response:
116
+ usage = response["usage"]
117
+ return usage.get("input_tokens", 0), usage.get("output_tokens", 0)
118
+
119
+ elif api_type == "embedding":
120
+ if vendor in ["openai", "azureopenai", "qwen", "doubao"]:
121
+ if hasattr(response, 'usage') and response.usage:
122
+ return response.usage.prompt_tokens, 0
123
+
124
+ elif vendor == "jina":
125
+ if isinstance(response, dict) and "usage" in response:
126
+ total_tokens = response["usage"].get("total_tokens", 0)
127
+ return total_tokens, 0
128
+
129
+ elif vendor == "gemini":
130
+ if hasattr(response, 'usage') and response.usage:
131
+ return response.usage.prompt_tokens, 0
132
+
133
+ return 0, 0
134
+
135
+ def calculate_tokens_and_cost(response, provider: str, model: str) -> Tuple[List[int], float]:
136
+ input_tokens, output_tokens = extract_token_usage(response, provider)
137
+ total_tokens = input_tokens + output_tokens
138
+ cost = pricing_manager.calculate_cost(model, input_tokens, output_tokens)
139
+
140
+ return [input_tokens, output_tokens, total_tokens], cost
141
+
142
+ class LMMEngine:
143
+ pass
144
+
145
+ # ==================== LLM ====================
146
+
147
+ class LMMEngineOpenAI(LMMEngine):
148
+ def __init__(
149
+ self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs
150
+ ):
151
+ assert model is not None, "model must be provided"
152
+ self.model = model
153
+ self.provider = "llm-openai"
154
+
155
+ api_key = api_key or os.getenv("OPENAI_API_KEY")
156
+ if api_key is None:
157
+ raise ValueError(
158
+ "An API Key needs to be provided in either the api_key parameter or as an environment variable named OPENAI_API_KEY"
159
+ )
160
+
161
+ self.base_url = base_url
162
+
163
+ self.api_key = api_key
164
+ self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
165
+
166
+ if not self.base_url:
167
+ self.llm_client = OpenAI(api_key=self.api_key)
168
+ else:
169
+ self.llm_client = OpenAI(base_url=self.base_url, api_key=self.api_key)
170
+
171
+ @backoff.on_exception(
172
+ backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
173
+ )
174
+ def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
175
+ """Generate the next message based on previous messages"""
176
+ response = self.llm_client.chat.completions.create(
177
+ model=self.model,
178
+ messages=messages,
179
+ max_tokens=max_new_tokens if max_new_tokens else 4096,
180
+ temperature=temperature,
181
+ **kwargs,
182
+ )
183
+
184
+ content = response.choices[0].message.content
185
+ total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
186
+
187
+ return content, total_tokens, cost
188
+
189
+
190
+ class LMMEngineQwen(LMMEngine):
191
+ def __init__(
192
+ self, base_url=None, api_key=None, model=None, rate_limit=-1, enable_thinking=False, **kwargs
193
+ ):
194
+ assert model is not None, "model must be provided"
195
+ self.model = model
196
+ self.enable_thinking = enable_thinking
197
+ self.provider = "llm-qwen"
198
+
199
+ api_key = api_key or os.getenv("DASHSCOPE_API_KEY")
200
+ if api_key is None:
201
+ raise ValueError(
202
+ "An API Key needs to be provided in either the api_key parameter or as an environment variable named DASHSCOPE_API_KEY"
203
+ )
204
+
205
+ self.base_url = base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1"
206
+ self.api_key = api_key
207
+ self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
208
+
209
+ self.llm_client = OpenAI(base_url=self.base_url, api_key=self.api_key)
210
+
211
+ @backoff.on_exception(
212
+ backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
213
+ )
214
+ def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
215
+ """Generate the next message based on previous messages"""
216
+ # For Qwen3 models, we need to handle thinking mode
217
+ extra_body = {}
218
+ if self.model.startswith("qwen3") and not self.enable_thinking:
219
+ extra_body["enable_thinking"] = False
220
+
221
+ response = self.llm_client.chat.completions.create(
222
+ model=self.model,
223
+ messages=messages,
224
+ max_tokens=max_new_tokens if max_new_tokens else 4096,
225
+ temperature=temperature,
226
+ **extra_body,
227
+ **kwargs,
228
+ )
229
+
230
+ content = response.choices[0].message.content
231
+ total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
232
+
233
+ return content, total_tokens, cost
234
+
235
+
236
+ class LMMEngineDoubao(LMMEngine):
237
+ def __init__(
238
+ self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs
239
+ ):
240
+ assert model is not None, "model must be provided"
241
+ self.model = model
242
+ self.provider = "llm-doubao"
243
+
244
+ api_key = api_key or os.getenv("ARK_API_KEY")
245
+ if api_key is None:
246
+ raise ValueError(
247
+ "An API Key needs to be provided in either the api_key parameter or as an environment variable named ARK_API_KEY"
248
+ )
249
+
250
+ self.base_url = base_url or "https://ark.cn-beijing.volces.com/api/v3"
251
+ self.api_key = api_key
252
+ self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
253
+
254
+ self.llm_client = OpenAI(base_url=self.base_url, api_key=self.api_key)
255
+
256
+ @backoff.on_exception(
257
+ backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
258
+ )
259
+ def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
260
+ """Generate the next message based on previous messages"""
261
+ response = self.llm_client.chat.completions.create(
262
+ model=self.model,
263
+ messages=messages,
264
+ max_tokens=max_new_tokens if max_new_tokens else 4096,
265
+ temperature=temperature,
266
+ extra_body={
267
+ "thinking": {
268
+ "type": "disabled",
269
+ # "type": "enabled",
270
+ # "type": "auto",
271
+ }
272
+ },
273
+ **kwargs,
274
+ )
275
+
276
+ content = response.choices[0].message.content
277
+ total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
278
+
279
+ return content, total_tokens, cost
280
+
281
+
282
+ class LMMEngineAnthropic(LMMEngine):
283
+ def __init__(
284
+ self, base_url=None, api_key=None, model=None, thinking=False, **kwargs
285
+ ):
286
+ assert model is not None, "model must be provided"
287
+ self.model = model
288
+ self.thinking = thinking
289
+ self.provider = "llm-anthropic"
290
+
291
+ api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
292
+ if api_key is None:
293
+ raise ValueError(
294
+ "An API Key needs to be provided in either the api_key parameter or as an environment variable named ANTHROPIC_API_KEY"
295
+ )
296
+
297
+ self.api_key = api_key
298
+
299
+ self.llm_client = Anthropic(api_key=self.api_key)
300
+
301
+ @backoff.on_exception(
302
+ backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
303
+ )
304
+ def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
305
+ """Generate the next message based on previous messages"""
306
+ if self.thinking:
307
+ response = self.llm_client.messages.create(
308
+ system=messages[0]["content"][0]["text"],
309
+ model=self.model,
310
+ messages=messages[1:],
311
+ max_tokens=8192,
312
+ thinking={"type": "enabled", "budget_tokens": 4096},
313
+ **kwargs,
314
+ )
315
+ thoughts = response.content[0].thinking
316
+ print("CLAUDE 3.7 THOUGHTS:", thoughts)
317
+ content = response.content[1].text
318
+ else:
319
+ response = self.llm_client.messages.create(
320
+ system=messages[0]["content"][0]["text"],
321
+ model=self.model,
322
+ messages=messages[1:],
323
+ max_tokens=max_new_tokens if max_new_tokens else 4096,
324
+ temperature=temperature,
325
+ **kwargs,
326
+ )
327
+ content = response.content[0].text
328
+
329
+ total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
330
+ return content, total_tokens, cost
331
+
332
+
333
+ class LMMEngineGemini(LMMEngine):
334
+ def __init__(
335
+ self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs
336
+ ):
337
+ assert model is not None, "model must be provided"
338
+ self.model = model
339
+ self.provider = "llm-gemini"
340
+
341
+ api_key = api_key or os.getenv("GEMINI_API_KEY")
342
+ if api_key is None:
343
+ raise ValueError(
344
+ "An API Key needs to be provided in either the api_key parameter or as an environment variable named GEMINI_API_KEY"
345
+ )
346
+
347
+ self.base_url = base_url or os.getenv("GEMINI_ENDPOINT_URL")
348
+ if self.base_url is None:
349
+ raise ValueError(
350
+ "An endpoint URL needs to be provided in either the endpoint_url parameter or as an environment variable named GEMINI_ENDPOINT_URL"
351
+ )
352
+
353
+ self.api_key = api_key
354
+ self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
355
+
356
+ self.llm_client = OpenAI(base_url=self.base_url, api_key=self.api_key)
357
+
358
+ @backoff.on_exception(
359
+ backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
360
+ )
361
+ def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
362
+ """Generate the next message based on previous messages"""
363
+ response = self.llm_client.chat.completions.create(
364
+ model=self.model,
365
+ messages=messages,
366
+ max_tokens=max_new_tokens if max_new_tokens else 4096,
367
+ temperature=temperature,
368
+ # reasoning_effort="low",
369
+ extra_body={
370
+ 'extra_body': {
371
+ "google": {
372
+ "thinking_config": {
373
+ "thinking_budget": 128,
374
+ "include_thoughts": True
375
+ }
376
+ }
377
+ }
378
+ },
379
+ **kwargs,
380
+ )
381
+
382
+ content = response.choices[0].message.content
383
+ total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
384
+
385
+ return content, total_tokens, cost
386
+
387
+
388
+
389
+ class LMMEngineOpenRouter(LMMEngine):
390
+ def __init__(
391
+ self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs
392
+ ):
393
+ assert model is not None, "model must be provided"
394
+ self.model = model
395
+ self.provider = "llm-openrouter"
396
+
397
+ api_key = api_key or os.getenv("OPENROUTER_API_KEY")
398
+ if api_key is None:
399
+ raise ValueError(
400
+ "An API Key needs to be provided in either the api_key parameter or as an environment variable named OPENROUTER_API_KEY"
401
+ )
402
+
403
+ self.base_url = base_url or os.getenv("OPEN_ROUTER_ENDPOINT_URL")
404
+ if self.base_url is None:
405
+ raise ValueError(
406
+ "An endpoint URL needs to be provided in either the endpoint_url parameter or as an environment variable named OPEN_ROUTER_ENDPOINT_URL"
407
+ )
408
+
409
+ self.api_key = api_key
410
+ self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
411
+
412
+ self.llm_client = OpenAI(base_url=self.base_url, api_key=self.api_key)
413
+
414
+ @backoff.on_exception(
415
+ backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
416
+ )
417
+ def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
418
+ """Generate the next message based on previous messages"""
419
+ response = self.llm_client.chat.completions.create(
420
+ model=self.model,
421
+ messages=messages,
422
+ max_tokens=max_new_tokens if max_new_tokens else 4096,
423
+ temperature=temperature,
424
+ **kwargs,
425
+ )
426
+
427
+ content = response.choices[0].message.content
428
+ total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
429
+
430
+ return content, total_tokens, cost
431
+
432
+
433
+ class LMMEngineAzureOpenAI(LMMEngine):
434
+ def __init__(
435
+ self,
436
+ base_url=None,
437
+ api_key=None,
438
+ azure_endpoint=None,
439
+ model=None,
440
+ api_version=None,
441
+ rate_limit=-1,
442
+ **kwargs
443
+ ):
444
+ assert model is not None, "model must be provided"
445
+ self.model = model
446
+ self.provider = "llm-azureopenai"
447
+
448
+ assert api_version is not None, "api_version must be provided"
449
+ self.api_version = api_version
450
+
451
+ api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
452
+ if api_key is None:
453
+ raise ValueError(
454
+ "An API Key needs to be provided in either the api_key parameter or as an environment variable named AZURE_OPENAI_API_KEY"
455
+ )
456
+
457
+ self.api_key = api_key
458
+
459
+ azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
460
+ if azure_endpoint is None:
461
+ raise ValueError(
462
+ "An Azure API endpoint needs to be provided in either the azure_endpoint parameter or as an environment variable named AZURE_OPENAI_ENDPOINT"
463
+ )
464
+
465
+ self.azure_endpoint = azure_endpoint
466
+ self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
467
+
468
+ self.llm_client = AzureOpenAI(
469
+ azure_endpoint=self.azure_endpoint,
470
+ api_key=self.api_key,
471
+ api_version=self.api_version,
472
+ )
473
+ self.cost = 0.0
474
+
475
+ # @backoff.on_exception(backoff.expo, (APIConnectionError, APIError, RateLimitError), max_tries=10)
476
+ def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
477
+ """Generate the next message based on previous messages"""
478
+ response = self.llm_client.chat.completions.create(
479
+ model=self.model,
480
+ messages=messages,
481
+ max_tokens=max_new_tokens if max_new_tokens else 4096,
482
+ temperature=temperature,
483
+ **kwargs,
484
+ )
485
+ content = response.choices[0].message.content
486
+ total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
487
+ return content, total_tokens, cost
488
+
489
+
490
+ class LMMEnginevLLM(LMMEngine):
491
+ def __init__(
492
+ self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs
493
+ ):
494
+ assert model is not None, "model must be provided"
495
+ self.model = model
496
+ self.api_key = api_key
497
+ self.provider = "llm-vllm"
498
+
499
+ self.base_url = base_url or os.getenv("vLLM_ENDPOINT_URL")
500
+ if self.base_url is None:
501
+ raise ValueError(
502
+ "An endpoint URL needs to be provided in either the endpoint_url parameter or as an environment variable named vLLM_ENDPOINT_URL"
503
+ )
504
+
505
+ self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
506
+
507
+ self.llm_client = OpenAI(base_url=self.base_url, api_key=self.api_key)
508
+
509
+ # @backoff.on_exception(backoff.expo, (APIConnectionError, APIError, RateLimitError), max_tries=10)
510
+ # TODO: Default params chosen for the Qwen model
511
+ def generate(
512
+ self,
513
+ messages,
514
+ temperature=0.0,
515
+ top_p=0.8,
516
+ repetition_penalty=1.05,
517
+ max_new_tokens=512,
518
+ **kwargs
519
+ ):
520
+ """Generate the next message based on previous messages"""
521
+ response = self.llm_client.chat.completions.create(
522
+ model=self.model,
523
+ messages=messages,
524
+ max_tokens=max_new_tokens if max_new_tokens else 4096,
525
+ temperature=temperature,
526
+ top_p=top_p,
527
+ extra_body={"repetition_penalty": repetition_penalty},
528
+ )
529
+ content = response.choices[0].message.content
530
+ total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
531
+ return content, total_tokens, cost
532
+
533
+
534
+ class LMMEngineHuggingFace(LMMEngine):
535
+ def __init__(self, base_url=None, api_key=None, rate_limit=-1, **kwargs):
536
+ assert base_url is not None, "HuggingFace endpoint must be provided"
537
+ self.base_url = base_url
538
+ self.model = base_url.split('/')[-1] if base_url else "huggingface-tgi"
539
+ self.provider = "llm-huggingface"
540
+
541
+ api_key = api_key or os.getenv("HF_TOKEN")
542
+ if api_key is None:
543
+ raise ValueError(
544
+ "A HuggingFace token needs to be provided in either the api_key parameter or as an environment variable named HF_TOKEN"
545
+ )
546
+
547
+ self.api_key = api_key
548
+ self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
549
+
550
+ self.llm_client = OpenAI(base_url=self.base_url, api_key=self.api_key)
551
+
552
+ @backoff.on_exception(
553
+ backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
554
+ )
555
+ def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
556
+ """Generate the next message based on previous messages"""
557
+ response = self.llm_client.chat.completions.create(
558
+ model="tgi",
559
+ messages=messages,
560
+ max_tokens=max_new_tokens if max_new_tokens else 4096,
561
+ temperature=temperature,
562
+ **kwargs,
563
+ )
564
+
565
+ content = response.choices[0].message.content
566
+ total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
567
+
568
+ return content, total_tokens, cost
569
+
570
+
571
+ class LMMEngineDeepSeek(LMMEngine):
572
+ def __init__(
573
+ self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs
574
+ ):
575
+ assert model is not None, "model must be provided"
576
+ self.model = model
577
+ self.provider = "llm-deepseek"
578
+
579
+ api_key = api_key or os.getenv("DEEPSEEK_API_KEY")
580
+ if api_key is None:
581
+ raise ValueError(
582
+ "An API Key needs to be provided in either the api_key parameter or as an environment variable named DEEPSEEK_API_KEY"
583
+ )
584
+
585
+ self.base_url = base_url or "https://api.deepseek.com"
586
+ self.api_key = api_key
587
+ self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
588
+
589
+ self.llm_client = OpenAI(base_url=self.base_url, api_key=self.api_key)
590
+
591
+ @backoff.on_exception(
592
+ backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
593
+ )
594
+ def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
595
+ """Generate the next message based on previous messages"""
596
+ response = self.llm_client.chat.completions.create(
597
+ model=self.model,
598
+ messages=messages,
599
+ max_tokens=max_new_tokens if max_new_tokens else 4096,
600
+ temperature=temperature,
601
+ **kwargs,
602
+ )
603
+
604
+ content = response.choices[0].message.content
605
+ total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
606
+ return content, total_tokens, cost
607
+
608
+
609
+ class LMMEngineZhipu(LMMEngine):
610
+ def __init__(
611
+ self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs
612
+ ):
613
+ assert model is not None, "model must be provided"
614
+ self.model = model
615
+ self.provider = "llm-zhipu"
616
+
617
+ api_key = api_key or os.getenv("ZHIPU_API_KEY")
618
+ if api_key is None:
619
+ raise ValueError(
620
+ "An API Key needs to be provided in either the api_key parameter or as an environment variable named ZHIPU_API_KEY"
621
+ )
622
+
623
+ self.api_key = api_key
624
+ self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
625
+
626
+ # Use ZhipuAI client directly instead of OpenAI compatibility layer
627
+ self.llm_client = ZhipuAI(api_key=self.api_key)
628
+
629
+ @backoff.on_exception(
630
+ backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
631
+ )
632
+ def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
633
+ """Generate the next message based on previous messages"""
634
+ response = self.llm_client.chat.completions.create(
635
+ model=self.model,
636
+ messages=messages,
637
+ max_tokens=max_new_tokens if max_new_tokens else 4096,
638
+ temperature=temperature,
639
+ **kwargs,
640
+ )
641
+
642
+ content = response.choices[0].message.content # type: ignore
643
+ total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
644
+ return content, total_tokens, cost
645
+
646
+
647
+
648
+ class LMMEngineGroq(LMMEngine):
649
+ def __init__(
650
+ self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs
651
+ ):
652
+ assert model is not None, "model must be provided"
653
+ self.model = model
654
+ self.provider = "llm-groq"
655
+
656
+ api_key = api_key or os.getenv("GROQ_API_KEY")
657
+ if api_key is None:
658
+ raise ValueError(
659
+ "An API Key needs to be provided in either the api_key parameter or as an environment variable named GROQ_API_KEY"
660
+ )
661
+
662
+ self.api_key = api_key
663
+ self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
664
+
665
+ # Use Groq client directly
666
+ self.llm_client = Groq(api_key=self.api_key)
667
+
668
+ @backoff.on_exception(
669
+ backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
670
+ )
671
+ def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
672
+ """Generate the next message based on previous messages"""
673
+ response = self.llm_client.chat.completions.create(
674
+ model=self.model,
675
+ messages=messages,
676
+ max_tokens=max_new_tokens if max_new_tokens else 4096,
677
+ temperature=temperature,
678
+ **kwargs,
679
+ )
680
+
681
+ content = response.choices[0].message.content
682
+ total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
683
+ return content, total_tokens, cost
684
+
685
+
686
+ class LMMEngineSiliconflow(LMMEngine):
687
+ def __init__(
688
+ self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs
689
+ ):
690
+ assert model is not None, "model must be provided"
691
+ self.model = model
692
+ self.provider = "llm-siliconflow"
693
+
694
+ api_key = api_key or os.getenv("SILICONFLOW_API_KEY")
695
+ if api_key is None:
696
+ raise ValueError(
697
+ "An API Key needs to be provided in either the api_key parameter or as an environment variable named SILICONFLOW_API_KEY"
698
+ )
699
+
700
+ self.base_url = base_url or "https://api.siliconflow.cn/v1"
701
+ self.api_key = api_key
702
+ self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
703
+
704
+ self.llm_client = OpenAI(base_url=self.base_url, api_key=self.api_key)
705
+
706
+ @backoff.on_exception(
707
+ backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
708
+ )
709
+ def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
710
+ """Generate the next message based on previous messages"""
711
+ response = self.llm_client.chat.completions.create(
712
+ model=self.model,
713
+ messages=messages,
714
+ max_tokens=max_new_tokens if max_new_tokens else 4096,
715
+ temperature=temperature,
716
+ **kwargs,
717
+ )
718
+
719
+ content = response.choices[0].message.content
720
+ total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
721
+ return content, total_tokens, cost
722
+
723
+
724
+ class LMMEngineMonica(LMMEngine):
725
+ def __init__(
726
+ self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs
727
+ ):
728
+ assert model is not None, "model must be provided"
729
+ self.model = model
730
+ self.provider = "llm-monica"
731
+
732
+ api_key = api_key or os.getenv("MONICA_API_KEY")
733
+ if api_key is None:
734
+ raise ValueError(
735
+ "An API Key needs to be provided in either the api_key parameter or as an environment variable named MONICA_API_KEY"
736
+ )
737
+
738
+ self.base_url = base_url or "https://openapi.monica.im/v1"
739
+ self.api_key = api_key
740
+ self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
741
+
742
+ self.llm_client = OpenAI(base_url=self.base_url, api_key=self.api_key)
743
+
744
+ @backoff.on_exception(
745
+ backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
746
+ )
747
+ def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
748
+ """Generate the next message based on previous messages"""
749
+ response = self.llm_client.chat.completions.create(
750
+ model=self.model,
751
+ messages=messages,
752
+ max_tokens=max_new_tokens if max_new_tokens else 4096,
753
+ temperature=temperature,
754
+ **kwargs,
755
+ )
756
+
757
+ content = response.choices[0].message.content
758
+ total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
759
+ return content, total_tokens, cost
760
+
761
+
762
+ class LMMEngineAWSBedrock(LMMEngine):
763
+ def __init__(
764
+ self,
765
+ aws_access_key=None,
766
+ aws_secret_key=None,
767
+ aws_region=None,
768
+ model=None,
769
+ rate_limit=-1,
770
+ **kwargs
771
+ ):
772
+ assert model is not None, "model must be provided"
773
+ self.model = model
774
+ self.provider = "llm-bedrock"
775
+
776
+ # Claude model mapping for AWS Bedrock
777
+ self.claude_model_map = {
778
+ "claude-opus-4": "anthropic.claude-opus-4-20250514-v1:0",
779
+ "claude-sonnet-4": "anthropic.claude-sonnet-4-20250514-v1:0",
780
+ "claude-3-7-sonnet": "anthropic.claude-3-7-sonnet-20250219-v1:0",
781
+ "claude-3-5-sonnet": "anthropic.claude-3-5-sonnet-20241022-v2:0",
782
+ "claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
783
+ "claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
784
+ "claude-3-5-haiku": "anthropic.claude-3-5-haiku-20241022-v1:0",
785
+ "claude-3-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
786
+ "claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0",
787
+ "claude-3-opus": "anthropic.claude-3-opus-20240229-v1:0",
788
+ }
789
+
790
+ # Get the actual Bedrock model ID
791
+ self.bedrock_model_id = self.claude_model_map.get(model, model)
792
+
793
+ # AWS credentials
794
+ aws_access_key = aws_access_key or os.getenv("AWS_ACCESS_KEY_ID")
795
+ aws_secret_key = aws_secret_key or os.getenv("AWS_SECRET_ACCESS_KEY")
796
+ aws_region = aws_region or os.getenv("AWS_DEFAULT_REGION") or "us-west-2"
797
+
798
+ if aws_access_key is None:
799
+ raise ValueError(
800
+ "AWS Access Key needs to be provided in either the aws_access_key parameter or as an environment variable named AWS_ACCESS_KEY_ID"
801
+ )
802
+ if aws_secret_key is None:
803
+ raise ValueError(
804
+ "AWS Secret Key needs to be provided in either the aws_secret_key parameter or as an environment variable named AWS_SECRET_ACCESS_KEY"
805
+ )
806
+
807
+ self.aws_region = aws_region
808
+ self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
809
+
810
+ # Initialize Bedrock client
811
+ self.bedrock_client = boto3.client(
812
+ service_name="bedrock-runtime",
813
+ region_name=aws_region,
814
+ aws_access_key_id=aws_access_key,
815
+ aws_secret_access_key=aws_secret_key
816
+ )
817
+
818
+ @backoff.on_exception(
819
+ backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
820
+ )
821
+ def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
822
+ """Generate the next message based on previous messages"""
823
+
824
+ # Convert messages to Bedrock format
825
+ # Extract system message if present
826
+ system_message = None
827
+ user_messages = []
828
+
829
+ for message in messages:
830
+ if message["role"] == "system":
831
+ if isinstance(message["content"], list):
832
+ system_message = message["content"][0]["text"]
833
+ else:
834
+ system_message = message["content"]
835
+ else:
836
+ # Handle both list and string content formats
837
+ if isinstance(message["content"], list):
838
+ content = message["content"][0]["text"] if message["content"] else ""
839
+ else:
840
+ content = message["content"]
841
+
842
+ user_messages.append({
843
+ "role": message["role"],
844
+ "content": content
845
+ })
846
+
847
+ # Prepare the body for Bedrock
848
+ body = {
849
+ "max_tokens": max_new_tokens if max_new_tokens else 4096,
850
+ "messages": user_messages,
851
+ "anthropic_version": "bedrock-2023-05-31"
852
+ }
853
+
854
+ if temperature > 0:
855
+ body["temperature"] = temperature
856
+
857
+ if system_message:
858
+ body["system"] = system_message
859
+
860
+ try:
861
+ response = self.bedrock_client.invoke_model(
862
+ body=json.dumps(body),
863
+ modelId=self.bedrock_model_id
864
+ )
865
+
866
+ response_body = json.loads(response.get("body").read())
867
+
868
+ if "content" in response_body and response_body["content"]:
869
+ content = response_body["content"][0]["text"]
870
+ else:
871
+ raise ValueError("No content in response")
872
+
873
+ total_tokens, cost = calculate_tokens_and_cost(response_body, self.provider, self.model)
874
+ return content, total_tokens, cost
875
+
876
+ except Exception as e:
877
+ print(f"AWS Bedrock error: {e}")
878
+ raise
879
+
880
+ # ==================== Embedding ====================
881
+
882
+ class OpenAIEmbeddingEngine(LMMEngine):
883
+ def __init__(
884
+ self,
885
+ embedding_model: str = "text-embedding-3-small",
886
+ api_key=None,
887
+ **kwargs
888
+ ):
889
+ """Init an OpenAI Embedding engine
890
+
891
+ Args:
892
+ embedding_model (str, optional): Model name. Defaults to "text-embedding-3-small".
893
+ api_key (_type_, optional): Auth key from OpenAI. Defaults to None.
894
+ """
895
+ self.model = embedding_model
896
+ self.provider = "embedding-openai"
897
+
898
+ api_key = api_key or os.getenv("OPENAI_API_KEY")
899
+ if api_key is None:
900
+ raise ValueError(
901
+ "An API Key needs to be provided in either the api_key parameter or as an environment variable named OPENAI_API_KEY"
902
+ )
903
+ self.api_key = api_key
904
+
905
+ @backoff.on_exception(
906
+ backoff.expo,
907
+ (
908
+ APIError,
909
+ RateLimitError,
910
+ APIConnectionError,
911
+ ),
912
+ )
913
+ def get_embeddings(self, text: str) -> Tuple[np.ndarray, List[int], float]:
914
+ client = OpenAI(api_key=self.api_key)
915
+ response = client.embeddings.create(model=self.model, input=text)
916
+
917
+ embeddings = np.array([data.embedding for data in response.data])
918
+ total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
919
+
920
+ return embeddings, total_tokens, cost
921
+
922
+
923
+
924
+ class GeminiEmbeddingEngine(LMMEngine):
925
+ def __init__(
926
+ self,
927
+ embedding_model: str = "text-embedding-004",
928
+ api_key=None,
929
+ **kwargs
930
+ ):
931
+ """Init an Gemini Embedding engine
932
+
933
+ Args:
934
+ embedding_model (str, optional): Model name. Defaults to "text-embedding-004".
935
+ api_key (_type_, optional): Auth key from Gemini. Defaults to None.
936
+ """
937
+ self.model = embedding_model
938
+ self.provider = "embedding-gemini"
939
+
940
+ api_key = api_key or os.getenv("GEMINI_API_KEY")
941
+ if api_key is None:
942
+ raise ValueError(
943
+ "An API Key needs to be provided in either the api_key parameter or as an environment variable named GEMINI_API_KEY"
944
+ )
945
+ self.api_key = api_key
946
+
947
+ @backoff.on_exception(
948
+ backoff.expo,
949
+ (
950
+ APIError,
951
+ RateLimitError,
952
+ APIConnectionError,
953
+ ),
954
+ )
955
+ def get_embeddings(self, text: str) -> Tuple[np.ndarray, List[int], float]:
956
+ client = genai.Client(api_key=self.api_key)
957
+
958
+ result = client.models.embed_content(
959
+ model=self.model,
960
+ contents=text,
961
+ config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"),
962
+ )
963
+
964
+ embeddings = np.array([i.values for i in result.embeddings]) # type: ignore
965
+ total_tokens, cost = calculate_tokens_and_cost(result, self.provider, self.model)
966
+
967
+ return embeddings, total_tokens, cost
968
+
969
+
970
+
971
+ class AzureOpenAIEmbeddingEngine(LMMEngine):
972
+ def __init__(
973
+ self,
974
+ embedding_model: str = "text-embedding-3-small",
975
+ api_key=None,
976
+ api_version=None,
977
+ endpoint_url=None,
978
+ **kwargs
979
+ ):
980
+ """Init an Azure OpenAI Embedding engine
981
+
982
+ Args:
983
+ embedding_model (str, optional): Model name. Defaults to "text-embedding-3-small".
984
+ api_key (_type_, optional): Auth key from Azure OpenAI. Defaults to None.
985
+ api_version (_type_, optional): API version. Defaults to None.
986
+ endpoint_url (_type_, optional): Endpoint URL. Defaults to None.
987
+ """
988
+ self.model = embedding_model
989
+ self.provider = "embedding-azureopenai"
990
+
991
+ api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
992
+ if api_key is None:
993
+ raise ValueError(
994
+ "An API Key needs to be provided in either the api_key parameter or as an environment variable named AZURE_OPENAI_API_KEY"
995
+ )
996
+ self.api_key = api_key
997
+
998
+ api_version = api_version or os.getenv("OPENAI_API_VERSION")
999
+ if api_version is None:
1000
+ raise ValueError(
1001
+ "An API Version needs to be provided in either the api_version parameter or as an environment variable named OPENAI_API_VERSION"
1002
+ )
1003
+ self.api_version = api_version
1004
+
1005
+ endpoint_url = endpoint_url or os.getenv("AZURE_OPENAI_ENDPOINT")
1006
+ if endpoint_url is None:
1007
+ raise ValueError(
1008
+ "An Endpoint URL needs to be provided in either the endpoint_url parameter or as an environment variable named AZURE_OPENAI_ENDPOINT"
1009
+ )
1010
+ self.endpoint_url = endpoint_url
1011
+
1012
+ @backoff.on_exception(
1013
+ backoff.expo,
1014
+ (
1015
+ APIError,
1016
+ RateLimitError,
1017
+ APIConnectionError,
1018
+ ),
1019
+ )
1020
+ def get_embeddings(self, text: str) -> Tuple[np.ndarray, List[int], float]:
1021
+ client = AzureOpenAI(
1022
+ api_key=self.api_key,
1023
+ api_version=self.api_version,
1024
+ azure_endpoint=self.endpoint_url,
1025
+ )
1026
+ response = client.embeddings.create(input=text, model=self.model)
1027
+
1028
+ embeddings = np.array([data.embedding for data in response.data])
1029
+ total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
1030
+
1031
+ return embeddings, total_tokens, cost
1032
+
1033
+
1034
+ class DashScopeEmbeddingEngine(LMMEngine):
1035
+ def __init__(
1036
+ self,
1037
+ embedding_model: str = "text-embedding-v4",
1038
+ api_key=None,
1039
+ dimensions: int = 1024,
1040
+ **kwargs
1041
+ ):
1042
+ """Init a DashScope (阿里云百炼) Embedding engine
1043
+
1044
+ Args:
1045
+ embedding_model (str, optional): Model name. Defaults to "text-embedding-v4".
1046
+ api_key (_type_, optional): Auth key from DashScope. Defaults to None.
1047
+ dimensions (int, optional): Embedding dimensions. Defaults to 1024.
1048
+ """
1049
+ self.model = embedding_model
1050
+ self.dimensions = dimensions
1051
+ self.provider = "embedding-qwen"
1052
+
1053
+ api_key = api_key or os.getenv("DASHSCOPE_API_KEY")
1054
+ if api_key is None:
1055
+ raise ValueError(
1056
+ "An API Key needs to be provided in either the api_key parameter or as an environment variable named DASHSCOPE_API_KEY"
1057
+ )
1058
+ self.api_key = api_key
1059
+
1060
+ # Initialize OpenAI client with DashScope base URL
1061
+ self.client = OpenAI(
1062
+ api_key=self.api_key,
1063
+ base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
1064
+ )
1065
+
1066
+ @backoff.on_exception(
1067
+ backoff.expo,
1068
+ (
1069
+ APIError,
1070
+ RateLimitError,
1071
+ APIConnectionError,
1072
+ ),
1073
+ )
1074
+ def get_embeddings(self, text: str) -> Tuple[np.ndarray, List[int], float]:
1075
+ response = self.client.embeddings.create(
1076
+ model=self.model,
1077
+ input=text,
1078
+ dimensions=self.dimensions,
1079
+ encoding_format="float"
1080
+ )
1081
+
1082
+ embeddings = np.array([data.embedding for data in response.data])
1083
+ total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
1084
+
1085
+ return embeddings, total_tokens, cost
1086
+
1087
+
1088
+
1089
+ class DoubaoEmbeddingEngine(LMMEngine):
1090
+ def __init__(
1091
+ self,
1092
+ embedding_model: str = "doubao-embedding-256",
1093
+ api_key=None,
1094
+ **kwargs
1095
+ ):
1096
+ """Init a Doubao (字节跳动豆包) Embedding engine
1097
+
1098
+ Args:
1099
+ embedding_model (str, optional): Model name. Defaults to "doubao-embedding-256".
1100
+ api_key (_type_, optional): Auth key from Doubao. Defaults to None.
1101
+ """
1102
+ self.model = embedding_model
1103
+ self.provider = "embedding-doubao"
1104
+
1105
+ api_key = api_key or os.getenv("ARK_API_KEY")
1106
+ if api_key is None:
1107
+ raise ValueError(
1108
+ "An API Key needs to be provided in either the api_key parameter or as an environment variable named ARK_API_KEY"
1109
+ )
1110
+ self.api_key = api_key
1111
+ self.base_url = "https://ark.cn-beijing.volces.com/api/v3"
1112
+
1113
+ # Use OpenAI-compatible client for text embeddings
1114
+ self.client = OpenAI(
1115
+ api_key=self.api_key,
1116
+ base_url=self.base_url
1117
+ )
1118
+
1119
+ @backoff.on_exception(
1120
+ backoff.expo,
1121
+ (
1122
+ APIError,
1123
+ RateLimitError,
1124
+ APIConnectionError,
1125
+ ),
1126
+ )
1127
+ def get_embeddings(self, text: str) -> Tuple[np.ndarray, List[int], float]:
1128
+ response = self.client.embeddings.create(
1129
+ model=self.model,
1130
+ input=text,
1131
+ encoding_format="float"
1132
+ )
1133
+
1134
+ embeddings = np.array([data.embedding for data in response.data])
1135
+ total_tokens, cost = calculate_tokens_and_cost(response, self.provider, self.model)
1136
+
1137
+ return embeddings, total_tokens, cost
1138
+
1139
+
1140
+ class JinaEmbeddingEngine(LMMEngine):
1141
+ def __init__(
1142
+ self,
1143
+ embedding_model: str = "jina-embeddings-v4",
1144
+ api_key=None,
1145
+ task: str = "retrieval.query",
1146
+ **kwargs
1147
+ ):
1148
+ """Init a Jina AI Embedding engine
1149
+
1150
+ Args:
1151
+ embedding_model (str, optional): Model name. Defaults to "jina-embeddings-v4".
1152
+ api_key (_type_, optional): Auth key from Jina AI. Defaults to None.
1153
+ task (str, optional): Task type. Options: "retrieval.query", "retrieval.passage", "text-matching". Defaults to "retrieval.query".
1154
+ """
1155
+ self.model = embedding_model
1156
+ self.task = task
1157
+ self.provider = "embedding-jina"
1158
+
1159
+ api_key = api_key or os.getenv("JINA_API_KEY")
1160
+ if api_key is None:
1161
+ raise ValueError(
1162
+ "An API Key needs to be provided in either the api_key parameter or as an environment variable named JINA_API_KEY"
1163
+ )
1164
+ self.api_key = api_key
1165
+ self.base_url = "https://api.jina.ai/v1"
1166
+
1167
+ @backoff.on_exception(
1168
+ backoff.expo,
1169
+ (
1170
+ APIError,
1171
+ RateLimitError,
1172
+ APIConnectionError,
1173
+ ),
1174
+ )
1175
+ def get_embeddings(self, text: str) -> Tuple[np.ndarray, List[int], float]:
1176
+ import requests
1177
+
1178
+ headers = {
1179
+ "Content-Type": "application/json",
1180
+ "Authorization": f"Bearer {self.api_key}"
1181
+ }
1182
+
1183
+ data = {
1184
+ "model": self.model,
1185
+ "task": self.task,
1186
+ "input": [
1187
+ {
1188
+ "text": text
1189
+ }
1190
+ ]
1191
+ }
1192
+
1193
+ response = requests.post(
1194
+ f"{self.base_url}/embeddings",
1195
+ headers=headers,
1196
+ json=data
1197
+ )
1198
+
1199
+ if response.status_code != 200:
1200
+ raise Exception(f"Jina AI API error: {response.text}")
1201
+
1202
+ result = response.json()
1203
+ embeddings = np.array([data["embedding"] for data in result["data"]])
1204
+
1205
+ total_tokens, cost = calculate_tokens_and_cost(result, self.provider, self.model)
1206
+
1207
+ return embeddings, total_tokens, cost
1208
+
1209
+
1210
+ # ==================== webSearch ====================
1211
+ class SearchEngine:
1212
+ """Base class for search engines"""
1213
+ pass
1214
+
1215
+ class BochaAISearchEngine(SearchEngine):
1216
+ def __init__(
1217
+ self,
1218
+ api_key: str|None = None,
1219
+ base_url: str = "https://api.bochaai.com/v1",
1220
+ rate_limit: int = -1,
1221
+ **kwargs
1222
+ ):
1223
+ """Init a Bocha AI Search engine
1224
+
1225
+ Args:
1226
+ api_key (str, optional): Auth key from Bocha AI. Defaults to None.
1227
+ base_url (str, optional): Base URL for the API. Defaults to "https://api.bochaai.com/v1".
1228
+ rate_limit (int, optional): Rate limit per minute. Defaults to -1 (no limit).
1229
+ """
1230
+ api_key = api_key or os.getenv("BOCHA_API_KEY")
1231
+ if api_key is None:
1232
+ raise ValueError(
1233
+ "An API Key needs to be provided in either the api_key parameter or as an environment variable named BOCHA_API_KEY"
1234
+ )
1235
+
1236
+ self.api_key = api_key
1237
+ self.base_url = base_url
1238
+ self.endpoint = f"{base_url}/ai-search"
1239
+ self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
1240
+
1241
+ @backoff.on_exception(
1242
+ backoff.expo,
1243
+ (
1244
+ APIConnectionError,
1245
+ APIError,
1246
+ RateLimitError,
1247
+ requests.exceptions.RequestException,
1248
+ ),
1249
+ max_time=60
1250
+ )
1251
+ def search(
1252
+ self,
1253
+ query: str,
1254
+ freshness: str = "noLimit",
1255
+ answer: bool = True,
1256
+ stream: bool = False,
1257
+ **kwargs
1258
+ ) -> Union[Dict[str, Any], Any]:
1259
+ """Search with AI and return intelligent answer
1260
+
1261
+ Args:
1262
+ query (str): Search query
1263
+ freshness (str, optional): Freshness filter. Defaults to "noLimit".
1264
+ answer (bool, optional): Whether to return answer. Defaults to True.
1265
+ stream (bool, optional): Whether to stream response. Defaults to False.
1266
+
1267
+ Returns:
1268
+ Union[Dict[str, Any], Any]: AI search results with sources and answer
1269
+ """
1270
+ headers = {
1271
+ 'Authorization': f'Bearer {self.api_key}',
1272
+ 'Content-Type': 'application/json'
1273
+ }
1274
+
1275
+ payload = {
1276
+ "query": query,
1277
+ "freshness": freshness,
1278
+ "answer": answer,
1279
+ "stream": stream,
1280
+ **kwargs
1281
+ }
1282
+
1283
+ if stream:
1284
+ result = self._stream_search(headers, payload)
1285
+ return result, [0, 0, 0], 0.06
1286
+ else:
1287
+ result = self._regular_search(headers, payload)
1288
+ return result, [0, 0, 0], 0.06
1289
+
1290
+
1291
+ def _regular_search(self, headers: Dict[str, str], payload: Dict[str, Any]) -> Dict[str, Any]:
1292
+ """Regular non-streaming search"""
1293
+ response = requests.post(
1294
+ self.endpoint,
1295
+ headers=headers,
1296
+ json=payload
1297
+ )
1298
+
1299
+ if response.status_code != 200:
1300
+ raise APIError(f"Bocha AI Search API error: {response.text}") # type: ignore
1301
+
1302
+ return response.json()
1303
+
1304
+ def _stream_search(self, headers: Dict[str, str], payload: Dict[str, Any]):
1305
+ """Streaming search response"""
1306
+ response = requests.post(
1307
+ self.endpoint,
1308
+ headers=headers,
1309
+ json=payload,
1310
+ stream=True
1311
+ )
1312
+
1313
+ if response.status_code != 200:
1314
+ raise APIError(f"Bocha AI Search API error: {response.text}") # type: ignore
1315
+
1316
+ for line in response.iter_lines():
1317
+ if line:
1318
+ line = line.decode('utf-8')
1319
+ if line.startswith('data:'):
1320
+ data = line[5:].strip()
1321
+ if data and data != '{"event":"done"}':
1322
+ try:
1323
+ yield json.loads(data)
1324
+ except json.JSONDecodeError:
1325
+ continue
1326
+
1327
+ def get_answer(self, query: str, **kwargs) -> Tuple[str, int, float]:
1328
+ """Get AI generated answer only"""
1329
+ result, _, remaining_balance = self.search(query, answer=True, **kwargs)
1330
+
1331
+ # Extract answer from messages
1332
+ messages = result.get("messages", []) # type: ignore
1333
+ answer = ""
1334
+ for message in messages:
1335
+ if message.get("type") == "answer":
1336
+ answer = message.get("content", "")
1337
+ break
1338
+
1339
+ return answer, [0,0,0], remaining_balance # type: ignore
1340
+
1341
+
1342
+ def get_sources(self, query: str, **kwargs) -> List[Dict[str, Any]]:
1343
+ """Get source materials only"""
1344
+ result, _, remaining_balance = self.search(query, **kwargs)
1345
+
1346
+ # Extract sources from messages
1347
+ sources = []
1348
+ messages = result.get("messages", []) # type: ignore
1349
+ for message in messages:
1350
+ if message.get("type") == "source":
1351
+ content_type = message.get("content_type", "")
1352
+ if content_type in ["webpage", "image", "video", "baike_pro", "medical_common"]:
1353
+ sources.append({
1354
+ "type": content_type,
1355
+ "content": json.loads(message.get("content", "{}"))
1356
+ })
1357
+
1358
+ return sources, 0, remaining_balance # type: ignore
1359
+
1360
+
1361
+ def get_follow_up_questions(self, query: str, **kwargs) -> List[str]:
1362
+ """Get follow-up questions"""
1363
+ result, _, remaining_balance = self.search(query, **kwargs)
1364
+
1365
+ # Extract follow-up questions from messages
1366
+ follow_ups = []
1367
+ messages = result.get("messages", []) # type: ignore
1368
+ for message in messages:
1369
+ if message.get("type") == "follow_up":
1370
+ follow_ups.append(message.get("content", ""))
1371
+
1372
+ return follow_ups, 0, remaining_balance # type: ignore
1373
+
1374
+
1375
+ class ExaResearchEngine(SearchEngine):
1376
+ def __init__(
1377
+ self,
1378
+ api_key: str|None = None,
1379
+ base_url: str = "https://api.exa.ai",
1380
+ rate_limit: int = -1,
1381
+ **kwargs
1382
+ ):
1383
+ """Init an Exa Research engine
1384
+
1385
+ Args:
1386
+ api_key (str, optional): Auth key from Exa AI. Defaults to None.
1387
+ base_url (str, optional): Base URL for the API. Defaults to "https://api.exa.ai".
1388
+ rate_limit (int, optional): Rate limit per minute. Defaults to -1 (no limit).
1389
+ """
1390
+ api_key = api_key or os.getenv("EXA_API_KEY")
1391
+ if api_key is None:
1392
+ raise ValueError(
1393
+ "An API Key needs to be provided in either the api_key parameter or as an environment variable named EXA_API_KEY"
1394
+ )
1395
+
1396
+ self.api_key = api_key
1397
+ self.base_url = base_url
1398
+ self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
1399
+
1400
+ # Initialize OpenAI-compatible client for chat completions
1401
+ self.chat_client = OpenAI(
1402
+ base_url=base_url,
1403
+ api_key=api_key
1404
+ )
1405
+
1406
+ # Initialize Exa client for research tasks
1407
+ try:
1408
+ from exa_py import Exa
1409
+ self.exa_client = Exa(api_key=api_key)
1410
+ except ImportError:
1411
+ self.exa_client = None
1412
+ print("Warning: exa_py not installed. Research tasks will not be available.")
1413
+
1414
+ @backoff.on_exception(
1415
+ backoff.expo,
1416
+ (
1417
+ APIConnectionError,
1418
+ APIError,
1419
+ RateLimitError,
1420
+ ),
1421
+ max_time=60
1422
+ )
1423
+ def search(self, query: str, **kwargs):
1424
+ """Standard Exa search with direct cost from API
1425
+
1426
+ Args:
1427
+ query (str): Search query
1428
+ **kwargs: Additional search parameters
1429
+
1430
+ Returns:
1431
+ tuple: (result, tokens, cost) where cost is actual API cost
1432
+ """
1433
+ headers = {
1434
+ 'x-api-key': self.api_key,
1435
+ 'Content-Type': 'application/json'
1436
+ }
1437
+
1438
+ payload = {
1439
+ "query": query,
1440
+ **kwargs
1441
+ }
1442
+
1443
+ response = requests.post(
1444
+ f"{self.base_url}/search",
1445
+ headers=headers,
1446
+ json=payload
1447
+ )
1448
+
1449
+ if response.status_code != 200:
1450
+ raise APIError(f"Exa Search API error: {response.text}") # type: ignore
1451
+
1452
+ result = response.json()
1453
+
1454
+ cost = 0.0
1455
+ if "costDollars" in result:
1456
+ cost = result["costDollars"].get("total", 0.0)
1457
+
1458
+ return result, [0, 0, 0], cost
1459
+
1460
+ def chat_research(
1461
+ self,
1462
+ query: str,
1463
+ model: str = "exa",
1464
+ stream: bool = False,
1465
+ **kwargs
1466
+ ) -> Union[str, Any]:
1467
+ """Research using chat completions interface
1468
+
1469
+ Args:
1470
+ query (str): Research query
1471
+ model (str, optional): Model name. Defaults to "exa".
1472
+ stream (bool, optional): Whether to stream response. Defaults to False.
1473
+
1474
+ Returns:
1475
+ Union[str, Any]: Research result or stream
1476
+ """
1477
+ messages = [
1478
+ {"role": "user", "content": query}
1479
+ ]
1480
+
1481
+ if stream:
1482
+ completion = self.chat_client.chat.completions.create(
1483
+ model=model,
1484
+ messages=messages, # type: ignore
1485
+ stream=True,
1486
+ **kwargs
1487
+ )
1488
+ return completion
1489
+ else:
1490
+ completion = self.chat_client.chat.completions.create(
1491
+ model=model,
1492
+ messages=messages, # type: ignore
1493
+ **kwargs
1494
+ )
1495
+ result = completion.choices[0].message.content # type: ignore
1496
+ return result,[0,0,0],0.005