camel-ai 0.2.22__py3-none-any.whl → 0.2.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/agents/_types.py +41 -0
- camel/agents/_utils.py +188 -0
- camel/agents/chat_agent.py +570 -965
- camel/agents/knowledge_graph_agent.py +7 -1
- camel/agents/multi_hop_generator_agent.py +1 -1
- camel/configs/base_config.py +10 -13
- camel/configs/deepseek_config.py +4 -30
- camel/configs/gemini_config.py +5 -31
- camel/configs/openai_config.py +14 -32
- camel/configs/qwen_config.py +36 -36
- camel/datagen/self_improving_cot.py +81 -3
- camel/datagen/self_instruct/filter/instruction_filter.py +19 -3
- camel/datagen/self_instruct/self_instruct.py +52 -3
- camel/datasets/__init__.py +28 -0
- camel/datasets/base.py +969 -0
- camel/environments/__init__.py +16 -0
- camel/environments/base.py +503 -0
- camel/extractors/__init__.py +16 -0
- camel/extractors/base.py +263 -0
- camel/memories/agent_memories.py +16 -1
- camel/memories/blocks/chat_history_block.py +10 -2
- camel/memories/blocks/vectordb_block.py +1 -0
- camel/memories/context_creators/score_based.py +20 -3
- camel/memories/records.py +10 -0
- camel/messages/base.py +8 -8
- camel/models/__init__.py +2 -0
- camel/models/_utils.py +57 -0
- camel/models/aiml_model.py +48 -17
- camel/models/anthropic_model.py +41 -3
- camel/models/azure_openai_model.py +39 -3
- camel/models/base_audio_model.py +92 -0
- camel/models/base_model.py +88 -13
- camel/models/cohere_model.py +88 -11
- camel/models/deepseek_model.py +107 -45
- camel/models/fish_audio_model.py +18 -8
- camel/models/gemini_model.py +133 -15
- camel/models/groq_model.py +72 -10
- camel/models/internlm_model.py +14 -3
- camel/models/litellm_model.py +9 -2
- camel/models/mistral_model.py +42 -5
- camel/models/model_manager.py +57 -3
- camel/models/moonshot_model.py +33 -4
- camel/models/nemotron_model.py +32 -3
- camel/models/nvidia_model.py +43 -3
- camel/models/ollama_model.py +139 -17
- camel/models/openai_audio_models.py +87 -2
- camel/models/openai_compatible_model.py +37 -3
- camel/models/openai_model.py +158 -46
- camel/models/qwen_model.py +61 -4
- camel/models/reka_model.py +53 -3
- camel/models/samba_model.py +209 -4
- camel/models/sglang_model.py +153 -14
- camel/models/siliconflow_model.py +16 -3
- camel/models/stub_model.py +46 -4
- camel/models/togetherai_model.py +38 -3
- camel/models/vllm_model.py +37 -3
- camel/models/yi_model.py +36 -3
- camel/models/zhipuai_model.py +38 -3
- camel/retrievers/__init__.py +3 -0
- camel/retrievers/hybrid_retrival.py +237 -0
- camel/toolkits/__init__.py +15 -1
- camel/toolkits/arxiv_toolkit.py +2 -1
- camel/toolkits/ask_news_toolkit.py +4 -2
- camel/toolkits/audio_analysis_toolkit.py +238 -0
- camel/toolkits/base.py +22 -3
- camel/toolkits/code_execution.py +2 -0
- camel/toolkits/dappier_toolkit.py +2 -1
- camel/toolkits/data_commons_toolkit.py +38 -12
- camel/toolkits/excel_toolkit.py +172 -0
- camel/toolkits/function_tool.py +13 -0
- camel/toolkits/github_toolkit.py +5 -1
- camel/toolkits/google_maps_toolkit.py +2 -1
- camel/toolkits/google_scholar_toolkit.py +2 -0
- camel/toolkits/human_toolkit.py +0 -3
- camel/toolkits/image_analysis_toolkit.py +202 -0
- camel/toolkits/linkedin_toolkit.py +3 -2
- camel/toolkits/meshy_toolkit.py +3 -2
- camel/toolkits/mineru_toolkit.py +2 -2
- camel/toolkits/networkx_toolkit.py +240 -0
- camel/toolkits/notion_toolkit.py +2 -0
- camel/toolkits/openbb_toolkit.py +3 -2
- camel/toolkits/page_script.js +376 -0
- camel/toolkits/reddit_toolkit.py +11 -3
- camel/toolkits/retrieval_toolkit.py +6 -1
- camel/toolkits/semantic_scholar_toolkit.py +2 -1
- camel/toolkits/stripe_toolkit.py +8 -2
- camel/toolkits/sympy_toolkit.py +6 -1
- camel/toolkits/video_analysis_toolkit.py +407 -0
- camel/toolkits/{video_toolkit.py → video_download_toolkit.py} +21 -25
- camel/toolkits/web_toolkit.py +1307 -0
- camel/toolkits/whatsapp_toolkit.py +3 -2
- camel/toolkits/zapier_toolkit.py +191 -0
- camel/types/__init__.py +2 -2
- camel/types/agents/__init__.py +16 -0
- camel/types/agents/tool_calling_record.py +52 -0
- camel/types/enums.py +3 -0
- camel/types/openai_types.py +16 -14
- camel/utils/__init__.py +2 -1
- camel/utils/async_func.py +2 -2
- camel/utils/commons.py +114 -1
- camel/verifiers/__init__.py +23 -0
- camel/verifiers/base.py +340 -0
- camel/verifiers/models.py +82 -0
- camel/verifiers/python_verifier.py +202 -0
- camel_ai-0.2.23.dist-info/METADATA +671 -0
- {camel_ai-0.2.22.dist-info → camel_ai-0.2.23.dist-info}/RECORD +122 -97
- {camel_ai-0.2.22.dist-info → camel_ai-0.2.23.dist-info}/WHEEL +1 -1
- camel_ai-0.2.22.dist-info/METADATA +0 -527
- {camel_ai-0.2.22.dist-info → camel_ai-0.2.23.dist-info/licenses}/LICENSE +0 -0
camel/agents/chat_agent.py
CHANGED
|
@@ -15,8 +15,7 @@ from __future__ import annotations
|
|
|
15
15
|
|
|
16
16
|
import json
|
|
17
17
|
import logging
|
|
18
|
-
import
|
|
19
|
-
import uuid
|
|
18
|
+
import textwrap
|
|
20
19
|
from collections import defaultdict
|
|
21
20
|
from typing import (
|
|
22
21
|
TYPE_CHECKING,
|
|
@@ -25,15 +24,25 @@ from typing import (
|
|
|
25
24
|
Dict,
|
|
26
25
|
List,
|
|
27
26
|
Optional,
|
|
28
|
-
|
|
27
|
+
Set,
|
|
29
28
|
Type,
|
|
30
29
|
Union,
|
|
31
30
|
)
|
|
32
31
|
|
|
33
|
-
from openai
|
|
34
|
-
|
|
32
|
+
from openai import (
|
|
33
|
+
AsyncStream,
|
|
34
|
+
Stream,
|
|
35
|
+
)
|
|
35
36
|
from pydantic import BaseModel, ValidationError
|
|
36
37
|
|
|
38
|
+
from camel.agents._types import ModelResponse, ToolCallRequest
|
|
39
|
+
from camel.agents._utils import (
|
|
40
|
+
convert_to_function_tool,
|
|
41
|
+
convert_to_schema,
|
|
42
|
+
get_info_dict,
|
|
43
|
+
handle_logprobs,
|
|
44
|
+
safe_model_dump,
|
|
45
|
+
)
|
|
37
46
|
from camel.agents.base import BaseAgent
|
|
38
47
|
from camel.memories import (
|
|
39
48
|
AgentMemory,
|
|
@@ -48,7 +57,9 @@ from camel.models import (
|
|
|
48
57
|
ModelManager,
|
|
49
58
|
ModelProcessingError,
|
|
50
59
|
)
|
|
60
|
+
from camel.prompts import TextPrompt
|
|
51
61
|
from camel.responses import ChatAgentResponse
|
|
62
|
+
from camel.toolkits import FunctionTool
|
|
52
63
|
from camel.types import (
|
|
53
64
|
ChatCompletion,
|
|
54
65
|
ChatCompletionChunk,
|
|
@@ -57,19 +68,11 @@ from camel.types import (
|
|
|
57
68
|
OpenAIBackendRole,
|
|
58
69
|
RoleType,
|
|
59
70
|
)
|
|
60
|
-
from camel.
|
|
61
|
-
|
|
62
|
-
generate_prompt_for_structured_output,
|
|
63
|
-
get_model_encoding,
|
|
64
|
-
get_pydantic_object_schema,
|
|
65
|
-
json_to_function_code,
|
|
66
|
-
)
|
|
71
|
+
from camel.types.agents import ToolCallingRecord
|
|
72
|
+
from camel.utils import get_model_encoding
|
|
67
73
|
|
|
68
74
|
if TYPE_CHECKING:
|
|
69
|
-
from openai import Stream
|
|
70
|
-
|
|
71
75
|
from camel.terminators import ResponseTerminator
|
|
72
|
-
from camel.toolkits import FunctionTool
|
|
73
76
|
|
|
74
77
|
|
|
75
78
|
logger = logging.getLogger(__name__)
|
|
@@ -86,41 +89,15 @@ except (ImportError, AttributeError):
|
|
|
86
89
|
from camel.utils import track_agent
|
|
87
90
|
|
|
88
91
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
the tools.
|
|
96
|
-
result (Any): The execution result of calling this tool.
|
|
97
|
-
tool_call_id (str): The ID of the tool call, if available.
|
|
98
|
-
"""
|
|
99
|
-
|
|
100
|
-
tool_name: str
|
|
101
|
-
args: Dict[str, Any]
|
|
102
|
-
result: Any
|
|
103
|
-
tool_call_id: str
|
|
104
|
-
|
|
105
|
-
def __str__(self) -> str:
|
|
106
|
-
r"""Overridden version of the string function.
|
|
107
|
-
|
|
108
|
-
Returns:
|
|
109
|
-
str: Modified string to represent the tool calling.
|
|
110
|
-
"""
|
|
111
|
-
return (
|
|
112
|
-
f"Tool Execution: {self.tool_name}\n"
|
|
113
|
-
f"\tArgs: {self.args}\n"
|
|
114
|
-
f"\tResult: {self.result}\n"
|
|
115
|
-
)
|
|
116
|
-
|
|
117
|
-
def as_dict(self) -> dict[str, Any]:
|
|
118
|
-
r"""Returns the function calling record as a dictionary.
|
|
119
|
-
|
|
120
|
-
Returns:
|
|
121
|
-
dict[str, Any]: The function calling record as a dictionary.
|
|
92
|
+
SIMPLE_FORMAT_PROMPT = TextPrompt(
|
|
93
|
+
textwrap.dedent(
|
|
94
|
+
"""\
|
|
95
|
+
Please format the following content:
|
|
96
|
+
|
|
97
|
+
{content}
|
|
122
98
|
"""
|
|
123
|
-
|
|
99
|
+
)
|
|
100
|
+
)
|
|
124
101
|
|
|
125
102
|
|
|
126
103
|
@track_agent(name="ChatAgent")
|
|
@@ -148,11 +125,12 @@ class ChatAgent(BaseAgent):
|
|
|
148
125
|
tools (Optional[List[Union[FunctionTool, Callable]]], optional): List
|
|
149
126
|
of available :obj:`FunctionTool` or :obj:`Callable`. (default:
|
|
150
127
|
:obj:`None`)
|
|
151
|
-
external_tools (Optional[List[Union[FunctionTool, Callable
|
|
152
|
-
optional): List of external tools
|
|
153
|
-
:obj:`
|
|
154
|
-
|
|
155
|
-
processing it.
|
|
128
|
+
external_tools (Optional[List[Union[FunctionTool, Callable,
|
|
129
|
+
Dict[str, Any]]]], optional): List of external tools
|
|
130
|
+
(:obj:`FunctionTool` or :obj:`Callable` or :obj:`Dict[str, Any]`)
|
|
131
|
+
bind to one chat agent. When these tools are called, the agent will
|
|
132
|
+
directly return the request instead of processing it.
|
|
133
|
+
(default: :obj:`None`)
|
|
156
134
|
response_terminators (List[ResponseTerminator], optional): List of
|
|
157
135
|
:obj:`ResponseTerminator` bind to one chat agent.
|
|
158
136
|
(default: :obj:`None`)
|
|
@@ -173,265 +151,164 @@ class ChatAgent(BaseAgent):
|
|
|
173
151
|
token_limit: Optional[int] = None,
|
|
174
152
|
output_language: Optional[str] = None,
|
|
175
153
|
tools: Optional[List[Union[FunctionTool, Callable]]] = None,
|
|
176
|
-
external_tools: Optional[
|
|
154
|
+
external_tools: Optional[
|
|
155
|
+
List[Union[FunctionTool, Callable, Dict[str, Any]]]
|
|
156
|
+
] = None,
|
|
177
157
|
response_terminators: Optional[List[ResponseTerminator]] = None,
|
|
178
158
|
scheduling_strategy: str = "round_robin",
|
|
179
159
|
single_iteration: bool = False,
|
|
180
160
|
) -> None:
|
|
181
|
-
#
|
|
182
|
-
if isinstance(system_message, str):
|
|
183
|
-
system_message = BaseMessage.make_assistant_message(
|
|
184
|
-
role_name='Assistant', content=system_message
|
|
185
|
-
)
|
|
186
|
-
|
|
187
|
-
self.orig_sys_message: Optional[BaseMessage] = system_message
|
|
188
|
-
self._system_message: Optional[BaseMessage] = system_message
|
|
189
|
-
self.role_name: str = (
|
|
190
|
-
getattr(system_message, 'role_name', None) or "assistant"
|
|
191
|
-
)
|
|
192
|
-
self.role_type: RoleType = (
|
|
193
|
-
getattr(system_message, 'role_type', None) or RoleType.ASSISTANT
|
|
194
|
-
)
|
|
161
|
+
# Set up model backend
|
|
195
162
|
self.model_backend = ModelManager(
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
163
|
+
(
|
|
164
|
+
model
|
|
165
|
+
if model is not None
|
|
166
|
+
else ModelFactory.create(
|
|
167
|
+
model_platform=ModelPlatformType.DEFAULT,
|
|
168
|
+
model_type=ModelType.DEFAULT,
|
|
169
|
+
)
|
|
201
170
|
),
|
|
202
171
|
scheduling_strategy=scheduling_strategy,
|
|
203
172
|
)
|
|
204
173
|
self.model_type = self.model_backend.model_type
|
|
205
174
|
|
|
206
|
-
#
|
|
207
|
-
self.tools: List[FunctionTool] = (
|
|
208
|
-
self._initialize_tools(tools) if tools else []
|
|
209
|
-
)
|
|
210
|
-
self.external_tools: List[FunctionTool] = (
|
|
211
|
-
self._initialize_tools(external_tools) if external_tools else []
|
|
212
|
-
)
|
|
213
|
-
self.external_tool_names: List[str] = [
|
|
214
|
-
tool.get_function_name() for tool in self.external_tools
|
|
215
|
-
]
|
|
216
|
-
self.all_tools = self.tools + self.external_tools or []
|
|
217
|
-
|
|
218
|
-
# Create tool dictionaries and configure backend tools if necessary
|
|
219
|
-
self.tool_dict = {
|
|
220
|
-
tool.get_function_name(): tool for tool in self.all_tools
|
|
221
|
-
}
|
|
222
|
-
|
|
223
|
-
# If the user set tools from `ChatAgent`, it will override the
|
|
224
|
-
# configured tools in `BaseModelBackend`.
|
|
225
|
-
if self.all_tools:
|
|
226
|
-
logger.warning(
|
|
227
|
-
"Overriding the configured tools in `BaseModelBackend` with the tools from `ChatAgent`."
|
|
228
|
-
)
|
|
229
|
-
tool_schema_list = [
|
|
230
|
-
tool.get_openai_tool_schema() for tool in self.all_tools
|
|
231
|
-
]
|
|
232
|
-
self.model_backend.model_config_dict['tools'] = tool_schema_list
|
|
233
|
-
|
|
234
|
-
self.model_token_limit = token_limit or self.model_backend.token_limit
|
|
175
|
+
# Set up memory
|
|
235
176
|
context_creator = ScoreBasedContextCreator(
|
|
236
177
|
self.model_backend.token_counter,
|
|
237
|
-
self.
|
|
178
|
+
token_limit or self.model_backend.token_limit,
|
|
238
179
|
)
|
|
239
180
|
self.memory: AgentMemory = memory or ChatHistoryMemory(
|
|
240
181
|
context_creator, window_size=message_window_size
|
|
241
182
|
)
|
|
242
183
|
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
184
|
+
# Set up system message and initialize messages
|
|
185
|
+
self._original_system_message = (
|
|
186
|
+
BaseMessage.make_assistant_message(
|
|
187
|
+
role_name="Assistant", content=system_message
|
|
188
|
+
)
|
|
189
|
+
if isinstance(system_message, str)
|
|
190
|
+
else system_message
|
|
191
|
+
)
|
|
192
|
+
self._output_language = output_language
|
|
193
|
+
self._system_message = (
|
|
194
|
+
self._generate_system_message_for_output_language()
|
|
195
|
+
)
|
|
249
196
|
self.init_messages()
|
|
250
|
-
self.tool_prompt_added = False
|
|
251
|
-
self.single_iteration = single_iteration
|
|
252
197
|
|
|
253
|
-
|
|
254
|
-
self
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
if not isinstance(tool, FunctionTool):
|
|
262
|
-
tool = FunctionTool(tool)
|
|
263
|
-
func_tools.append(tool)
|
|
264
|
-
return func_tools
|
|
265
|
-
|
|
266
|
-
def add_tool(
|
|
267
|
-
self, tool: Union[FunctionTool, Callable], is_external: bool = False
|
|
268
|
-
) -> None:
|
|
269
|
-
r"""Add a tool to the agent, specifying if it's an external tool."""
|
|
270
|
-
# Initialize the tool
|
|
271
|
-
initialized_tool = self._initialize_tools([tool])
|
|
272
|
-
|
|
273
|
-
# Update tools or external tools based on is_external flag
|
|
274
|
-
if is_external:
|
|
275
|
-
self.external_tools = self.external_tools + initialized_tool
|
|
276
|
-
self.external_tool_names.extend(
|
|
277
|
-
tool.get_function_name() for tool in initialized_tool
|
|
278
|
-
)
|
|
279
|
-
else:
|
|
280
|
-
self.tools = self.tools + initialized_tool
|
|
198
|
+
# Set up role name and role type
|
|
199
|
+
self.role_name: str = (
|
|
200
|
+
getattr(self.system_message, "role_name", None) or "assistant"
|
|
201
|
+
)
|
|
202
|
+
self.role_type: RoleType = (
|
|
203
|
+
getattr(self.system_message, "role_type", None)
|
|
204
|
+
or RoleType.ASSISTANT
|
|
205
|
+
)
|
|
281
206
|
|
|
282
|
-
#
|
|
283
|
-
self.
|
|
284
|
-
|
|
285
|
-
|
|
207
|
+
# Set up tools
|
|
208
|
+
self._internal_tools = {
|
|
209
|
+
tool.get_function_name(): tool
|
|
210
|
+
for tool in [
|
|
211
|
+
convert_to_function_tool(tool) for tool in (tools or [])
|
|
212
|
+
]
|
|
286
213
|
}
|
|
287
214
|
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
215
|
+
self._external_tool_schemas = {
|
|
216
|
+
tool_schema["function"]["name"]: tool_schema
|
|
217
|
+
for tool_schema in [
|
|
218
|
+
convert_to_schema(tool) for tool in (external_tools or [])
|
|
219
|
+
]
|
|
220
|
+
}
|
|
292
221
|
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
return False
|
|
222
|
+
# Set up other properties
|
|
223
|
+
self.terminated = False
|
|
224
|
+
self.response_terminators = response_terminators or []
|
|
225
|
+
self.single_iteration = single_iteration
|
|
298
226
|
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
self.all_tools = (self.tools or []) + (
|
|
306
|
-
self.external_tools or []
|
|
307
|
-
)
|
|
308
|
-
self.tool_dict = {
|
|
309
|
-
tool.get_function_name(): tool for tool in self.all_tools
|
|
310
|
-
}
|
|
311
|
-
tool_schema_list = [
|
|
312
|
-
tool.get_openai_tool_schema() for tool in self.all_tools
|
|
313
|
-
]
|
|
314
|
-
self.model_backend.model_config_dict['tools'] = (
|
|
315
|
-
tool_schema_list
|
|
316
|
-
)
|
|
317
|
-
return True
|
|
318
|
-
return False
|
|
227
|
+
def reset(self):
|
|
228
|
+
r"""Resets the :obj:`ChatAgent` to its initial state."""
|
|
229
|
+
self.terminated = False
|
|
230
|
+
self.init_messages()
|
|
231
|
+
for terminator in self.response_terminators:
|
|
232
|
+
terminator.reset()
|
|
319
233
|
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
]
|
|
325
|
-
external_tools = [
|
|
326
|
-
tool.get_function_name() for tool in (self.external_tools or [])
|
|
327
|
-
]
|
|
234
|
+
@property
|
|
235
|
+
def system_message(self) -> Optional[BaseMessage]:
|
|
236
|
+
r"""Returns the system message for the agent."""
|
|
237
|
+
return self._system_message
|
|
328
238
|
|
|
329
|
-
|
|
239
|
+
@property
|
|
240
|
+
def tool_dict(self) -> Dict[str, FunctionTool]:
|
|
241
|
+
r"""Returns a dictionary of internal tools."""
|
|
242
|
+
return self._internal_tools
|
|
330
243
|
|
|
331
|
-
|
|
332
|
-
def
|
|
333
|
-
r"""
|
|
244
|
+
@property
|
|
245
|
+
def output_language(self) -> Optional[str]:
|
|
246
|
+
r"""Returns the output language for the agent."""
|
|
247
|
+
return self._output_language
|
|
334
248
|
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
249
|
+
@output_language.setter
|
|
250
|
+
def output_language(self, value: str) -> None:
|
|
251
|
+
r"""Set the output language for the agent.
|
|
338
252
|
|
|
339
|
-
|
|
340
|
-
str: A string representing the tool prompt.
|
|
253
|
+
Note that this will clear the message history.
|
|
341
254
|
"""
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
tool_description = tool_info['description']
|
|
348
|
-
tool_json = json.dumps(tool_info, indent=4)
|
|
349
|
-
|
|
350
|
-
prompt = f"Use the function '{tool_name}' to '{tool_description}':\n{tool_json}\n"
|
|
351
|
-
tool_prompts.append(prompt)
|
|
352
|
-
|
|
353
|
-
tool_prompt_str = "\n".join(tool_prompts)
|
|
354
|
-
|
|
355
|
-
final_prompt = f"""
|
|
356
|
-
You have access to the following functions:
|
|
255
|
+
self._output_language = value
|
|
256
|
+
self._system_message = (
|
|
257
|
+
self._generate_system_message_for_output_language()
|
|
258
|
+
)
|
|
259
|
+
self.init_messages()
|
|
357
260
|
|
|
358
|
-
|
|
261
|
+
def _get_full_tool_schemas(self) -> List[Dict[str, Any]]:
|
|
262
|
+
r"""Returns a list of tool schemas of all tools, including internal
|
|
263
|
+
and external tools.
|
|
264
|
+
"""
|
|
265
|
+
return list(self._external_tool_schemas.values()) + [
|
|
266
|
+
func_tool.get_openai_tool_schema()
|
|
267
|
+
for func_tool in self._internal_tools.values()
|
|
268
|
+
]
|
|
359
269
|
|
|
360
|
-
|
|
361
|
-
|
|
270
|
+
def _get_external_tool_names(self) -> Set[str]:
|
|
271
|
+
r"""Returns a set of external tool names."""
|
|
272
|
+
return set(self._external_tool_schemas.keys())
|
|
362
273
|
|
|
363
|
-
|
|
274
|
+
def add_tool(self, tool: Union[FunctionTool, Callable]) -> None:
|
|
275
|
+
r"""Add a tool to the agent."""
|
|
276
|
+
new_tool = convert_to_function_tool(tool)
|
|
277
|
+
self._internal_tools[new_tool.get_function_name()] = new_tool
|
|
364
278
|
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
- If there is no function call available, answer the question like normal
|
|
371
|
-
with your current knowledge and do not tell the user about function calls
|
|
372
|
-
"""
|
|
373
|
-
return final_prompt
|
|
279
|
+
def add_external_tool(
|
|
280
|
+
self, tool: Union[FunctionTool, Callable, Dict[str, Any]]
|
|
281
|
+
) -> None:
|
|
282
|
+
new_tool_schema = convert_to_schema(tool)
|
|
283
|
+
self._external_tool_schemas[new_tool_schema["name"]] = new_tool_schema
|
|
374
284
|
|
|
375
|
-
def
|
|
376
|
-
r"""
|
|
377
|
-
arguments.
|
|
285
|
+
def remove_tool(self, tool_name: str) -> bool:
|
|
286
|
+
r"""Remove a tool from the agent by name.
|
|
378
287
|
|
|
379
288
|
Args:
|
|
380
|
-
|
|
381
|
-
function call.
|
|
289
|
+
tool_name (str): The name of the tool to remove.
|
|
382
290
|
|
|
383
291
|
Returns:
|
|
384
|
-
|
|
385
|
-
if found, otherwise :obj:`None`.
|
|
292
|
+
bool: Whether the tool was successfully removed.
|
|
386
293
|
"""
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
function_name, args_string = match.groups()
|
|
392
|
-
try:
|
|
393
|
-
args = json.loads(args_string)
|
|
394
|
-
return {"function": function_name, "arguments": args}
|
|
395
|
-
except json.JSONDecodeError as error:
|
|
396
|
-
logger.error(f"Error parsing function arguments: {error}")
|
|
397
|
-
return None
|
|
398
|
-
return None
|
|
399
|
-
|
|
400
|
-
def reset(self):
|
|
401
|
-
r"""Resets the :obj:`ChatAgent` to its initial state."""
|
|
402
|
-
self.terminated = False
|
|
403
|
-
self.init_messages()
|
|
404
|
-
for terminator in self.response_terminators:
|
|
405
|
-
terminator.reset()
|
|
406
|
-
|
|
407
|
-
@property
|
|
408
|
-
def system_message(self) -> Optional[BaseMessage]:
|
|
409
|
-
r"""The getter method for the property :obj:`system_message`.
|
|
410
|
-
|
|
411
|
-
Returns:
|
|
412
|
-
Optional[BaseMessage]: The system message of this agent if set,
|
|
413
|
-
else :obj:`None`.
|
|
414
|
-
"""
|
|
415
|
-
return self._system_message
|
|
294
|
+
if tool_name in self._internal_tools:
|
|
295
|
+
del self._internal_tools[tool_name]
|
|
296
|
+
return True
|
|
297
|
+
return False
|
|
416
298
|
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
r"""The setter method for the property :obj:`system_message`.
|
|
299
|
+
def remove_external_tool(self, tool_name: str) -> bool:
|
|
300
|
+
r"""Remove an external tool from the agent by name.
|
|
420
301
|
|
|
421
302
|
Args:
|
|
422
|
-
|
|
423
|
-
new system message of this agent.
|
|
424
|
-
"""
|
|
425
|
-
self._system_message = message
|
|
426
|
-
|
|
427
|
-
def is_tools_added(self) -> bool:
|
|
428
|
-
r"""Whether tool calling is enabled for this agent.
|
|
303
|
+
tool_name (str): The name of the tool to remove.
|
|
429
304
|
|
|
430
305
|
Returns:
|
|
431
|
-
bool: Whether tool
|
|
432
|
-
by whether the dictionary of tools is empty.
|
|
306
|
+
bool: Whether the tool was successfully removed.
|
|
433
307
|
"""
|
|
434
|
-
|
|
308
|
+
if tool_name in self._external_tool_schemas:
|
|
309
|
+
del self._external_tool_schemas[tool_name]
|
|
310
|
+
return True
|
|
311
|
+
return False
|
|
435
312
|
|
|
436
313
|
def update_memory(
|
|
437
314
|
self, message: BaseMessage, role: OpenAIBackendRole
|
|
@@ -447,94 +324,41 @@ class ChatAgent(BaseAgent):
|
|
|
447
324
|
MemoryRecord(message=message, role_at_backend=role)
|
|
448
325
|
)
|
|
449
326
|
|
|
450
|
-
def
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
generated.
|
|
327
|
+
def _generate_system_message_for_output_language(
|
|
328
|
+
self,
|
|
329
|
+
) -> Optional[BaseMessage]:
|
|
330
|
+
r"""Generate a new system message with the output language prompt.
|
|
455
331
|
|
|
456
|
-
|
|
457
|
-
|
|
332
|
+
The output language determines the language in which the output text
|
|
333
|
+
should be generated.
|
|
458
334
|
|
|
459
335
|
Returns:
|
|
460
|
-
BaseMessage: The
|
|
336
|
+
BaseMessage: The new system message.
|
|
461
337
|
"""
|
|
462
|
-
self.
|
|
338
|
+
if not self._output_language:
|
|
339
|
+
return self._original_system_message
|
|
340
|
+
|
|
463
341
|
language_prompt = (
|
|
464
342
|
"\nRegardless of the input language, "
|
|
465
|
-
f"you must output text in {
|
|
343
|
+
f"you must output text in {self._output_language}."
|
|
466
344
|
)
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
)
|
|
345
|
+
|
|
346
|
+
if self._original_system_message is not None:
|
|
347
|
+
content = self._original_system_message.content + language_prompt
|
|
348
|
+
return self._original_system_message.create_new_instance(content)
|
|
472
349
|
else:
|
|
473
|
-
|
|
350
|
+
return BaseMessage.make_assistant_message(
|
|
474
351
|
role_name="Assistant",
|
|
475
352
|
content=language_prompt,
|
|
476
353
|
)
|
|
477
354
|
|
|
478
|
-
system_record = MemoryRecord(
|
|
479
|
-
message=self._system_message,
|
|
480
|
-
role_at_backend=OpenAIBackendRole.SYSTEM,
|
|
481
|
-
)
|
|
482
|
-
self.memory.clear()
|
|
483
|
-
self.memory.write_record(system_record)
|
|
484
|
-
return self._system_message
|
|
485
|
-
|
|
486
|
-
def get_info(
|
|
487
|
-
self,
|
|
488
|
-
session_id: Optional[str],
|
|
489
|
-
usage: Optional[Dict[str, int]],
|
|
490
|
-
termination_reasons: List[str],
|
|
491
|
-
num_tokens: int,
|
|
492
|
-
tool_calls: List[ToolCallingRecord],
|
|
493
|
-
external_tool_request: Optional[ChatCompletionMessageToolCall] = None,
|
|
494
|
-
) -> Dict[str, Any]:
|
|
495
|
-
r"""Returns a dictionary containing information about the chat session.
|
|
496
|
-
|
|
497
|
-
Args:
|
|
498
|
-
session_id (str, optional): The ID of the chat session.
|
|
499
|
-
usage (Dict[str, int], optional): Information about the usage of
|
|
500
|
-
the LLM.
|
|
501
|
-
termination_reasons (List[str]): The reasons for the termination
|
|
502
|
-
of the chat session.
|
|
503
|
-
num_tokens (int): The number of tokens used in the chat session.
|
|
504
|
-
tool_calls (List[ToolCallingRecord]): The list of function
|
|
505
|
-
calling records, containing the information of called tools.
|
|
506
|
-
external_tool_request
|
|
507
|
-
(Optional[ChatCompletionMessageToolCall], optional):
|
|
508
|
-
The tool calling request of external tools from the model.
|
|
509
|
-
These requests are directly returned to the user instead of
|
|
510
|
-
being processed by the agent automatically.
|
|
511
|
-
(default: :obj:`None`)
|
|
512
|
-
|
|
513
|
-
Returns:
|
|
514
|
-
Dict[str, Any]: The chat session information.
|
|
515
|
-
"""
|
|
516
|
-
return {
|
|
517
|
-
"id": session_id,
|
|
518
|
-
"usage": usage,
|
|
519
|
-
"termination_reasons": termination_reasons,
|
|
520
|
-
"num_tokens": num_tokens,
|
|
521
|
-
"tool_calls": tool_calls,
|
|
522
|
-
"external_tool_request": external_tool_request,
|
|
523
|
-
}
|
|
524
|
-
|
|
525
355
|
def init_messages(self) -> None:
|
|
526
356
|
r"""Initializes the stored messages list with the current system
|
|
527
357
|
message.
|
|
528
358
|
"""
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
role_at_backend=OpenAIBackendRole.SYSTEM,
|
|
533
|
-
)
|
|
534
|
-
self.memory.clear()
|
|
535
|
-
self.memory.write_record(system_record)
|
|
536
|
-
else:
|
|
537
|
-
self.memory.clear()
|
|
359
|
+
self.memory.clear()
|
|
360
|
+
if self.system_message is not None:
|
|
361
|
+
self.update_memory(self.system_message, OpenAIBackendRole.SYSTEM)
|
|
538
362
|
|
|
539
363
|
def record_message(self, message: BaseMessage) -> None:
|
|
540
364
|
r"""Records the externally provided message into the agent memory as if
|
|
@@ -547,10 +371,82 @@ class ChatAgent(BaseAgent):
|
|
|
547
371
|
"""
|
|
548
372
|
self.update_memory(message, OpenAIBackendRole.ASSISTANT)
|
|
549
373
|
|
|
374
|
+
def _try_format_message(
|
|
375
|
+
self, message: BaseMessage, response_format: Type[BaseModel]
|
|
376
|
+
) -> bool:
|
|
377
|
+
r"""Try to format the message if needed.
|
|
378
|
+
|
|
379
|
+
Returns:
|
|
380
|
+
bool: Whether the message is formatted successfully (or no format
|
|
381
|
+
is needed).
|
|
382
|
+
"""
|
|
383
|
+
if message.parsed:
|
|
384
|
+
return True
|
|
385
|
+
|
|
386
|
+
try:
|
|
387
|
+
message.parsed = response_format.model_validate_json(
|
|
388
|
+
message.content
|
|
389
|
+
)
|
|
390
|
+
return True
|
|
391
|
+
except ValidationError:
|
|
392
|
+
return False
|
|
393
|
+
|
|
394
|
+
def _format_response_if_needed(
|
|
395
|
+
self,
|
|
396
|
+
response: ModelResponse,
|
|
397
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
398
|
+
) -> None:
|
|
399
|
+
r"""Format the response if needed.
|
|
400
|
+
|
|
401
|
+
This function won't format the response under the following cases:
|
|
402
|
+
1. The response format is None (not provided)
|
|
403
|
+
2. The response is empty
|
|
404
|
+
"""
|
|
405
|
+
if response_format is None:
|
|
406
|
+
return
|
|
407
|
+
|
|
408
|
+
for message in response.output_messages:
|
|
409
|
+
if self._try_format_message(message, response_format):
|
|
410
|
+
continue
|
|
411
|
+
|
|
412
|
+
prompt = SIMPLE_FORMAT_PROMPT.format(content=message.content)
|
|
413
|
+
openai_message: OpenAIMessage = {"role": "user", "content": prompt}
|
|
414
|
+
# Explicitly set the tools to empty list to avoid calling tools
|
|
415
|
+
response = self._get_model_response(
|
|
416
|
+
[openai_message], 0, response_format, []
|
|
417
|
+
)
|
|
418
|
+
message.content = response.output_messages[0].content
|
|
419
|
+
if not self._try_format_message(message, response_format):
|
|
420
|
+
logger.warning(f"Failed to parse response: {message.content}")
|
|
421
|
+
|
|
422
|
+
async def _aformat_response_if_needed(
|
|
423
|
+
self,
|
|
424
|
+
response: ModelResponse,
|
|
425
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
426
|
+
) -> None:
|
|
427
|
+
r"""Format the response if needed."""
|
|
428
|
+
|
|
429
|
+
if response_format is None:
|
|
430
|
+
return
|
|
431
|
+
|
|
432
|
+
for message in response.output_messages:
|
|
433
|
+
self._try_format_message(message, response_format)
|
|
434
|
+
if message.parsed:
|
|
435
|
+
continue
|
|
436
|
+
|
|
437
|
+
prompt = SIMPLE_FORMAT_PROMPT.format(content=message.content)
|
|
438
|
+
openai_message: OpenAIMessage = {"role": "user", "content": prompt}
|
|
439
|
+
response = await self._aget_model_response(
|
|
440
|
+
[openai_message], 0, response_format, []
|
|
441
|
+
)
|
|
442
|
+
message.content = response.output_messages[0].content
|
|
443
|
+
self._try_format_message(message, response_format)
|
|
444
|
+
|
|
550
445
|
def step(
|
|
551
446
|
self,
|
|
552
447
|
input_message: Union[BaseMessage, str],
|
|
553
448
|
response_format: Optional[Type[BaseModel]] = None,
|
|
449
|
+
reason_params: Optional[Dict[str, Any]] = None,
|
|
554
450
|
) -> ChatAgentResponse:
|
|
555
451
|
r"""Executes a single step in the chat session, generating a response
|
|
556
452
|
to the input message.
|
|
@@ -563,295 +459,116 @@ class ChatAgent(BaseAgent):
|
|
|
563
459
|
model defining the expected structure of the response. Used to
|
|
564
460
|
generate a structured response if provided. (default:
|
|
565
461
|
:obj:`None`)
|
|
462
|
+
reason_params (Optional[Dict[str, Any]], optional): A dictionary
|
|
463
|
+
containing the parameters for the reasoning step.
|
|
464
|
+
Argument `choices` is the number of choices/candidates to
|
|
465
|
+
consider.
|
|
466
|
+
Argument `threshold` is the threshold for the probability of
|
|
467
|
+
the choices.
|
|
468
|
+
(default: :obj:`None`)
|
|
566
469
|
|
|
567
470
|
Returns:
|
|
568
471
|
ChatAgentResponse: Contains output messages, a termination status
|
|
569
472
|
flag, and session information.
|
|
570
473
|
"""
|
|
571
474
|
|
|
572
|
-
if (
|
|
573
|
-
self.model_backend.model_config_dict.get("response_format")
|
|
574
|
-
and response_format
|
|
575
|
-
):
|
|
576
|
-
logger.warning(
|
|
577
|
-
f"Overriding the response format with {response_format}."
|
|
578
|
-
)
|
|
579
|
-
|
|
580
|
-
self.original_model_dict = self.model_backend.model_config_dict
|
|
581
|
-
model_response_format_modified = False
|
|
582
|
-
if (
|
|
583
|
-
response_format
|
|
584
|
-
and self.model_type.support_native_structured_output
|
|
585
|
-
):
|
|
586
|
-
self.model_backend.model_config_dict = (
|
|
587
|
-
self.original_model_dict.copy()
|
|
588
|
-
)
|
|
589
|
-
self.model_backend.model_config_dict["response_format"] = (
|
|
590
|
-
response_format
|
|
591
|
-
)
|
|
592
|
-
model_response_format_modified = True
|
|
593
|
-
|
|
594
475
|
# Convert input message to BaseMessage if necessary
|
|
595
476
|
if isinstance(input_message, str):
|
|
596
477
|
input_message = BaseMessage.make_user_message(
|
|
597
|
-
role_name=
|
|
478
|
+
role_name="User", content=input_message
|
|
598
479
|
)
|
|
599
480
|
|
|
600
|
-
#
|
|
601
|
-
|
|
602
|
-
self.is_tools_added()
|
|
603
|
-
and not self.model_type.support_native_tool_calling
|
|
604
|
-
and not self.tool_prompt_added
|
|
605
|
-
):
|
|
606
|
-
self._inject_tool_prompt()
|
|
481
|
+
# Inject thinking steps
|
|
482
|
+
input_message = self._update_reasoning(input_message, reason_params)
|
|
607
483
|
|
|
608
484
|
# Add user input to memory
|
|
609
485
|
self.update_memory(input_message, OpenAIBackendRole.USER)
|
|
610
486
|
|
|
611
|
-
try:
|
|
612
|
-
return self._handle_step(response_format, self.single_iteration)
|
|
613
|
-
finally:
|
|
614
|
-
if model_response_format_modified:
|
|
615
|
-
# Reset model config back to original state
|
|
616
|
-
self.model_backend.model_config_dict = self.original_model_dict
|
|
617
|
-
|
|
618
|
-
def _inject_tool_prompt(self) -> None:
|
|
619
|
-
r"""Generate and add the tool prompt to memory."""
|
|
620
|
-
tool_prompt = self._generate_tool_prompt(
|
|
621
|
-
self.model_backend.model_config_dict["tools"]
|
|
622
|
-
)
|
|
623
|
-
tool_msg = BaseMessage.make_assistant_message(
|
|
624
|
-
role_name="Assistant", content=tool_prompt
|
|
625
|
-
)
|
|
626
|
-
self.update_memory(tool_msg, OpenAIBackendRole.SYSTEM)
|
|
627
|
-
self.tool_prompt_added = True
|
|
628
|
-
|
|
629
|
-
def _handle_step(
|
|
630
|
-
self,
|
|
631
|
-
response_format: Optional[Type[BaseModel]],
|
|
632
|
-
single_step: bool,
|
|
633
|
-
) -> ChatAgentResponse:
|
|
634
|
-
r"""Handles a single or multi-step interaction."""
|
|
635
|
-
|
|
636
|
-
if (
|
|
637
|
-
self.model_backend.model_config_dict.get("tool_choice")
|
|
638
|
-
== "required"
|
|
639
|
-
and not single_step
|
|
640
|
-
):
|
|
641
|
-
raise ValueError(
|
|
642
|
-
"`tool_choice` cannot be set to `required` for multi-step"
|
|
643
|
-
" mode. To proceed, set `single_iteration` to `True`."
|
|
644
|
-
)
|
|
645
|
-
|
|
646
|
-
# Record function calls made during the session
|
|
647
487
|
tool_call_records: List[ToolCallingRecord] = []
|
|
648
|
-
|
|
649
|
-
external_tool_request = None
|
|
488
|
+
external_tool_call_request: Optional[ToolCallRequest] = None
|
|
650
489
|
|
|
651
490
|
while True:
|
|
652
491
|
try:
|
|
653
492
|
openai_messages, num_tokens = self.memory.get_context()
|
|
654
493
|
except RuntimeError as e:
|
|
655
|
-
self.model_backend.model_config_dict = self.original_model_dict
|
|
656
494
|
return self._step_token_exceed(
|
|
657
495
|
e.args[1], tool_call_records, "max_tokens_exceeded"
|
|
658
496
|
)
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
497
|
+
# Get response from model backend
|
|
498
|
+
response = self._get_model_response(
|
|
499
|
+
openai_messages,
|
|
500
|
+
num_tokens,
|
|
501
|
+
response_format,
|
|
502
|
+
self._get_full_tool_schemas(),
|
|
664
503
|
)
|
|
665
504
|
|
|
666
|
-
if
|
|
667
|
-
# update last openai message
|
|
668
|
-
usr_msg = openai_messages.pop()
|
|
669
|
-
usr_msg["content"] = generate_prompt_for_structured_output(
|
|
670
|
-
response_format,
|
|
671
|
-
usr_msg["content"], # type: ignore [arg-type]
|
|
672
|
-
)
|
|
673
|
-
openai_messages.append(usr_msg)
|
|
674
|
-
|
|
675
|
-
# Process model response
|
|
676
|
-
(
|
|
677
|
-
response,
|
|
678
|
-
output_messages,
|
|
679
|
-
finish_reasons,
|
|
680
|
-
usage_dict,
|
|
681
|
-
response_id,
|
|
682
|
-
) = self._step_model_response(openai_messages, num_tokens)
|
|
683
|
-
|
|
684
|
-
# Try to parse structured output to return a Pydantic object
|
|
685
|
-
if inject_prompt_for_structured_output and isinstance(
|
|
686
|
-
response, ChatCompletion
|
|
687
|
-
):
|
|
688
|
-
content = response.choices[0].message.content
|
|
689
|
-
try:
|
|
690
|
-
json_content = json.loads(str(content))
|
|
691
|
-
output_messages[0].parsed = response_format(**json_content) # type: ignore [assignment, misc]
|
|
692
|
-
except json.JSONDecodeError as e:
|
|
693
|
-
logger.error(
|
|
694
|
-
f"Failed in parsing the output into JSON: {e}"
|
|
695
|
-
)
|
|
696
|
-
output_messages[0].parsed = None
|
|
697
|
-
except ValidationError as e:
|
|
698
|
-
logger.warning(
|
|
699
|
-
"Successfully generating JSON response, "
|
|
700
|
-
"but failed in parsing it into Pydantic object :"
|
|
701
|
-
f"{e}, return the JSON response in parsed field"
|
|
702
|
-
)
|
|
703
|
-
output_messages[0].parsed = json_content
|
|
704
|
-
|
|
705
|
-
# Finalize on standard response in multi-step mode
|
|
706
|
-
if self._is_standard_response(response):
|
|
505
|
+
if self.single_iteration:
|
|
707
506
|
break
|
|
708
507
|
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
tool_call_records.append(
|
|
714
|
-
self._step_tool_call_and_update(response)
|
|
715
|
-
)
|
|
508
|
+
if tool_call_request := response.tool_call_request:
|
|
509
|
+
if tool_call_request.tool_name in self._external_tool_schemas:
|
|
510
|
+
external_tool_call_request = tool_call_request
|
|
511
|
+
break
|
|
716
512
|
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
info = self._step_get_info(
|
|
720
|
-
output_messages,
|
|
721
|
-
finish_reasons,
|
|
722
|
-
usage_dict,
|
|
723
|
-
response_id,
|
|
724
|
-
tool_call_records,
|
|
725
|
-
num_tokens,
|
|
726
|
-
tool_request,
|
|
727
|
-
)
|
|
728
|
-
self._log_final_output(output_messages)
|
|
729
|
-
self.model_backend.model_config_dict = (
|
|
730
|
-
self.original_model_dict
|
|
731
|
-
)
|
|
732
|
-
return ChatAgentResponse(
|
|
733
|
-
msgs=output_messages,
|
|
734
|
-
terminated=self.terminated,
|
|
735
|
-
info=info,
|
|
736
|
-
)
|
|
737
|
-
|
|
738
|
-
# Single-step mode ends after one iteration
|
|
739
|
-
if single_step:
|
|
740
|
-
break
|
|
513
|
+
tool_call_records.append(self._execute_tool(tool_call_request))
|
|
514
|
+
continue
|
|
741
515
|
|
|
742
|
-
|
|
743
|
-
if (
|
|
744
|
-
response_format
|
|
745
|
-
and not inject_prompt_for_structured_output
|
|
746
|
-
and self.model_type
|
|
747
|
-
not in {
|
|
748
|
-
"gpt-4o",
|
|
749
|
-
"gpt-4o-mini",
|
|
750
|
-
}
|
|
751
|
-
):
|
|
752
|
-
(
|
|
753
|
-
output_messages,
|
|
754
|
-
finish_reasons,
|
|
755
|
-
usage_dict,
|
|
756
|
-
response_id,
|
|
757
|
-
tool_call,
|
|
758
|
-
num_tokens,
|
|
759
|
-
) = self._structure_output_with_function(response_format)
|
|
760
|
-
tool_call_records.append(tool_call)
|
|
516
|
+
break
|
|
761
517
|
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
response_id,
|
|
768
|
-
tool_call_records,
|
|
769
|
-
num_tokens,
|
|
770
|
-
external_tool_request,
|
|
771
|
-
)
|
|
772
|
-
self._log_final_output(output_messages)
|
|
773
|
-
self.model_backend.model_config_dict = self.original_model_dict
|
|
774
|
-
return ChatAgentResponse(
|
|
775
|
-
msgs=output_messages, terminated=self.terminated, info=info
|
|
518
|
+
self._format_response_if_needed(response, response_format)
|
|
519
|
+
self._record_final_output(response.output_messages)
|
|
520
|
+
|
|
521
|
+
return self._convert_to_chatagent_response(
|
|
522
|
+
response, tool_call_records, num_tokens, external_tool_call_request
|
|
776
523
|
)
|
|
777
524
|
|
|
778
|
-
def
|
|
779
|
-
self,
|
|
780
|
-
|
|
781
|
-
|
|
525
|
+
def _update_reasoning(
|
|
526
|
+
self,
|
|
527
|
+
input_message: BaseMessage,
|
|
528
|
+
reason_params: Optional[Dict[str, Any]] = None,
|
|
529
|
+
) -> BaseMessage:
|
|
530
|
+
r"""Updates the input message to include reasoning instructions and
|
|
531
|
+
adds human interaction capability.
|
|
782
532
|
|
|
783
533
|
Args:
|
|
784
|
-
|
|
534
|
+
input_message (BaseMessage): The message to be updated with
|
|
535
|
+
reasoning instructions.
|
|
536
|
+
reason_params (Optional[Dict[str, Any]], optional): Parameters for
|
|
537
|
+
the reasoning process.
|
|
785
538
|
|
|
786
539
|
Returns:
|
|
787
|
-
|
|
788
|
-
present, otherwise None.
|
|
540
|
+
BaseMessage: The updated message with reasoning instructions.
|
|
789
541
|
"""
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
"'", '"'
|
|
805
|
-
),
|
|
806
|
-
name=str(parsed_content["function"]),
|
|
807
|
-
),
|
|
808
|
-
type="function",
|
|
809
|
-
)
|
|
810
|
-
elif (
|
|
811
|
-
self.is_tools_added()
|
|
812
|
-
and self.model_type.support_native_tool_calling
|
|
813
|
-
and response.choices[0].message.tool_calls
|
|
814
|
-
):
|
|
815
|
-
return response.choices[0].message.tool_calls[0]
|
|
816
|
-
|
|
817
|
-
# No tool call found
|
|
818
|
-
return None
|
|
819
|
-
|
|
820
|
-
def _is_standard_response(self, response: Any) -> bool:
|
|
821
|
-
r"""Determine if the provided response is a standard reply without
|
|
822
|
-
tool calls.
|
|
823
|
-
|
|
824
|
-
Args:
|
|
825
|
-
response (Any): The response object to evaluate.
|
|
826
|
-
|
|
827
|
-
Returns:
|
|
828
|
-
bool: `True` if the response is a standard reply, `False`
|
|
829
|
-
otherwise.
|
|
542
|
+
if reason_params is None:
|
|
543
|
+
return input_message
|
|
544
|
+
choices = reason_params.get("choices", 3)
|
|
545
|
+
threshold = reason_params.get("threshold", 0.5)
|
|
546
|
+
|
|
547
|
+
input_message.content += f"""First, come up with potential {choices}
|
|
548
|
+
choices/candidates.
|
|
549
|
+
Next, assign a probability/credibility between 0 and 1 to each choice
|
|
550
|
+
(make sure they add up to 1).
|
|
551
|
+
Finally, if only one choice has a probability/credibility greater than
|
|
552
|
+
{threshold}, continue with that choice.
|
|
553
|
+
Otherwise, call tool `ask_human_via_console` to ask the user to decide
|
|
554
|
+
which one to continue with, give user the probability/credibility of
|
|
555
|
+
all choices, and the reason for each choice.
|
|
830
556
|
"""
|
|
831
|
-
if not self.is_tools_added():
|
|
832
|
-
return True
|
|
833
557
|
|
|
834
|
-
|
|
835
|
-
|
|
558
|
+
# Add tools to agent
|
|
559
|
+
from camel.toolkits.human_toolkit import HumanToolkit
|
|
836
560
|
|
|
837
|
-
|
|
838
|
-
|
|
561
|
+
human_toolkit = HumanToolkit()
|
|
562
|
+
self.add_tool(human_toolkit.ask_human_via_console)
|
|
839
563
|
|
|
840
|
-
return
|
|
841
|
-
response.choices[0].message.content or ""
|
|
842
|
-
)
|
|
564
|
+
return input_message
|
|
843
565
|
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
else:
|
|
849
|
-
logger.warning(
|
|
850
|
-
"Multiple messages returned in `step()`. Record "
|
|
851
|
-
"selected message manually using `record_message()`."
|
|
852
|
-
)
|
|
566
|
+
@property
|
|
567
|
+
def chat_history(self) -> List[OpenAIMessage]:
|
|
568
|
+
openai_messages, _ = self.memory.get_context()
|
|
569
|
+
return openai_messages
|
|
853
570
|
|
|
854
|
-
async def
|
|
571
|
+
async def astep(
|
|
855
572
|
self,
|
|
856
573
|
input_message: Union[BaseMessage, str],
|
|
857
574
|
response_format: Optional[Type[BaseModel]] = None,
|
|
@@ -879,12 +596,13 @@ class ChatAgent(BaseAgent):
|
|
|
879
596
|
"""
|
|
880
597
|
if isinstance(input_message, str):
|
|
881
598
|
input_message = BaseMessage.make_user_message(
|
|
882
|
-
role_name=
|
|
599
|
+
role_name="User", content=input_message
|
|
883
600
|
)
|
|
884
601
|
|
|
885
602
|
self.update_memory(input_message, OpenAIBackendRole.USER)
|
|
886
603
|
|
|
887
604
|
tool_call_records: List[ToolCallingRecord] = []
|
|
605
|
+
external_tool_call_request: Optional[ToolCallRequest] = None
|
|
888
606
|
while True:
|
|
889
607
|
try:
|
|
890
608
|
openai_messages, num_tokens = self.memory.get_context()
|
|
@@ -893,224 +611,146 @@ class ChatAgent(BaseAgent):
|
|
|
893
611
|
e.args[1], tool_call_records, "max_tokens_exceeded"
|
|
894
612
|
)
|
|
895
613
|
|
|
896
|
-
(
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
if (
|
|
905
|
-
not self.is_tools_added()
|
|
906
|
-
or not isinstance(response, ChatCompletion)
|
|
907
|
-
or not response.choices[0].message.tool_calls
|
|
908
|
-
):
|
|
614
|
+
response = await self._aget_model_response(
|
|
615
|
+
openai_messages,
|
|
616
|
+
num_tokens,
|
|
617
|
+
response_format,
|
|
618
|
+
self._get_full_tool_schemas(),
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
if self.single_iteration:
|
|
909
622
|
break
|
|
910
623
|
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
info = self._step_get_info(
|
|
916
|
-
output_messages,
|
|
917
|
-
finish_reasons,
|
|
918
|
-
usage_dict,
|
|
919
|
-
response_id,
|
|
920
|
-
tool_call_records,
|
|
921
|
-
num_tokens,
|
|
922
|
-
external_tool_request,
|
|
923
|
-
)
|
|
924
|
-
return ChatAgentResponse(
|
|
925
|
-
msgs=output_messages, terminated=self.terminated, info=info
|
|
926
|
-
)
|
|
624
|
+
if tool_call_request := response.tool_call_request:
|
|
625
|
+
if tool_call_request.tool_name in self._external_tool_schemas:
|
|
626
|
+
external_tool_call_request = tool_call_request
|
|
627
|
+
break
|
|
927
628
|
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
)
|
|
629
|
+
tool_call_record = await self._aexecute_tool(tool_call_request)
|
|
630
|
+
tool_call_records.append(tool_call_record)
|
|
631
|
+
continue
|
|
932
632
|
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
)
|
|
937
|
-
(
|
|
938
|
-
output_messages,
|
|
939
|
-
finish_reasons,
|
|
940
|
-
usage_dict,
|
|
941
|
-
response_id,
|
|
942
|
-
tool_call_record,
|
|
943
|
-
num_tokens,
|
|
944
|
-
) = self._structure_output_with_function(response_format)
|
|
945
|
-
tool_call_records.append(tool_call_record)
|
|
633
|
+
break
|
|
634
|
+
|
|
635
|
+
await self._aformat_response_if_needed(response, response_format)
|
|
636
|
+
self._record_final_output(response.output_messages)
|
|
946
637
|
|
|
638
|
+
return self._convert_to_chatagent_response(
|
|
639
|
+
response, tool_call_records, num_tokens, external_tool_call_request
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
def _convert_to_chatagent_response(
|
|
643
|
+
self,
|
|
644
|
+
response: ModelResponse,
|
|
645
|
+
tool_call_records: List[ToolCallingRecord],
|
|
646
|
+
num_tokens: int,
|
|
647
|
+
external_tool_call_request: Optional[ToolCallRequest],
|
|
648
|
+
) -> ChatAgentResponse:
|
|
649
|
+
r"""Parse the final model response into the chat agent response."""
|
|
947
650
|
info = self._step_get_info(
|
|
948
|
-
output_messages,
|
|
949
|
-
finish_reasons,
|
|
950
|
-
usage_dict,
|
|
951
|
-
response_id,
|
|
651
|
+
response.output_messages,
|
|
652
|
+
response.finish_reasons,
|
|
653
|
+
response.usage_dict,
|
|
654
|
+
response.response_id,
|
|
952
655
|
tool_call_records,
|
|
953
656
|
num_tokens,
|
|
657
|
+
external_tool_call_request,
|
|
954
658
|
)
|
|
955
659
|
|
|
660
|
+
return ChatAgentResponse(
|
|
661
|
+
msgs=response.output_messages,
|
|
662
|
+
terminated=self.terminated,
|
|
663
|
+
info=info,
|
|
664
|
+
)
|
|
665
|
+
|
|
666
|
+
def _record_final_output(self, output_messages: List[BaseMessage]) -> None:
|
|
667
|
+
r"""Log final messages or warnings about multiple responses."""
|
|
956
668
|
if len(output_messages) == 1:
|
|
957
|
-
# Auto record if the output result is a single message
|
|
958
669
|
self.record_message(output_messages[0])
|
|
959
670
|
else:
|
|
960
671
|
logger.warning(
|
|
961
|
-
"Multiple messages returned in `step()
|
|
962
|
-
"
|
|
963
|
-
"record the selected message manually."
|
|
672
|
+
"Multiple messages returned in `step()`. Record "
|
|
673
|
+
"selected message manually using `record_message()`."
|
|
964
674
|
)
|
|
965
675
|
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
) ->
|
|
973
|
-
r"""
|
|
974
|
-
records the function call in the provided list of tool calls and
|
|
975
|
-
updates the memory of the current agent.
|
|
676
|
+
def _get_model_response(
|
|
677
|
+
self,
|
|
678
|
+
openai_messages: List[OpenAIMessage],
|
|
679
|
+
num_tokens: int,
|
|
680
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
681
|
+
tool_schemas: Optional[List[Dict[str, Any]]] = None,
|
|
682
|
+
) -> ModelResponse:
|
|
683
|
+
r"""Internal function for agent step model response."""
|
|
976
684
|
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
685
|
+
response = None
|
|
686
|
+
try:
|
|
687
|
+
response = self.model_backend.run(
|
|
688
|
+
openai_messages, response_format, tool_schemas or None
|
|
689
|
+
)
|
|
690
|
+
except Exception as exc:
|
|
691
|
+
logger.error(
|
|
692
|
+
f"An error occurred while running model "
|
|
693
|
+
f"{self.model_backend.model_type}, "
|
|
694
|
+
f"index: {self.model_backend.current_model_index}",
|
|
695
|
+
exc_info=exc,
|
|
696
|
+
)
|
|
697
|
+
error_info = str(exc)
|
|
980
698
|
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
699
|
+
if not response and self.model_backend.num_models > 1:
|
|
700
|
+
raise ModelProcessingError(
|
|
701
|
+
"Unable to process messages: none of the provided models "
|
|
702
|
+
"run succesfully."
|
|
703
|
+
)
|
|
704
|
+
elif not response:
|
|
705
|
+
raise ModelProcessingError(
|
|
706
|
+
f"Unable to process messages: the only provided model "
|
|
707
|
+
f"did not run succesfully. Error: {error_info}"
|
|
708
|
+
)
|
|
984
709
|
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
self.
|
|
710
|
+
logger.info(
|
|
711
|
+
f"Model {self.model_backend.model_type}, "
|
|
712
|
+
f"index {self.model_backend.current_model_index}, "
|
|
713
|
+
f"processed these messages: {openai_messages}"
|
|
988
714
|
)
|
|
989
715
|
|
|
990
|
-
# Update the messages
|
|
991
|
-
self.update_memory(func_assistant_msg, OpenAIBackendRole.ASSISTANT)
|
|
992
|
-
self.update_memory(func_result_msg, OpenAIBackendRole.FUNCTION)
|
|
993
|
-
|
|
994
|
-
return tool_call_record
|
|
995
|
-
|
|
996
|
-
async def _step_tool_call_and_update_async(
|
|
997
|
-
self, response: ChatCompletion
|
|
998
|
-
) -> ToolCallingRecord:
|
|
999
|
-
(
|
|
1000
|
-
func_assistant_msg,
|
|
1001
|
-
func_result_msg,
|
|
1002
|
-
func_record,
|
|
1003
|
-
) = await self.step_tool_call_async(response)
|
|
1004
|
-
|
|
1005
|
-
self.update_memory(func_assistant_msg, OpenAIBackendRole.ASSISTANT)
|
|
1006
|
-
self.update_memory(func_result_msg, OpenAIBackendRole.FUNCTION)
|
|
1007
|
-
|
|
1008
|
-
return func_record
|
|
1009
|
-
|
|
1010
|
-
def _structure_output_with_function(
|
|
1011
|
-
self, response_format: Type[BaseModel]
|
|
1012
|
-
) -> Tuple[
|
|
1013
|
-
List[BaseMessage],
|
|
1014
|
-
List[str],
|
|
1015
|
-
Dict[str, int],
|
|
1016
|
-
str,
|
|
1017
|
-
ToolCallingRecord,
|
|
1018
|
-
int,
|
|
1019
|
-
]:
|
|
1020
|
-
r"""Internal function of structuring the output of the agent based on
|
|
1021
|
-
the given output schema.
|
|
1022
|
-
|
|
1023
|
-
Args:
|
|
1024
|
-
response_format (Type[BaseModel]): The output schema to use for
|
|
1025
|
-
structuring the output.
|
|
1026
|
-
|
|
1027
|
-
Returns:
|
|
1028
|
-
Tuple[List[BaseMessage], List[str], Dict[str, int], str,
|
|
1029
|
-
ToolCallingRecord, int]:
|
|
1030
|
-
A tuple containing the output messages, finish reasons, usage
|
|
1031
|
-
dictionary, response ID, function calling record, and number of
|
|
1032
|
-
tokens.
|
|
1033
|
-
"""
|
|
1034
|
-
from camel.toolkits import FunctionTool
|
|
1035
|
-
|
|
1036
|
-
schema_json = get_pydantic_object_schema(response_format)
|
|
1037
|
-
func_str = json_to_function_code(schema_json)
|
|
1038
|
-
func_callable = func_string_to_callable(func_str)
|
|
1039
|
-
func = FunctionTool(func_callable)
|
|
1040
|
-
|
|
1041
|
-
original_model_dict = self.model_backend.model_config_dict
|
|
1042
|
-
|
|
1043
|
-
# Replace the original tools with the structuring function
|
|
1044
|
-
self.tool_dict = {func.get_function_name(): func}
|
|
1045
|
-
self.model_backend.model_config_dict = original_model_dict.copy()
|
|
1046
|
-
self.model_backend.model_config_dict["tools"] = [
|
|
1047
|
-
func.get_openai_tool_schema()
|
|
1048
|
-
]
|
|
1049
|
-
self.model_backend.model_config_dict["tool_choice"] = "required"
|
|
1050
|
-
|
|
1051
|
-
openai_messages, num_tokens = self.memory.get_context()
|
|
1052
|
-
(
|
|
1053
|
-
response,
|
|
1054
|
-
output_messages,
|
|
1055
|
-
finish_reasons,
|
|
1056
|
-
usage_dict,
|
|
1057
|
-
response_id,
|
|
1058
|
-
) = self._step_model_response(openai_messages, num_tokens)
|
|
1059
|
-
|
|
1060
716
|
if isinstance(response, ChatCompletion):
|
|
1061
|
-
|
|
717
|
+
return self._handle_batch_response(response)
|
|
1062
718
|
else:
|
|
1063
|
-
|
|
1064
|
-
"Structured output is not supported for stream responses."
|
|
1065
|
-
)
|
|
719
|
+
return self._handle_stream_response(response, num_tokens)
|
|
1066
720
|
|
|
1067
|
-
|
|
1068
|
-
base_message_item.content = json.dumps(tool_call_record.result)
|
|
1069
|
-
|
|
1070
|
-
# Recover the original tools
|
|
1071
|
-
self.model_backend.model_config_dict = original_model_dict
|
|
1072
|
-
|
|
1073
|
-
return (
|
|
1074
|
-
output_messages,
|
|
1075
|
-
finish_reasons,
|
|
1076
|
-
usage_dict,
|
|
1077
|
-
response_id,
|
|
1078
|
-
tool_call_record,
|
|
1079
|
-
num_tokens,
|
|
1080
|
-
)
|
|
1081
|
-
|
|
1082
|
-
def _step_model_response(
|
|
721
|
+
async def _aget_model_response(
|
|
1083
722
|
self,
|
|
1084
723
|
openai_messages: List[OpenAIMessage],
|
|
1085
724
|
num_tokens: int,
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
List[str],
|
|
1090
|
-
Dict[str, int],
|
|
1091
|
-
str,
|
|
1092
|
-
]:
|
|
725
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
726
|
+
tool_schemas: Optional[List[Dict[str, Any]]] = None,
|
|
727
|
+
) -> ModelResponse:
|
|
1093
728
|
r"""Internal function for agent step model response."""
|
|
1094
729
|
|
|
1095
730
|
response = None
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
if not response:
|
|
731
|
+
try:
|
|
732
|
+
response = await self.model_backend.arun(
|
|
733
|
+
openai_messages, response_format, tool_schemas or None
|
|
734
|
+
)
|
|
735
|
+
except Exception as exc:
|
|
736
|
+
logger.error(
|
|
737
|
+
f"An error occurred while running model "
|
|
738
|
+
f"{self.model_backend.model_type}, "
|
|
739
|
+
f"index: {self.model_backend.current_model_index}",
|
|
740
|
+
exc_info=exc,
|
|
741
|
+
)
|
|
742
|
+
error_info = str(exc)
|
|
743
|
+
|
|
744
|
+
if not response and self.model_backend.num_models > 1:
|
|
1110
745
|
raise ModelProcessingError(
|
|
1111
746
|
"Unable to process messages: none of the provided models "
|
|
1112
747
|
"run succesfully."
|
|
1113
748
|
)
|
|
749
|
+
elif not response:
|
|
750
|
+
raise ModelProcessingError(
|
|
751
|
+
f"Unable to process messages: the only provided model "
|
|
752
|
+
f"did not run succesfully. Error: {error_info}"
|
|
753
|
+
)
|
|
1114
754
|
|
|
1115
755
|
logger.info(
|
|
1116
756
|
f"Model {self.model_backend.model_type}, "
|
|
@@ -1119,20 +759,9 @@ class ChatAgent(BaseAgent):
|
|
|
1119
759
|
)
|
|
1120
760
|
|
|
1121
761
|
if isinstance(response, ChatCompletion):
|
|
1122
|
-
|
|
1123
|
-
self.handle_batch_response(response)
|
|
1124
|
-
)
|
|
762
|
+
return self._handle_batch_response(response)
|
|
1125
763
|
else:
|
|
1126
|
-
|
|
1127
|
-
self.handle_stream_response(response, num_tokens)
|
|
1128
|
-
)
|
|
1129
|
-
return (
|
|
1130
|
-
response,
|
|
1131
|
-
output_messages,
|
|
1132
|
-
finish_reasons,
|
|
1133
|
-
usage_dict,
|
|
1134
|
-
response_id,
|
|
1135
|
-
)
|
|
764
|
+
return await self._ahandle_stream_response(response, num_tokens)
|
|
1136
765
|
|
|
1137
766
|
def _step_get_info(
|
|
1138
767
|
self,
|
|
@@ -1142,7 +771,7 @@ class ChatAgent(BaseAgent):
|
|
|
1142
771
|
response_id: str,
|
|
1143
772
|
tool_calls: List[ToolCallingRecord],
|
|
1144
773
|
num_tokens: int,
|
|
1145
|
-
|
|
774
|
+
external_tool_call_request: Optional[ToolCallRequest] = None,
|
|
1146
775
|
) -> Dict[str, Any]:
|
|
1147
776
|
r"""Process the output of a chat step and gather information about the
|
|
1148
777
|
step.
|
|
@@ -1162,9 +791,8 @@ class ChatAgent(BaseAgent):
|
|
|
1162
791
|
tool_calls (List[ToolCallingRecord]): Records of function calls
|
|
1163
792
|
made during this step.
|
|
1164
793
|
num_tokens (int): The number of tokens used in this step.
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
(default: :obj:`None`)
|
|
794
|
+
external_tool_call_request (Optional[ToolCallRequest]): The
|
|
795
|
+
request for external tool call.
|
|
1168
796
|
|
|
1169
797
|
Returns:
|
|
1170
798
|
Dict[str, Any]: A dictionary containing information about the chat
|
|
@@ -1194,103 +822,114 @@ class ChatAgent(BaseAgent):
|
|
|
1194
822
|
if self.terminated and termination_reason is not None:
|
|
1195
823
|
finish_reasons = [termination_reason] * len(finish_reasons)
|
|
1196
824
|
|
|
1197
|
-
|
|
825
|
+
return get_info_dict(
|
|
1198
826
|
response_id,
|
|
1199
827
|
usage_dict,
|
|
1200
828
|
finish_reasons,
|
|
1201
829
|
num_tokens,
|
|
1202
830
|
tool_calls,
|
|
1203
|
-
|
|
831
|
+
external_tool_call_request,
|
|
1204
832
|
)
|
|
1205
|
-
return info
|
|
1206
833
|
|
|
1207
|
-
def
|
|
834
|
+
def _handle_batch_response(
|
|
1208
835
|
self, response: ChatCompletion
|
|
1209
|
-
) ->
|
|
836
|
+
) -> ModelResponse:
|
|
1210
837
|
r"""Process a batch response from the model and extract the necessary
|
|
1211
838
|
information.
|
|
1212
839
|
|
|
1213
840
|
Args:
|
|
1214
|
-
response (
|
|
841
|
+
response (ChatCompletion): Model response.
|
|
1215
842
|
|
|
1216
843
|
Returns:
|
|
1217
|
-
|
|
1218
|
-
finish reasons, usage dictionary, and response id.
|
|
844
|
+
_ModelResponse: parsed model response.
|
|
1219
845
|
"""
|
|
1220
846
|
output_messages: List[BaseMessage] = []
|
|
1221
847
|
for choice in response.choices:
|
|
848
|
+
meta_dict = {}
|
|
849
|
+
if logprobs_info := handle_logprobs(choice):
|
|
850
|
+
meta_dict["logprobs_info"] = logprobs_info
|
|
851
|
+
|
|
1222
852
|
chat_message = BaseMessage(
|
|
1223
853
|
role_name=self.role_name,
|
|
1224
854
|
role_type=self.role_type,
|
|
1225
|
-
meta_dict=
|
|
855
|
+
meta_dict=meta_dict,
|
|
1226
856
|
content=choice.message.content or "",
|
|
1227
|
-
parsed=getattr(choice.message,
|
|
857
|
+
parsed=getattr(choice.message, "parsed", None),
|
|
1228
858
|
)
|
|
1229
|
-
|
|
1230
|
-
if choice.logprobs is not None:
|
|
1231
|
-
tokens_logprobs = choice.logprobs.content
|
|
1232
|
-
|
|
1233
|
-
if tokens_logprobs is not None:
|
|
1234
|
-
# Extract and structure logprob information
|
|
1235
|
-
logprobs_info = [
|
|
1236
|
-
{
|
|
1237
|
-
"token": token_logprob.token,
|
|
1238
|
-
"logprob": token_logprob.logprob,
|
|
1239
|
-
"top_logprobs": [
|
|
1240
|
-
(top_logprob.token, top_logprob.logprob)
|
|
1241
|
-
for top_logprob in token_logprob.top_logprobs
|
|
1242
|
-
],
|
|
1243
|
-
}
|
|
1244
|
-
for token_logprob in tokens_logprobs
|
|
1245
|
-
]
|
|
1246
|
-
# Ensure meta_dict exists before adding logprobs info
|
|
1247
|
-
if chat_message.meta_dict is None:
|
|
1248
|
-
chat_message.meta_dict = {}
|
|
1249
|
-
chat_message.meta_dict["logprobs_info"] = logprobs_info
|
|
1250
|
-
# Append the processed chat message to output
|
|
859
|
+
|
|
1251
860
|
output_messages.append(chat_message)
|
|
1252
861
|
|
|
1253
862
|
finish_reasons = [
|
|
1254
863
|
str(choice.finish_reason) for choice in response.choices
|
|
1255
864
|
]
|
|
1256
|
-
usage = (
|
|
1257
|
-
self._safe_model_dump(response.usage)
|
|
1258
|
-
if response.usage is not None
|
|
1259
|
-
else {}
|
|
1260
|
-
)
|
|
1261
|
-
return (
|
|
1262
|
-
output_messages,
|
|
1263
|
-
finish_reasons,
|
|
1264
|
-
usage,
|
|
1265
|
-
response.id,
|
|
1266
|
-
)
|
|
1267
865
|
|
|
1268
|
-
|
|
1269
|
-
|
|
866
|
+
usage = {}
|
|
867
|
+
if response.usage is not None:
|
|
868
|
+
usage = safe_model_dump(response.usage)
|
|
869
|
+
|
|
870
|
+
tool_call_request: Optional[ToolCallRequest] = None
|
|
871
|
+
if tool_calls := response.choices[0].message.tool_calls:
|
|
872
|
+
tool_name = tool_calls[0].function.name
|
|
873
|
+
tool_call_id = tool_calls[0].id
|
|
874
|
+
args = json.loads(tool_calls[0].function.arguments)
|
|
875
|
+
tool_call_request = ToolCallRequest(
|
|
876
|
+
tool_name=tool_name, args=args, tool_call_id=tool_call_id
|
|
877
|
+
)
|
|
1270
878
|
|
|
1271
|
-
|
|
1272
|
-
|
|
879
|
+
return ModelResponse(
|
|
880
|
+
response=response,
|
|
881
|
+
tool_call_request=tool_call_request,
|
|
882
|
+
output_messages=output_messages,
|
|
883
|
+
finish_reasons=finish_reasons,
|
|
884
|
+
usage_dict=usage,
|
|
885
|
+
response_id=response.id or "",
|
|
886
|
+
)
|
|
887
|
+
|
|
888
|
+
def _handle_stream_response(
|
|
889
|
+
self,
|
|
890
|
+
response: Stream[ChatCompletionChunk],
|
|
891
|
+
prompt_tokens: int,
|
|
892
|
+
) -> ModelResponse:
|
|
893
|
+
r"""Process a stream response from the model and extract the necessary
|
|
894
|
+
information.
|
|
1273
895
|
|
|
1274
896
|
Args:
|
|
1275
|
-
|
|
897
|
+
response (dict): Model response.
|
|
898
|
+
prompt_tokens (int): Number of input prompt tokens.
|
|
1276
899
|
|
|
1277
900
|
Returns:
|
|
1278
|
-
|
|
901
|
+
_ModelResponse: a parsed model response.
|
|
1279
902
|
"""
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
903
|
+
content_dict: defaultdict = defaultdict(lambda: "")
|
|
904
|
+
finish_reasons_dict: defaultdict = defaultdict(lambda: "")
|
|
905
|
+
output_messages: List[BaseMessage] = []
|
|
906
|
+
response_id: str = ""
|
|
907
|
+
# All choices in one response share one role
|
|
908
|
+
for chunk in response:
|
|
909
|
+
response_id = chunk.id
|
|
910
|
+
self._handle_chunk(
|
|
911
|
+
chunk, content_dict, finish_reasons_dict, output_messages
|
|
912
|
+
)
|
|
913
|
+
finish_reasons = [
|
|
914
|
+
finish_reasons_dict[i] for i in range(len(finish_reasons_dict))
|
|
915
|
+
]
|
|
916
|
+
usage_dict = self.get_usage_dict(output_messages, prompt_tokens)
|
|
917
|
+
|
|
918
|
+
# TODO: Handle tool calls
|
|
919
|
+
return ModelResponse(
|
|
920
|
+
response=response,
|
|
921
|
+
tool_call_request=None,
|
|
922
|
+
output_messages=output_messages,
|
|
923
|
+
finish_reasons=finish_reasons,
|
|
924
|
+
usage_dict=usage_dict,
|
|
925
|
+
response_id=response_id,
|
|
926
|
+
)
|
|
1288
927
|
|
|
1289
|
-
def
|
|
928
|
+
async def _ahandle_stream_response(
|
|
1290
929
|
self,
|
|
1291
|
-
response:
|
|
930
|
+
response: AsyncStream[ChatCompletionChunk],
|
|
1292
931
|
prompt_tokens: int,
|
|
1293
|
-
) ->
|
|
932
|
+
) -> ModelResponse:
|
|
1294
933
|
r"""Process a stream response from the model and extract the necessary
|
|
1295
934
|
information.
|
|
1296
935
|
|
|
@@ -1299,37 +938,58 @@ class ChatAgent(BaseAgent):
|
|
|
1299
938
|
prompt_tokens (int): Number of input prompt tokens.
|
|
1300
939
|
|
|
1301
940
|
Returns:
|
|
1302
|
-
|
|
1303
|
-
finish reasons, usage dictionary, and response id.
|
|
941
|
+
_ModelResponse: a parsed model response.
|
|
1304
942
|
"""
|
|
1305
943
|
content_dict: defaultdict = defaultdict(lambda: "")
|
|
1306
944
|
finish_reasons_dict: defaultdict = defaultdict(lambda: "")
|
|
1307
945
|
output_messages: List[BaseMessage] = []
|
|
1308
946
|
response_id: str = ""
|
|
1309
947
|
# All choices in one response share one role
|
|
1310
|
-
for chunk in response:
|
|
948
|
+
async for chunk in response:
|
|
1311
949
|
response_id = chunk.id
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
if delta.content is not None:
|
|
1316
|
-
# When response has not been stopped
|
|
1317
|
-
# Notice that only the first chunk_dict has the "role"
|
|
1318
|
-
content_dict[index] += delta.content
|
|
1319
|
-
if choice.finish_reason:
|
|
1320
|
-
finish_reasons_dict[index] = choice.finish_reason
|
|
1321
|
-
chat_message = BaseMessage(
|
|
1322
|
-
role_name=self.role_name,
|
|
1323
|
-
role_type=self.role_type,
|
|
1324
|
-
meta_dict=dict(),
|
|
1325
|
-
content=content_dict[index],
|
|
1326
|
-
)
|
|
1327
|
-
output_messages.append(chat_message)
|
|
950
|
+
self._handle_chunk(
|
|
951
|
+
chunk, content_dict, finish_reasons_dict, output_messages
|
|
952
|
+
)
|
|
1328
953
|
finish_reasons = [
|
|
1329
954
|
finish_reasons_dict[i] for i in range(len(finish_reasons_dict))
|
|
1330
955
|
]
|
|
1331
956
|
usage_dict = self.get_usage_dict(output_messages, prompt_tokens)
|
|
1332
|
-
|
|
957
|
+
|
|
958
|
+
# TODO: Handle tool calls
|
|
959
|
+
return ModelResponse(
|
|
960
|
+
response=response,
|
|
961
|
+
tool_call_request=None,
|
|
962
|
+
output_messages=output_messages,
|
|
963
|
+
finish_reasons=finish_reasons,
|
|
964
|
+
usage_dict=usage_dict,
|
|
965
|
+
response_id=response_id,
|
|
966
|
+
)
|
|
967
|
+
|
|
968
|
+
def _handle_chunk(
|
|
969
|
+
self,
|
|
970
|
+
chunk: ChatCompletionChunk,
|
|
971
|
+
content_dict: defaultdict,
|
|
972
|
+
finish_reasons_dict: defaultdict,
|
|
973
|
+
output_messages: List[BaseMessage],
|
|
974
|
+
) -> None:
|
|
975
|
+
r"""Handle a chunk of the model response."""
|
|
976
|
+
for choice in chunk.choices:
|
|
977
|
+
index = choice.index
|
|
978
|
+
delta = choice.delta
|
|
979
|
+
if delta.content is not None:
|
|
980
|
+
content_dict[index] += delta.content
|
|
981
|
+
|
|
982
|
+
if not choice.finish_reason:
|
|
983
|
+
continue
|
|
984
|
+
|
|
985
|
+
finish_reasons_dict[index] = choice.finish_reason
|
|
986
|
+
chat_message = BaseMessage(
|
|
987
|
+
role_name=self.role_name,
|
|
988
|
+
role_type=self.role_type,
|
|
989
|
+
meta_dict=dict(),
|
|
990
|
+
content=content_dict[index],
|
|
991
|
+
)
|
|
992
|
+
output_messages.append(chat_message)
|
|
1333
993
|
|
|
1334
994
|
def _step_token_exceed(
|
|
1335
995
|
self,
|
|
@@ -1351,9 +1011,8 @@ class ChatAgent(BaseAgent):
|
|
|
1351
1011
|
information about token number and called functions.
|
|
1352
1012
|
"""
|
|
1353
1013
|
self.terminated = True
|
|
1354
|
-
output_messages: List[BaseMessage] = []
|
|
1355
1014
|
|
|
1356
|
-
info =
|
|
1015
|
+
info = get_info_dict(
|
|
1357
1016
|
None,
|
|
1358
1017
|
None,
|
|
1359
1018
|
[termination_reason],
|
|
@@ -1362,111 +1021,53 @@ class ChatAgent(BaseAgent):
|
|
|
1362
1021
|
)
|
|
1363
1022
|
|
|
1364
1023
|
return ChatAgentResponse(
|
|
1365
|
-
msgs=
|
|
1024
|
+
msgs=[],
|
|
1366
1025
|
terminated=self.terminated,
|
|
1367
1026
|
info=info,
|
|
1368
1027
|
)
|
|
1369
1028
|
|
|
1370
|
-
def
|
|
1029
|
+
def _execute_tool(
|
|
1371
1030
|
self,
|
|
1372
|
-
|
|
1373
|
-
) ->
|
|
1374
|
-
|
|
1375
|
-
]:
|
|
1376
|
-
r"""Execute the function with arguments following the model's response.
|
|
1031
|
+
tool_call_request: ToolCallRequest,
|
|
1032
|
+
) -> ToolCallingRecord:
|
|
1033
|
+
r"""Execute the tool with arguments following the model's response.
|
|
1377
1034
|
|
|
1378
1035
|
Args:
|
|
1379
|
-
|
|
1380
|
-
model.
|
|
1036
|
+
tool_call_request (_ToolCallRequest): The tool call request.
|
|
1381
1037
|
|
|
1382
1038
|
Returns:
|
|
1383
|
-
|
|
1384
|
-
one about the arguments and the other about the execution
|
|
1385
|
-
result, and a struct for logging information about this
|
|
1039
|
+
FunctionCallingRecord: A struct for logging information about this
|
|
1386
1040
|
function call.
|
|
1387
1041
|
"""
|
|
1388
|
-
|
|
1389
|
-
|
|
1390
|
-
|
|
1391
|
-
|
|
1392
|
-
|
|
1393
|
-
arguments_str = choice.message.tool_calls[0].function.arguments
|
|
1394
|
-
args = self._safe_json_loads(arguments_str)
|
|
1395
|
-
|
|
1396
|
-
tool = self.tool_dict[func_name]
|
|
1042
|
+
func_name = tool_call_request.tool_name
|
|
1043
|
+
args = tool_call_request.args
|
|
1044
|
+
tool_call_id = tool_call_request.tool_call_id
|
|
1045
|
+
tool = self._internal_tools[func_name]
|
|
1397
1046
|
result = tool(**args)
|
|
1398
|
-
tool_call_id = choice.message.tool_calls[0].id
|
|
1399
1047
|
|
|
1400
|
-
|
|
1401
|
-
role_name=self.role_name,
|
|
1402
|
-
role_type=self.role_type,
|
|
1403
|
-
meta_dict=None,
|
|
1404
|
-
content="",
|
|
1405
|
-
func_name=func_name,
|
|
1406
|
-
args=args,
|
|
1407
|
-
tool_call_id=tool_call_id,
|
|
1408
|
-
)
|
|
1409
|
-
func_msg = FunctionCallingMessage(
|
|
1410
|
-
role_name=self.role_name,
|
|
1411
|
-
role_type=self.role_type,
|
|
1412
|
-
meta_dict=None,
|
|
1413
|
-
content="",
|
|
1414
|
-
func_name=func_name,
|
|
1415
|
-
result=result,
|
|
1416
|
-
tool_call_id=tool_call_id,
|
|
1417
|
-
)
|
|
1418
|
-
|
|
1419
|
-
# Record information about this function call
|
|
1420
|
-
func_record = ToolCallingRecord(
|
|
1421
|
-
tool_name=func_name,
|
|
1422
|
-
args=args,
|
|
1423
|
-
result=result,
|
|
1424
|
-
tool_call_id=tool_call_id,
|
|
1425
|
-
)
|
|
1426
|
-
return assist_msg, func_msg, func_record
|
|
1427
|
-
|
|
1428
|
-
def _safe_json_loads(self, arguments_str):
|
|
1429
|
-
# Replace Python types with their JSON equivalents
|
|
1430
|
-
arguments_str = arguments_str.replace("None", "null")
|
|
1431
|
-
arguments_str = arguments_str.replace("True", "true")
|
|
1432
|
-
arguments_str = arguments_str.replace("False", "false")
|
|
1433
|
-
|
|
1434
|
-
# Attempt to parse the corrected string
|
|
1435
|
-
try:
|
|
1436
|
-
return json.loads(arguments_str)
|
|
1437
|
-
except json.JSONDecodeError as e:
|
|
1438
|
-
raise ValueError(f"Invalid JSON format: {e}")
|
|
1048
|
+
return self._record_tool_calling(func_name, args, result, tool_call_id)
|
|
1439
1049
|
|
|
1440
|
-
async def
|
|
1050
|
+
async def _aexecute_tool(
|
|
1441
1051
|
self,
|
|
1442
|
-
|
|
1443
|
-
) ->
|
|
1444
|
-
|
|
1445
|
-
|
|
1446
|
-
|
|
1447
|
-
|
|
1448
|
-
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
|
|
1452
|
-
|
|
1453
|
-
|
|
1454
|
-
|
|
1455
|
-
|
|
1456
|
-
|
|
1457
|
-
|
|
1052
|
+
tool_call_request: ToolCallRequest,
|
|
1053
|
+
) -> ToolCallingRecord:
|
|
1054
|
+
func_name = tool_call_request.tool_name
|
|
1055
|
+
args = tool_call_request.args
|
|
1056
|
+
tool_call_id = tool_call_request.tool_call_id
|
|
1057
|
+
tool = self._internal_tools[func_name]
|
|
1058
|
+
result = await tool.async_call(**args)
|
|
1059
|
+
return self._record_tool_calling(func_name, args, result, tool_call_id)
|
|
1060
|
+
|
|
1061
|
+
def _record_tool_calling(
|
|
1062
|
+
self,
|
|
1063
|
+
func_name: str,
|
|
1064
|
+
args: Dict[str, Any],
|
|
1065
|
+
result: Any,
|
|
1066
|
+
tool_call_id: str,
|
|
1067
|
+
):
|
|
1068
|
+
r"""Record the tool calling information in the memory, and return the
|
|
1069
|
+
tool calling record.
|
|
1458
1070
|
"""
|
|
1459
|
-
# Note that when function calling is enabled, `n` is set to 1.
|
|
1460
|
-
choice = response.choices[0]
|
|
1461
|
-
if choice.message.tool_calls is None:
|
|
1462
|
-
raise RuntimeError("Tool call is None")
|
|
1463
|
-
func_name = choice.message.tool_calls[0].function.name
|
|
1464
|
-
|
|
1465
|
-
args = json.loads(choice.message.tool_calls[0].function.arguments)
|
|
1466
|
-
tool = self.tool_dict[func_name]
|
|
1467
|
-
result = await tool(**args)
|
|
1468
|
-
tool_call_id = choice.message.tool_calls[0].id
|
|
1469
|
-
|
|
1470
1071
|
assist_msg = FunctionCallingMessage(
|
|
1471
1072
|
role_name=self.role_name,
|
|
1472
1073
|
role_type=self.role_type,
|
|
@@ -1486,14 +1087,18 @@ class ChatAgent(BaseAgent):
|
|
|
1486
1087
|
tool_call_id=tool_call_id,
|
|
1487
1088
|
)
|
|
1488
1089
|
|
|
1489
|
-
|
|
1490
|
-
|
|
1090
|
+
self.update_memory(assist_msg, OpenAIBackendRole.ASSISTANT)
|
|
1091
|
+
self.update_memory(func_msg, OpenAIBackendRole.FUNCTION)
|
|
1092
|
+
|
|
1093
|
+
# Record information about this tool call
|
|
1094
|
+
tool_record = ToolCallingRecord(
|
|
1491
1095
|
tool_name=func_name,
|
|
1492
1096
|
args=args,
|
|
1493
1097
|
result=result,
|
|
1494
1098
|
tool_call_id=tool_call_id,
|
|
1495
1099
|
)
|
|
1496
|
-
|
|
1100
|
+
|
|
1101
|
+
return tool_record
|
|
1497
1102
|
|
|
1498
1103
|
def get_usage_dict(
|
|
1499
1104
|
self, output_messages: List[BaseMessage], prompt_tokens: int
|
|
@@ -1508,15 +1113,15 @@ class ChatAgent(BaseAgent):
|
|
|
1508
1113
|
dict: Usage dictionary.
|
|
1509
1114
|
"""
|
|
1510
1115
|
encoding = get_model_encoding(self.model_type.value_for_tiktoken)
|
|
1511
|
-
completion_tokens =
|
|
1512
|
-
|
|
1513
|
-
|
|
1514
|
-
|
|
1116
|
+
completion_tokens = sum(
|
|
1117
|
+
len(encoding.encode(message.content))
|
|
1118
|
+
for message in output_messages
|
|
1119
|
+
)
|
|
1120
|
+
return dict(
|
|
1515
1121
|
completion_tokens=completion_tokens,
|
|
1516
1122
|
prompt_tokens=prompt_tokens,
|
|
1517
1123
|
total_tokens=completion_tokens + prompt_tokens,
|
|
1518
1124
|
)
|
|
1519
|
-
return usage_dict
|
|
1520
1125
|
|
|
1521
1126
|
def add_model_scheduling_strategy(self, name: str, strategy_fn: Callable):
|
|
1522
1127
|
r"""Add a scheduling strategy method provided by user to ModelManger.
|