camel-ai 0.1.1__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (117) hide show
  1. camel/__init__.py +1 -11
  2. camel/agents/__init__.py +7 -5
  3. camel/agents/chat_agent.py +134 -86
  4. camel/agents/critic_agent.py +28 -17
  5. camel/agents/deductive_reasoner_agent.py +235 -0
  6. camel/agents/embodied_agent.py +92 -40
  7. camel/agents/knowledge_graph_agent.py +221 -0
  8. camel/agents/role_assignment_agent.py +27 -17
  9. camel/agents/task_agent.py +60 -34
  10. camel/agents/tool_agents/base.py +0 -1
  11. camel/agents/tool_agents/hugging_face_tool_agent.py +7 -4
  12. camel/configs/__init__.py +29 -0
  13. camel/configs/anthropic_config.py +73 -0
  14. camel/configs/base_config.py +22 -0
  15. camel/{configs.py → configs/openai_config.py} +37 -64
  16. camel/embeddings/__init__.py +2 -0
  17. camel/embeddings/base.py +3 -2
  18. camel/embeddings/openai_embedding.py +10 -5
  19. camel/embeddings/sentence_transformers_embeddings.py +65 -0
  20. camel/functions/__init__.py +18 -3
  21. camel/functions/google_maps_function.py +335 -0
  22. camel/functions/math_functions.py +7 -7
  23. camel/functions/open_api_function.py +380 -0
  24. camel/functions/open_api_specs/coursera/__init__.py +13 -0
  25. camel/functions/open_api_specs/coursera/openapi.yaml +82 -0
  26. camel/functions/open_api_specs/klarna/__init__.py +13 -0
  27. camel/functions/open_api_specs/klarna/openapi.yaml +87 -0
  28. camel/functions/open_api_specs/speak/__init__.py +13 -0
  29. camel/functions/open_api_specs/speak/openapi.yaml +151 -0
  30. camel/functions/openai_function.py +346 -42
  31. camel/functions/retrieval_functions.py +61 -0
  32. camel/functions/search_functions.py +100 -35
  33. camel/functions/slack_functions.py +275 -0
  34. camel/functions/twitter_function.py +484 -0
  35. camel/functions/weather_functions.py +36 -23
  36. camel/generators.py +65 -46
  37. camel/human.py +17 -11
  38. camel/interpreters/__init__.py +25 -0
  39. camel/interpreters/base.py +49 -0
  40. camel/{utils/python_interpreter.py → interpreters/internal_python_interpreter.py} +129 -48
  41. camel/interpreters/interpreter_error.py +19 -0
  42. camel/interpreters/subprocess_interpreter.py +190 -0
  43. camel/loaders/__init__.py +22 -0
  44. camel/{functions/base_io_functions.py → loaders/base_io.py} +38 -35
  45. camel/{functions/unstructured_io_fuctions.py → loaders/unstructured_io.py} +199 -110
  46. camel/memories/__init__.py +17 -7
  47. camel/memories/agent_memories.py +156 -0
  48. camel/memories/base.py +97 -32
  49. camel/memories/blocks/__init__.py +21 -0
  50. camel/memories/{chat_history_memory.py → blocks/chat_history_block.py} +34 -34
  51. camel/memories/blocks/vectordb_block.py +101 -0
  52. camel/memories/context_creators/__init__.py +3 -2
  53. camel/memories/context_creators/score_based.py +32 -20
  54. camel/memories/records.py +6 -5
  55. camel/messages/__init__.py +2 -2
  56. camel/messages/base.py +99 -16
  57. camel/messages/func_message.py +7 -4
  58. camel/models/__init__.py +6 -2
  59. camel/models/anthropic_model.py +146 -0
  60. camel/models/base_model.py +10 -3
  61. camel/models/model_factory.py +17 -11
  62. camel/models/open_source_model.py +25 -13
  63. camel/models/openai_audio_models.py +251 -0
  64. camel/models/openai_model.py +20 -13
  65. camel/models/stub_model.py +10 -5
  66. camel/prompts/__init__.py +7 -5
  67. camel/prompts/ai_society.py +21 -14
  68. camel/prompts/base.py +54 -47
  69. camel/prompts/code.py +22 -14
  70. camel/prompts/evaluation.py +8 -5
  71. camel/prompts/misalignment.py +26 -19
  72. camel/prompts/object_recognition.py +35 -0
  73. camel/prompts/prompt_templates.py +14 -8
  74. camel/prompts/role_description_prompt_template.py +16 -10
  75. camel/prompts/solution_extraction.py +9 -5
  76. camel/prompts/task_prompt_template.py +24 -21
  77. camel/prompts/translation.py +9 -5
  78. camel/responses/agent_responses.py +5 -2
  79. camel/retrievers/__init__.py +26 -0
  80. camel/retrievers/auto_retriever.py +330 -0
  81. camel/retrievers/base.py +69 -0
  82. camel/retrievers/bm25_retriever.py +140 -0
  83. camel/retrievers/cohere_rerank_retriever.py +108 -0
  84. camel/retrievers/vector_retriever.py +183 -0
  85. camel/societies/__init__.py +1 -1
  86. camel/societies/babyagi_playing.py +56 -32
  87. camel/societies/role_playing.py +188 -133
  88. camel/storages/__init__.py +18 -0
  89. camel/storages/graph_storages/__init__.py +23 -0
  90. camel/storages/graph_storages/base.py +82 -0
  91. camel/storages/graph_storages/graph_element.py +74 -0
  92. camel/storages/graph_storages/neo4j_graph.py +582 -0
  93. camel/storages/key_value_storages/base.py +1 -2
  94. camel/storages/key_value_storages/in_memory.py +1 -2
  95. camel/storages/key_value_storages/json.py +8 -13
  96. camel/storages/vectordb_storages/__init__.py +33 -0
  97. camel/storages/vectordb_storages/base.py +202 -0
  98. camel/storages/vectordb_storages/milvus.py +396 -0
  99. camel/storages/vectordb_storages/qdrant.py +373 -0
  100. camel/terminators/__init__.py +1 -1
  101. camel/terminators/base.py +2 -3
  102. camel/terminators/response_terminator.py +21 -12
  103. camel/terminators/token_limit_terminator.py +5 -3
  104. camel/toolkits/__init__.py +21 -0
  105. camel/toolkits/base.py +22 -0
  106. camel/toolkits/github_toolkit.py +245 -0
  107. camel/types/__init__.py +18 -6
  108. camel/types/enums.py +129 -15
  109. camel/types/openai_types.py +10 -5
  110. camel/utils/__init__.py +20 -13
  111. camel/utils/commons.py +170 -85
  112. camel/utils/token_counting.py +135 -15
  113. {camel_ai-0.1.1.dist-info → camel_ai-0.1.4.dist-info}/METADATA +123 -75
  114. camel_ai-0.1.4.dist-info/RECORD +119 -0
  115. {camel_ai-0.1.1.dist-info → camel_ai-0.1.4.dist-info}/WHEEL +1 -1
  116. camel/memories/context_creators/base.py +0 -72
  117. camel_ai-0.1.1.dist-info/RECORD +0 -75
@@ -0,0 +1,330 @@
1
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2
+ # Licensed under the Apache License, Version 2.0 (the “License”);
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an “AS IS” BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+ import datetime
15
+ import os
16
+ import re
17
+ from pathlib import Path
18
+ from typing import List, Optional, Tuple, Union
19
+ from urllib.parse import urlparse
20
+
21
+ from camel.embeddings import BaseEmbedding, OpenAIEmbedding
22
+ from camel.retrievers.vector_retriever import VectorRetriever
23
+ from camel.storages import (
24
+ BaseVectorStorage,
25
+ MilvusStorage,
26
+ QdrantStorage,
27
+ VectorDBQuery,
28
+ )
29
+ from camel.types import StorageType
30
+
31
+ DEFAULT_TOP_K_RESULTS = 1
32
+ DEFAULT_SIMILARITY_THRESHOLD = 0.75
33
+
34
+
35
+ class AutoRetriever:
36
+ r"""Facilitates the automatic retrieval of information using a
37
+ query-based approach with pre-defined elements.
38
+
39
+ Attributes:
40
+ url_and_api_key (Optional[Tuple[str, str]]): URL and API key for
41
+ accessing the vector storage remotely.
42
+ vector_storage_local_path (Optional[str]): Local path for vector
43
+ storage, if applicable.
44
+ storage_type (Optional[StorageType]): The type of vector storage to
45
+ use. Defaults to `StorageType.MILVUS`.
46
+ embedding_model (Optional[BaseEmbedding]): Model used for embedding
47
+ queries and documents. Defaults to `OpenAIEmbedding()`.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ url_and_api_key: Optional[Tuple[str, str]] = None,
53
+ vector_storage_local_path: Optional[str] = None,
54
+ storage_type: Optional[StorageType] = None,
55
+ embedding_model: Optional[BaseEmbedding] = None,
56
+ ):
57
+ self.storage_type = storage_type or StorageType.MILVUS
58
+ self.embedding_model = embedding_model or OpenAIEmbedding()
59
+ self.vector_storage_local_path = vector_storage_local_path
60
+ self.url_and_api_key = url_and_api_key
61
+
62
+ def _initialize_vector_storage(
63
+ self,
64
+ collection_name: Optional[str] = None,
65
+ ) -> BaseVectorStorage:
66
+ r"""Sets up and returns a vector storage instance with specified
67
+ parameters.
68
+
69
+ Args:
70
+ collection_name (Optional[str]): Name of the collection in the
71
+ vector storage.
72
+
73
+ Returns:
74
+ BaseVectorStorage: Configured vector storage instance.
75
+ """
76
+ if self.storage_type == StorageType.MILVUS:
77
+ if self.url_and_api_key is None:
78
+ raise ValueError(
79
+ "URL and API key required for Milvus storage are not"
80
+ "provided."
81
+ )
82
+ return MilvusStorage(
83
+ vector_dim=self.embedding_model.get_output_dim(),
84
+ collection_name=collection_name,
85
+ url_and_api_key=self.url_and_api_key,
86
+ )
87
+
88
+ if self.storage_type == StorageType.QDRANT:
89
+ return QdrantStorage(
90
+ vector_dim=self.embedding_model.get_output_dim(),
91
+ collection_name=collection_name,
92
+ path=self.vector_storage_local_path,
93
+ url_and_api_key=self.url_and_api_key,
94
+ )
95
+
96
+ raise ValueError(
97
+ f"Unsupported vector storage type: {self.storage_type}"
98
+ )
99
+
100
+ def _collection_name_generator(self, content_input_path: str) -> str:
101
+ r"""Generates a valid collection name from a given file path or URL.
102
+
103
+ Args:
104
+ content_input_path: str. The input URL or file path from which to
105
+ generate the collection name.
106
+
107
+ Returns:
108
+ str: A sanitized, valid collection name suitable for use.
109
+ """
110
+ # Check path type
111
+ parsed_url = urlparse(content_input_path)
112
+ self.is_url = all([parsed_url.scheme, parsed_url.netloc])
113
+
114
+ # Convert given path into a collection name, ensuring it only
115
+ # contains numbers, letters, and underscores
116
+ if self.is_url:
117
+ # For URLs, remove https://, replace /, and any characters not
118
+ # allowed by Milvus with _
119
+ collection_name = re.sub(
120
+ r'[^0-9a-zA-Z]+',
121
+ '_',
122
+ content_input_path.replace("https://", ""),
123
+ )
124
+ else:
125
+ # For file paths, get the stem and replace spaces with _, also
126
+ # ensuring only allowed characters are present
127
+ collection_name = re.sub(
128
+ r'[^0-9a-zA-Z]+', '_', Path(content_input_path).stem
129
+ )
130
+
131
+ # Ensure the collection name does not start or end with underscore
132
+ collection_name = collection_name.strip("_")
133
+ # Limit the maximum length of the collection name to 30 characters
134
+ collection_name = collection_name[:30]
135
+ return collection_name
136
+
137
+ def _get_file_modified_date_from_file(self, content_input_path: str) -> str:
138
+ r"""Retrieves the last modified date and time of a given file. This
139
+ function takes a file path as input and returns the last modified date
140
+ and time of that file.
141
+
142
+ Args:
143
+ content_input_path (str): The file path of the content whose
144
+ modified date is to be retrieved.
145
+
146
+ Returns:
147
+ str: The last modified time from file.
148
+ """
149
+ mod_time = os.path.getmtime(content_input_path)
150
+ readable_mod_time = datetime.datetime.fromtimestamp(mod_time).isoformat(
151
+ timespec='seconds'
152
+ )
153
+ return readable_mod_time
154
+
155
+ def _get_file_modified_date_from_storage(
156
+ self, vector_storage_instance: BaseVectorStorage
157
+ ) -> str:
158
+ r"""Retrieves the last modified date and time of a given file. This
159
+ function takes vector storage instance as input and returns the last
160
+ modified date from the meta data.
161
+
162
+ Args:
163
+ vector_storage_instance (BaseVectorStorage): The vector storage
164
+ where modified date is to be retrieved from meta data.
165
+
166
+ Returns:
167
+ str: The last modified date from vector storage.
168
+ """
169
+
170
+ # Insert any query to get modified date from vector db
171
+ # NOTE: Can be optimized when CAMEL vector storage support
172
+ # direct chunk payload extraction
173
+ query_vector_any = self.embedding_model.embed(obj="any_query")
174
+ query_any = VectorDBQuery(query_vector_any, top_k=1)
175
+ result_any = vector_storage_instance.query(query_any)
176
+
177
+ # Extract the file's last modified date from the metadata
178
+ # in the query result
179
+ if result_any[0].record.payload is not None:
180
+ file_modified_date_from_meta = result_any[0].record.payload[
181
+ "metadata"
182
+ ]['last_modified']
183
+ else:
184
+ raise ValueError(
185
+ "The vector storage exits but the payload is None,"
186
+ "please check the collection"
187
+ )
188
+
189
+ return file_modified_date_from_meta
190
+
191
+ def run_vector_retriever(
192
+ self,
193
+ query: str,
194
+ content_input_paths: Union[str, List[str]],
195
+ top_k: int = DEFAULT_TOP_K_RESULTS,
196
+ similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
197
+ return_detailed_info: bool = False,
198
+ ) -> str:
199
+ r"""Executes the automatic vector retriever process using vector
200
+ storage.
201
+
202
+ Args:
203
+ query (str): Query string for information retriever.
204
+ content_input_paths (Union[str, List[str]]): Paths to local
205
+ files or remote URLs.
206
+ top_k (int, optional): The number of top results to return during
207
+ retrieve. Must be a positive integer. Defaults to
208
+ `DEFAULT_TOP_K_RESULTS`.
209
+ similarity_threshold (float, optional): The similarity threshold
210
+ for filtering results. Defaults to
211
+ `DEFAULT_SIMILARITY_THRESHOLD`.
212
+ return_detailed_info (bool, optional): Whether to return detailed
213
+ information including similarity score, content path and
214
+ metadata. Defaults to False.
215
+
216
+ Returns:
217
+ string: By default, returns only the text information. If
218
+ `return_detailed_info` is `True`, return detailed information
219
+ including similarity score, content path and metadata.
220
+
221
+ Raises:
222
+ ValueError: If there's an vector storage existing with content
223
+ name in the vector path but the payload is None. If
224
+ `content_input_paths` is empty.
225
+ RuntimeError: If any errors occur during the retrieve process.
226
+ """
227
+ if not content_input_paths:
228
+ raise ValueError("content_input_paths cannot be empty.")
229
+
230
+ content_input_paths = (
231
+ [content_input_paths]
232
+ if isinstance(content_input_paths, str)
233
+ else content_input_paths
234
+ )
235
+
236
+ vr = VectorRetriever()
237
+
238
+ all_retrieved_info = []
239
+ for content_input_path in content_input_paths:
240
+ # Generate a valid collection name
241
+ collection_name = self._collection_name_generator(
242
+ content_input_path
243
+ )
244
+ try:
245
+ vector_storage_instance = self._initialize_vector_storage(
246
+ collection_name
247
+ )
248
+
249
+ # Check the modified time of the input file path, only works
250
+ # for local path since no standard way for remote url
251
+ file_is_modified = False # initialize with a default value
252
+ if (
253
+ vector_storage_instance.status().vector_count != 0
254
+ and not self.is_url
255
+ ):
256
+ # Get original modified date from file
257
+ modified_date_from_file = (
258
+ self._get_file_modified_date_from_file(
259
+ content_input_path
260
+ )
261
+ )
262
+ # Get modified date from vector storage
263
+ modified_date_from_storage = (
264
+ self._get_file_modified_date_from_storage(
265
+ vector_storage_instance
266
+ )
267
+ )
268
+ # Determine if the file has been modified since the last
269
+ # check
270
+ file_is_modified = (
271
+ modified_date_from_file != modified_date_from_storage
272
+ )
273
+
274
+ if (
275
+ vector_storage_instance.status().vector_count == 0
276
+ or file_is_modified
277
+ ):
278
+ # Clear the vector storage
279
+ vector_storage_instance.clear()
280
+ # Process and store the content to the vector storage
281
+ vr = VectorRetriever(
282
+ storage=vector_storage_instance,
283
+ similarity_threshold=similarity_threshold,
284
+ )
285
+ vr.process(content_input_path)
286
+ else:
287
+ vr = VectorRetriever(
288
+ storage=vector_storage_instance,
289
+ similarity_threshold=similarity_threshold,
290
+ )
291
+ # Retrieve info by given query from the vector storage
292
+ retrieved_info = vr.query(query, top_k)
293
+ all_retrieved_info.extend(retrieved_info)
294
+ except Exception as e:
295
+ raise RuntimeError(
296
+ f"Error in auto vector retriever processing: {e!s}"
297
+ ) from e
298
+
299
+ # Split records into those with and without a 'similarity_score'
300
+ # Records with 'similarity_score' lower than 'similarity_threshold'
301
+ # will not have a 'similarity_score' in the output content
302
+ with_score = [
303
+ info for info in all_retrieved_info if 'similarity score' in info
304
+ ]
305
+ without_score = [
306
+ info
307
+ for info in all_retrieved_info
308
+ if 'similarity score' not in info
309
+ ]
310
+ # Sort only the list with scores
311
+ with_score_sorted = sorted(
312
+ with_score, key=lambda x: x['similarity score'], reverse=True
313
+ )
314
+ # Merge back the sorted scored items with the non-scored items
315
+ all_retrieved_info_sorted = with_score_sorted + without_score
316
+ # Select the 'top_k' results
317
+ all_retrieved_info = all_retrieved_info_sorted[:top_k]
318
+
319
+ retrieved_infos = "\n".join(str(info) for info in all_retrieved_info)
320
+ retrieved_infos_text = "\n".join(
321
+ info['text'] for info in all_retrieved_info if 'text' in info
322
+ )
323
+
324
+ detailed_info = f"Original Query:\n{{ {query} }}\nRetrieved Context:\n{retrieved_infos}"
325
+ text_info = f"Original Query:\n{{ {query} }}\nRetrieved Context:\n{retrieved_infos_text}"
326
+
327
+ if return_detailed_info:
328
+ return detailed_info
329
+ else:
330
+ return text_info
@@ -0,0 +1,69 @@
1
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2
+ # Licensed under the Apache License, Version 2.0 (the “License”);
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an “AS IS” BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+ from abc import ABC, abstractmethod
15
+ from typing import Any, Callable
16
+
17
+ DEFAULT_TOP_K_RESULTS = 1
18
+
19
+
20
+ def _query_unimplemented(self, *input: Any) -> None:
21
+ r"""Defines the query behavior performed at every call.
22
+
23
+ Query the results. Subclasses should implement this
24
+ method according to their specific needs.
25
+
26
+ It should be overridden by all subclasses.
27
+
28
+ .. note::
29
+ Although the recipe for forward pass needs to be defined within
30
+ this function, one should call the :class:`BaseRetriever` instance
31
+ afterwards instead of this since the former takes care of running the
32
+ registered hooks while the latter silently ignores them.
33
+ """
34
+ raise NotImplementedError(
35
+ f"Retriever [{type(self).__name__}] is missing the required \"query\" function"
36
+ )
37
+
38
+
39
+ def _process_unimplemented(self, *input: Any) -> None:
40
+ r"""Defines the process behavior performed at every call.
41
+
42
+ Processes content from a file or URL, divides it into chunks by
43
+ using `Unstructured IO`,then stored internally. This method must be
44
+ called before executing queries with the retriever.
45
+
46
+ Should be overridden by all subclasses.
47
+
48
+ .. note::
49
+ Although the recipe for forward pass needs to be defined within
50
+ this function, one should call the :class:`BaseRetriever` instance
51
+ afterwards instead of this since the former takes care of running the
52
+ registered hooks while the latter silently ignores them.
53
+ """
54
+ raise NotImplementedError(
55
+ f"Retriever [{type(self).__name__}] is missing the required \"process\" function"
56
+ )
57
+
58
+
59
+ class BaseRetriever(ABC):
60
+ r"""Abstract base class for implementing various types of information
61
+ retrievers.
62
+ """
63
+
64
+ @abstractmethod
65
+ def __init__(self) -> None:
66
+ pass
67
+
68
+ process: Callable[..., Any] = _process_unimplemented
69
+ query: Callable[..., Any] = _query_unimplemented
@@ -0,0 +1,140 @@
1
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2
+ # Licensed under the Apache License, Version 2.0 (the “License”);
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an “AS IS” BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+ from typing import Any, Dict, List
15
+
16
+ import numpy as np
17
+
18
+ from camel.loaders import UnstructuredIO
19
+ from camel.retrievers import BaseRetriever
20
+
21
+ DEFAULT_TOP_K_RESULTS = 1
22
+
23
+
24
+ class BM25Retriever(BaseRetriever):
25
+ r"""An implementation of the `BaseRetriever` using the `BM25` model.
26
+
27
+ This class facilitates the retriever of relevant information using a
28
+ query-based approach, it ranks documents based on the occurrence and
29
+ frequency of the query terms.
30
+
31
+ Attributes:
32
+ bm25 (BM25Okapi): An instance of the BM25Okapi class used for
33
+ calculating document scores.
34
+ content_input_path (str): The path to the content that has been
35
+ processed and stored.
36
+ unstructured_modules (UnstructuredIO): A module for parsing files and
37
+ URLs and chunking content based on specified parameters.
38
+
39
+ References:
40
+ https://github.com/dorianbrown/rank_bm25
41
+ """
42
+
43
+ def __init__(self) -> None:
44
+ r"""Initializes the BM25Retriever."""
45
+
46
+ try:
47
+ from rank_bm25 import BM25Okapi
48
+ except ImportError as e:
49
+ raise ImportError(
50
+ "Package `rank_bm25` not installed, install by running 'pip install rank_bm25'"
51
+ ) from e
52
+
53
+ self.bm25: BM25Okapi = None
54
+ self.content_input_path: str = ""
55
+ self.unstructured_modules: UnstructuredIO = UnstructuredIO()
56
+
57
+ def process(
58
+ self,
59
+ content_input_path: str,
60
+ chunk_type: str = "chunk_by_title",
61
+ **kwargs: Any,
62
+ ) -> None:
63
+ r"""Processes content from a file or URL, divides it into chunks by
64
+ using `Unstructured IO`,then stored internally. This method must be
65
+ called before executing queries with the retriever.
66
+
67
+ Args:
68
+ content_input_path (str): File path or URL of the content to be
69
+ processed.
70
+ chunk_type (str): Type of chunking going to apply. Defaults to
71
+ "chunk_by_title".
72
+ **kwargs (Any): Additional keyword arguments for content parsing.
73
+ """
74
+ from rank_bm25 import BM25Okapi
75
+
76
+ # Load and preprocess documents
77
+ self.content_input_path = content_input_path
78
+ elements = self.unstructured_modules.parse_file_or_url(
79
+ content_input_path, **kwargs
80
+ )
81
+ self.chunks = self.unstructured_modules.chunk_elements(
82
+ chunk_type=chunk_type, elements=elements
83
+ )
84
+
85
+ # Convert chunks to a list of strings for tokenization
86
+ tokenized_corpus = [str(chunk).split(" ") for chunk in self.chunks]
87
+ self.bm25 = BM25Okapi(tokenized_corpus)
88
+
89
+ def query(
90
+ self,
91
+ query: str,
92
+ top_k: int = DEFAULT_TOP_K_RESULTS,
93
+ ) -> List[Dict[str, Any]]:
94
+ r"""Executes a query and compiles the results.
95
+
96
+ Args:
97
+ query (str): Query string for information retriever.
98
+ top_k (int, optional): The number of top results to return during
99
+ retriever. Must be a positive integer. Defaults to
100
+ `DEFAULT_TOP_K_RESULTS`.
101
+
102
+ Returns:
103
+ List[Dict[str]]: Concatenated list of the query results.
104
+
105
+ Raises:
106
+ ValueError: If `top_k` is less than or equal to 0, if the BM25
107
+ model has not been initialized by calling `process`
108
+ first.
109
+ """
110
+
111
+ if top_k <= 0:
112
+ raise ValueError("top_k must be a positive integer.")
113
+ if self.bm25 is None or not self.chunks:
114
+ raise ValueError(
115
+ "BM25 model is not initialized. Call `process` first."
116
+ )
117
+
118
+ # Preprocess query similarly to how documents were processed
119
+ processed_query = query.split(" ")
120
+ # Retrieve documents based on BM25 scores
121
+ scores = self.bm25.get_scores(processed_query)
122
+
123
+ top_k_indices = np.argpartition(scores, -top_k)[-top_k:]
124
+
125
+ formatted_results = []
126
+ for i in top_k_indices:
127
+ result_dict = {
128
+ 'similarity score': scores[i],
129
+ 'content path': self.content_input_path,
130
+ 'metadata': self.chunks[i].metadata.to_dict(),
131
+ 'text': str(self.chunks[i]),
132
+ }
133
+ formatted_results.append(result_dict)
134
+
135
+ # Sort the list of dictionaries by 'similarity score' from high to low
136
+ formatted_results.sort(
137
+ key=lambda x: x['similarity score'], reverse=True
138
+ )
139
+
140
+ return formatted_results
@@ -0,0 +1,108 @@
1
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2
+ # Licensed under the Apache License, Version 2.0 (the “License”);
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an “AS IS” BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+ import os
15
+ from typing import Any, Dict, List, Optional
16
+
17
+ from camel.retrievers import BaseRetriever
18
+
19
+ DEFAULT_TOP_K_RESULTS = 1
20
+
21
+
22
+ class CohereRerankRetriever(BaseRetriever):
23
+ r"""An implementation of the `BaseRetriever` using the `Cohere Re-ranking`
24
+ model.
25
+
26
+ Attributes:
27
+ model_name (str): The model name to use for re-ranking.
28
+ api_key (Optional[str]): The API key for authenticating with the
29
+ Cohere service.
30
+
31
+ References:
32
+ https://txt.cohere.com/rerank/
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ model_name: str = "rerank-multilingual-v2.0",
38
+ api_key: Optional[str] = None,
39
+ ) -> None:
40
+ r"""Initializes an instance of the CohereRerankRetriever. This
41
+ constructor sets up a client for interacting with the Cohere API using
42
+ the specified model name and API key. If the API key is not provided,
43
+ it attempts to retrieve it from the COHERE_API_KEY environment
44
+ variable.
45
+
46
+ Args:
47
+ model_name (str): The name of the model to be used for re-ranking.
48
+ Defaults to 'rerank-multilingual-v2.0'.
49
+ api_key (Optional[str]): The API key for authenticating requests
50
+ to the Cohere API. If not provided, the method will attempt to
51
+ retrieve the key from the environment variable
52
+ 'COHERE_API_KEY'.
53
+
54
+ Raises:
55
+ ImportError: If the 'cohere' package is not installed.
56
+ ValueError: If the API key is neither passed as an argument nor
57
+ set in the environment variable.
58
+ """
59
+
60
+ try:
61
+ import cohere
62
+ except ImportError as e:
63
+ raise ImportError("Package 'cohere' is not installed") from e
64
+
65
+ try:
66
+ self.api_key = api_key or os.environ["COHERE_API_KEY"]
67
+ except ValueError as e:
68
+ raise ValueError(
69
+ "Must pass in cohere api key or specify via COHERE_API_KEY environment variable."
70
+ ) from e
71
+
72
+ self.co = cohere.Client(self.api_key)
73
+ self.model_name = model_name
74
+
75
+ def query(
76
+ self,
77
+ query: str,
78
+ retrieved_result: List[Dict[str, Any]],
79
+ top_k: int = DEFAULT_TOP_K_RESULTS,
80
+ ) -> List[Dict[str, Any]]:
81
+ r"""Queries and compiles results using the Cohere re-ranking model.
82
+
83
+ Args:
84
+ query (str): Query string for information retriever.
85
+ retrieved_result (List[Dict[str, Any]]): The content to be
86
+ re-ranked, should be the output from `BaseRetriever` like
87
+ `VectorRetriever`.
88
+ top_k (int, optional): The number of top results to return during
89
+ retriever. Must be a positive integer. Defaults to
90
+ `DEFAULT_TOP_K_RESULTS`.
91
+
92
+ Returns:
93
+ List[Dict[str, Any]]: Concatenated list of the query results.
94
+ """
95
+ rerank_results = self.co.rerank(
96
+ query=query,
97
+ documents=retrieved_result,
98
+ top_n=top_k,
99
+ model=self.model_name,
100
+ )
101
+ formatted_results = []
102
+ for i in range(0, len(rerank_results.results)):
103
+ selected_chunk = retrieved_result[rerank_results[i].index]
104
+ selected_chunk['similarity score'] = rerank_results[
105
+ i
106
+ ].relevance_score
107
+ formatted_results.append(selected_chunk)
108
+ return formatted_results