camel-ai 0.1.5.4__py3-none-any.whl → 0.1.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/agents/knowledge_graph_agent.py +11 -15
- camel/agents/task_agent.py +0 -1
- camel/configs/__init__.py +12 -0
- camel/configs/gemini_config.py +97 -0
- camel/configs/litellm_config.py +8 -18
- camel/configs/ollama_config.py +85 -0
- camel/configs/zhipuai_config.py +78 -0
- camel/embeddings/openai_embedding.py +2 -2
- camel/functions/search_functions.py +5 -14
- camel/functions/slack_functions.py +5 -7
- camel/functions/twitter_function.py +3 -8
- camel/functions/weather_functions.py +3 -8
- camel/interpreters/__init__.py +2 -0
- camel/interpreters/docker_interpreter.py +235 -0
- camel/loaders/__init__.py +2 -0
- camel/loaders/base_io.py +5 -9
- camel/loaders/jina_url_reader.py +99 -0
- camel/loaders/unstructured_io.py +4 -6
- camel/models/__init__.py +2 -0
- camel/models/anthropic_model.py +6 -4
- camel/models/gemini_model.py +203 -0
- camel/models/litellm_model.py +49 -21
- camel/models/model_factory.py +4 -2
- camel/models/nemotron_model.py +14 -6
- camel/models/ollama_model.py +11 -17
- camel/models/openai_audio_models.py +10 -2
- camel/models/openai_model.py +4 -3
- camel/models/zhipuai_model.py +12 -6
- camel/retrievers/bm25_retriever.py +3 -8
- camel/retrievers/cohere_rerank_retriever.py +3 -5
- camel/storages/__init__.py +2 -0
- camel/storages/graph_storages/neo4j_graph.py +3 -7
- camel/storages/key_value_storages/__init__.py +2 -0
- camel/storages/key_value_storages/redis.py +169 -0
- camel/storages/vectordb_storages/milvus.py +3 -7
- camel/storages/vectordb_storages/qdrant.py +3 -7
- camel/toolkits/__init__.py +2 -0
- camel/toolkits/code_execution.py +69 -0
- camel/toolkits/github_toolkit.py +5 -9
- camel/types/enums.py +53 -1
- camel/utils/__init__.py +4 -2
- camel/utils/async_func.py +42 -0
- camel/utils/commons.py +31 -49
- camel/utils/token_counting.py +74 -1
- {camel_ai-0.1.5.4.dist-info → camel_ai-0.1.5.6.dist-info}/METADATA +12 -3
- {camel_ai-0.1.5.4.dist-info → camel_ai-0.1.5.6.dist-info}/RECORD +48 -39
- {camel_ai-0.1.5.4.dist-info → camel_ai-0.1.5.6.dist-info}/WHEEL +0 -0
camel/models/litellm_model.py
CHANGED
|
@@ -11,24 +11,25 @@
|
|
|
11
11
|
# See the License for the specific language governing permissions and
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
-
from typing import
|
|
14
|
+
from typing import Any, Dict, List, Optional
|
|
15
15
|
|
|
16
16
|
from camel.configs import LITELLM_API_PARAMS
|
|
17
17
|
from camel.messages import OpenAIMessage
|
|
18
|
+
from camel.types import ChatCompletion
|
|
18
19
|
from camel.utils import LiteLLMTokenCounter
|
|
19
20
|
|
|
20
|
-
if TYPE_CHECKING:
|
|
21
|
-
from litellm.utils import CustomStreamWrapper, ModelResponse
|
|
22
|
-
|
|
23
21
|
|
|
24
22
|
class LiteLLMModel:
|
|
25
23
|
r"""Constructor for LiteLLM backend with OpenAI compatibility."""
|
|
26
24
|
|
|
27
|
-
# NOTE: Currently
|
|
28
|
-
# limitation of the current camel design.
|
|
25
|
+
# NOTE: Currently stream mode is not supported.
|
|
29
26
|
|
|
30
27
|
def __init__(
|
|
31
|
-
self,
|
|
28
|
+
self,
|
|
29
|
+
model_type: str,
|
|
30
|
+
model_config_dict: Dict[str, Any],
|
|
31
|
+
api_key: Optional[str] = None,
|
|
32
|
+
url: Optional[str] = None,
|
|
32
33
|
) -> None:
|
|
33
34
|
r"""Constructor for LiteLLM backend.
|
|
34
35
|
|
|
@@ -37,12 +38,48 @@ class LiteLLMModel:
|
|
|
37
38
|
such as GPT-3.5-turbo, Claude-2, etc.
|
|
38
39
|
model_config_dict (Dict[str, Any]): A dictionary of parameters for
|
|
39
40
|
the model configuration.
|
|
41
|
+
api_key (Optional[str]): The API key for authenticating with the
|
|
42
|
+
model service. (default: :obj:`None`)
|
|
43
|
+
url (Optional[str]): The url to the model service. (default:
|
|
44
|
+
:obj:`None`)
|
|
40
45
|
"""
|
|
41
46
|
self.model_type = model_type
|
|
42
47
|
self.model_config_dict = model_config_dict
|
|
43
48
|
self._client = None
|
|
44
49
|
self._token_counter: Optional[LiteLLMTokenCounter] = None
|
|
45
50
|
self.check_model_config()
|
|
51
|
+
self._url = url
|
|
52
|
+
self._api_key = api_key
|
|
53
|
+
|
|
54
|
+
def _convert_response_from_litellm_to_openai(
|
|
55
|
+
self, response
|
|
56
|
+
) -> ChatCompletion:
|
|
57
|
+
r"""Converts a response from the LiteLLM format to the OpenAI format.
|
|
58
|
+
|
|
59
|
+
Parameters:
|
|
60
|
+
response (LiteLLMResponse): The response object from LiteLLM.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
ChatCompletion: The response object in OpenAI's format.
|
|
64
|
+
"""
|
|
65
|
+
return ChatCompletion.construct(
|
|
66
|
+
id=response.id,
|
|
67
|
+
choices=[
|
|
68
|
+
{
|
|
69
|
+
"index": response.choices[0].index,
|
|
70
|
+
"message": {
|
|
71
|
+
"role": response.choices[0].message.role,
|
|
72
|
+
"content": response.choices[0].message.content,
|
|
73
|
+
},
|
|
74
|
+
"finish_reason": response.choices[0].finish_reason,
|
|
75
|
+
}
|
|
76
|
+
],
|
|
77
|
+
created=response.created,
|
|
78
|
+
model=response.model,
|
|
79
|
+
object=response.object,
|
|
80
|
+
system_fingerprint=response.system_fingerprint,
|
|
81
|
+
usage=response.usage,
|
|
82
|
+
)
|
|
46
83
|
|
|
47
84
|
@property
|
|
48
85
|
def client(self):
|
|
@@ -67,7 +104,7 @@ class LiteLLMModel:
|
|
|
67
104
|
def run(
|
|
68
105
|
self,
|
|
69
106
|
messages: List[OpenAIMessage],
|
|
70
|
-
) ->
|
|
107
|
+
) -> ChatCompletion:
|
|
71
108
|
r"""Runs inference of LiteLLM chat completion.
|
|
72
109
|
|
|
73
110
|
Args:
|
|
@@ -75,15 +112,16 @@ class LiteLLMModel:
|
|
|
75
112
|
in OpenAI format.
|
|
76
113
|
|
|
77
114
|
Returns:
|
|
78
|
-
|
|
79
|
-
`ModelResponse` in the non-stream mode, or
|
|
80
|
-
`CustomStreamWrapper` in the stream mode.
|
|
115
|
+
ChatCompletion
|
|
81
116
|
"""
|
|
82
117
|
response = self.client(
|
|
118
|
+
api_key=self._api_key,
|
|
119
|
+
base_url=self._url,
|
|
83
120
|
model=self.model_type,
|
|
84
121
|
messages=messages,
|
|
85
122
|
**self.model_config_dict,
|
|
86
123
|
)
|
|
124
|
+
response = self._convert_response_from_litellm_to_openai(response)
|
|
87
125
|
return response
|
|
88
126
|
|
|
89
127
|
def check_model_config(self):
|
|
@@ -100,13 +138,3 @@ class LiteLLMModel:
|
|
|
100
138
|
f"Unexpected argument `{param}` is "
|
|
101
139
|
"input into LiteLLM model backend."
|
|
102
140
|
)
|
|
103
|
-
|
|
104
|
-
@property
|
|
105
|
-
def stream(self) -> bool:
|
|
106
|
-
r"""Returns whether the model is in stream mode, which sends partial
|
|
107
|
-
results each time.
|
|
108
|
-
|
|
109
|
-
Returns:
|
|
110
|
-
bool: Whether the model is in stream mode.
|
|
111
|
-
"""
|
|
112
|
-
return self.model_config_dict.get('stream', False)
|
camel/models/model_factory.py
CHANGED
|
@@ -15,6 +15,7 @@ from typing import Any, Dict, Optional, Union
|
|
|
15
15
|
|
|
16
16
|
from camel.models.anthropic_model import AnthropicModel
|
|
17
17
|
from camel.models.base_model import BaseModelBackend
|
|
18
|
+
from camel.models.gemini_model import GeminiModel
|
|
18
19
|
from camel.models.litellm_model import LiteLLMModel
|
|
19
20
|
from camel.models.ollama_model import OllamaModel
|
|
20
21
|
from camel.models.open_source_model import OpenSourceModel
|
|
@@ -59,7 +60,6 @@ class ModelFactory:
|
|
|
59
60
|
BaseModelBackend: The initialized backend.
|
|
60
61
|
"""
|
|
61
62
|
model_class: Any
|
|
62
|
-
|
|
63
63
|
if isinstance(model_type, ModelType):
|
|
64
64
|
if model_platform.is_open_source and model_type.is_open_source:
|
|
65
65
|
model_class = OpenSourceModel
|
|
@@ -70,6 +70,8 @@ class ModelFactory:
|
|
|
70
70
|
model_class = AnthropicModel
|
|
71
71
|
elif model_platform.is_zhipuai and model_type.is_zhipuai:
|
|
72
72
|
model_class = ZhipuAIModel
|
|
73
|
+
elif model_platform.is_gemini and model_type.is_gemini:
|
|
74
|
+
model_class = GeminiModel
|
|
73
75
|
elif model_type == ModelType.STUB:
|
|
74
76
|
model_class = StubModel
|
|
75
77
|
else:
|
|
@@ -80,6 +82,7 @@ class ModelFactory:
|
|
|
80
82
|
elif isinstance(model_type, str):
|
|
81
83
|
if model_platform.is_ollama:
|
|
82
84
|
model_class = OllamaModel
|
|
85
|
+
return model_class(model_type, model_config_dict, url)
|
|
83
86
|
elif model_platform.is_litellm:
|
|
84
87
|
model_class = LiteLLMModel
|
|
85
88
|
else:
|
|
@@ -89,5 +92,4 @@ class ModelFactory:
|
|
|
89
92
|
)
|
|
90
93
|
else:
|
|
91
94
|
raise ValueError(f"Invalid model type `{model_type}` provided.")
|
|
92
|
-
|
|
93
95
|
return model_class(model_type, model_config_dict, api_key, url)
|
camel/models/nemotron_model.py
CHANGED
|
@@ -20,7 +20,7 @@ from camel.messages import OpenAIMessage
|
|
|
20
20
|
from camel.types import ChatCompletion, ModelType
|
|
21
21
|
from camel.utils import (
|
|
22
22
|
BaseTokenCounter,
|
|
23
|
-
|
|
23
|
+
api_keys_required,
|
|
24
24
|
)
|
|
25
25
|
|
|
26
26
|
|
|
@@ -33,6 +33,7 @@ class NemotronModel:
|
|
|
33
33
|
self,
|
|
34
34
|
model_type: ModelType,
|
|
35
35
|
api_key: Optional[str] = None,
|
|
36
|
+
url: Optional[str] = None,
|
|
36
37
|
) -> None:
|
|
37
38
|
r"""Constructor for Nvidia backend.
|
|
38
39
|
|
|
@@ -40,18 +41,25 @@ class NemotronModel:
|
|
|
40
41
|
model_type (ModelType): Model for which a backend is created.
|
|
41
42
|
api_key (Optional[str]): The API key for authenticating with the
|
|
42
43
|
Nvidia service. (default: :obj:`None`)
|
|
44
|
+
url (Optional[str]): The url to the Nvidia service. (default:
|
|
45
|
+
:obj:`None`)
|
|
43
46
|
"""
|
|
44
47
|
self.model_type = model_type
|
|
45
|
-
|
|
48
|
+
self._url = url or os.environ.get("NVIDIA_API_BASE_URL")
|
|
46
49
|
self._api_key = api_key or os.environ.get("NVIDIA_API_KEY")
|
|
47
|
-
if not
|
|
48
|
-
raise ValueError(
|
|
50
|
+
if not self._url or not self._api_key:
|
|
51
|
+
raise ValueError(
|
|
52
|
+
"NVIDIA_API_BASE_URL and NVIDIA_API_KEY should be set."
|
|
53
|
+
)
|
|
49
54
|
self._client = OpenAI(
|
|
50
|
-
timeout=60,
|
|
55
|
+
timeout=60,
|
|
56
|
+
max_retries=3,
|
|
57
|
+
base_url=self._url,
|
|
58
|
+
api_key=self._api_key,
|
|
51
59
|
)
|
|
52
60
|
self._token_counter: Optional[BaseTokenCounter] = None
|
|
53
61
|
|
|
54
|
-
@
|
|
62
|
+
@api_keys_required("NVIDIA_API_KEY")
|
|
55
63
|
def run(
|
|
56
64
|
self,
|
|
57
65
|
messages: List[OpenAIMessage],
|
camel/models/ollama_model.py
CHANGED
|
@@ -11,12 +11,11 @@
|
|
|
11
11
|
# See the License for the specific language governing permissions and
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
-
import os
|
|
15
14
|
from typing import Any, Dict, List, Optional, Union
|
|
16
15
|
|
|
17
16
|
from openai import OpenAI, Stream
|
|
18
17
|
|
|
19
|
-
from camel.configs import
|
|
18
|
+
from camel.configs import OLLAMA_API_PARAMS
|
|
20
19
|
from camel.messages import OpenAIMessage
|
|
21
20
|
from camel.types import ChatCompletion, ChatCompletionChunk, ModelType
|
|
22
21
|
from camel.utils import BaseTokenCounter, OpenAITokenCounter
|
|
@@ -25,39 +24,34 @@ from camel.utils import BaseTokenCounter, OpenAITokenCounter
|
|
|
25
24
|
class OllamaModel:
|
|
26
25
|
r"""Ollama service interface."""
|
|
27
26
|
|
|
28
|
-
# NOTE: Current `ModelType and `TokenCounter` desigen is not suitable,
|
|
29
|
-
# stream mode is not supported
|
|
30
|
-
|
|
31
27
|
def __init__(
|
|
32
28
|
self,
|
|
33
29
|
model_type: str,
|
|
34
30
|
model_config_dict: Dict[str, Any],
|
|
35
|
-
api_key: Optional[str] = None,
|
|
36
31
|
url: Optional[str] = None,
|
|
37
32
|
) -> None:
|
|
38
33
|
r"""Constructor for Ollama backend with OpenAI compatibility.
|
|
39
34
|
|
|
35
|
+
# Reference: https://github.com/ollama/ollama/blob/main/docs/openai.md
|
|
36
|
+
|
|
40
37
|
Args:
|
|
41
38
|
model_type (str): Model for which a backend is created.
|
|
42
39
|
model_config_dict (Dict[str, Any]): A dictionary that will
|
|
43
40
|
be fed into openai.ChatCompletion.create().
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
url (Optional[str]): The url to the model service.
|
|
41
|
+
url (Optional[str]): The url to the model service. (default:
|
|
42
|
+
:obj:`None`)
|
|
47
43
|
"""
|
|
48
44
|
self.model_type = model_type
|
|
49
45
|
self.model_config_dict = model_config_dict
|
|
50
|
-
self._url = url or os.environ.get('OPENAI_API_BASE_URL')
|
|
51
|
-
self._api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
|
52
46
|
# Use OpenAI cilent as interface call Ollama
|
|
53
|
-
# Reference: https://github.com/ollama/ollama/blob/main/docs/openai.md
|
|
54
47
|
self._client = OpenAI(
|
|
55
48
|
timeout=60,
|
|
56
49
|
max_retries=3,
|
|
57
|
-
base_url=
|
|
58
|
-
api_key=
|
|
50
|
+
base_url=url,
|
|
51
|
+
api_key="ollama", # required but ignored
|
|
59
52
|
)
|
|
60
53
|
self._token_counter: Optional[BaseTokenCounter] = None
|
|
54
|
+
self.check_model_config()
|
|
61
55
|
|
|
62
56
|
@property
|
|
63
57
|
def token_counter(self) -> BaseTokenCounter:
|
|
@@ -74,17 +68,17 @@ class OllamaModel:
|
|
|
74
68
|
|
|
75
69
|
def check_model_config(self):
|
|
76
70
|
r"""Check whether the model configuration contains any
|
|
77
|
-
unexpected arguments to
|
|
71
|
+
unexpected arguments to Ollama API.
|
|
78
72
|
|
|
79
73
|
Raises:
|
|
80
74
|
ValueError: If the model configuration dictionary contains any
|
|
81
75
|
unexpected arguments to OpenAI API.
|
|
82
76
|
"""
|
|
83
77
|
for param in self.model_config_dict:
|
|
84
|
-
if param not in
|
|
78
|
+
if param not in OLLAMA_API_PARAMS:
|
|
85
79
|
raise ValueError(
|
|
86
80
|
f"Unexpected argument `{param}` is "
|
|
87
|
-
"input into
|
|
81
|
+
"input into Ollama model backend."
|
|
88
82
|
)
|
|
89
83
|
|
|
90
84
|
def run(
|
|
@@ -25,10 +25,18 @@ class OpenAIAudioModels:
|
|
|
25
25
|
|
|
26
26
|
def __init__(
|
|
27
27
|
self,
|
|
28
|
+
api_key: Optional[str] = None,
|
|
29
|
+
url: Optional[str] = None,
|
|
28
30
|
) -> None:
|
|
29
31
|
r"""Initialize an instance of OpenAI."""
|
|
30
|
-
|
|
31
|
-
self.
|
|
32
|
+
self._url = url or os.environ.get("OPENAI_API_BASE_URL")
|
|
33
|
+
self._api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
|
34
|
+
self._client = OpenAI(
|
|
35
|
+
timeout=120,
|
|
36
|
+
max_retries=3,
|
|
37
|
+
base_url=self._url,
|
|
38
|
+
api_key=self._api_key,
|
|
39
|
+
)
|
|
32
40
|
|
|
33
41
|
def text_to_speech(
|
|
34
42
|
self,
|
camel/models/openai_model.py
CHANGED
|
@@ -23,7 +23,7 @@ from camel.types import ChatCompletion, ChatCompletionChunk, ModelType
|
|
|
23
23
|
from camel.utils import (
|
|
24
24
|
BaseTokenCounter,
|
|
25
25
|
OpenAITokenCounter,
|
|
26
|
-
|
|
26
|
+
api_keys_required,
|
|
27
27
|
)
|
|
28
28
|
|
|
29
29
|
|
|
@@ -46,7 +46,8 @@ class OpenAIModel(BaseModelBackend):
|
|
|
46
46
|
be fed into openai.ChatCompletion.create().
|
|
47
47
|
api_key (Optional[str]): The API key for authenticating with the
|
|
48
48
|
OpenAI service. (default: :obj:`None`)
|
|
49
|
-
url (Optional[str]): The url to the OpenAI service.
|
|
49
|
+
url (Optional[str]): The url to the OpenAI service. (default:
|
|
50
|
+
:obj:`None`)
|
|
50
51
|
"""
|
|
51
52
|
super().__init__(model_type, model_config_dict, api_key, url)
|
|
52
53
|
self._url = url or os.environ.get("OPENAI_API_BASE_URL")
|
|
@@ -71,7 +72,7 @@ class OpenAIModel(BaseModelBackend):
|
|
|
71
72
|
self._token_counter = OpenAITokenCounter(self.model_type)
|
|
72
73
|
return self._token_counter
|
|
73
74
|
|
|
74
|
-
@
|
|
75
|
+
@api_keys_required("OPENAI_API_KEY")
|
|
75
76
|
def run(
|
|
76
77
|
self,
|
|
77
78
|
messages: List[OpenAIMessage],
|
camel/models/zhipuai_model.py
CHANGED
|
@@ -17,14 +17,14 @@ from typing import Any, Dict, List, Optional, Union
|
|
|
17
17
|
|
|
18
18
|
from openai import OpenAI, Stream
|
|
19
19
|
|
|
20
|
-
from camel.configs import
|
|
20
|
+
from camel.configs import ZHIPUAI_API_PARAMS
|
|
21
21
|
from camel.messages import OpenAIMessage
|
|
22
22
|
from camel.models import BaseModelBackend
|
|
23
23
|
from camel.types import ChatCompletion, ChatCompletionChunk, ModelType
|
|
24
24
|
from camel.utils import (
|
|
25
25
|
BaseTokenCounter,
|
|
26
26
|
OpenAITokenCounter,
|
|
27
|
-
|
|
27
|
+
api_keys_required,
|
|
28
28
|
)
|
|
29
29
|
|
|
30
30
|
|
|
@@ -47,10 +47,16 @@ class ZhipuAIModel(BaseModelBackend):
|
|
|
47
47
|
be fed into openai.ChatCompletion.create().
|
|
48
48
|
api_key (Optional[str]): The API key for authenticating with the
|
|
49
49
|
ZhipuAI service. (default: :obj:`None`)
|
|
50
|
+
url (Optional[str]): The url to the ZhipuAI service. (default:
|
|
51
|
+
:obj:`None`)
|
|
50
52
|
"""
|
|
51
53
|
super().__init__(model_type, model_config_dict)
|
|
52
54
|
self._url = url or os.environ.get("ZHIPUAI_API_BASE_URL")
|
|
53
55
|
self._api_key = api_key or os.environ.get("ZHIPUAI_API_KEY")
|
|
56
|
+
if not self._url or not self._api_key:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
"ZHIPUAI_API_BASE_URL and ZHIPUAI_API_KEY should be set."
|
|
59
|
+
)
|
|
54
60
|
self._client = OpenAI(
|
|
55
61
|
timeout=60,
|
|
56
62
|
max_retries=3,
|
|
@@ -59,7 +65,7 @@ class ZhipuAIModel(BaseModelBackend):
|
|
|
59
65
|
)
|
|
60
66
|
self._token_counter: Optional[BaseTokenCounter] = None
|
|
61
67
|
|
|
62
|
-
@
|
|
68
|
+
@api_keys_required("ZHIPUAI_API_KEY")
|
|
63
69
|
def run(
|
|
64
70
|
self,
|
|
65
71
|
messages: List[OpenAIMessage],
|
|
@@ -104,13 +110,13 @@ class ZhipuAIModel(BaseModelBackend):
|
|
|
104
110
|
|
|
105
111
|
Raises:
|
|
106
112
|
ValueError: If the model configuration dictionary contains any
|
|
107
|
-
unexpected arguments to
|
|
113
|
+
unexpected arguments to ZhipuAI API.
|
|
108
114
|
"""
|
|
109
115
|
for param in self.model_config_dict:
|
|
110
|
-
if param not in
|
|
116
|
+
if param not in ZHIPUAI_API_PARAMS:
|
|
111
117
|
raise ValueError(
|
|
112
118
|
f"Unexpected argument `{param}` is "
|
|
113
|
-
"input into
|
|
119
|
+
"input into ZhipuAI model backend."
|
|
114
120
|
)
|
|
115
121
|
pass
|
|
116
122
|
|
|
@@ -17,6 +17,7 @@ import numpy as np
|
|
|
17
17
|
|
|
18
18
|
from camel.loaders import UnstructuredIO
|
|
19
19
|
from camel.retrievers import BaseRetriever
|
|
20
|
+
from camel.utils import dependencies_required
|
|
20
21
|
|
|
21
22
|
DEFAULT_TOP_K_RESULTS = 1
|
|
22
23
|
|
|
@@ -40,16 +41,10 @@ class BM25Retriever(BaseRetriever):
|
|
|
40
41
|
https://github.com/dorianbrown/rank_bm25
|
|
41
42
|
"""
|
|
42
43
|
|
|
44
|
+
@dependencies_required('rank_bm25')
|
|
43
45
|
def __init__(self) -> None:
|
|
44
46
|
r"""Initializes the BM25Retriever."""
|
|
45
|
-
|
|
46
|
-
try:
|
|
47
|
-
from rank_bm25 import BM25Okapi
|
|
48
|
-
except ImportError as e:
|
|
49
|
-
raise ImportError(
|
|
50
|
-
"Package `rank_bm25` not installed, install by running 'pip "
|
|
51
|
-
"install rank_bm25'"
|
|
52
|
-
) from e
|
|
47
|
+
from rank_bm25 import BM25Okapi
|
|
53
48
|
|
|
54
49
|
self.bm25: BM25Okapi = None
|
|
55
50
|
self.content_input_path: str = ""
|
|
@@ -15,6 +15,7 @@ import os
|
|
|
15
15
|
from typing import Any, Dict, List, Optional
|
|
16
16
|
|
|
17
17
|
from camel.retrievers import BaseRetriever
|
|
18
|
+
from camel.utils import dependencies_required
|
|
18
19
|
|
|
19
20
|
DEFAULT_TOP_K_RESULTS = 1
|
|
20
21
|
|
|
@@ -32,6 +33,7 @@ class CohereRerankRetriever(BaseRetriever):
|
|
|
32
33
|
https://txt.cohere.com/rerank/
|
|
33
34
|
"""
|
|
34
35
|
|
|
36
|
+
@dependencies_required('cohere')
|
|
35
37
|
def __init__(
|
|
36
38
|
self,
|
|
37
39
|
model_name: str = "rerank-multilingual-v2.0",
|
|
@@ -56,11 +58,7 @@ class CohereRerankRetriever(BaseRetriever):
|
|
|
56
58
|
ValueError: If the API key is neither passed as an argument nor
|
|
57
59
|
set in the environment variable.
|
|
58
60
|
"""
|
|
59
|
-
|
|
60
|
-
try:
|
|
61
|
-
import cohere
|
|
62
|
-
except ImportError as e:
|
|
63
|
-
raise ImportError("Package 'cohere' is not installed") from e
|
|
61
|
+
import cohere
|
|
64
62
|
|
|
65
63
|
try:
|
|
66
64
|
self.api_key = api_key or os.environ["COHERE_API_KEY"]
|
camel/storages/__init__.py
CHANGED
|
@@ -17,6 +17,7 @@ from .graph_storages.neo4j_graph import Neo4jGraph
|
|
|
17
17
|
from .key_value_storages.base import BaseKeyValueStorage
|
|
18
18
|
from .key_value_storages.in_memory import InMemoryKeyValueStorage
|
|
19
19
|
from .key_value_storages.json import JsonStorage
|
|
20
|
+
from .key_value_storages.redis import RedisStorage
|
|
20
21
|
from .vectordb_storages.base import (
|
|
21
22
|
BaseVectorStorage,
|
|
22
23
|
VectorDBQuery,
|
|
@@ -30,6 +31,7 @@ __all__ = [
|
|
|
30
31
|
'BaseKeyValueStorage',
|
|
31
32
|
'InMemoryKeyValueStorage',
|
|
32
33
|
'JsonStorage',
|
|
34
|
+
'RedisStorage',
|
|
33
35
|
'VectorRecord',
|
|
34
36
|
'BaseVectorStorage',
|
|
35
37
|
'VectorDBQuery',
|
|
@@ -16,6 +16,7 @@ from hashlib import md5
|
|
|
16
16
|
from typing import Any, Dict, List, Optional
|
|
17
17
|
|
|
18
18
|
from camel.storages.graph_storages import BaseGraphStorage, GraphElement
|
|
19
|
+
from camel.utils import dependencies_required
|
|
19
20
|
|
|
20
21
|
logger = logging.getLogger(__name__)
|
|
21
22
|
|
|
@@ -81,6 +82,7 @@ class Neo4jGraph(BaseGraphStorage):
|
|
|
81
82
|
than `LIST_LIMIT` elements from results. Defaults to `False`.
|
|
82
83
|
"""
|
|
83
84
|
|
|
85
|
+
@dependencies_required('neo4j')
|
|
84
86
|
def __init__(
|
|
85
87
|
self,
|
|
86
88
|
url: str,
|
|
@@ -91,13 +93,7 @@ class Neo4jGraph(BaseGraphStorage):
|
|
|
91
93
|
truncate: bool = False,
|
|
92
94
|
) -> None:
|
|
93
95
|
r"""Create a new Neo4j graph instance."""
|
|
94
|
-
|
|
95
|
-
import neo4j
|
|
96
|
-
except ImportError:
|
|
97
|
-
raise ValueError(
|
|
98
|
-
"Could not import neo4j python package. "
|
|
99
|
-
"Please install it with `pip install neo4j`."
|
|
100
|
-
)
|
|
96
|
+
import neo4j
|
|
101
97
|
|
|
102
98
|
self.driver = neo4j.GraphDatabase.driver(
|
|
103
99
|
url, auth=(username, password)
|
|
@@ -15,9 +15,11 @@
|
|
|
15
15
|
from .base import BaseKeyValueStorage
|
|
16
16
|
from .in_memory import InMemoryKeyValueStorage
|
|
17
17
|
from .json import JsonStorage
|
|
18
|
+
from .redis import RedisStorage
|
|
18
19
|
|
|
19
20
|
__all__ = [
|
|
20
21
|
'BaseKeyValueStorage',
|
|
21
22
|
'InMemoryKeyValueStorage',
|
|
22
23
|
'JsonStorage',
|
|
24
|
+
'RedisStorage',
|
|
23
25
|
]
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the “License”);
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an “AS IS” BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|
19
|
+
|
|
20
|
+
from camel.storages.key_value_storages import BaseKeyValueStorage
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from redis.asyncio import Redis
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class RedisStorage(BaseKeyValueStorage):
|
|
29
|
+
r"""A concrete implementation of the :obj:`BaseCacheStorage` using Redis as
|
|
30
|
+
the backend. This is suitable for distributed cache systems that require
|
|
31
|
+
persistence and high availability.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
sid: str,
|
|
37
|
+
url: str = "redis://localhost:6379",
|
|
38
|
+
loop: Optional[asyncio.AbstractEventLoop] = None,
|
|
39
|
+
**kwargs,
|
|
40
|
+
) -> None:
|
|
41
|
+
r"""Initializes the RedisStorage instance with the provided URL and
|
|
42
|
+
options.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
sid (str): The ID for the storage instance to identify the
|
|
46
|
+
record space.
|
|
47
|
+
url (str): The URL for connecting to the Redis server.
|
|
48
|
+
**kwargs: Additional keyword arguments for Redis client
|
|
49
|
+
configuration.
|
|
50
|
+
|
|
51
|
+
Raises:
|
|
52
|
+
ImportError: If the `redis.asyncio` module is not installed.
|
|
53
|
+
"""
|
|
54
|
+
try:
|
|
55
|
+
import redis.asyncio as aredis
|
|
56
|
+
except ImportError as exc:
|
|
57
|
+
logger.error(
|
|
58
|
+
"Please install `redis` first. You can install it by "
|
|
59
|
+
"running `pip install redis`."
|
|
60
|
+
)
|
|
61
|
+
raise exc
|
|
62
|
+
|
|
63
|
+
self._client: Optional[aredis.Redis] = None
|
|
64
|
+
self._url = url
|
|
65
|
+
self._sid = sid
|
|
66
|
+
self._loop = loop or asyncio.get_event_loop()
|
|
67
|
+
|
|
68
|
+
self._create_client(**kwargs)
|
|
69
|
+
|
|
70
|
+
def __enter__(self):
|
|
71
|
+
return self
|
|
72
|
+
|
|
73
|
+
def __exit__(self, exc_type, exc, tb):
|
|
74
|
+
self._run_async(self.close())
|
|
75
|
+
|
|
76
|
+
async def close(self) -> None:
|
|
77
|
+
r"""Closes the Redis client asynchronously."""
|
|
78
|
+
if self._client:
|
|
79
|
+
await self._client.close()
|
|
80
|
+
|
|
81
|
+
def _create_client(self, **kwargs) -> None:
|
|
82
|
+
r"""Creates the Redis client with the provided URL and options.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
**kwargs: Additional keyword arguments for Redis client
|
|
86
|
+
configuration.
|
|
87
|
+
"""
|
|
88
|
+
import redis.asyncio as aredis
|
|
89
|
+
|
|
90
|
+
self._client = aredis.from_url(self._url, **kwargs)
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def client(self) -> Optional["Redis"]:
|
|
94
|
+
r"""Returns the Redis client instance.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
redis.asyncio.Redis: The Redis client instance.
|
|
98
|
+
"""
|
|
99
|
+
return self._client
|
|
100
|
+
|
|
101
|
+
def save(
|
|
102
|
+
self, records: List[Dict[str, Any]], expire: Optional[int] = None
|
|
103
|
+
) -> None:
|
|
104
|
+
r"""Saves a batch of records to the key-value storage system."""
|
|
105
|
+
try:
|
|
106
|
+
self._run_async(self._async_save(records, expire))
|
|
107
|
+
except Exception as e:
|
|
108
|
+
logger.error(f"Error in save: {e}")
|
|
109
|
+
|
|
110
|
+
def load(self) -> List[Dict[str, Any]]:
|
|
111
|
+
r"""Loads all stored records from the key-value storage system.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
List[Dict[str, Any]]: A list of dictionaries, where each dictionary
|
|
115
|
+
represents a stored record.
|
|
116
|
+
"""
|
|
117
|
+
try:
|
|
118
|
+
return self._run_async(self._async_load())
|
|
119
|
+
except Exception as e:
|
|
120
|
+
logger.error(f"Error in load: {e}")
|
|
121
|
+
return []
|
|
122
|
+
|
|
123
|
+
def clear(self) -> None:
|
|
124
|
+
r"""Removes all records from the key-value storage system."""
|
|
125
|
+
try:
|
|
126
|
+
self._run_async(self._async_clear())
|
|
127
|
+
except Exception as e:
|
|
128
|
+
logger.error(f"Error in clear: {e}")
|
|
129
|
+
|
|
130
|
+
async def _async_save(
|
|
131
|
+
self, records: List[Dict[str, Any]], expire: Optional[int] = None
|
|
132
|
+
) -> None:
|
|
133
|
+
if self._client is None:
|
|
134
|
+
raise ValueError("Redis client is not initialized")
|
|
135
|
+
try:
|
|
136
|
+
value = json.dumps(records)
|
|
137
|
+
if expire:
|
|
138
|
+
await self._client.setex(self._sid, expire, value)
|
|
139
|
+
else:
|
|
140
|
+
await self._client.set(self._sid, value)
|
|
141
|
+
except Exception as e:
|
|
142
|
+
logger.error(f"Error saving records: {e}")
|
|
143
|
+
|
|
144
|
+
async def _async_load(self) -> List[Dict[str, Any]]:
|
|
145
|
+
if self._client is None:
|
|
146
|
+
raise ValueError("Redis client is not initialized")
|
|
147
|
+
try:
|
|
148
|
+
value = await self._client.get(self._sid)
|
|
149
|
+
if value:
|
|
150
|
+
return json.loads(value)
|
|
151
|
+
return []
|
|
152
|
+
except Exception as e:
|
|
153
|
+
logger.error(f"Error loading records: {e}")
|
|
154
|
+
return []
|
|
155
|
+
|
|
156
|
+
async def _async_clear(self) -> None:
|
|
157
|
+
if self._client is None:
|
|
158
|
+
raise ValueError("Redis client is not initialized")
|
|
159
|
+
try:
|
|
160
|
+
await self._client.delete(self._sid)
|
|
161
|
+
except Exception as e:
|
|
162
|
+
logger.error(f"Error clearing records: {e}")
|
|
163
|
+
|
|
164
|
+
def _run_async(self, coro):
|
|
165
|
+
if not self._loop.is_running():
|
|
166
|
+
return self._loop.run_until_complete(coro)
|
|
167
|
+
else:
|
|
168
|
+
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
|
169
|
+
return future.result()
|