camel-ai 0.1.6.8__py3-none-any.whl → 0.1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/agents/chat_agent.py +5 -2
- camel/configs/__init__.py +10 -3
- camel/configs/samba_config.py +67 -8
- camel/embeddings/openai_embedding.py +13 -6
- camel/loaders/firecrawl_reader.py +24 -0
- camel/models/model_factory.py +2 -2
- camel/models/ollama_model.py +26 -4
- camel/models/samba_model.py +257 -97
- camel/models/vllm_model.py +24 -2
- camel/retrievers/auto_retriever.py +7 -6
- camel/retrievers/vector_retriever.py +11 -7
- camel/toolkits/__init__.py +3 -0
- camel/toolkits/reddit_toolkit.py +229 -0
- camel/toolkits/retrieval_toolkit.py +27 -11
- camel/types/enums.py +1 -17
- camel/utils/constants.py +8 -2
- {camel_ai-0.1.6.8.dist-info → camel_ai-0.1.7.0.dist-info}/METADATA +5 -3
- {camel_ai-0.1.6.8.dist-info → camel_ai-0.1.7.0.dist-info}/RECORD +20 -19
- {camel_ai-0.1.6.8.dist-info → camel_ai-0.1.7.0.dist-info}/WHEEL +0 -0
camel/models/samba_model.py
CHANGED
|
@@ -11,14 +11,23 @@
|
|
|
11
11
|
# See the License for the specific language governing permissions and
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
import json
|
|
14
15
|
import os
|
|
16
|
+
import time
|
|
17
|
+
import uuid
|
|
15
18
|
from typing import Any, Dict, List, Optional, Union
|
|
16
19
|
|
|
20
|
+
import httpx
|
|
17
21
|
from openai import OpenAI, Stream
|
|
18
22
|
|
|
19
|
-
from camel.configs import
|
|
23
|
+
from camel.configs import SAMBA_FAST_API_PARAMS, SAMBA_VERSE_API_PARAMS
|
|
20
24
|
from camel.messages import OpenAIMessage
|
|
21
|
-
from camel.types import
|
|
25
|
+
from camel.types import (
|
|
26
|
+
ChatCompletion,
|
|
27
|
+
ChatCompletionChunk,
|
|
28
|
+
CompletionUsage,
|
|
29
|
+
ModelType,
|
|
30
|
+
)
|
|
22
31
|
from camel.utils import (
|
|
23
32
|
BaseTokenCounter,
|
|
24
33
|
OpenAITokenCounter,
|
|
@@ -31,7 +40,7 @@ class SambaModel:
|
|
|
31
40
|
|
|
32
41
|
def __init__(
|
|
33
42
|
self,
|
|
34
|
-
model_type:
|
|
43
|
+
model_type: str,
|
|
35
44
|
model_config_dict: Dict[str, Any],
|
|
36
45
|
api_key: Optional[str] = None,
|
|
37
46
|
url: Optional[str] = None,
|
|
@@ -40,21 +49,29 @@ class SambaModel:
|
|
|
40
49
|
r"""Constructor for SambaNova backend.
|
|
41
50
|
|
|
42
51
|
Args:
|
|
43
|
-
model_type (
|
|
44
|
-
created.
|
|
52
|
+
model_type (str): Model for which a SambaNova backend is
|
|
53
|
+
created. Supported models via Fast API: `https://sambanova.ai/
|
|
54
|
+
fast-api?api_ref=128521`. Supported models via SambaVerse API
|
|
55
|
+
is listed in `https://sambaverse.sambanova.ai/models`.
|
|
45
56
|
model_config_dict (Dict[str, Any]): A dictionary that will
|
|
46
57
|
be fed into API request.
|
|
47
58
|
api_key (Optional[str]): The API key for authenticating with the
|
|
48
59
|
SambaNova service. (default: :obj:`None`)
|
|
49
|
-
url (Optional[str]): The url to the SambaNova service.
|
|
50
|
-
:obj:`"https://fast-api.snova.ai/
|
|
60
|
+
url (Optional[str]): The url to the SambaNova service. Current
|
|
61
|
+
support SambaNova Fast API: :obj:`"https://fast-api.snova.ai/
|
|
62
|
+
v1/chat/ completions"` and SambaVerse API: :obj:`"https://
|
|
63
|
+
sambaverse.sambanova.ai/api/predict"`. (default::obj:`"https://
|
|
64
|
+
fast-api.snova.ai/v1/chat/completions"`)
|
|
51
65
|
token_counter (Optional[BaseTokenCounter]): Token counter to use
|
|
52
66
|
for the model. If not provided, `OpenAITokenCounter(ModelType.
|
|
53
67
|
GPT_4O_MINI)` will be used.
|
|
54
68
|
"""
|
|
55
69
|
self.model_type = model_type
|
|
56
70
|
self._api_key = api_key or os.environ.get("SAMBA_API_KEY")
|
|
57
|
-
self._url = url or os.environ.get(
|
|
71
|
+
self._url = url or os.environ.get(
|
|
72
|
+
"SAMBA_API_BASE_URL",
|
|
73
|
+
"https://fast-api.snova.ai/v1/chat/completions",
|
|
74
|
+
)
|
|
58
75
|
self._token_counter = token_counter
|
|
59
76
|
self.model_config_dict = model_config_dict
|
|
60
77
|
self.check_model_config()
|
|
@@ -79,12 +96,26 @@ class SambaModel:
|
|
|
79
96
|
ValueError: If the model configuration dictionary contains any
|
|
80
97
|
unexpected arguments to SambaNova API.
|
|
81
98
|
"""
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
99
|
+
if self._url == "https://fast-api.snova.ai/v1/chat/completions":
|
|
100
|
+
for param in self.model_config_dict:
|
|
101
|
+
if param not in SAMBA_FAST_API_PARAMS:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
f"Unexpected argument `{param}` is "
|
|
104
|
+
"input into SambaNova Fast API."
|
|
105
|
+
)
|
|
106
|
+
elif self._url == "https://sambaverse.sambanova.ai/api/predict":
|
|
107
|
+
for param in self.model_config_dict:
|
|
108
|
+
if param not in SAMBA_VERSE_API_PARAMS:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
f"Unexpected argument `{param}` is "
|
|
111
|
+
"input into SambaVerse API."
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
else:
|
|
115
|
+
raise ValueError(
|
|
116
|
+
f"{self._url} is not supported, please check the url to the"
|
|
117
|
+
" SambaNova service"
|
|
118
|
+
)
|
|
88
119
|
|
|
89
120
|
@api_keys_required("SAMBA_API_KEY")
|
|
90
121
|
def run( # type: ignore[misc]
|
|
@@ -125,38 +156,44 @@ class SambaModel:
|
|
|
125
156
|
RuntimeError: If the HTTP request fails.
|
|
126
157
|
"""
|
|
127
158
|
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
159
|
+
# Handle SambaNova's Fast API
|
|
160
|
+
if self._url == "https://fast-api.snova.ai/v1/chat/completions":
|
|
161
|
+
headers = {
|
|
162
|
+
"Authorization": f"Basic {self._api_key}",
|
|
163
|
+
"Content-Type": "application/json",
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
data = {
|
|
167
|
+
"messages": messages,
|
|
168
|
+
"max_tokens": self.token_limit,
|
|
169
|
+
"stop": self.model_config_dict.get("stop"),
|
|
170
|
+
"model": self.model_type,
|
|
171
|
+
"stream": True,
|
|
172
|
+
"stream_options": self.model_config_dict.get("stream_options"),
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
try:
|
|
176
|
+
with httpx.stream(
|
|
177
|
+
"POST",
|
|
178
|
+
self._url,
|
|
179
|
+
headers=headers,
|
|
180
|
+
json=data,
|
|
181
|
+
) as api_response:
|
|
182
|
+
stream = Stream[ChatCompletionChunk](
|
|
183
|
+
cast_to=ChatCompletionChunk,
|
|
184
|
+
response=api_response,
|
|
185
|
+
client=OpenAI(api_key="required_but_not_used"),
|
|
186
|
+
)
|
|
187
|
+
for chunk in stream:
|
|
188
|
+
yield chunk
|
|
189
|
+
except httpx.HTTPError as e:
|
|
190
|
+
raise RuntimeError(f"HTTP request failed: {e!s}")
|
|
191
|
+
|
|
192
|
+
elif self._url == "https://sambaverse.sambanova.ai/api/predict":
|
|
193
|
+
raise ValueError(
|
|
194
|
+
"https://sambaverse.sambanova.ai/api/predict doesn't support"
|
|
195
|
+
" stream mode"
|
|
196
|
+
)
|
|
160
197
|
|
|
161
198
|
def _run_non_streaming(
|
|
162
199
|
self, messages: List[OpenAIMessage]
|
|
@@ -177,62 +214,138 @@ class SambaModel:
|
|
|
177
214
|
expected data.
|
|
178
215
|
"""
|
|
179
216
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
217
|
+
# Handle SambaNova's Fast API
|
|
218
|
+
if self._url == "https://fast-api.snova.ai/v1/chat/completions":
|
|
219
|
+
headers = {
|
|
220
|
+
"Authorization": f"Basic {self._api_key}",
|
|
221
|
+
"Content-Type": "application/json",
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
data = {
|
|
225
|
+
"messages": messages,
|
|
226
|
+
"max_tokens": self.token_limit,
|
|
227
|
+
"stop": self.model_config_dict.get("stop"),
|
|
228
|
+
"model": self.model_type,
|
|
229
|
+
"stream": True,
|
|
230
|
+
"stream_options": self.model_config_dict.get("stream_options"),
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
try:
|
|
234
|
+
with httpx.stream(
|
|
235
|
+
"POST",
|
|
236
|
+
self._url,
|
|
237
|
+
headers=headers,
|
|
238
|
+
json=data,
|
|
239
|
+
) as api_response:
|
|
240
|
+
samba_response = []
|
|
241
|
+
for chunk in api_response.iter_text():
|
|
242
|
+
if chunk.startswith('data: '):
|
|
243
|
+
chunk = chunk[6:]
|
|
244
|
+
if '[DONE]' in chunk:
|
|
245
|
+
break
|
|
246
|
+
json_data = json.loads(chunk)
|
|
247
|
+
samba_response.append(json_data)
|
|
248
|
+
return self._fastapi_to_openai_response(samba_response)
|
|
249
|
+
except httpx.HTTPError as e:
|
|
250
|
+
raise RuntimeError(f"HTTP request failed: {e!s}")
|
|
251
|
+
except json.JSONDecodeError as e:
|
|
252
|
+
raise ValueError(f"Failed to decode JSON response: {e!s}")
|
|
253
|
+
|
|
254
|
+
# Handle SambaNova's Sambaverse API
|
|
255
|
+
else:
|
|
256
|
+
headers = {
|
|
257
|
+
"Content-Type": "application/json",
|
|
258
|
+
"key": str(self._api_key),
|
|
259
|
+
"modelName": self.model_type,
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
data = {
|
|
263
|
+
"instance": json.dumps(
|
|
264
|
+
{
|
|
265
|
+
"conversation_id": str(uuid.uuid4()),
|
|
266
|
+
"messages": messages,
|
|
267
|
+
}
|
|
268
|
+
),
|
|
269
|
+
"params": {
|
|
270
|
+
"do_sample": {"type": "bool", "value": "true"},
|
|
271
|
+
"max_tokens_to_generate": {
|
|
272
|
+
"type": "int",
|
|
273
|
+
"value": str(self.model_config_dict.get("max_tokens")),
|
|
274
|
+
},
|
|
275
|
+
"process_prompt": {"type": "bool", "value": "true"},
|
|
276
|
+
"repetition_penalty": {
|
|
277
|
+
"type": "float",
|
|
278
|
+
"value": str(
|
|
279
|
+
self.model_config_dict.get("repetition_penalty")
|
|
280
|
+
),
|
|
281
|
+
},
|
|
282
|
+
"return_token_count_only": {
|
|
283
|
+
"type": "bool",
|
|
284
|
+
"value": "false",
|
|
285
|
+
},
|
|
286
|
+
"select_expert": {
|
|
287
|
+
"type": "str",
|
|
288
|
+
"value": self.model_type.split('/')[1],
|
|
289
|
+
},
|
|
290
|
+
"stop_sequences": {
|
|
291
|
+
"type": "str",
|
|
292
|
+
"value": self.model_config_dict.get("stop_sequences"),
|
|
293
|
+
},
|
|
294
|
+
"temperature": {
|
|
295
|
+
"type": "float",
|
|
296
|
+
"value": str(
|
|
297
|
+
self.model_config_dict.get("temperature")
|
|
298
|
+
),
|
|
299
|
+
},
|
|
300
|
+
"top_k": {
|
|
301
|
+
"type": "int",
|
|
302
|
+
"value": str(self.model_config_dict.get("top_k")),
|
|
303
|
+
},
|
|
304
|
+
"top_p": {
|
|
305
|
+
"type": "float",
|
|
306
|
+
"value": str(self.model_config_dict.get("top_p")),
|
|
307
|
+
},
|
|
308
|
+
},
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
try:
|
|
312
|
+
# Send the request and handle the response
|
|
313
|
+
with httpx.Client() as client:
|
|
314
|
+
response = client.post(
|
|
315
|
+
self._url, # type: ignore[arg-type]
|
|
316
|
+
headers=headers,
|
|
317
|
+
json=data,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
raw_text = response.text
|
|
321
|
+
# Split the string into two dictionaries
|
|
322
|
+
dicts = raw_text.split('}\n{')
|
|
323
|
+
|
|
324
|
+
# Keep only the last dictionary
|
|
325
|
+
last_dict = '{' + dicts[-1]
|
|
326
|
+
|
|
327
|
+
# Parse the dictionary
|
|
328
|
+
last_dict = json.loads(last_dict)
|
|
329
|
+
return self._sambaverse_to_openai_response(last_dict) # type: ignore[arg-type]
|
|
330
|
+
|
|
331
|
+
except httpx.HTTPStatusError:
|
|
332
|
+
raise RuntimeError(f"HTTP request failed: {raw_text}")
|
|
333
|
+
|
|
334
|
+
def _fastapi_to_openai_response(
|
|
220
335
|
self, samba_response: List[Dict[str, Any]]
|
|
221
336
|
) -> ChatCompletion:
|
|
222
|
-
r"""Converts SambaNova response chunks into an
|
|
223
|
-
|
|
337
|
+
r"""Converts SambaNova Fast API response chunks into an
|
|
338
|
+
OpenAI-compatible response.
|
|
224
339
|
|
|
225
340
|
Args:
|
|
226
341
|
samba_response (List[Dict[str, Any]]): A list of dictionaries
|
|
227
|
-
representing partial responses from the SambaNova API.
|
|
342
|
+
representing partial responses from the SambaNova Fast API.
|
|
228
343
|
|
|
229
344
|
Returns:
|
|
230
345
|
ChatCompletion: A `ChatCompletion` object constructed from the
|
|
231
346
|
aggregated response data.
|
|
232
|
-
|
|
233
|
-
Raises:
|
|
234
|
-
ValueError: If the response data is invalid or incomplete.
|
|
235
347
|
"""
|
|
348
|
+
|
|
236
349
|
# Step 1: Combine the content from each chunk
|
|
237
350
|
full_content = ""
|
|
238
351
|
for chunk in samba_response:
|
|
@@ -268,17 +381,64 @@ class SambaModel:
|
|
|
268
381
|
|
|
269
382
|
return obj
|
|
270
383
|
|
|
384
|
+
def _sambaverse_to_openai_response(
|
|
385
|
+
self, samba_response: Dict[str, Any]
|
|
386
|
+
) -> ChatCompletion:
|
|
387
|
+
r"""Converts SambaVerse API response into an OpenAI-compatible
|
|
388
|
+
response.
|
|
389
|
+
|
|
390
|
+
Args:
|
|
391
|
+
samba_response (Dict[str, Any]): A dictionary representing
|
|
392
|
+
responses from the SambaVerse API.
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
ChatCompletion: A `ChatCompletion` object constructed from the
|
|
396
|
+
aggregated response data.
|
|
397
|
+
"""
|
|
398
|
+
choices = [
|
|
399
|
+
dict(
|
|
400
|
+
index=0,
|
|
401
|
+
message={
|
|
402
|
+
"role": 'assistant',
|
|
403
|
+
"content": samba_response['result']['responses'][0][
|
|
404
|
+
'completion'
|
|
405
|
+
],
|
|
406
|
+
},
|
|
407
|
+
finish_reason=samba_response['result']['responses'][0][
|
|
408
|
+
'stop_reason'
|
|
409
|
+
],
|
|
410
|
+
)
|
|
411
|
+
]
|
|
412
|
+
|
|
413
|
+
obj = ChatCompletion.construct(
|
|
414
|
+
id=None,
|
|
415
|
+
choices=choices,
|
|
416
|
+
created=int(time.time()),
|
|
417
|
+
model=self.model_type,
|
|
418
|
+
object="chat.completion",
|
|
419
|
+
# SambaVerse API only provide `total_tokens`
|
|
420
|
+
usage=CompletionUsage(
|
|
421
|
+
completion_tokens=0,
|
|
422
|
+
prompt_tokens=0,
|
|
423
|
+
total_tokens=int(
|
|
424
|
+
samba_response['result']['responses'][0][
|
|
425
|
+
'total_tokens_count'
|
|
426
|
+
]
|
|
427
|
+
),
|
|
428
|
+
),
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
return obj
|
|
432
|
+
|
|
271
433
|
@property
|
|
272
434
|
def token_limit(self) -> int:
|
|
273
|
-
r"""Returns the maximum token limit for
|
|
435
|
+
r"""Returns the maximum token limit for the given model.
|
|
274
436
|
|
|
275
437
|
Returns:
|
|
276
438
|
int: The maximum token limit for the given model.
|
|
277
439
|
"""
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
or self.model_type.token_limit
|
|
281
|
-
)
|
|
440
|
+
max_tokens = self.model_config_dict["max_tokens"]
|
|
441
|
+
return max_tokens
|
|
282
442
|
|
|
283
443
|
@property
|
|
284
444
|
def stream(self) -> bool:
|
camel/models/vllm_model.py
CHANGED
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
14
|
import os
|
|
15
|
+
import subprocess
|
|
15
16
|
from typing import Any, Dict, List, Optional, Union
|
|
16
17
|
|
|
17
18
|
from openai import OpenAI, Stream
|
|
@@ -52,17 +53,38 @@ class VLLMModel:
|
|
|
52
53
|
"""
|
|
53
54
|
self.model_type = model_type
|
|
54
55
|
self.model_config_dict = model_config_dict
|
|
55
|
-
self._url =
|
|
56
|
+
self._url = (
|
|
57
|
+
url
|
|
58
|
+
or os.environ.get("VLLM_BASE_URL")
|
|
59
|
+
or "http://localhost:8000/v1"
|
|
60
|
+
)
|
|
61
|
+
if not url and not os.environ.get("VLLM_BASE_URL"):
|
|
62
|
+
self._start_server()
|
|
56
63
|
# Use OpenAI cilent as interface call vLLM
|
|
57
64
|
self._client = OpenAI(
|
|
58
65
|
timeout=60,
|
|
59
66
|
max_retries=3,
|
|
60
|
-
base_url=self._url
|
|
67
|
+
base_url=self._url,
|
|
61
68
|
api_key=api_key,
|
|
62
69
|
)
|
|
63
70
|
self._token_counter = token_counter
|
|
64
71
|
self.check_model_config()
|
|
65
72
|
|
|
73
|
+
def _start_server(self) -> None:
|
|
74
|
+
r"""Starts the vllm server in a subprocess."""
|
|
75
|
+
try:
|
|
76
|
+
subprocess.Popen(
|
|
77
|
+
["vllm", "server", "--port", "8000"],
|
|
78
|
+
stdout=subprocess.PIPE,
|
|
79
|
+
stderr=subprocess.PIPE,
|
|
80
|
+
)
|
|
81
|
+
print(
|
|
82
|
+
f"vllm server started on http://localhost:8000/v1 "
|
|
83
|
+
f"for {self.model_type} model"
|
|
84
|
+
)
|
|
85
|
+
except Exception as e:
|
|
86
|
+
print(f"Failed to start vllm server: {e}")
|
|
87
|
+
|
|
66
88
|
@property
|
|
67
89
|
def token_counter(self) -> BaseTokenCounter:
|
|
68
90
|
r"""Initialize the token counter for the model backend.
|
|
@@ -25,15 +25,13 @@ from camel.storages import (
|
|
|
25
25
|
VectorDBQuery,
|
|
26
26
|
)
|
|
27
27
|
from camel.types import StorageType
|
|
28
|
+
from camel.utils import Constants
|
|
28
29
|
|
|
29
30
|
try:
|
|
30
31
|
from unstructured.documents.elements import Element
|
|
31
32
|
except ImportError:
|
|
32
33
|
Element = None
|
|
33
34
|
|
|
34
|
-
DEFAULT_TOP_K_RESULTS = 1
|
|
35
|
-
DEFAULT_SIMILARITY_THRESHOLD = 0.75
|
|
36
|
-
|
|
37
35
|
|
|
38
36
|
class AutoRetriever:
|
|
39
37
|
r"""Facilitates the automatic retrieval of information using a
|
|
@@ -178,9 +176,10 @@ class AutoRetriever:
|
|
|
178
176
|
self,
|
|
179
177
|
query: str,
|
|
180
178
|
contents: Union[str, List[str], Element, List[Element]],
|
|
181
|
-
top_k: int = DEFAULT_TOP_K_RESULTS,
|
|
182
|
-
similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
|
|
179
|
+
top_k: int = Constants.DEFAULT_TOP_K_RESULTS,
|
|
180
|
+
similarity_threshold: float = Constants.DEFAULT_SIMILARITY_THRESHOLD,
|
|
183
181
|
return_detailed_info: bool = False,
|
|
182
|
+
max_characters: int = 500,
|
|
184
183
|
) -> dict[str, Sequence[Collection[str]]]:
|
|
185
184
|
r"""Executes the automatic vector retriever process using vector
|
|
186
185
|
storage.
|
|
@@ -198,6 +197,8 @@ class AutoRetriever:
|
|
|
198
197
|
return_detailed_info (bool, optional): Whether to return detailed
|
|
199
198
|
information including similarity score, content path and
|
|
200
199
|
metadata. Defaults to `False`.
|
|
200
|
+
max_characters (int): Max number of characters in each chunk.
|
|
201
|
+
Defaults to `500`.
|
|
201
202
|
|
|
202
203
|
Returns:
|
|
203
204
|
dict[str, Sequence[Collection[str]]]: By default, returns
|
|
@@ -262,7 +263,7 @@ class AutoRetriever:
|
|
|
262
263
|
storage=vector_storage_instance,
|
|
263
264
|
embedding_model=self.embedding_model,
|
|
264
265
|
)
|
|
265
|
-
vr.process(content)
|
|
266
|
+
vr.process(content=content, max_characters=max_characters)
|
|
266
267
|
else:
|
|
267
268
|
vr = VectorRetriever(
|
|
268
269
|
storage=vector_storage_instance,
|
|
@@ -25,15 +25,13 @@ from camel.storages import (
|
|
|
25
25
|
VectorDBQuery,
|
|
26
26
|
VectorRecord,
|
|
27
27
|
)
|
|
28
|
+
from camel.utils import Constants
|
|
28
29
|
|
|
29
30
|
try:
|
|
30
31
|
from unstructured.documents.elements import Element
|
|
31
32
|
except ImportError:
|
|
32
33
|
Element = None
|
|
33
34
|
|
|
34
|
-
DEFAULT_TOP_K_RESULTS = 1
|
|
35
|
-
DEFAULT_SIMILARITY_THRESHOLD = 0.75
|
|
36
|
-
|
|
37
35
|
|
|
38
36
|
class VectorRetriever(BaseRetriever):
|
|
39
37
|
r"""An implementation of the `BaseRetriever` by using vector storage and
|
|
@@ -76,6 +74,7 @@ class VectorRetriever(BaseRetriever):
|
|
|
76
74
|
self,
|
|
77
75
|
content: Union[str, Element],
|
|
78
76
|
chunk_type: str = "chunk_by_title",
|
|
77
|
+
max_characters: int = 500,
|
|
79
78
|
**kwargs: Any,
|
|
80
79
|
) -> None:
|
|
81
80
|
r"""Processes content from a file or URL, divides it into chunks by
|
|
@@ -87,6 +86,8 @@ class VectorRetriever(BaseRetriever):
|
|
|
87
86
|
string content or Element object.
|
|
88
87
|
chunk_type (str): Type of chunking going to apply. Defaults to
|
|
89
88
|
"chunk_by_title".
|
|
89
|
+
max_characters (int): Max number of characters in each chunk.
|
|
90
|
+
Defaults to `500`.
|
|
90
91
|
**kwargs (Any): Additional keyword arguments for content parsing.
|
|
91
92
|
"""
|
|
92
93
|
if isinstance(content, Element):
|
|
@@ -101,7 +102,9 @@ class VectorRetriever(BaseRetriever):
|
|
|
101
102
|
elements = [self.uio.create_element_from_text(text=content)]
|
|
102
103
|
if elements:
|
|
103
104
|
chunks = self.uio.chunk_elements(
|
|
104
|
-
chunk_type=chunk_type,
|
|
105
|
+
chunk_type=chunk_type,
|
|
106
|
+
elements=elements,
|
|
107
|
+
max_characters=max_characters,
|
|
105
108
|
)
|
|
106
109
|
if not elements:
|
|
107
110
|
warnings.warn(
|
|
@@ -142,8 +145,8 @@ class VectorRetriever(BaseRetriever):
|
|
|
142
145
|
def query(
|
|
143
146
|
self,
|
|
144
147
|
query: str,
|
|
145
|
-
top_k: int = DEFAULT_TOP_K_RESULTS,
|
|
146
|
-
similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
|
|
148
|
+
top_k: int = Constants.DEFAULT_TOP_K_RESULTS,
|
|
149
|
+
similarity_threshold: float = Constants.DEFAULT_SIMILARITY_THRESHOLD,
|
|
147
150
|
) -> List[Dict[str, Any]]:
|
|
148
151
|
r"""Executes a query in vector storage and compiles the retrieved
|
|
149
152
|
results into a dictionary.
|
|
@@ -154,7 +157,8 @@ class VectorRetriever(BaseRetriever):
|
|
|
154
157
|
for filtering results. Defaults to
|
|
155
158
|
`DEFAULT_SIMILARITY_THRESHOLD`.
|
|
156
159
|
top_k (int, optional): The number of top results to return during
|
|
157
|
-
retriever. Must be a positive integer. Defaults to
|
|
160
|
+
retriever. Must be a positive integer. Defaults to
|
|
161
|
+
`DEFAULT_TOP_K_RESULTS`.
|
|
158
162
|
|
|
159
163
|
Returns:
|
|
160
164
|
List[Dict[str, Any]]: Concatenated list of the query results.
|
camel/toolkits/__init__.py
CHANGED
|
@@ -29,6 +29,7 @@ from .weather_toolkit import WEATHER_FUNCS, WeatherToolkit
|
|
|
29
29
|
from .slack_toolkit import SLACK_FUNCS, SlackToolkit
|
|
30
30
|
from .dalle_toolkit import DALLE_FUNCS, DalleToolkit
|
|
31
31
|
from .linkedin_toolkit import LINKEDIN_FUNCS, LinkedInToolkit
|
|
32
|
+
from .reddit_toolkit import REDDIT_FUNCS, RedditToolkit
|
|
32
33
|
|
|
33
34
|
from .base import BaseToolkit
|
|
34
35
|
from .code_execution import CodeExecutionToolkit
|
|
@@ -49,6 +50,7 @@ __all__ = [
|
|
|
49
50
|
'SLACK_FUNCS',
|
|
50
51
|
'DALLE_FUNCS',
|
|
51
52
|
'LINKEDIN_FUNCS',
|
|
53
|
+
'REDDIT_FUNCS',
|
|
52
54
|
'BaseToolkit',
|
|
53
55
|
'GithubToolkit',
|
|
54
56
|
'MathToolkit',
|
|
@@ -61,5 +63,6 @@ __all__ = [
|
|
|
61
63
|
'RetrievalToolkit',
|
|
62
64
|
'OpenAPIToolkit',
|
|
63
65
|
'LinkedInToolkit',
|
|
66
|
+
'RedditToolkit',
|
|
64
67
|
'CodeExecutionToolkit',
|
|
65
68
|
]
|