camel-ai 0.1.1__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -11
- camel/agents/__init__.py +7 -5
- camel/agents/chat_agent.py +134 -86
- camel/agents/critic_agent.py +28 -17
- camel/agents/deductive_reasoner_agent.py +235 -0
- camel/agents/embodied_agent.py +92 -40
- camel/agents/knowledge_graph_agent.py +221 -0
- camel/agents/role_assignment_agent.py +27 -17
- camel/agents/task_agent.py +60 -34
- camel/agents/tool_agents/base.py +0 -1
- camel/agents/tool_agents/hugging_face_tool_agent.py +7 -4
- camel/configs/__init__.py +29 -0
- camel/configs/anthropic_config.py +73 -0
- camel/configs/base_config.py +22 -0
- camel/{configs.py → configs/openai_config.py} +37 -64
- camel/embeddings/__init__.py +2 -0
- camel/embeddings/base.py +3 -2
- camel/embeddings/openai_embedding.py +10 -5
- camel/embeddings/sentence_transformers_embeddings.py +65 -0
- camel/functions/__init__.py +18 -3
- camel/functions/google_maps_function.py +335 -0
- camel/functions/math_functions.py +7 -7
- camel/functions/open_api_function.py +380 -0
- camel/functions/open_api_specs/coursera/__init__.py +13 -0
- camel/functions/open_api_specs/coursera/openapi.yaml +82 -0
- camel/functions/open_api_specs/klarna/__init__.py +13 -0
- camel/functions/open_api_specs/klarna/openapi.yaml +87 -0
- camel/functions/open_api_specs/speak/__init__.py +13 -0
- camel/functions/open_api_specs/speak/openapi.yaml +151 -0
- camel/functions/openai_function.py +346 -42
- camel/functions/retrieval_functions.py +61 -0
- camel/functions/search_functions.py +100 -35
- camel/functions/slack_functions.py +275 -0
- camel/functions/twitter_function.py +484 -0
- camel/functions/weather_functions.py +36 -23
- camel/generators.py +65 -46
- camel/human.py +17 -11
- camel/interpreters/__init__.py +25 -0
- camel/interpreters/base.py +49 -0
- camel/{utils/python_interpreter.py → interpreters/internal_python_interpreter.py} +129 -48
- camel/interpreters/interpreter_error.py +19 -0
- camel/interpreters/subprocess_interpreter.py +190 -0
- camel/loaders/__init__.py +22 -0
- camel/{functions/base_io_functions.py → loaders/base_io.py} +38 -35
- camel/{functions/unstructured_io_fuctions.py → loaders/unstructured_io.py} +199 -110
- camel/memories/__init__.py +17 -7
- camel/memories/agent_memories.py +156 -0
- camel/memories/base.py +97 -32
- camel/memories/blocks/__init__.py +21 -0
- camel/memories/{chat_history_memory.py → blocks/chat_history_block.py} +34 -34
- camel/memories/blocks/vectordb_block.py +101 -0
- camel/memories/context_creators/__init__.py +3 -2
- camel/memories/context_creators/score_based.py +32 -20
- camel/memories/records.py +6 -5
- camel/messages/__init__.py +2 -2
- camel/messages/base.py +99 -16
- camel/messages/func_message.py +7 -4
- camel/models/__init__.py +6 -2
- camel/models/anthropic_model.py +146 -0
- camel/models/base_model.py +10 -3
- camel/models/model_factory.py +17 -11
- camel/models/open_source_model.py +25 -13
- camel/models/openai_audio_models.py +251 -0
- camel/models/openai_model.py +20 -13
- camel/models/stub_model.py +10 -5
- camel/prompts/__init__.py +7 -5
- camel/prompts/ai_society.py +21 -14
- camel/prompts/base.py +54 -47
- camel/prompts/code.py +22 -14
- camel/prompts/evaluation.py +8 -5
- camel/prompts/misalignment.py +26 -19
- camel/prompts/object_recognition.py +35 -0
- camel/prompts/prompt_templates.py +14 -8
- camel/prompts/role_description_prompt_template.py +16 -10
- camel/prompts/solution_extraction.py +9 -5
- camel/prompts/task_prompt_template.py +24 -21
- camel/prompts/translation.py +9 -5
- camel/responses/agent_responses.py +5 -2
- camel/retrievers/__init__.py +26 -0
- camel/retrievers/auto_retriever.py +330 -0
- camel/retrievers/base.py +69 -0
- camel/retrievers/bm25_retriever.py +140 -0
- camel/retrievers/cohere_rerank_retriever.py +108 -0
- camel/retrievers/vector_retriever.py +183 -0
- camel/societies/__init__.py +1 -1
- camel/societies/babyagi_playing.py +56 -32
- camel/societies/role_playing.py +188 -133
- camel/storages/__init__.py +18 -0
- camel/storages/graph_storages/__init__.py +23 -0
- camel/storages/graph_storages/base.py +82 -0
- camel/storages/graph_storages/graph_element.py +74 -0
- camel/storages/graph_storages/neo4j_graph.py +582 -0
- camel/storages/key_value_storages/base.py +1 -2
- camel/storages/key_value_storages/in_memory.py +1 -2
- camel/storages/key_value_storages/json.py +8 -13
- camel/storages/vectordb_storages/__init__.py +33 -0
- camel/storages/vectordb_storages/base.py +202 -0
- camel/storages/vectordb_storages/milvus.py +396 -0
- camel/storages/vectordb_storages/qdrant.py +373 -0
- camel/terminators/__init__.py +1 -1
- camel/terminators/base.py +2 -3
- camel/terminators/response_terminator.py +21 -12
- camel/terminators/token_limit_terminator.py +5 -3
- camel/toolkits/__init__.py +21 -0
- camel/toolkits/base.py +22 -0
- camel/toolkits/github_toolkit.py +245 -0
- camel/types/__init__.py +18 -6
- camel/types/enums.py +129 -15
- camel/types/openai_types.py +10 -5
- camel/utils/__init__.py +20 -13
- camel/utils/commons.py +170 -85
- camel/utils/token_counting.py +135 -15
- {camel_ai-0.1.1.dist-info → camel_ai-0.1.4.dist-info}/METADATA +123 -75
- camel_ai-0.1.4.dist-info/RECORD +119 -0
- {camel_ai-0.1.1.dist-info → camel_ai-0.1.4.dist-info}/WHEEL +1 -1
- camel/memories/context_creators/base.py +0 -72
- camel_ai-0.1.1.dist-info/RECORD +0 -75
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
import datetime
|
|
15
|
+
import os
|
|
16
|
+
import re
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import List, Optional, Tuple, Union
|
|
19
|
+
from urllib.parse import urlparse
|
|
20
|
+
|
|
21
|
+
from camel.embeddings import BaseEmbedding, OpenAIEmbedding
|
|
22
|
+
from camel.retrievers.vector_retriever import VectorRetriever
|
|
23
|
+
from camel.storages import (
|
|
24
|
+
BaseVectorStorage,
|
|
25
|
+
MilvusStorage,
|
|
26
|
+
QdrantStorage,
|
|
27
|
+
VectorDBQuery,
|
|
28
|
+
)
|
|
29
|
+
from camel.types import StorageType
|
|
30
|
+
|
|
31
|
+
DEFAULT_TOP_K_RESULTS = 1
|
|
32
|
+
DEFAULT_SIMILARITY_THRESHOLD = 0.75
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class AutoRetriever:
|
|
36
|
+
r"""Facilitates the automatic retrieval of information using a
|
|
37
|
+
query-based approach with pre-defined elements.
|
|
38
|
+
|
|
39
|
+
Attributes:
|
|
40
|
+
url_and_api_key (Optional[Tuple[str, str]]): URL and API key for
|
|
41
|
+
accessing the vector storage remotely.
|
|
42
|
+
vector_storage_local_path (Optional[str]): Local path for vector
|
|
43
|
+
storage, if applicable.
|
|
44
|
+
storage_type (Optional[StorageType]): The type of vector storage to
|
|
45
|
+
use. Defaults to `StorageType.MILVUS`.
|
|
46
|
+
embedding_model (Optional[BaseEmbedding]): Model used for embedding
|
|
47
|
+
queries and documents. Defaults to `OpenAIEmbedding()`.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
url_and_api_key: Optional[Tuple[str, str]] = None,
|
|
53
|
+
vector_storage_local_path: Optional[str] = None,
|
|
54
|
+
storage_type: Optional[StorageType] = None,
|
|
55
|
+
embedding_model: Optional[BaseEmbedding] = None,
|
|
56
|
+
):
|
|
57
|
+
self.storage_type = storage_type or StorageType.MILVUS
|
|
58
|
+
self.embedding_model = embedding_model or OpenAIEmbedding()
|
|
59
|
+
self.vector_storage_local_path = vector_storage_local_path
|
|
60
|
+
self.url_and_api_key = url_and_api_key
|
|
61
|
+
|
|
62
|
+
def _initialize_vector_storage(
|
|
63
|
+
self,
|
|
64
|
+
collection_name: Optional[str] = None,
|
|
65
|
+
) -> BaseVectorStorage:
|
|
66
|
+
r"""Sets up and returns a vector storage instance with specified
|
|
67
|
+
parameters.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
collection_name (Optional[str]): Name of the collection in the
|
|
71
|
+
vector storage.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
BaseVectorStorage: Configured vector storage instance.
|
|
75
|
+
"""
|
|
76
|
+
if self.storage_type == StorageType.MILVUS:
|
|
77
|
+
if self.url_and_api_key is None:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
"URL and API key required for Milvus storage are not"
|
|
80
|
+
"provided."
|
|
81
|
+
)
|
|
82
|
+
return MilvusStorage(
|
|
83
|
+
vector_dim=self.embedding_model.get_output_dim(),
|
|
84
|
+
collection_name=collection_name,
|
|
85
|
+
url_and_api_key=self.url_and_api_key,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
if self.storage_type == StorageType.QDRANT:
|
|
89
|
+
return QdrantStorage(
|
|
90
|
+
vector_dim=self.embedding_model.get_output_dim(),
|
|
91
|
+
collection_name=collection_name,
|
|
92
|
+
path=self.vector_storage_local_path,
|
|
93
|
+
url_and_api_key=self.url_and_api_key,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
raise ValueError(
|
|
97
|
+
f"Unsupported vector storage type: {self.storage_type}"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def _collection_name_generator(self, content_input_path: str) -> str:
|
|
101
|
+
r"""Generates a valid collection name from a given file path or URL.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
content_input_path: str. The input URL or file path from which to
|
|
105
|
+
generate the collection name.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
str: A sanitized, valid collection name suitable for use.
|
|
109
|
+
"""
|
|
110
|
+
# Check path type
|
|
111
|
+
parsed_url = urlparse(content_input_path)
|
|
112
|
+
self.is_url = all([parsed_url.scheme, parsed_url.netloc])
|
|
113
|
+
|
|
114
|
+
# Convert given path into a collection name, ensuring it only
|
|
115
|
+
# contains numbers, letters, and underscores
|
|
116
|
+
if self.is_url:
|
|
117
|
+
# For URLs, remove https://, replace /, and any characters not
|
|
118
|
+
# allowed by Milvus with _
|
|
119
|
+
collection_name = re.sub(
|
|
120
|
+
r'[^0-9a-zA-Z]+',
|
|
121
|
+
'_',
|
|
122
|
+
content_input_path.replace("https://", ""),
|
|
123
|
+
)
|
|
124
|
+
else:
|
|
125
|
+
# For file paths, get the stem and replace spaces with _, also
|
|
126
|
+
# ensuring only allowed characters are present
|
|
127
|
+
collection_name = re.sub(
|
|
128
|
+
r'[^0-9a-zA-Z]+', '_', Path(content_input_path).stem
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# Ensure the collection name does not start or end with underscore
|
|
132
|
+
collection_name = collection_name.strip("_")
|
|
133
|
+
# Limit the maximum length of the collection name to 30 characters
|
|
134
|
+
collection_name = collection_name[:30]
|
|
135
|
+
return collection_name
|
|
136
|
+
|
|
137
|
+
def _get_file_modified_date_from_file(self, content_input_path: str) -> str:
|
|
138
|
+
r"""Retrieves the last modified date and time of a given file. This
|
|
139
|
+
function takes a file path as input and returns the last modified date
|
|
140
|
+
and time of that file.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
content_input_path (str): The file path of the content whose
|
|
144
|
+
modified date is to be retrieved.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
str: The last modified time from file.
|
|
148
|
+
"""
|
|
149
|
+
mod_time = os.path.getmtime(content_input_path)
|
|
150
|
+
readable_mod_time = datetime.datetime.fromtimestamp(mod_time).isoformat(
|
|
151
|
+
timespec='seconds'
|
|
152
|
+
)
|
|
153
|
+
return readable_mod_time
|
|
154
|
+
|
|
155
|
+
def _get_file_modified_date_from_storage(
|
|
156
|
+
self, vector_storage_instance: BaseVectorStorage
|
|
157
|
+
) -> str:
|
|
158
|
+
r"""Retrieves the last modified date and time of a given file. This
|
|
159
|
+
function takes vector storage instance as input and returns the last
|
|
160
|
+
modified date from the meta data.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
vector_storage_instance (BaseVectorStorage): The vector storage
|
|
164
|
+
where modified date is to be retrieved from meta data.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
str: The last modified date from vector storage.
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
# Insert any query to get modified date from vector db
|
|
171
|
+
# NOTE: Can be optimized when CAMEL vector storage support
|
|
172
|
+
# direct chunk payload extraction
|
|
173
|
+
query_vector_any = self.embedding_model.embed(obj="any_query")
|
|
174
|
+
query_any = VectorDBQuery(query_vector_any, top_k=1)
|
|
175
|
+
result_any = vector_storage_instance.query(query_any)
|
|
176
|
+
|
|
177
|
+
# Extract the file's last modified date from the metadata
|
|
178
|
+
# in the query result
|
|
179
|
+
if result_any[0].record.payload is not None:
|
|
180
|
+
file_modified_date_from_meta = result_any[0].record.payload[
|
|
181
|
+
"metadata"
|
|
182
|
+
]['last_modified']
|
|
183
|
+
else:
|
|
184
|
+
raise ValueError(
|
|
185
|
+
"The vector storage exits but the payload is None,"
|
|
186
|
+
"please check the collection"
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
return file_modified_date_from_meta
|
|
190
|
+
|
|
191
|
+
def run_vector_retriever(
|
|
192
|
+
self,
|
|
193
|
+
query: str,
|
|
194
|
+
content_input_paths: Union[str, List[str]],
|
|
195
|
+
top_k: int = DEFAULT_TOP_K_RESULTS,
|
|
196
|
+
similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
|
|
197
|
+
return_detailed_info: bool = False,
|
|
198
|
+
) -> str:
|
|
199
|
+
r"""Executes the automatic vector retriever process using vector
|
|
200
|
+
storage.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
query (str): Query string for information retriever.
|
|
204
|
+
content_input_paths (Union[str, List[str]]): Paths to local
|
|
205
|
+
files or remote URLs.
|
|
206
|
+
top_k (int, optional): The number of top results to return during
|
|
207
|
+
retrieve. Must be a positive integer. Defaults to
|
|
208
|
+
`DEFAULT_TOP_K_RESULTS`.
|
|
209
|
+
similarity_threshold (float, optional): The similarity threshold
|
|
210
|
+
for filtering results. Defaults to
|
|
211
|
+
`DEFAULT_SIMILARITY_THRESHOLD`.
|
|
212
|
+
return_detailed_info (bool, optional): Whether to return detailed
|
|
213
|
+
information including similarity score, content path and
|
|
214
|
+
metadata. Defaults to False.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
string: By default, returns only the text information. If
|
|
218
|
+
`return_detailed_info` is `True`, return detailed information
|
|
219
|
+
including similarity score, content path and metadata.
|
|
220
|
+
|
|
221
|
+
Raises:
|
|
222
|
+
ValueError: If there's an vector storage existing with content
|
|
223
|
+
name in the vector path but the payload is None. If
|
|
224
|
+
`content_input_paths` is empty.
|
|
225
|
+
RuntimeError: If any errors occur during the retrieve process.
|
|
226
|
+
"""
|
|
227
|
+
if not content_input_paths:
|
|
228
|
+
raise ValueError("content_input_paths cannot be empty.")
|
|
229
|
+
|
|
230
|
+
content_input_paths = (
|
|
231
|
+
[content_input_paths]
|
|
232
|
+
if isinstance(content_input_paths, str)
|
|
233
|
+
else content_input_paths
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
vr = VectorRetriever()
|
|
237
|
+
|
|
238
|
+
all_retrieved_info = []
|
|
239
|
+
for content_input_path in content_input_paths:
|
|
240
|
+
# Generate a valid collection name
|
|
241
|
+
collection_name = self._collection_name_generator(
|
|
242
|
+
content_input_path
|
|
243
|
+
)
|
|
244
|
+
try:
|
|
245
|
+
vector_storage_instance = self._initialize_vector_storage(
|
|
246
|
+
collection_name
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# Check the modified time of the input file path, only works
|
|
250
|
+
# for local path since no standard way for remote url
|
|
251
|
+
file_is_modified = False # initialize with a default value
|
|
252
|
+
if (
|
|
253
|
+
vector_storage_instance.status().vector_count != 0
|
|
254
|
+
and not self.is_url
|
|
255
|
+
):
|
|
256
|
+
# Get original modified date from file
|
|
257
|
+
modified_date_from_file = (
|
|
258
|
+
self._get_file_modified_date_from_file(
|
|
259
|
+
content_input_path
|
|
260
|
+
)
|
|
261
|
+
)
|
|
262
|
+
# Get modified date from vector storage
|
|
263
|
+
modified_date_from_storage = (
|
|
264
|
+
self._get_file_modified_date_from_storage(
|
|
265
|
+
vector_storage_instance
|
|
266
|
+
)
|
|
267
|
+
)
|
|
268
|
+
# Determine if the file has been modified since the last
|
|
269
|
+
# check
|
|
270
|
+
file_is_modified = (
|
|
271
|
+
modified_date_from_file != modified_date_from_storage
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
if (
|
|
275
|
+
vector_storage_instance.status().vector_count == 0
|
|
276
|
+
or file_is_modified
|
|
277
|
+
):
|
|
278
|
+
# Clear the vector storage
|
|
279
|
+
vector_storage_instance.clear()
|
|
280
|
+
# Process and store the content to the vector storage
|
|
281
|
+
vr = VectorRetriever(
|
|
282
|
+
storage=vector_storage_instance,
|
|
283
|
+
similarity_threshold=similarity_threshold,
|
|
284
|
+
)
|
|
285
|
+
vr.process(content_input_path)
|
|
286
|
+
else:
|
|
287
|
+
vr = VectorRetriever(
|
|
288
|
+
storage=vector_storage_instance,
|
|
289
|
+
similarity_threshold=similarity_threshold,
|
|
290
|
+
)
|
|
291
|
+
# Retrieve info by given query from the vector storage
|
|
292
|
+
retrieved_info = vr.query(query, top_k)
|
|
293
|
+
all_retrieved_info.extend(retrieved_info)
|
|
294
|
+
except Exception as e:
|
|
295
|
+
raise RuntimeError(
|
|
296
|
+
f"Error in auto vector retriever processing: {e!s}"
|
|
297
|
+
) from e
|
|
298
|
+
|
|
299
|
+
# Split records into those with and without a 'similarity_score'
|
|
300
|
+
# Records with 'similarity_score' lower than 'similarity_threshold'
|
|
301
|
+
# will not have a 'similarity_score' in the output content
|
|
302
|
+
with_score = [
|
|
303
|
+
info for info in all_retrieved_info if 'similarity score' in info
|
|
304
|
+
]
|
|
305
|
+
without_score = [
|
|
306
|
+
info
|
|
307
|
+
for info in all_retrieved_info
|
|
308
|
+
if 'similarity score' not in info
|
|
309
|
+
]
|
|
310
|
+
# Sort only the list with scores
|
|
311
|
+
with_score_sorted = sorted(
|
|
312
|
+
with_score, key=lambda x: x['similarity score'], reverse=True
|
|
313
|
+
)
|
|
314
|
+
# Merge back the sorted scored items with the non-scored items
|
|
315
|
+
all_retrieved_info_sorted = with_score_sorted + without_score
|
|
316
|
+
# Select the 'top_k' results
|
|
317
|
+
all_retrieved_info = all_retrieved_info_sorted[:top_k]
|
|
318
|
+
|
|
319
|
+
retrieved_infos = "\n".join(str(info) for info in all_retrieved_info)
|
|
320
|
+
retrieved_infos_text = "\n".join(
|
|
321
|
+
info['text'] for info in all_retrieved_info if 'text' in info
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
detailed_info = f"Original Query:\n{{ {query} }}\nRetrieved Context:\n{retrieved_infos}"
|
|
325
|
+
text_info = f"Original Query:\n{{ {query} }}\nRetrieved Context:\n{retrieved_infos_text}"
|
|
326
|
+
|
|
327
|
+
if return_detailed_info:
|
|
328
|
+
return detailed_info
|
|
329
|
+
else:
|
|
330
|
+
return text_info
|
camel/retrievers/base.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
from abc import ABC, abstractmethod
|
|
15
|
+
from typing import Any, Callable
|
|
16
|
+
|
|
17
|
+
DEFAULT_TOP_K_RESULTS = 1
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _query_unimplemented(self, *input: Any) -> None:
|
|
21
|
+
r"""Defines the query behavior performed at every call.
|
|
22
|
+
|
|
23
|
+
Query the results. Subclasses should implement this
|
|
24
|
+
method according to their specific needs.
|
|
25
|
+
|
|
26
|
+
It should be overridden by all subclasses.
|
|
27
|
+
|
|
28
|
+
.. note::
|
|
29
|
+
Although the recipe for forward pass needs to be defined within
|
|
30
|
+
this function, one should call the :class:`BaseRetriever` instance
|
|
31
|
+
afterwards instead of this since the former takes care of running the
|
|
32
|
+
registered hooks while the latter silently ignores them.
|
|
33
|
+
"""
|
|
34
|
+
raise NotImplementedError(
|
|
35
|
+
f"Retriever [{type(self).__name__}] is missing the required \"query\" function"
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _process_unimplemented(self, *input: Any) -> None:
|
|
40
|
+
r"""Defines the process behavior performed at every call.
|
|
41
|
+
|
|
42
|
+
Processes content from a file or URL, divides it into chunks by
|
|
43
|
+
using `Unstructured IO`,then stored internally. This method must be
|
|
44
|
+
called before executing queries with the retriever.
|
|
45
|
+
|
|
46
|
+
Should be overridden by all subclasses.
|
|
47
|
+
|
|
48
|
+
.. note::
|
|
49
|
+
Although the recipe for forward pass needs to be defined within
|
|
50
|
+
this function, one should call the :class:`BaseRetriever` instance
|
|
51
|
+
afterwards instead of this since the former takes care of running the
|
|
52
|
+
registered hooks while the latter silently ignores them.
|
|
53
|
+
"""
|
|
54
|
+
raise NotImplementedError(
|
|
55
|
+
f"Retriever [{type(self).__name__}] is missing the required \"process\" function"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class BaseRetriever(ABC):
|
|
60
|
+
r"""Abstract base class for implementing various types of information
|
|
61
|
+
retrievers.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
@abstractmethod
|
|
65
|
+
def __init__(self) -> None:
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
process: Callable[..., Any] = _process_unimplemented
|
|
69
|
+
query: Callable[..., Any] = _query_unimplemented
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
from typing import Any, Dict, List
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
from camel.loaders import UnstructuredIO
|
|
19
|
+
from camel.retrievers import BaseRetriever
|
|
20
|
+
|
|
21
|
+
DEFAULT_TOP_K_RESULTS = 1
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BM25Retriever(BaseRetriever):
|
|
25
|
+
r"""An implementation of the `BaseRetriever` using the `BM25` model.
|
|
26
|
+
|
|
27
|
+
This class facilitates the retriever of relevant information using a
|
|
28
|
+
query-based approach, it ranks documents based on the occurrence and
|
|
29
|
+
frequency of the query terms.
|
|
30
|
+
|
|
31
|
+
Attributes:
|
|
32
|
+
bm25 (BM25Okapi): An instance of the BM25Okapi class used for
|
|
33
|
+
calculating document scores.
|
|
34
|
+
content_input_path (str): The path to the content that has been
|
|
35
|
+
processed and stored.
|
|
36
|
+
unstructured_modules (UnstructuredIO): A module for parsing files and
|
|
37
|
+
URLs and chunking content based on specified parameters.
|
|
38
|
+
|
|
39
|
+
References:
|
|
40
|
+
https://github.com/dorianbrown/rank_bm25
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(self) -> None:
|
|
44
|
+
r"""Initializes the BM25Retriever."""
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
from rank_bm25 import BM25Okapi
|
|
48
|
+
except ImportError as e:
|
|
49
|
+
raise ImportError(
|
|
50
|
+
"Package `rank_bm25` not installed, install by running 'pip install rank_bm25'"
|
|
51
|
+
) from e
|
|
52
|
+
|
|
53
|
+
self.bm25: BM25Okapi = None
|
|
54
|
+
self.content_input_path: str = ""
|
|
55
|
+
self.unstructured_modules: UnstructuredIO = UnstructuredIO()
|
|
56
|
+
|
|
57
|
+
def process(
|
|
58
|
+
self,
|
|
59
|
+
content_input_path: str,
|
|
60
|
+
chunk_type: str = "chunk_by_title",
|
|
61
|
+
**kwargs: Any,
|
|
62
|
+
) -> None:
|
|
63
|
+
r"""Processes content from a file or URL, divides it into chunks by
|
|
64
|
+
using `Unstructured IO`,then stored internally. This method must be
|
|
65
|
+
called before executing queries with the retriever.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
content_input_path (str): File path or URL of the content to be
|
|
69
|
+
processed.
|
|
70
|
+
chunk_type (str): Type of chunking going to apply. Defaults to
|
|
71
|
+
"chunk_by_title".
|
|
72
|
+
**kwargs (Any): Additional keyword arguments for content parsing.
|
|
73
|
+
"""
|
|
74
|
+
from rank_bm25 import BM25Okapi
|
|
75
|
+
|
|
76
|
+
# Load and preprocess documents
|
|
77
|
+
self.content_input_path = content_input_path
|
|
78
|
+
elements = self.unstructured_modules.parse_file_or_url(
|
|
79
|
+
content_input_path, **kwargs
|
|
80
|
+
)
|
|
81
|
+
self.chunks = self.unstructured_modules.chunk_elements(
|
|
82
|
+
chunk_type=chunk_type, elements=elements
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Convert chunks to a list of strings for tokenization
|
|
86
|
+
tokenized_corpus = [str(chunk).split(" ") for chunk in self.chunks]
|
|
87
|
+
self.bm25 = BM25Okapi(tokenized_corpus)
|
|
88
|
+
|
|
89
|
+
def query(
|
|
90
|
+
self,
|
|
91
|
+
query: str,
|
|
92
|
+
top_k: int = DEFAULT_TOP_K_RESULTS,
|
|
93
|
+
) -> List[Dict[str, Any]]:
|
|
94
|
+
r"""Executes a query and compiles the results.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
query (str): Query string for information retriever.
|
|
98
|
+
top_k (int, optional): The number of top results to return during
|
|
99
|
+
retriever. Must be a positive integer. Defaults to
|
|
100
|
+
`DEFAULT_TOP_K_RESULTS`.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
List[Dict[str]]: Concatenated list of the query results.
|
|
104
|
+
|
|
105
|
+
Raises:
|
|
106
|
+
ValueError: If `top_k` is less than or equal to 0, if the BM25
|
|
107
|
+
model has not been initialized by calling `process`
|
|
108
|
+
first.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
if top_k <= 0:
|
|
112
|
+
raise ValueError("top_k must be a positive integer.")
|
|
113
|
+
if self.bm25 is None or not self.chunks:
|
|
114
|
+
raise ValueError(
|
|
115
|
+
"BM25 model is not initialized. Call `process` first."
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# Preprocess query similarly to how documents were processed
|
|
119
|
+
processed_query = query.split(" ")
|
|
120
|
+
# Retrieve documents based on BM25 scores
|
|
121
|
+
scores = self.bm25.get_scores(processed_query)
|
|
122
|
+
|
|
123
|
+
top_k_indices = np.argpartition(scores, -top_k)[-top_k:]
|
|
124
|
+
|
|
125
|
+
formatted_results = []
|
|
126
|
+
for i in top_k_indices:
|
|
127
|
+
result_dict = {
|
|
128
|
+
'similarity score': scores[i],
|
|
129
|
+
'content path': self.content_input_path,
|
|
130
|
+
'metadata': self.chunks[i].metadata.to_dict(),
|
|
131
|
+
'text': str(self.chunks[i]),
|
|
132
|
+
}
|
|
133
|
+
formatted_results.append(result_dict)
|
|
134
|
+
|
|
135
|
+
# Sort the list of dictionaries by 'similarity score' from high to low
|
|
136
|
+
formatted_results.sort(
|
|
137
|
+
key=lambda x: x['similarity score'], reverse=True
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
return formatted_results
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
import os
|
|
15
|
+
from typing import Any, Dict, List, Optional
|
|
16
|
+
|
|
17
|
+
from camel.retrievers import BaseRetriever
|
|
18
|
+
|
|
19
|
+
DEFAULT_TOP_K_RESULTS = 1
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class CohereRerankRetriever(BaseRetriever):
|
|
23
|
+
r"""An implementation of the `BaseRetriever` using the `Cohere Re-ranking`
|
|
24
|
+
model.
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
model_name (str): The model name to use for re-ranking.
|
|
28
|
+
api_key (Optional[str]): The API key for authenticating with the
|
|
29
|
+
Cohere service.
|
|
30
|
+
|
|
31
|
+
References:
|
|
32
|
+
https://txt.cohere.com/rerank/
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
model_name: str = "rerank-multilingual-v2.0",
|
|
38
|
+
api_key: Optional[str] = None,
|
|
39
|
+
) -> None:
|
|
40
|
+
r"""Initializes an instance of the CohereRerankRetriever. This
|
|
41
|
+
constructor sets up a client for interacting with the Cohere API using
|
|
42
|
+
the specified model name and API key. If the API key is not provided,
|
|
43
|
+
it attempts to retrieve it from the COHERE_API_KEY environment
|
|
44
|
+
variable.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
model_name (str): The name of the model to be used for re-ranking.
|
|
48
|
+
Defaults to 'rerank-multilingual-v2.0'.
|
|
49
|
+
api_key (Optional[str]): The API key for authenticating requests
|
|
50
|
+
to the Cohere API. If not provided, the method will attempt to
|
|
51
|
+
retrieve the key from the environment variable
|
|
52
|
+
'COHERE_API_KEY'.
|
|
53
|
+
|
|
54
|
+
Raises:
|
|
55
|
+
ImportError: If the 'cohere' package is not installed.
|
|
56
|
+
ValueError: If the API key is neither passed as an argument nor
|
|
57
|
+
set in the environment variable.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
try:
|
|
61
|
+
import cohere
|
|
62
|
+
except ImportError as e:
|
|
63
|
+
raise ImportError("Package 'cohere' is not installed") from e
|
|
64
|
+
|
|
65
|
+
try:
|
|
66
|
+
self.api_key = api_key or os.environ["COHERE_API_KEY"]
|
|
67
|
+
except ValueError as e:
|
|
68
|
+
raise ValueError(
|
|
69
|
+
"Must pass in cohere api key or specify via COHERE_API_KEY environment variable."
|
|
70
|
+
) from e
|
|
71
|
+
|
|
72
|
+
self.co = cohere.Client(self.api_key)
|
|
73
|
+
self.model_name = model_name
|
|
74
|
+
|
|
75
|
+
def query(
|
|
76
|
+
self,
|
|
77
|
+
query: str,
|
|
78
|
+
retrieved_result: List[Dict[str, Any]],
|
|
79
|
+
top_k: int = DEFAULT_TOP_K_RESULTS,
|
|
80
|
+
) -> List[Dict[str, Any]]:
|
|
81
|
+
r"""Queries and compiles results using the Cohere re-ranking model.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
query (str): Query string for information retriever.
|
|
85
|
+
retrieved_result (List[Dict[str, Any]]): The content to be
|
|
86
|
+
re-ranked, should be the output from `BaseRetriever` like
|
|
87
|
+
`VectorRetriever`.
|
|
88
|
+
top_k (int, optional): The number of top results to return during
|
|
89
|
+
retriever. Must be a positive integer. Defaults to
|
|
90
|
+
`DEFAULT_TOP_K_RESULTS`.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
List[Dict[str, Any]]: Concatenated list of the query results.
|
|
94
|
+
"""
|
|
95
|
+
rerank_results = self.co.rerank(
|
|
96
|
+
query=query,
|
|
97
|
+
documents=retrieved_result,
|
|
98
|
+
top_n=top_k,
|
|
99
|
+
model=self.model_name,
|
|
100
|
+
)
|
|
101
|
+
formatted_results = []
|
|
102
|
+
for i in range(0, len(rerank_results.results)):
|
|
103
|
+
selected_chunk = retrieved_result[rerank_results[i].index]
|
|
104
|
+
selected_chunk['similarity score'] = rerank_results[
|
|
105
|
+
i
|
|
106
|
+
].relevance_score
|
|
107
|
+
formatted_results.append(selected_chunk)
|
|
108
|
+
return formatted_results
|