lybic-guiagents 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lybic-guiagents might be problematic. Click here for more details.

Files changed (85) hide show
  1. desktop_env/__init__.py +1 -0
  2. desktop_env/actions.py +203 -0
  3. desktop_env/controllers/__init__.py +0 -0
  4. desktop_env/controllers/python.py +471 -0
  5. desktop_env/controllers/setup.py +882 -0
  6. desktop_env/desktop_env.py +509 -0
  7. desktop_env/evaluators/__init__.py +5 -0
  8. desktop_env/evaluators/getters/__init__.py +41 -0
  9. desktop_env/evaluators/getters/calc.py +15 -0
  10. desktop_env/evaluators/getters/chrome.py +1774 -0
  11. desktop_env/evaluators/getters/file.py +154 -0
  12. desktop_env/evaluators/getters/general.py +42 -0
  13. desktop_env/evaluators/getters/gimp.py +38 -0
  14. desktop_env/evaluators/getters/impress.py +126 -0
  15. desktop_env/evaluators/getters/info.py +24 -0
  16. desktop_env/evaluators/getters/misc.py +406 -0
  17. desktop_env/evaluators/getters/replay.py +20 -0
  18. desktop_env/evaluators/getters/vlc.py +86 -0
  19. desktop_env/evaluators/getters/vscode.py +35 -0
  20. desktop_env/evaluators/metrics/__init__.py +160 -0
  21. desktop_env/evaluators/metrics/basic_os.py +68 -0
  22. desktop_env/evaluators/metrics/chrome.py +493 -0
  23. desktop_env/evaluators/metrics/docs.py +1011 -0
  24. desktop_env/evaluators/metrics/general.py +665 -0
  25. desktop_env/evaluators/metrics/gimp.py +637 -0
  26. desktop_env/evaluators/metrics/libreoffice.py +28 -0
  27. desktop_env/evaluators/metrics/others.py +92 -0
  28. desktop_env/evaluators/metrics/pdf.py +31 -0
  29. desktop_env/evaluators/metrics/slides.py +957 -0
  30. desktop_env/evaluators/metrics/table.py +585 -0
  31. desktop_env/evaluators/metrics/thunderbird.py +176 -0
  32. desktop_env/evaluators/metrics/utils.py +719 -0
  33. desktop_env/evaluators/metrics/vlc.py +524 -0
  34. desktop_env/evaluators/metrics/vscode.py +283 -0
  35. desktop_env/providers/__init__.py +35 -0
  36. desktop_env/providers/aws/__init__.py +0 -0
  37. desktop_env/providers/aws/manager.py +278 -0
  38. desktop_env/providers/aws/provider.py +186 -0
  39. desktop_env/providers/aws/provider_with_proxy.py +315 -0
  40. desktop_env/providers/aws/proxy_pool.py +193 -0
  41. desktop_env/providers/azure/__init__.py +0 -0
  42. desktop_env/providers/azure/manager.py +87 -0
  43. desktop_env/providers/azure/provider.py +207 -0
  44. desktop_env/providers/base.py +97 -0
  45. desktop_env/providers/gcp/__init__.py +0 -0
  46. desktop_env/providers/gcp/manager.py +0 -0
  47. desktop_env/providers/gcp/provider.py +0 -0
  48. desktop_env/providers/virtualbox/__init__.py +0 -0
  49. desktop_env/providers/virtualbox/manager.py +463 -0
  50. desktop_env/providers/virtualbox/provider.py +124 -0
  51. desktop_env/providers/vmware/__init__.py +0 -0
  52. desktop_env/providers/vmware/manager.py +455 -0
  53. desktop_env/providers/vmware/provider.py +105 -0
  54. gui_agents/__init__.py +0 -0
  55. gui_agents/agents/Action.py +209 -0
  56. gui_agents/agents/__init__.py +0 -0
  57. gui_agents/agents/agent_s.py +832 -0
  58. gui_agents/agents/global_state.py +610 -0
  59. gui_agents/agents/grounding.py +651 -0
  60. gui_agents/agents/hardware_interface.py +129 -0
  61. gui_agents/agents/manager.py +568 -0
  62. gui_agents/agents/translator.py +132 -0
  63. gui_agents/agents/worker.py +355 -0
  64. gui_agents/cli_app.py +560 -0
  65. gui_agents/core/__init__.py +0 -0
  66. gui_agents/core/engine.py +1496 -0
  67. gui_agents/core/knowledge.py +449 -0
  68. gui_agents/core/mllm.py +555 -0
  69. gui_agents/tools/__init__.py +0 -0
  70. gui_agents/tools/tools.py +727 -0
  71. gui_agents/unit_test/__init__.py +0 -0
  72. gui_agents/unit_test/run_tests.py +65 -0
  73. gui_agents/unit_test/test_manager.py +330 -0
  74. gui_agents/unit_test/test_worker.py +269 -0
  75. gui_agents/utils/__init__.py +0 -0
  76. gui_agents/utils/analyze_display.py +301 -0
  77. gui_agents/utils/common_utils.py +263 -0
  78. gui_agents/utils/display_viewer.py +281 -0
  79. gui_agents/utils/embedding_manager.py +53 -0
  80. gui_agents/utils/image_axis_utils.py +27 -0
  81. lybic_guiagents-0.1.0.dist-info/METADATA +416 -0
  82. lybic_guiagents-0.1.0.dist-info/RECORD +85 -0
  83. lybic_guiagents-0.1.0.dist-info/WHEEL +5 -0
  84. lybic_guiagents-0.1.0.dist-info/licenses/LICENSE +201 -0
  85. lybic_guiagents-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,555 @@
1
+ import base64
2
+
3
+ import numpy as np
4
+
5
+ from gui_agents.core.engine import (
6
+ LMMEngineAnthropic,
7
+ LMMEngineAzureOpenAI,
8
+ LMMEngineHuggingFace,
9
+ LMMEngineOpenAI,
10
+ LMMEngineOpenRouter,
11
+ LMMEnginevLLM,
12
+ LMMEngineGemini,
13
+ LMMEngineQwen,
14
+ LMMEngineDoubao,
15
+ LMMEngineDeepSeek,
16
+ LMMEngineZhipu,
17
+ LMMEngineGroq,
18
+ LMMEngineSiliconflow,
19
+ LMMEngineMonica,
20
+ LMMEngineAWSBedrock,
21
+ OpenAIEmbeddingEngine,
22
+ GeminiEmbeddingEngine,
23
+ AzureOpenAIEmbeddingEngine,
24
+ DashScopeEmbeddingEngine,
25
+ DoubaoEmbeddingEngine,
26
+ JinaEmbeddingEngine,
27
+ BochaAISearchEngine,
28
+ ExaResearchEngine,
29
+ )
30
+
31
+ class CostManager:
32
+ """Cost manager, responsible for adding currency symbols based on engine type"""
33
+
34
+ # Chinese engines use CNY
35
+ CNY_ENGINES = {
36
+ LMMEngineQwen, LMMEngineDoubao, LMMEngineDeepSeek, LMMEngineZhipu,
37
+ LMMEngineSiliconflow, DashScopeEmbeddingEngine, DoubaoEmbeddingEngine
38
+ }
39
+ # Other engines use USD
40
+ USD_ENGINES = {
41
+ LMMEngineOpenAI, LMMEngineAnthropic, LMMEngineAzureOpenAI, LMMEngineGemini,
42
+ LMMEngineOpenRouter, LMMEnginevLLM, LMMEngineHuggingFace, LMMEngineGroq,
43
+ LMMEngineMonica, LMMEngineAWSBedrock, OpenAIEmbeddingEngine,
44
+ GeminiEmbeddingEngine, AzureOpenAIEmbeddingEngine, JinaEmbeddingEngine
45
+ }
46
+
47
+ @classmethod
48
+ def get_currency_symbol(cls, engine) -> str:
49
+ engine_type = type(engine)
50
+
51
+ if engine_type in cls.CNY_ENGINES:
52
+ return "¥"
53
+ elif engine_type in cls.USD_ENGINES:
54
+ return "$"
55
+ else:
56
+ return "$"
57
+
58
+ @classmethod
59
+ def format_cost(cls, cost: float, engine) -> str:
60
+ currency = cls.get_currency_symbol(engine)
61
+ return f"{cost:.7f}{currency}"
62
+
63
+ @classmethod
64
+ def add_costs(cls, cost1: str, cost2: str) -> str:
65
+ currency_symbols = ["$", "¥", "¥", "€", "£"]
66
+ currency1 = currency2 = "$"
67
+ value1 = value2 = 0.0
68
+
69
+ if isinstance(cost1, (int, float)):
70
+ value1 = float(cost1)
71
+ currency1 = "$"
72
+ else:
73
+ cost1_str = str(cost1)
74
+ for symbol in currency_symbols:
75
+ if symbol in cost1_str:
76
+ value1 = float(cost1_str.replace(symbol, "").strip())
77
+ currency1 = symbol
78
+ break
79
+ else:
80
+ try:
81
+ value1 = float(cost1_str)
82
+ currency1 = "$"
83
+ except:
84
+ value1 = 0.0
85
+
86
+ if isinstance(cost2, (int, float)):
87
+ value2 = float(cost2)
88
+ currency2 = "$"
89
+ else:
90
+ cost2_str = str(cost2)
91
+ for symbol in currency_symbols:
92
+ if symbol in cost2_str:
93
+ value2 = float(cost2_str.replace(symbol, "").strip())
94
+ currency2 = symbol
95
+ break
96
+ else:
97
+ try:
98
+ value2 = float(cost2_str)
99
+ currency2 = "$"
100
+ except:
101
+ value2 = 0.0
102
+
103
+ if currency1 != currency2:
104
+ print(f"Warning: Different currencies in cost accumulation: {currency1} and {currency2}")
105
+ currency = currency1
106
+ else:
107
+ currency = currency1
108
+
109
+ total_value = value1 + value2
110
+ return f"{total_value:.6f}{currency}"
111
+
112
+ class LLMAgent:
113
+ def __init__(self, engine_params=None, system_prompt=None, engine=None):
114
+ if engine is None:
115
+ if engine_params is not None:
116
+ engine_type = engine_params.get("engine_type")
117
+ if engine_type == "openai":
118
+ self.engine = LMMEngineOpenAI(**engine_params)
119
+ elif engine_type == "anthropic":
120
+ self.engine = LMMEngineAnthropic(**engine_params)
121
+ elif engine_type == "azure":
122
+ self.engine = LMMEngineAzureOpenAI(**engine_params)
123
+ elif engine_type == "vllm":
124
+ self.engine = LMMEnginevLLM(**engine_params)
125
+ elif engine_type == "huggingface":
126
+ self.engine = LMMEngineHuggingFace(**engine_params)
127
+ elif engine_type == "gemini":
128
+ self.engine = LMMEngineGemini(**engine_params)
129
+ elif engine_type == "open_router":
130
+ self.engine = LMMEngineOpenRouter(**engine_params)
131
+ elif engine_type == "dashscope":
132
+ self.engine = LMMEngineQwen(**engine_params)
133
+ elif engine_type == "doubao":
134
+ self.engine = LMMEngineDoubao(**engine_params)
135
+ elif engine_type == "deepseek":
136
+ self.engine = LMMEngineDeepSeek(**engine_params)
137
+ elif engine_type == "zhipu":
138
+ self.engine = LMMEngineZhipu(**engine_params)
139
+ elif engine_type == "groq":
140
+ self.engine = LMMEngineGroq(**engine_params)
141
+ elif engine_type == "siliconflow":
142
+ self.engine = LMMEngineSiliconflow(**engine_params)
143
+ elif engine_type == "monica":
144
+ self.engine = LMMEngineMonica(**engine_params)
145
+ elif engine_type == "aws_bedrock":
146
+ self.engine = LMMEngineAWSBedrock(**engine_params)
147
+ else:
148
+ raise ValueError("engine_type is not supported")
149
+ else:
150
+ raise ValueError("engine_params must be provided")
151
+ else:
152
+ self.engine = engine
153
+
154
+ self.messages = [] # Empty messages
155
+
156
+ if system_prompt:
157
+ self.add_system_prompt(system_prompt)
158
+ else:
159
+ self.add_system_prompt("You are a helpful assistant.")
160
+
161
+ def encode_image(self, image_content):
162
+ # if image_content is a path to an image file, check type of the image_content to verify
163
+ if isinstance(image_content, str):
164
+ with open(image_content, "rb") as image_file:
165
+ return base64.b64encode(image_file.read()).decode("utf-8")
166
+ else:
167
+ return base64.b64encode(image_content).decode("utf-8")
168
+
169
+ def reset(
170
+ self,
171
+ ):
172
+
173
+ self.messages = [
174
+ {
175
+ "role": "system",
176
+ "content": [{"type": "text", "text": self.system_prompt}],
177
+ }
178
+ ]
179
+
180
+ def add_system_prompt(self, system_prompt):
181
+ self.system_prompt = system_prompt
182
+ if len(self.messages) > 0:
183
+ self.messages[0] = {
184
+ "role": "system",
185
+ "content": [{"type": "text", "text": self.system_prompt}],
186
+ }
187
+ else:
188
+ self.messages.append(
189
+ {
190
+ "role": "system",
191
+ "content": [{"type": "text", "text": self.system_prompt}],
192
+ }
193
+ )
194
+
195
+ def remove_message_at(self, index):
196
+ """Remove a message at a given index"""
197
+ if index < len(self.messages):
198
+ self.messages.pop(index)
199
+
200
+ def replace_message_at(
201
+ self, index, text_content, image_content=None, image_detail="high"
202
+ ):
203
+ """Replace a message at a given index"""
204
+ if index < len(self.messages):
205
+ self.messages[index] = {
206
+ "role": self.messages[index]["role"],
207
+ "content": [{"type": "text", "text": text_content}],
208
+ }
209
+ if image_content:
210
+ base64_image = self.encode_image(image_content)
211
+ self.messages[index]["content"].append(
212
+ {
213
+ "type": "image_url",
214
+ "image_url": {
215
+ "url": f"data:image/png;base64,{base64_image}",
216
+ "detail": image_detail,
217
+ },
218
+ }
219
+ )
220
+
221
+ def add_message(
222
+ self,
223
+ text_content,
224
+ image_content=None,
225
+ role=None,
226
+ image_detail="high",
227
+ put_text_last=False,
228
+ ):
229
+ """Add a new message to the list of messages"""
230
+
231
+ # API-style inference from OpenAI and similar services
232
+ if isinstance(
233
+ self.engine,
234
+ (
235
+ LMMEngineAnthropic,
236
+ LMMEngineAzureOpenAI,
237
+ LMMEngineHuggingFace,
238
+ LMMEngineOpenAI,
239
+ LMMEngineOpenRouter,
240
+ LMMEnginevLLM,
241
+ LMMEngineGemini,
242
+ LMMEngineQwen,
243
+ LMMEngineDoubao,
244
+ LMMEngineDeepSeek,
245
+ LMMEngineZhipu,
246
+ LMMEngineGroq,
247
+ LMMEngineSiliconflow,
248
+ LMMEngineMonica,
249
+ LMMEngineAWSBedrock,
250
+ ),
251
+ ):
252
+ # infer role from previous message
253
+ if role != "user":
254
+ if self.messages[-1]["role"] == "system":
255
+ role = "user"
256
+ elif self.messages[-1]["role"] == "user":
257
+ role = "assistant"
258
+ elif self.messages[-1]["role"] == "assistant":
259
+ role = "user"
260
+
261
+ message = {
262
+ "role": role,
263
+ "content": [{"type": "text", "text": text_content}],
264
+ }
265
+
266
+ if isinstance(image_content, np.ndarray) or image_content:
267
+ # Check if image_content is a list or a single image
268
+ if isinstance(image_content, list):
269
+ # If image_content is a list of images, loop through each image
270
+ for image in image_content:
271
+ base64_image = self.encode_image(image)
272
+ message["content"].append(
273
+ {
274
+ "type": "image_url",
275
+ "image_url": {
276
+ "url": f"data:image/png;base64,{base64_image}",
277
+ "detail": image_detail,
278
+ },
279
+ }
280
+ )
281
+ else:
282
+ # If image_content is a single image, handle it directly
283
+ base64_image = self.encode_image(image_content)
284
+ message["content"].append(
285
+ {
286
+ "type": "image_url",
287
+ "image_url": {
288
+ "url": f"data:image/png;base64,{base64_image}",
289
+ "detail": image_detail,
290
+ },
291
+ }
292
+ )
293
+
294
+ # Rotate text to be the last message if desired
295
+ if put_text_last:
296
+ text_content = message["content"].pop(0)
297
+ message["content"].append(text_content)
298
+
299
+ self.messages.append(message)
300
+
301
+ # For API-style inference from Anthropic
302
+ elif isinstance(self.engine, (LMMEngineAnthropic, LMMEngineAWSBedrock)):
303
+ # infer role from previous message
304
+ if role != "user":
305
+ if self.messages[-1]["role"] == "system":
306
+ role = "user"
307
+ elif self.messages[-1]["role"] == "user":
308
+ role = "assistant"
309
+ elif self.messages[-1]["role"] == "assistant":
310
+ role = "user"
311
+
312
+ message = {
313
+ "role": role,
314
+ "content": [{"type": "text", "text": text_content}],
315
+ }
316
+
317
+ if image_content:
318
+ # Check if image_content is a list or a single image
319
+ if isinstance(image_content, list):
320
+ # If image_content is a list of images, loop through each image
321
+ for image in image_content:
322
+ base64_image = self.encode_image(image)
323
+ message["content"].append(
324
+ {
325
+ "type": "image",
326
+ "source": {
327
+ "type": "base64",
328
+ "media_type": "image/png",
329
+ "data": base64_image,
330
+ },
331
+ }
332
+ )
333
+ else:
334
+ # If image_content is a single image, handle it directly
335
+ base64_image = self.encode_image(image_content)
336
+ message["content"].append(
337
+ {
338
+ "type": "image",
339
+ "source": {
340
+ "type": "base64",
341
+ "media_type": "image/png",
342
+ "data": base64_image,
343
+ },
344
+ }
345
+ )
346
+ self.messages.append(message)
347
+
348
+ # Locally hosted vLLM model inference
349
+ elif isinstance(self.engine, LMMEnginevLLM):
350
+ # infer role from previous message
351
+ if role != "user":
352
+ if self.messages[-1]["role"] == "system":
353
+ role = "user"
354
+ elif self.messages[-1]["role"] == "user":
355
+ role = "assistant"
356
+ elif self.messages[-1]["role"] == "assistant":
357
+ role = "user"
358
+
359
+ message = {
360
+ "role": role,
361
+ "content": [{"type": "text", "text": text_content}],
362
+ }
363
+
364
+ if image_content:
365
+ # Check if image_content is a list or a single image
366
+ if isinstance(image_content, list):
367
+ # If image_content is a list of images, loop through each image
368
+ for image in image_content:
369
+ base64_image = self.encode_image(image)
370
+ message["content"].append(
371
+ {
372
+ "type": "image_url",
373
+ "image_url": {
374
+ "url": f"data:image;base64,{base64_image}"
375
+ },
376
+ }
377
+ )
378
+ else:
379
+ # If image_content is a single image, handle it directly
380
+ base64_image = self.encode_image(image_content)
381
+ message["content"].append(
382
+ {
383
+ "type": "image_url",
384
+ "image_url": {"url": f"data:image;base64,{base64_image}"},
385
+ }
386
+ )
387
+
388
+ self.messages.append(message)
389
+ else:
390
+ raise ValueError("engine_type is not supported")
391
+
392
+ def get_response(
393
+ self,
394
+ user_message=None,
395
+ messages=None,
396
+ temperature=0.0,
397
+ max_new_tokens=None,
398
+ **kwargs,
399
+ ):
400
+ """Generate the next response based on previous messages"""
401
+ if messages is None:
402
+ messages = self.messages
403
+ if user_message:
404
+ messages.append(
405
+ {"role": "user", "content": [{"type": "text", "text": user_message}]}
406
+ )
407
+
408
+ content, total_tokens, cost = self.engine.generate(
409
+ messages,
410
+ temperature=temperature,
411
+ max_new_tokens=max_new_tokens, # type: ignore
412
+ **kwargs,
413
+ )
414
+
415
+ cost_string = CostManager.format_cost(cost, self.engine)
416
+
417
+ return content, total_tokens, cost_string
418
+
419
+ class EmbeddingAgent:
420
+ def __init__(self, engine_params=None, engine=None):
421
+ if engine is None:
422
+ if engine_params is not None:
423
+ engine_type = engine_params.get("engine_type")
424
+ if engine_type == "openai":
425
+ self.engine = OpenAIEmbeddingEngine(**engine_params)
426
+ elif engine_type == "gemini":
427
+ self.engine = GeminiEmbeddingEngine(**engine_params)
428
+ elif engine_type == "azure":
429
+ self.engine = AzureOpenAIEmbeddingEngine(**engine_params)
430
+ elif engine_type == "dashscope":
431
+ self.engine = DashScopeEmbeddingEngine(**engine_params)
432
+ elif engine_type == "doubao":
433
+ self.engine = DoubaoEmbeddingEngine(**engine_params)
434
+ elif engine_type == "jina":
435
+ self.engine = JinaEmbeddingEngine(**engine_params)
436
+ else:
437
+ raise ValueError(f"Embedding engine type '{engine_type}' is not supported")
438
+ else:
439
+ raise ValueError("engine_params must be provided")
440
+ else:
441
+ self.engine = engine
442
+
443
+ def get_embeddings(self, text):
444
+ """Get embeddings for the given text
445
+
446
+ Args:
447
+ text (str): The text to get embeddings for
448
+
449
+ Returns:
450
+ numpy.ndarray: The embeddings for the text
451
+ """
452
+ embeddings, total_tokens, cost = self.engine.get_embeddings(text)
453
+ cost_string = CostManager.format_cost(cost, self.engine)
454
+ return embeddings, total_tokens, cost_string
455
+
456
+
457
+ def get_similarity(self, text1, text2):
458
+ """Calculate the cosine similarity between two texts
459
+
460
+ Args:
461
+ text1 (str): First text
462
+ text2 (str): Second text
463
+
464
+ Returns:
465
+ float: Cosine similarity score between the two texts
466
+ """
467
+ embeddings1, tokens1, cost1 = self.get_embeddings(text1)
468
+ embeddings2, tokens2, cost2 = self.get_embeddings(text2)
469
+
470
+ # Calculate cosine similarity
471
+ dot_product = np.dot(embeddings1, embeddings2)
472
+ norm1 = np.linalg.norm(embeddings1)
473
+ norm2 = np.linalg.norm(embeddings2)
474
+
475
+ similarity = dot_product / (norm1 * norm2)
476
+ total_tokens = tokens1 + tokens2
477
+ total_cost = CostManager.add_costs(cost1, cost2)
478
+
479
+ return similarity, total_tokens, total_cost
480
+
481
+ def batch_get_embeddings(self, texts):
482
+ """Get embeddings for multiple texts
483
+
484
+ Args:
485
+ texts (List[str]): List of texts to get embeddings for
486
+
487
+ Returns:
488
+ List[numpy.ndarray]: List of embeddings for each text
489
+ """
490
+ embeddings = []
491
+ total_tokens = [0, 0, 0]
492
+ if texts:
493
+ first_embedding, first_tokens, first_cost = self.get_embeddings(texts[0])
494
+ embeddings.append(first_embedding)
495
+ total_tokens[0] += first_tokens[0]
496
+ total_tokens[1] += first_tokens[1]
497
+ total_tokens[2] += first_tokens[2]
498
+ total_cost = first_cost
499
+
500
+ for text in texts[1:]:
501
+ embedding, tokens, cost = self.get_embeddings(text)
502
+ embeddings.append(embedding)
503
+ total_tokens[0] += tokens[0]
504
+ total_tokens[1] += tokens[1]
505
+ total_tokens[2] += tokens[2]
506
+ total_cost = CostManager.add_costs(total_cost, cost)
507
+ else:
508
+ currency = CostManager.get_currency_symbol(self.engine)
509
+ total_cost = f"0.0{currency}"
510
+
511
+ return embeddings, total_tokens, total_cost
512
+
513
+
514
+ class WebSearchAgent:
515
+ def __init__(self, engine_params=None, engine=None):
516
+ if engine is None:
517
+ if engine_params is not None:
518
+ self.engine_type = engine_params.get("engine_type")
519
+ if self.engine_type == "bocha":
520
+ self.engine = BochaAISearchEngine(**engine_params)
521
+ elif self.engine_type == "exa":
522
+ self.engine = ExaResearchEngine(**engine_params)
523
+ else:
524
+ raise ValueError(f"Web search engine type '{self.engine_type}' is not supported")
525
+ else:
526
+ raise ValueError("engine_params must be provided")
527
+ else:
528
+ self.engine = engine
529
+
530
+ def get_answer(self, query, **kwargs):
531
+ """Get a direct answer for the query
532
+
533
+ Args:
534
+ query (str): The search query
535
+ **kwargs: Additional arguments to pass to the search engine
536
+
537
+ Returns:
538
+ str: The answer text
539
+ """
540
+ if isinstance(self.engine, BochaAISearchEngine):
541
+ answer, tokens, cost = self.engine.get_answer(query, **kwargs)
542
+ return answer, tokens, str(cost)
543
+
544
+ elif isinstance(self.engine, ExaResearchEngine):
545
+ # For Exa, we'll use the chat_research method which returns a complete answer
546
+ # results, tokens, cost = self.engine.search(query, **kwargs)
547
+ results, tokens, cost = self.engine.chat_research(query, **kwargs)
548
+ if isinstance(results, dict) and "messages" in results:
549
+ for message in results.get("messages", []):
550
+ if message.get("type") == "answer":
551
+ return message.get("content", ""), tokens, str(cost)
552
+ return str(results), tokens, str(cost)
553
+
554
+ else:
555
+ raise ValueError(f"Web search engine type '{self.engine_type}' is not supported")
File without changes