camel-ai 0.1.6.2__py3-none-any.whl → 0.1.6.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- camel/__init__.py +1 -1
- camel/configs/gemini_config.py +0 -1
- camel/configs/groq_config.py +1 -1
- camel/configs/mistral_config.py +14 -10
- camel/embeddings/mistral_embedding.py +5 -5
- camel/interpreters/docker_interpreter.py +1 -1
- camel/loaders/__init__.py +1 -2
- camel/loaders/base_io.py +118 -52
- camel/loaders/jina_url_reader.py +6 -6
- camel/loaders/unstructured_io.py +34 -295
- camel/models/__init__.py +2 -0
- camel/models/mistral_model.py +120 -26
- camel/models/model_factory.py +3 -3
- camel/models/openai_compatibility_model.py +105 -0
- camel/retrievers/auto_retriever.py +40 -52
- camel/retrievers/bm25_retriever.py +9 -6
- camel/retrievers/vector_retriever.py +29 -20
- camel/storages/object_storages/__init__.py +22 -0
- camel/storages/object_storages/amazon_s3.py +205 -0
- camel/storages/object_storages/azure_blob.py +166 -0
- camel/storages/object_storages/base.py +115 -0
- camel/storages/object_storages/google_cloud.py +152 -0
- camel/toolkits/retrieval_toolkit.py +6 -6
- camel/toolkits/search_toolkit.py +4 -4
- camel/types/enums.py +7 -0
- camel/utils/token_counting.py +7 -3
- {camel_ai-0.1.6.2.dist-info → camel_ai-0.1.6.5.dist-info}/METADATA +9 -5
- {camel_ai-0.1.6.2.dist-info → camel_ai-0.1.6.5.dist-info}/RECORD +29 -23
- {camel_ai-0.1.6.2.dist-info → camel_ai-0.1.6.5.dist-info}/WHEEL +0 -0
camel/models/model_factory.py
CHANGED
|
@@ -22,6 +22,7 @@ from camel.models.litellm_model import LiteLLMModel
|
|
|
22
22
|
from camel.models.mistral_model import MistralModel
|
|
23
23
|
from camel.models.ollama_model import OllamaModel
|
|
24
24
|
from camel.models.open_source_model import OpenSourceModel
|
|
25
|
+
from camel.models.openai_compatibility_model import OpenAICompatibilityModel
|
|
25
26
|
from camel.models.openai_model import OpenAIModel
|
|
26
27
|
from camel.models.stub_model import StubModel
|
|
27
28
|
from camel.models.vllm_model import VLLMModel
|
|
@@ -105,11 +106,10 @@ class ModelFactory:
|
|
|
105
106
|
)
|
|
106
107
|
elif model_platform.is_vllm:
|
|
107
108
|
model_class = VLLMModel
|
|
108
|
-
return model_class(
|
|
109
|
-
model_type, model_config_dict, url, api_key, token_counter
|
|
110
|
-
)
|
|
111
109
|
elif model_platform.is_litellm:
|
|
112
110
|
model_class = LiteLLMModel
|
|
111
|
+
elif model_platform.is_openai_compatibility_model:
|
|
112
|
+
model_class = OpenAICompatibilityModel
|
|
113
113
|
else:
|
|
114
114
|
raise ValueError(
|
|
115
115
|
f"Unknown pair of model platform `{model_platform}` "
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
|
|
15
|
+
from typing import Any, Dict, List, Optional, Union
|
|
16
|
+
|
|
17
|
+
from openai import OpenAI, Stream
|
|
18
|
+
|
|
19
|
+
from camel.messages import OpenAIMessage
|
|
20
|
+
from camel.types import ChatCompletion, ChatCompletionChunk, ModelType
|
|
21
|
+
from camel.utils import (
|
|
22
|
+
BaseTokenCounter,
|
|
23
|
+
OpenAITokenCounter,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class OpenAICompatibilityModel:
|
|
28
|
+
r"""Constructor for model backend supporting OpenAI compatibility."""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
model_type: str,
|
|
33
|
+
model_config_dict: Dict[str, Any],
|
|
34
|
+
api_key: str,
|
|
35
|
+
url: str,
|
|
36
|
+
token_counter: Optional[BaseTokenCounter] = None,
|
|
37
|
+
) -> None:
|
|
38
|
+
r"""Constructor for model backend.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
model_type (ModelType): Model for which a backend is created.
|
|
42
|
+
model_config_dict (Dict[str, Any]): A dictionary that will
|
|
43
|
+
be fed into openai.ChatCompletion.create().
|
|
44
|
+
api_key (str): The API key for authenticating with the
|
|
45
|
+
model service. (default: :obj:`None`)
|
|
46
|
+
url (str): The url to the model service. (default:
|
|
47
|
+
:obj:`None`)
|
|
48
|
+
token_counter (Optional[BaseTokenCounter]): Token counter to use
|
|
49
|
+
for the model. If not provided, `OpenAITokenCounter(ModelType.
|
|
50
|
+
GPT_3_5_TURBO)` will be used.
|
|
51
|
+
"""
|
|
52
|
+
self.model_type = model_type
|
|
53
|
+
self.model_config_dict = model_config_dict
|
|
54
|
+
self._token_counter = token_counter
|
|
55
|
+
self._client = OpenAI(
|
|
56
|
+
timeout=60,
|
|
57
|
+
max_retries=3,
|
|
58
|
+
api_key=api_key,
|
|
59
|
+
base_url=url,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def run(
|
|
63
|
+
self,
|
|
64
|
+
messages: List[OpenAIMessage],
|
|
65
|
+
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
|
|
66
|
+
r"""Runs inference of OpenAI chat completion.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
messages (List[OpenAIMessage]): Message list with the chat history
|
|
70
|
+
in OpenAI API format.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Union[ChatCompletion, Stream[ChatCompletionChunk]]:
|
|
74
|
+
`ChatCompletion` in the non-stream mode, or
|
|
75
|
+
`Stream[ChatCompletionChunk]` in the stream mode.
|
|
76
|
+
"""
|
|
77
|
+
response = self._client.chat.completions.create(
|
|
78
|
+
messages=messages,
|
|
79
|
+
model=self.model_type,
|
|
80
|
+
**self.model_config_dict,
|
|
81
|
+
)
|
|
82
|
+
return response
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def token_counter(self) -> BaseTokenCounter:
|
|
86
|
+
r"""Initialize the token counter for the model backend.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
OpenAITokenCounter: The token counter following the model's
|
|
90
|
+
tokenization style.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
if not self._token_counter:
|
|
94
|
+
self._token_counter = OpenAITokenCounter(ModelType.GPT_3_5_TURBO)
|
|
95
|
+
return self._token_counter
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def stream(self) -> bool:
|
|
99
|
+
r"""Returns whether the model is in stream mode, which sends partial
|
|
100
|
+
results each time.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
bool: Whether the model is in stream mode.
|
|
104
|
+
"""
|
|
105
|
+
return self.model_config_dict.get('stream', False)
|
|
@@ -15,7 +15,7 @@ import datetime
|
|
|
15
15
|
import os
|
|
16
16
|
import re
|
|
17
17
|
from pathlib import Path
|
|
18
|
-
from typing import List, Optional, Tuple, Union
|
|
18
|
+
from typing import Collection, List, Optional, Sequence, Tuple, Union
|
|
19
19
|
from urllib.parse import urlparse
|
|
20
20
|
|
|
21
21
|
from camel.embeddings import BaseEmbedding, OpenAIEmbedding
|
|
@@ -97,36 +97,36 @@ class AutoRetriever:
|
|
|
97
97
|
f"Unsupported vector storage type: {self.storage_type}"
|
|
98
98
|
)
|
|
99
99
|
|
|
100
|
-
def _collection_name_generator(self,
|
|
100
|
+
def _collection_name_generator(self, content: str) -> str:
|
|
101
101
|
r"""Generates a valid collection name from a given file path or URL.
|
|
102
102
|
|
|
103
103
|
Args:
|
|
104
|
-
|
|
105
|
-
generate the collection name.
|
|
104
|
+
contents (str): Local file path, remote URL or string content.
|
|
106
105
|
|
|
107
106
|
Returns:
|
|
108
107
|
str: A sanitized, valid collection name suitable for use.
|
|
109
108
|
"""
|
|
110
|
-
# Check
|
|
111
|
-
parsed_url = urlparse(
|
|
112
|
-
|
|
109
|
+
# Check if the content is URL
|
|
110
|
+
parsed_url = urlparse(content)
|
|
111
|
+
is_url = all([parsed_url.scheme, parsed_url.netloc])
|
|
113
112
|
|
|
114
113
|
# Convert given path into a collection name, ensuring it only
|
|
115
114
|
# contains numbers, letters, and underscores
|
|
116
|
-
if
|
|
115
|
+
if is_url:
|
|
117
116
|
# For URLs, remove https://, replace /, and any characters not
|
|
118
117
|
# allowed by Milvus with _
|
|
119
118
|
collection_name = re.sub(
|
|
120
119
|
r'[^0-9a-zA-Z]+',
|
|
121
120
|
'_',
|
|
122
|
-
|
|
121
|
+
content.replace("https://", ""),
|
|
123
122
|
)
|
|
124
|
-
|
|
123
|
+
elif os.path.exists(content):
|
|
125
124
|
# For file paths, get the stem and replace spaces with _, also
|
|
126
125
|
# ensuring only allowed characters are present
|
|
127
|
-
collection_name = re.sub(
|
|
128
|
-
|
|
129
|
-
|
|
126
|
+
collection_name = re.sub(r'[^0-9a-zA-Z]+', '_', Path(content).stem)
|
|
127
|
+
else:
|
|
128
|
+
# the content is string input
|
|
129
|
+
collection_name = content[:10]
|
|
130
130
|
|
|
131
131
|
# Ensure the collection name does not start or end with underscore
|
|
132
132
|
collection_name = collection_name.strip("_")
|
|
@@ -193,18 +193,18 @@ class AutoRetriever:
|
|
|
193
193
|
def run_vector_retriever(
|
|
194
194
|
self,
|
|
195
195
|
query: str,
|
|
196
|
-
|
|
196
|
+
contents: Union[str, List[str]],
|
|
197
197
|
top_k: int = DEFAULT_TOP_K_RESULTS,
|
|
198
198
|
similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
|
|
199
199
|
return_detailed_info: bool = False,
|
|
200
|
-
) -> str:
|
|
200
|
+
) -> dict[str, Sequence[Collection[str]]]:
|
|
201
201
|
r"""Executes the automatic vector retriever process using vector
|
|
202
202
|
storage.
|
|
203
203
|
|
|
204
204
|
Args:
|
|
205
205
|
query (str): Query string for information retriever.
|
|
206
|
-
|
|
207
|
-
|
|
206
|
+
contents (Union[str, List[str]]): Local file paths, remote URLs or
|
|
207
|
+
string contents.
|
|
208
208
|
top_k (int, optional): The number of top results to return during
|
|
209
209
|
retrieve. Must be a positive integer. Defaults to
|
|
210
210
|
`DEFAULT_TOP_K_RESULTS`.
|
|
@@ -216,31 +216,26 @@ class AutoRetriever:
|
|
|
216
216
|
metadata. Defaults to `False`.
|
|
217
217
|
|
|
218
218
|
Returns:
|
|
219
|
-
|
|
220
|
-
`return_detailed_info` is
|
|
221
|
-
|
|
219
|
+
dict[str, Sequence[Collection[str]]]: By default, returns
|
|
220
|
+
only the text information. If `return_detailed_info` is
|
|
221
|
+
`True`, return detailed information including similarity
|
|
222
|
+
score, content path and metadata.
|
|
222
223
|
|
|
223
224
|
Raises:
|
|
224
225
|
ValueError: If there's an vector storage existing with content
|
|
225
226
|
name in the vector path but the payload is None. If
|
|
226
|
-
`
|
|
227
|
+
`contents` is empty.
|
|
227
228
|
RuntimeError: If any errors occur during the retrieve process.
|
|
228
229
|
"""
|
|
229
|
-
if not
|
|
230
|
-
raise ValueError("
|
|
230
|
+
if not contents:
|
|
231
|
+
raise ValueError("content cannot be empty.")
|
|
231
232
|
|
|
232
|
-
|
|
233
|
-
[content_input_paths]
|
|
234
|
-
if isinstance(content_input_paths, str)
|
|
235
|
-
else content_input_paths
|
|
236
|
-
)
|
|
233
|
+
contents = [contents] if isinstance(contents, str) else contents
|
|
237
234
|
|
|
238
235
|
all_retrieved_info = []
|
|
239
|
-
for
|
|
236
|
+
for content in contents:
|
|
240
237
|
# Generate a valid collection name
|
|
241
|
-
collection_name = self._collection_name_generator(
|
|
242
|
-
content_input_path
|
|
243
|
-
)
|
|
238
|
+
collection_name = self._collection_name_generator(content)
|
|
244
239
|
try:
|
|
245
240
|
vector_storage_instance = self._initialize_vector_storage(
|
|
246
241
|
collection_name
|
|
@@ -251,13 +246,11 @@ class AutoRetriever:
|
|
|
251
246
|
file_is_modified = False # initialize with a default value
|
|
252
247
|
if (
|
|
253
248
|
vector_storage_instance.status().vector_count != 0
|
|
254
|
-
and
|
|
249
|
+
and os.path.exists(content)
|
|
255
250
|
):
|
|
256
251
|
# Get original modified date from file
|
|
257
252
|
modified_date_from_file = (
|
|
258
|
-
self._get_file_modified_date_from_file(
|
|
259
|
-
content_input_path
|
|
260
|
-
)
|
|
253
|
+
self._get_file_modified_date_from_file(content)
|
|
261
254
|
)
|
|
262
255
|
# Get modified date from vector storage
|
|
263
256
|
modified_date_from_storage = (
|
|
@@ -280,18 +273,16 @@ class AutoRetriever:
|
|
|
280
273
|
# Process and store the content to the vector storage
|
|
281
274
|
vr = VectorRetriever(
|
|
282
275
|
storage=vector_storage_instance,
|
|
283
|
-
similarity_threshold=similarity_threshold,
|
|
284
276
|
embedding_model=self.embedding_model,
|
|
285
277
|
)
|
|
286
|
-
vr.process(
|
|
278
|
+
vr.process(content)
|
|
287
279
|
else:
|
|
288
280
|
vr = VectorRetriever(
|
|
289
281
|
storage=vector_storage_instance,
|
|
290
|
-
similarity_threshold=similarity_threshold,
|
|
291
282
|
embedding_model=self.embedding_model,
|
|
292
283
|
)
|
|
293
284
|
# Retrieve info by given query from the vector storage
|
|
294
|
-
retrieved_info = vr.query(query, top_k)
|
|
285
|
+
retrieved_info = vr.query(query, top_k, similarity_threshold)
|
|
295
286
|
all_retrieved_info.extend(retrieved_info)
|
|
296
287
|
except Exception as e:
|
|
297
288
|
raise RuntimeError(
|
|
@@ -318,20 +309,17 @@ class AutoRetriever:
|
|
|
318
309
|
# Select the 'top_k' results
|
|
319
310
|
all_retrieved_info = all_retrieved_info_sorted[:top_k]
|
|
320
311
|
|
|
321
|
-
|
|
322
|
-
retrieved_infos_text = "\n".join(
|
|
323
|
-
info['text'] for info in all_retrieved_info if 'text' in info
|
|
324
|
-
)
|
|
312
|
+
text_retrieved_info = [item['text'] for item in all_retrieved_info]
|
|
325
313
|
|
|
326
|
-
detailed_info =
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
314
|
+
detailed_info = {
|
|
315
|
+
"Original Query": query,
|
|
316
|
+
"Retrieved Context": all_retrieved_info,
|
|
317
|
+
}
|
|
330
318
|
|
|
331
|
-
text_info =
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
319
|
+
text_info = {
|
|
320
|
+
"Original Query": query,
|
|
321
|
+
"Retrieved Context": text_retrieved_info,
|
|
322
|
+
}
|
|
335
323
|
|
|
336
324
|
if return_detailed_info:
|
|
337
325
|
return detailed_info
|
|
@@ -74,13 +74,16 @@ class BM25Retriever(BaseRetriever):
|
|
|
74
74
|
elements = self.unstructured_modules.parse_file_or_url(
|
|
75
75
|
content_input_path, **kwargs
|
|
76
76
|
)
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
77
|
+
if elements:
|
|
78
|
+
self.chunks = self.unstructured_modules.chunk_elements(
|
|
79
|
+
chunk_type=chunk_type, elements=elements
|
|
80
|
+
)
|
|
80
81
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
82
|
+
# Convert chunks to a list of strings for tokenization
|
|
83
|
+
tokenized_corpus = [str(chunk).split(" ") for chunk in self.chunks]
|
|
84
|
+
self.bm25 = BM25Okapi(tokenized_corpus)
|
|
85
|
+
else:
|
|
86
|
+
self.bm25 = None
|
|
84
87
|
|
|
85
88
|
def query(
|
|
86
89
|
self,
|
|
@@ -11,7 +11,10 @@
|
|
|
11
11
|
# See the License for the specific language governing permissions and
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
import os
|
|
15
|
+
import warnings
|
|
14
16
|
from typing import Any, Dict, List, Optional
|
|
17
|
+
from urllib.parse import urlparse
|
|
15
18
|
|
|
16
19
|
from camel.embeddings import BaseEmbedding, OpenAIEmbedding
|
|
17
20
|
from camel.loaders import UnstructuredIO
|
|
@@ -38,24 +41,18 @@ class VectorRetriever(BaseRetriever):
|
|
|
38
41
|
embedding_model (BaseEmbedding): Embedding model used to generate
|
|
39
42
|
vector embeddings.
|
|
40
43
|
storage (BaseVectorStorage): Vector storage to query.
|
|
41
|
-
similarity_threshold (float, optional): The similarity threshold
|
|
42
|
-
for filtering results. Defaults to `DEFAULT_SIMILARITY_THRESHOLD`.
|
|
43
44
|
unstructured_modules (UnstructuredIO): A module for parsing files and
|
|
44
45
|
URLs and chunking content based on specified parameters.
|
|
45
46
|
"""
|
|
46
47
|
|
|
47
48
|
def __init__(
|
|
48
49
|
self,
|
|
49
|
-
similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
|
|
50
50
|
embedding_model: Optional[BaseEmbedding] = None,
|
|
51
51
|
storage: Optional[BaseVectorStorage] = None,
|
|
52
52
|
) -> None:
|
|
53
53
|
r"""Initializes the retriever class with an optional embedding model.
|
|
54
54
|
|
|
55
55
|
Args:
|
|
56
|
-
similarity_threshold (float, optional): The similarity threshold
|
|
57
|
-
for filtering results. Defaults to
|
|
58
|
-
`DEFAULT_SIMILARITY_THRESHOLD`.
|
|
59
56
|
embedding_model (Optional[BaseEmbedding]): The embedding model
|
|
60
57
|
instance. Defaults to `OpenAIEmbedding` if not provided.
|
|
61
58
|
storage (BaseVectorStorage): Vector storage to query.
|
|
@@ -68,12 +65,11 @@ class VectorRetriever(BaseRetriever):
|
|
|
68
65
|
vector_dim=self.embedding_model.get_output_dim()
|
|
69
66
|
)
|
|
70
67
|
)
|
|
71
|
-
self.
|
|
72
|
-
self.unstructured_modules: UnstructuredIO = UnstructuredIO()
|
|
68
|
+
self.uio: UnstructuredIO = UnstructuredIO()
|
|
73
69
|
|
|
74
70
|
def process(
|
|
75
71
|
self,
|
|
76
|
-
|
|
72
|
+
content: str,
|
|
77
73
|
chunk_type: str = "chunk_by_title",
|
|
78
74
|
**kwargs: Any,
|
|
79
75
|
) -> None:
|
|
@@ -82,18 +78,27 @@ class VectorRetriever(BaseRetriever):
|
|
|
82
78
|
vector storage.
|
|
83
79
|
|
|
84
80
|
Args:
|
|
85
|
-
|
|
86
|
-
processed.
|
|
81
|
+
contents (str): Local file path, remote URL or string content.
|
|
87
82
|
chunk_type (str): Type of chunking going to apply. Defaults to
|
|
88
83
|
"chunk_by_title".
|
|
89
84
|
**kwargs (Any): Additional keyword arguments for content parsing.
|
|
90
85
|
"""
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
)
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
86
|
+
# Check if the content is URL
|
|
87
|
+
parsed_url = urlparse(content)
|
|
88
|
+
is_url = all([parsed_url.scheme, parsed_url.netloc])
|
|
89
|
+
if is_url or os.path.exists(content):
|
|
90
|
+
elements = self.uio.parse_file_or_url(content, **kwargs)
|
|
91
|
+
else:
|
|
92
|
+
elements = [self.uio.create_element_from_text(text=content)]
|
|
93
|
+
if elements:
|
|
94
|
+
chunks = self.uio.chunk_elements(
|
|
95
|
+
chunk_type=chunk_type, elements=elements
|
|
96
|
+
)
|
|
97
|
+
if not elements:
|
|
98
|
+
warnings.warn(
|
|
99
|
+
f"No elements were extracted from the content: {content}"
|
|
100
|
+
)
|
|
101
|
+
return
|
|
97
102
|
# Iterate to process and store embeddings, set batch of 50
|
|
98
103
|
for i in range(0, len(chunks), 50):
|
|
99
104
|
batch_chunks = chunks[i : i + 50]
|
|
@@ -105,7 +110,7 @@ class VectorRetriever(BaseRetriever):
|
|
|
105
110
|
# Prepare the payload for each vector record, includes the content
|
|
106
111
|
# path, chunk metadata, and chunk text
|
|
107
112
|
for vector, chunk in zip(batch_vectors, batch_chunks):
|
|
108
|
-
content_path_info = {"content path":
|
|
113
|
+
content_path_info = {"content path": content}
|
|
109
114
|
chunk_metadata = {"metadata": chunk.metadata.to_dict()}
|
|
110
115
|
chunk_text = {"text": str(chunk)}
|
|
111
116
|
combined_dict = {
|
|
@@ -124,12 +129,16 @@ class VectorRetriever(BaseRetriever):
|
|
|
124
129
|
self,
|
|
125
130
|
query: str,
|
|
126
131
|
top_k: int = DEFAULT_TOP_K_RESULTS,
|
|
132
|
+
similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
|
|
127
133
|
) -> List[Dict[str, Any]]:
|
|
128
134
|
r"""Executes a query in vector storage and compiles the retrieved
|
|
129
135
|
results into a dictionary.
|
|
130
136
|
|
|
131
137
|
Args:
|
|
132
138
|
query (str): Query string for information retriever.
|
|
139
|
+
similarity_threshold (float, optional): The similarity threshold
|
|
140
|
+
for filtering results. Defaults to
|
|
141
|
+
`DEFAULT_SIMILARITY_THRESHOLD`.
|
|
133
142
|
top_k (int, optional): The number of top results to return during
|
|
134
143
|
retriever. Must be a positive integer. Defaults to 1.
|
|
135
144
|
|
|
@@ -161,7 +170,7 @@ class VectorRetriever(BaseRetriever):
|
|
|
161
170
|
formatted_results = []
|
|
162
171
|
for result in query_results:
|
|
163
172
|
if (
|
|
164
|
-
result.similarity >=
|
|
173
|
+
result.similarity >= similarity_threshold
|
|
165
174
|
and result.record.payload is not None
|
|
166
175
|
):
|
|
167
176
|
result_dict = {
|
|
@@ -182,7 +191,7 @@ class VectorRetriever(BaseRetriever):
|
|
|
182
191
|
'text': (
|
|
183
192
|
f"No suitable information retrieved "
|
|
184
193
|
f"from {content_path} with similarity_threshold"
|
|
185
|
-
f" = {
|
|
194
|
+
f" = {similarity_threshold}."
|
|
186
195
|
)
|
|
187
196
|
}
|
|
188
197
|
]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
from .amazon_s3 import AmazonS3Storage
|
|
15
|
+
from .azure_blob import AzureBlobStorage
|
|
16
|
+
from .google_cloud import GoogleCloudStorage
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"AmazonS3Storage",
|
|
20
|
+
"AzureBlobStorage",
|
|
21
|
+
"GoogleCloudStorage",
|
|
22
|
+
]
|