camel-ai 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -11
- camel/agents/__init__.py +5 -5
- camel/agents/chat_agent.py +124 -63
- camel/agents/critic_agent.py +28 -17
- camel/agents/deductive_reasoner_agent.py +235 -0
- camel/agents/embodied_agent.py +92 -40
- camel/agents/role_assignment_agent.py +27 -17
- camel/agents/task_agent.py +60 -34
- camel/agents/tool_agents/base.py +0 -1
- camel/agents/tool_agents/hugging_face_tool_agent.py +7 -4
- camel/configs.py +119 -7
- camel/embeddings/__init__.py +2 -0
- camel/embeddings/base.py +3 -2
- camel/embeddings/openai_embedding.py +3 -3
- camel/embeddings/sentence_transformers_embeddings.py +65 -0
- camel/functions/__init__.py +13 -3
- camel/functions/google_maps_function.py +335 -0
- camel/functions/math_functions.py +7 -7
- camel/functions/openai_function.py +344 -42
- camel/functions/search_functions.py +100 -35
- camel/functions/twitter_function.py +484 -0
- camel/functions/weather_functions.py +36 -23
- camel/generators.py +65 -46
- camel/human.py +17 -11
- camel/interpreters/__init__.py +25 -0
- camel/interpreters/base.py +49 -0
- camel/{utils/python_interpreter.py → interpreters/internal_python_interpreter.py} +129 -48
- camel/interpreters/interpreter_error.py +19 -0
- camel/interpreters/subprocess_interpreter.py +190 -0
- camel/loaders/__init__.py +22 -0
- camel/{functions/base_io_functions.py → loaders/base_io.py} +38 -35
- camel/{functions/unstructured_io_fuctions.py → loaders/unstructured_io.py} +199 -110
- camel/memories/__init__.py +17 -7
- camel/memories/agent_memories.py +156 -0
- camel/memories/base.py +97 -32
- camel/memories/blocks/__init__.py +21 -0
- camel/memories/{chat_history_memory.py → blocks/chat_history_block.py} +34 -34
- camel/memories/blocks/vectordb_block.py +101 -0
- camel/memories/context_creators/__init__.py +3 -2
- camel/memories/context_creators/score_based.py +32 -20
- camel/memories/records.py +6 -5
- camel/messages/__init__.py +2 -2
- camel/messages/base.py +99 -16
- camel/messages/func_message.py +7 -4
- camel/models/__init__.py +4 -2
- camel/models/anthropic_model.py +132 -0
- camel/models/base_model.py +3 -2
- camel/models/model_factory.py +10 -8
- camel/models/open_source_model.py +25 -13
- camel/models/openai_model.py +9 -10
- camel/models/stub_model.py +6 -5
- camel/prompts/__init__.py +7 -5
- camel/prompts/ai_society.py +21 -14
- camel/prompts/base.py +54 -47
- camel/prompts/code.py +22 -14
- camel/prompts/evaluation.py +8 -5
- camel/prompts/misalignment.py +26 -19
- camel/prompts/object_recognition.py +35 -0
- camel/prompts/prompt_templates.py +14 -8
- camel/prompts/role_description_prompt_template.py +16 -10
- camel/prompts/solution_extraction.py +9 -5
- camel/prompts/task_prompt_template.py +24 -21
- camel/prompts/translation.py +9 -5
- camel/responses/agent_responses.py +5 -2
- camel/retrievers/__init__.py +24 -0
- camel/retrievers/auto_retriever.py +319 -0
- camel/retrievers/base.py +64 -0
- camel/retrievers/bm25_retriever.py +149 -0
- camel/retrievers/vector_retriever.py +166 -0
- camel/societies/__init__.py +1 -1
- camel/societies/babyagi_playing.py +56 -32
- camel/societies/role_playing.py +188 -133
- camel/storages/__init__.py +18 -0
- camel/storages/graph_storages/__init__.py +23 -0
- camel/storages/graph_storages/base.py +82 -0
- camel/storages/graph_storages/graph_element.py +74 -0
- camel/storages/graph_storages/neo4j_graph.py +582 -0
- camel/storages/key_value_storages/base.py +1 -2
- camel/storages/key_value_storages/in_memory.py +1 -2
- camel/storages/key_value_storages/json.py +8 -13
- camel/storages/vectordb_storages/__init__.py +33 -0
- camel/storages/vectordb_storages/base.py +202 -0
- camel/storages/vectordb_storages/milvus.py +396 -0
- camel/storages/vectordb_storages/qdrant.py +371 -0
- camel/terminators/__init__.py +1 -1
- camel/terminators/base.py +2 -3
- camel/terminators/response_terminator.py +21 -12
- camel/terminators/token_limit_terminator.py +5 -3
- camel/types/__init__.py +12 -6
- camel/types/enums.py +86 -13
- camel/types/openai_types.py +10 -5
- camel/utils/__init__.py +18 -13
- camel/utils/commons.py +242 -81
- camel/utils/token_counting.py +135 -15
- {camel_ai-0.1.1.dist-info → camel_ai-0.1.3.dist-info}/METADATA +116 -74
- camel_ai-0.1.3.dist-info/RECORD +101 -0
- {camel_ai-0.1.1.dist-info → camel_ai-0.1.3.dist-info}/WHEEL +1 -1
- camel/memories/context_creators/base.py +0 -72
- camel_ai-0.1.1.dist-info/RECORD +0 -75
|
@@ -30,6 +30,7 @@ class ChatAgentResponse:
|
|
|
30
30
|
to terminate the chat session.
|
|
31
31
|
info (Dict[str, Any]): Extra information about the chat message.
|
|
32
32
|
"""
|
|
33
|
+
|
|
33
34
|
msgs: List[BaseMessage]
|
|
34
35
|
terminated: bool
|
|
35
36
|
info: Dict[str, Any]
|
|
@@ -37,6 +38,8 @@ class ChatAgentResponse:
|
|
|
37
38
|
@property
|
|
38
39
|
def msg(self):
|
|
39
40
|
if len(self.msgs) != 1:
|
|
40
|
-
raise RuntimeError(
|
|
41
|
-
|
|
41
|
+
raise RuntimeError(
|
|
42
|
+
"Property msg is only available "
|
|
43
|
+
"for a single message in msgs."
|
|
44
|
+
)
|
|
42
45
|
return self.msgs[0]
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
from .auto_retriever import AutoRetriever
|
|
15
|
+
from .base import BaseRetriever
|
|
16
|
+
from .bm25_retriever import BM25Retriever
|
|
17
|
+
from .vector_retriever import VectorRetriever
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
'BaseRetriever',
|
|
21
|
+
'VectorRetriever',
|
|
22
|
+
'AutoRetriever',
|
|
23
|
+
'BM25Retriever',
|
|
24
|
+
]
|
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
import datetime
|
|
15
|
+
import os
|
|
16
|
+
import re
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import List, Optional, Tuple, Union
|
|
19
|
+
from urllib.parse import urlparse
|
|
20
|
+
|
|
21
|
+
from camel.embeddings import BaseEmbedding, OpenAIEmbedding
|
|
22
|
+
from camel.retrievers.vector_retriever import VectorRetriever
|
|
23
|
+
from camel.storages import (
|
|
24
|
+
BaseVectorStorage,
|
|
25
|
+
MilvusStorage,
|
|
26
|
+
QdrantStorage,
|
|
27
|
+
VectorDBQuery,
|
|
28
|
+
)
|
|
29
|
+
from camel.types import StorageType
|
|
30
|
+
|
|
31
|
+
DEFAULT_TOP_K_RESULTS = 1
|
|
32
|
+
DEFAULT_SIMILARITY_THRESHOLD = 0.75
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class AutoRetriever:
|
|
36
|
+
r"""Facilitates the automatic retrieval of information using a
|
|
37
|
+
query-based approach with pre-defined elements.
|
|
38
|
+
|
|
39
|
+
Attributes:
|
|
40
|
+
url_and_api_key (Optional[Tuple[str, str]]): URL and API key for
|
|
41
|
+
accessing the vector storage remotely.
|
|
42
|
+
vector_storage_local_path (Optional[str]): Local path for vector
|
|
43
|
+
storage, if applicable.
|
|
44
|
+
storage_type (Optional[StorageType]): The type of vector storage to
|
|
45
|
+
use. Defaults to `StorageType.MILVUS`.
|
|
46
|
+
embedding_model (Optional[BaseEmbedding]): Model used for embedding
|
|
47
|
+
queries and documents. Defaults to `OpenAIEmbedding()`.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
url_and_api_key: Optional[Tuple[str, str]] = None,
|
|
53
|
+
vector_storage_local_path: Optional[str] = None,
|
|
54
|
+
storage_type: Optional[StorageType] = None,
|
|
55
|
+
embedding_model: Optional[BaseEmbedding] = None,
|
|
56
|
+
):
|
|
57
|
+
self.storage_type = storage_type or StorageType.MILVUS
|
|
58
|
+
self.embedding_model = embedding_model or OpenAIEmbedding()
|
|
59
|
+
self.vector_storage_local_path = vector_storage_local_path
|
|
60
|
+
self.url_and_api_key = url_and_api_key
|
|
61
|
+
|
|
62
|
+
def _initialize_vector_storage(
|
|
63
|
+
self,
|
|
64
|
+
collection_name: Optional[str] = None,
|
|
65
|
+
) -> BaseVectorStorage:
|
|
66
|
+
r"""Sets up and returns a vector storage instance with specified parameters.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
collection_name (Optional[str]): Name of the collection in the
|
|
70
|
+
vector storage.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
BaseVectorStorage: Configured vector storage instance.
|
|
74
|
+
"""
|
|
75
|
+
if self.storage_type == StorageType.MILVUS:
|
|
76
|
+
if self.url_and_api_key is None:
|
|
77
|
+
raise ValueError(
|
|
78
|
+
"URL and API key required for Milvus storage are not"
|
|
79
|
+
"provided."
|
|
80
|
+
)
|
|
81
|
+
return MilvusStorage(
|
|
82
|
+
vector_dim=self.embedding_model.get_output_dim(),
|
|
83
|
+
collection_name=collection_name,
|
|
84
|
+
url_and_api_key=self.url_and_api_key,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
if self.storage_type == StorageType.QDRANT:
|
|
88
|
+
return QdrantStorage(
|
|
89
|
+
vector_dim=self.embedding_model.get_output_dim(),
|
|
90
|
+
collection_name=collection_name,
|
|
91
|
+
path=self.vector_storage_local_path,
|
|
92
|
+
url_and_api_key=self.url_and_api_key,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
raise ValueError(
|
|
96
|
+
f"Unsupported vector storage type: {self.storage_type}"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def _collection_name_generator(self, content_input_path: str) -> str:
|
|
100
|
+
r"""Generates a valid collection name from a given file path or URL.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
content_input_path: str. The input URL or file path from which to
|
|
104
|
+
generate the collection name.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
str: A sanitized, valid collection name suitable for use.
|
|
108
|
+
"""
|
|
109
|
+
# Check path type
|
|
110
|
+
parsed_url = urlparse(content_input_path)
|
|
111
|
+
self.is_url = all([parsed_url.scheme, parsed_url.netloc])
|
|
112
|
+
|
|
113
|
+
# Convert given path into a collection name, ensuring it only
|
|
114
|
+
# contains numbers, letters, and underscores
|
|
115
|
+
if self.is_url:
|
|
116
|
+
# For URLs, remove https://, replace /, and any characters not
|
|
117
|
+
# allowed by Milvus with _
|
|
118
|
+
collection_name = re.sub(
|
|
119
|
+
r'[^0-9a-zA-Z]+',
|
|
120
|
+
'_',
|
|
121
|
+
content_input_path.replace("https://", ""),
|
|
122
|
+
)
|
|
123
|
+
else:
|
|
124
|
+
# For file paths, get the stem and replace spaces with _, also
|
|
125
|
+
# ensuring only allowed characters are present
|
|
126
|
+
collection_name = re.sub(
|
|
127
|
+
r'[^0-9a-zA-Z]+', '_', Path(content_input_path).stem
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Ensure the collection name does not start or end with underscore
|
|
131
|
+
collection_name = collection_name.strip("_")
|
|
132
|
+
# Limit the maximum length of the collection name to 30 characters
|
|
133
|
+
collection_name = collection_name[:30]
|
|
134
|
+
return collection_name
|
|
135
|
+
|
|
136
|
+
def _get_file_modified_date_from_file(self, content_input_path: str) -> str:
|
|
137
|
+
r"""Retrieves the last modified date and time of a given file. This
|
|
138
|
+
function takes a file path as input and returns the last modified date
|
|
139
|
+
and time of that file.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
content_input_path (str): The file path of the content whose
|
|
143
|
+
modified date is to be retrieved.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
str: The last modified time from file.
|
|
147
|
+
"""
|
|
148
|
+
mod_time = os.path.getmtime(content_input_path)
|
|
149
|
+
readable_mod_time = datetime.datetime.fromtimestamp(mod_time).isoformat(
|
|
150
|
+
timespec='seconds'
|
|
151
|
+
)
|
|
152
|
+
return readable_mod_time
|
|
153
|
+
|
|
154
|
+
def _get_file_modified_date_from_storage(
|
|
155
|
+
self, vector_storage_instance: BaseVectorStorage
|
|
156
|
+
) -> str:
|
|
157
|
+
r"""Retrieves the last modified date and time of a given file. This
|
|
158
|
+
function takes vector storage instance as input and returns the last
|
|
159
|
+
modified date from the meta data.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
vector_storage_instance (BaseVectorStorage): The vector storage
|
|
163
|
+
where modified date is to be retrieved from meta data.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
str: The last modified date from vector storage.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
# Insert any query to get modified date from vector db
|
|
170
|
+
# NOTE: Can be optimized when CAMEL vector storage support
|
|
171
|
+
# direct chunk payload extraction
|
|
172
|
+
query_vector_any = self.embedding_model.embed(obj="any_query")
|
|
173
|
+
query_any = VectorDBQuery(query_vector_any, top_k=1)
|
|
174
|
+
result_any = vector_storage_instance.query(query_any)
|
|
175
|
+
|
|
176
|
+
# Extract the file's last modified date from the metadata
|
|
177
|
+
# in the query result
|
|
178
|
+
if result_any[0].record.payload is not None:
|
|
179
|
+
file_modified_date_from_meta = result_any[0].record.payload[
|
|
180
|
+
"metadata"
|
|
181
|
+
]['last_modified']
|
|
182
|
+
else:
|
|
183
|
+
raise ValueError(
|
|
184
|
+
"The vector storage exits but the payload is None,"
|
|
185
|
+
"please check the collection"
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
return file_modified_date_from_meta
|
|
189
|
+
|
|
190
|
+
def run_vector_retriever(
|
|
191
|
+
self,
|
|
192
|
+
query: str,
|
|
193
|
+
content_input_paths: Union[str, List[str]],
|
|
194
|
+
top_k: int = DEFAULT_TOP_K_RESULTS,
|
|
195
|
+
similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
|
|
196
|
+
return_detailed_info: bool = False,
|
|
197
|
+
) -> str:
|
|
198
|
+
r"""Executes the automatic vector retriever process using vector storage.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
query (str): Query string for information retriever.
|
|
202
|
+
content_input_paths (Union[str, List[str]]): Paths to local
|
|
203
|
+
files or remote URLs.
|
|
204
|
+
top_k (int, optional): The number of top results to return during
|
|
205
|
+
retrieve. Must be a positive integer. Defaults to
|
|
206
|
+
`DEFAULT_TOP_K_RESULTS`.
|
|
207
|
+
similarity_threshold (float, optional): The similarity threshold
|
|
208
|
+
for filtering results. Defaults to
|
|
209
|
+
`DEFAULT_SIMILARITY_THRESHOLD`.
|
|
210
|
+
return_detailed_info (bool, optional): Whether to return detailed
|
|
211
|
+
information including similarity score, content path and
|
|
212
|
+
metadata. Defaults to False.
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
string: By default, returns only the text information. If
|
|
216
|
+
`return_detailed_info` is `True`, return detailed information
|
|
217
|
+
including similarity score, content path and metadata.
|
|
218
|
+
|
|
219
|
+
Raises:
|
|
220
|
+
ValueError: If there's an vector storage existing with content
|
|
221
|
+
name in the vector path but the payload is None. If
|
|
222
|
+
`content_input_paths` is empty.
|
|
223
|
+
RuntimeError: If any errors occur during the retrieve process.
|
|
224
|
+
"""
|
|
225
|
+
if not content_input_paths:
|
|
226
|
+
raise ValueError("content_input_paths cannot be empty.")
|
|
227
|
+
|
|
228
|
+
content_input_paths = (
|
|
229
|
+
[content_input_paths]
|
|
230
|
+
if isinstance(content_input_paths, str)
|
|
231
|
+
else content_input_paths
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
vr = VectorRetriever()
|
|
235
|
+
|
|
236
|
+
retrieved_infos = ""
|
|
237
|
+
retrieved_infos_text = ""
|
|
238
|
+
|
|
239
|
+
for content_input_path in content_input_paths:
|
|
240
|
+
# Generate a valid collection name
|
|
241
|
+
collection_name = self._collection_name_generator(
|
|
242
|
+
content_input_path
|
|
243
|
+
)
|
|
244
|
+
try:
|
|
245
|
+
vector_storage_instance = self._initialize_vector_storage(
|
|
246
|
+
collection_name
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# Check the modified time of the input file path, only works
|
|
250
|
+
# for local path since no standard way for remote url
|
|
251
|
+
file_is_modified = False # initialize with a default value
|
|
252
|
+
if (
|
|
253
|
+
vector_storage_instance.status().vector_count != 0
|
|
254
|
+
and not self.is_url
|
|
255
|
+
):
|
|
256
|
+
# Get original modified date from file
|
|
257
|
+
modified_date_from_file = (
|
|
258
|
+
self._get_file_modified_date_from_file(
|
|
259
|
+
content_input_path
|
|
260
|
+
)
|
|
261
|
+
)
|
|
262
|
+
# Get modified date from vector storage
|
|
263
|
+
modified_date_from_storage = (
|
|
264
|
+
self._get_file_modified_date_from_storage(
|
|
265
|
+
vector_storage_instance
|
|
266
|
+
)
|
|
267
|
+
)
|
|
268
|
+
# Determine if the file has been modified since the last
|
|
269
|
+
# check
|
|
270
|
+
file_is_modified = (
|
|
271
|
+
modified_date_from_file != modified_date_from_storage
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
if (
|
|
275
|
+
vector_storage_instance.status().vector_count == 0
|
|
276
|
+
or file_is_modified
|
|
277
|
+
):
|
|
278
|
+
# Clear the vector storage
|
|
279
|
+
vector_storage_instance.clear()
|
|
280
|
+
# Process and store the content to the vector storage
|
|
281
|
+
vr.process(content_input_path, vector_storage_instance)
|
|
282
|
+
# Retrieve info by given query from the vector storage
|
|
283
|
+
retrieved_info = vr.query(
|
|
284
|
+
query, vector_storage_instance, top_k, similarity_threshold
|
|
285
|
+
)
|
|
286
|
+
# Reorganize the retrieved info with original query
|
|
287
|
+
for info in retrieved_info:
|
|
288
|
+
retrieved_infos += "\n" + str(info)
|
|
289
|
+
retrieved_infos_text += "\n" + str(info['text'])
|
|
290
|
+
output = (
|
|
291
|
+
"Original Query:"
|
|
292
|
+
+ "\n"
|
|
293
|
+
+ "{"
|
|
294
|
+
+ query
|
|
295
|
+
+ "}"
|
|
296
|
+
+ "\n"
|
|
297
|
+
+ "Retrieved Context:"
|
|
298
|
+
+ retrieved_infos
|
|
299
|
+
)
|
|
300
|
+
output_text = (
|
|
301
|
+
"Original Query:"
|
|
302
|
+
+ "\n"
|
|
303
|
+
+ "{"
|
|
304
|
+
+ query
|
|
305
|
+
+ "}"
|
|
306
|
+
+ "\n"
|
|
307
|
+
+ "Retrieved Context:"
|
|
308
|
+
+ retrieved_infos_text
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
except Exception as e:
|
|
312
|
+
raise RuntimeError(
|
|
313
|
+
f"Error in auto vector retriever processing: {e!s}"
|
|
314
|
+
) from e
|
|
315
|
+
|
|
316
|
+
if return_detailed_info:
|
|
317
|
+
return output
|
|
318
|
+
else:
|
|
319
|
+
return output_text
|
camel/retrievers/base.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
from abc import ABC, abstractmethod
|
|
15
|
+
from typing import Any, Dict, List
|
|
16
|
+
|
|
17
|
+
DEFAULT_TOP_K_RESULTS = 1
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class BaseRetriever(ABC):
|
|
21
|
+
r"""Abstract base class for implementing various types of information
|
|
22
|
+
retrievers.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def __init__(self) -> None:
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def process(
|
|
31
|
+
self,
|
|
32
|
+
content_input_path: str,
|
|
33
|
+
chunk_type: str = "chunk_by_title",
|
|
34
|
+
**kwargs: Any,
|
|
35
|
+
) -> None:
|
|
36
|
+
r"""Processes content from a file or URL, divides it into chunks by
|
|
37
|
+
using `Unstructured IO`,then stored internally. This method must be
|
|
38
|
+
called before executing queries with the retriever.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
content_input_path (str): File path or URL of the content to be
|
|
42
|
+
processed.
|
|
43
|
+
chunk_type (str): Type of chunking going to apply. Defaults to
|
|
44
|
+
"chunk_by_title".
|
|
45
|
+
**kwargs (Any): Additional keyword arguments for content parsing.
|
|
46
|
+
"""
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def query(
|
|
51
|
+
self, query: str, top_k: int = DEFAULT_TOP_K_RESULTS, **kwargs: Any
|
|
52
|
+
) -> List[Dict[str, Any]]:
|
|
53
|
+
r"""Query the results. Subclasses should implement this
|
|
54
|
+
method according to their specific needs.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
query (str): Query string for information retriever.
|
|
58
|
+
top_k (int, optional): The number of top results to return during
|
|
59
|
+
retriever. Must be a positive integer. Defaults to
|
|
60
|
+
`DEFAULT_TOP_K_RESULTS`.
|
|
61
|
+
**kwargs (Any): Flexible keyword arguments for additional
|
|
62
|
+
parameters, like `similarity_threshold`.
|
|
63
|
+
"""
|
|
64
|
+
pass
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
from typing import Any, Dict, List
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
from camel.loaders import UnstructuredIO
|
|
19
|
+
from camel.retrievers import BaseRetriever
|
|
20
|
+
|
|
21
|
+
DEFAULT_TOP_K_RESULTS = 1
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BM25Retriever(BaseRetriever):
|
|
25
|
+
r"""An implementation of the `BaseRetriever` using the `BM25` model.
|
|
26
|
+
|
|
27
|
+
This class facilitates the retriever of relevant information using a
|
|
28
|
+
query-based approach, it ranks documents based on the occurrence and
|
|
29
|
+
frequency of the query terms.
|
|
30
|
+
|
|
31
|
+
Attributes:
|
|
32
|
+
bm25 (BM25Okapi): An instance of the BM25Okapi class used for
|
|
33
|
+
calculating document scores.
|
|
34
|
+
content_input_path (str): The path to the content that has been
|
|
35
|
+
processed and stored.
|
|
36
|
+
chunks (List[Any]): A list of document chunks processed from the
|
|
37
|
+
input content.
|
|
38
|
+
|
|
39
|
+
References:
|
|
40
|
+
https://github.com/dorianbrown/rank_bm25
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(self) -> None:
|
|
44
|
+
r"""Initializes the BM25Retriever."""
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
from rank_bm25 import BM25Okapi
|
|
48
|
+
except ImportError as e:
|
|
49
|
+
raise ImportError(
|
|
50
|
+
"Package `rank_bm25` not installed, install by running"
|
|
51
|
+
" 'pip install rank_bm25'"
|
|
52
|
+
) from e
|
|
53
|
+
|
|
54
|
+
self.bm25: BM25Okapi = None
|
|
55
|
+
self.content_input_path: str = ""
|
|
56
|
+
self.chunks: List[Any] = []
|
|
57
|
+
|
|
58
|
+
def process(
|
|
59
|
+
self,
|
|
60
|
+
content_input_path: str,
|
|
61
|
+
chunk_type: str = "chunk_by_title",
|
|
62
|
+
**kwargs: Any,
|
|
63
|
+
) -> None:
|
|
64
|
+
r"""Processes content from a file or URL, divides it into chunks by
|
|
65
|
+
using `Unstructured IO`,then stored internally. This method must be
|
|
66
|
+
called before executing queries with the retriever.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
content_input_path (str): File path or URL of the content to be
|
|
70
|
+
processed.
|
|
71
|
+
chunk_type (str): Type of chunking going to apply. Defaults to
|
|
72
|
+
"chunk_by_title".
|
|
73
|
+
**kwargs (Any): Additional keyword arguments for content parsing.
|
|
74
|
+
"""
|
|
75
|
+
from rank_bm25 import BM25Okapi
|
|
76
|
+
|
|
77
|
+
# Load and preprocess documents
|
|
78
|
+
self.content_input_path = content_input_path
|
|
79
|
+
unstructured_modules = UnstructuredIO()
|
|
80
|
+
elements = unstructured_modules.parse_file_or_url(
|
|
81
|
+
content_input_path, **kwargs
|
|
82
|
+
)
|
|
83
|
+
self.chunks = unstructured_modules.chunk_elements(
|
|
84
|
+
chunk_type=chunk_type, elements=elements
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Convert chunks to a list of strings for tokenization
|
|
88
|
+
tokenized_corpus = [str(chunk).split(" ") for chunk in self.chunks]
|
|
89
|
+
self.bm25 = BM25Okapi(tokenized_corpus)
|
|
90
|
+
|
|
91
|
+
def query( # type: ignore[override]
|
|
92
|
+
self,
|
|
93
|
+
query: str,
|
|
94
|
+
top_k: int = DEFAULT_TOP_K_RESULTS,
|
|
95
|
+
) -> List[Dict[str, Any]]:
|
|
96
|
+
r"""Executes a query and compiles the results.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
query (str): Query string for information retriever.
|
|
100
|
+
top_k (int, optional): The number of top results to return during
|
|
101
|
+
retriever. Must be a positive integer. Defaults to
|
|
102
|
+
`DEFAULT_TOP_K_RESULTS`.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
List[Dict[str]]: Concatenated list of the query results.
|
|
106
|
+
|
|
107
|
+
Raises:
|
|
108
|
+
ValueError: If `top_k` is less than or equal to 0, if the BM25
|
|
109
|
+
model has not been initialized by calling `process_and_store`
|
|
110
|
+
first.
|
|
111
|
+
|
|
112
|
+
Note:
|
|
113
|
+
`storage` and `kwargs` parameters are included to maintain
|
|
114
|
+
compatibility with the `BaseRetriever` interface but are not used
|
|
115
|
+
in this implementation.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
if top_k <= 0:
|
|
119
|
+
raise ValueError("top_k must be a positive integer.")
|
|
120
|
+
|
|
121
|
+
if self.bm25 is None:
|
|
122
|
+
raise ValueError(
|
|
123
|
+
"BM25 model is not initialized. Call `process_and_store`"
|
|
124
|
+
" first."
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Preprocess query similarly to how documents were processed
|
|
128
|
+
processed_query = query.split(" ")
|
|
129
|
+
# Retrieve documents based on BM25 scores
|
|
130
|
+
scores = self.bm25.get_scores(processed_query)
|
|
131
|
+
|
|
132
|
+
top_k_indices = np.argpartition(scores, -top_k)[-top_k:]
|
|
133
|
+
|
|
134
|
+
formatted_results = []
|
|
135
|
+
for i in top_k_indices:
|
|
136
|
+
result_dict = {
|
|
137
|
+
'similarity score': scores[i],
|
|
138
|
+
'content path': self.content_input_path,
|
|
139
|
+
'metadata': self.chunks[i].metadata.to_dict(),
|
|
140
|
+
'text': str(self.chunks[i]),
|
|
141
|
+
}
|
|
142
|
+
formatted_results.append(result_dict)
|
|
143
|
+
|
|
144
|
+
# Sort the list of dictionaries by 'similarity score' from high to low
|
|
145
|
+
formatted_results.sort(
|
|
146
|
+
key=lambda x: x['similarity score'], reverse=True
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
return formatted_results
|