camel-ai 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +30 -0
- camel/agents/__init__.py +40 -0
- camel/agents/base.py +29 -0
- camel/agents/chat_agent.py +539 -0
- camel/agents/critic_agent.py +179 -0
- camel/agents/embodied_agent.py +138 -0
- camel/agents/role_assignment_agent.py +117 -0
- camel/agents/task_agent.py +382 -0
- camel/agents/tool_agents/__init__.py +20 -0
- camel/agents/tool_agents/base.py +40 -0
- camel/agents/tool_agents/hugging_face_tool_agent.py +203 -0
- camel/configs.py +159 -0
- camel/embeddings/__init__.py +20 -0
- camel/embeddings/base.py +65 -0
- camel/embeddings/openai_embedding.py +74 -0
- camel/functions/__init__.py +27 -0
- camel/functions/base_io_functions.py +261 -0
- camel/functions/math_functions.py +61 -0
- camel/functions/openai_function.py +88 -0
- camel/functions/search_functions.py +309 -0
- camel/functions/unstructured_io_fuctions.py +616 -0
- camel/functions/weather_functions.py +136 -0
- camel/generators.py +263 -0
- camel/human.py +130 -0
- camel/memories/__init__.py +28 -0
- camel/memories/base.py +75 -0
- camel/memories/chat_history_memory.py +111 -0
- camel/memories/context_creators/__init__.py +18 -0
- camel/memories/context_creators/base.py +72 -0
- camel/memories/context_creators/score_based.py +130 -0
- camel/memories/records.py +92 -0
- camel/messages/__init__.py +38 -0
- camel/messages/base.py +223 -0
- camel/messages/func_message.py +106 -0
- camel/models/__init__.py +26 -0
- camel/models/base_model.py +110 -0
- camel/models/model_factory.py +59 -0
- camel/models/open_source_model.py +144 -0
- camel/models/openai_model.py +103 -0
- camel/models/stub_model.py +106 -0
- camel/prompts/__init__.py +38 -0
- camel/prompts/ai_society.py +121 -0
- camel/prompts/base.py +227 -0
- camel/prompts/code.py +111 -0
- camel/prompts/evaluation.py +40 -0
- camel/prompts/misalignment.py +84 -0
- camel/prompts/prompt_templates.py +117 -0
- camel/prompts/role_description_prompt_template.py +53 -0
- camel/prompts/solution_extraction.py +44 -0
- camel/prompts/task_prompt_template.py +56 -0
- camel/prompts/translation.py +42 -0
- camel/responses/__init__.py +18 -0
- camel/responses/agent_responses.py +42 -0
- camel/societies/__init__.py +20 -0
- camel/societies/babyagi_playing.py +254 -0
- camel/societies/role_playing.py +456 -0
- camel/storages/__init__.py +23 -0
- camel/storages/key_value_storages/__init__.py +23 -0
- camel/storages/key_value_storages/base.py +57 -0
- camel/storages/key_value_storages/in_memory.py +51 -0
- camel/storages/key_value_storages/json.py +97 -0
- camel/terminators/__init__.py +23 -0
- camel/terminators/base.py +44 -0
- camel/terminators/response_terminator.py +118 -0
- camel/terminators/token_limit_terminator.py +55 -0
- camel/types/__init__.py +54 -0
- camel/types/enums.py +176 -0
- camel/types/openai_types.py +39 -0
- camel/utils/__init__.py +47 -0
- camel/utils/commons.py +243 -0
- camel/utils/python_interpreter.py +435 -0
- camel/utils/token_counting.py +220 -0
- camel_ai-0.1.1.dist-info/METADATA +311 -0
- camel_ai-0.1.1.dist-info/RECORD +75 -0
- camel_ai-0.1.1.dist-info/WHEEL +4 -0
camel/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
import camel.agents
|
|
15
|
+
import camel.configs
|
|
16
|
+
import camel.generators
|
|
17
|
+
import camel.messages
|
|
18
|
+
import camel.prompts
|
|
19
|
+
import camel.types
|
|
20
|
+
import camel.utils
|
|
21
|
+
import camel.functions
|
|
22
|
+
import camel.memories
|
|
23
|
+
import camel.storages
|
|
24
|
+
|
|
25
|
+
__version__ = '0.1.1'
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
'__version__',
|
|
29
|
+
'camel',
|
|
30
|
+
]
|
camel/agents/__init__.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
from .base import BaseAgent
|
|
15
|
+
from .chat_agent import ChatAgent
|
|
16
|
+
from .task_agent import (
|
|
17
|
+
TaskSpecifyAgent,
|
|
18
|
+
TaskPlannerAgent,
|
|
19
|
+
TaskCreationAgent,
|
|
20
|
+
TaskPrioritizationAgent,
|
|
21
|
+
)
|
|
22
|
+
from .critic_agent import CriticAgent
|
|
23
|
+
from .tool_agents.base import BaseToolAgent
|
|
24
|
+
from .tool_agents.hugging_face_tool_agent import HuggingFaceToolAgent
|
|
25
|
+
from .embodied_agent import EmbodiedAgent
|
|
26
|
+
from .role_assignment_agent import RoleAssignmentAgent
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
'BaseAgent',
|
|
30
|
+
'ChatAgent',
|
|
31
|
+
'TaskSpecifyAgent',
|
|
32
|
+
'TaskPlannerAgent',
|
|
33
|
+
'TaskCreationAgent',
|
|
34
|
+
'TaskPrioritizationAgent',
|
|
35
|
+
'CriticAgent',
|
|
36
|
+
'BaseToolAgent',
|
|
37
|
+
'HuggingFaceToolAgent',
|
|
38
|
+
'EmbodiedAgent',
|
|
39
|
+
'RoleAssignmentAgent',
|
|
40
|
+
]
|
camel/agents/base.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
from abc import ABC, abstractmethod
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BaseAgent(ABC):
|
|
19
|
+
r"""An abstract base class for all CAMEL agents."""
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def reset(self, *args: Any, **kwargs: Any) -> Any:
|
|
23
|
+
r"""Resets the agent to its initial state."""
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def step(self, *args: Any, **kwargs: Any) -> Any:
|
|
28
|
+
r"""Performs a single step of the agent."""
|
|
29
|
+
pass
|
|
@@ -0,0 +1,539 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
import json
|
|
15
|
+
from collections import defaultdict
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
18
|
+
|
|
19
|
+
from openai import Stream
|
|
20
|
+
|
|
21
|
+
from camel.agents import BaseAgent
|
|
22
|
+
from camel.configs import BaseConfig, ChatGPTConfig
|
|
23
|
+
from camel.functions import OpenAIFunction
|
|
24
|
+
from camel.memories import (
|
|
25
|
+
BaseMemory,
|
|
26
|
+
ChatHistoryMemory,
|
|
27
|
+
MemoryRecord,
|
|
28
|
+
ScoreBasedContextCreator,
|
|
29
|
+
)
|
|
30
|
+
from camel.messages import BaseMessage, FunctionCallingMessage, OpenAIMessage
|
|
31
|
+
from camel.models import BaseModelBackend, ModelFactory
|
|
32
|
+
from camel.responses import ChatAgentResponse
|
|
33
|
+
from camel.terminators import ResponseTerminator
|
|
34
|
+
from camel.types import (
|
|
35
|
+
ChatCompletion,
|
|
36
|
+
ChatCompletionChunk,
|
|
37
|
+
ModelType,
|
|
38
|
+
OpenAIBackendRole,
|
|
39
|
+
RoleType,
|
|
40
|
+
)
|
|
41
|
+
from camel.utils import get_model_encoding
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass(frozen=True)
|
|
45
|
+
class FunctionCallingRecord:
|
|
46
|
+
r"""Historical records of functions called in the conversation.
|
|
47
|
+
|
|
48
|
+
Attributes:
|
|
49
|
+
func_name (str): The name of the function being called.
|
|
50
|
+
args (Dict[str, Any]): The dictionary of arguments passed to
|
|
51
|
+
the function.
|
|
52
|
+
result (Any): The execution result of calling this function.
|
|
53
|
+
"""
|
|
54
|
+
func_name: str
|
|
55
|
+
args: Dict[str, Any]
|
|
56
|
+
result: Any
|
|
57
|
+
|
|
58
|
+
def __str__(self) -> str:
|
|
59
|
+
r"""Overridden version of the string function.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
str: Modified string to represent the function calling.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
return (f"Function Execution: {self.func_name}\n"
|
|
66
|
+
f"\tArgs: {self.args}\n"
|
|
67
|
+
f"\tResult: {self.result}")
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class ChatAgent(BaseAgent):
|
|
71
|
+
r"""Class for managing conversations of CAMEL Chat Agents.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
system_message (BaseMessage): The system message for the chat agent.
|
|
75
|
+
model_type (ModelType, optional): The LLM model to use for generating
|
|
76
|
+
responses. (default :obj:`ModelType.GPT_3_5_TURBO`)
|
|
77
|
+
model_config (BaseConfig, optional): Configuration options for the
|
|
78
|
+
LLM model. (default: :obj:`None`)
|
|
79
|
+
memory (BaseMemory, optional): The agent memory for managing chat
|
|
80
|
+
messages. If `None`, a :obj:`ChatHistoryMemory` will be used.
|
|
81
|
+
(default: :obj:`None`)
|
|
82
|
+
message_window_size (int, optional): The maximum number of previous
|
|
83
|
+
messages to include in the context window. If `None`, no windowing
|
|
84
|
+
is performed. (default: :obj:`None`)
|
|
85
|
+
token_limit (int, optional): The maxinum number of tokens in a context.
|
|
86
|
+
The context will be automatically pruned to fulfill the limitation.
|
|
87
|
+
If `None`, it will be set according to the backend model.
|
|
88
|
+
(default: :obj:`None`)
|
|
89
|
+
output_language (str, optional): The language to be output by the
|
|
90
|
+
agent. (default: :obj:`None`)
|
|
91
|
+
function_list (List[OpenAIFunction], optional): List of available
|
|
92
|
+
:obj:`OpenAIFunction`. (default: :obj:`None`)
|
|
93
|
+
response_terminators (List[ResponseTerminator], optional): List of
|
|
94
|
+
:obj:`ResponseTerminator` bind to one chat agent.
|
|
95
|
+
(default: :obj:`None`)
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
system_message: BaseMessage,
|
|
101
|
+
model_type: Optional[ModelType] = None,
|
|
102
|
+
model_config: Optional[BaseConfig] = None,
|
|
103
|
+
memory: Optional[BaseMemory] = None,
|
|
104
|
+
message_window_size: Optional[int] = None,
|
|
105
|
+
token_limit: Optional[int] = None,
|
|
106
|
+
output_language: Optional[str] = None,
|
|
107
|
+
function_list: Optional[List[OpenAIFunction]] = None,
|
|
108
|
+
response_terminators: Optional[List[ResponseTerminator]] = None,
|
|
109
|
+
) -> None:
|
|
110
|
+
|
|
111
|
+
self.orig_sys_message: BaseMessage = system_message
|
|
112
|
+
self.system_message = system_message
|
|
113
|
+
self.role_name: str = system_message.role_name
|
|
114
|
+
self.role_type: RoleType = system_message.role_type
|
|
115
|
+
self.output_language: Optional[str] = output_language
|
|
116
|
+
if self.output_language is not None:
|
|
117
|
+
self.set_output_language(self.output_language)
|
|
118
|
+
|
|
119
|
+
self.model_type: ModelType = (model_type if model_type is not None else
|
|
120
|
+
ModelType.GPT_3_5_TURBO)
|
|
121
|
+
|
|
122
|
+
self.func_dict: Dict[str, Callable] = {}
|
|
123
|
+
if function_list is not None:
|
|
124
|
+
for func in function_list:
|
|
125
|
+
self.func_dict[func.name] = func.func
|
|
126
|
+
self.model_config = model_config or ChatGPTConfig()
|
|
127
|
+
|
|
128
|
+
self.model_backend: BaseModelBackend = ModelFactory.create(
|
|
129
|
+
self.model_type, self.model_config.__dict__)
|
|
130
|
+
self.model_token_limit = token_limit or self.model_backend.token_limit
|
|
131
|
+
context_creator = ScoreBasedContextCreator(
|
|
132
|
+
self.model_backend.token_counter,
|
|
133
|
+
self.model_token_limit,
|
|
134
|
+
)
|
|
135
|
+
self.memory: BaseMemory = memory or ChatHistoryMemory(
|
|
136
|
+
context_creator, window_size=message_window_size)
|
|
137
|
+
|
|
138
|
+
self.terminated: bool = False
|
|
139
|
+
self.response_terminators = response_terminators or []
|
|
140
|
+
self.init_messages()
|
|
141
|
+
|
|
142
|
+
def reset(self):
|
|
143
|
+
r"""Resets the :obj:`ChatAgent` to its initial state and returns the
|
|
144
|
+
stored messages.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
List[BaseMessage]: The stored messages.
|
|
148
|
+
"""
|
|
149
|
+
self.terminated = False
|
|
150
|
+
self.init_messages()
|
|
151
|
+
for terminator in self.response_terminators:
|
|
152
|
+
terminator.reset()
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def system_message(self) -> BaseMessage:
|
|
156
|
+
r"""The getter method for the property :obj:`system_message`.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
BaseMessage: The system message of this agent.
|
|
160
|
+
"""
|
|
161
|
+
return self._system_message
|
|
162
|
+
|
|
163
|
+
@system_message.setter
|
|
164
|
+
def system_message(self, message: BaseMessage):
|
|
165
|
+
r"""The setter method for the property :obj:`system_message`.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
message (BaseMessage): The message to be set as the
|
|
169
|
+
new system message of this agent.
|
|
170
|
+
"""
|
|
171
|
+
self._system_message = message
|
|
172
|
+
|
|
173
|
+
def is_function_calling_enabled(self) -> bool:
|
|
174
|
+
r"""Whether OpenAI function calling is enabled for this agent.
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
bool: Whether OpenAI function calling is enabled for this
|
|
178
|
+
agent, determined by whether the dictionary of functions
|
|
179
|
+
is empty.
|
|
180
|
+
"""
|
|
181
|
+
return len(self.func_dict) > 0
|
|
182
|
+
|
|
183
|
+
def update_memory(self, message: BaseMessage,
|
|
184
|
+
role: OpenAIBackendRole) -> None:
|
|
185
|
+
r"""Updates the agent memory with a new message.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
message (BaseMessage): The new message to add to the stored
|
|
189
|
+
messages.
|
|
190
|
+
role (OpenAIBackendRole): The backend role type.
|
|
191
|
+
"""
|
|
192
|
+
self.memory.write_record(MemoryRecord(message, role))
|
|
193
|
+
|
|
194
|
+
def set_output_language(self, output_language: str) -> BaseMessage:
|
|
195
|
+
r"""Sets the output language for the system message. This method
|
|
196
|
+
updates the output language for the system message. The output
|
|
197
|
+
language determines the language in which the output text should be
|
|
198
|
+
generated.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
output_language (str): The desired output language.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
BaseMessage: The updated system message object.
|
|
205
|
+
"""
|
|
206
|
+
self.output_language = output_language
|
|
207
|
+
content = (self.orig_sys_message.content +
|
|
208
|
+
("\nRegardless of the input language, "
|
|
209
|
+
f"you must output text in {output_language}."))
|
|
210
|
+
self.system_message = self.system_message.create_new_instance(content)
|
|
211
|
+
return self.system_message
|
|
212
|
+
|
|
213
|
+
def get_info(self, id: Optional[str], usage: Optional[Dict[str, int]],
|
|
214
|
+
termination_reasons: List[str], num_tokens: int,
|
|
215
|
+
called_funcs: List[FunctionCallingRecord]) -> Dict[str, Any]:
|
|
216
|
+
r"""Returns a dictionary containing information about the chat session.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
id (str, optional): The ID of the chat session.
|
|
220
|
+
usage (Dict[str, int], optional): Information about the usage of
|
|
221
|
+
the LLM model.
|
|
222
|
+
termination_reasons (List[str]): The reasons for the termination
|
|
223
|
+
of the chat session.
|
|
224
|
+
num_tokens (int): The number of tokens used in the chat session.
|
|
225
|
+
called_funcs (List[FunctionCallingRecord]): The list of function
|
|
226
|
+
calling records, containing the information of called
|
|
227
|
+
functions.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
Dict[str, Any]: The chat session information.
|
|
231
|
+
"""
|
|
232
|
+
return {
|
|
233
|
+
"id": id,
|
|
234
|
+
"usage": usage,
|
|
235
|
+
"termination_reasons": termination_reasons,
|
|
236
|
+
"num_tokens": num_tokens,
|
|
237
|
+
"called_functions": called_funcs,
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
def init_messages(self) -> None:
|
|
241
|
+
r"""Initializes the stored messages list with the initial system
|
|
242
|
+
message.
|
|
243
|
+
"""
|
|
244
|
+
system_record = MemoryRecord(self.system_message,
|
|
245
|
+
OpenAIBackendRole.SYSTEM)
|
|
246
|
+
self.memory.clear()
|
|
247
|
+
self.memory.write_record(system_record)
|
|
248
|
+
|
|
249
|
+
def record_message(self, message: BaseMessage) -> None:
|
|
250
|
+
r"""Records the externally provided message into the agent memory as if
|
|
251
|
+
it were an answer of the :obj:`ChatAgent` from the backend. Currently,
|
|
252
|
+
the choice of the critic is submitted with this method.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
message (BaseMessage): An external message to be recorded in the
|
|
256
|
+
memory.
|
|
257
|
+
"""
|
|
258
|
+
self.update_memory(message, OpenAIBackendRole.ASSISTANT)
|
|
259
|
+
|
|
260
|
+
def step(
|
|
261
|
+
self,
|
|
262
|
+
input_message: BaseMessage,
|
|
263
|
+
) -> ChatAgentResponse:
|
|
264
|
+
r"""Performs a single step in the chat session by generating a response
|
|
265
|
+
to the input message.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
input_message (BaseMessage): The input message to the agent.
|
|
269
|
+
Its `role` field that specifies the role at backend may be either
|
|
270
|
+
`user` or `assistant` but it will be set to `user` anyway since
|
|
271
|
+
for the self agent any incoming message is external.
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
ChatAgentResponse: A struct containing the output messages,
|
|
275
|
+
a boolean indicating whether the chat session has terminated,
|
|
276
|
+
and information about the chat session.
|
|
277
|
+
"""
|
|
278
|
+
self.update_memory(input_message, OpenAIBackendRole.USER)
|
|
279
|
+
|
|
280
|
+
output_messages: List[BaseMessage]
|
|
281
|
+
info: Dict[str, Any]
|
|
282
|
+
called_funcs: List[FunctionCallingRecord] = []
|
|
283
|
+
while True:
|
|
284
|
+
# Format messages and get the token number
|
|
285
|
+
openai_messages: Optional[List[OpenAIMessage]]
|
|
286
|
+
|
|
287
|
+
try:
|
|
288
|
+
openai_messages, num_tokens = self.memory.get_context()
|
|
289
|
+
except RuntimeError as e:
|
|
290
|
+
return self.step_token_exceed(e.args[1], called_funcs,
|
|
291
|
+
"max_tokens_exceeded")
|
|
292
|
+
|
|
293
|
+
# Obtain the model's response
|
|
294
|
+
response = self.model_backend.run(openai_messages)
|
|
295
|
+
|
|
296
|
+
if isinstance(response, ChatCompletion):
|
|
297
|
+
output_messages, finish_reasons, usage_dict, response_id = (
|
|
298
|
+
self.handle_batch_response(response))
|
|
299
|
+
else:
|
|
300
|
+
output_messages, finish_reasons, usage_dict, response_id = (
|
|
301
|
+
self.handle_stream_response(response, num_tokens))
|
|
302
|
+
|
|
303
|
+
if (self.is_function_calling_enabled()
|
|
304
|
+
and finish_reasons[0] == 'function_call'
|
|
305
|
+
and isinstance(response, ChatCompletion)):
|
|
306
|
+
# Do function calling
|
|
307
|
+
func_assistant_msg, func_result_msg, func_record = (
|
|
308
|
+
self.step_function_call(response))
|
|
309
|
+
|
|
310
|
+
# Update the messages
|
|
311
|
+
self.update_memory(func_assistant_msg,
|
|
312
|
+
OpenAIBackendRole.ASSISTANT)
|
|
313
|
+
self.update_memory(func_result_msg, OpenAIBackendRole.FUNCTION)
|
|
314
|
+
|
|
315
|
+
# Record the function calling
|
|
316
|
+
called_funcs.append(func_record)
|
|
317
|
+
else:
|
|
318
|
+
# Function calling disabled or not a function calling
|
|
319
|
+
|
|
320
|
+
# Loop over responses terminators, get list of termination
|
|
321
|
+
# tuples with whether the terminator terminates the agent
|
|
322
|
+
# and termination reason
|
|
323
|
+
termination = [
|
|
324
|
+
terminator.is_terminated(output_messages)
|
|
325
|
+
for terminator in self.response_terminators
|
|
326
|
+
]
|
|
327
|
+
# Terminate the agent if any of the terminator terminates
|
|
328
|
+
self.terminated, termination_reason = next(
|
|
329
|
+
((terminated, termination_reason)
|
|
330
|
+
for terminated, termination_reason in termination
|
|
331
|
+
if terminated), (False, None))
|
|
332
|
+
# For now only retain the first termination reason
|
|
333
|
+
if self.terminated and termination_reason is not None:
|
|
334
|
+
finish_reasons = [termination_reason] * len(finish_reasons)
|
|
335
|
+
|
|
336
|
+
info = self.get_info(
|
|
337
|
+
response_id,
|
|
338
|
+
usage_dict,
|
|
339
|
+
finish_reasons,
|
|
340
|
+
num_tokens,
|
|
341
|
+
called_funcs,
|
|
342
|
+
)
|
|
343
|
+
break
|
|
344
|
+
|
|
345
|
+
return ChatAgentResponse(output_messages, self.terminated, info)
|
|
346
|
+
|
|
347
|
+
def handle_batch_response(
|
|
348
|
+
self, response: ChatCompletion
|
|
349
|
+
) -> Tuple[List[BaseMessage], List[str], Dict[str, int], str]:
|
|
350
|
+
r"""
|
|
351
|
+
|
|
352
|
+
Args:
|
|
353
|
+
response (dict): Model response.
|
|
354
|
+
|
|
355
|
+
Returns:
|
|
356
|
+
tuple: A tuple of list of output `ChatMessage`, list of
|
|
357
|
+
finish reasons, usage dictionary, and response id.
|
|
358
|
+
"""
|
|
359
|
+
output_messages: List[BaseMessage] = []
|
|
360
|
+
for choice in response.choices:
|
|
361
|
+
chat_message = BaseMessage(
|
|
362
|
+
role_name=self.role_name,
|
|
363
|
+
role_type=self.role_type,
|
|
364
|
+
meta_dict=dict(),
|
|
365
|
+
content=choice.message.content or "",
|
|
366
|
+
)
|
|
367
|
+
output_messages.append(chat_message)
|
|
368
|
+
finish_reasons = [
|
|
369
|
+
str(choice.finish_reason) for choice in response.choices
|
|
370
|
+
]
|
|
371
|
+
usage = (response.usage.model_dump()
|
|
372
|
+
if response.usage is not None else {})
|
|
373
|
+
return (
|
|
374
|
+
output_messages,
|
|
375
|
+
finish_reasons,
|
|
376
|
+
usage,
|
|
377
|
+
response.id,
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
def handle_stream_response(
|
|
381
|
+
self,
|
|
382
|
+
response: Stream[ChatCompletionChunk],
|
|
383
|
+
prompt_tokens: int,
|
|
384
|
+
) -> Tuple[List[BaseMessage], List[str], Dict[str, int], str]:
|
|
385
|
+
r"""
|
|
386
|
+
|
|
387
|
+
Args:
|
|
388
|
+
response (dict): Model response.
|
|
389
|
+
prompt_tokens (int): Number of input prompt tokens.
|
|
390
|
+
|
|
391
|
+
Returns:
|
|
392
|
+
tuple: A tuple of list of output `ChatMessage`, list of
|
|
393
|
+
finish reasons, usage dictionary, and response id.
|
|
394
|
+
"""
|
|
395
|
+
content_dict: defaultdict = defaultdict(lambda: "")
|
|
396
|
+
finish_reasons_dict: defaultdict = defaultdict(lambda: "")
|
|
397
|
+
output_messages: List[BaseMessage] = []
|
|
398
|
+
response_id: str = ""
|
|
399
|
+
# All choices in one response share one role
|
|
400
|
+
for chunk in response:
|
|
401
|
+
response_id = chunk.id
|
|
402
|
+
for choice in chunk.choices:
|
|
403
|
+
index = choice.index
|
|
404
|
+
delta = choice.delta
|
|
405
|
+
if delta.content is not None:
|
|
406
|
+
# When response has not been stopped
|
|
407
|
+
# Notice that only the first chunk_dict has the "role"
|
|
408
|
+
content_dict[index] += delta.content
|
|
409
|
+
else:
|
|
410
|
+
finish_reasons_dict[index] = choice.finish_reason
|
|
411
|
+
chat_message = BaseMessage(role_name=self.role_name,
|
|
412
|
+
role_type=self.role_type,
|
|
413
|
+
meta_dict=dict(),
|
|
414
|
+
content=content_dict[index])
|
|
415
|
+
output_messages.append(chat_message)
|
|
416
|
+
finish_reasons = [
|
|
417
|
+
finish_reasons_dict[i] for i in range(len(finish_reasons_dict))
|
|
418
|
+
]
|
|
419
|
+
usage_dict = self.get_usage_dict(output_messages, prompt_tokens)
|
|
420
|
+
return output_messages, finish_reasons, usage_dict, response_id
|
|
421
|
+
|
|
422
|
+
def step_token_exceed(self, num_tokens: int,
|
|
423
|
+
called_funcs: List[FunctionCallingRecord],
|
|
424
|
+
termination_reason: str) -> ChatAgentResponse:
|
|
425
|
+
r"""Return trivial response containing number of tokens and information
|
|
426
|
+
of called functions when the number of tokens exceeds.
|
|
427
|
+
|
|
428
|
+
Args:
|
|
429
|
+
num_tokens (int): Number of tokens in the messages.
|
|
430
|
+
called_funcs (List[FunctionCallingRecord]): List of information
|
|
431
|
+
objects of functions called in the current step.
|
|
432
|
+
termination_reason (str): String of termination reason.
|
|
433
|
+
|
|
434
|
+
Returns:
|
|
435
|
+
ChatAgentResponse: The struct containing trivial outputs and
|
|
436
|
+
information about token number and called functions.
|
|
437
|
+
"""
|
|
438
|
+
self.terminated = True
|
|
439
|
+
output_messages: List[BaseMessage] = []
|
|
440
|
+
|
|
441
|
+
info = self.get_info(
|
|
442
|
+
None,
|
|
443
|
+
None,
|
|
444
|
+
[termination_reason],
|
|
445
|
+
num_tokens,
|
|
446
|
+
called_funcs,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
return ChatAgentResponse(
|
|
450
|
+
output_messages,
|
|
451
|
+
self.terminated,
|
|
452
|
+
info,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
def step_function_call(
|
|
456
|
+
self,
|
|
457
|
+
response: ChatCompletion,
|
|
458
|
+
) -> Tuple[FunctionCallingMessage, FunctionCallingMessage,
|
|
459
|
+
FunctionCallingRecord]:
|
|
460
|
+
r"""Execute the function with arguments following the model's response.
|
|
461
|
+
|
|
462
|
+
Args:
|
|
463
|
+
response (Dict[str, Any]): The response obtained by calling the
|
|
464
|
+
model.
|
|
465
|
+
|
|
466
|
+
Returns:
|
|
467
|
+
tuple: A tuple consisting of two obj:`FunctionCallingMessage`,
|
|
468
|
+
one about the arguments and the other about the execution
|
|
469
|
+
result, and a struct for logging information about this
|
|
470
|
+
function call.
|
|
471
|
+
"""
|
|
472
|
+
# Note that when function calling is enabled, `n` is set to 1.
|
|
473
|
+
choice = response.choices[0]
|
|
474
|
+
if choice.message.function_call is None:
|
|
475
|
+
raise RuntimeError("Function call is None")
|
|
476
|
+
func_name = choice.message.function_call.name
|
|
477
|
+
func = self.func_dict[func_name]
|
|
478
|
+
|
|
479
|
+
args_str: str = choice.message.function_call.arguments
|
|
480
|
+
args = json.loads(args_str.replace("\'", "\""))
|
|
481
|
+
|
|
482
|
+
# Pass the extracted arguments to the indicated function
|
|
483
|
+
try:
|
|
484
|
+
result = func(**args)
|
|
485
|
+
except Exception:
|
|
486
|
+
raise ValueError(
|
|
487
|
+
f"Execution of function {func.__name__} failed with "
|
|
488
|
+
f"arguments being {args}.")
|
|
489
|
+
|
|
490
|
+
assist_msg = FunctionCallingMessage(
|
|
491
|
+
role_name=self.role_name,
|
|
492
|
+
role_type=self.role_type,
|
|
493
|
+
meta_dict=None,
|
|
494
|
+
content="",
|
|
495
|
+
func_name=func_name,
|
|
496
|
+
args=args,
|
|
497
|
+
)
|
|
498
|
+
func_msg = FunctionCallingMessage(
|
|
499
|
+
role_name=self.role_name,
|
|
500
|
+
role_type=self.role_type,
|
|
501
|
+
meta_dict=None,
|
|
502
|
+
content="",
|
|
503
|
+
func_name=func_name,
|
|
504
|
+
result=result,
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
# Record information about this function call
|
|
508
|
+
func_record = FunctionCallingRecord(func_name, args, result)
|
|
509
|
+
return assist_msg, func_msg, func_record
|
|
510
|
+
|
|
511
|
+
def get_usage_dict(self, output_messages: List[BaseMessage],
|
|
512
|
+
prompt_tokens: int) -> Dict[str, int]:
|
|
513
|
+
r"""Get usage dictionary when using the stream mode.
|
|
514
|
+
|
|
515
|
+
Args:
|
|
516
|
+
output_messages (list): List of output messages.
|
|
517
|
+
prompt_tokens (int): Number of input prompt tokens.
|
|
518
|
+
|
|
519
|
+
Returns:
|
|
520
|
+
dict: Usage dictionary.
|
|
521
|
+
"""
|
|
522
|
+
encoding = get_model_encoding(self.model_type.value_for_tiktoken)
|
|
523
|
+
completion_tokens = 0
|
|
524
|
+
for message in output_messages:
|
|
525
|
+
completion_tokens += len(encoding.encode(message.content))
|
|
526
|
+
usage_dict = dict(completion_tokens=completion_tokens,
|
|
527
|
+
prompt_tokens=prompt_tokens,
|
|
528
|
+
total_tokens=completion_tokens + prompt_tokens)
|
|
529
|
+
return usage_dict
|
|
530
|
+
|
|
531
|
+
def __repr__(self) -> str:
|
|
532
|
+
r"""Returns a string representation of the :obj:`ChatAgent`.
|
|
533
|
+
|
|
534
|
+
Returns:
|
|
535
|
+
str: The string representation of the :obj:`ChatAgent`.
|
|
536
|
+
"""
|
|
537
|
+
return (
|
|
538
|
+
f"ChatAgent({self.role_name}, {self.role_type}, {self.model_type})"
|
|
539
|
+
)
|