camel-ai 0.2.15a0__py3-none-any.whl → 0.2.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (95) hide show
  1. camel/__init__.py +1 -1
  2. camel/agents/chat_agent.py +18 -4
  3. camel/agents/multi_hop_generator_agent.py +85 -0
  4. camel/agents/programmed_agent_instruction.py +148 -0
  5. camel/benchmarks/__init__.py +13 -1
  6. camel/benchmarks/apibank.py +565 -0
  7. camel/benchmarks/apibench.py +500 -0
  8. camel/benchmarks/gaia.py +4 -4
  9. camel/benchmarks/nexus.py +518 -0
  10. camel/benchmarks/ragbench.py +333 -0
  11. camel/bots/__init__.py +1 -1
  12. camel/bots/discord/__init__.py +26 -0
  13. camel/bots/discord/discord_app.py +384 -0
  14. camel/bots/discord/discord_installation.py +64 -0
  15. camel/bots/discord/discord_store.py +160 -0
  16. camel/configs/__init__.py +3 -0
  17. camel/configs/anthropic_config.py +17 -15
  18. camel/configs/internlm_config.py +60 -0
  19. camel/data_collector/base.py +5 -5
  20. camel/data_collector/sharegpt_collector.py +2 -2
  21. camel/datagen/__init__.py +6 -2
  22. camel/datagen/{o1datagen.py → cotdatagen.py} +19 -6
  23. camel/datagen/self_instruct/__init__.py +36 -0
  24. camel/datagen/self_instruct/filter/__init__.py +34 -0
  25. camel/datagen/self_instruct/filter/filter_function.py +216 -0
  26. camel/datagen/self_instruct/filter/filter_registry.py +56 -0
  27. camel/datagen/self_instruct/filter/instruction_filter.py +81 -0
  28. camel/datagen/self_instruct/self_instruct.py +393 -0
  29. camel/datagen/self_instruct/templates.py +382 -0
  30. camel/datahubs/huggingface.py +12 -2
  31. camel/datahubs/models.py +2 -3
  32. camel/embeddings/mistral_embedding.py +5 -1
  33. camel/embeddings/openai_compatible_embedding.py +6 -1
  34. camel/embeddings/openai_embedding.py +5 -1
  35. camel/interpreters/e2b_interpreter.py +5 -1
  36. camel/loaders/__init__.py +2 -0
  37. camel/loaders/apify_reader.py +5 -1
  38. camel/loaders/chunkr_reader.py +5 -1
  39. camel/loaders/firecrawl_reader.py +0 -30
  40. camel/loaders/panda_reader.py +337 -0
  41. camel/logger.py +11 -5
  42. camel/messages/__init__.py +10 -4
  43. camel/messages/conversion/conversation_models.py +5 -0
  44. camel/messages/func_message.py +30 -22
  45. camel/models/__init__.py +2 -0
  46. camel/models/anthropic_model.py +6 -23
  47. camel/models/azure_openai_model.py +1 -2
  48. camel/models/cohere_model.py +13 -1
  49. camel/models/deepseek_model.py +5 -1
  50. camel/models/gemini_model.py +15 -2
  51. camel/models/groq_model.py +5 -1
  52. camel/models/internlm_model.py +143 -0
  53. camel/models/mistral_model.py +19 -8
  54. camel/models/model_factory.py +3 -0
  55. camel/models/nemotron_model.py +5 -1
  56. camel/models/nvidia_model.py +5 -1
  57. camel/models/openai_model.py +5 -1
  58. camel/models/qwen_model.py +5 -1
  59. camel/models/reka_model.py +5 -1
  60. camel/models/reward/__init__.py +2 -0
  61. camel/models/reward/nemotron_model.py +5 -1
  62. camel/models/reward/skywork_model.py +88 -0
  63. camel/models/samba_model.py +5 -1
  64. camel/models/togetherai_model.py +5 -1
  65. camel/models/yi_model.py +5 -1
  66. camel/models/zhipuai_model.py +5 -1
  67. camel/schemas/openai_converter.py +5 -1
  68. camel/storages/graph_storages/nebula_graph.py +89 -20
  69. camel/storages/graph_storages/neo4j_graph.py +138 -0
  70. camel/synthetic_datagen/source2synth/data_processor.py +373 -0
  71. camel/synthetic_datagen/source2synth/models.py +68 -0
  72. camel/synthetic_datagen/source2synth/user_data_processor_config.py +73 -0
  73. camel/toolkits/__init__.py +4 -0
  74. camel/toolkits/arxiv_toolkit.py +20 -3
  75. camel/toolkits/dappier_toolkit.py +196 -0
  76. camel/toolkits/function_tool.py +61 -61
  77. camel/toolkits/google_scholar_toolkit.py +9 -0
  78. camel/toolkits/meshy_toolkit.py +5 -1
  79. camel/toolkits/notion_toolkit.py +1 -1
  80. camel/toolkits/openbb_toolkit.py +869 -0
  81. camel/toolkits/search_toolkit.py +91 -5
  82. camel/toolkits/stripe_toolkit.py +5 -1
  83. camel/toolkits/twitter_toolkit.py +24 -16
  84. camel/types/__init__.py +4 -2
  85. camel/types/enums.py +34 -1
  86. camel/types/openai_types.py +6 -4
  87. camel/types/unified_model_type.py +5 -0
  88. camel/utils/__init__.py +2 -0
  89. camel/utils/commons.py +104 -19
  90. camel/utils/token_counting.py +3 -3
  91. {camel_ai-0.2.15a0.dist-info → camel_ai-0.2.17.dist-info}/METADATA +160 -177
  92. {camel_ai-0.2.15a0.dist-info → camel_ai-0.2.17.dist-info}/RECORD +94 -69
  93. {camel_ai-0.2.15a0.dist-info → camel_ai-0.2.17.dist-info}/WHEEL +1 -1
  94. camel/bots/discord_app.py +0 -138
  95. {camel_ai-0.2.15a0.dist-info → camel_ai-0.2.17.dist-info}/LICENSE +0 -0
@@ -0,0 +1,500 @@
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ import json
16
+ import logging
17
+ import random
18
+ from pathlib import Path
19
+ from typing import Any, Dict, Literal, Optional
20
+
21
+ import tree_sitter_python as tspython
22
+ from tqdm import tqdm
23
+ from tree_sitter import Language, Parser
24
+
25
+ from camel.agents import ChatAgent
26
+ from camel.benchmarks.base import BaseBenchmark
27
+ from camel.messages import BaseMessage
28
+ from camel.utils import download_github_subdirectory
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ # Mapping of dataset names to file names
34
+ # 'Oracle' retriver used here which means all the full
35
+ # API documentation will be included in the prompt
36
+ dataset_mapping = {
37
+ "huggingface": {
38
+ "api": "huggingface_api.jsonl",
39
+ "eval": "huggingface_eval.json",
40
+ "train": "huggingface_train.json",
41
+ "questions": "questions_huggingface_oracle.jsonl",
42
+ },
43
+ "tensorflowhub": {
44
+ "api": "tensorflowhub_api.jsonl",
45
+ "eval": "tensorflow_eval.json",
46
+ "train": "tensorflow_train.json",
47
+ "questions": "questions_tensorflowhub_oracle.jsonl",
48
+ },
49
+ "torchhub": {
50
+ "api": "torchhub_api.jsonl",
51
+ "eval": "torchhub_eval.json",
52
+ "train": "torchhub_train.json",
53
+ "questions": "questions_torchhub_oracle.jsonl",
54
+ },
55
+ }
56
+
57
+
58
+ # This function is migrated from the original repo:
59
+ # https://github.com/ShishirPatil/gorilla
60
+ def encode_question(question: str, dataset_name: str) -> str:
61
+ r"""Encode multiple prompt instructions into a single string."""
62
+
63
+ if dataset_name == "torchhub":
64
+ domains = "1. $DOMAIN is inferred from the task description and \
65
+ should include one of {Classification, Semantic Segmentation, \
66
+ Object Detection, Audio Separation, Video Classification, \
67
+ Text-to-Speech}."
68
+ elif dataset_name == "huggingface":
69
+ domains = "1. $DOMAIN should include one of {Multimodal Feature \
70
+ Extraction, Multimodal Text-to-Image, Multimodal \
71
+ Image-to-Text, Multimodal Text-to-Video, \
72
+ Multimodal Visual Question Answering, Multimodal Document \
73
+ Question Answer, Multimodal Graph Machine Learning, \
74
+ Computer Vision Depth Estimation, Computer Vision Image \
75
+ Classification, Computer Vision Object Detection, \
76
+ Computer Vision Image Segmentation, Computer Vision \
77
+ Image-to-Image, Computer Vision Unconditional \
78
+ Image Generation, Computer Vision Video Classification, \
79
+ Computer Vision Zero-Shor Image Classification, \
80
+ Natural Language Processing Text Classification, \
81
+ Natural Language Processing Token Classification, \
82
+ Natural Language Processing Table Question Answering, \
83
+ Natural Language Processing Question Answering, \
84
+ Natural Language Processing, Zero-Shot Classification \
85
+ Natural Language Processing Translation, Natural Language \
86
+ Processing Summarization, Natural Language Processing \
87
+ Conversational, Natural Language Processing Text \
88
+ Generation, Natural Language Processing Fill-Mask, \
89
+ Natural Language Processing Text2Text Generation, \
90
+ Natural Language Processing Sentence Similarity, \
91
+ Audio Text-to-Speech, Audio Automatic Speech Recognition, \
92
+ Audio Audio-to-Audio, Audio Audio Classification, \
93
+ Audio Voice Activity Detection, Tabular Tabular \
94
+ Classification, Tabular Tabular Regression, \
95
+ Reinforcement Learning Reinforcement Learning, \
96
+ Reinforcement Learning Robotics }"
97
+ elif dataset_name == "tensorflowhub":
98
+ domains = "1. $DOMAIN is inferred from the task description \
99
+ and should include one of {text-sequence-alignment, \
100
+ text-embedding, text-language-model, text-preprocessing, \
101
+ text-classification, text-generation, text-question-answering, \
102
+ text-retrieval-question-answering, text-segmentation, \
103
+ text-to-mel, image-classification, image-feature-vector, \
104
+ image-object-detection, image-segmentation, \
105
+ image-generator, image-pose-detection, image-rnn-agent, \
106
+ image-augmentation, image-classifier, image-style-transfer, \
107
+ image-aesthetic-quality, image-depth-estimation, \
108
+ image-super-resolution, image-deblurring, image-extrapolation, \
109
+ image-text-recognition, image-dehazing, image-deraining, \
110
+ image-enhancemenmt, image-classification-logits, \
111
+ image-frame-interpolation, image-text-detection, image-denoising, \
112
+ image-others, video-classification, video-feature-extraction, \
113
+ video-generation, video-audio-text, video-text, \
114
+ audio-embedding, audio-event-classification, audio-command-detection, \
115
+ audio-paralinguists-classification, audio-speech-to-text, \
116
+ audio-speech-synthesis, audio-synthesis, audio-pitch-extraction}"
117
+ else:
118
+ logger.info("Error: API name is not supported.")
119
+
120
+ prompt = (
121
+ question
122
+ + "\nWrite a python program in 1 to 2 lines to call API in "
123
+ + dataset_name
124
+ + ".\n\nThe answer should follow the format: <<<domain>>> $DOMAIN, \
125
+ <<<api_call>>>: $API_CALL, <<<api_provider>>>: $API_PROVIDER, \
126
+ <<<explanation>>>: $EXPLANATION, <<<code>>>: $CODE}. \
127
+ Here are the requirements:\n"
128
+ + domains
129
+ + "\n2. The $API_CALL should have only 1 line of code \
130
+ that calls api.\n 3. The $API_PROVIDER should be the \
131
+ programming framework used.\n4. $EXPLANATION should be \
132
+ a step-by-step explanation.\n5. The $CODE is the python code.\n6. \
133
+ Do not repeat the format in your answer."
134
+ )
135
+ return prompt
136
+
137
+
138
+ class APIBenchBenchmark(BaseBenchmark):
139
+ r"""APIBench Benchmark adopted from `Gorilla: Large Language Model
140
+ Connected with Massive APIs`
141
+ <https://huggingface.co/datasets/gorilla-llm/APIBench>.
142
+
143
+ Args:
144
+ data_dir (str): The directory to save the data.
145
+ save_to (str): The file to save the results.
146
+ processes (int, optional): The number of processes to use.
147
+ (default: :obj:`1`)
148
+ """
149
+
150
+ # TODO: Integrate retriever (pending)
151
+
152
+ def __init__(
153
+ self,
154
+ data_dir: str,
155
+ save_to: str,
156
+ processes: int = 1,
157
+ ):
158
+ r"""Initialize the APIBench benchmark.
159
+
160
+ Args:
161
+ data_dir (str): The directory to save the data.
162
+ save_to (str): The file to save the results.
163
+ processes (int, optional): The number of processes to use for
164
+ parallel processing. (default: :obj:`1`)
165
+ """
166
+ super().__init__("apibench", data_dir, save_to, processes)
167
+
168
+ def download(self):
169
+ r"""Download the APIBench dataset."""
170
+ from huggingface_hub import snapshot_download
171
+
172
+ snapshot_download(
173
+ repo_id="gorilla-llm/APIBench",
174
+ repo_type="dataset",
175
+ local_dir=self.data_dir,
176
+ local_dir_use_symlinks=True,
177
+ )
178
+
179
+ repo = "ShishirPatil/gorilla"
180
+ subdir = "/gorilla/eval/eval-data/questions"
181
+ data_dir = self.data_dir
182
+
183
+ download_github_subdirectory(repo, subdir, data_dir)
184
+
185
+ def load(self, dataset_name: str, force_download: bool = False): # type: ignore[override]
186
+ r"""Load the APIBench Benchmark dataset.
187
+
188
+ Args:
189
+ dataset_name (str): Name of the specific dataset to be loaded.
190
+ force_download (bool, optional): Whether to force
191
+ download the data. (default: :obj:`False`)
192
+ """
193
+
194
+ if force_download:
195
+ logger.info("Force downloading data.")
196
+ self.download()
197
+
198
+ def load_json_lines(file_path: Path):
199
+ r"""Helper function to load JSON lines from a file."""
200
+ try:
201
+ with open(file_path, "r") as f:
202
+ return [json.loads(line) for line in f]
203
+ except FileNotFoundError:
204
+ raise FileNotFoundError(f"File not found: {file_path}")
205
+ except json.JSONDecodeError as e:
206
+ raise ValueError(
207
+ f"Error decoding JSON in file {file_path}: {e}"
208
+ )
209
+
210
+ dataset_path = self.data_dir / dataset_name
211
+ if not dataset_path.exists():
212
+ raise FileNotFoundError(
213
+ f"Dataset directory does not exist: {dataset_path}"
214
+ )
215
+
216
+ for label in ['api', 'eval', 'questions']:
217
+ file_name = dataset_mapping[dataset_name][label]
218
+ file_path = (
219
+ dataset_path / file_name
220
+ if label == 'questions'
221
+ else self.data_dir / file_name
222
+ )
223
+
224
+ # Load data based on label type
225
+ if label in ['api', 'questions', 'eval']:
226
+ data = load_json_lines(file_path)
227
+
228
+ if label == 'eval':
229
+ # Extract 'api_data' specifically for eval label
230
+ data = [item['api_data'] for item in data]
231
+
232
+ self._data[label] = data
233
+ else:
234
+ raise ValueError(f"Unknown label: {label}")
235
+
236
+ ast_database = []
237
+ for data in self._data['api']:
238
+ ast_tree = ast_parse(data['api_call'])
239
+ ast_database.append(ast_tree)
240
+ self._data['ast'] = ast_database
241
+
242
+ def run( # type: ignore[override]
243
+ self,
244
+ agent: ChatAgent,
245
+ dataset_name: Literal["huggingface", "tensorflowhub", "torchhub"],
246
+ randomize: bool = False,
247
+ subset: Optional[int] = None,
248
+ ) -> Dict[str, Any]:
249
+ r"""Run the benchmark.
250
+
251
+ Args:
252
+ agent (ChatAgent): The agent to run the
253
+ benchmark.
254
+ dataset_name (Literal["huggingface",
255
+ "tensorflowhub", "torchhub"]):
256
+ The dataset to run the benchmark.
257
+ randomize (bool, optional): Whether to randomize the data.
258
+ (default: :obj:`False`)
259
+ subset (Optional[int], optional): The subset of data to run.
260
+ (default: :obj:`None`)
261
+ """
262
+
263
+ if dataset_name not in dataset_mapping:
264
+ raise ValueError(f"Invalid value for dataset: {dataset_name}.")
265
+
266
+ logger.info(f"Running APIBench benchmark on {dataset_name}.")
267
+ self.load(dataset_name)
268
+ datas = self._data['questions']
269
+
270
+ # Shuffle and subset data if necessary
271
+ if randomize:
272
+ random.shuffle(datas)
273
+ if subset:
274
+ datas = datas[:subset]
275
+
276
+ logger.info(f"Number of tasks: {len(datas)}")
277
+
278
+ # Initialize results storage
279
+ self._results = []
280
+
281
+ with open(self.save_to, "w") as f:
282
+ for question in tqdm(datas, desc="Running"):
283
+ prompt = encode_question(question["text"], dataset_name)
284
+ msg = BaseMessage.make_user_message(
285
+ role_name="User", content=prompt
286
+ )
287
+ try:
288
+ # Generate response
289
+ responses = agent.step(msg)
290
+ response = responses.msgs[0].content
291
+ api_database = self._data['api']
292
+ qa_pairs = self._data['eval']
293
+ ast_database = self._data['ast']
294
+ question_id = question['question_id']
295
+
296
+ # Evaluate response
297
+ error, correct, hallucination = evaluate_response(
298
+ response,
299
+ question_id,
300
+ dataset_name,
301
+ api_database,
302
+ qa_pairs,
303
+ ast_database,
304
+ )
305
+ self._results.append(
306
+ {
307
+ "question": question,
308
+ "agent_response": response,
309
+ "correct": correct,
310
+ "hallucination": hallucination,
311
+ "error": str(error) if error else None,
312
+ }
313
+ )
314
+ except Exception as e:
315
+ logger.warning(
316
+ f"Error in processing task: {question}: {e}"
317
+ )
318
+ self._results.append(
319
+ {
320
+ "question": question,
321
+ "agent_response": None,
322
+ "correct": False,
323
+ "hallucination": False,
324
+ "error": str(e),
325
+ }
326
+ )
327
+
328
+ agent.reset()
329
+
330
+ f.write(json.dumps(self._results[-1], indent=2) + "\n")
331
+ f.flush()
332
+
333
+ total = len(self._results)
334
+ correct = sum(r["correct"] for r in self.results)
335
+ hallucination = sum(r["hallucination"] for r in self.results)
336
+
337
+ return {
338
+ "total": total,
339
+ "correct": correct,
340
+ "hallucination": hallucination,
341
+ "accuracy": correct / total if total else "N/A",
342
+ "hallucination rate": hallucination / total if total else "N/A",
343
+ }
344
+
345
+
346
+ # This code is modified from the
347
+ # evaluators in the original repo
348
+ # https://github.com/ShishirPatil/gorilla
349
+ # Get all the subtrees given a root_node
350
+ def get_all_sub_trees(root_node):
351
+ node_stack = []
352
+ sub_tree_sexp_list = []
353
+ depth = 1
354
+ # text = root_node.text
355
+ node_stack.append([root_node, depth])
356
+ while len(node_stack) != 0:
357
+ cur_node, cur_depth = node_stack.pop()
358
+ if cur_node.child_count > 0:
359
+ sub_tree_sexp_list.append(
360
+ [
361
+ str(cur_node),
362
+ cur_depth,
363
+ cur_node,
364
+ cur_node.children[0].text,
365
+ ]
366
+ )
367
+ else:
368
+ sub_tree_sexp_list.append(
369
+ [str(cur_node), cur_depth, cur_node, None]
370
+ )
371
+ for child_node in cur_node.children:
372
+ if len(child_node.children) != 0:
373
+ depth = cur_depth + 1
374
+ node_stack.append([child_node, depth])
375
+ return sub_tree_sexp_list
376
+
377
+
378
+ # Parse the program into AST trees
379
+ def ast_parse(candidate):
380
+ PY_LANGUAGE = Language(tspython.language())
381
+ parser = Parser(PY_LANGUAGE)
382
+
383
+ candidate_tree = parser.parse(bytes(candidate, "utf8")).root_node
384
+ return candidate_tree
385
+
386
+
387
+ # Get all the arguments in the ast tree
388
+ def get_args(node, dataset_name):
389
+ if node.child_count == 0:
390
+ return []
391
+ args_list = []
392
+ if dataset_name == "huggingface":
393
+ for child in node.children[0].children[0].children[1].children:
394
+ if "=" in child.text.decode():
395
+ args_list.append(child.children[2].text)
396
+ elif (
397
+ child.text.decode() != "("
398
+ and child.text.decode() != ")"
399
+ and child.text.decode() != ","
400
+ ):
401
+ args_list.append(child.text)
402
+ elif dataset_name == "tensorflowhub":
403
+ for child in node.children[0].children[0].children[1].children:
404
+ if (
405
+ 'model=' in child.text.decode()
406
+ or 'model =' in child.text.decode()
407
+ ):
408
+ args_list.append(child.children[2].text)
409
+ elif (
410
+ child.text.decode() != "("
411
+ and child.text.decode() != ")"
412
+ and child.text.decode() != ","
413
+ ):
414
+ args_list.append(child.text)
415
+ elif dataset_name == "torchhub":
416
+ for child in node.children[0].children[0].children[1].children:
417
+ if (
418
+ "repo_or_dir" in child.text.decode()
419
+ or "model" in child.text.decode()
420
+ ):
421
+ args_list.append(child.children[2].text)
422
+ return args_list
423
+
424
+
425
+ # Check if there is an api match
426
+ def ast_check(candidate_subtree_list, base_tree_list, dataset_name):
427
+ for idx, base_tree in enumerate(base_tree_list):
428
+ if base_tree.children[0].children[0].child_count == 0:
429
+ continue
430
+ api_name = base_tree.children[0].children[0].children[0].text
431
+ for candidate_tree in candidate_subtree_list:
432
+ if candidate_tree[3] == api_name:
433
+ break
434
+ # Now we have a sub-tree
435
+ candidate_tree = candidate_tree[2]
436
+ args_list = get_args(base_tree, dataset_name)
437
+ if len(args_list) == 0:
438
+ continue
439
+ ast_match = True
440
+ for arg in args_list:
441
+ if (
442
+ arg.decode().lstrip("'").rstrip("'")
443
+ not in candidate_tree.text.decode()
444
+ ):
445
+ ast_match = False
446
+ break
447
+ if ast_match:
448
+ return idx
449
+ return -1
450
+
451
+
452
+ def evaluate_response(
453
+ response, question_id, dataset_name, api_database, qa_pairs, ast_database
454
+ ):
455
+ try:
456
+ # Index the "api_call" domain
457
+ output = response.split("api_call")
458
+ if len(output) == 1:
459
+ api_call = output[0]
460
+ else:
461
+ # Parse the output
462
+ output = output[1].split("api_provider")[0]
463
+ if ":" not in output:
464
+ start = 0
465
+ else:
466
+ start = output.index(":")
467
+ if ")" not in output:
468
+ end = -2
469
+ else:
470
+ end = output.rindex(")")
471
+ api_call = output[start + 2 : end + 1]
472
+
473
+ try:
474
+ ast_tree = ast_parse(api_call)
475
+ except Exception as parse_error:
476
+ print(f"Error parsing api_call: {api_call}, error: {parse_error}")
477
+ return parse_error, False, False
478
+ # Search for a subtree
479
+ ast_subtree_list = get_all_sub_trees(ast_tree)
480
+ # Check which ast tree is matching
481
+ database_index = ast_check(
482
+ ast_subtree_list, ast_database, dataset_name
483
+ )
484
+ # We cannot index this ast in our database
485
+ if database_index == -1:
486
+ halluncination = True
487
+ correct = False
488
+ # We index our reference api_call
489
+ ref_api_call = api_database[database_index]
490
+ # Check for functionality
491
+ if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
492
+ correct = True
493
+ halluncination = False
494
+ else:
495
+ return None, False, False
496
+ except Exception as e:
497
+ print(f'Error parsing response: {response}, error: {e}')
498
+ return e, False, False
499
+
500
+ return None, correct, halluncination
camel/benchmarks/gaia.py CHANGED
@@ -25,8 +25,8 @@ from typing import Any, Dict, List, Literal, Optional, Protocol, Union
25
25
  from tqdm import tqdm
26
26
 
27
27
  from camel.agents import ChatAgent
28
- from camel.benchmarks import BaseBenchmark
29
- from camel.messages.base import BaseMessage
28
+ from camel.benchmarks.base import BaseBenchmark
29
+ from camel.messages import BaseMessage
30
30
  from camel.retrievers.auto_retriever import AutoRetriever
31
31
 
32
32
  logger = logging.getLogger(__name__)
@@ -280,11 +280,11 @@ class GAIABenchmark(BaseBenchmark):
280
280
  f"Skipping task because file not found: {file_path}"
281
281
  )
282
282
  return False
283
- if file_path.suffix in ['.pdf', '.docx', '.doc', '.txt']:
283
+ if file_path.suffix in [".pdf", ".docx", ".doc", ".txt"]:
284
284
  if not self.retriever.reset(task_id=task["task_id"]):
285
285
  return False
286
286
  retrieved_info = self.retriever.retrieve(
287
- query=task["Question"], contents=[task['file_name']]
287
+ query=task["Question"], contents=[task["file_name"]]
288
288
  )
289
289
  retrieved_content = [
290
290
  item["text"]