camel-ai 0.2.15a0__py3-none-any.whl → 0.2.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/agents/chat_agent.py +18 -4
- camel/agents/multi_hop_generator_agent.py +85 -0
- camel/agents/programmed_agent_instruction.py +148 -0
- camel/benchmarks/__init__.py +13 -1
- camel/benchmarks/apibank.py +565 -0
- camel/benchmarks/apibench.py +500 -0
- camel/benchmarks/gaia.py +4 -4
- camel/benchmarks/nexus.py +518 -0
- camel/benchmarks/ragbench.py +333 -0
- camel/bots/__init__.py +1 -1
- camel/bots/discord/__init__.py +26 -0
- camel/bots/discord/discord_app.py +384 -0
- camel/bots/discord/discord_installation.py +64 -0
- camel/bots/discord/discord_store.py +160 -0
- camel/configs/__init__.py +3 -0
- camel/configs/anthropic_config.py +17 -15
- camel/configs/internlm_config.py +60 -0
- camel/data_collector/base.py +5 -5
- camel/data_collector/sharegpt_collector.py +2 -2
- camel/datagen/__init__.py +6 -2
- camel/datagen/{o1datagen.py → cotdatagen.py} +19 -6
- camel/datagen/self_instruct/__init__.py +36 -0
- camel/datagen/self_instruct/filter/__init__.py +34 -0
- camel/datagen/self_instruct/filter/filter_function.py +216 -0
- camel/datagen/self_instruct/filter/filter_registry.py +56 -0
- camel/datagen/self_instruct/filter/instruction_filter.py +81 -0
- camel/datagen/self_instruct/self_instruct.py +393 -0
- camel/datagen/self_instruct/templates.py +382 -0
- camel/datahubs/huggingface.py +12 -2
- camel/datahubs/models.py +2 -3
- camel/embeddings/mistral_embedding.py +5 -1
- camel/embeddings/openai_compatible_embedding.py +6 -1
- camel/embeddings/openai_embedding.py +5 -1
- camel/interpreters/e2b_interpreter.py +5 -1
- camel/loaders/__init__.py +2 -0
- camel/loaders/apify_reader.py +5 -1
- camel/loaders/chunkr_reader.py +5 -1
- camel/loaders/firecrawl_reader.py +0 -30
- camel/loaders/panda_reader.py +337 -0
- camel/logger.py +11 -5
- camel/messages/__init__.py +10 -4
- camel/messages/conversion/conversation_models.py +5 -0
- camel/messages/func_message.py +30 -22
- camel/models/__init__.py +2 -0
- camel/models/anthropic_model.py +6 -23
- camel/models/azure_openai_model.py +1 -2
- camel/models/cohere_model.py +13 -1
- camel/models/deepseek_model.py +5 -1
- camel/models/gemini_model.py +15 -2
- camel/models/groq_model.py +5 -1
- camel/models/internlm_model.py +143 -0
- camel/models/mistral_model.py +19 -8
- camel/models/model_factory.py +3 -0
- camel/models/nemotron_model.py +5 -1
- camel/models/nvidia_model.py +5 -1
- camel/models/openai_model.py +5 -1
- camel/models/qwen_model.py +5 -1
- camel/models/reka_model.py +5 -1
- camel/models/reward/__init__.py +2 -0
- camel/models/reward/nemotron_model.py +5 -1
- camel/models/reward/skywork_model.py +88 -0
- camel/models/samba_model.py +5 -1
- camel/models/togetherai_model.py +5 -1
- camel/models/yi_model.py +5 -1
- camel/models/zhipuai_model.py +5 -1
- camel/schemas/openai_converter.py +5 -1
- camel/storages/graph_storages/nebula_graph.py +89 -20
- camel/storages/graph_storages/neo4j_graph.py +138 -0
- camel/synthetic_datagen/source2synth/data_processor.py +373 -0
- camel/synthetic_datagen/source2synth/models.py +68 -0
- camel/synthetic_datagen/source2synth/user_data_processor_config.py +73 -0
- camel/toolkits/__init__.py +4 -0
- camel/toolkits/arxiv_toolkit.py +20 -3
- camel/toolkits/dappier_toolkit.py +196 -0
- camel/toolkits/function_tool.py +61 -61
- camel/toolkits/google_scholar_toolkit.py +9 -0
- camel/toolkits/meshy_toolkit.py +5 -1
- camel/toolkits/notion_toolkit.py +1 -1
- camel/toolkits/openbb_toolkit.py +869 -0
- camel/toolkits/search_toolkit.py +91 -5
- camel/toolkits/stripe_toolkit.py +5 -1
- camel/toolkits/twitter_toolkit.py +24 -16
- camel/types/__init__.py +4 -2
- camel/types/enums.py +34 -1
- camel/types/openai_types.py +6 -4
- camel/types/unified_model_type.py +5 -0
- camel/utils/__init__.py +2 -0
- camel/utils/commons.py +104 -19
- camel/utils/token_counting.py +3 -3
- {camel_ai-0.2.15a0.dist-info → camel_ai-0.2.17.dist-info}/METADATA +160 -177
- {camel_ai-0.2.15a0.dist-info → camel_ai-0.2.17.dist-info}/RECORD +94 -69
- {camel_ai-0.2.15a0.dist-info → camel_ai-0.2.17.dist-info}/WHEEL +1 -1
- camel/bots/discord_app.py +0 -138
- {camel_ai-0.2.15a0.dist-info → camel_ai-0.2.17.dist-info}/LICENSE +0 -0
|
@@ -0,0 +1,500 @@
|
|
|
1
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import logging
|
|
17
|
+
import random
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Any, Dict, Literal, Optional
|
|
20
|
+
|
|
21
|
+
import tree_sitter_python as tspython
|
|
22
|
+
from tqdm import tqdm
|
|
23
|
+
from tree_sitter import Language, Parser
|
|
24
|
+
|
|
25
|
+
from camel.agents import ChatAgent
|
|
26
|
+
from camel.benchmarks.base import BaseBenchmark
|
|
27
|
+
from camel.messages import BaseMessage
|
|
28
|
+
from camel.utils import download_github_subdirectory
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# Mapping of dataset names to file names
|
|
34
|
+
# 'Oracle' retriver used here which means all the full
|
|
35
|
+
# API documentation will be included in the prompt
|
|
36
|
+
dataset_mapping = {
|
|
37
|
+
"huggingface": {
|
|
38
|
+
"api": "huggingface_api.jsonl",
|
|
39
|
+
"eval": "huggingface_eval.json",
|
|
40
|
+
"train": "huggingface_train.json",
|
|
41
|
+
"questions": "questions_huggingface_oracle.jsonl",
|
|
42
|
+
},
|
|
43
|
+
"tensorflowhub": {
|
|
44
|
+
"api": "tensorflowhub_api.jsonl",
|
|
45
|
+
"eval": "tensorflow_eval.json",
|
|
46
|
+
"train": "tensorflow_train.json",
|
|
47
|
+
"questions": "questions_tensorflowhub_oracle.jsonl",
|
|
48
|
+
},
|
|
49
|
+
"torchhub": {
|
|
50
|
+
"api": "torchhub_api.jsonl",
|
|
51
|
+
"eval": "torchhub_eval.json",
|
|
52
|
+
"train": "torchhub_train.json",
|
|
53
|
+
"questions": "questions_torchhub_oracle.jsonl",
|
|
54
|
+
},
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# This function is migrated from the original repo:
|
|
59
|
+
# https://github.com/ShishirPatil/gorilla
|
|
60
|
+
def encode_question(question: str, dataset_name: str) -> str:
|
|
61
|
+
r"""Encode multiple prompt instructions into a single string."""
|
|
62
|
+
|
|
63
|
+
if dataset_name == "torchhub":
|
|
64
|
+
domains = "1. $DOMAIN is inferred from the task description and \
|
|
65
|
+
should include one of {Classification, Semantic Segmentation, \
|
|
66
|
+
Object Detection, Audio Separation, Video Classification, \
|
|
67
|
+
Text-to-Speech}."
|
|
68
|
+
elif dataset_name == "huggingface":
|
|
69
|
+
domains = "1. $DOMAIN should include one of {Multimodal Feature \
|
|
70
|
+
Extraction, Multimodal Text-to-Image, Multimodal \
|
|
71
|
+
Image-to-Text, Multimodal Text-to-Video, \
|
|
72
|
+
Multimodal Visual Question Answering, Multimodal Document \
|
|
73
|
+
Question Answer, Multimodal Graph Machine Learning, \
|
|
74
|
+
Computer Vision Depth Estimation, Computer Vision Image \
|
|
75
|
+
Classification, Computer Vision Object Detection, \
|
|
76
|
+
Computer Vision Image Segmentation, Computer Vision \
|
|
77
|
+
Image-to-Image, Computer Vision Unconditional \
|
|
78
|
+
Image Generation, Computer Vision Video Classification, \
|
|
79
|
+
Computer Vision Zero-Shor Image Classification, \
|
|
80
|
+
Natural Language Processing Text Classification, \
|
|
81
|
+
Natural Language Processing Token Classification, \
|
|
82
|
+
Natural Language Processing Table Question Answering, \
|
|
83
|
+
Natural Language Processing Question Answering, \
|
|
84
|
+
Natural Language Processing, Zero-Shot Classification \
|
|
85
|
+
Natural Language Processing Translation, Natural Language \
|
|
86
|
+
Processing Summarization, Natural Language Processing \
|
|
87
|
+
Conversational, Natural Language Processing Text \
|
|
88
|
+
Generation, Natural Language Processing Fill-Mask, \
|
|
89
|
+
Natural Language Processing Text2Text Generation, \
|
|
90
|
+
Natural Language Processing Sentence Similarity, \
|
|
91
|
+
Audio Text-to-Speech, Audio Automatic Speech Recognition, \
|
|
92
|
+
Audio Audio-to-Audio, Audio Audio Classification, \
|
|
93
|
+
Audio Voice Activity Detection, Tabular Tabular \
|
|
94
|
+
Classification, Tabular Tabular Regression, \
|
|
95
|
+
Reinforcement Learning Reinforcement Learning, \
|
|
96
|
+
Reinforcement Learning Robotics }"
|
|
97
|
+
elif dataset_name == "tensorflowhub":
|
|
98
|
+
domains = "1. $DOMAIN is inferred from the task description \
|
|
99
|
+
and should include one of {text-sequence-alignment, \
|
|
100
|
+
text-embedding, text-language-model, text-preprocessing, \
|
|
101
|
+
text-classification, text-generation, text-question-answering, \
|
|
102
|
+
text-retrieval-question-answering, text-segmentation, \
|
|
103
|
+
text-to-mel, image-classification, image-feature-vector, \
|
|
104
|
+
image-object-detection, image-segmentation, \
|
|
105
|
+
image-generator, image-pose-detection, image-rnn-agent, \
|
|
106
|
+
image-augmentation, image-classifier, image-style-transfer, \
|
|
107
|
+
image-aesthetic-quality, image-depth-estimation, \
|
|
108
|
+
image-super-resolution, image-deblurring, image-extrapolation, \
|
|
109
|
+
image-text-recognition, image-dehazing, image-deraining, \
|
|
110
|
+
image-enhancemenmt, image-classification-logits, \
|
|
111
|
+
image-frame-interpolation, image-text-detection, image-denoising, \
|
|
112
|
+
image-others, video-classification, video-feature-extraction, \
|
|
113
|
+
video-generation, video-audio-text, video-text, \
|
|
114
|
+
audio-embedding, audio-event-classification, audio-command-detection, \
|
|
115
|
+
audio-paralinguists-classification, audio-speech-to-text, \
|
|
116
|
+
audio-speech-synthesis, audio-synthesis, audio-pitch-extraction}"
|
|
117
|
+
else:
|
|
118
|
+
logger.info("Error: API name is not supported.")
|
|
119
|
+
|
|
120
|
+
prompt = (
|
|
121
|
+
question
|
|
122
|
+
+ "\nWrite a python program in 1 to 2 lines to call API in "
|
|
123
|
+
+ dataset_name
|
|
124
|
+
+ ".\n\nThe answer should follow the format: <<<domain>>> $DOMAIN, \
|
|
125
|
+
<<<api_call>>>: $API_CALL, <<<api_provider>>>: $API_PROVIDER, \
|
|
126
|
+
<<<explanation>>>: $EXPLANATION, <<<code>>>: $CODE}. \
|
|
127
|
+
Here are the requirements:\n"
|
|
128
|
+
+ domains
|
|
129
|
+
+ "\n2. The $API_CALL should have only 1 line of code \
|
|
130
|
+
that calls api.\n 3. The $API_PROVIDER should be the \
|
|
131
|
+
programming framework used.\n4. $EXPLANATION should be \
|
|
132
|
+
a step-by-step explanation.\n5. The $CODE is the python code.\n6. \
|
|
133
|
+
Do not repeat the format in your answer."
|
|
134
|
+
)
|
|
135
|
+
return prompt
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class APIBenchBenchmark(BaseBenchmark):
|
|
139
|
+
r"""APIBench Benchmark adopted from `Gorilla: Large Language Model
|
|
140
|
+
Connected with Massive APIs`
|
|
141
|
+
<https://huggingface.co/datasets/gorilla-llm/APIBench>.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
data_dir (str): The directory to save the data.
|
|
145
|
+
save_to (str): The file to save the results.
|
|
146
|
+
processes (int, optional): The number of processes to use.
|
|
147
|
+
(default: :obj:`1`)
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
# TODO: Integrate retriever (pending)
|
|
151
|
+
|
|
152
|
+
def __init__(
|
|
153
|
+
self,
|
|
154
|
+
data_dir: str,
|
|
155
|
+
save_to: str,
|
|
156
|
+
processes: int = 1,
|
|
157
|
+
):
|
|
158
|
+
r"""Initialize the APIBench benchmark.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
data_dir (str): The directory to save the data.
|
|
162
|
+
save_to (str): The file to save the results.
|
|
163
|
+
processes (int, optional): The number of processes to use for
|
|
164
|
+
parallel processing. (default: :obj:`1`)
|
|
165
|
+
"""
|
|
166
|
+
super().__init__("apibench", data_dir, save_to, processes)
|
|
167
|
+
|
|
168
|
+
def download(self):
|
|
169
|
+
r"""Download the APIBench dataset."""
|
|
170
|
+
from huggingface_hub import snapshot_download
|
|
171
|
+
|
|
172
|
+
snapshot_download(
|
|
173
|
+
repo_id="gorilla-llm/APIBench",
|
|
174
|
+
repo_type="dataset",
|
|
175
|
+
local_dir=self.data_dir,
|
|
176
|
+
local_dir_use_symlinks=True,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
repo = "ShishirPatil/gorilla"
|
|
180
|
+
subdir = "/gorilla/eval/eval-data/questions"
|
|
181
|
+
data_dir = self.data_dir
|
|
182
|
+
|
|
183
|
+
download_github_subdirectory(repo, subdir, data_dir)
|
|
184
|
+
|
|
185
|
+
def load(self, dataset_name: str, force_download: bool = False): # type: ignore[override]
|
|
186
|
+
r"""Load the APIBench Benchmark dataset.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
dataset_name (str): Name of the specific dataset to be loaded.
|
|
190
|
+
force_download (bool, optional): Whether to force
|
|
191
|
+
download the data. (default: :obj:`False`)
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
if force_download:
|
|
195
|
+
logger.info("Force downloading data.")
|
|
196
|
+
self.download()
|
|
197
|
+
|
|
198
|
+
def load_json_lines(file_path: Path):
|
|
199
|
+
r"""Helper function to load JSON lines from a file."""
|
|
200
|
+
try:
|
|
201
|
+
with open(file_path, "r") as f:
|
|
202
|
+
return [json.loads(line) for line in f]
|
|
203
|
+
except FileNotFoundError:
|
|
204
|
+
raise FileNotFoundError(f"File not found: {file_path}")
|
|
205
|
+
except json.JSONDecodeError as e:
|
|
206
|
+
raise ValueError(
|
|
207
|
+
f"Error decoding JSON in file {file_path}: {e}"
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
dataset_path = self.data_dir / dataset_name
|
|
211
|
+
if not dataset_path.exists():
|
|
212
|
+
raise FileNotFoundError(
|
|
213
|
+
f"Dataset directory does not exist: {dataset_path}"
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
for label in ['api', 'eval', 'questions']:
|
|
217
|
+
file_name = dataset_mapping[dataset_name][label]
|
|
218
|
+
file_path = (
|
|
219
|
+
dataset_path / file_name
|
|
220
|
+
if label == 'questions'
|
|
221
|
+
else self.data_dir / file_name
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Load data based on label type
|
|
225
|
+
if label in ['api', 'questions', 'eval']:
|
|
226
|
+
data = load_json_lines(file_path)
|
|
227
|
+
|
|
228
|
+
if label == 'eval':
|
|
229
|
+
# Extract 'api_data' specifically for eval label
|
|
230
|
+
data = [item['api_data'] for item in data]
|
|
231
|
+
|
|
232
|
+
self._data[label] = data
|
|
233
|
+
else:
|
|
234
|
+
raise ValueError(f"Unknown label: {label}")
|
|
235
|
+
|
|
236
|
+
ast_database = []
|
|
237
|
+
for data in self._data['api']:
|
|
238
|
+
ast_tree = ast_parse(data['api_call'])
|
|
239
|
+
ast_database.append(ast_tree)
|
|
240
|
+
self._data['ast'] = ast_database
|
|
241
|
+
|
|
242
|
+
def run( # type: ignore[override]
|
|
243
|
+
self,
|
|
244
|
+
agent: ChatAgent,
|
|
245
|
+
dataset_name: Literal["huggingface", "tensorflowhub", "torchhub"],
|
|
246
|
+
randomize: bool = False,
|
|
247
|
+
subset: Optional[int] = None,
|
|
248
|
+
) -> Dict[str, Any]:
|
|
249
|
+
r"""Run the benchmark.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
agent (ChatAgent): The agent to run the
|
|
253
|
+
benchmark.
|
|
254
|
+
dataset_name (Literal["huggingface",
|
|
255
|
+
"tensorflowhub", "torchhub"]):
|
|
256
|
+
The dataset to run the benchmark.
|
|
257
|
+
randomize (bool, optional): Whether to randomize the data.
|
|
258
|
+
(default: :obj:`False`)
|
|
259
|
+
subset (Optional[int], optional): The subset of data to run.
|
|
260
|
+
(default: :obj:`None`)
|
|
261
|
+
"""
|
|
262
|
+
|
|
263
|
+
if dataset_name not in dataset_mapping:
|
|
264
|
+
raise ValueError(f"Invalid value for dataset: {dataset_name}.")
|
|
265
|
+
|
|
266
|
+
logger.info(f"Running APIBench benchmark on {dataset_name}.")
|
|
267
|
+
self.load(dataset_name)
|
|
268
|
+
datas = self._data['questions']
|
|
269
|
+
|
|
270
|
+
# Shuffle and subset data if necessary
|
|
271
|
+
if randomize:
|
|
272
|
+
random.shuffle(datas)
|
|
273
|
+
if subset:
|
|
274
|
+
datas = datas[:subset]
|
|
275
|
+
|
|
276
|
+
logger.info(f"Number of tasks: {len(datas)}")
|
|
277
|
+
|
|
278
|
+
# Initialize results storage
|
|
279
|
+
self._results = []
|
|
280
|
+
|
|
281
|
+
with open(self.save_to, "w") as f:
|
|
282
|
+
for question in tqdm(datas, desc="Running"):
|
|
283
|
+
prompt = encode_question(question["text"], dataset_name)
|
|
284
|
+
msg = BaseMessage.make_user_message(
|
|
285
|
+
role_name="User", content=prompt
|
|
286
|
+
)
|
|
287
|
+
try:
|
|
288
|
+
# Generate response
|
|
289
|
+
responses = agent.step(msg)
|
|
290
|
+
response = responses.msgs[0].content
|
|
291
|
+
api_database = self._data['api']
|
|
292
|
+
qa_pairs = self._data['eval']
|
|
293
|
+
ast_database = self._data['ast']
|
|
294
|
+
question_id = question['question_id']
|
|
295
|
+
|
|
296
|
+
# Evaluate response
|
|
297
|
+
error, correct, hallucination = evaluate_response(
|
|
298
|
+
response,
|
|
299
|
+
question_id,
|
|
300
|
+
dataset_name,
|
|
301
|
+
api_database,
|
|
302
|
+
qa_pairs,
|
|
303
|
+
ast_database,
|
|
304
|
+
)
|
|
305
|
+
self._results.append(
|
|
306
|
+
{
|
|
307
|
+
"question": question,
|
|
308
|
+
"agent_response": response,
|
|
309
|
+
"correct": correct,
|
|
310
|
+
"hallucination": hallucination,
|
|
311
|
+
"error": str(error) if error else None,
|
|
312
|
+
}
|
|
313
|
+
)
|
|
314
|
+
except Exception as e:
|
|
315
|
+
logger.warning(
|
|
316
|
+
f"Error in processing task: {question}: {e}"
|
|
317
|
+
)
|
|
318
|
+
self._results.append(
|
|
319
|
+
{
|
|
320
|
+
"question": question,
|
|
321
|
+
"agent_response": None,
|
|
322
|
+
"correct": False,
|
|
323
|
+
"hallucination": False,
|
|
324
|
+
"error": str(e),
|
|
325
|
+
}
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
agent.reset()
|
|
329
|
+
|
|
330
|
+
f.write(json.dumps(self._results[-1], indent=2) + "\n")
|
|
331
|
+
f.flush()
|
|
332
|
+
|
|
333
|
+
total = len(self._results)
|
|
334
|
+
correct = sum(r["correct"] for r in self.results)
|
|
335
|
+
hallucination = sum(r["hallucination"] for r in self.results)
|
|
336
|
+
|
|
337
|
+
return {
|
|
338
|
+
"total": total,
|
|
339
|
+
"correct": correct,
|
|
340
|
+
"hallucination": hallucination,
|
|
341
|
+
"accuracy": correct / total if total else "N/A",
|
|
342
|
+
"hallucination rate": hallucination / total if total else "N/A",
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
# This code is modified from the
|
|
347
|
+
# evaluators in the original repo
|
|
348
|
+
# https://github.com/ShishirPatil/gorilla
|
|
349
|
+
# Get all the subtrees given a root_node
|
|
350
|
+
def get_all_sub_trees(root_node):
|
|
351
|
+
node_stack = []
|
|
352
|
+
sub_tree_sexp_list = []
|
|
353
|
+
depth = 1
|
|
354
|
+
# text = root_node.text
|
|
355
|
+
node_stack.append([root_node, depth])
|
|
356
|
+
while len(node_stack) != 0:
|
|
357
|
+
cur_node, cur_depth = node_stack.pop()
|
|
358
|
+
if cur_node.child_count > 0:
|
|
359
|
+
sub_tree_sexp_list.append(
|
|
360
|
+
[
|
|
361
|
+
str(cur_node),
|
|
362
|
+
cur_depth,
|
|
363
|
+
cur_node,
|
|
364
|
+
cur_node.children[0].text,
|
|
365
|
+
]
|
|
366
|
+
)
|
|
367
|
+
else:
|
|
368
|
+
sub_tree_sexp_list.append(
|
|
369
|
+
[str(cur_node), cur_depth, cur_node, None]
|
|
370
|
+
)
|
|
371
|
+
for child_node in cur_node.children:
|
|
372
|
+
if len(child_node.children) != 0:
|
|
373
|
+
depth = cur_depth + 1
|
|
374
|
+
node_stack.append([child_node, depth])
|
|
375
|
+
return sub_tree_sexp_list
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
# Parse the program into AST trees
|
|
379
|
+
def ast_parse(candidate):
|
|
380
|
+
PY_LANGUAGE = Language(tspython.language())
|
|
381
|
+
parser = Parser(PY_LANGUAGE)
|
|
382
|
+
|
|
383
|
+
candidate_tree = parser.parse(bytes(candidate, "utf8")).root_node
|
|
384
|
+
return candidate_tree
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
# Get all the arguments in the ast tree
|
|
388
|
+
def get_args(node, dataset_name):
|
|
389
|
+
if node.child_count == 0:
|
|
390
|
+
return []
|
|
391
|
+
args_list = []
|
|
392
|
+
if dataset_name == "huggingface":
|
|
393
|
+
for child in node.children[0].children[0].children[1].children:
|
|
394
|
+
if "=" in child.text.decode():
|
|
395
|
+
args_list.append(child.children[2].text)
|
|
396
|
+
elif (
|
|
397
|
+
child.text.decode() != "("
|
|
398
|
+
and child.text.decode() != ")"
|
|
399
|
+
and child.text.decode() != ","
|
|
400
|
+
):
|
|
401
|
+
args_list.append(child.text)
|
|
402
|
+
elif dataset_name == "tensorflowhub":
|
|
403
|
+
for child in node.children[0].children[0].children[1].children:
|
|
404
|
+
if (
|
|
405
|
+
'model=' in child.text.decode()
|
|
406
|
+
or 'model =' in child.text.decode()
|
|
407
|
+
):
|
|
408
|
+
args_list.append(child.children[2].text)
|
|
409
|
+
elif (
|
|
410
|
+
child.text.decode() != "("
|
|
411
|
+
and child.text.decode() != ")"
|
|
412
|
+
and child.text.decode() != ","
|
|
413
|
+
):
|
|
414
|
+
args_list.append(child.text)
|
|
415
|
+
elif dataset_name == "torchhub":
|
|
416
|
+
for child in node.children[0].children[0].children[1].children:
|
|
417
|
+
if (
|
|
418
|
+
"repo_or_dir" in child.text.decode()
|
|
419
|
+
or "model" in child.text.decode()
|
|
420
|
+
):
|
|
421
|
+
args_list.append(child.children[2].text)
|
|
422
|
+
return args_list
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
# Check if there is an api match
|
|
426
|
+
def ast_check(candidate_subtree_list, base_tree_list, dataset_name):
|
|
427
|
+
for idx, base_tree in enumerate(base_tree_list):
|
|
428
|
+
if base_tree.children[0].children[0].child_count == 0:
|
|
429
|
+
continue
|
|
430
|
+
api_name = base_tree.children[0].children[0].children[0].text
|
|
431
|
+
for candidate_tree in candidate_subtree_list:
|
|
432
|
+
if candidate_tree[3] == api_name:
|
|
433
|
+
break
|
|
434
|
+
# Now we have a sub-tree
|
|
435
|
+
candidate_tree = candidate_tree[2]
|
|
436
|
+
args_list = get_args(base_tree, dataset_name)
|
|
437
|
+
if len(args_list) == 0:
|
|
438
|
+
continue
|
|
439
|
+
ast_match = True
|
|
440
|
+
for arg in args_list:
|
|
441
|
+
if (
|
|
442
|
+
arg.decode().lstrip("'").rstrip("'")
|
|
443
|
+
not in candidate_tree.text.decode()
|
|
444
|
+
):
|
|
445
|
+
ast_match = False
|
|
446
|
+
break
|
|
447
|
+
if ast_match:
|
|
448
|
+
return idx
|
|
449
|
+
return -1
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
def evaluate_response(
|
|
453
|
+
response, question_id, dataset_name, api_database, qa_pairs, ast_database
|
|
454
|
+
):
|
|
455
|
+
try:
|
|
456
|
+
# Index the "api_call" domain
|
|
457
|
+
output = response.split("api_call")
|
|
458
|
+
if len(output) == 1:
|
|
459
|
+
api_call = output[0]
|
|
460
|
+
else:
|
|
461
|
+
# Parse the output
|
|
462
|
+
output = output[1].split("api_provider")[0]
|
|
463
|
+
if ":" not in output:
|
|
464
|
+
start = 0
|
|
465
|
+
else:
|
|
466
|
+
start = output.index(":")
|
|
467
|
+
if ")" not in output:
|
|
468
|
+
end = -2
|
|
469
|
+
else:
|
|
470
|
+
end = output.rindex(")")
|
|
471
|
+
api_call = output[start + 2 : end + 1]
|
|
472
|
+
|
|
473
|
+
try:
|
|
474
|
+
ast_tree = ast_parse(api_call)
|
|
475
|
+
except Exception as parse_error:
|
|
476
|
+
print(f"Error parsing api_call: {api_call}, error: {parse_error}")
|
|
477
|
+
return parse_error, False, False
|
|
478
|
+
# Search for a subtree
|
|
479
|
+
ast_subtree_list = get_all_sub_trees(ast_tree)
|
|
480
|
+
# Check which ast tree is matching
|
|
481
|
+
database_index = ast_check(
|
|
482
|
+
ast_subtree_list, ast_database, dataset_name
|
|
483
|
+
)
|
|
484
|
+
# We cannot index this ast in our database
|
|
485
|
+
if database_index == -1:
|
|
486
|
+
halluncination = True
|
|
487
|
+
correct = False
|
|
488
|
+
# We index our reference api_call
|
|
489
|
+
ref_api_call = api_database[database_index]
|
|
490
|
+
# Check for functionality
|
|
491
|
+
if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
|
|
492
|
+
correct = True
|
|
493
|
+
halluncination = False
|
|
494
|
+
else:
|
|
495
|
+
return None, False, False
|
|
496
|
+
except Exception as e:
|
|
497
|
+
print(f'Error parsing response: {response}, error: {e}')
|
|
498
|
+
return e, False, False
|
|
499
|
+
|
|
500
|
+
return None, correct, halluncination
|
camel/benchmarks/gaia.py
CHANGED
|
@@ -25,8 +25,8 @@ from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
|
|
25
25
|
from tqdm import tqdm
|
|
26
26
|
|
|
27
27
|
from camel.agents import ChatAgent
|
|
28
|
-
from camel.benchmarks import BaseBenchmark
|
|
29
|
-
from camel.messages
|
|
28
|
+
from camel.benchmarks.base import BaseBenchmark
|
|
29
|
+
from camel.messages import BaseMessage
|
|
30
30
|
from camel.retrievers.auto_retriever import AutoRetriever
|
|
31
31
|
|
|
32
32
|
logger = logging.getLogger(__name__)
|
|
@@ -280,11 +280,11 @@ class GAIABenchmark(BaseBenchmark):
|
|
|
280
280
|
f"Skipping task because file not found: {file_path}"
|
|
281
281
|
)
|
|
282
282
|
return False
|
|
283
|
-
if file_path.suffix in [
|
|
283
|
+
if file_path.suffix in [".pdf", ".docx", ".doc", ".txt"]:
|
|
284
284
|
if not self.retriever.reset(task_id=task["task_id"]):
|
|
285
285
|
return False
|
|
286
286
|
retrieved_info = self.retriever.retrieve(
|
|
287
|
-
query=task["Question"], contents=[task[
|
|
287
|
+
query=task["Question"], contents=[task["file_name"]]
|
|
288
288
|
)
|
|
289
289
|
retrieved_content = [
|
|
290
290
|
item["text"]
|