lybic-guiagents 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lybic-guiagents might be problematic. Click here for more details.
- desktop_env/__init__.py +1 -0
- desktop_env/actions.py +203 -0
- desktop_env/controllers/__init__.py +0 -0
- desktop_env/controllers/python.py +471 -0
- desktop_env/controllers/setup.py +882 -0
- desktop_env/desktop_env.py +509 -0
- desktop_env/evaluators/__init__.py +5 -0
- desktop_env/evaluators/getters/__init__.py +41 -0
- desktop_env/evaluators/getters/calc.py +15 -0
- desktop_env/evaluators/getters/chrome.py +1774 -0
- desktop_env/evaluators/getters/file.py +154 -0
- desktop_env/evaluators/getters/general.py +42 -0
- desktop_env/evaluators/getters/gimp.py +38 -0
- desktop_env/evaluators/getters/impress.py +126 -0
- desktop_env/evaluators/getters/info.py +24 -0
- desktop_env/evaluators/getters/misc.py +406 -0
- desktop_env/evaluators/getters/replay.py +20 -0
- desktop_env/evaluators/getters/vlc.py +86 -0
- desktop_env/evaluators/getters/vscode.py +35 -0
- desktop_env/evaluators/metrics/__init__.py +160 -0
- desktop_env/evaluators/metrics/basic_os.py +68 -0
- desktop_env/evaluators/metrics/chrome.py +493 -0
- desktop_env/evaluators/metrics/docs.py +1011 -0
- desktop_env/evaluators/metrics/general.py +665 -0
- desktop_env/evaluators/metrics/gimp.py +637 -0
- desktop_env/evaluators/metrics/libreoffice.py +28 -0
- desktop_env/evaluators/metrics/others.py +92 -0
- desktop_env/evaluators/metrics/pdf.py +31 -0
- desktop_env/evaluators/metrics/slides.py +957 -0
- desktop_env/evaluators/metrics/table.py +585 -0
- desktop_env/evaluators/metrics/thunderbird.py +176 -0
- desktop_env/evaluators/metrics/utils.py +719 -0
- desktop_env/evaluators/metrics/vlc.py +524 -0
- desktop_env/evaluators/metrics/vscode.py +283 -0
- desktop_env/providers/__init__.py +35 -0
- desktop_env/providers/aws/__init__.py +0 -0
- desktop_env/providers/aws/manager.py +278 -0
- desktop_env/providers/aws/provider.py +186 -0
- desktop_env/providers/aws/provider_with_proxy.py +315 -0
- desktop_env/providers/aws/proxy_pool.py +193 -0
- desktop_env/providers/azure/__init__.py +0 -0
- desktop_env/providers/azure/manager.py +87 -0
- desktop_env/providers/azure/provider.py +207 -0
- desktop_env/providers/base.py +97 -0
- desktop_env/providers/gcp/__init__.py +0 -0
- desktop_env/providers/gcp/manager.py +0 -0
- desktop_env/providers/gcp/provider.py +0 -0
- desktop_env/providers/virtualbox/__init__.py +0 -0
- desktop_env/providers/virtualbox/manager.py +463 -0
- desktop_env/providers/virtualbox/provider.py +124 -0
- desktop_env/providers/vmware/__init__.py +0 -0
- desktop_env/providers/vmware/manager.py +455 -0
- desktop_env/providers/vmware/provider.py +105 -0
- gui_agents/__init__.py +0 -0
- gui_agents/agents/Action.py +209 -0
- gui_agents/agents/__init__.py +0 -0
- gui_agents/agents/agent_s.py +832 -0
- gui_agents/agents/global_state.py +610 -0
- gui_agents/agents/grounding.py +651 -0
- gui_agents/agents/hardware_interface.py +129 -0
- gui_agents/agents/manager.py +568 -0
- gui_agents/agents/translator.py +132 -0
- gui_agents/agents/worker.py +355 -0
- gui_agents/cli_app.py +560 -0
- gui_agents/core/__init__.py +0 -0
- gui_agents/core/engine.py +1496 -0
- gui_agents/core/knowledge.py +449 -0
- gui_agents/core/mllm.py +555 -0
- gui_agents/tools/__init__.py +0 -0
- gui_agents/tools/tools.py +727 -0
- gui_agents/unit_test/__init__.py +0 -0
- gui_agents/unit_test/run_tests.py +65 -0
- gui_agents/unit_test/test_manager.py +330 -0
- gui_agents/unit_test/test_worker.py +269 -0
- gui_agents/utils/__init__.py +0 -0
- gui_agents/utils/analyze_display.py +301 -0
- gui_agents/utils/common_utils.py +263 -0
- gui_agents/utils/display_viewer.py +281 -0
- gui_agents/utils/embedding_manager.py +53 -0
- gui_agents/utils/image_axis_utils.py +27 -0
- lybic_guiagents-0.1.0.dist-info/METADATA +416 -0
- lybic_guiagents-0.1.0.dist-info/RECORD +85 -0
- lybic_guiagents-0.1.0.dist-info/WHEEL +5 -0
- lybic_guiagents-0.1.0.dist-info/licenses/LICENSE +201 -0
- lybic_guiagents-0.1.0.dist-info/top_level.txt +2 -0
gui_agents/core/mllm.py
ADDED
|
@@ -0,0 +1,555 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from gui_agents.core.engine import (
|
|
6
|
+
LMMEngineAnthropic,
|
|
7
|
+
LMMEngineAzureOpenAI,
|
|
8
|
+
LMMEngineHuggingFace,
|
|
9
|
+
LMMEngineOpenAI,
|
|
10
|
+
LMMEngineOpenRouter,
|
|
11
|
+
LMMEnginevLLM,
|
|
12
|
+
LMMEngineGemini,
|
|
13
|
+
LMMEngineQwen,
|
|
14
|
+
LMMEngineDoubao,
|
|
15
|
+
LMMEngineDeepSeek,
|
|
16
|
+
LMMEngineZhipu,
|
|
17
|
+
LMMEngineGroq,
|
|
18
|
+
LMMEngineSiliconflow,
|
|
19
|
+
LMMEngineMonica,
|
|
20
|
+
LMMEngineAWSBedrock,
|
|
21
|
+
OpenAIEmbeddingEngine,
|
|
22
|
+
GeminiEmbeddingEngine,
|
|
23
|
+
AzureOpenAIEmbeddingEngine,
|
|
24
|
+
DashScopeEmbeddingEngine,
|
|
25
|
+
DoubaoEmbeddingEngine,
|
|
26
|
+
JinaEmbeddingEngine,
|
|
27
|
+
BochaAISearchEngine,
|
|
28
|
+
ExaResearchEngine,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
class CostManager:
|
|
32
|
+
"""Cost manager, responsible for adding currency symbols based on engine type"""
|
|
33
|
+
|
|
34
|
+
# Chinese engines use CNY
|
|
35
|
+
CNY_ENGINES = {
|
|
36
|
+
LMMEngineQwen, LMMEngineDoubao, LMMEngineDeepSeek, LMMEngineZhipu,
|
|
37
|
+
LMMEngineSiliconflow, DashScopeEmbeddingEngine, DoubaoEmbeddingEngine
|
|
38
|
+
}
|
|
39
|
+
# Other engines use USD
|
|
40
|
+
USD_ENGINES = {
|
|
41
|
+
LMMEngineOpenAI, LMMEngineAnthropic, LMMEngineAzureOpenAI, LMMEngineGemini,
|
|
42
|
+
LMMEngineOpenRouter, LMMEnginevLLM, LMMEngineHuggingFace, LMMEngineGroq,
|
|
43
|
+
LMMEngineMonica, LMMEngineAWSBedrock, OpenAIEmbeddingEngine,
|
|
44
|
+
GeminiEmbeddingEngine, AzureOpenAIEmbeddingEngine, JinaEmbeddingEngine
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def get_currency_symbol(cls, engine) -> str:
|
|
49
|
+
engine_type = type(engine)
|
|
50
|
+
|
|
51
|
+
if engine_type in cls.CNY_ENGINES:
|
|
52
|
+
return "¥"
|
|
53
|
+
elif engine_type in cls.USD_ENGINES:
|
|
54
|
+
return "$"
|
|
55
|
+
else:
|
|
56
|
+
return "$"
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def format_cost(cls, cost: float, engine) -> str:
|
|
60
|
+
currency = cls.get_currency_symbol(engine)
|
|
61
|
+
return f"{cost:.7f}{currency}"
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
def add_costs(cls, cost1: str, cost2: str) -> str:
|
|
65
|
+
currency_symbols = ["$", "¥", "¥", "€", "£"]
|
|
66
|
+
currency1 = currency2 = "$"
|
|
67
|
+
value1 = value2 = 0.0
|
|
68
|
+
|
|
69
|
+
if isinstance(cost1, (int, float)):
|
|
70
|
+
value1 = float(cost1)
|
|
71
|
+
currency1 = "$"
|
|
72
|
+
else:
|
|
73
|
+
cost1_str = str(cost1)
|
|
74
|
+
for symbol in currency_symbols:
|
|
75
|
+
if symbol in cost1_str:
|
|
76
|
+
value1 = float(cost1_str.replace(symbol, "").strip())
|
|
77
|
+
currency1 = symbol
|
|
78
|
+
break
|
|
79
|
+
else:
|
|
80
|
+
try:
|
|
81
|
+
value1 = float(cost1_str)
|
|
82
|
+
currency1 = "$"
|
|
83
|
+
except:
|
|
84
|
+
value1 = 0.0
|
|
85
|
+
|
|
86
|
+
if isinstance(cost2, (int, float)):
|
|
87
|
+
value2 = float(cost2)
|
|
88
|
+
currency2 = "$"
|
|
89
|
+
else:
|
|
90
|
+
cost2_str = str(cost2)
|
|
91
|
+
for symbol in currency_symbols:
|
|
92
|
+
if symbol in cost2_str:
|
|
93
|
+
value2 = float(cost2_str.replace(symbol, "").strip())
|
|
94
|
+
currency2 = symbol
|
|
95
|
+
break
|
|
96
|
+
else:
|
|
97
|
+
try:
|
|
98
|
+
value2 = float(cost2_str)
|
|
99
|
+
currency2 = "$"
|
|
100
|
+
except:
|
|
101
|
+
value2 = 0.0
|
|
102
|
+
|
|
103
|
+
if currency1 != currency2:
|
|
104
|
+
print(f"Warning: Different currencies in cost accumulation: {currency1} and {currency2}")
|
|
105
|
+
currency = currency1
|
|
106
|
+
else:
|
|
107
|
+
currency = currency1
|
|
108
|
+
|
|
109
|
+
total_value = value1 + value2
|
|
110
|
+
return f"{total_value:.6f}{currency}"
|
|
111
|
+
|
|
112
|
+
class LLMAgent:
|
|
113
|
+
def __init__(self, engine_params=None, system_prompt=None, engine=None):
|
|
114
|
+
if engine is None:
|
|
115
|
+
if engine_params is not None:
|
|
116
|
+
engine_type = engine_params.get("engine_type")
|
|
117
|
+
if engine_type == "openai":
|
|
118
|
+
self.engine = LMMEngineOpenAI(**engine_params)
|
|
119
|
+
elif engine_type == "anthropic":
|
|
120
|
+
self.engine = LMMEngineAnthropic(**engine_params)
|
|
121
|
+
elif engine_type == "azure":
|
|
122
|
+
self.engine = LMMEngineAzureOpenAI(**engine_params)
|
|
123
|
+
elif engine_type == "vllm":
|
|
124
|
+
self.engine = LMMEnginevLLM(**engine_params)
|
|
125
|
+
elif engine_type == "huggingface":
|
|
126
|
+
self.engine = LMMEngineHuggingFace(**engine_params)
|
|
127
|
+
elif engine_type == "gemini":
|
|
128
|
+
self.engine = LMMEngineGemini(**engine_params)
|
|
129
|
+
elif engine_type == "open_router":
|
|
130
|
+
self.engine = LMMEngineOpenRouter(**engine_params)
|
|
131
|
+
elif engine_type == "dashscope":
|
|
132
|
+
self.engine = LMMEngineQwen(**engine_params)
|
|
133
|
+
elif engine_type == "doubao":
|
|
134
|
+
self.engine = LMMEngineDoubao(**engine_params)
|
|
135
|
+
elif engine_type == "deepseek":
|
|
136
|
+
self.engine = LMMEngineDeepSeek(**engine_params)
|
|
137
|
+
elif engine_type == "zhipu":
|
|
138
|
+
self.engine = LMMEngineZhipu(**engine_params)
|
|
139
|
+
elif engine_type == "groq":
|
|
140
|
+
self.engine = LMMEngineGroq(**engine_params)
|
|
141
|
+
elif engine_type == "siliconflow":
|
|
142
|
+
self.engine = LMMEngineSiliconflow(**engine_params)
|
|
143
|
+
elif engine_type == "monica":
|
|
144
|
+
self.engine = LMMEngineMonica(**engine_params)
|
|
145
|
+
elif engine_type == "aws_bedrock":
|
|
146
|
+
self.engine = LMMEngineAWSBedrock(**engine_params)
|
|
147
|
+
else:
|
|
148
|
+
raise ValueError("engine_type is not supported")
|
|
149
|
+
else:
|
|
150
|
+
raise ValueError("engine_params must be provided")
|
|
151
|
+
else:
|
|
152
|
+
self.engine = engine
|
|
153
|
+
|
|
154
|
+
self.messages = [] # Empty messages
|
|
155
|
+
|
|
156
|
+
if system_prompt:
|
|
157
|
+
self.add_system_prompt(system_prompt)
|
|
158
|
+
else:
|
|
159
|
+
self.add_system_prompt("You are a helpful assistant.")
|
|
160
|
+
|
|
161
|
+
def encode_image(self, image_content):
|
|
162
|
+
# if image_content is a path to an image file, check type of the image_content to verify
|
|
163
|
+
if isinstance(image_content, str):
|
|
164
|
+
with open(image_content, "rb") as image_file:
|
|
165
|
+
return base64.b64encode(image_file.read()).decode("utf-8")
|
|
166
|
+
else:
|
|
167
|
+
return base64.b64encode(image_content).decode("utf-8")
|
|
168
|
+
|
|
169
|
+
def reset(
|
|
170
|
+
self,
|
|
171
|
+
):
|
|
172
|
+
|
|
173
|
+
self.messages = [
|
|
174
|
+
{
|
|
175
|
+
"role": "system",
|
|
176
|
+
"content": [{"type": "text", "text": self.system_prompt}],
|
|
177
|
+
}
|
|
178
|
+
]
|
|
179
|
+
|
|
180
|
+
def add_system_prompt(self, system_prompt):
|
|
181
|
+
self.system_prompt = system_prompt
|
|
182
|
+
if len(self.messages) > 0:
|
|
183
|
+
self.messages[0] = {
|
|
184
|
+
"role": "system",
|
|
185
|
+
"content": [{"type": "text", "text": self.system_prompt}],
|
|
186
|
+
}
|
|
187
|
+
else:
|
|
188
|
+
self.messages.append(
|
|
189
|
+
{
|
|
190
|
+
"role": "system",
|
|
191
|
+
"content": [{"type": "text", "text": self.system_prompt}],
|
|
192
|
+
}
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
def remove_message_at(self, index):
|
|
196
|
+
"""Remove a message at a given index"""
|
|
197
|
+
if index < len(self.messages):
|
|
198
|
+
self.messages.pop(index)
|
|
199
|
+
|
|
200
|
+
def replace_message_at(
|
|
201
|
+
self, index, text_content, image_content=None, image_detail="high"
|
|
202
|
+
):
|
|
203
|
+
"""Replace a message at a given index"""
|
|
204
|
+
if index < len(self.messages):
|
|
205
|
+
self.messages[index] = {
|
|
206
|
+
"role": self.messages[index]["role"],
|
|
207
|
+
"content": [{"type": "text", "text": text_content}],
|
|
208
|
+
}
|
|
209
|
+
if image_content:
|
|
210
|
+
base64_image = self.encode_image(image_content)
|
|
211
|
+
self.messages[index]["content"].append(
|
|
212
|
+
{
|
|
213
|
+
"type": "image_url",
|
|
214
|
+
"image_url": {
|
|
215
|
+
"url": f"data:image/png;base64,{base64_image}",
|
|
216
|
+
"detail": image_detail,
|
|
217
|
+
},
|
|
218
|
+
}
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
def add_message(
|
|
222
|
+
self,
|
|
223
|
+
text_content,
|
|
224
|
+
image_content=None,
|
|
225
|
+
role=None,
|
|
226
|
+
image_detail="high",
|
|
227
|
+
put_text_last=False,
|
|
228
|
+
):
|
|
229
|
+
"""Add a new message to the list of messages"""
|
|
230
|
+
|
|
231
|
+
# API-style inference from OpenAI and similar services
|
|
232
|
+
if isinstance(
|
|
233
|
+
self.engine,
|
|
234
|
+
(
|
|
235
|
+
LMMEngineAnthropic,
|
|
236
|
+
LMMEngineAzureOpenAI,
|
|
237
|
+
LMMEngineHuggingFace,
|
|
238
|
+
LMMEngineOpenAI,
|
|
239
|
+
LMMEngineOpenRouter,
|
|
240
|
+
LMMEnginevLLM,
|
|
241
|
+
LMMEngineGemini,
|
|
242
|
+
LMMEngineQwen,
|
|
243
|
+
LMMEngineDoubao,
|
|
244
|
+
LMMEngineDeepSeek,
|
|
245
|
+
LMMEngineZhipu,
|
|
246
|
+
LMMEngineGroq,
|
|
247
|
+
LMMEngineSiliconflow,
|
|
248
|
+
LMMEngineMonica,
|
|
249
|
+
LMMEngineAWSBedrock,
|
|
250
|
+
),
|
|
251
|
+
):
|
|
252
|
+
# infer role from previous message
|
|
253
|
+
if role != "user":
|
|
254
|
+
if self.messages[-1]["role"] == "system":
|
|
255
|
+
role = "user"
|
|
256
|
+
elif self.messages[-1]["role"] == "user":
|
|
257
|
+
role = "assistant"
|
|
258
|
+
elif self.messages[-1]["role"] == "assistant":
|
|
259
|
+
role = "user"
|
|
260
|
+
|
|
261
|
+
message = {
|
|
262
|
+
"role": role,
|
|
263
|
+
"content": [{"type": "text", "text": text_content}],
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
if isinstance(image_content, np.ndarray) or image_content:
|
|
267
|
+
# Check if image_content is a list or a single image
|
|
268
|
+
if isinstance(image_content, list):
|
|
269
|
+
# If image_content is a list of images, loop through each image
|
|
270
|
+
for image in image_content:
|
|
271
|
+
base64_image = self.encode_image(image)
|
|
272
|
+
message["content"].append(
|
|
273
|
+
{
|
|
274
|
+
"type": "image_url",
|
|
275
|
+
"image_url": {
|
|
276
|
+
"url": f"data:image/png;base64,{base64_image}",
|
|
277
|
+
"detail": image_detail,
|
|
278
|
+
},
|
|
279
|
+
}
|
|
280
|
+
)
|
|
281
|
+
else:
|
|
282
|
+
# If image_content is a single image, handle it directly
|
|
283
|
+
base64_image = self.encode_image(image_content)
|
|
284
|
+
message["content"].append(
|
|
285
|
+
{
|
|
286
|
+
"type": "image_url",
|
|
287
|
+
"image_url": {
|
|
288
|
+
"url": f"data:image/png;base64,{base64_image}",
|
|
289
|
+
"detail": image_detail,
|
|
290
|
+
},
|
|
291
|
+
}
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
# Rotate text to be the last message if desired
|
|
295
|
+
if put_text_last:
|
|
296
|
+
text_content = message["content"].pop(0)
|
|
297
|
+
message["content"].append(text_content)
|
|
298
|
+
|
|
299
|
+
self.messages.append(message)
|
|
300
|
+
|
|
301
|
+
# For API-style inference from Anthropic
|
|
302
|
+
elif isinstance(self.engine, (LMMEngineAnthropic, LMMEngineAWSBedrock)):
|
|
303
|
+
# infer role from previous message
|
|
304
|
+
if role != "user":
|
|
305
|
+
if self.messages[-1]["role"] == "system":
|
|
306
|
+
role = "user"
|
|
307
|
+
elif self.messages[-1]["role"] == "user":
|
|
308
|
+
role = "assistant"
|
|
309
|
+
elif self.messages[-1]["role"] == "assistant":
|
|
310
|
+
role = "user"
|
|
311
|
+
|
|
312
|
+
message = {
|
|
313
|
+
"role": role,
|
|
314
|
+
"content": [{"type": "text", "text": text_content}],
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
if image_content:
|
|
318
|
+
# Check if image_content is a list or a single image
|
|
319
|
+
if isinstance(image_content, list):
|
|
320
|
+
# If image_content is a list of images, loop through each image
|
|
321
|
+
for image in image_content:
|
|
322
|
+
base64_image = self.encode_image(image)
|
|
323
|
+
message["content"].append(
|
|
324
|
+
{
|
|
325
|
+
"type": "image",
|
|
326
|
+
"source": {
|
|
327
|
+
"type": "base64",
|
|
328
|
+
"media_type": "image/png",
|
|
329
|
+
"data": base64_image,
|
|
330
|
+
},
|
|
331
|
+
}
|
|
332
|
+
)
|
|
333
|
+
else:
|
|
334
|
+
# If image_content is a single image, handle it directly
|
|
335
|
+
base64_image = self.encode_image(image_content)
|
|
336
|
+
message["content"].append(
|
|
337
|
+
{
|
|
338
|
+
"type": "image",
|
|
339
|
+
"source": {
|
|
340
|
+
"type": "base64",
|
|
341
|
+
"media_type": "image/png",
|
|
342
|
+
"data": base64_image,
|
|
343
|
+
},
|
|
344
|
+
}
|
|
345
|
+
)
|
|
346
|
+
self.messages.append(message)
|
|
347
|
+
|
|
348
|
+
# Locally hosted vLLM model inference
|
|
349
|
+
elif isinstance(self.engine, LMMEnginevLLM):
|
|
350
|
+
# infer role from previous message
|
|
351
|
+
if role != "user":
|
|
352
|
+
if self.messages[-1]["role"] == "system":
|
|
353
|
+
role = "user"
|
|
354
|
+
elif self.messages[-1]["role"] == "user":
|
|
355
|
+
role = "assistant"
|
|
356
|
+
elif self.messages[-1]["role"] == "assistant":
|
|
357
|
+
role = "user"
|
|
358
|
+
|
|
359
|
+
message = {
|
|
360
|
+
"role": role,
|
|
361
|
+
"content": [{"type": "text", "text": text_content}],
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
if image_content:
|
|
365
|
+
# Check if image_content is a list or a single image
|
|
366
|
+
if isinstance(image_content, list):
|
|
367
|
+
# If image_content is a list of images, loop through each image
|
|
368
|
+
for image in image_content:
|
|
369
|
+
base64_image = self.encode_image(image)
|
|
370
|
+
message["content"].append(
|
|
371
|
+
{
|
|
372
|
+
"type": "image_url",
|
|
373
|
+
"image_url": {
|
|
374
|
+
"url": f"data:image;base64,{base64_image}"
|
|
375
|
+
},
|
|
376
|
+
}
|
|
377
|
+
)
|
|
378
|
+
else:
|
|
379
|
+
# If image_content is a single image, handle it directly
|
|
380
|
+
base64_image = self.encode_image(image_content)
|
|
381
|
+
message["content"].append(
|
|
382
|
+
{
|
|
383
|
+
"type": "image_url",
|
|
384
|
+
"image_url": {"url": f"data:image;base64,{base64_image}"},
|
|
385
|
+
}
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
self.messages.append(message)
|
|
389
|
+
else:
|
|
390
|
+
raise ValueError("engine_type is not supported")
|
|
391
|
+
|
|
392
|
+
def get_response(
|
|
393
|
+
self,
|
|
394
|
+
user_message=None,
|
|
395
|
+
messages=None,
|
|
396
|
+
temperature=0.0,
|
|
397
|
+
max_new_tokens=None,
|
|
398
|
+
**kwargs,
|
|
399
|
+
):
|
|
400
|
+
"""Generate the next response based on previous messages"""
|
|
401
|
+
if messages is None:
|
|
402
|
+
messages = self.messages
|
|
403
|
+
if user_message:
|
|
404
|
+
messages.append(
|
|
405
|
+
{"role": "user", "content": [{"type": "text", "text": user_message}]}
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
content, total_tokens, cost = self.engine.generate(
|
|
409
|
+
messages,
|
|
410
|
+
temperature=temperature,
|
|
411
|
+
max_new_tokens=max_new_tokens, # type: ignore
|
|
412
|
+
**kwargs,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
cost_string = CostManager.format_cost(cost, self.engine)
|
|
416
|
+
|
|
417
|
+
return content, total_tokens, cost_string
|
|
418
|
+
|
|
419
|
+
class EmbeddingAgent:
|
|
420
|
+
def __init__(self, engine_params=None, engine=None):
|
|
421
|
+
if engine is None:
|
|
422
|
+
if engine_params is not None:
|
|
423
|
+
engine_type = engine_params.get("engine_type")
|
|
424
|
+
if engine_type == "openai":
|
|
425
|
+
self.engine = OpenAIEmbeddingEngine(**engine_params)
|
|
426
|
+
elif engine_type == "gemini":
|
|
427
|
+
self.engine = GeminiEmbeddingEngine(**engine_params)
|
|
428
|
+
elif engine_type == "azure":
|
|
429
|
+
self.engine = AzureOpenAIEmbeddingEngine(**engine_params)
|
|
430
|
+
elif engine_type == "dashscope":
|
|
431
|
+
self.engine = DashScopeEmbeddingEngine(**engine_params)
|
|
432
|
+
elif engine_type == "doubao":
|
|
433
|
+
self.engine = DoubaoEmbeddingEngine(**engine_params)
|
|
434
|
+
elif engine_type == "jina":
|
|
435
|
+
self.engine = JinaEmbeddingEngine(**engine_params)
|
|
436
|
+
else:
|
|
437
|
+
raise ValueError(f"Embedding engine type '{engine_type}' is not supported")
|
|
438
|
+
else:
|
|
439
|
+
raise ValueError("engine_params must be provided")
|
|
440
|
+
else:
|
|
441
|
+
self.engine = engine
|
|
442
|
+
|
|
443
|
+
def get_embeddings(self, text):
|
|
444
|
+
"""Get embeddings for the given text
|
|
445
|
+
|
|
446
|
+
Args:
|
|
447
|
+
text (str): The text to get embeddings for
|
|
448
|
+
|
|
449
|
+
Returns:
|
|
450
|
+
numpy.ndarray: The embeddings for the text
|
|
451
|
+
"""
|
|
452
|
+
embeddings, total_tokens, cost = self.engine.get_embeddings(text)
|
|
453
|
+
cost_string = CostManager.format_cost(cost, self.engine)
|
|
454
|
+
return embeddings, total_tokens, cost_string
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def get_similarity(self, text1, text2):
|
|
458
|
+
"""Calculate the cosine similarity between two texts
|
|
459
|
+
|
|
460
|
+
Args:
|
|
461
|
+
text1 (str): First text
|
|
462
|
+
text2 (str): Second text
|
|
463
|
+
|
|
464
|
+
Returns:
|
|
465
|
+
float: Cosine similarity score between the two texts
|
|
466
|
+
"""
|
|
467
|
+
embeddings1, tokens1, cost1 = self.get_embeddings(text1)
|
|
468
|
+
embeddings2, tokens2, cost2 = self.get_embeddings(text2)
|
|
469
|
+
|
|
470
|
+
# Calculate cosine similarity
|
|
471
|
+
dot_product = np.dot(embeddings1, embeddings2)
|
|
472
|
+
norm1 = np.linalg.norm(embeddings1)
|
|
473
|
+
norm2 = np.linalg.norm(embeddings2)
|
|
474
|
+
|
|
475
|
+
similarity = dot_product / (norm1 * norm2)
|
|
476
|
+
total_tokens = tokens1 + tokens2
|
|
477
|
+
total_cost = CostManager.add_costs(cost1, cost2)
|
|
478
|
+
|
|
479
|
+
return similarity, total_tokens, total_cost
|
|
480
|
+
|
|
481
|
+
def batch_get_embeddings(self, texts):
|
|
482
|
+
"""Get embeddings for multiple texts
|
|
483
|
+
|
|
484
|
+
Args:
|
|
485
|
+
texts (List[str]): List of texts to get embeddings for
|
|
486
|
+
|
|
487
|
+
Returns:
|
|
488
|
+
List[numpy.ndarray]: List of embeddings for each text
|
|
489
|
+
"""
|
|
490
|
+
embeddings = []
|
|
491
|
+
total_tokens = [0, 0, 0]
|
|
492
|
+
if texts:
|
|
493
|
+
first_embedding, first_tokens, first_cost = self.get_embeddings(texts[0])
|
|
494
|
+
embeddings.append(first_embedding)
|
|
495
|
+
total_tokens[0] += first_tokens[0]
|
|
496
|
+
total_tokens[1] += first_tokens[1]
|
|
497
|
+
total_tokens[2] += first_tokens[2]
|
|
498
|
+
total_cost = first_cost
|
|
499
|
+
|
|
500
|
+
for text in texts[1:]:
|
|
501
|
+
embedding, tokens, cost = self.get_embeddings(text)
|
|
502
|
+
embeddings.append(embedding)
|
|
503
|
+
total_tokens[0] += tokens[0]
|
|
504
|
+
total_tokens[1] += tokens[1]
|
|
505
|
+
total_tokens[2] += tokens[2]
|
|
506
|
+
total_cost = CostManager.add_costs(total_cost, cost)
|
|
507
|
+
else:
|
|
508
|
+
currency = CostManager.get_currency_symbol(self.engine)
|
|
509
|
+
total_cost = f"0.0{currency}"
|
|
510
|
+
|
|
511
|
+
return embeddings, total_tokens, total_cost
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
class WebSearchAgent:
|
|
515
|
+
def __init__(self, engine_params=None, engine=None):
|
|
516
|
+
if engine is None:
|
|
517
|
+
if engine_params is not None:
|
|
518
|
+
self.engine_type = engine_params.get("engine_type")
|
|
519
|
+
if self.engine_type == "bocha":
|
|
520
|
+
self.engine = BochaAISearchEngine(**engine_params)
|
|
521
|
+
elif self.engine_type == "exa":
|
|
522
|
+
self.engine = ExaResearchEngine(**engine_params)
|
|
523
|
+
else:
|
|
524
|
+
raise ValueError(f"Web search engine type '{self.engine_type}' is not supported")
|
|
525
|
+
else:
|
|
526
|
+
raise ValueError("engine_params must be provided")
|
|
527
|
+
else:
|
|
528
|
+
self.engine = engine
|
|
529
|
+
|
|
530
|
+
def get_answer(self, query, **kwargs):
|
|
531
|
+
"""Get a direct answer for the query
|
|
532
|
+
|
|
533
|
+
Args:
|
|
534
|
+
query (str): The search query
|
|
535
|
+
**kwargs: Additional arguments to pass to the search engine
|
|
536
|
+
|
|
537
|
+
Returns:
|
|
538
|
+
str: The answer text
|
|
539
|
+
"""
|
|
540
|
+
if isinstance(self.engine, BochaAISearchEngine):
|
|
541
|
+
answer, tokens, cost = self.engine.get_answer(query, **kwargs)
|
|
542
|
+
return answer, tokens, str(cost)
|
|
543
|
+
|
|
544
|
+
elif isinstance(self.engine, ExaResearchEngine):
|
|
545
|
+
# For Exa, we'll use the chat_research method which returns a complete answer
|
|
546
|
+
# results, tokens, cost = self.engine.search(query, **kwargs)
|
|
547
|
+
results, tokens, cost = self.engine.chat_research(query, **kwargs)
|
|
548
|
+
if isinstance(results, dict) and "messages" in results:
|
|
549
|
+
for message in results.get("messages", []):
|
|
550
|
+
if message.get("type") == "answer":
|
|
551
|
+
return message.get("content", ""), tokens, str(cost)
|
|
552
|
+
return str(results), tokens, str(cost)
|
|
553
|
+
|
|
554
|
+
else:
|
|
555
|
+
raise ValueError(f"Web search engine type '{self.engine_type}' is not supported")
|
|
File without changes
|