camel-ai 0.2.58__py3-none-any.whl → 0.2.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (44) hide show
  1. camel/__init__.py +1 -1
  2. camel/agents/chat_agent.py +126 -9
  3. camel/agents/critic_agent.py +73 -8
  4. camel/benchmarks/__init__.py +2 -0
  5. camel/benchmarks/browsecomp.py +854 -0
  6. camel/configs/cohere_config.py +1 -1
  7. camel/configs/mistral_config.py +1 -1
  8. camel/configs/openai_config.py +3 -0
  9. camel/configs/reka_config.py +1 -1
  10. camel/configs/samba_config.py +2 -2
  11. camel/datagen/cot_datagen.py +29 -34
  12. camel/embeddings/jina_embedding.py +8 -1
  13. camel/embeddings/sentence_transformers_embeddings.py +2 -2
  14. camel/embeddings/vlm_embedding.py +9 -2
  15. camel/human.py +14 -0
  16. camel/memories/records.py +3 -0
  17. camel/messages/base.py +15 -3
  18. camel/models/azure_openai_model.py +1 -0
  19. camel/models/model_factory.py +2 -2
  20. camel/retrievers/bm25_retriever.py +1 -2
  21. camel/retrievers/hybrid_retrival.py +2 -2
  22. camel/societies/role_playing.py +50 -0
  23. camel/societies/workforce/role_playing_worker.py +17 -8
  24. camel/societies/workforce/workforce.py +70 -14
  25. camel/storages/vectordb_storages/oceanbase.py +1 -2
  26. camel/toolkits/async_browser_toolkit.py +5 -1
  27. camel/toolkits/base.py +4 -2
  28. camel/toolkits/browser_toolkit.py +6 -3
  29. camel/toolkits/dalle_toolkit.py +4 -0
  30. camel/toolkits/excel_toolkit.py +11 -3
  31. camel/toolkits/github_toolkit.py +43 -25
  32. camel/toolkits/image_analysis_toolkit.py +3 -0
  33. camel/toolkits/jina_reranker_toolkit.py +194 -77
  34. camel/toolkits/mcp_toolkit.py +60 -16
  35. camel/toolkits/page_script.js +40 -28
  36. camel/toolkits/twitter_toolkit.py +6 -1
  37. camel/toolkits/video_analysis_toolkit.py +3 -0
  38. camel/toolkits/video_download_toolkit.py +3 -0
  39. camel/toolkits/wolfram_alpha_toolkit.py +46 -22
  40. camel/types/enums.py +14 -5
  41. {camel_ai-0.2.58.dist-info → camel_ai-0.2.60.dist-info}/METADATA +7 -9
  42. {camel_ai-0.2.58.dist-info → camel_ai-0.2.60.dist-info}/RECORD +44 -43
  43. {camel_ai-0.2.58.dist-info → camel_ai-0.2.60.dist-info}/WHEEL +0 -0
  44. {camel_ai-0.2.58.dist-info → camel_ai-0.2.60.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,854 @@
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ import base64
16
+ import hashlib
17
+ import json
18
+ import os
19
+ import random
20
+ import traceback
21
+ from collections import defaultdict
22
+ from multiprocessing.pool import ThreadPool
23
+ from typing import Any, Dict, List, Optional, Tuple, Union
24
+
25
+ from pydantic import BaseModel, Field
26
+
27
+ from camel.agents.chat_agent import ChatAgent
28
+ from camel.benchmarks.base import BaseBenchmark
29
+ from camel.logger import get_logger
30
+ from camel.societies.role_playing import RolePlaying
31
+ from camel.societies.workforce.workforce import Workforce
32
+ from camel.tasks.task import Task
33
+
34
+ logger = get_logger(__name__)
35
+
36
+
37
+ class Message(BaseModel):
38
+ role: str
39
+ content: str
40
+ variant: Optional[str] = None
41
+
42
+
43
+ MessageList = List[Message]
44
+
45
+
46
+ class QueryResponse(BaseModel):
47
+ r"""A structured query response for benchmark evaluation.
48
+
49
+ This class defines the expected format for model responses to benchmark
50
+ questions, including explanation, exact answer, and confidence score.
51
+ """
52
+
53
+ explanation: str = Field(
54
+ description="""your explanation for your final answer."""
55
+ )
56
+ exact_answer: str = Field(description="""your succinct, final answer.""")
57
+ confidence: str = Field(
58
+ description="""
59
+ your confidence score between 0|\%| and 100|\%| for your answer.
60
+ """
61
+ )
62
+
63
+
64
+ class GradingResponse(BaseModel):
65
+ r"""A structured grading response for evaluating model answers.
66
+
67
+ This class defines the expected format for grading responses, including
68
+ extracted answer, reasoning about correctness, binary correctness judgment,
69
+ and confidence score extraction.
70
+ """
71
+
72
+ extracted_final_answer: str = Field(
73
+ description="""
74
+ The final exact answer extracted from the [response].
75
+ Put the extracted answer as 'None' if there is no exact, final answer to
76
+ extract from the response."""
77
+ )
78
+ reasoning: str = Field(
79
+ description="""
80
+ Explain why the extracted_final_answer is correct or incorrect
81
+ based on [correct_answer], focusing only on if there are meaningful
82
+ differences between [correct_answer] and the extracted_final_answer.
83
+ Do not comment on any background to the problem, do not attempt
84
+ to solve the problem, do not argue for any answer different
85
+ than [correct_answer], focus only on whether the answers match."""
86
+ )
87
+ correct: str = Field(
88
+ description="""Answer 'yes' if extracted_final_answer matches the
89
+ [correct_answer] given above, or is within a small margin of error for
90
+ numerical problems. Answer 'no' otherwise, i.e. if there if there is any
91
+ inconsistency, ambiguity, non-equivalency, or if the extracted answer is
92
+ incorrect."""
93
+ )
94
+ confidence: str = Field(
95
+ description="""The extracted confidence score between 0|\%|
96
+ and 100|\%| from [response]. Put 100 if there is no confidence score available.
97
+ """
98
+ )
99
+
100
+
101
+ class SingleEvalResult(BaseModel):
102
+ r"""Result of evaluating a single benchmark sample.
103
+
104
+ This class stores the evaluation results for a single benchmark example,
105
+ including score, HTML representation, conversation history, and metrics.
106
+ """
107
+
108
+ score: Optional[float] = None
109
+ html: str
110
+ convo: MessageList
111
+ metrics: Dict[str, float] = Field(default_factory=dict)
112
+
113
+
114
+ class EvalResult(BaseModel):
115
+ r"""Result of running a complete benchmark evaluation.
116
+
117
+ This class aggregates results from multiple sample evaluations, storing
118
+ the overall score, detailed metrics, HTML reports, and conversation logs.
119
+ """
120
+
121
+ score: Optional[float] = None # top-line metric
122
+ metrics: Optional[Dict[str, float]] = None # other metrics
123
+ htmls: List[str] # strings of valid HTML
124
+ convos: List[MessageList] # sampled conversations
125
+
126
+
127
+ # Define the message template first
128
+ _message_template = """
129
+ <div class="message {{ role }}">
130
+ <div class="role">
131
+ {{ role }}
132
+ {% if variant %}<span class="variant">({{ variant }})</span>{% endif %}
133
+ </div>
134
+ <div class="content">
135
+ <pre>{{ content }}</pre>
136
+ </div>
137
+ </div>
138
+ """
139
+
140
+ # TODO: Add necessary prompts when tuning.
141
+ QUERY_TEMPLATE = """
142
+ {question}
143
+
144
+ Your response should be in the following format:
145
+ Explanation: {{your explanation for your final answer}}
146
+ Exact Answer: {{your succinct, final answer}}
147
+ Confidence: {{your confidence score between 0% and 100% for your answer}}
148
+ """.strip()
149
+
150
+ SUMMARIZE_TEMPLATE = """
151
+ Based on the chat history:
152
+ {chat_history}
153
+
154
+ answer the question:
155
+ {query}
156
+ """
157
+
158
+ FORMAT_JSON_TEMPLATE = """
159
+ format content into json:
160
+ {content}
161
+ """
162
+
163
+ GRADER_TEMPLATE = """
164
+ Judge whether the following [response] to [question] is correct or not
165
+ based on the precise and unambiguous [correct_answer] below.
166
+
167
+ [question]: {question}
168
+
169
+ [response]: {response}
170
+
171
+ Your judgement must be in the format and criteria specified below:
172
+
173
+ extracted_final_answer: The final exact answer extracted from the [response].
174
+ Put the extracted answer as 'None' if there is no exact, final answer to
175
+ extract from the response.
176
+
177
+ [correct_answer]: {correct_answer}
178
+
179
+ reasoning: Explain why the extracted_final_answer is correct or incorrect
180
+ based on [correct_answer], focusing only on if there are meaningful
181
+ differences between [correct_answer] and the extracted_final_answer.
182
+ Do not comment on any background to the problem, do not attempt
183
+ to solve the problem, do not argue for any answer different
184
+ than [correct_answer], focus only on whether the answers match.
185
+
186
+ correct: Answer 'yes' if extracted_final_answer matches the
187
+ [correct_answer] given above, or is within a small margin of error for
188
+ numerical problems. Answer 'no' otherwise, i.e. if there is any
189
+ inconsistency, ambiguity, non-equivalency, or if the extracted answer is
190
+ incorrect.
191
+
192
+
193
+ confidence: The extracted confidence score between 0|\%| and 100|\%|
194
+ from [response]. Put 100 if there is no confidence score available.
195
+ """.strip()
196
+
197
+
198
+ HTML_JINJA = """
199
+ <h3>Question:</h3>
200
+ {{ message_to_html(prompt_messages) | safe }}
201
+ <h3>Sampled message</h3>
202
+ {{ message_to_html(next_message) | safe }}
203
+ <h3>Results</h3>
204
+ <p>Correct Answer: {{ correct_answer }}</p>
205
+ <p>Extracted Answer: {{ extracted_answer }}</p>
206
+ <p>Score: {{ score }}</p>
207
+ """
208
+ _report_template = """<!DOCTYPE html>
209
+ <html>
210
+ <head>
211
+ <style>
212
+ .message {
213
+ padding: 8px 16px;
214
+ margin-bottom: 8px;
215
+ border-radius: 4px;
216
+ }
217
+ .message.user {
218
+ background-color: #B2DFDB;
219
+ color: #00695C;
220
+ }
221
+ .message.assistant {
222
+ background-color: #B39DDB;
223
+ color: #4527A0;
224
+ }
225
+ .message.system {
226
+ background-color: #EEEEEE;
227
+ color: #212121;
228
+ }
229
+ .role {
230
+ font-weight: bold;
231
+ margin-bottom: 4px;
232
+ }
233
+ .variant {
234
+ color: #795548;
235
+ }
236
+ table, th, td {
237
+ border: 1px solid black;
238
+ }
239
+ pre {
240
+ white-space: pre-wrap;
241
+ }
242
+ </style>
243
+ </head>
244
+ <body>
245
+ {% if metrics %}
246
+ <h1>Metrics</h1>
247
+ <table>
248
+ <tr>
249
+ <th>Metric</th>
250
+ <th>Value</th>
251
+ </tr>
252
+ <tr>
253
+ <td><b>Score</b></td>
254
+ <td>{{ score | float | round(3) }}</td>
255
+ </tr>
256
+ {% for name, value in metrics.items() %}
257
+ <tr>
258
+ <td>{{ name }}</td>
259
+ <td>{{ value }}</td>
260
+ </tr>
261
+ {% endfor %}
262
+ </table>
263
+ {% endif %}
264
+ <h1>Examples</h1>
265
+ {% for html in htmls %}
266
+ {{ html | safe }}
267
+ <hr>
268
+ {% endfor %}
269
+ </body>
270
+ </html>
271
+ """
272
+
273
+
274
+ class JinjaEnv:
275
+ r"""A class that encapsulates the Jinja environment setup."""
276
+
277
+ _instance: Optional['JinjaEnv'] = None
278
+ _env = None
279
+
280
+ def __init__(self):
281
+ r"""Initialize the JinjaEnv instance if not already initialized."""
282
+ if not getattr(self, '_initialized', False):
283
+ self._initialized = True
284
+
285
+ def __new__(cls):
286
+ r"""Implement singleton pattern to ensure only one instance exists."""
287
+ if cls._instance is None:
288
+ cls._instance = super(JinjaEnv, cls).__new__(cls)
289
+ cls._instance._initialized = False
290
+ return cls._instance
291
+
292
+ @classmethod
293
+ def get_instance(cls):
294
+ r"""Get the singleton instance of JinjaEnv.
295
+
296
+ Returns:
297
+ JinjaEnv: The singleton instance.
298
+ """
299
+ if cls._instance is None:
300
+ cls._instance = cls()
301
+ return cls._instance
302
+
303
+ @property
304
+ def env(self):
305
+ r"""Lazily initialize and return the Jinja environment.
306
+
307
+ Returns:
308
+ jinja2.Environment: The Jinja environment instance.
309
+ """
310
+ if self._env is None:
311
+ # Lazy import of jinja2
312
+ import jinja2
313
+
314
+ # Create the Jinja environment
315
+ self._env = jinja2.Environment(
316
+ loader=jinja2.BaseLoader(),
317
+ undefined=jinja2.StrictUndefined,
318
+ autoescape=jinja2.select_autoescape(["html", "xml"]),
319
+ )
320
+
321
+ # Register the message_to_html function
322
+ self._env.globals["message_to_html"] = self.message_to_html
323
+
324
+ return self._env
325
+
326
+ def from_string(self, template_str):
327
+ r"""Create a template from the given string.
328
+
329
+ Args:
330
+ template_str (str): The template string.
331
+
332
+ Returns:
333
+ jinja2.Template: The compiled template.
334
+ """
335
+ return self.env.from_string(template_str)
336
+
337
+ @staticmethod
338
+ def message_to_html(message: Message) -> str:
339
+ r"""Generate HTML snippet (inside a <div>) for a message.
340
+
341
+ Args:
342
+ message (Message): The message to convert to HTML.
343
+
344
+ Returns:
345
+ str: The HTML representation of the message.
346
+ """
347
+ return (
348
+ JinjaEnv.get_instance()
349
+ .from_string(_message_template)
350
+ .render(
351
+ role=message.role,
352
+ content=message.content,
353
+ variant=message.variant,
354
+ )
355
+ )
356
+
357
+
358
+ def derive_key(password: str, length: int) -> bytes:
359
+ r"""Derive a fixed-length key from the password using SHA256."""
360
+ hasher = hashlib.sha256()
361
+ hasher.update(password.encode())
362
+ key = hasher.digest()
363
+ return key * (length // len(key)) + key[: length % len(key)]
364
+
365
+
366
+ def decrypt(ciphertext_b64: str, password: str) -> str:
367
+ r"""Decrypt base64-encoded ciphertext with XOR."""
368
+ encrypted = base64.b64decode(ciphertext_b64)
369
+ key = derive_key(password, len(encrypted))
370
+ decrypted = bytes(a ^ b for a, b in zip(encrypted, key))
371
+ return decrypted.decode()
372
+
373
+
374
+ def _compute_stat(values: list, stat: str):
375
+ import numpy as np
376
+
377
+ if stat == "mean":
378
+ return np.mean(values)
379
+ elif stat == "std":
380
+ return np.std(values)
381
+ elif stat == "min":
382
+ return np.min(values)
383
+ elif stat == "max":
384
+ return np.max(values)
385
+ else:
386
+ raise ValueError(f"Unknown {stat =}")
387
+
388
+
389
+ def aggregate_results(
390
+ single_eval_results: List[SingleEvalResult],
391
+ default_stats: Tuple[str, str] = ("mean", "std"),
392
+ name2stats: Optional[Dict[str, Tuple[str]]] = None,
393
+ ) -> EvalResult:
394
+ r"""Aggregate results from multiple evaluations into a single EvalResult.
395
+
396
+ Args:
397
+ single_eval_results (List[SingleEvalResult]): A list of
398
+ `SingleEvalResult` objects.
399
+ default_stats (Tuple[str, str]): A tuple of default statistics to
400
+ compute. (default: :obj:`("mean", "std")`)
401
+ name2stats (Optional[Dict[str, Tuple[str]]]): A dictionary mapping
402
+ metric names to statistics to compute. (default: :obj:`None`)
403
+
404
+ Returns:
405
+ EvalResult: An `EvalResult` object containing aggregated results.
406
+ """
407
+ name2stats = name2stats or {}
408
+ name2values = defaultdict(list)
409
+ htmls = []
410
+ convos = []
411
+
412
+ for single_eval_result in single_eval_results:
413
+ for name, value in single_eval_result.metrics.items():
414
+ name2values[name].append(value)
415
+ if single_eval_result.score is not None:
416
+ name2values["score"].append(single_eval_result.score)
417
+ htmls.append(single_eval_result.html)
418
+ convos.append(single_eval_result.convo)
419
+
420
+ final_metrics = {}
421
+ for name, values in name2values.items():
422
+ stats = name2stats.get(name, default_stats)
423
+ for stat in stats:
424
+ key = name if stat == "mean" else f"{name}:{stat}"
425
+ final_metrics[key] = _compute_stat(values, stat)
426
+
427
+ return EvalResult(
428
+ score=final_metrics.pop("score", None),
429
+ metrics=final_metrics,
430
+ htmls=htmls,
431
+ convos=convos,
432
+ )
433
+
434
+
435
+ class BrowseCompBenchmark(BaseBenchmark):
436
+ r"""BrowseComp Benchmark for evaluating browser-based comprehension tasks.
437
+
438
+ This benchmark evaluates the ability of language models to comprehend and
439
+ answer questions based on browser-based content, measuring accuracy and
440
+ performance.
441
+ """
442
+
443
+ def __init__(
444
+ self,
445
+ save_to: str,
446
+ processes: int = 1,
447
+ num_examples: Optional[int] = None,
448
+ n_repeats: int = 1,
449
+ ):
450
+ r"""Initialize the BrowseComp benchmark.
451
+
452
+ Args:
453
+ save_to (str): The file to save the results.
454
+ processes (int, optional): The number of processes to use for
455
+ parallel processing. (default: :obj:`1`)
456
+ num_examples (Optional[int]): Number of examples to evaluate.
457
+ If None, all examples are used. Controls the sample size for
458
+ testing. (default: :obj:`None`)
459
+ n_repeats (int, optional): Number of times to repeat each example.
460
+ Useful for evaluating consistency across multiple runs.
461
+ (default: :obj:`1`)
462
+ """
463
+ # Browsecomp benchmark won't download any data
464
+ # use current path as the data_dir passing into super init
465
+ current_path = os.path.dirname(os.path.abspath(__file__))
466
+
467
+ super().__init__("browsecomp", current_path, save_to, processes)
468
+ self.num_examples = num_examples
469
+ self.n_repeats = n_repeats
470
+ self.examples: List[Dict[str, Any]] = []
471
+ self.load()
472
+ self._raw_results: List[Any] = []
473
+ self._validated_results: List[SingleEvalResult] = []
474
+ self._eval_result: EvalResult
475
+ self.jinja_env = JinjaEnv.get_instance()
476
+
477
+ def download(self):
478
+ r"""Download the BrowseComp dataset.
479
+
480
+ This method is implemented to maintain compatibility
481
+ with the BaseBenchmark interface, but BrowseComp doesn't
482
+ require downloading data separately.
483
+
484
+ Returns:
485
+ self: The benchmark instance
486
+ """
487
+ logger.info("BrowseComp benchmark does not require downloading data.")
488
+ return self
489
+
490
+ def load(self):
491
+ r"""Load the BrowseComp dataset.
492
+
493
+ This method loads the dataset from a remote CSV file, converts each
494
+ row to a dictionary, and applies sampling if num_examples is
495
+ specified. It also handles repeating examples if n_repeats > 1.
496
+
497
+ Returns:
498
+ self: The benchmark instance
499
+ """
500
+ # Load dataset from remote CSV
501
+ import pandas
502
+
503
+ df = pandas.read_csv(
504
+ "https://openaipublic.blob.core.windows.net/simple-evals/browse_comp_test_set.csv"
505
+ )
506
+ # Convert each row to a dictionary
507
+ examples = [row.to_dict() for _, row in df.iterrows()]
508
+
509
+ # Sample examples if num_examples is specified
510
+ if self.num_examples:
511
+ assert (
512
+ self.n_repeats == 1
513
+ ), "n_repeats only supported when max_examples = None"
514
+ rng = random.Random(0) # Use fixed seed for reproducibility
515
+ examples = rng.sample(examples, self.num_examples)
516
+
517
+ # Repeat examples if n_repeats > 1
518
+ self.examples = examples * self.n_repeats
519
+ return self
520
+
521
+ @property
522
+ def train(self):
523
+ r"""Get the training set.
524
+
525
+ This property is implemented to maintain compatibility with
526
+ the BaseBenchmark interface, but BrowseComp doesn't have a
527
+ training set.
528
+
529
+ Raises:
530
+ NotImplementedError: BrowseComp does not have a training set.
531
+ """
532
+ raise NotImplementedError("BrowseComp does not have a training set.")
533
+
534
+ def run( # type: ignore[override]
535
+ self,
536
+ pipeline_template: Union[ChatAgent, RolePlaying, Workforce],
537
+ chat_turn_limit: int = 10,
538
+ roleplaying_summarizer: Optional[ChatAgent] = None,
539
+ task_json_formatter: Optional[ChatAgent] = None,
540
+ ) -> None:
541
+ r"""Run the benchmark by processing each example in parallel.
542
+
543
+ This method applies the provided pipeline to each example in the
544
+ dataset using a process pool for parallel execution. It shows progress
545
+ using tqdm and stores the results in self._raw_results.
546
+
547
+ Args:
548
+ pipeline_template (Union[ChatAgent, RolePlaying, Workforce]): The
549
+ template agent or framework to use for processing examples.
550
+ Can be a ChatAgent, RolePlaying, or Workforce instance that
551
+ will be cloned for each example.
552
+ chat_turn_limit (int): Maximum number of conversation turns allowed
553
+ when using RolePlaying pipeline. (default: :obj:`10`)
554
+ roleplaying_summarizer (Optional[ChatAgent]): Optional ChatAgent to
555
+ summarize RolePlaying conversations. If None and RolePlaying is
556
+ used, a default summarizer will be created.
557
+ (default: :obj:`None`)
558
+ task_json_formatter (Optional[ChatAgent]): Optional ChatAgent to
559
+ format task JSON. If None and Workforce is used, a default
560
+ formatter will be created. (default: :obj:`None`)
561
+ """
562
+ from tqdm import tqdm
563
+
564
+ # Use a process pool for parallel execution
565
+ def process_benchmark_row(row: Dict[str, Any]) -> Dict[str, Any]:
566
+ r"""This inner function processes a single benchmark row by
567
+ extracting the problem and answer, creating a pipeline instance,
568
+ and generating a response using the appropriate method based on
569
+ the pipeline type.
570
+
571
+ Args:
572
+ row (Dict[str, Any]): A row from the dataset containing
573
+ encrypted problem and answer, along with a canary for
574
+ decryption.
575
+
576
+ Returns:
577
+ Dict[str, Any]: A dictionary containing the decrypted problem,
578
+ expected answer, model response, and structured response
579
+ fields.
580
+ """
581
+
582
+ problem = decrypt(row.get("problem", ""), row.get("canary", ""))
583
+ answer = decrypt(row.get("answer", ""), row.get("canary", ""))
584
+ try:
585
+ input_message = QUERY_TEMPLATE.format(question=problem)
586
+
587
+ if isinstance(pipeline_template, (ChatAgent)):
588
+ pipeline = pipeline_template.clone() # type: ignore[assignment]
589
+
590
+ response_text = pipeline.step(
591
+ input_message, response_format=QueryResponse
592
+ )
593
+ elif isinstance(pipeline_template, Workforce):
594
+ pipeline = pipeline_template.clone() # type: ignore[assignment]
595
+ task = Task(content=input_message, id="0")
596
+ task = pipeline.process_task(task) # type: ignore[attr-defined]
597
+ if task_json_formatter:
598
+ formatter_in_process = task_json_formatter.clone()
599
+ else:
600
+ formatter_in_process = ChatAgent(
601
+ "You are a helpful assistant."
602
+ )
603
+ response_text = formatter_in_process.step(
604
+ FORMAT_JSON_TEMPLATE.format(content=task.result),
605
+ response_format=QueryResponse,
606
+ )
607
+
608
+ elif isinstance(pipeline_template, RolePlaying):
609
+ # RolePlaying is different.
610
+ pipeline = pipeline_template.clone( # type: ignore[assignment]
611
+ task_prompt=input_message
612
+ )
613
+
614
+ n = 0
615
+ input_msg = pipeline.init_chat() # type: ignore[attr-defined]
616
+ chat_history = []
617
+ while n < chat_turn_limit:
618
+ n += 1
619
+ assistant_response, user_response = pipeline.step(
620
+ input_msg
621
+ )
622
+ if assistant_response.terminated: # type: ignore[attr-defined]
623
+ break
624
+ if user_response.terminated: # type: ignore[attr-defined]
625
+ break
626
+ if "CAMEL_TASK_DONE" in user_response.msg.content: # type: ignore[attr-defined]
627
+ break
628
+
629
+ chat_history.append(
630
+ f"AI User: {user_response.msg.content}" # type: ignore[attr-defined]
631
+ )
632
+ chat_history.append(
633
+ f"AI Assistant: {assistant_response.msg.content}" # type: ignore[attr-defined]
634
+ )
635
+ input_msg = assistant_response.msg # type: ignore[attr-defined]
636
+
637
+ chat_history_str = "\n".join(chat_history)
638
+ if roleplaying_summarizer:
639
+ summarizer_in_process = roleplaying_summarizer.clone()
640
+ else:
641
+ summarizer_in_process = ChatAgent(
642
+ "You are a helpful assistant."
643
+ )
644
+
645
+ summarize_prompt = SUMMARIZE_TEMPLATE.format(
646
+ chat_history=chat_history_str,
647
+ query=input_message,
648
+ )
649
+ response_text = summarizer_in_process.step(
650
+ summarize_prompt, response_format=QueryResponse
651
+ )
652
+ else:
653
+ raise NotImplementedError(
654
+ f"{type(pipeline_template)} is not supported."
655
+ )
656
+ # Parse the response JSON
657
+ response_dict = json.loads(response_text.msg.content)
658
+
659
+ # Format the response as a key-value string
660
+ formatted_response = f"""
661
+ Explanation: {response_dict['explanation']}
662
+
663
+ Exact Answer: {response_dict['exact_answer']}
664
+ Confidence: {response_dict['confidence']}"""
665
+
666
+ # Create the result dictionary
667
+ raw_result = {}
668
+ raw_result['problem'] = problem
669
+ raw_result['expected_answer'] = answer
670
+ raw_result['response'] = formatted_response
671
+ # Keep the original dict for reference
672
+ raw_result['response_dict'] = response_dict
673
+
674
+ return raw_result
675
+ except Exception as e:
676
+ # Log any errors that occur during evaluation
677
+ logger.error(f"Error evaluating result: {e}")
678
+ logger.error(traceback.format_exc())
679
+ return {
680
+ 'problem': problem,
681
+ 'expected_answer': answer,
682
+ 'response': traceback.format_exc(),
683
+ 'response_dict': {},
684
+ }
685
+
686
+ pool_class = ThreadPool
687
+ with pool_class(min(self.processes, len(self.examples))) as pool:
688
+ self._raw_results = list(
689
+ tqdm(
690
+ pool.imap(process_benchmark_row, self.examples),
691
+ total=len(self.examples),
692
+ )
693
+ )
694
+
695
+ def make_report(self, eval_result: EvalResult) -> str:
696
+ r"""Create a standalone HTML report from an EvalResult."""
697
+ return self.jinja_env.from_string(_report_template).render(
698
+ score=eval_result.score,
699
+ metrics=eval_result.metrics,
700
+ htmls=eval_result.htmls,
701
+ )
702
+
703
+ def validate(self, grader: Optional[ChatAgent] = None) -> None:
704
+ r"""Validate the raw results using the GRADER_TEMPLATE and ChatAgent.
705
+
706
+ This method evaluates the correctness of each response by
707
+ multi-threading. A dedicated chat agent is created in each thread.
708
+ The chat agent will compare raw result with the expected answer. The
709
+ grading results will be aggregated in a report.
710
+
711
+ Args:
712
+ grader: The ChatAgent used for validation. If None, a default
713
+ agent will be created in each thread. If provided, the
714
+ provided agent will be used as a template and be cloned into
715
+ new agents in each thread. (default: :obj:`None`)
716
+ """
717
+ from tqdm import tqdm
718
+
719
+ def validate_each_one(raw_result: Dict[str, Any]) -> SingleEvalResult:
720
+ r"""This inner function formats the prompt for the ChatAgent
721
+ grader, sends it for evaluation, extracts the correctness
722
+ assessment, and creates an HTML representation of the result.
723
+
724
+ Args:
725
+ raw_result (Dict[str, Any]): A dictionary containing 'problem',
726
+ 'response', and 'expected_answer' keys.
727
+
728
+ Returns:
729
+ SingleEvalResult: An evaluation result object with score,
730
+ metrics, and HTML.
731
+ """
732
+ # Format the template
733
+ prompt = GRADER_TEMPLATE.format(
734
+ question=raw_result['problem'],
735
+ response=raw_result['response'],
736
+ correct_answer=raw_result['expected_answer'],
737
+ )
738
+ if grader:
739
+ grader_in_process = grader.clone()
740
+ else:
741
+ grader_in_process = ChatAgent("You are a helpful assistant.")
742
+
743
+ # Create a conversation list for the result
744
+ convo = [
745
+ Message(content=raw_result['problem'], role="user"),
746
+ Message(content=raw_result['response'], role="assistant"),
747
+ ]
748
+
749
+ try:
750
+ response = grader_in_process.step(
751
+ prompt, response_format=GradingResponse
752
+ )
753
+
754
+ content = json.loads(response.msg.content)
755
+
756
+ grade_result = content['correct']
757
+
758
+ # Convert to binary metrics (1 for correct, 0 for incorrect)
759
+ is_correct = int(grade_result == "yes")
760
+ is_incorrect = int(grade_result == "no")
761
+
762
+ # Set the score (1 for correct, 0 for incorrect)
763
+ score = is_correct
764
+
765
+ # Generate HTML representation of the result
766
+ html = self.jinja_env.from_string(HTML_JINJA).render(
767
+ prompt_messages=Message(
768
+ content=raw_result.get('problem', ''), role="user"
769
+ ),
770
+ next_message=Message(
771
+ content=raw_result.get('response', ''),
772
+ role="assistant",
773
+ ),
774
+ score=score,
775
+ correct_answer=raw_result.get('expected_answer', ''),
776
+ extracted_answer=raw_result.get('response_dict', {}).get(
777
+ 'exact_answer', ''
778
+ ),
779
+ )
780
+ # Return the evaluation result
781
+ return SingleEvalResult(
782
+ html=html,
783
+ score=score,
784
+ convo=convo,
785
+ metrics={
786
+ "is_correct": is_correct,
787
+ "is_incorrect": is_incorrect,
788
+ },
789
+ )
790
+ except Exception as e:
791
+ # Log any errors that occur during evaluation
792
+ logger.error(f"Error evaluating result: {e}")
793
+ logger.error(traceback.format_exc())
794
+ html = self.jinja_env.from_string(HTML_JINJA).render(
795
+ prompt_messages=Message(
796
+ content=raw_result.get('problem', ''), role="user"
797
+ ),
798
+ next_message=Message(
799
+ content=raw_result.get('response', ''),
800
+ role="assistant",
801
+ ),
802
+ score=0,
803
+ correct_answer=raw_result.get('expected_answer', ''),
804
+ extracted_answer=raw_result.get('response_dict', {}).get(
805
+ 'exact_answer', ''
806
+ ),
807
+ )
808
+ return SingleEvalResult(
809
+ html=html,
810
+ score=0,
811
+ convo=convo,
812
+ metrics={
813
+ "is_correct": 0,
814
+ "is_incorrect": 1,
815
+ },
816
+ )
817
+
818
+ pool_class = ThreadPool
819
+ with pool_class(min(self.processes, len(self._raw_results))) as pool:
820
+ self._validated_results = list(
821
+ tqdm(
822
+ pool.imap(validate_each_one, self._raw_results),
823
+ total=len(self._raw_results),
824
+ )
825
+ )
826
+
827
+ aggregate_metrics = {
828
+ "is_correct": sum(
829
+ result.metrics["is_correct"]
830
+ for result in self._validated_results
831
+ )
832
+ / len(self._validated_results),
833
+ "is_incorrect": sum(
834
+ result.metrics["is_incorrect"]
835
+ for result in self._validated_results
836
+ )
837
+ / len(self._validated_results),
838
+ }
839
+ logger.info("AGGREGATE METRICS")
840
+ logger.info(aggregate_metrics)
841
+ logger.info("##################")
842
+
843
+ output_d = {
844
+ "accuracy": aggregate_metrics["is_correct"],
845
+ }
846
+
847
+ logger.info(f"Accuracy: {output_d['accuracy']:.3f}")
848
+
849
+ self._eval_result = aggregate_results(self._validated_results)
850
+ # ^^^ how to use a sampler
851
+ report_filename = self.save_to
852
+ logger.info(f"Writing report to {report_filename}")
853
+ with open(report_filename, "w") as fh:
854
+ fh.write(self.make_report(self._eval_result))