camel-ai 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/agents/__init__.py +2 -0
- camel/agents/chat_agent.py +40 -53
- camel/agents/knowledge_graph_agent.py +221 -0
- camel/configs/__init__.py +29 -0
- camel/configs/anthropic_config.py +73 -0
- camel/configs/base_config.py +22 -0
- camel/configs/openai_config.py +132 -0
- camel/embeddings/openai_embedding.py +7 -2
- camel/functions/__init__.py +13 -8
- camel/functions/open_api_function.py +380 -0
- camel/functions/open_api_specs/coursera/__init__.py +13 -0
- camel/functions/open_api_specs/coursera/openapi.yaml +82 -0
- camel/functions/open_api_specs/klarna/__init__.py +13 -0
- camel/functions/open_api_specs/klarna/openapi.yaml +87 -0
- camel/functions/open_api_specs/speak/__init__.py +13 -0
- camel/functions/open_api_specs/speak/openapi.yaml +151 -0
- camel/functions/openai_function.py +3 -1
- camel/functions/retrieval_functions.py +61 -0
- camel/functions/slack_functions.py +275 -0
- camel/models/__init__.py +2 -0
- camel/models/anthropic_model.py +16 -2
- camel/models/base_model.py +8 -2
- camel/models/model_factory.py +7 -3
- camel/models/openai_audio_models.py +251 -0
- camel/models/openai_model.py +12 -4
- camel/models/stub_model.py +5 -1
- camel/retrievers/__init__.py +2 -0
- camel/retrievers/auto_retriever.py +47 -36
- camel/retrievers/base.py +42 -37
- camel/retrievers/bm25_retriever.py +10 -19
- camel/retrievers/cohere_rerank_retriever.py +108 -0
- camel/retrievers/vector_retriever.py +43 -26
- camel/storages/vectordb_storages/qdrant.py +3 -1
- camel/toolkits/__init__.py +21 -0
- camel/toolkits/base.py +22 -0
- camel/toolkits/github_toolkit.py +245 -0
- camel/types/__init__.py +6 -0
- camel/types/enums.py +44 -3
- camel/utils/__init__.py +4 -2
- camel/utils/commons.py +97 -173
- {camel_ai-0.1.3.dist-info → camel_ai-0.1.5.dist-info}/METADATA +9 -3
- {camel_ai-0.1.3.dist-info → camel_ai-0.1.5.dist-info}/RECORD +44 -26
- camel/configs.py +0 -271
- {camel_ai-0.1.3.dist-info → camel_ai-0.1.5.dist-info}/WHEEL +0 -0
camel/models/base_model.py
CHANGED
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
14
|
from abc import ABC, abstractmethod
|
|
15
|
-
from typing import Any, Dict, List, Union
|
|
15
|
+
from typing import Any, Dict, List, Optional, Union
|
|
16
16
|
|
|
17
17
|
from openai import Stream
|
|
18
18
|
|
|
@@ -27,17 +27,23 @@ class BaseModelBackend(ABC):
|
|
|
27
27
|
"""
|
|
28
28
|
|
|
29
29
|
def __init__(
|
|
30
|
-
self,
|
|
30
|
+
self,
|
|
31
|
+
model_type: ModelType,
|
|
32
|
+
model_config_dict: Dict[str, Any],
|
|
33
|
+
api_key: Optional[str] = None,
|
|
31
34
|
) -> None:
|
|
32
35
|
r"""Constructor for the model backend.
|
|
33
36
|
|
|
34
37
|
Args:
|
|
35
38
|
model_type (ModelType): Model for which a backend is created.
|
|
36
39
|
model_config_dict (Dict[str, Any]): A config dictionary.
|
|
40
|
+
api_key (Optional[str]): The API key for authenticating with the
|
|
41
|
+
LLM service.
|
|
37
42
|
"""
|
|
38
43
|
self.model_type = model_type
|
|
39
44
|
|
|
40
45
|
self.model_config_dict = model_config_dict
|
|
46
|
+
self._api_key = api_key
|
|
41
47
|
self.check_model_config()
|
|
42
48
|
|
|
43
49
|
@property
|
camel/models/model_factory.py
CHANGED
|
@@ -11,7 +11,7 @@
|
|
|
11
11
|
# See the License for the specific language governing permissions and
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
-
from typing import Any, Dict
|
|
14
|
+
from typing import Any, Dict, Optional
|
|
15
15
|
|
|
16
16
|
from camel.models.anthropic_model import AnthropicModel
|
|
17
17
|
from camel.models.base_model import BaseModelBackend
|
|
@@ -30,7 +30,9 @@ class ModelFactory:
|
|
|
30
30
|
|
|
31
31
|
@staticmethod
|
|
32
32
|
def create(
|
|
33
|
-
model_type: ModelType,
|
|
33
|
+
model_type: ModelType,
|
|
34
|
+
model_config_dict: Dict,
|
|
35
|
+
api_key: Optional[str] = None,
|
|
34
36
|
) -> BaseModelBackend:
|
|
35
37
|
r"""Creates an instance of `BaseModelBackend` of the specified type.
|
|
36
38
|
|
|
@@ -38,6 +40,8 @@ class ModelFactory:
|
|
|
38
40
|
model_type (ModelType): Model for which a backend is created.
|
|
39
41
|
model_config_dict (Dict): A dictionary that will be fed into
|
|
40
42
|
the backend constructor.
|
|
43
|
+
api_key (Optional[str]): The API key for authenticating with the
|
|
44
|
+
LLM service.
|
|
41
45
|
|
|
42
46
|
Raises:
|
|
43
47
|
ValueError: If there is not backend for the model.
|
|
@@ -57,5 +61,5 @@ class ModelFactory:
|
|
|
57
61
|
else:
|
|
58
62
|
raise ValueError(f"Unknown model type `{model_type}` is input")
|
|
59
63
|
|
|
60
|
-
inst = model_class(model_type, model_config_dict)
|
|
64
|
+
inst = model_class(model_type, model_config_dict, api_key)
|
|
61
65
|
return inst
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
import os
|
|
15
|
+
from typing import Any, List, Optional, Union
|
|
16
|
+
|
|
17
|
+
from openai import OpenAI, _legacy_response
|
|
18
|
+
|
|
19
|
+
from camel.types import AudioModelType, VoiceType
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class OpenAIAudioModels:
|
|
23
|
+
r"""Provides access to OpenAI's Text-to-Speech (TTS) and Speech_to_Text
|
|
24
|
+
(STT) models."""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
) -> None:
|
|
29
|
+
r"""Initialize an instance of OpenAI."""
|
|
30
|
+
url = os.environ.get('OPENAI_API_BASE_URL')
|
|
31
|
+
self._client = OpenAI(timeout=120, max_retries=3, base_url=url)
|
|
32
|
+
|
|
33
|
+
def text_to_speech(
|
|
34
|
+
self,
|
|
35
|
+
input: str,
|
|
36
|
+
model_type: AudioModelType = AudioModelType.TTS_1,
|
|
37
|
+
voice: VoiceType = VoiceType.ALLOY,
|
|
38
|
+
storage_path: Optional[str] = None,
|
|
39
|
+
**kwargs: Any,
|
|
40
|
+
) -> Union[
|
|
41
|
+
List[_legacy_response.HttpxBinaryResponseContent],
|
|
42
|
+
_legacy_response.HttpxBinaryResponseContent,
|
|
43
|
+
]:
|
|
44
|
+
r"""Convert text to speech using OpenAI's TTS model. This method
|
|
45
|
+
converts the given input text to speech using the specified model and
|
|
46
|
+
voice.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
input (str): The text to be converted to speech.
|
|
50
|
+
model_type (AudioModelType, optional): The TTS model to use.
|
|
51
|
+
Defaults to `AudioModelType.TTS_1`.
|
|
52
|
+
voice (VoiceType, optional): The voice to be used for generating
|
|
53
|
+
speech. Defaults to `VoiceType.ALLOY`.
|
|
54
|
+
storage_path (str, optional): The local path to store the
|
|
55
|
+
generated speech file if provided, defaults to `None`.
|
|
56
|
+
**kwargs (Any): Extra kwargs passed to the TTS API.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Union[List[_legacy_response.HttpxBinaryResponseContent],
|
|
60
|
+
_legacy_response.HttpxBinaryResponseContent]: List of response
|
|
61
|
+
content object from OpenAI if input charaters more than 4096,
|
|
62
|
+
single response content if input charaters less than 4096.
|
|
63
|
+
|
|
64
|
+
Raises:
|
|
65
|
+
Exception: If there's an error during the TTS API call.
|
|
66
|
+
"""
|
|
67
|
+
try:
|
|
68
|
+
# Model only support at most 4096 characters one time.
|
|
69
|
+
max_chunk_size = 4095
|
|
70
|
+
audio_chunks = []
|
|
71
|
+
chunk_index = 0
|
|
72
|
+
if len(input) > max_chunk_size:
|
|
73
|
+
while input:
|
|
74
|
+
if len(input) <= max_chunk_size:
|
|
75
|
+
chunk = input
|
|
76
|
+
input = ''
|
|
77
|
+
else:
|
|
78
|
+
# Find the nearest period before the chunk size limit
|
|
79
|
+
while input[max_chunk_size - 1] != '.':
|
|
80
|
+
max_chunk_size -= 1
|
|
81
|
+
|
|
82
|
+
chunk = input[:max_chunk_size]
|
|
83
|
+
input = input[max_chunk_size:].lstrip()
|
|
84
|
+
|
|
85
|
+
response = self._client.audio.speech.create(
|
|
86
|
+
model=model_type.value,
|
|
87
|
+
voice=voice.value,
|
|
88
|
+
input=chunk,
|
|
89
|
+
**kwargs,
|
|
90
|
+
)
|
|
91
|
+
if storage_path:
|
|
92
|
+
try:
|
|
93
|
+
# Create a new storage path for each chunk
|
|
94
|
+
file_name, file_extension = os.path.splitext(
|
|
95
|
+
storage_path
|
|
96
|
+
)
|
|
97
|
+
new_storage_path = (
|
|
98
|
+
f"{file_name}_{chunk_index}{file_extension}"
|
|
99
|
+
)
|
|
100
|
+
response.write_to_file(new_storage_path)
|
|
101
|
+
chunk_index += 1
|
|
102
|
+
except Exception as e:
|
|
103
|
+
raise Exception(
|
|
104
|
+
"Error during writing the file"
|
|
105
|
+
) from e
|
|
106
|
+
|
|
107
|
+
audio_chunks.append(response)
|
|
108
|
+
return audio_chunks
|
|
109
|
+
|
|
110
|
+
else:
|
|
111
|
+
response = self._client.audio.speech.create(
|
|
112
|
+
model=model_type.value,
|
|
113
|
+
voice=voice.value,
|
|
114
|
+
input=input,
|
|
115
|
+
**kwargs,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
if storage_path:
|
|
119
|
+
try:
|
|
120
|
+
response.write_to_file(storage_path)
|
|
121
|
+
except Exception as e:
|
|
122
|
+
raise Exception("Error during write the file") from e
|
|
123
|
+
|
|
124
|
+
return response
|
|
125
|
+
|
|
126
|
+
except Exception as e:
|
|
127
|
+
raise Exception("Error during TTS API call") from e
|
|
128
|
+
|
|
129
|
+
def _split_audio(
|
|
130
|
+
self, audio_file_path: str, chunk_size_mb: int = 24
|
|
131
|
+
) -> list:
|
|
132
|
+
r"""Split the audio file into smaller chunks. Since the Whisper API
|
|
133
|
+
only supports files that are less than 25 MB.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
audio_file_path (str): Path to the input audio file.
|
|
137
|
+
chunk_size_mb (int, optional): Size of each chunk in megabytes.
|
|
138
|
+
Defaults to `24`.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
list: List of paths to the split audio files.
|
|
142
|
+
"""
|
|
143
|
+
from pydub import AudioSegment
|
|
144
|
+
|
|
145
|
+
audio = AudioSegment.from_file(audio_file_path)
|
|
146
|
+
audio_format = os.path.splitext(audio_file_path)[1][1:].lower()
|
|
147
|
+
|
|
148
|
+
# Calculate chunk size in bytes
|
|
149
|
+
chunk_size_bytes = chunk_size_mb * 1024 * 1024
|
|
150
|
+
|
|
151
|
+
# Number of chunks needed
|
|
152
|
+
num_chunks = os.path.getsize(audio_file_path) // chunk_size_bytes + 1
|
|
153
|
+
|
|
154
|
+
# Create a directory to store the chunks
|
|
155
|
+
output_dir = os.path.splitext(audio_file_path)[0] + "_chunks"
|
|
156
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
157
|
+
|
|
158
|
+
# Get audio chunk len in milliseconds
|
|
159
|
+
chunk_size_milliseconds = len(audio) // (num_chunks)
|
|
160
|
+
|
|
161
|
+
# Split the audio into chunks
|
|
162
|
+
split_files = []
|
|
163
|
+
for i in range(num_chunks):
|
|
164
|
+
start = i * chunk_size_milliseconds
|
|
165
|
+
end = (i + 1) * chunk_size_milliseconds
|
|
166
|
+
if i + 1 == num_chunks:
|
|
167
|
+
chunk = audio[start:]
|
|
168
|
+
else:
|
|
169
|
+
chunk = audio[start:end]
|
|
170
|
+
# Create new chunk path
|
|
171
|
+
chunk_path = os.path.join(output_dir, f"chunk_{i}.{audio_format}")
|
|
172
|
+
chunk.export(chunk_path, format=audio_format)
|
|
173
|
+
split_files.append(chunk_path)
|
|
174
|
+
return split_files
|
|
175
|
+
|
|
176
|
+
def speech_to_text(
|
|
177
|
+
self,
|
|
178
|
+
audio_file_path: str,
|
|
179
|
+
translate_into_english: bool = False,
|
|
180
|
+
**kwargs: Any,
|
|
181
|
+
) -> str:
|
|
182
|
+
r"""Convert speech audio to text.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
audio_file_path (str): The audio file path, supporting one of
|
|
186
|
+
these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or
|
|
187
|
+
webm.
|
|
188
|
+
translate_into_english (bool, optional): Whether to translate the
|
|
189
|
+
speech into English. Defaults to `False`.
|
|
190
|
+
**kwargs (Any): Extra keyword arguments passed to the
|
|
191
|
+
Speech-to-Text (STT) API.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
str: The output text.
|
|
195
|
+
|
|
196
|
+
Raises:
|
|
197
|
+
ValueError: If the audio file format is not supported.
|
|
198
|
+
Exception: If there's an error during the STT API call.
|
|
199
|
+
"""
|
|
200
|
+
supported_formats = [
|
|
201
|
+
"flac",
|
|
202
|
+
"mp3",
|
|
203
|
+
"mp4",
|
|
204
|
+
"mpeg",
|
|
205
|
+
"mpga",
|
|
206
|
+
"m4a",
|
|
207
|
+
"ogg",
|
|
208
|
+
"wav",
|
|
209
|
+
"webm",
|
|
210
|
+
]
|
|
211
|
+
file_format = audio_file_path.split(".")[-1].lower()
|
|
212
|
+
|
|
213
|
+
if file_format not in supported_formats:
|
|
214
|
+
raise ValueError(f"Unsupported audio file format: {file_format}")
|
|
215
|
+
try:
|
|
216
|
+
if os.path.getsize(audio_file_path) > 24 * 1024 * 1024:
|
|
217
|
+
# Split audio into chunks
|
|
218
|
+
audio_chunks = self._split_audio(audio_file_path)
|
|
219
|
+
texts = []
|
|
220
|
+
for chunk_path in audio_chunks:
|
|
221
|
+
audio_data = open(chunk_path, "rb")
|
|
222
|
+
if translate_into_english:
|
|
223
|
+
translation = self._client.audio.translations.create(
|
|
224
|
+
model="whisper-1", file=audio_data, **kwargs
|
|
225
|
+
)
|
|
226
|
+
texts.append(translation.text)
|
|
227
|
+
else:
|
|
228
|
+
transcription = (
|
|
229
|
+
self._client.audio.transcriptions.create(
|
|
230
|
+
model="whisper-1", file=audio_data, **kwargs
|
|
231
|
+
)
|
|
232
|
+
)
|
|
233
|
+
texts.append(transcription.text)
|
|
234
|
+
os.remove(chunk_path) # Delete temporary chunk file
|
|
235
|
+
return " ".join(texts)
|
|
236
|
+
else:
|
|
237
|
+
# Process the entire audio file
|
|
238
|
+
audio_data = open(audio_file_path, "rb")
|
|
239
|
+
|
|
240
|
+
if translate_into_english:
|
|
241
|
+
translation = self._client.audio.translations.create(
|
|
242
|
+
model="whisper-1", file=audio_data, **kwargs
|
|
243
|
+
)
|
|
244
|
+
return translation.text
|
|
245
|
+
else:
|
|
246
|
+
transcription = self._client.audio.transcriptions.create(
|
|
247
|
+
model="whisper-1", file=audio_data, **kwargs
|
|
248
|
+
)
|
|
249
|
+
return transcription.text
|
|
250
|
+
except Exception as e:
|
|
251
|
+
raise Exception("Error during STT API call") from e
|
camel/models/openai_model.py
CHANGED
|
@@ -16,7 +16,7 @@ from typing import Any, Dict, List, Optional, Union
|
|
|
16
16
|
|
|
17
17
|
from openai import OpenAI, Stream
|
|
18
18
|
|
|
19
|
-
from camel.configs import
|
|
19
|
+
from camel.configs import OPENAI_API_PARAMS
|
|
20
20
|
from camel.messages import OpenAIMessage
|
|
21
21
|
from camel.models import BaseModelBackend
|
|
22
22
|
from camel.types import ChatCompletion, ChatCompletionChunk, ModelType
|
|
@@ -27,7 +27,10 @@ class OpenAIModel(BaseModelBackend):
|
|
|
27
27
|
r"""OpenAI API in a unified BaseModelBackend interface."""
|
|
28
28
|
|
|
29
29
|
def __init__(
|
|
30
|
-
self,
|
|
30
|
+
self,
|
|
31
|
+
model_type: ModelType,
|
|
32
|
+
model_config_dict: Dict[str, Any],
|
|
33
|
+
api_key: Optional[str] = None,
|
|
31
34
|
) -> None:
|
|
32
35
|
r"""Constructor for OpenAI backend.
|
|
33
36
|
|
|
@@ -36,10 +39,15 @@ class OpenAIModel(BaseModelBackend):
|
|
|
36
39
|
one of GPT_* series.
|
|
37
40
|
model_config_dict (Dict[str, Any]): A dictionary that will
|
|
38
41
|
be fed into openai.ChatCompletion.create().
|
|
42
|
+
api_key (Optional[str]): The API key for authenticating with the
|
|
43
|
+
OpenAI service. (default: :obj:`None`)
|
|
39
44
|
"""
|
|
40
45
|
super().__init__(model_type, model_config_dict)
|
|
41
46
|
url = os.environ.get('OPENAI_API_BASE_URL', None)
|
|
42
|
-
self.
|
|
47
|
+
self._api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
|
48
|
+
self._client = OpenAI(
|
|
49
|
+
timeout=60, max_retries=3, base_url=url, api_key=self._api_key
|
|
50
|
+
)
|
|
43
51
|
self._token_counter: Optional[BaseTokenCounter] = None
|
|
44
52
|
|
|
45
53
|
@property
|
|
@@ -86,7 +94,7 @@ class OpenAIModel(BaseModelBackend):
|
|
|
86
94
|
unexpected arguments to OpenAI API.
|
|
87
95
|
"""
|
|
88
96
|
for param in self.model_config_dict:
|
|
89
|
-
if param not in
|
|
97
|
+
if param not in OPENAI_API_PARAMS:
|
|
90
98
|
raise ValueError(
|
|
91
99
|
f"Unexpected argument `{param}` is "
|
|
92
100
|
"input into OpenAI model backend."
|
camel/models/stub_model.py
CHANGED
|
@@ -50,11 +50,15 @@ class StubModel(BaseModelBackend):
|
|
|
50
50
|
model_type = ModelType.STUB
|
|
51
51
|
|
|
52
52
|
def __init__(
|
|
53
|
-
self,
|
|
53
|
+
self,
|
|
54
|
+
model_type: ModelType,
|
|
55
|
+
model_config_dict: Dict[str, Any],
|
|
56
|
+
api_key: Optional[str] = None,
|
|
54
57
|
) -> None:
|
|
55
58
|
r"""All arguments are unused for the dummy model."""
|
|
56
59
|
super().__init__(model_type, model_config_dict)
|
|
57
60
|
self._token_counter: Optional[BaseTokenCounter] = None
|
|
61
|
+
self._api_key = api_key
|
|
58
62
|
|
|
59
63
|
@property
|
|
60
64
|
def token_counter(self) -> BaseTokenCounter:
|
camel/retrievers/__init__.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
from .auto_retriever import AutoRetriever
|
|
15
15
|
from .base import BaseRetriever
|
|
16
16
|
from .bm25_retriever import BM25Retriever
|
|
17
|
+
from .cohere_rerank_retriever import CohereRerankRetriever
|
|
17
18
|
from .vector_retriever import VectorRetriever
|
|
18
19
|
|
|
19
20
|
__all__ = [
|
|
@@ -21,4 +22,5 @@ __all__ = [
|
|
|
21
22
|
'VectorRetriever',
|
|
22
23
|
'AutoRetriever',
|
|
23
24
|
'BM25Retriever',
|
|
25
|
+
'CohereRerankRetriever',
|
|
24
26
|
]
|
|
@@ -63,7 +63,8 @@ class AutoRetriever:
|
|
|
63
63
|
self,
|
|
64
64
|
collection_name: Optional[str] = None,
|
|
65
65
|
) -> BaseVectorStorage:
|
|
66
|
-
r"""Sets up and returns a vector storage instance with specified
|
|
66
|
+
r"""Sets up and returns a vector storage instance with specified
|
|
67
|
+
parameters.
|
|
67
68
|
|
|
68
69
|
Args:
|
|
69
70
|
collection_name (Optional[str]): Name of the collection in the
|
|
@@ -195,7 +196,8 @@ class AutoRetriever:
|
|
|
195
196
|
similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
|
|
196
197
|
return_detailed_info: bool = False,
|
|
197
198
|
) -> str:
|
|
198
|
-
r"""Executes the automatic vector retriever process using vector
|
|
199
|
+
r"""Executes the automatic vector retriever process using vector
|
|
200
|
+
storage.
|
|
199
201
|
|
|
200
202
|
Args:
|
|
201
203
|
query (str): Query string for information retriever.
|
|
@@ -233,9 +235,7 @@ class AutoRetriever:
|
|
|
233
235
|
|
|
234
236
|
vr = VectorRetriever()
|
|
235
237
|
|
|
236
|
-
|
|
237
|
-
retrieved_infos_text = ""
|
|
238
|
-
|
|
238
|
+
all_retrieved_info = []
|
|
239
239
|
for content_input_path in content_input_paths:
|
|
240
240
|
# Generate a valid collection name
|
|
241
241
|
collection_name = self._collection_name_generator(
|
|
@@ -278,42 +278,53 @@ class AutoRetriever:
|
|
|
278
278
|
# Clear the vector storage
|
|
279
279
|
vector_storage_instance.clear()
|
|
280
280
|
# Process and store the content to the vector storage
|
|
281
|
-
vr
|
|
281
|
+
vr = VectorRetriever(
|
|
282
|
+
storage=vector_storage_instance,
|
|
283
|
+
similarity_threshold=similarity_threshold,
|
|
284
|
+
)
|
|
285
|
+
vr.process(content_input_path)
|
|
286
|
+
else:
|
|
287
|
+
vr = VectorRetriever(
|
|
288
|
+
storage=vector_storage_instance,
|
|
289
|
+
similarity_threshold=similarity_threshold,
|
|
290
|
+
)
|
|
282
291
|
# Retrieve info by given query from the vector storage
|
|
283
|
-
retrieved_info = vr.query(
|
|
284
|
-
|
|
285
|
-
)
|
|
286
|
-
# Reorganize the retrieved info with original query
|
|
287
|
-
for info in retrieved_info:
|
|
288
|
-
retrieved_infos += "\n" + str(info)
|
|
289
|
-
retrieved_infos_text += "\n" + str(info['text'])
|
|
290
|
-
output = (
|
|
291
|
-
"Original Query:"
|
|
292
|
-
+ "\n"
|
|
293
|
-
+ "{"
|
|
294
|
-
+ query
|
|
295
|
-
+ "}"
|
|
296
|
-
+ "\n"
|
|
297
|
-
+ "Retrieved Context:"
|
|
298
|
-
+ retrieved_infos
|
|
299
|
-
)
|
|
300
|
-
output_text = (
|
|
301
|
-
"Original Query:"
|
|
302
|
-
+ "\n"
|
|
303
|
-
+ "{"
|
|
304
|
-
+ query
|
|
305
|
-
+ "}"
|
|
306
|
-
+ "\n"
|
|
307
|
-
+ "Retrieved Context:"
|
|
308
|
-
+ retrieved_infos_text
|
|
309
|
-
)
|
|
310
|
-
|
|
292
|
+
retrieved_info = vr.query(query, top_k)
|
|
293
|
+
all_retrieved_info.extend(retrieved_info)
|
|
311
294
|
except Exception as e:
|
|
312
295
|
raise RuntimeError(
|
|
313
296
|
f"Error in auto vector retriever processing: {e!s}"
|
|
314
297
|
) from e
|
|
315
298
|
|
|
299
|
+
# Split records into those with and without a 'similarity_score'
|
|
300
|
+
# Records with 'similarity_score' lower than 'similarity_threshold'
|
|
301
|
+
# will not have a 'similarity_score' in the output content
|
|
302
|
+
with_score = [
|
|
303
|
+
info for info in all_retrieved_info if 'similarity score' in info
|
|
304
|
+
]
|
|
305
|
+
without_score = [
|
|
306
|
+
info
|
|
307
|
+
for info in all_retrieved_info
|
|
308
|
+
if 'similarity score' not in info
|
|
309
|
+
]
|
|
310
|
+
# Sort only the list with scores
|
|
311
|
+
with_score_sorted = sorted(
|
|
312
|
+
with_score, key=lambda x: x['similarity score'], reverse=True
|
|
313
|
+
)
|
|
314
|
+
# Merge back the sorted scored items with the non-scored items
|
|
315
|
+
all_retrieved_info_sorted = with_score_sorted + without_score
|
|
316
|
+
# Select the 'top_k' results
|
|
317
|
+
all_retrieved_info = all_retrieved_info_sorted[:top_k]
|
|
318
|
+
|
|
319
|
+
retrieved_infos = "\n".join(str(info) for info in all_retrieved_info)
|
|
320
|
+
retrieved_infos_text = "\n".join(
|
|
321
|
+
info['text'] for info in all_retrieved_info if 'text' in info
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
detailed_info = f"Original Query:\n{{ {query} }}\nRetrieved Context:\n{retrieved_infos}"
|
|
325
|
+
text_info = f"Original Query:\n{{ {query} }}\nRetrieved Context:\n{retrieved_infos_text}"
|
|
326
|
+
|
|
316
327
|
if return_detailed_info:
|
|
317
|
-
return
|
|
328
|
+
return detailed_info
|
|
318
329
|
else:
|
|
319
|
-
return
|
|
330
|
+
return text_info
|
camel/retrievers/base.py
CHANGED
|
@@ -12,53 +12,58 @@
|
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
14
|
from abc import ABC, abstractmethod
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Callable
|
|
16
16
|
|
|
17
17
|
DEFAULT_TOP_K_RESULTS = 1
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
|
|
21
|
-
r"""
|
|
22
|
-
|
|
20
|
+
def _query_unimplemented(self, *input: Any) -> None:
|
|
21
|
+
r"""Defines the query behavior performed at every call.
|
|
22
|
+
|
|
23
|
+
Query the results. Subclasses should implement this
|
|
24
|
+
method according to their specific needs.
|
|
25
|
+
|
|
26
|
+
It should be overridden by all subclasses.
|
|
27
|
+
|
|
28
|
+
.. note::
|
|
29
|
+
Although the recipe for forward pass needs to be defined within
|
|
30
|
+
this function, one should call the :class:`BaseRetriever` instance
|
|
31
|
+
afterwards instead of this since the former takes care of running the
|
|
32
|
+
registered hooks while the latter silently ignores them.
|
|
23
33
|
"""
|
|
34
|
+
raise NotImplementedError(
|
|
35
|
+
f"Retriever [{type(self).__name__}] is missing the required \"query\" function"
|
|
36
|
+
)
|
|
24
37
|
|
|
25
|
-
@abstractmethod
|
|
26
|
-
def __init__(self) -> None:
|
|
27
|
-
pass
|
|
28
38
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
chunk_type: str = "chunk_by_title",
|
|
34
|
-
**kwargs: Any,
|
|
35
|
-
) -> None:
|
|
36
|
-
r"""Processes content from a file or URL, divides it into chunks by
|
|
39
|
+
def _process_unimplemented(self, *input: Any) -> None:
|
|
40
|
+
r"""Defines the process behavior performed at every call.
|
|
41
|
+
|
|
42
|
+
Processes content from a file or URL, divides it into chunks by
|
|
37
43
|
using `Unstructured IO`,then stored internally. This method must be
|
|
38
44
|
called before executing queries with the retriever.
|
|
39
45
|
|
|
40
|
-
|
|
41
|
-
content_input_path (str): File path or URL of the content to be
|
|
42
|
-
processed.
|
|
43
|
-
chunk_type (str): Type of chunking going to apply. Defaults to
|
|
44
|
-
"chunk_by_title".
|
|
45
|
-
**kwargs (Any): Additional keyword arguments for content parsing.
|
|
46
|
-
"""
|
|
47
|
-
pass
|
|
46
|
+
Should be overridden by all subclasses.
|
|
48
47
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
48
|
+
.. note::
|
|
49
|
+
Although the recipe for forward pass needs to be defined within
|
|
50
|
+
this function, one should call the :class:`BaseRetriever` instance
|
|
51
|
+
afterwards instead of this since the former takes care of running the
|
|
52
|
+
registered hooks while the latter silently ignores them.
|
|
53
|
+
"""
|
|
54
|
+
raise NotImplementedError(
|
|
55
|
+
f"Retriever [{type(self).__name__}] is missing the required \"process\" function"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class BaseRetriever(ABC):
|
|
60
|
+
r"""Abstract base class for implementing various types of information
|
|
61
|
+
retrievers.
|
|
62
|
+
"""
|
|
55
63
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
top_k (int, optional): The number of top results to return during
|
|
59
|
-
retriever. Must be a positive integer. Defaults to
|
|
60
|
-
`DEFAULT_TOP_K_RESULTS`.
|
|
61
|
-
**kwargs (Any): Flexible keyword arguments for additional
|
|
62
|
-
parameters, like `similarity_threshold`.
|
|
63
|
-
"""
|
|
64
|
+
@abstractmethod
|
|
65
|
+
def __init__(self) -> None:
|
|
64
66
|
pass
|
|
67
|
+
|
|
68
|
+
process: Callable[..., Any] = _process_unimplemented
|
|
69
|
+
query: Callable[..., Any] = _query_unimplemented
|