camel-ai 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (99) hide show
  1. camel/__init__.py +1 -11
  2. camel/agents/__init__.py +5 -5
  3. camel/agents/chat_agent.py +124 -63
  4. camel/agents/critic_agent.py +28 -17
  5. camel/agents/deductive_reasoner_agent.py +235 -0
  6. camel/agents/embodied_agent.py +92 -40
  7. camel/agents/role_assignment_agent.py +27 -17
  8. camel/agents/task_agent.py +60 -34
  9. camel/agents/tool_agents/base.py +0 -1
  10. camel/agents/tool_agents/hugging_face_tool_agent.py +7 -4
  11. camel/configs.py +119 -7
  12. camel/embeddings/__init__.py +2 -0
  13. camel/embeddings/base.py +3 -2
  14. camel/embeddings/openai_embedding.py +3 -3
  15. camel/embeddings/sentence_transformers_embeddings.py +65 -0
  16. camel/functions/__init__.py +13 -3
  17. camel/functions/google_maps_function.py +335 -0
  18. camel/functions/math_functions.py +7 -7
  19. camel/functions/openai_function.py +344 -42
  20. camel/functions/search_functions.py +100 -35
  21. camel/functions/twitter_function.py +484 -0
  22. camel/functions/weather_functions.py +36 -23
  23. camel/generators.py +65 -46
  24. camel/human.py +17 -11
  25. camel/interpreters/__init__.py +25 -0
  26. camel/interpreters/base.py +49 -0
  27. camel/{utils/python_interpreter.py → interpreters/internal_python_interpreter.py} +129 -48
  28. camel/interpreters/interpreter_error.py +19 -0
  29. camel/interpreters/subprocess_interpreter.py +190 -0
  30. camel/loaders/__init__.py +22 -0
  31. camel/{functions/base_io_functions.py → loaders/base_io.py} +38 -35
  32. camel/{functions/unstructured_io_fuctions.py → loaders/unstructured_io.py} +199 -110
  33. camel/memories/__init__.py +17 -7
  34. camel/memories/agent_memories.py +156 -0
  35. camel/memories/base.py +97 -32
  36. camel/memories/blocks/__init__.py +21 -0
  37. camel/memories/{chat_history_memory.py → blocks/chat_history_block.py} +34 -34
  38. camel/memories/blocks/vectordb_block.py +101 -0
  39. camel/memories/context_creators/__init__.py +3 -2
  40. camel/memories/context_creators/score_based.py +32 -20
  41. camel/memories/records.py +6 -5
  42. camel/messages/__init__.py +2 -2
  43. camel/messages/base.py +99 -16
  44. camel/messages/func_message.py +7 -4
  45. camel/models/__init__.py +4 -2
  46. camel/models/anthropic_model.py +132 -0
  47. camel/models/base_model.py +3 -2
  48. camel/models/model_factory.py +10 -8
  49. camel/models/open_source_model.py +25 -13
  50. camel/models/openai_model.py +9 -10
  51. camel/models/stub_model.py +6 -5
  52. camel/prompts/__init__.py +7 -5
  53. camel/prompts/ai_society.py +21 -14
  54. camel/prompts/base.py +54 -47
  55. camel/prompts/code.py +22 -14
  56. camel/prompts/evaluation.py +8 -5
  57. camel/prompts/misalignment.py +26 -19
  58. camel/prompts/object_recognition.py +35 -0
  59. camel/prompts/prompt_templates.py +14 -8
  60. camel/prompts/role_description_prompt_template.py +16 -10
  61. camel/prompts/solution_extraction.py +9 -5
  62. camel/prompts/task_prompt_template.py +24 -21
  63. camel/prompts/translation.py +9 -5
  64. camel/responses/agent_responses.py +5 -2
  65. camel/retrievers/__init__.py +24 -0
  66. camel/retrievers/auto_retriever.py +319 -0
  67. camel/retrievers/base.py +64 -0
  68. camel/retrievers/bm25_retriever.py +149 -0
  69. camel/retrievers/vector_retriever.py +166 -0
  70. camel/societies/__init__.py +1 -1
  71. camel/societies/babyagi_playing.py +56 -32
  72. camel/societies/role_playing.py +188 -133
  73. camel/storages/__init__.py +18 -0
  74. camel/storages/graph_storages/__init__.py +23 -0
  75. camel/storages/graph_storages/base.py +82 -0
  76. camel/storages/graph_storages/graph_element.py +74 -0
  77. camel/storages/graph_storages/neo4j_graph.py +582 -0
  78. camel/storages/key_value_storages/base.py +1 -2
  79. camel/storages/key_value_storages/in_memory.py +1 -2
  80. camel/storages/key_value_storages/json.py +8 -13
  81. camel/storages/vectordb_storages/__init__.py +33 -0
  82. camel/storages/vectordb_storages/base.py +202 -0
  83. camel/storages/vectordb_storages/milvus.py +396 -0
  84. camel/storages/vectordb_storages/qdrant.py +371 -0
  85. camel/terminators/__init__.py +1 -1
  86. camel/terminators/base.py +2 -3
  87. camel/terminators/response_terminator.py +21 -12
  88. camel/terminators/token_limit_terminator.py +5 -3
  89. camel/types/__init__.py +12 -6
  90. camel/types/enums.py +86 -13
  91. camel/types/openai_types.py +10 -5
  92. camel/utils/__init__.py +18 -13
  93. camel/utils/commons.py +242 -81
  94. camel/utils/token_counting.py +135 -15
  95. {camel_ai-0.1.1.dist-info → camel_ai-0.1.3.dist-info}/METADATA +116 -74
  96. camel_ai-0.1.3.dist-info/RECORD +101 -0
  97. {camel_ai-0.1.1.dist-info → camel_ai-0.1.3.dist-info}/WHEEL +1 -1
  98. camel/memories/context_creators/base.py +0 -72
  99. camel_ai-0.1.1.dist-info/RECORD +0 -75
@@ -30,6 +30,7 @@ class ChatAgentResponse:
30
30
  to terminate the chat session.
31
31
  info (Dict[str, Any]): Extra information about the chat message.
32
32
  """
33
+
33
34
  msgs: List[BaseMessage]
34
35
  terminated: bool
35
36
  info: Dict[str, Any]
@@ -37,6 +38,8 @@ class ChatAgentResponse:
37
38
  @property
38
39
  def msg(self):
39
40
  if len(self.msgs) != 1:
40
- raise RuntimeError("Property msg is only available "
41
- "for a single message in msgs.")
41
+ raise RuntimeError(
42
+ "Property msg is only available "
43
+ "for a single message in msgs."
44
+ )
42
45
  return self.msgs[0]
@@ -0,0 +1,24 @@
1
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2
+ # Licensed under the Apache License, Version 2.0 (the “License”);
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an “AS IS” BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+ from .auto_retriever import AutoRetriever
15
+ from .base import BaseRetriever
16
+ from .bm25_retriever import BM25Retriever
17
+ from .vector_retriever import VectorRetriever
18
+
19
+ __all__ = [
20
+ 'BaseRetriever',
21
+ 'VectorRetriever',
22
+ 'AutoRetriever',
23
+ 'BM25Retriever',
24
+ ]
@@ -0,0 +1,319 @@
1
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2
+ # Licensed under the Apache License, Version 2.0 (the “License”);
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an “AS IS” BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+ import datetime
15
+ import os
16
+ import re
17
+ from pathlib import Path
18
+ from typing import List, Optional, Tuple, Union
19
+ from urllib.parse import urlparse
20
+
21
+ from camel.embeddings import BaseEmbedding, OpenAIEmbedding
22
+ from camel.retrievers.vector_retriever import VectorRetriever
23
+ from camel.storages import (
24
+ BaseVectorStorage,
25
+ MilvusStorage,
26
+ QdrantStorage,
27
+ VectorDBQuery,
28
+ )
29
+ from camel.types import StorageType
30
+
31
+ DEFAULT_TOP_K_RESULTS = 1
32
+ DEFAULT_SIMILARITY_THRESHOLD = 0.75
33
+
34
+
35
+ class AutoRetriever:
36
+ r"""Facilitates the automatic retrieval of information using a
37
+ query-based approach with pre-defined elements.
38
+
39
+ Attributes:
40
+ url_and_api_key (Optional[Tuple[str, str]]): URL and API key for
41
+ accessing the vector storage remotely.
42
+ vector_storage_local_path (Optional[str]): Local path for vector
43
+ storage, if applicable.
44
+ storage_type (Optional[StorageType]): The type of vector storage to
45
+ use. Defaults to `StorageType.MILVUS`.
46
+ embedding_model (Optional[BaseEmbedding]): Model used for embedding
47
+ queries and documents. Defaults to `OpenAIEmbedding()`.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ url_and_api_key: Optional[Tuple[str, str]] = None,
53
+ vector_storage_local_path: Optional[str] = None,
54
+ storage_type: Optional[StorageType] = None,
55
+ embedding_model: Optional[BaseEmbedding] = None,
56
+ ):
57
+ self.storage_type = storage_type or StorageType.MILVUS
58
+ self.embedding_model = embedding_model or OpenAIEmbedding()
59
+ self.vector_storage_local_path = vector_storage_local_path
60
+ self.url_and_api_key = url_and_api_key
61
+
62
+ def _initialize_vector_storage(
63
+ self,
64
+ collection_name: Optional[str] = None,
65
+ ) -> BaseVectorStorage:
66
+ r"""Sets up and returns a vector storage instance with specified parameters.
67
+
68
+ Args:
69
+ collection_name (Optional[str]): Name of the collection in the
70
+ vector storage.
71
+
72
+ Returns:
73
+ BaseVectorStorage: Configured vector storage instance.
74
+ """
75
+ if self.storage_type == StorageType.MILVUS:
76
+ if self.url_and_api_key is None:
77
+ raise ValueError(
78
+ "URL and API key required for Milvus storage are not"
79
+ "provided."
80
+ )
81
+ return MilvusStorage(
82
+ vector_dim=self.embedding_model.get_output_dim(),
83
+ collection_name=collection_name,
84
+ url_and_api_key=self.url_and_api_key,
85
+ )
86
+
87
+ if self.storage_type == StorageType.QDRANT:
88
+ return QdrantStorage(
89
+ vector_dim=self.embedding_model.get_output_dim(),
90
+ collection_name=collection_name,
91
+ path=self.vector_storage_local_path,
92
+ url_and_api_key=self.url_and_api_key,
93
+ )
94
+
95
+ raise ValueError(
96
+ f"Unsupported vector storage type: {self.storage_type}"
97
+ )
98
+
99
+ def _collection_name_generator(self, content_input_path: str) -> str:
100
+ r"""Generates a valid collection name from a given file path or URL.
101
+
102
+ Args:
103
+ content_input_path: str. The input URL or file path from which to
104
+ generate the collection name.
105
+
106
+ Returns:
107
+ str: A sanitized, valid collection name suitable for use.
108
+ """
109
+ # Check path type
110
+ parsed_url = urlparse(content_input_path)
111
+ self.is_url = all([parsed_url.scheme, parsed_url.netloc])
112
+
113
+ # Convert given path into a collection name, ensuring it only
114
+ # contains numbers, letters, and underscores
115
+ if self.is_url:
116
+ # For URLs, remove https://, replace /, and any characters not
117
+ # allowed by Milvus with _
118
+ collection_name = re.sub(
119
+ r'[^0-9a-zA-Z]+',
120
+ '_',
121
+ content_input_path.replace("https://", ""),
122
+ )
123
+ else:
124
+ # For file paths, get the stem and replace spaces with _, also
125
+ # ensuring only allowed characters are present
126
+ collection_name = re.sub(
127
+ r'[^0-9a-zA-Z]+', '_', Path(content_input_path).stem
128
+ )
129
+
130
+ # Ensure the collection name does not start or end with underscore
131
+ collection_name = collection_name.strip("_")
132
+ # Limit the maximum length of the collection name to 30 characters
133
+ collection_name = collection_name[:30]
134
+ return collection_name
135
+
136
+ def _get_file_modified_date_from_file(self, content_input_path: str) -> str:
137
+ r"""Retrieves the last modified date and time of a given file. This
138
+ function takes a file path as input and returns the last modified date
139
+ and time of that file.
140
+
141
+ Args:
142
+ content_input_path (str): The file path of the content whose
143
+ modified date is to be retrieved.
144
+
145
+ Returns:
146
+ str: The last modified time from file.
147
+ """
148
+ mod_time = os.path.getmtime(content_input_path)
149
+ readable_mod_time = datetime.datetime.fromtimestamp(mod_time).isoformat(
150
+ timespec='seconds'
151
+ )
152
+ return readable_mod_time
153
+
154
+ def _get_file_modified_date_from_storage(
155
+ self, vector_storage_instance: BaseVectorStorage
156
+ ) -> str:
157
+ r"""Retrieves the last modified date and time of a given file. This
158
+ function takes vector storage instance as input and returns the last
159
+ modified date from the meta data.
160
+
161
+ Args:
162
+ vector_storage_instance (BaseVectorStorage): The vector storage
163
+ where modified date is to be retrieved from meta data.
164
+
165
+ Returns:
166
+ str: The last modified date from vector storage.
167
+ """
168
+
169
+ # Insert any query to get modified date from vector db
170
+ # NOTE: Can be optimized when CAMEL vector storage support
171
+ # direct chunk payload extraction
172
+ query_vector_any = self.embedding_model.embed(obj="any_query")
173
+ query_any = VectorDBQuery(query_vector_any, top_k=1)
174
+ result_any = vector_storage_instance.query(query_any)
175
+
176
+ # Extract the file's last modified date from the metadata
177
+ # in the query result
178
+ if result_any[0].record.payload is not None:
179
+ file_modified_date_from_meta = result_any[0].record.payload[
180
+ "metadata"
181
+ ]['last_modified']
182
+ else:
183
+ raise ValueError(
184
+ "The vector storage exits but the payload is None,"
185
+ "please check the collection"
186
+ )
187
+
188
+ return file_modified_date_from_meta
189
+
190
+ def run_vector_retriever(
191
+ self,
192
+ query: str,
193
+ content_input_paths: Union[str, List[str]],
194
+ top_k: int = DEFAULT_TOP_K_RESULTS,
195
+ similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
196
+ return_detailed_info: bool = False,
197
+ ) -> str:
198
+ r"""Executes the automatic vector retriever process using vector storage.
199
+
200
+ Args:
201
+ query (str): Query string for information retriever.
202
+ content_input_paths (Union[str, List[str]]): Paths to local
203
+ files or remote URLs.
204
+ top_k (int, optional): The number of top results to return during
205
+ retrieve. Must be a positive integer. Defaults to
206
+ `DEFAULT_TOP_K_RESULTS`.
207
+ similarity_threshold (float, optional): The similarity threshold
208
+ for filtering results. Defaults to
209
+ `DEFAULT_SIMILARITY_THRESHOLD`.
210
+ return_detailed_info (bool, optional): Whether to return detailed
211
+ information including similarity score, content path and
212
+ metadata. Defaults to False.
213
+
214
+ Returns:
215
+ string: By default, returns only the text information. If
216
+ `return_detailed_info` is `True`, return detailed information
217
+ including similarity score, content path and metadata.
218
+
219
+ Raises:
220
+ ValueError: If there's an vector storage existing with content
221
+ name in the vector path but the payload is None. If
222
+ `content_input_paths` is empty.
223
+ RuntimeError: If any errors occur during the retrieve process.
224
+ """
225
+ if not content_input_paths:
226
+ raise ValueError("content_input_paths cannot be empty.")
227
+
228
+ content_input_paths = (
229
+ [content_input_paths]
230
+ if isinstance(content_input_paths, str)
231
+ else content_input_paths
232
+ )
233
+
234
+ vr = VectorRetriever()
235
+
236
+ retrieved_infos = ""
237
+ retrieved_infos_text = ""
238
+
239
+ for content_input_path in content_input_paths:
240
+ # Generate a valid collection name
241
+ collection_name = self._collection_name_generator(
242
+ content_input_path
243
+ )
244
+ try:
245
+ vector_storage_instance = self._initialize_vector_storage(
246
+ collection_name
247
+ )
248
+
249
+ # Check the modified time of the input file path, only works
250
+ # for local path since no standard way for remote url
251
+ file_is_modified = False # initialize with a default value
252
+ if (
253
+ vector_storage_instance.status().vector_count != 0
254
+ and not self.is_url
255
+ ):
256
+ # Get original modified date from file
257
+ modified_date_from_file = (
258
+ self._get_file_modified_date_from_file(
259
+ content_input_path
260
+ )
261
+ )
262
+ # Get modified date from vector storage
263
+ modified_date_from_storage = (
264
+ self._get_file_modified_date_from_storage(
265
+ vector_storage_instance
266
+ )
267
+ )
268
+ # Determine if the file has been modified since the last
269
+ # check
270
+ file_is_modified = (
271
+ modified_date_from_file != modified_date_from_storage
272
+ )
273
+
274
+ if (
275
+ vector_storage_instance.status().vector_count == 0
276
+ or file_is_modified
277
+ ):
278
+ # Clear the vector storage
279
+ vector_storage_instance.clear()
280
+ # Process and store the content to the vector storage
281
+ vr.process(content_input_path, vector_storage_instance)
282
+ # Retrieve info by given query from the vector storage
283
+ retrieved_info = vr.query(
284
+ query, vector_storage_instance, top_k, similarity_threshold
285
+ )
286
+ # Reorganize the retrieved info with original query
287
+ for info in retrieved_info:
288
+ retrieved_infos += "\n" + str(info)
289
+ retrieved_infos_text += "\n" + str(info['text'])
290
+ output = (
291
+ "Original Query:"
292
+ + "\n"
293
+ + "{"
294
+ + query
295
+ + "}"
296
+ + "\n"
297
+ + "Retrieved Context:"
298
+ + retrieved_infos
299
+ )
300
+ output_text = (
301
+ "Original Query:"
302
+ + "\n"
303
+ + "{"
304
+ + query
305
+ + "}"
306
+ + "\n"
307
+ + "Retrieved Context:"
308
+ + retrieved_infos_text
309
+ )
310
+
311
+ except Exception as e:
312
+ raise RuntimeError(
313
+ f"Error in auto vector retriever processing: {e!s}"
314
+ ) from e
315
+
316
+ if return_detailed_info:
317
+ return output
318
+ else:
319
+ return output_text
@@ -0,0 +1,64 @@
1
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2
+ # Licensed under the Apache License, Version 2.0 (the “License”);
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an “AS IS” BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+ from abc import ABC, abstractmethod
15
+ from typing import Any, Dict, List
16
+
17
+ DEFAULT_TOP_K_RESULTS = 1
18
+
19
+
20
+ class BaseRetriever(ABC):
21
+ r"""Abstract base class for implementing various types of information
22
+ retrievers.
23
+ """
24
+
25
+ @abstractmethod
26
+ def __init__(self) -> None:
27
+ pass
28
+
29
+ @abstractmethod
30
+ def process(
31
+ self,
32
+ content_input_path: str,
33
+ chunk_type: str = "chunk_by_title",
34
+ **kwargs: Any,
35
+ ) -> None:
36
+ r"""Processes content from a file or URL, divides it into chunks by
37
+ using `Unstructured IO`,then stored internally. This method must be
38
+ called before executing queries with the retriever.
39
+
40
+ Args:
41
+ content_input_path (str): File path or URL of the content to be
42
+ processed.
43
+ chunk_type (str): Type of chunking going to apply. Defaults to
44
+ "chunk_by_title".
45
+ **kwargs (Any): Additional keyword arguments for content parsing.
46
+ """
47
+ pass
48
+
49
+ @abstractmethod
50
+ def query(
51
+ self, query: str, top_k: int = DEFAULT_TOP_K_RESULTS, **kwargs: Any
52
+ ) -> List[Dict[str, Any]]:
53
+ r"""Query the results. Subclasses should implement this
54
+ method according to their specific needs.
55
+
56
+ Args:
57
+ query (str): Query string for information retriever.
58
+ top_k (int, optional): The number of top results to return during
59
+ retriever. Must be a positive integer. Defaults to
60
+ `DEFAULT_TOP_K_RESULTS`.
61
+ **kwargs (Any): Flexible keyword arguments for additional
62
+ parameters, like `similarity_threshold`.
63
+ """
64
+ pass
@@ -0,0 +1,149 @@
1
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2
+ # Licensed under the Apache License, Version 2.0 (the “License”);
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an “AS IS” BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+ from typing import Any, Dict, List
15
+
16
+ import numpy as np
17
+
18
+ from camel.loaders import UnstructuredIO
19
+ from camel.retrievers import BaseRetriever
20
+
21
+ DEFAULT_TOP_K_RESULTS = 1
22
+
23
+
24
+ class BM25Retriever(BaseRetriever):
25
+ r"""An implementation of the `BaseRetriever` using the `BM25` model.
26
+
27
+ This class facilitates the retriever of relevant information using a
28
+ query-based approach, it ranks documents based on the occurrence and
29
+ frequency of the query terms.
30
+
31
+ Attributes:
32
+ bm25 (BM25Okapi): An instance of the BM25Okapi class used for
33
+ calculating document scores.
34
+ content_input_path (str): The path to the content that has been
35
+ processed and stored.
36
+ chunks (List[Any]): A list of document chunks processed from the
37
+ input content.
38
+
39
+ References:
40
+ https://github.com/dorianbrown/rank_bm25
41
+ """
42
+
43
+ def __init__(self) -> None:
44
+ r"""Initializes the BM25Retriever."""
45
+
46
+ try:
47
+ from rank_bm25 import BM25Okapi
48
+ except ImportError as e:
49
+ raise ImportError(
50
+ "Package `rank_bm25` not installed, install by running"
51
+ " 'pip install rank_bm25'"
52
+ ) from e
53
+
54
+ self.bm25: BM25Okapi = None
55
+ self.content_input_path: str = ""
56
+ self.chunks: List[Any] = []
57
+
58
+ def process(
59
+ self,
60
+ content_input_path: str,
61
+ chunk_type: str = "chunk_by_title",
62
+ **kwargs: Any,
63
+ ) -> None:
64
+ r"""Processes content from a file or URL, divides it into chunks by
65
+ using `Unstructured IO`,then stored internally. This method must be
66
+ called before executing queries with the retriever.
67
+
68
+ Args:
69
+ content_input_path (str): File path or URL of the content to be
70
+ processed.
71
+ chunk_type (str): Type of chunking going to apply. Defaults to
72
+ "chunk_by_title".
73
+ **kwargs (Any): Additional keyword arguments for content parsing.
74
+ """
75
+ from rank_bm25 import BM25Okapi
76
+
77
+ # Load and preprocess documents
78
+ self.content_input_path = content_input_path
79
+ unstructured_modules = UnstructuredIO()
80
+ elements = unstructured_modules.parse_file_or_url(
81
+ content_input_path, **kwargs
82
+ )
83
+ self.chunks = unstructured_modules.chunk_elements(
84
+ chunk_type=chunk_type, elements=elements
85
+ )
86
+
87
+ # Convert chunks to a list of strings for tokenization
88
+ tokenized_corpus = [str(chunk).split(" ") for chunk in self.chunks]
89
+ self.bm25 = BM25Okapi(tokenized_corpus)
90
+
91
+ def query( # type: ignore[override]
92
+ self,
93
+ query: str,
94
+ top_k: int = DEFAULT_TOP_K_RESULTS,
95
+ ) -> List[Dict[str, Any]]:
96
+ r"""Executes a query and compiles the results.
97
+
98
+ Args:
99
+ query (str): Query string for information retriever.
100
+ top_k (int, optional): The number of top results to return during
101
+ retriever. Must be a positive integer. Defaults to
102
+ `DEFAULT_TOP_K_RESULTS`.
103
+
104
+ Returns:
105
+ List[Dict[str]]: Concatenated list of the query results.
106
+
107
+ Raises:
108
+ ValueError: If `top_k` is less than or equal to 0, if the BM25
109
+ model has not been initialized by calling `process_and_store`
110
+ first.
111
+
112
+ Note:
113
+ `storage` and `kwargs` parameters are included to maintain
114
+ compatibility with the `BaseRetriever` interface but are not used
115
+ in this implementation.
116
+ """
117
+
118
+ if top_k <= 0:
119
+ raise ValueError("top_k must be a positive integer.")
120
+
121
+ if self.bm25 is None:
122
+ raise ValueError(
123
+ "BM25 model is not initialized. Call `process_and_store`"
124
+ " first."
125
+ )
126
+
127
+ # Preprocess query similarly to how documents were processed
128
+ processed_query = query.split(" ")
129
+ # Retrieve documents based on BM25 scores
130
+ scores = self.bm25.get_scores(processed_query)
131
+
132
+ top_k_indices = np.argpartition(scores, -top_k)[-top_k:]
133
+
134
+ formatted_results = []
135
+ for i in top_k_indices:
136
+ result_dict = {
137
+ 'similarity score': scores[i],
138
+ 'content path': self.content_input_path,
139
+ 'metadata': self.chunks[i].metadata.to_dict(),
140
+ 'text': str(self.chunks[i]),
141
+ }
142
+ formatted_results.append(result_dict)
143
+
144
+ # Sort the list of dictionaries by 'similarity score' from high to low
145
+ formatted_results.sort(
146
+ key=lambda x: x['similarity score'], reverse=True
147
+ )
148
+
149
+ return formatted_results