camel-ai 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (69) hide show
  1. camel/__init__.py +1 -1
  2. camel/agents/chat_agent.py +362 -237
  3. camel/benchmarks/__init__.py +11 -1
  4. camel/benchmarks/apibank.py +560 -0
  5. camel/benchmarks/apibench.py +496 -0
  6. camel/benchmarks/gaia.py +2 -2
  7. camel/benchmarks/nexus.py +518 -0
  8. camel/datagen/__init__.py +21 -0
  9. camel/datagen/cotdatagen.py +448 -0
  10. camel/datagen/self_instruct/__init__.py +36 -0
  11. camel/datagen/self_instruct/filter/__init__.py +34 -0
  12. camel/datagen/self_instruct/filter/filter_function.py +216 -0
  13. camel/datagen/self_instruct/filter/filter_registry.py +56 -0
  14. camel/datagen/self_instruct/filter/instruction_filter.py +81 -0
  15. camel/datagen/self_instruct/self_instruct.py +393 -0
  16. camel/datagen/self_instruct/templates.py +384 -0
  17. camel/datahubs/huggingface.py +12 -2
  18. camel/datahubs/models.py +4 -2
  19. camel/embeddings/mistral_embedding.py +5 -1
  20. camel/embeddings/openai_compatible_embedding.py +6 -1
  21. camel/embeddings/openai_embedding.py +5 -1
  22. camel/interpreters/e2b_interpreter.py +5 -1
  23. camel/loaders/apify_reader.py +5 -1
  24. camel/loaders/chunkr_reader.py +5 -1
  25. camel/loaders/firecrawl_reader.py +0 -30
  26. camel/logger.py +11 -5
  27. camel/messages/conversion/sharegpt/hermes/hermes_function_formatter.py +4 -1
  28. camel/models/anthropic_model.py +5 -1
  29. camel/models/azure_openai_model.py +1 -2
  30. camel/models/cohere_model.py +5 -1
  31. camel/models/deepseek_model.py +5 -1
  32. camel/models/gemini_model.py +5 -1
  33. camel/models/groq_model.py +5 -1
  34. camel/models/mistral_model.py +5 -1
  35. camel/models/nemotron_model.py +5 -1
  36. camel/models/nvidia_model.py +5 -1
  37. camel/models/openai_model.py +5 -1
  38. camel/models/qwen_model.py +5 -1
  39. camel/models/reka_model.py +5 -1
  40. camel/models/reward/nemotron_model.py +5 -1
  41. camel/models/samba_model.py +5 -1
  42. camel/models/togetherai_model.py +5 -1
  43. camel/models/yi_model.py +5 -1
  44. camel/models/zhipuai_model.py +5 -1
  45. camel/retrievers/auto_retriever.py +8 -0
  46. camel/retrievers/vector_retriever.py +6 -3
  47. camel/schemas/openai_converter.py +5 -1
  48. camel/societies/role_playing.py +4 -4
  49. camel/societies/workforce/workforce.py +2 -2
  50. camel/storages/graph_storages/nebula_graph.py +119 -27
  51. camel/storages/graph_storages/neo4j_graph.py +138 -0
  52. camel/toolkits/__init__.py +4 -0
  53. camel/toolkits/arxiv_toolkit.py +20 -3
  54. camel/toolkits/dappier_toolkit.py +196 -0
  55. camel/toolkits/function_tool.py +61 -61
  56. camel/toolkits/meshy_toolkit.py +5 -1
  57. camel/toolkits/notion_toolkit.py +1 -1
  58. camel/toolkits/openbb_toolkit.py +869 -0
  59. camel/toolkits/search_toolkit.py +91 -5
  60. camel/toolkits/stripe_toolkit.py +5 -1
  61. camel/toolkits/twitter_toolkit.py +24 -16
  62. camel/types/enums.py +7 -1
  63. camel/types/unified_model_type.py +5 -0
  64. camel/utils/__init__.py +4 -0
  65. camel/utils/commons.py +142 -20
  66. {camel_ai-0.2.14.dist-info → camel_ai-0.2.16.dist-info}/METADATA +17 -5
  67. {camel_ai-0.2.14.dist-info → camel_ai-0.2.16.dist-info}/RECORD +69 -55
  68. {camel_ai-0.2.14.dist-info → camel_ai-0.2.16.dist-info}/LICENSE +0 -0
  69. {camel_ai-0.2.14.dist-info → camel_ai-0.2.16.dist-info}/WHEEL +0 -0
@@ -32,7 +32,7 @@ from typing import (
32
32
 
33
33
  from openai.types.chat import ChatCompletionMessageToolCall
34
34
  from openai.types.chat.chat_completion_message_tool_call import Function
35
- from pydantic import BaseModel
35
+ from pydantic import BaseModel, ValidationError
36
36
 
37
37
  from camel.agents.base import BaseAgent
38
38
  from camel.memories import (
@@ -59,6 +59,7 @@ from camel.types import (
59
59
  )
60
60
  from camel.utils import (
61
61
  func_string_to_callable,
62
+ generate_prompt_for_structured_output,
62
63
  get_model_encoding,
63
64
  get_pydantic_object_schema,
64
65
  json_to_function_code,
@@ -142,17 +143,21 @@ class ChatAgent(BaseAgent):
142
143
  (default: :obj:`None`)
143
144
  output_language (str, optional): The language to be output by the
144
145
  agent. (default: :obj:`None`)
145
- tools (List[FunctionTool], optional): List of available
146
- :obj:`FunctionTool`. (default: :obj:`None`)
147
- external_tools (List[FunctionTool], optional): List of external tools
148
- (:obj:`FunctionTool`) bind to one chat agent. When these tools
149
- are called, the agent will directly return the request instead of
146
+ tools (Optional[List[Union[FunctionTool, Callable]]], optional): List
147
+ of available :obj:`FunctionTool` or :obj:`Callable`. (default:
148
+ :obj:`None`)
149
+ external_tools (Optional[List[Union[FunctionTool, Callable]]],
150
+ optional): List of external tools (:obj:`FunctionTool` or or
151
+ :obj:`Callable`) bind to one chat agent. When these tools are
152
+ called, the agent will directly return the request instead of
150
153
  processing it. (default: :obj:`None`)
151
154
  response_terminators (List[ResponseTerminator], optional): List of
152
155
  :obj:`ResponseTerminator` bind to one chat agent.
153
156
  (default: :obj:`None`)
154
157
  scheduling_strategy (str): name of function that defines how to select
155
158
  the next model in ModelManager. (default: :str:`round_robin`)
159
+ single_iteration (bool): Whether to let the agent perform only one
160
+ model calling at each step. (default: :obj:`False`)
156
161
  """
157
162
 
158
163
  def __init__(
@@ -165,11 +170,13 @@ class ChatAgent(BaseAgent):
165
170
  message_window_size: Optional[int] = None,
166
171
  token_limit: Optional[int] = None,
167
172
  output_language: Optional[str] = None,
168
- tools: Optional[List[FunctionTool]] = None,
169
- external_tools: Optional[List[FunctionTool]] = None,
173
+ tools: Optional[List[Union[FunctionTool, Callable]]] = None,
174
+ external_tools: Optional[List[Union[FunctionTool, Callable]]] = None,
170
175
  response_terminators: Optional[List[ResponseTerminator]] = None,
171
176
  scheduling_strategy: str = "round_robin",
177
+ single_iteration: bool = False,
172
178
  ) -> None:
179
+ # Initialize the system message, converting string to BaseMessage if needed
173
180
  if isinstance(system_message, str):
174
181
  system_message = BaseMessage.make_assistant_message(
175
182
  role_name='Assistant', content=system_message
@@ -192,34 +199,35 @@ class ChatAgent(BaseAgent):
192
199
  ),
193
200
  scheduling_strategy=scheduling_strategy,
194
201
  )
195
-
196
202
  self.model_type = self.model_backend.model_type
197
203
 
198
- # Tool registration
199
- external_tools = external_tools or []
200
- tools = tools or []
201
- all_tools = tools + external_tools
202
- self.external_tool_names = [
203
- tool.get_function_name() for tool in external_tools
204
+ # Initialize tools
205
+ self.tools: List[FunctionTool] = (
206
+ self._initialize_tools(tools) if tools else []
207
+ )
208
+ self.external_tools: List[FunctionTool] = (
209
+ self._initialize_tools(external_tools) if external_tools else []
210
+ )
211
+ self.external_tool_names: List[str] = [
212
+ tool.get_function_name() for tool in self.external_tools
204
213
  ]
205
- self.func_dict = {
206
- tool.get_function_name(): tool.func for tool in all_tools
214
+ self.all_tools = self.tools + self.external_tools or []
215
+
216
+ # Create tool dictionaries and configure backend tools if necessary
217
+ self.tool_dict = {
218
+ tool.get_function_name(): tool for tool in self.all_tools
207
219
  }
208
- self.tool_dict = {tool.get_function_name(): tool for tool in all_tools}
209
220
 
210
221
  # If the user set tools from `ChatAgent`, it will override the
211
222
  # configured tools in `BaseModelBackend`.
212
- if all_tools:
223
+ if self.all_tools:
213
224
  logger.warning(
214
225
  "Overriding the configured tools in `BaseModelBackend` with the tools from `ChatAgent`."
215
226
  )
216
227
  tool_schema_list = [
217
- tool.get_openai_tool_schema() for tool in all_tools
228
+ tool.get_openai_tool_schema() for tool in self.all_tools
218
229
  ]
219
230
  self.model_backend.model_config_dict['tools'] = tool_schema_list
220
- self.tool_schema_list = tool_schema_list
221
-
222
- self.model_config_dict = self.model_backend.model_config_dict
223
231
 
224
232
  self.model_token_limit = token_limit or self.model_backend.token_limit
225
233
  context_creator = ScoreBasedContextCreator(
@@ -237,8 +245,86 @@ class ChatAgent(BaseAgent):
237
245
  self.terminated: bool = False
238
246
  self.response_terminators = response_terminators or []
239
247
  self.init_messages()
240
-
241
248
  self.tool_prompt_added = False
249
+ self.single_iteration = single_iteration
250
+
251
+ def _initialize_tools(
252
+ self, tools: List[Union[FunctionTool, Callable]]
253
+ ) -> List[FunctionTool]:
254
+ r"""Helper method to initialize tools as FunctionTool instances."""
255
+ from camel.toolkits import FunctionTool
256
+
257
+ func_tools = []
258
+ for tool in tools:
259
+ if not isinstance(tool, FunctionTool):
260
+ tool = FunctionTool(tool)
261
+ func_tools.append(tool)
262
+ return func_tools
263
+
264
+ def add_tool(
265
+ self, tool: Union[FunctionTool, Callable], is_external: bool = False
266
+ ) -> None:
267
+ r"""Add a tool to the agent, specifying if it's an external tool."""
268
+ # Initialize the tool
269
+ initialized_tool = self._initialize_tools([tool])
270
+
271
+ # Update tools or external tools based on is_external flag
272
+ if is_external:
273
+ self.external_tools = self.external_tools + initialized_tool
274
+ self.external_tool_names.extend(
275
+ tool.get_function_name() for tool in initialized_tool
276
+ )
277
+ else:
278
+ self.tools = self.tools + initialized_tool
279
+
280
+ # Rebuild all_tools, and tool_dict
281
+ self.all_tools = self.tools + self.external_tools
282
+ self.tool_dict = {
283
+ tool.get_function_name(): tool for tool in self.all_tools
284
+ }
285
+
286
+ tool_schema_list = [
287
+ tool.get_openai_tool_schema() for tool in self.all_tools
288
+ ]
289
+ self.model_backend.model_config_dict['tools'] = tool_schema_list
290
+
291
+ def remove_tool(self, tool_name: str, is_external: bool = False) -> bool:
292
+ r"""Remove a tool by name, specifying if it's an external tool."""
293
+ tool_list = self.external_tools if is_external else self.tools
294
+ if not tool_list:
295
+ return False
296
+
297
+ for tool in tool_list:
298
+ if tool.get_function_name() == tool_name:
299
+ tool_list.remove(tool)
300
+ if is_external:
301
+ self.external_tool_names.remove(tool_name)
302
+ # Reinitialize the tool dictionary
303
+ self.all_tools = (self.tools or []) + (
304
+ self.external_tools or []
305
+ )
306
+ self.tool_dict = {
307
+ tool.get_function_name(): tool for tool in self.all_tools
308
+ }
309
+ tool_schema_list = [
310
+ tool.get_openai_tool_schema() for tool in self.all_tools
311
+ ]
312
+ self.model_backend.model_config_dict['tools'] = (
313
+ tool_schema_list
314
+ )
315
+ return True
316
+ return False
317
+
318
+ def list_tools(self) -> dict:
319
+ r"""List all tools, separated into normal and external tools."""
320
+ normal_tools = [
321
+ tool.get_function_name() for tool in (self.tools or [])
322
+ ]
323
+ external_tools = [
324
+ tool.get_function_name() for tool in (self.external_tools or [])
325
+ ]
326
+
327
+ return {"normal_tools": normal_tools, "external_tools": external_tools}
242
328
 
243
329
  # ruff: noqa: E501
244
330
  def _generate_tool_prompt(self, tool_schema_list: List[Dict]) -> str:
@@ -264,9 +350,7 @@ class ChatAgent(BaseAgent):
264
350
 
265
351
  tool_prompt_str = "\n".join(tool_prompts)
266
352
 
267
- final_prompt = f'''
268
- # Tool prompt
269
- TOOL_PROMPT = f"""
353
+ final_prompt = f"""
270
354
  You have access to the following functions:
271
355
 
272
356
  {tool_prompt_str}
@@ -274,19 +358,16 @@ class ChatAgent(BaseAgent):
274
358
  If you choose to call a function ONLY reply in the following format with no
275
359
  prefix or suffix:
276
360
 
277
- <function=example_function_name>{{"example_name": "example_value"}}
278
- </function>
361
+ <function=example_function_name>{{"example_name": "example_value"}}</function>
279
362
 
280
363
  Reminder:
281
- - Function calls MUST follow the specified format, start with <function=
282
- and end with </function>
364
+ - Function calls MUST follow the specified format, start with <function= and end with </function>
283
365
  - Required parameters MUST be specified
284
366
  - Only call one function at a time
285
367
  - Put the entire function call reply on one line
286
368
  - If there is no function call available, answer the question like normal
287
369
  with your current knowledge and do not tell the user about function calls
288
370
  """
289
- '''
290
371
  return final_prompt
291
372
 
292
373
  def _parse_tool_response(self, response: str):
@@ -310,7 +391,7 @@ class ChatAgent(BaseAgent):
310
391
  args = json.loads(args_string)
311
392
  return {"function": function_name, "arguments": args}
312
393
  except json.JSONDecodeError as error:
313
- print(f"Error parsing function arguments: {error}")
394
+ logger.error(f"Error parsing function arguments: {error}")
314
395
  return None
315
396
  return None
316
397
 
@@ -342,14 +423,13 @@ class ChatAgent(BaseAgent):
342
423
  self._system_message = message
343
424
 
344
425
  def is_tools_added(self) -> bool:
345
- r"""Whether OpenAI function calling is enabled for this agent.
426
+ r"""Whether tool calling is enabled for this agent.
346
427
 
347
428
  Returns:
348
- bool: Whether OpenAI function calling is enabled for this
349
- agent, determined by whether the dictionary of tools
350
- is empty.
429
+ bool: Whether tool calling is enabled for this agent, determined
430
+ by whether the dictionary of tools is empty.
351
431
  """
352
- return len(self.func_dict) > 0
432
+ return len(self.tool_dict) > 0
353
433
 
354
434
  def update_memory(
355
435
  self, message: BaseMessage, role: OpenAIBackendRole
@@ -415,7 +495,7 @@ class ChatAgent(BaseAgent):
415
495
  Args:
416
496
  session_id (str, optional): The ID of the chat session.
417
497
  usage (Dict[str, int], optional): Information about the usage of
418
- the LLM model.
498
+ the LLM.
419
499
  termination_reasons (List[str]): The reasons for the termination
420
500
  of the chat session.
421
501
  num_tokens (int): The number of tokens used in the chat session.
@@ -470,26 +550,23 @@ class ChatAgent(BaseAgent):
470
550
  input_message: Union[BaseMessage, str],
471
551
  response_format: Optional[Type[BaseModel]] = None,
472
552
  ) -> ChatAgentResponse:
473
- r"""Performs a single step in the chat session by generating a response
553
+ r"""Executes a single step in the chat session, generating a response
474
554
  to the input message.
475
555
 
476
556
  Args:
477
- input_message (Union[BaseMessage, str]): The input message to the
478
- agent. For BaseMessage input, its `role` field that specifies
479
- the role at backend may be either `user` or `assistant` but it
480
- will be set to `user` anyway since for the self agent any
481
- incoming message is external. For str input, the `role_name` would be `User`.
482
- response_format (Optional[Type[BaseModel]], optional): A pydantic
483
- model class that includes value types and field descriptions
484
- used to generate a structured response by LLM. This schema
485
- helps in defining the expected output format. (default:
557
+ input_message (Union[BaseMessage, str]): The input message for the
558
+ agent. If provided as a BaseMessage, the `role` is adjusted to
559
+ `user` to indicate an external message.
560
+ response_format (Optional[Type[BaseModel]], optional): A Pydantic
561
+ model defining the expected structure of the response. Used to
562
+ generate a structured response if provided. (default:
486
563
  :obj:`None`)
487
564
 
488
565
  Returns:
489
- ChatAgentResponse: A struct containing the output messages,
490
- a boolean indicating whether the chat session has terminated,
491
- and information about the chat session.
566
+ ChatAgentResponse: Contains output messages, a termination status
567
+ flag, and session information.
492
568
  """
569
+
493
570
  if (
494
571
  self.model_backend.model_config_dict.get("response_format")
495
572
  and response_format
@@ -499,173 +576,135 @@ class ChatAgent(BaseAgent):
499
576
  "the model configuration and in the ChatAgent step."
500
577
  )
501
578
 
502
- original_model_dict = self.model_backend.model_config_dict
579
+ self.original_model_dict = self.model_backend.model_config_dict
503
580
  if response_format and self.model_type in {"gpt-4o", "gpt-4o-mini"}:
504
- self.model_backend.model_config_dict = original_model_dict.copy()
581
+ self.model_backend.model_config_dict = (
582
+ self.original_model_dict.copy()
583
+ )
505
584
  self.model_backend.model_config_dict["response_format"] = (
506
585
  response_format
507
586
  )
508
587
 
588
+ # Convert input message to BaseMessage if necessary
509
589
  if isinstance(input_message, str):
510
590
  input_message = BaseMessage.make_user_message(
511
591
  role_name='User', content=input_message
512
592
  )
513
593
 
514
- if "llama" in self.model_type.lower():
515
- if (
516
- self.model_backend.model_config_dict.get("tools", None)
517
- and not self.tool_prompt_added
518
- ):
519
- tool_prompt = self._generate_tool_prompt(self.tool_schema_list)
594
+ # Handle tool prompt injection if needed
595
+ if (
596
+ self.is_tools_added()
597
+ and not self.model_type.support_native_tool_calling
598
+ and not self.tool_prompt_added
599
+ ):
600
+ self._inject_tool_prompt()
520
601
 
521
- tool_sys_msg = BaseMessage.make_assistant_message(
522
- role_name="Assistant",
523
- content=tool_prompt,
524
- )
602
+ # Add user input to memory
603
+ self.update_memory(input_message, OpenAIBackendRole.USER)
525
604
 
526
- self.update_memory(tool_sys_msg, OpenAIBackendRole.SYSTEM)
527
- self.tool_prompt_added = True
605
+ return self._handle_step(response_format, self.single_iteration)
528
606
 
529
- self.update_memory(input_message, OpenAIBackendRole.USER)
607
+ def _inject_tool_prompt(self) -> None:
608
+ r"""Generate and add the tool prompt to memory."""
609
+ tool_prompt = self._generate_tool_prompt(
610
+ self.model_backend.model_config_dict["tools"]
611
+ )
612
+ tool_msg = BaseMessage.make_assistant_message(
613
+ role_name="Assistant", content=tool_prompt
614
+ )
615
+ self.update_memory(tool_msg, OpenAIBackendRole.SYSTEM)
616
+ self.tool_prompt_added = True
530
617
 
531
- tool_call_records: List[FunctionCallingRecord] = []
532
- while True:
533
- # Check if token has exceeded
534
- try:
535
- openai_messages, num_tokens = self.memory.get_context()
536
- except RuntimeError as e:
537
- return self._step_token_exceed(
538
- e.args[1], tool_call_records, "max_tokens_exceeded"
539
- )
618
+ def _handle_step(
619
+ self,
620
+ response_format: Optional[Type[BaseModel]],
621
+ single_step: bool,
622
+ ) -> ChatAgentResponse:
623
+ r"""Handles a single or multi-step interaction."""
540
624
 
541
- (
542
- response,
543
- output_messages,
544
- finish_reasons,
545
- usage_dict,
546
- response_id,
547
- ) = self._step_model_response(openai_messages, num_tokens)
548
- # If the model response is not a function call, meaning the
549
- # model has generated a message response, break the loop
550
- if (
551
- not self.is_tools_added()
552
- or not isinstance(response, ChatCompletion)
553
- or "</function>" not in response.choices[0].message.content # type: ignore[operator]
554
- ):
555
- break
556
-
557
- parsed_content = self._parse_tool_response(
558
- response.choices[0].message.content # type: ignore[arg-type]
559
- )
625
+ if (
626
+ self.model_backend.model_config_dict.get("tool_choice")
627
+ == "required"
628
+ and not single_step
629
+ ):
630
+ raise ValueError(
631
+ "`tool_choice` cannot be set to `required` for multi-step"
632
+ " mode. To proceed, set `single_iteration` to `True`."
633
+ )
560
634
 
561
- response.choices[0].message.tool_calls = [
562
- ChatCompletionMessageToolCall(
563
- id=str(uuid.uuid4()),
564
- function=Function(
565
- arguments=str(parsed_content["arguments"]).replace(
566
- "'", '"'
567
- ),
568
- name=str(parsed_content["function"]),
569
- ),
570
- type="function",
571
- )
572
- ]
635
+ # Record function calls made during the session
636
+ tool_call_records: List[FunctionCallingRecord] = []
573
637
 
574
- # Check for external tool call
575
- tool_call_request = response.choices[0].message.tool_calls[0]
576
- if tool_call_request.function.name in self.external_tool_names:
577
- # if model calls an external tool, directly return the
578
- # request
579
- info = self._step_get_info(
580
- output_messages,
581
- finish_reasons,
582
- usage_dict,
583
- response_id,
584
- tool_call_records,
585
- num_tokens,
586
- tool_call_request,
587
- )
588
- return ChatAgentResponse(
589
- msgs=output_messages,
590
- terminated=self.terminated,
591
- info=info,
592
- )
638
+ external_tool_request = None
593
639
 
594
- # Normal function calling
595
- tool_call_records.append(
596
- self._step_tool_call_and_update(response)
640
+ while True:
641
+ try:
642
+ openai_messages, num_tokens = self.memory.get_context()
643
+ except RuntimeError as e:
644
+ self.model_backend.model_config_dict = self.original_model_dict
645
+ return self._step_token_exceed(
646
+ e.args[1], tool_call_records, "max_tokens_exceeded"
597
647
  )
598
648
 
599
- if response_format is not None:
600
- (
601
- output_messages,
602
- finish_reasons,
603
- usage_dict,
604
- response_id,
605
- tool_call,
606
- num_tokens,
607
- ) = self._structure_output_with_function(response_format)
608
- tool_call_records.append(tool_call)
649
+ # Prompt engineering approach for structured output for non-native tool calling models
650
+ inject_prompt_for_structured_output = (
651
+ response_format
652
+ and not self.model_type.support_native_structured_output
653
+ )
654
+
655
+ if inject_prompt_for_structured_output:
656
+ # update last openai message
657
+ usr_msg = openai_messages.pop()
658
+ usr_msg["content"] = generate_prompt_for_structured_output(
659
+ response_format,
660
+ usr_msg["content"], # type: ignore [arg-type]
661
+ )
662
+ openai_messages.append(usr_msg)
609
663
 
610
- info = self._step_get_info(
664
+ # Process model response
665
+ (
666
+ response,
611
667
  output_messages,
612
668
  finish_reasons,
613
669
  usage_dict,
614
670
  response_id,
615
- tool_call_records,
616
- num_tokens,
617
- )
618
-
619
- if len(output_messages) == 1:
620
- # Auto record if the output result is a single message
621
- self.record_message(output_messages[0])
622
- else:
623
- logger.warning(
624
- "Multiple messages returned in `step()`, message won't be "
625
- "recorded automatically. Please call `record_message()` "
626
- "to record the selected message manually."
627
- )
628
-
629
- return ChatAgentResponse(
630
- msgs=output_messages, terminated=self.terminated, info=info
631
- )
632
-
633
- else:
634
- self.update_memory(input_message, OpenAIBackendRole.USER)
671
+ ) = self._step_model_response(openai_messages, num_tokens)
635
672
 
636
- tool_call_records: List[FunctionCallingRecord] = [] # type: ignore[no-redef]
637
- while True:
638
- # Check if token has exceeded
673
+ # Try to parse structured output to return a Pydantic object
674
+ if inject_prompt_for_structured_output and isinstance(
675
+ response, ChatCompletion
676
+ ):
677
+ content = response.choices[0].message.content
639
678
  try:
640
- openai_messages, num_tokens = self.memory.get_context()
641
- except RuntimeError as e:
642
- self.model_backend.model_config_dict = original_model_dict
643
- return self._step_token_exceed(
644
- e.args[1], tool_call_records, "max_tokens_exceeded"
679
+ json_content = json.loads(str(content))
680
+ output_messages[0].parsed = response_format(**json_content) # type: ignore [assignment, misc]
681
+ except json.JSONDecodeError as e:
682
+ logger.error(
683
+ f"Failed in parsing the output into JSON: {e}"
684
+ )
685
+ output_messages[0].parsed = None
686
+ except ValidationError as e:
687
+ logger.warning(
688
+ "Successfully generating JSON response, "
689
+ "but failed in parsing it into Pydantic object :"
690
+ f"{e}, return the JSON response in parsed field"
645
691
  )
692
+ output_messages[0].parsed = json_content
646
693
 
647
- (
648
- response,
649
- output_messages,
650
- finish_reasons,
651
- usage_dict,
652
- response_id,
653
- ) = self._step_model_response(openai_messages, num_tokens)
654
- # If the model response is not a function call, meaning the
655
- # model has generated a message response, break the loop
656
- if (
657
- not self.is_tools_added()
658
- or not isinstance(response, ChatCompletion)
659
- or not response.choices[0].message.tool_calls
660
- ):
661
- break
662
-
663
- # Check for external tool call
664
- tool_call_request = response.choices[0].message.tool_calls[0]
665
-
666
- if tool_call_request.function.name in self.external_tool_names:
667
- # if model calls an external tool, directly return the
668
- # request
694
+ # Finalize on standard response in multi-step mode
695
+ if self._is_standard_response(response):
696
+ break
697
+
698
+ # Handle tool requests
699
+ tool_request = self._extract_tool_call(response)
700
+ if isinstance(response, ChatCompletion) and tool_request:
701
+ response.choices[0].message.tool_calls = [tool_request]
702
+ tool_call_records.append(
703
+ self._step_tool_call_and_update(response)
704
+ )
705
+
706
+ if tool_request.function.name in self.external_tool_names:
707
+ external_tool_request = tool_request
669
708
  info = self._step_get_info(
670
709
  output_messages,
671
710
  finish_reasons,
@@ -673,58 +712,132 @@ class ChatAgent(BaseAgent):
673
712
  response_id,
674
713
  tool_call_records,
675
714
  num_tokens,
676
- tool_call_request,
715
+ tool_request,
716
+ )
717
+ self._log_final_output(output_messages)
718
+ self.model_backend.model_config_dict = (
719
+ self.original_model_dict
677
720
  )
678
-
679
- self.model_backend.model_config_dict = original_model_dict
680
721
  return ChatAgentResponse(
681
722
  msgs=output_messages,
682
723
  terminated=self.terminated,
683
724
  info=info,
684
725
  )
685
726
 
686
- # Normal function calling
687
- tool_call_records.append(
688
- self._step_tool_call_and_update(response)
689
- )
690
-
691
- if (
692
- response_format is not None
693
- and self.model_type.support_native_tool_calling
694
- and self.model_type not in {"gpt-4o", "gpt-4o-mini"}
695
- ):
696
- (
697
- output_messages,
698
- finish_reasons,
699
- usage_dict,
700
- response_id,
701
- tool_call,
702
- num_tokens,
703
- ) = self._structure_output_with_function(response_format)
704
- tool_call_records.append(tool_call)
727
+ # Single-step mode ends after one iteration
728
+ if single_step:
729
+ break
705
730
 
706
- info = self._step_get_info(
731
+ # Optional structured output via function calling
732
+ if (
733
+ response_format
734
+ and not inject_prompt_for_structured_output
735
+ and self.model_type
736
+ not in {
737
+ "gpt-4o",
738
+ "gpt-4o-mini",
739
+ }
740
+ ):
741
+ (
707
742
  output_messages,
708
743
  finish_reasons,
709
744
  usage_dict,
710
745
  response_id,
711
- tool_call_records,
746
+ tool_call,
712
747
  num_tokens,
713
- )
748
+ ) = self._structure_output_with_function(response_format)
749
+ tool_call_records.append(tool_call)
750
+
751
+ # Final info and response
752
+ info = self._step_get_info(
753
+ output_messages,
754
+ finish_reasons,
755
+ usage_dict,
756
+ response_id,
757
+ tool_call_records,
758
+ num_tokens,
759
+ external_tool_request,
760
+ )
761
+ self._log_final_output(output_messages)
762
+ self.model_backend.model_config_dict = self.original_model_dict
763
+ return ChatAgentResponse(
764
+ msgs=output_messages, terminated=self.terminated, info=info
765
+ )
766
+
767
+ def _extract_tool_call(
768
+ self, response: Any
769
+ ) -> Optional[ChatCompletionMessageToolCall]:
770
+ r"""Extract the tool call from the model response, if present.
771
+
772
+ Args:
773
+ response (Any): The model's response object.
714
774
 
715
- if len(output_messages) == 1:
716
- # Auto record if the output result is a single message
717
- self.record_message(output_messages[0])
718
- else:
719
- logger.warning(
720
- "Multiple messages returned in `step()`, message won't be "
721
- "recorded automatically. Please call `record_message()` "
722
- "to record the selected message manually."
775
+ Returns:
776
+ Optional[ChatCompletionMessageToolCall]: The parsed tool call if
777
+ present, otherwise None.
778
+ """
779
+ # Check if the response contains tool calls
780
+ if (
781
+ self.is_tools_added()
782
+ and not self.model_type.support_native_tool_calling
783
+ and "</function>" in response.choices[0].message.content
784
+ ):
785
+ parsed_content = self._parse_tool_response(
786
+ response.choices[0].message.content
787
+ )
788
+ if parsed_content:
789
+ return ChatCompletionMessageToolCall(
790
+ id=str(uuid.uuid4()),
791
+ function=Function(
792
+ arguments=str(parsed_content["arguments"]).replace(
793
+ "'", '"'
794
+ ),
795
+ name=str(parsed_content["function"]),
796
+ ),
797
+ type="function",
723
798
  )
799
+ elif (
800
+ self.is_tools_added()
801
+ and self.model_type.support_native_tool_calling
802
+ and response.choices[0].message.tool_calls
803
+ ):
804
+ return response.choices[0].message.tool_calls[0]
805
+
806
+ # No tool call found
807
+ return None
808
+
809
+ def _is_standard_response(self, response: Any) -> bool:
810
+ r"""Determine if the provided response is a standard reply without
811
+ tool calls.
812
+
813
+ Args:
814
+ response (Any): The response object to evaluate.
815
+
816
+ Returns:
817
+ bool: `True` if the response is a standard reply, `False`
818
+ otherwise.
819
+ """
820
+ if not self.is_tools_added():
821
+ return True
724
822
 
725
- self.model_backend.model_config_dict = original_model_dict
726
- return ChatAgentResponse(
727
- msgs=output_messages, terminated=self.terminated, info=info
823
+ if not isinstance(response, ChatCompletion):
824
+ return True
825
+
826
+ if self.model_type.support_native_tool_calling:
827
+ return response.choices[0].message.tool_calls is None
828
+
829
+ return "</function>" not in str(
830
+ response.choices[0].message.content or ""
831
+ )
832
+
833
+ def _log_final_output(self, output_messages: List[BaseMessage]) -> None:
834
+ r"""Log final messages or warnings about multiple responses."""
835
+ if len(output_messages) == 1:
836
+ self.record_message(output_messages[0])
837
+ else:
838
+ logger.warning(
839
+ "Multiple messages returned in `step()`. Record "
840
+ "selected message manually using `record_message()`."
728
841
  )
729
842
 
730
843
  async def step_async(
@@ -740,7 +853,8 @@ class ChatAgent(BaseAgent):
740
853
  agent. For BaseMessage input, its `role` field that specifies
741
854
  the role at backend may be either `user` or `assistant` but it
742
855
  will be set to `user` anyway since for the self agent any
743
- incoming message is external. For str input, the `role_name` would be `User`.
856
+ incoming message is external. For str input, the `role_name`
857
+ would be `User`.
744
858
  response_format (Optional[Type[BaseModel]], optional): A pydantic
745
859
  model class that includes value types and field descriptions
746
860
  used to generate a structured response by LLM. This schema
@@ -779,13 +893,13 @@ class ChatAgent(BaseAgent):
779
893
  if (
780
894
  not self.is_tools_added()
781
895
  or not isinstance(response, ChatCompletion)
782
- or response.choices[0].message.tool_calls is None
896
+ or not response.choices[0].message.tool_calls
783
897
  ):
784
898
  break
785
899
 
786
900
  # Check for external tool call
787
- tool_call_request = response.choices[0].message.tool_calls[0]
788
- if tool_call_request.function.name in self.external_tool_names:
901
+ external_tool_request = response.choices[0].message.tool_calls[0]
902
+ if external_tool_request.function.name in self.external_tool_names:
789
903
  # if model calls an external tool, directly return the request
790
904
  info = self._step_get_info(
791
905
  output_messages,
@@ -794,7 +908,7 @@ class ChatAgent(BaseAgent):
794
908
  response_id,
795
909
  tool_call_records,
796
910
  num_tokens,
797
- tool_call_request,
911
+ external_tool_request,
798
912
  )
799
913
  return ChatAgentResponse(
800
914
  msgs=output_messages, terminated=self.terminated, info=info
@@ -859,7 +973,7 @@ class ChatAgent(BaseAgent):
859
973
 
860
974
  # Perform function calling
861
975
  func_assistant_msg, func_result_msg, tool_call_record = (
862
- self.step_tool_call(response)
976
+ self._step_tool_call(response)
863
977
  )
864
978
 
865
979
  # Update the messages
@@ -913,11 +1027,9 @@ class ChatAgent(BaseAgent):
913
1027
  func_callable = func_string_to_callable(func_str)
914
1028
  func = FunctionTool(func_callable)
915
1029
 
916
- original_func_dict = self.func_dict
917
1030
  original_model_dict = self.model_backend.model_config_dict
918
1031
 
919
1032
  # Replace the original tools with the structuring function
920
- self.func_dict = {func.get_function_name(): func.func}
921
1033
  self.tool_dict = {func.get_function_name(): func}
922
1034
  self.model_backend.model_config_dict = original_model_dict.copy()
923
1035
  self.model_backend.model_config_dict["tools"] = [
@@ -945,7 +1057,6 @@ class ChatAgent(BaseAgent):
945
1057
  base_message_item.content = json.dumps(tool_call_record.result)
946
1058
 
947
1059
  # Recover the original tools
948
- self.func_dict = original_func_dict
949
1060
  self.model_backend.model_config_dict = original_model_dict
950
1061
 
951
1062
  return (
@@ -1245,7 +1356,7 @@ class ChatAgent(BaseAgent):
1245
1356
  info=info,
1246
1357
  )
1247
1358
 
1248
- def step_tool_call(
1359
+ def _step_tool_call(
1249
1360
  self,
1250
1361
  response: ChatCompletion,
1251
1362
  ) -> Tuple[
@@ -1268,7 +1379,9 @@ class ChatAgent(BaseAgent):
1268
1379
  raise RuntimeError("Tool call is None")
1269
1380
  func_name = choice.message.tool_calls[0].function.name
1270
1381
 
1271
- args = json.loads(choice.message.tool_calls[0].function.arguments)
1382
+ arguments_str = choice.message.tool_calls[0].function.arguments
1383
+ args = self._safe_json_loads(arguments_str)
1384
+
1272
1385
  tool = self.tool_dict[func_name]
1273
1386
  result = tool(**args)
1274
1387
 
@@ -1295,6 +1408,18 @@ class ChatAgent(BaseAgent):
1295
1408
  )
1296
1409
  return assist_msg, func_msg, func_record
1297
1410
 
1411
+ def _safe_json_loads(self, arguments_str):
1412
+ # Replace Python types with their JSON equivalents
1413
+ arguments_str = arguments_str.replace("None", "null")
1414
+ arguments_str = arguments_str.replace("True", "true")
1415
+ arguments_str = arguments_str.replace("False", "false")
1416
+
1417
+ # Attempt to parse the corrected string
1418
+ try:
1419
+ return json.loads(arguments_str)
1420
+ except json.JSONDecodeError as e:
1421
+ raise ValueError(f"Invalid JSON format: {e}")
1422
+
1298
1423
  async def step_tool_call_async(
1299
1424
  self,
1300
1425
  response: ChatCompletion,