camel-ai 0.2.58__py3-none-any.whl → 0.2.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/agents/chat_agent.py +126 -9
- camel/agents/critic_agent.py +73 -8
- camel/benchmarks/__init__.py +2 -0
- camel/benchmarks/browsecomp.py +854 -0
- camel/configs/cohere_config.py +1 -1
- camel/configs/mistral_config.py +1 -1
- camel/configs/openai_config.py +3 -0
- camel/configs/reka_config.py +1 -1
- camel/configs/samba_config.py +2 -2
- camel/datagen/cot_datagen.py +29 -34
- camel/embeddings/jina_embedding.py +8 -1
- camel/embeddings/sentence_transformers_embeddings.py +2 -2
- camel/embeddings/vlm_embedding.py +9 -2
- camel/human.py +14 -0
- camel/memories/records.py +3 -0
- camel/messages/base.py +15 -3
- camel/models/azure_openai_model.py +1 -0
- camel/models/model_factory.py +2 -2
- camel/retrievers/bm25_retriever.py +1 -2
- camel/retrievers/hybrid_retrival.py +2 -2
- camel/societies/role_playing.py +50 -0
- camel/societies/workforce/role_playing_worker.py +17 -8
- camel/societies/workforce/workforce.py +70 -14
- camel/storages/vectordb_storages/oceanbase.py +1 -2
- camel/toolkits/async_browser_toolkit.py +5 -1
- camel/toolkits/base.py +4 -2
- camel/toolkits/browser_toolkit.py +6 -3
- camel/toolkits/dalle_toolkit.py +4 -0
- camel/toolkits/excel_toolkit.py +11 -3
- camel/toolkits/github_toolkit.py +43 -25
- camel/toolkits/image_analysis_toolkit.py +3 -0
- camel/toolkits/jina_reranker_toolkit.py +194 -77
- camel/toolkits/mcp_toolkit.py +60 -16
- camel/toolkits/page_script.js +40 -28
- camel/toolkits/twitter_toolkit.py +6 -1
- camel/toolkits/video_analysis_toolkit.py +3 -0
- camel/toolkits/video_download_toolkit.py +3 -0
- camel/toolkits/wolfram_alpha_toolkit.py +46 -22
- camel/types/enums.py +14 -5
- {camel_ai-0.2.58.dist-info → camel_ai-0.2.60.dist-info}/METADATA +7 -9
- {camel_ai-0.2.58.dist-info → camel_ai-0.2.60.dist-info}/RECORD +44 -43
- {camel_ai-0.2.58.dist-info → camel_ai-0.2.60.dist-info}/WHEEL +0 -0
- {camel_ai-0.2.58.dist-info → camel_ai-0.2.60.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,854 @@
|
|
|
1
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
|
+
|
|
15
|
+
import base64
|
|
16
|
+
import hashlib
|
|
17
|
+
import json
|
|
18
|
+
import os
|
|
19
|
+
import random
|
|
20
|
+
import traceback
|
|
21
|
+
from collections import defaultdict
|
|
22
|
+
from multiprocessing.pool import ThreadPool
|
|
23
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
24
|
+
|
|
25
|
+
from pydantic import BaseModel, Field
|
|
26
|
+
|
|
27
|
+
from camel.agents.chat_agent import ChatAgent
|
|
28
|
+
from camel.benchmarks.base import BaseBenchmark
|
|
29
|
+
from camel.logger import get_logger
|
|
30
|
+
from camel.societies.role_playing import RolePlaying
|
|
31
|
+
from camel.societies.workforce.workforce import Workforce
|
|
32
|
+
from camel.tasks.task import Task
|
|
33
|
+
|
|
34
|
+
logger = get_logger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Message(BaseModel):
|
|
38
|
+
role: str
|
|
39
|
+
content: str
|
|
40
|
+
variant: Optional[str] = None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
MessageList = List[Message]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class QueryResponse(BaseModel):
|
|
47
|
+
r"""A structured query response for benchmark evaluation.
|
|
48
|
+
|
|
49
|
+
This class defines the expected format for model responses to benchmark
|
|
50
|
+
questions, including explanation, exact answer, and confidence score.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
explanation: str = Field(
|
|
54
|
+
description="""your explanation for your final answer."""
|
|
55
|
+
)
|
|
56
|
+
exact_answer: str = Field(description="""your succinct, final answer.""")
|
|
57
|
+
confidence: str = Field(
|
|
58
|
+
description="""
|
|
59
|
+
your confidence score between 0|\%| and 100|\%| for your answer.
|
|
60
|
+
"""
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class GradingResponse(BaseModel):
|
|
65
|
+
r"""A structured grading response for evaluating model answers.
|
|
66
|
+
|
|
67
|
+
This class defines the expected format for grading responses, including
|
|
68
|
+
extracted answer, reasoning about correctness, binary correctness judgment,
|
|
69
|
+
and confidence score extraction.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
extracted_final_answer: str = Field(
|
|
73
|
+
description="""
|
|
74
|
+
The final exact answer extracted from the [response].
|
|
75
|
+
Put the extracted answer as 'None' if there is no exact, final answer to
|
|
76
|
+
extract from the response."""
|
|
77
|
+
)
|
|
78
|
+
reasoning: str = Field(
|
|
79
|
+
description="""
|
|
80
|
+
Explain why the extracted_final_answer is correct or incorrect
|
|
81
|
+
based on [correct_answer], focusing only on if there are meaningful
|
|
82
|
+
differences between [correct_answer] and the extracted_final_answer.
|
|
83
|
+
Do not comment on any background to the problem, do not attempt
|
|
84
|
+
to solve the problem, do not argue for any answer different
|
|
85
|
+
than [correct_answer], focus only on whether the answers match."""
|
|
86
|
+
)
|
|
87
|
+
correct: str = Field(
|
|
88
|
+
description="""Answer 'yes' if extracted_final_answer matches the
|
|
89
|
+
[correct_answer] given above, or is within a small margin of error for
|
|
90
|
+
numerical problems. Answer 'no' otherwise, i.e. if there if there is any
|
|
91
|
+
inconsistency, ambiguity, non-equivalency, or if the extracted answer is
|
|
92
|
+
incorrect."""
|
|
93
|
+
)
|
|
94
|
+
confidence: str = Field(
|
|
95
|
+
description="""The extracted confidence score between 0|\%|
|
|
96
|
+
and 100|\%| from [response]. Put 100 if there is no confidence score available.
|
|
97
|
+
"""
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class SingleEvalResult(BaseModel):
|
|
102
|
+
r"""Result of evaluating a single benchmark sample.
|
|
103
|
+
|
|
104
|
+
This class stores the evaluation results for a single benchmark example,
|
|
105
|
+
including score, HTML representation, conversation history, and metrics.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
score: Optional[float] = None
|
|
109
|
+
html: str
|
|
110
|
+
convo: MessageList
|
|
111
|
+
metrics: Dict[str, float] = Field(default_factory=dict)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class EvalResult(BaseModel):
|
|
115
|
+
r"""Result of running a complete benchmark evaluation.
|
|
116
|
+
|
|
117
|
+
This class aggregates results from multiple sample evaluations, storing
|
|
118
|
+
the overall score, detailed metrics, HTML reports, and conversation logs.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
score: Optional[float] = None # top-line metric
|
|
122
|
+
metrics: Optional[Dict[str, float]] = None # other metrics
|
|
123
|
+
htmls: List[str] # strings of valid HTML
|
|
124
|
+
convos: List[MessageList] # sampled conversations
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
# Define the message template first
|
|
128
|
+
_message_template = """
|
|
129
|
+
<div class="message {{ role }}">
|
|
130
|
+
<div class="role">
|
|
131
|
+
{{ role }}
|
|
132
|
+
{% if variant %}<span class="variant">({{ variant }})</span>{% endif %}
|
|
133
|
+
</div>
|
|
134
|
+
<div class="content">
|
|
135
|
+
<pre>{{ content }}</pre>
|
|
136
|
+
</div>
|
|
137
|
+
</div>
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
# TODO: Add necessary prompts when tuning.
|
|
141
|
+
QUERY_TEMPLATE = """
|
|
142
|
+
{question}
|
|
143
|
+
|
|
144
|
+
Your response should be in the following format:
|
|
145
|
+
Explanation: {{your explanation for your final answer}}
|
|
146
|
+
Exact Answer: {{your succinct, final answer}}
|
|
147
|
+
Confidence: {{your confidence score between 0% and 100% for your answer}}
|
|
148
|
+
""".strip()
|
|
149
|
+
|
|
150
|
+
SUMMARIZE_TEMPLATE = """
|
|
151
|
+
Based on the chat history:
|
|
152
|
+
{chat_history}
|
|
153
|
+
|
|
154
|
+
answer the question:
|
|
155
|
+
{query}
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
FORMAT_JSON_TEMPLATE = """
|
|
159
|
+
format content into json:
|
|
160
|
+
{content}
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
GRADER_TEMPLATE = """
|
|
164
|
+
Judge whether the following [response] to [question] is correct or not
|
|
165
|
+
based on the precise and unambiguous [correct_answer] below.
|
|
166
|
+
|
|
167
|
+
[question]: {question}
|
|
168
|
+
|
|
169
|
+
[response]: {response}
|
|
170
|
+
|
|
171
|
+
Your judgement must be in the format and criteria specified below:
|
|
172
|
+
|
|
173
|
+
extracted_final_answer: The final exact answer extracted from the [response].
|
|
174
|
+
Put the extracted answer as 'None' if there is no exact, final answer to
|
|
175
|
+
extract from the response.
|
|
176
|
+
|
|
177
|
+
[correct_answer]: {correct_answer}
|
|
178
|
+
|
|
179
|
+
reasoning: Explain why the extracted_final_answer is correct or incorrect
|
|
180
|
+
based on [correct_answer], focusing only on if there are meaningful
|
|
181
|
+
differences between [correct_answer] and the extracted_final_answer.
|
|
182
|
+
Do not comment on any background to the problem, do not attempt
|
|
183
|
+
to solve the problem, do not argue for any answer different
|
|
184
|
+
than [correct_answer], focus only on whether the answers match.
|
|
185
|
+
|
|
186
|
+
correct: Answer 'yes' if extracted_final_answer matches the
|
|
187
|
+
[correct_answer] given above, or is within a small margin of error for
|
|
188
|
+
numerical problems. Answer 'no' otherwise, i.e. if there is any
|
|
189
|
+
inconsistency, ambiguity, non-equivalency, or if the extracted answer is
|
|
190
|
+
incorrect.
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
confidence: The extracted confidence score between 0|\%| and 100|\%|
|
|
194
|
+
from [response]. Put 100 if there is no confidence score available.
|
|
195
|
+
""".strip()
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
HTML_JINJA = """
|
|
199
|
+
<h3>Question:</h3>
|
|
200
|
+
{{ message_to_html(prompt_messages) | safe }}
|
|
201
|
+
<h3>Sampled message</h3>
|
|
202
|
+
{{ message_to_html(next_message) | safe }}
|
|
203
|
+
<h3>Results</h3>
|
|
204
|
+
<p>Correct Answer: {{ correct_answer }}</p>
|
|
205
|
+
<p>Extracted Answer: {{ extracted_answer }}</p>
|
|
206
|
+
<p>Score: {{ score }}</p>
|
|
207
|
+
"""
|
|
208
|
+
_report_template = """<!DOCTYPE html>
|
|
209
|
+
<html>
|
|
210
|
+
<head>
|
|
211
|
+
<style>
|
|
212
|
+
.message {
|
|
213
|
+
padding: 8px 16px;
|
|
214
|
+
margin-bottom: 8px;
|
|
215
|
+
border-radius: 4px;
|
|
216
|
+
}
|
|
217
|
+
.message.user {
|
|
218
|
+
background-color: #B2DFDB;
|
|
219
|
+
color: #00695C;
|
|
220
|
+
}
|
|
221
|
+
.message.assistant {
|
|
222
|
+
background-color: #B39DDB;
|
|
223
|
+
color: #4527A0;
|
|
224
|
+
}
|
|
225
|
+
.message.system {
|
|
226
|
+
background-color: #EEEEEE;
|
|
227
|
+
color: #212121;
|
|
228
|
+
}
|
|
229
|
+
.role {
|
|
230
|
+
font-weight: bold;
|
|
231
|
+
margin-bottom: 4px;
|
|
232
|
+
}
|
|
233
|
+
.variant {
|
|
234
|
+
color: #795548;
|
|
235
|
+
}
|
|
236
|
+
table, th, td {
|
|
237
|
+
border: 1px solid black;
|
|
238
|
+
}
|
|
239
|
+
pre {
|
|
240
|
+
white-space: pre-wrap;
|
|
241
|
+
}
|
|
242
|
+
</style>
|
|
243
|
+
</head>
|
|
244
|
+
<body>
|
|
245
|
+
{% if metrics %}
|
|
246
|
+
<h1>Metrics</h1>
|
|
247
|
+
<table>
|
|
248
|
+
<tr>
|
|
249
|
+
<th>Metric</th>
|
|
250
|
+
<th>Value</th>
|
|
251
|
+
</tr>
|
|
252
|
+
<tr>
|
|
253
|
+
<td><b>Score</b></td>
|
|
254
|
+
<td>{{ score | float | round(3) }}</td>
|
|
255
|
+
</tr>
|
|
256
|
+
{% for name, value in metrics.items() %}
|
|
257
|
+
<tr>
|
|
258
|
+
<td>{{ name }}</td>
|
|
259
|
+
<td>{{ value }}</td>
|
|
260
|
+
</tr>
|
|
261
|
+
{% endfor %}
|
|
262
|
+
</table>
|
|
263
|
+
{% endif %}
|
|
264
|
+
<h1>Examples</h1>
|
|
265
|
+
{% for html in htmls %}
|
|
266
|
+
{{ html | safe }}
|
|
267
|
+
<hr>
|
|
268
|
+
{% endfor %}
|
|
269
|
+
</body>
|
|
270
|
+
</html>
|
|
271
|
+
"""
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class JinjaEnv:
|
|
275
|
+
r"""A class that encapsulates the Jinja environment setup."""
|
|
276
|
+
|
|
277
|
+
_instance: Optional['JinjaEnv'] = None
|
|
278
|
+
_env = None
|
|
279
|
+
|
|
280
|
+
def __init__(self):
|
|
281
|
+
r"""Initialize the JinjaEnv instance if not already initialized."""
|
|
282
|
+
if not getattr(self, '_initialized', False):
|
|
283
|
+
self._initialized = True
|
|
284
|
+
|
|
285
|
+
def __new__(cls):
|
|
286
|
+
r"""Implement singleton pattern to ensure only one instance exists."""
|
|
287
|
+
if cls._instance is None:
|
|
288
|
+
cls._instance = super(JinjaEnv, cls).__new__(cls)
|
|
289
|
+
cls._instance._initialized = False
|
|
290
|
+
return cls._instance
|
|
291
|
+
|
|
292
|
+
@classmethod
|
|
293
|
+
def get_instance(cls):
|
|
294
|
+
r"""Get the singleton instance of JinjaEnv.
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
JinjaEnv: The singleton instance.
|
|
298
|
+
"""
|
|
299
|
+
if cls._instance is None:
|
|
300
|
+
cls._instance = cls()
|
|
301
|
+
return cls._instance
|
|
302
|
+
|
|
303
|
+
@property
|
|
304
|
+
def env(self):
|
|
305
|
+
r"""Lazily initialize and return the Jinja environment.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
jinja2.Environment: The Jinja environment instance.
|
|
309
|
+
"""
|
|
310
|
+
if self._env is None:
|
|
311
|
+
# Lazy import of jinja2
|
|
312
|
+
import jinja2
|
|
313
|
+
|
|
314
|
+
# Create the Jinja environment
|
|
315
|
+
self._env = jinja2.Environment(
|
|
316
|
+
loader=jinja2.BaseLoader(),
|
|
317
|
+
undefined=jinja2.StrictUndefined,
|
|
318
|
+
autoescape=jinja2.select_autoescape(["html", "xml"]),
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
# Register the message_to_html function
|
|
322
|
+
self._env.globals["message_to_html"] = self.message_to_html
|
|
323
|
+
|
|
324
|
+
return self._env
|
|
325
|
+
|
|
326
|
+
def from_string(self, template_str):
|
|
327
|
+
r"""Create a template from the given string.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
template_str (str): The template string.
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
jinja2.Template: The compiled template.
|
|
334
|
+
"""
|
|
335
|
+
return self.env.from_string(template_str)
|
|
336
|
+
|
|
337
|
+
@staticmethod
|
|
338
|
+
def message_to_html(message: Message) -> str:
|
|
339
|
+
r"""Generate HTML snippet (inside a <div>) for a message.
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
message (Message): The message to convert to HTML.
|
|
343
|
+
|
|
344
|
+
Returns:
|
|
345
|
+
str: The HTML representation of the message.
|
|
346
|
+
"""
|
|
347
|
+
return (
|
|
348
|
+
JinjaEnv.get_instance()
|
|
349
|
+
.from_string(_message_template)
|
|
350
|
+
.render(
|
|
351
|
+
role=message.role,
|
|
352
|
+
content=message.content,
|
|
353
|
+
variant=message.variant,
|
|
354
|
+
)
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def derive_key(password: str, length: int) -> bytes:
|
|
359
|
+
r"""Derive a fixed-length key from the password using SHA256."""
|
|
360
|
+
hasher = hashlib.sha256()
|
|
361
|
+
hasher.update(password.encode())
|
|
362
|
+
key = hasher.digest()
|
|
363
|
+
return key * (length // len(key)) + key[: length % len(key)]
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def decrypt(ciphertext_b64: str, password: str) -> str:
|
|
367
|
+
r"""Decrypt base64-encoded ciphertext with XOR."""
|
|
368
|
+
encrypted = base64.b64decode(ciphertext_b64)
|
|
369
|
+
key = derive_key(password, len(encrypted))
|
|
370
|
+
decrypted = bytes(a ^ b for a, b in zip(encrypted, key))
|
|
371
|
+
return decrypted.decode()
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def _compute_stat(values: list, stat: str):
|
|
375
|
+
import numpy as np
|
|
376
|
+
|
|
377
|
+
if stat == "mean":
|
|
378
|
+
return np.mean(values)
|
|
379
|
+
elif stat == "std":
|
|
380
|
+
return np.std(values)
|
|
381
|
+
elif stat == "min":
|
|
382
|
+
return np.min(values)
|
|
383
|
+
elif stat == "max":
|
|
384
|
+
return np.max(values)
|
|
385
|
+
else:
|
|
386
|
+
raise ValueError(f"Unknown {stat =}")
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
def aggregate_results(
|
|
390
|
+
single_eval_results: List[SingleEvalResult],
|
|
391
|
+
default_stats: Tuple[str, str] = ("mean", "std"),
|
|
392
|
+
name2stats: Optional[Dict[str, Tuple[str]]] = None,
|
|
393
|
+
) -> EvalResult:
|
|
394
|
+
r"""Aggregate results from multiple evaluations into a single EvalResult.
|
|
395
|
+
|
|
396
|
+
Args:
|
|
397
|
+
single_eval_results (List[SingleEvalResult]): A list of
|
|
398
|
+
`SingleEvalResult` objects.
|
|
399
|
+
default_stats (Tuple[str, str]): A tuple of default statistics to
|
|
400
|
+
compute. (default: :obj:`("mean", "std")`)
|
|
401
|
+
name2stats (Optional[Dict[str, Tuple[str]]]): A dictionary mapping
|
|
402
|
+
metric names to statistics to compute. (default: :obj:`None`)
|
|
403
|
+
|
|
404
|
+
Returns:
|
|
405
|
+
EvalResult: An `EvalResult` object containing aggregated results.
|
|
406
|
+
"""
|
|
407
|
+
name2stats = name2stats or {}
|
|
408
|
+
name2values = defaultdict(list)
|
|
409
|
+
htmls = []
|
|
410
|
+
convos = []
|
|
411
|
+
|
|
412
|
+
for single_eval_result in single_eval_results:
|
|
413
|
+
for name, value in single_eval_result.metrics.items():
|
|
414
|
+
name2values[name].append(value)
|
|
415
|
+
if single_eval_result.score is not None:
|
|
416
|
+
name2values["score"].append(single_eval_result.score)
|
|
417
|
+
htmls.append(single_eval_result.html)
|
|
418
|
+
convos.append(single_eval_result.convo)
|
|
419
|
+
|
|
420
|
+
final_metrics = {}
|
|
421
|
+
for name, values in name2values.items():
|
|
422
|
+
stats = name2stats.get(name, default_stats)
|
|
423
|
+
for stat in stats:
|
|
424
|
+
key = name if stat == "mean" else f"{name}:{stat}"
|
|
425
|
+
final_metrics[key] = _compute_stat(values, stat)
|
|
426
|
+
|
|
427
|
+
return EvalResult(
|
|
428
|
+
score=final_metrics.pop("score", None),
|
|
429
|
+
metrics=final_metrics,
|
|
430
|
+
htmls=htmls,
|
|
431
|
+
convos=convos,
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
class BrowseCompBenchmark(BaseBenchmark):
|
|
436
|
+
r"""BrowseComp Benchmark for evaluating browser-based comprehension tasks.
|
|
437
|
+
|
|
438
|
+
This benchmark evaluates the ability of language models to comprehend and
|
|
439
|
+
answer questions based on browser-based content, measuring accuracy and
|
|
440
|
+
performance.
|
|
441
|
+
"""
|
|
442
|
+
|
|
443
|
+
def __init__(
|
|
444
|
+
self,
|
|
445
|
+
save_to: str,
|
|
446
|
+
processes: int = 1,
|
|
447
|
+
num_examples: Optional[int] = None,
|
|
448
|
+
n_repeats: int = 1,
|
|
449
|
+
):
|
|
450
|
+
r"""Initialize the BrowseComp benchmark.
|
|
451
|
+
|
|
452
|
+
Args:
|
|
453
|
+
save_to (str): The file to save the results.
|
|
454
|
+
processes (int, optional): The number of processes to use for
|
|
455
|
+
parallel processing. (default: :obj:`1`)
|
|
456
|
+
num_examples (Optional[int]): Number of examples to evaluate.
|
|
457
|
+
If None, all examples are used. Controls the sample size for
|
|
458
|
+
testing. (default: :obj:`None`)
|
|
459
|
+
n_repeats (int, optional): Number of times to repeat each example.
|
|
460
|
+
Useful for evaluating consistency across multiple runs.
|
|
461
|
+
(default: :obj:`1`)
|
|
462
|
+
"""
|
|
463
|
+
# Browsecomp benchmark won't download any data
|
|
464
|
+
# use current path as the data_dir passing into super init
|
|
465
|
+
current_path = os.path.dirname(os.path.abspath(__file__))
|
|
466
|
+
|
|
467
|
+
super().__init__("browsecomp", current_path, save_to, processes)
|
|
468
|
+
self.num_examples = num_examples
|
|
469
|
+
self.n_repeats = n_repeats
|
|
470
|
+
self.examples: List[Dict[str, Any]] = []
|
|
471
|
+
self.load()
|
|
472
|
+
self._raw_results: List[Any] = []
|
|
473
|
+
self._validated_results: List[SingleEvalResult] = []
|
|
474
|
+
self._eval_result: EvalResult
|
|
475
|
+
self.jinja_env = JinjaEnv.get_instance()
|
|
476
|
+
|
|
477
|
+
def download(self):
|
|
478
|
+
r"""Download the BrowseComp dataset.
|
|
479
|
+
|
|
480
|
+
This method is implemented to maintain compatibility
|
|
481
|
+
with the BaseBenchmark interface, but BrowseComp doesn't
|
|
482
|
+
require downloading data separately.
|
|
483
|
+
|
|
484
|
+
Returns:
|
|
485
|
+
self: The benchmark instance
|
|
486
|
+
"""
|
|
487
|
+
logger.info("BrowseComp benchmark does not require downloading data.")
|
|
488
|
+
return self
|
|
489
|
+
|
|
490
|
+
def load(self):
|
|
491
|
+
r"""Load the BrowseComp dataset.
|
|
492
|
+
|
|
493
|
+
This method loads the dataset from a remote CSV file, converts each
|
|
494
|
+
row to a dictionary, and applies sampling if num_examples is
|
|
495
|
+
specified. It also handles repeating examples if n_repeats > 1.
|
|
496
|
+
|
|
497
|
+
Returns:
|
|
498
|
+
self: The benchmark instance
|
|
499
|
+
"""
|
|
500
|
+
# Load dataset from remote CSV
|
|
501
|
+
import pandas
|
|
502
|
+
|
|
503
|
+
df = pandas.read_csv(
|
|
504
|
+
"https://openaipublic.blob.core.windows.net/simple-evals/browse_comp_test_set.csv"
|
|
505
|
+
)
|
|
506
|
+
# Convert each row to a dictionary
|
|
507
|
+
examples = [row.to_dict() for _, row in df.iterrows()]
|
|
508
|
+
|
|
509
|
+
# Sample examples if num_examples is specified
|
|
510
|
+
if self.num_examples:
|
|
511
|
+
assert (
|
|
512
|
+
self.n_repeats == 1
|
|
513
|
+
), "n_repeats only supported when max_examples = None"
|
|
514
|
+
rng = random.Random(0) # Use fixed seed for reproducibility
|
|
515
|
+
examples = rng.sample(examples, self.num_examples)
|
|
516
|
+
|
|
517
|
+
# Repeat examples if n_repeats > 1
|
|
518
|
+
self.examples = examples * self.n_repeats
|
|
519
|
+
return self
|
|
520
|
+
|
|
521
|
+
@property
|
|
522
|
+
def train(self):
|
|
523
|
+
r"""Get the training set.
|
|
524
|
+
|
|
525
|
+
This property is implemented to maintain compatibility with
|
|
526
|
+
the BaseBenchmark interface, but BrowseComp doesn't have a
|
|
527
|
+
training set.
|
|
528
|
+
|
|
529
|
+
Raises:
|
|
530
|
+
NotImplementedError: BrowseComp does not have a training set.
|
|
531
|
+
"""
|
|
532
|
+
raise NotImplementedError("BrowseComp does not have a training set.")
|
|
533
|
+
|
|
534
|
+
def run( # type: ignore[override]
|
|
535
|
+
self,
|
|
536
|
+
pipeline_template: Union[ChatAgent, RolePlaying, Workforce],
|
|
537
|
+
chat_turn_limit: int = 10,
|
|
538
|
+
roleplaying_summarizer: Optional[ChatAgent] = None,
|
|
539
|
+
task_json_formatter: Optional[ChatAgent] = None,
|
|
540
|
+
) -> None:
|
|
541
|
+
r"""Run the benchmark by processing each example in parallel.
|
|
542
|
+
|
|
543
|
+
This method applies the provided pipeline to each example in the
|
|
544
|
+
dataset using a process pool for parallel execution. It shows progress
|
|
545
|
+
using tqdm and stores the results in self._raw_results.
|
|
546
|
+
|
|
547
|
+
Args:
|
|
548
|
+
pipeline_template (Union[ChatAgent, RolePlaying, Workforce]): The
|
|
549
|
+
template agent or framework to use for processing examples.
|
|
550
|
+
Can be a ChatAgent, RolePlaying, or Workforce instance that
|
|
551
|
+
will be cloned for each example.
|
|
552
|
+
chat_turn_limit (int): Maximum number of conversation turns allowed
|
|
553
|
+
when using RolePlaying pipeline. (default: :obj:`10`)
|
|
554
|
+
roleplaying_summarizer (Optional[ChatAgent]): Optional ChatAgent to
|
|
555
|
+
summarize RolePlaying conversations. If None and RolePlaying is
|
|
556
|
+
used, a default summarizer will be created.
|
|
557
|
+
(default: :obj:`None`)
|
|
558
|
+
task_json_formatter (Optional[ChatAgent]): Optional ChatAgent to
|
|
559
|
+
format task JSON. If None and Workforce is used, a default
|
|
560
|
+
formatter will be created. (default: :obj:`None`)
|
|
561
|
+
"""
|
|
562
|
+
from tqdm import tqdm
|
|
563
|
+
|
|
564
|
+
# Use a process pool for parallel execution
|
|
565
|
+
def process_benchmark_row(row: Dict[str, Any]) -> Dict[str, Any]:
|
|
566
|
+
r"""This inner function processes a single benchmark row by
|
|
567
|
+
extracting the problem and answer, creating a pipeline instance,
|
|
568
|
+
and generating a response using the appropriate method based on
|
|
569
|
+
the pipeline type.
|
|
570
|
+
|
|
571
|
+
Args:
|
|
572
|
+
row (Dict[str, Any]): A row from the dataset containing
|
|
573
|
+
encrypted problem and answer, along with a canary for
|
|
574
|
+
decryption.
|
|
575
|
+
|
|
576
|
+
Returns:
|
|
577
|
+
Dict[str, Any]: A dictionary containing the decrypted problem,
|
|
578
|
+
expected answer, model response, and structured response
|
|
579
|
+
fields.
|
|
580
|
+
"""
|
|
581
|
+
|
|
582
|
+
problem = decrypt(row.get("problem", ""), row.get("canary", ""))
|
|
583
|
+
answer = decrypt(row.get("answer", ""), row.get("canary", ""))
|
|
584
|
+
try:
|
|
585
|
+
input_message = QUERY_TEMPLATE.format(question=problem)
|
|
586
|
+
|
|
587
|
+
if isinstance(pipeline_template, (ChatAgent)):
|
|
588
|
+
pipeline = pipeline_template.clone() # type: ignore[assignment]
|
|
589
|
+
|
|
590
|
+
response_text = pipeline.step(
|
|
591
|
+
input_message, response_format=QueryResponse
|
|
592
|
+
)
|
|
593
|
+
elif isinstance(pipeline_template, Workforce):
|
|
594
|
+
pipeline = pipeline_template.clone() # type: ignore[assignment]
|
|
595
|
+
task = Task(content=input_message, id="0")
|
|
596
|
+
task = pipeline.process_task(task) # type: ignore[attr-defined]
|
|
597
|
+
if task_json_formatter:
|
|
598
|
+
formatter_in_process = task_json_formatter.clone()
|
|
599
|
+
else:
|
|
600
|
+
formatter_in_process = ChatAgent(
|
|
601
|
+
"You are a helpful assistant."
|
|
602
|
+
)
|
|
603
|
+
response_text = formatter_in_process.step(
|
|
604
|
+
FORMAT_JSON_TEMPLATE.format(content=task.result),
|
|
605
|
+
response_format=QueryResponse,
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
elif isinstance(pipeline_template, RolePlaying):
|
|
609
|
+
# RolePlaying is different.
|
|
610
|
+
pipeline = pipeline_template.clone( # type: ignore[assignment]
|
|
611
|
+
task_prompt=input_message
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
n = 0
|
|
615
|
+
input_msg = pipeline.init_chat() # type: ignore[attr-defined]
|
|
616
|
+
chat_history = []
|
|
617
|
+
while n < chat_turn_limit:
|
|
618
|
+
n += 1
|
|
619
|
+
assistant_response, user_response = pipeline.step(
|
|
620
|
+
input_msg
|
|
621
|
+
)
|
|
622
|
+
if assistant_response.terminated: # type: ignore[attr-defined]
|
|
623
|
+
break
|
|
624
|
+
if user_response.terminated: # type: ignore[attr-defined]
|
|
625
|
+
break
|
|
626
|
+
if "CAMEL_TASK_DONE" in user_response.msg.content: # type: ignore[attr-defined]
|
|
627
|
+
break
|
|
628
|
+
|
|
629
|
+
chat_history.append(
|
|
630
|
+
f"AI User: {user_response.msg.content}" # type: ignore[attr-defined]
|
|
631
|
+
)
|
|
632
|
+
chat_history.append(
|
|
633
|
+
f"AI Assistant: {assistant_response.msg.content}" # type: ignore[attr-defined]
|
|
634
|
+
)
|
|
635
|
+
input_msg = assistant_response.msg # type: ignore[attr-defined]
|
|
636
|
+
|
|
637
|
+
chat_history_str = "\n".join(chat_history)
|
|
638
|
+
if roleplaying_summarizer:
|
|
639
|
+
summarizer_in_process = roleplaying_summarizer.clone()
|
|
640
|
+
else:
|
|
641
|
+
summarizer_in_process = ChatAgent(
|
|
642
|
+
"You are a helpful assistant."
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
summarize_prompt = SUMMARIZE_TEMPLATE.format(
|
|
646
|
+
chat_history=chat_history_str,
|
|
647
|
+
query=input_message,
|
|
648
|
+
)
|
|
649
|
+
response_text = summarizer_in_process.step(
|
|
650
|
+
summarize_prompt, response_format=QueryResponse
|
|
651
|
+
)
|
|
652
|
+
else:
|
|
653
|
+
raise NotImplementedError(
|
|
654
|
+
f"{type(pipeline_template)} is not supported."
|
|
655
|
+
)
|
|
656
|
+
# Parse the response JSON
|
|
657
|
+
response_dict = json.loads(response_text.msg.content)
|
|
658
|
+
|
|
659
|
+
# Format the response as a key-value string
|
|
660
|
+
formatted_response = f"""
|
|
661
|
+
Explanation: {response_dict['explanation']}
|
|
662
|
+
|
|
663
|
+
Exact Answer: {response_dict['exact_answer']}
|
|
664
|
+
Confidence: {response_dict['confidence']}"""
|
|
665
|
+
|
|
666
|
+
# Create the result dictionary
|
|
667
|
+
raw_result = {}
|
|
668
|
+
raw_result['problem'] = problem
|
|
669
|
+
raw_result['expected_answer'] = answer
|
|
670
|
+
raw_result['response'] = formatted_response
|
|
671
|
+
# Keep the original dict for reference
|
|
672
|
+
raw_result['response_dict'] = response_dict
|
|
673
|
+
|
|
674
|
+
return raw_result
|
|
675
|
+
except Exception as e:
|
|
676
|
+
# Log any errors that occur during evaluation
|
|
677
|
+
logger.error(f"Error evaluating result: {e}")
|
|
678
|
+
logger.error(traceback.format_exc())
|
|
679
|
+
return {
|
|
680
|
+
'problem': problem,
|
|
681
|
+
'expected_answer': answer,
|
|
682
|
+
'response': traceback.format_exc(),
|
|
683
|
+
'response_dict': {},
|
|
684
|
+
}
|
|
685
|
+
|
|
686
|
+
pool_class = ThreadPool
|
|
687
|
+
with pool_class(min(self.processes, len(self.examples))) as pool:
|
|
688
|
+
self._raw_results = list(
|
|
689
|
+
tqdm(
|
|
690
|
+
pool.imap(process_benchmark_row, self.examples),
|
|
691
|
+
total=len(self.examples),
|
|
692
|
+
)
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
def make_report(self, eval_result: EvalResult) -> str:
|
|
696
|
+
r"""Create a standalone HTML report from an EvalResult."""
|
|
697
|
+
return self.jinja_env.from_string(_report_template).render(
|
|
698
|
+
score=eval_result.score,
|
|
699
|
+
metrics=eval_result.metrics,
|
|
700
|
+
htmls=eval_result.htmls,
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
def validate(self, grader: Optional[ChatAgent] = None) -> None:
|
|
704
|
+
r"""Validate the raw results using the GRADER_TEMPLATE and ChatAgent.
|
|
705
|
+
|
|
706
|
+
This method evaluates the correctness of each response by
|
|
707
|
+
multi-threading. A dedicated chat agent is created in each thread.
|
|
708
|
+
The chat agent will compare raw result with the expected answer. The
|
|
709
|
+
grading results will be aggregated in a report.
|
|
710
|
+
|
|
711
|
+
Args:
|
|
712
|
+
grader: The ChatAgent used for validation. If None, a default
|
|
713
|
+
agent will be created in each thread. If provided, the
|
|
714
|
+
provided agent will be used as a template and be cloned into
|
|
715
|
+
new agents in each thread. (default: :obj:`None`)
|
|
716
|
+
"""
|
|
717
|
+
from tqdm import tqdm
|
|
718
|
+
|
|
719
|
+
def validate_each_one(raw_result: Dict[str, Any]) -> SingleEvalResult:
|
|
720
|
+
r"""This inner function formats the prompt for the ChatAgent
|
|
721
|
+
grader, sends it for evaluation, extracts the correctness
|
|
722
|
+
assessment, and creates an HTML representation of the result.
|
|
723
|
+
|
|
724
|
+
Args:
|
|
725
|
+
raw_result (Dict[str, Any]): A dictionary containing 'problem',
|
|
726
|
+
'response', and 'expected_answer' keys.
|
|
727
|
+
|
|
728
|
+
Returns:
|
|
729
|
+
SingleEvalResult: An evaluation result object with score,
|
|
730
|
+
metrics, and HTML.
|
|
731
|
+
"""
|
|
732
|
+
# Format the template
|
|
733
|
+
prompt = GRADER_TEMPLATE.format(
|
|
734
|
+
question=raw_result['problem'],
|
|
735
|
+
response=raw_result['response'],
|
|
736
|
+
correct_answer=raw_result['expected_answer'],
|
|
737
|
+
)
|
|
738
|
+
if grader:
|
|
739
|
+
grader_in_process = grader.clone()
|
|
740
|
+
else:
|
|
741
|
+
grader_in_process = ChatAgent("You are a helpful assistant.")
|
|
742
|
+
|
|
743
|
+
# Create a conversation list for the result
|
|
744
|
+
convo = [
|
|
745
|
+
Message(content=raw_result['problem'], role="user"),
|
|
746
|
+
Message(content=raw_result['response'], role="assistant"),
|
|
747
|
+
]
|
|
748
|
+
|
|
749
|
+
try:
|
|
750
|
+
response = grader_in_process.step(
|
|
751
|
+
prompt, response_format=GradingResponse
|
|
752
|
+
)
|
|
753
|
+
|
|
754
|
+
content = json.loads(response.msg.content)
|
|
755
|
+
|
|
756
|
+
grade_result = content['correct']
|
|
757
|
+
|
|
758
|
+
# Convert to binary metrics (1 for correct, 0 for incorrect)
|
|
759
|
+
is_correct = int(grade_result == "yes")
|
|
760
|
+
is_incorrect = int(grade_result == "no")
|
|
761
|
+
|
|
762
|
+
# Set the score (1 for correct, 0 for incorrect)
|
|
763
|
+
score = is_correct
|
|
764
|
+
|
|
765
|
+
# Generate HTML representation of the result
|
|
766
|
+
html = self.jinja_env.from_string(HTML_JINJA).render(
|
|
767
|
+
prompt_messages=Message(
|
|
768
|
+
content=raw_result.get('problem', ''), role="user"
|
|
769
|
+
),
|
|
770
|
+
next_message=Message(
|
|
771
|
+
content=raw_result.get('response', ''),
|
|
772
|
+
role="assistant",
|
|
773
|
+
),
|
|
774
|
+
score=score,
|
|
775
|
+
correct_answer=raw_result.get('expected_answer', ''),
|
|
776
|
+
extracted_answer=raw_result.get('response_dict', {}).get(
|
|
777
|
+
'exact_answer', ''
|
|
778
|
+
),
|
|
779
|
+
)
|
|
780
|
+
# Return the evaluation result
|
|
781
|
+
return SingleEvalResult(
|
|
782
|
+
html=html,
|
|
783
|
+
score=score,
|
|
784
|
+
convo=convo,
|
|
785
|
+
metrics={
|
|
786
|
+
"is_correct": is_correct,
|
|
787
|
+
"is_incorrect": is_incorrect,
|
|
788
|
+
},
|
|
789
|
+
)
|
|
790
|
+
except Exception as e:
|
|
791
|
+
# Log any errors that occur during evaluation
|
|
792
|
+
logger.error(f"Error evaluating result: {e}")
|
|
793
|
+
logger.error(traceback.format_exc())
|
|
794
|
+
html = self.jinja_env.from_string(HTML_JINJA).render(
|
|
795
|
+
prompt_messages=Message(
|
|
796
|
+
content=raw_result.get('problem', ''), role="user"
|
|
797
|
+
),
|
|
798
|
+
next_message=Message(
|
|
799
|
+
content=raw_result.get('response', ''),
|
|
800
|
+
role="assistant",
|
|
801
|
+
),
|
|
802
|
+
score=0,
|
|
803
|
+
correct_answer=raw_result.get('expected_answer', ''),
|
|
804
|
+
extracted_answer=raw_result.get('response_dict', {}).get(
|
|
805
|
+
'exact_answer', ''
|
|
806
|
+
),
|
|
807
|
+
)
|
|
808
|
+
return SingleEvalResult(
|
|
809
|
+
html=html,
|
|
810
|
+
score=0,
|
|
811
|
+
convo=convo,
|
|
812
|
+
metrics={
|
|
813
|
+
"is_correct": 0,
|
|
814
|
+
"is_incorrect": 1,
|
|
815
|
+
},
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
pool_class = ThreadPool
|
|
819
|
+
with pool_class(min(self.processes, len(self._raw_results))) as pool:
|
|
820
|
+
self._validated_results = list(
|
|
821
|
+
tqdm(
|
|
822
|
+
pool.imap(validate_each_one, self._raw_results),
|
|
823
|
+
total=len(self._raw_results),
|
|
824
|
+
)
|
|
825
|
+
)
|
|
826
|
+
|
|
827
|
+
aggregate_metrics = {
|
|
828
|
+
"is_correct": sum(
|
|
829
|
+
result.metrics["is_correct"]
|
|
830
|
+
for result in self._validated_results
|
|
831
|
+
)
|
|
832
|
+
/ len(self._validated_results),
|
|
833
|
+
"is_incorrect": sum(
|
|
834
|
+
result.metrics["is_incorrect"]
|
|
835
|
+
for result in self._validated_results
|
|
836
|
+
)
|
|
837
|
+
/ len(self._validated_results),
|
|
838
|
+
}
|
|
839
|
+
logger.info("AGGREGATE METRICS")
|
|
840
|
+
logger.info(aggregate_metrics)
|
|
841
|
+
logger.info("##################")
|
|
842
|
+
|
|
843
|
+
output_d = {
|
|
844
|
+
"accuracy": aggregate_metrics["is_correct"],
|
|
845
|
+
}
|
|
846
|
+
|
|
847
|
+
logger.info(f"Accuracy: {output_d['accuracy']:.3f}")
|
|
848
|
+
|
|
849
|
+
self._eval_result = aggregate_results(self._validated_results)
|
|
850
|
+
# ^^^ how to use a sampler
|
|
851
|
+
report_filename = self.save_to
|
|
852
|
+
logger.info(f"Writing report to {report_filename}")
|
|
853
|
+
with open(report_filename, "w") as fh:
|
|
854
|
+
fh.write(self.make_report(self._eval_result))
|