camel-ai 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -11
- camel/agents/__init__.py +5 -5
- camel/agents/chat_agent.py +124 -63
- camel/agents/critic_agent.py +28 -17
- camel/agents/deductive_reasoner_agent.py +235 -0
- camel/agents/embodied_agent.py +92 -40
- camel/agents/role_assignment_agent.py +27 -17
- camel/agents/task_agent.py +60 -34
- camel/agents/tool_agents/base.py +0 -1
- camel/agents/tool_agents/hugging_face_tool_agent.py +7 -4
- camel/configs.py +119 -7
- camel/embeddings/__init__.py +2 -0
- camel/embeddings/base.py +3 -2
- camel/embeddings/openai_embedding.py +3 -3
- camel/embeddings/sentence_transformers_embeddings.py +65 -0
- camel/functions/__init__.py +13 -3
- camel/functions/google_maps_function.py +335 -0
- camel/functions/math_functions.py +7 -7
- camel/functions/openai_function.py +344 -42
- camel/functions/search_functions.py +100 -35
- camel/functions/twitter_function.py +484 -0
- camel/functions/weather_functions.py +36 -23
- camel/generators.py +65 -46
- camel/human.py +17 -11
- camel/interpreters/__init__.py +25 -0
- camel/interpreters/base.py +49 -0
- camel/{utils/python_interpreter.py → interpreters/internal_python_interpreter.py} +129 -48
- camel/interpreters/interpreter_error.py +19 -0
- camel/interpreters/subprocess_interpreter.py +190 -0
- camel/loaders/__init__.py +22 -0
- camel/{functions/base_io_functions.py → loaders/base_io.py} +38 -35
- camel/{functions/unstructured_io_fuctions.py → loaders/unstructured_io.py} +199 -110
- camel/memories/__init__.py +17 -7
- camel/memories/agent_memories.py +156 -0
- camel/memories/base.py +97 -32
- camel/memories/blocks/__init__.py +21 -0
- camel/memories/{chat_history_memory.py → blocks/chat_history_block.py} +34 -34
- camel/memories/blocks/vectordb_block.py +101 -0
- camel/memories/context_creators/__init__.py +3 -2
- camel/memories/context_creators/score_based.py +32 -20
- camel/memories/records.py +6 -5
- camel/messages/__init__.py +2 -2
- camel/messages/base.py +99 -16
- camel/messages/func_message.py +7 -4
- camel/models/__init__.py +4 -2
- camel/models/anthropic_model.py +132 -0
- camel/models/base_model.py +3 -2
- camel/models/model_factory.py +10 -8
- camel/models/open_source_model.py +25 -13
- camel/models/openai_model.py +9 -10
- camel/models/stub_model.py +6 -5
- camel/prompts/__init__.py +7 -5
- camel/prompts/ai_society.py +21 -14
- camel/prompts/base.py +54 -47
- camel/prompts/code.py +22 -14
- camel/prompts/evaluation.py +8 -5
- camel/prompts/misalignment.py +26 -19
- camel/prompts/object_recognition.py +35 -0
- camel/prompts/prompt_templates.py +14 -8
- camel/prompts/role_description_prompt_template.py +16 -10
- camel/prompts/solution_extraction.py +9 -5
- camel/prompts/task_prompt_template.py +24 -21
- camel/prompts/translation.py +9 -5
- camel/responses/agent_responses.py +5 -2
- camel/retrievers/__init__.py +24 -0
- camel/retrievers/auto_retriever.py +319 -0
- camel/retrievers/base.py +64 -0
- camel/retrievers/bm25_retriever.py +149 -0
- camel/retrievers/vector_retriever.py +166 -0
- camel/societies/__init__.py +1 -1
- camel/societies/babyagi_playing.py +56 -32
- camel/societies/role_playing.py +188 -133
- camel/storages/__init__.py +18 -0
- camel/storages/graph_storages/__init__.py +23 -0
- camel/storages/graph_storages/base.py +82 -0
- camel/storages/graph_storages/graph_element.py +74 -0
- camel/storages/graph_storages/neo4j_graph.py +582 -0
- camel/storages/key_value_storages/base.py +1 -2
- camel/storages/key_value_storages/in_memory.py +1 -2
- camel/storages/key_value_storages/json.py +8 -13
- camel/storages/vectordb_storages/__init__.py +33 -0
- camel/storages/vectordb_storages/base.py +202 -0
- camel/storages/vectordb_storages/milvus.py +396 -0
- camel/storages/vectordb_storages/qdrant.py +371 -0
- camel/terminators/__init__.py +1 -1
- camel/terminators/base.py +2 -3
- camel/terminators/response_terminator.py +21 -12
- camel/terminators/token_limit_terminator.py +5 -3
- camel/types/__init__.py +12 -6
- camel/types/enums.py +86 -13
- camel/types/openai_types.py +10 -5
- camel/utils/__init__.py +18 -13
- camel/utils/commons.py +242 -81
- camel/utils/token_counting.py +135 -15
- {camel_ai-0.1.1.dist-info → camel_ai-0.1.3.dist-info}/METADATA +116 -74
- camel_ai-0.1.3.dist-info/RECORD +101 -0
- {camel_ai-0.1.1.dist-info → camel_ai-0.1.3.dist-info}/WHEEL +1 -1
- camel/memories/context_creators/base.py +0 -72
- camel_ai-0.1.1.dist-info/RECORD +0 -75
|
@@ -0,0 +1,371 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
from dataclasses import asdict
|
|
15
|
+
from datetime import datetime
|
|
16
|
+
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
|
17
|
+
|
|
18
|
+
from camel.storages.vectordb_storages import (
|
|
19
|
+
BaseVectorStorage,
|
|
20
|
+
VectorDBQuery,
|
|
21
|
+
VectorDBQueryResult,
|
|
22
|
+
VectorDBStatus,
|
|
23
|
+
VectorRecord,
|
|
24
|
+
)
|
|
25
|
+
from camel.types import VectorDistance
|
|
26
|
+
|
|
27
|
+
_qdrant_local_client_map: Dict[str, Tuple[Any, int]] = {}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class QdrantStorage(BaseVectorStorage):
|
|
31
|
+
r"""An implementation of the `BaseVectorStorage` for interacting with
|
|
32
|
+
Qdrant, a vector search engine.
|
|
33
|
+
|
|
34
|
+
The detailed information about Qdrant is available at:
|
|
35
|
+
`Qdrant <https://qdrant.tech/>`_
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
vector_dim (int): The dimenstion of storing vectors.
|
|
39
|
+
collection_name (Optional[str], optional): Name for the collection in
|
|
40
|
+
the Qdrant. If not provided, set it to the current time with iso
|
|
41
|
+
format. (default: :obj:`None`)
|
|
42
|
+
url_and_api_key (Optional[Tuple[str, str]], optional): Tuple containing
|
|
43
|
+
the URL and API key for connecting to a remote Qdrant instance.
|
|
44
|
+
(default: :obj:`None`)
|
|
45
|
+
path (Optional[str], optional): Path to a directory for initializing a
|
|
46
|
+
local Qdrant client. (default: :obj:`None`)
|
|
47
|
+
distance (VectorDistance, optional): The distance metric for vector
|
|
48
|
+
comparison (default: :obj:`VectorDistance.COSINE`)
|
|
49
|
+
delete_collection_on_del (bool, optional): Flag to determine if the
|
|
50
|
+
collection should be deleted upon object destruction.
|
|
51
|
+
(default: :obj:`False`)
|
|
52
|
+
**kwargs (Any): Additional keyword arguments for initializing
|
|
53
|
+
`QdrantClient`.
|
|
54
|
+
|
|
55
|
+
Notes:
|
|
56
|
+
- If `url_and_api_key` is provided, it takes priority and the client
|
|
57
|
+
will attempt to connect to the remote Qdrant instance using the URL
|
|
58
|
+
endpoint.
|
|
59
|
+
- If `url_and_api_key` is not provided and `path` is given, the client
|
|
60
|
+
will use the local path to initialize Qdrant.
|
|
61
|
+
- If neither `url_and_api_key` nor `path` is provided, the client will
|
|
62
|
+
be initialized with an in-memory storage (`":memory:"`).
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
vector_dim: int,
|
|
68
|
+
collection_name: Optional[str] = None,
|
|
69
|
+
url_and_api_key: Optional[Tuple[str, str]] = None,
|
|
70
|
+
path: Optional[str] = None,
|
|
71
|
+
distance: VectorDistance = VectorDistance.COSINE,
|
|
72
|
+
delete_collection_on_del: bool = False,
|
|
73
|
+
**kwargs: Any,
|
|
74
|
+
) -> None:
|
|
75
|
+
try:
|
|
76
|
+
from qdrant_client import QdrantClient
|
|
77
|
+
except ImportError as exc:
|
|
78
|
+
raise ImportError(
|
|
79
|
+
"Please install `qdrant-client` first. You can install it by "
|
|
80
|
+
"running `pip install qdrant-client`."
|
|
81
|
+
) from exc
|
|
82
|
+
|
|
83
|
+
self._client: QdrantClient
|
|
84
|
+
self._local_path: Optional[str] = None
|
|
85
|
+
self._create_client(url_and_api_key, path, **kwargs)
|
|
86
|
+
|
|
87
|
+
self.vector_dim = vector_dim
|
|
88
|
+
self.distance = distance
|
|
89
|
+
self.collection_name = (
|
|
90
|
+
collection_name or self._generate_collection_name()
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
self._check_and_create_collection()
|
|
94
|
+
|
|
95
|
+
self.delete_collection_on_del = delete_collection_on_del
|
|
96
|
+
|
|
97
|
+
def __del__(self):
|
|
98
|
+
r"""Deletes the collection if :obj:`del_collection` is set to
|
|
99
|
+
:obj:`True`.
|
|
100
|
+
"""
|
|
101
|
+
# If the client is a local client, decrease count by 1
|
|
102
|
+
if self._local_path is not None:
|
|
103
|
+
# if count decrease to 0, remove it from the map
|
|
104
|
+
_client, _count = _qdrant_local_client_map.pop(self._local_path)
|
|
105
|
+
if _count > 1:
|
|
106
|
+
_qdrant_local_client_map[self._local_path] = (
|
|
107
|
+
_client,
|
|
108
|
+
_count - 1,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
if (
|
|
112
|
+
hasattr(self, "delete_collection_on_del")
|
|
113
|
+
and self.delete_collection_on_del
|
|
114
|
+
):
|
|
115
|
+
self._delete_collection(self.collection_name)
|
|
116
|
+
|
|
117
|
+
def _create_client(
|
|
118
|
+
self,
|
|
119
|
+
url_and_api_key: Optional[Tuple[str, str]],
|
|
120
|
+
path: Optional[str],
|
|
121
|
+
**kwargs: Any,
|
|
122
|
+
) -> None:
|
|
123
|
+
from qdrant_client import QdrantClient
|
|
124
|
+
|
|
125
|
+
if url_and_api_key is not None:
|
|
126
|
+
self._client = QdrantClient(
|
|
127
|
+
url=url_and_api_key[0],
|
|
128
|
+
api_key=url_and_api_key[1],
|
|
129
|
+
**kwargs,
|
|
130
|
+
)
|
|
131
|
+
elif path is not None:
|
|
132
|
+
# Avoid creating a local client multiple times,
|
|
133
|
+
# which is prohibited by Qdrant
|
|
134
|
+
self._local_path = path
|
|
135
|
+
if path in _qdrant_local_client_map:
|
|
136
|
+
# Store client instance in the map and maintain counts
|
|
137
|
+
self._client, count = _qdrant_local_client_map[path]
|
|
138
|
+
_qdrant_local_client_map[path] = (self._client, count + 1)
|
|
139
|
+
else:
|
|
140
|
+
self._client = QdrantClient(path=path, **kwargs)
|
|
141
|
+
_qdrant_local_client_map[path] = (self._client, 1)
|
|
142
|
+
else:
|
|
143
|
+
self._client = QdrantClient(":memory:", **kwargs)
|
|
144
|
+
|
|
145
|
+
def _check_and_create_collection(self) -> None:
|
|
146
|
+
if self._collection_exists(self.collection_name):
|
|
147
|
+
in_dim = self._get_collection_info(self.collection_name)[
|
|
148
|
+
"vector_dim"
|
|
149
|
+
]
|
|
150
|
+
if in_dim != self.vector_dim:
|
|
151
|
+
# The name of collection has to be confirmed by the user
|
|
152
|
+
raise ValueError(
|
|
153
|
+
"Vector dimension of the existing collection "
|
|
154
|
+
f'"{self.collection_name}" ({in_dim}) is different from '
|
|
155
|
+
f"the given embedding dim ({self.vector_dim})."
|
|
156
|
+
)
|
|
157
|
+
else:
|
|
158
|
+
self._create_collection(
|
|
159
|
+
collection_name=self.collection_name,
|
|
160
|
+
size=self.vector_dim,
|
|
161
|
+
distance=self.distance,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
def _create_collection(
|
|
165
|
+
self,
|
|
166
|
+
collection_name: str,
|
|
167
|
+
size: int,
|
|
168
|
+
distance: VectorDistance = VectorDistance.COSINE,
|
|
169
|
+
**kwargs: Any,
|
|
170
|
+
) -> None:
|
|
171
|
+
r"""Creates a new collection in the database.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
collection_name (str): Name of the collection to be created.
|
|
175
|
+
size (int): Dimensionality of vectors to be stored in this
|
|
176
|
+
collection.
|
|
177
|
+
distance (VectorDistance, optional): The distance metric to be used
|
|
178
|
+
for vector similarity. (default: :obj:`VectorDistance.COSINE`)
|
|
179
|
+
**kwargs (Any): Additional keyword arguments.
|
|
180
|
+
"""
|
|
181
|
+
from qdrant_client.http.models import Distance, VectorParams
|
|
182
|
+
|
|
183
|
+
distance_map = {
|
|
184
|
+
VectorDistance.DOT: Distance.DOT,
|
|
185
|
+
VectorDistance.COSINE: Distance.COSINE,
|
|
186
|
+
VectorDistance.EUCLIDEAN: Distance.EUCLID,
|
|
187
|
+
}
|
|
188
|
+
self._client.recreate_collection(
|
|
189
|
+
collection_name=collection_name,
|
|
190
|
+
vectors_config=VectorParams(
|
|
191
|
+
size=size,
|
|
192
|
+
distance=distance_map[distance],
|
|
193
|
+
),
|
|
194
|
+
**kwargs,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
def _delete_collection(
|
|
198
|
+
self,
|
|
199
|
+
collection_name: str,
|
|
200
|
+
**kwargs: Any,
|
|
201
|
+
) -> None:
|
|
202
|
+
r"""Deletes an existing collection from the database.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
collection (str): Name of the collection to be deleted.
|
|
206
|
+
**kwargs (Any): Additional keyword arguments.
|
|
207
|
+
"""
|
|
208
|
+
self._client.delete_collection(
|
|
209
|
+
collection_name=collection_name, **kwargs
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
def _collection_exists(self, collection_name: str) -> bool:
|
|
213
|
+
r"""Returns wether the collection exists in the database"""
|
|
214
|
+
for c in self._client.get_collections().collections:
|
|
215
|
+
if collection_name == c.name:
|
|
216
|
+
return True
|
|
217
|
+
return False
|
|
218
|
+
|
|
219
|
+
def _generate_collection_name(self) -> str:
|
|
220
|
+
r"""Generates a collection name if user doesn't provide"""
|
|
221
|
+
return datetime.now().isoformat()
|
|
222
|
+
|
|
223
|
+
def _get_collection_info(self, collection_name: str) -> Dict[str, Any]:
|
|
224
|
+
r"""Retrieves details of an existing collection.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
collection_name (str): Name of the collection to be checked.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
Dict[str, Any]: A dictionary containing details about the
|
|
231
|
+
collection.
|
|
232
|
+
"""
|
|
233
|
+
from qdrant_client.http.models import VectorParams
|
|
234
|
+
|
|
235
|
+
# TODO: check more information
|
|
236
|
+
collection_info = self._client.get_collection(
|
|
237
|
+
collection_name=collection_name
|
|
238
|
+
)
|
|
239
|
+
vector_config = collection_info.config.params.vectors
|
|
240
|
+
return {
|
|
241
|
+
"vector_dim": vector_config.size
|
|
242
|
+
if isinstance(vector_config, VectorParams)
|
|
243
|
+
else None,
|
|
244
|
+
"vector_count": collection_info.points_count,
|
|
245
|
+
"status": collection_info.status,
|
|
246
|
+
"vectors_count": collection_info.vectors_count,
|
|
247
|
+
"config": collection_info.config,
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
def add(
|
|
251
|
+
self,
|
|
252
|
+
records: List[VectorRecord],
|
|
253
|
+
**kwargs,
|
|
254
|
+
) -> None:
|
|
255
|
+
r"""Adds a list of vectors to the specified collection.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
vectors (List[VectorRecord]): List of vectors to be added.
|
|
259
|
+
**kwargs (Any): Additional keyword arguments.
|
|
260
|
+
|
|
261
|
+
Raises:
|
|
262
|
+
RuntimeError: If there was an error in the addition process.
|
|
263
|
+
"""
|
|
264
|
+
from qdrant_client.http.models import PointStruct, UpdateStatus
|
|
265
|
+
|
|
266
|
+
qdrant_points = [PointStruct(**asdict(p)) for p in records]
|
|
267
|
+
op_info = self._client.upsert(
|
|
268
|
+
collection_name=self.collection_name,
|
|
269
|
+
points=qdrant_points,
|
|
270
|
+
wait=True,
|
|
271
|
+
**kwargs,
|
|
272
|
+
)
|
|
273
|
+
if op_info.status != UpdateStatus.COMPLETED:
|
|
274
|
+
raise RuntimeError(
|
|
275
|
+
"Failed to add vectors in Qdrant, operation info: "
|
|
276
|
+
f"{op_info}."
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
def delete(
|
|
280
|
+
self,
|
|
281
|
+
ids: List[str],
|
|
282
|
+
**kwargs: Any,
|
|
283
|
+
) -> None:
|
|
284
|
+
r"""Deletes a list of vectors identified by their IDs from the storage.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
ids (List[str]): List of unique identifiers for the vectors to be
|
|
288
|
+
deleted.
|
|
289
|
+
**kwargs (Any): Additional keyword arguments.
|
|
290
|
+
|
|
291
|
+
Raises:
|
|
292
|
+
RuntimeError: If there is an error during the deletion process.
|
|
293
|
+
"""
|
|
294
|
+
from qdrant_client.http.models import PointIdsList, UpdateStatus
|
|
295
|
+
|
|
296
|
+
points = cast(List[Union[str, int]], ids)
|
|
297
|
+
op_info = self._client.delete(
|
|
298
|
+
collection_name=self.collection_name,
|
|
299
|
+
points_selector=PointIdsList(points=points),
|
|
300
|
+
wait=True,
|
|
301
|
+
**kwargs,
|
|
302
|
+
)
|
|
303
|
+
if op_info.status != UpdateStatus.COMPLETED:
|
|
304
|
+
raise RuntimeError(
|
|
305
|
+
"Failed to delete vectors in Qdrant, operation info: "
|
|
306
|
+
f"{op_info}"
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
def status(self) -> VectorDBStatus:
|
|
310
|
+
status = self._get_collection_info(self.collection_name)
|
|
311
|
+
return VectorDBStatus(
|
|
312
|
+
vector_dim=status["vector_dim"],
|
|
313
|
+
vector_count=status["vector_count"],
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
def query(
|
|
317
|
+
self,
|
|
318
|
+
query: VectorDBQuery,
|
|
319
|
+
**kwargs: Any,
|
|
320
|
+
) -> List[VectorDBQueryResult]:
|
|
321
|
+
r"""Searches for similar vectors in the storage based on the provided
|
|
322
|
+
query.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
query (VectorDBQuery): The query object containing the search
|
|
326
|
+
vector and the number of top similar vectors to retrieve.
|
|
327
|
+
**kwargs (Any): Additional keyword arguments.
|
|
328
|
+
|
|
329
|
+
Returns:
|
|
330
|
+
List[VectorDBQueryResult]: A list of vectors retrieved from the
|
|
331
|
+
storage based on similarity to the query vector.
|
|
332
|
+
"""
|
|
333
|
+
# TODO: filter
|
|
334
|
+
search_result = self._client.search(
|
|
335
|
+
collection_name=self.collection_name,
|
|
336
|
+
query_vector=query.query_vector,
|
|
337
|
+
with_payload=True,
|
|
338
|
+
with_vectors=True,
|
|
339
|
+
limit=query.top_k,
|
|
340
|
+
**kwargs,
|
|
341
|
+
)
|
|
342
|
+
query_results = []
|
|
343
|
+
for point in search_result:
|
|
344
|
+
query_results.append(
|
|
345
|
+
VectorDBQueryResult.construct(
|
|
346
|
+
similarity=point.score,
|
|
347
|
+
id=str(point.id),
|
|
348
|
+
payload=point.payload,
|
|
349
|
+
vector=point.vector, # type: ignore[arg-type]
|
|
350
|
+
)
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
return query_results
|
|
354
|
+
|
|
355
|
+
def clear(self) -> None:
|
|
356
|
+
r"""Remove all vectors from the storage."""
|
|
357
|
+
self._delete_collection(self.collection_name)
|
|
358
|
+
self._create_collection(
|
|
359
|
+
collection_name=self.collection_name,
|
|
360
|
+
size=self.vector_dim,
|
|
361
|
+
distance=self.distance,
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
def load(self) -> None:
|
|
365
|
+
r"""Load the collection hosted on cloud service."""
|
|
366
|
+
pass
|
|
367
|
+
|
|
368
|
+
@property
|
|
369
|
+
def client(self) -> Any:
|
|
370
|
+
r"""Provides access to the underlying vector database client."""
|
|
371
|
+
return self._client
|
camel/terminators/__init__.py
CHANGED
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
14
|
from .base import BaseTerminator
|
|
15
|
-
from .response_terminator import
|
|
15
|
+
from .response_terminator import ResponseTerminator, ResponseWordsTerminator
|
|
16
16
|
from .token_limit_terminator import TokenLimitTerminator
|
|
17
17
|
|
|
18
18
|
__all__ = [
|
camel/terminators/base.py
CHANGED
|
@@ -18,7 +18,6 @@ from camel.messages import BaseMessage
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class BaseTerminator(ABC):
|
|
21
|
-
|
|
22
21
|
def __init__(self, *args, **kwargs) -> None:
|
|
23
22
|
self._terminated: bool = False
|
|
24
23
|
self._termination_reason: Optional[str] = None
|
|
@@ -33,10 +32,10 @@ class BaseTerminator(ABC):
|
|
|
33
32
|
|
|
34
33
|
|
|
35
34
|
class ResponseTerminator(BaseTerminator):
|
|
36
|
-
|
|
37
35
|
@abstractmethod
|
|
38
36
|
def is_terminated(
|
|
39
|
-
|
|
37
|
+
self, messages: List[BaseMessage]
|
|
38
|
+
) -> Tuple[bool, Optional[str]]:
|
|
40
39
|
pass
|
|
41
40
|
|
|
42
41
|
@abstractmethod
|
|
@@ -34,9 +34,12 @@ class ResponseWordsTerminator(ResponseTerminator):
|
|
|
34
34
|
(default: :obj:`TerminationMode.ANY`)
|
|
35
35
|
"""
|
|
36
36
|
|
|
37
|
-
def __init__(
|
|
38
|
-
|
|
39
|
-
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
words_dict: Dict[str, int],
|
|
40
|
+
case_sensitive: bool = False,
|
|
41
|
+
mode: TerminationMode = TerminationMode.ANY,
|
|
42
|
+
):
|
|
40
43
|
super().__init__()
|
|
41
44
|
self.words_dict = words_dict
|
|
42
45
|
self.case_sensitive = case_sensitive
|
|
@@ -50,11 +53,14 @@ class ResponseWordsTerminator(ResponseTerminator):
|
|
|
50
53
|
for word in self.words_dict:
|
|
51
54
|
threshold = self.words_dict[word]
|
|
52
55
|
if threshold <= 0:
|
|
53
|
-
raise ValueError(
|
|
54
|
-
|
|
56
|
+
raise ValueError(
|
|
57
|
+
f"Threshold for word `{word}` should "
|
|
58
|
+
f"be larger than 0, got `{threshold}`"
|
|
59
|
+
)
|
|
55
60
|
|
|
56
61
|
def is_terminated(
|
|
57
|
-
|
|
62
|
+
self, messages: List[BaseMessage]
|
|
63
|
+
) -> Tuple[bool, Optional[str]]:
|
|
58
64
|
r"""Whether terminate the agent by checking the occurrence
|
|
59
65
|
of specified words reached to preset thresholds.
|
|
60
66
|
|
|
@@ -90,10 +96,12 @@ class ResponseWordsTerminator(ResponseTerminator):
|
|
|
90
96
|
for word, value in self._word_count_dict[i].items():
|
|
91
97
|
if value >= self.words_dict[word]:
|
|
92
98
|
reached += 1
|
|
93
|
-
reason = (
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
99
|
+
reason = (
|
|
100
|
+
f"Word `{word}` appears {value} times in the "
|
|
101
|
+
f"{i + 1} message of the response which has "
|
|
102
|
+
f"reached termination threshold "
|
|
103
|
+
f"{self.words_dict[word]}."
|
|
104
|
+
)
|
|
97
105
|
reasons.append(reason)
|
|
98
106
|
all_reasons.append(reasons)
|
|
99
107
|
num_reached.append(reached)
|
|
@@ -108,8 +116,9 @@ class ResponseWordsTerminator(ResponseTerminator):
|
|
|
108
116
|
self._terminated = True
|
|
109
117
|
self._termination_reason = "\n".join(all_reasons[i])
|
|
110
118
|
else:
|
|
111
|
-
raise ValueError(
|
|
112
|
-
|
|
119
|
+
raise ValueError(
|
|
120
|
+
f"Unsupported termination mode " f"`{self.mode}`"
|
|
121
|
+
)
|
|
113
122
|
return self._terminated, self._termination_reason
|
|
114
123
|
|
|
115
124
|
def reset(self):
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
14
|
from typing import Optional, Tuple
|
|
15
15
|
|
|
16
|
-
from camel.terminators import BaseTerminator
|
|
16
|
+
from camel.terminators.base import BaseTerminator
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
class TokenLimitTerminator(BaseTerminator):
|
|
@@ -29,8 +29,10 @@ class TokenLimitTerminator(BaseTerminator):
|
|
|
29
29
|
|
|
30
30
|
def _validate(self):
|
|
31
31
|
if self.token_limit <= 0:
|
|
32
|
-
raise ValueError(
|
|
33
|
-
|
|
32
|
+
raise ValueError(
|
|
33
|
+
f"`token_limit` should be a "
|
|
34
|
+
f"value larger than 0, got {self.token_limit}."
|
|
35
|
+
)
|
|
34
36
|
|
|
35
37
|
def is_terminated(self, num_tokens: int) -> Tuple[bool, Optional[str]]:
|
|
36
38
|
r"""Whether terminate the agent by checking number of
|
camel/types/__init__.py
CHANGED
|
@@ -12,24 +12,27 @@
|
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
14
|
from .enums import (
|
|
15
|
-
|
|
15
|
+
EmbeddingModelType,
|
|
16
16
|
ModelType,
|
|
17
|
+
OpenAIBackendRole,
|
|
18
|
+
OpenAIImageDetailType,
|
|
19
|
+
OpenAIImageType,
|
|
20
|
+
RoleType,
|
|
21
|
+
StorageType,
|
|
17
22
|
TaskType,
|
|
18
23
|
TerminationMode,
|
|
19
|
-
OpenAIBackendRole,
|
|
20
|
-
EmbeddingModelType,
|
|
21
24
|
VectorDistance,
|
|
22
25
|
)
|
|
23
26
|
from .openai_types import (
|
|
24
|
-
Choice,
|
|
25
27
|
ChatCompletion,
|
|
28
|
+
ChatCompletionAssistantMessageParam,
|
|
26
29
|
ChatCompletionChunk,
|
|
30
|
+
ChatCompletionFunctionMessageParam,
|
|
27
31
|
ChatCompletionMessage,
|
|
28
32
|
ChatCompletionMessageParam,
|
|
29
33
|
ChatCompletionSystemMessageParam,
|
|
30
34
|
ChatCompletionUserMessageParam,
|
|
31
|
-
|
|
32
|
-
ChatCompletionFunctionMessageParam,
|
|
35
|
+
Choice,
|
|
33
36
|
CompletionUsage,
|
|
34
37
|
)
|
|
35
38
|
|
|
@@ -41,6 +44,7 @@ __all__ = [
|
|
|
41
44
|
'OpenAIBackendRole',
|
|
42
45
|
'EmbeddingModelType',
|
|
43
46
|
'VectorDistance',
|
|
47
|
+
'StorageType',
|
|
44
48
|
'Choice',
|
|
45
49
|
'ChatCompletion',
|
|
46
50
|
'ChatCompletionChunk',
|
|
@@ -51,4 +55,6 @@ __all__ = [
|
|
|
51
55
|
'ChatCompletionAssistantMessageParam',
|
|
52
56
|
'ChatCompletionFunctionMessageParam',
|
|
53
57
|
'CompletionUsage',
|
|
58
|
+
'OpenAIImageType',
|
|
59
|
+
'OpenAIImageDetailType',
|
|
54
60
|
]
|