camel-ai 0.1.5.4__py3-none-any.whl → 0.1.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (48) hide show
  1. camel/__init__.py +1 -1
  2. camel/agents/knowledge_graph_agent.py +11 -15
  3. camel/agents/task_agent.py +0 -1
  4. camel/configs/__init__.py +12 -0
  5. camel/configs/gemini_config.py +97 -0
  6. camel/configs/litellm_config.py +8 -18
  7. camel/configs/ollama_config.py +85 -0
  8. camel/configs/zhipuai_config.py +78 -0
  9. camel/embeddings/openai_embedding.py +2 -2
  10. camel/functions/search_functions.py +5 -14
  11. camel/functions/slack_functions.py +5 -7
  12. camel/functions/twitter_function.py +3 -8
  13. camel/functions/weather_functions.py +3 -8
  14. camel/interpreters/__init__.py +2 -0
  15. camel/interpreters/docker_interpreter.py +235 -0
  16. camel/loaders/__init__.py +2 -0
  17. camel/loaders/base_io.py +5 -9
  18. camel/loaders/jina_url_reader.py +99 -0
  19. camel/loaders/unstructured_io.py +4 -6
  20. camel/models/__init__.py +2 -0
  21. camel/models/anthropic_model.py +6 -4
  22. camel/models/gemini_model.py +203 -0
  23. camel/models/litellm_model.py +49 -21
  24. camel/models/model_factory.py +4 -2
  25. camel/models/nemotron_model.py +14 -6
  26. camel/models/ollama_model.py +11 -17
  27. camel/models/openai_audio_models.py +10 -2
  28. camel/models/openai_model.py +4 -3
  29. camel/models/zhipuai_model.py +12 -6
  30. camel/retrievers/bm25_retriever.py +3 -8
  31. camel/retrievers/cohere_rerank_retriever.py +3 -5
  32. camel/storages/__init__.py +2 -0
  33. camel/storages/graph_storages/neo4j_graph.py +3 -7
  34. camel/storages/key_value_storages/__init__.py +2 -0
  35. camel/storages/key_value_storages/redis.py +169 -0
  36. camel/storages/vectordb_storages/milvus.py +3 -7
  37. camel/storages/vectordb_storages/qdrant.py +3 -7
  38. camel/toolkits/__init__.py +2 -0
  39. camel/toolkits/code_execution.py +69 -0
  40. camel/toolkits/github_toolkit.py +5 -9
  41. camel/types/enums.py +53 -1
  42. camel/utils/__init__.py +4 -2
  43. camel/utils/async_func.py +42 -0
  44. camel/utils/commons.py +31 -49
  45. camel/utils/token_counting.py +74 -1
  46. {camel_ai-0.1.5.4.dist-info → camel_ai-0.1.5.6.dist-info}/METADATA +12 -3
  47. {camel_ai-0.1.5.4.dist-info → camel_ai-0.1.5.6.dist-info}/RECORD +48 -39
  48. {camel_ai-0.1.5.4.dist-info → camel_ai-0.1.5.6.dist-info}/WHEEL +0 -0
@@ -11,24 +11,25 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
14
+ from typing import Any, Dict, List, Optional
15
15
 
16
16
  from camel.configs import LITELLM_API_PARAMS
17
17
  from camel.messages import OpenAIMessage
18
+ from camel.types import ChatCompletion
18
19
  from camel.utils import LiteLLMTokenCounter
19
20
 
20
- if TYPE_CHECKING:
21
- from litellm.utils import CustomStreamWrapper, ModelResponse
22
-
23
21
 
24
22
  class LiteLLMModel:
25
23
  r"""Constructor for LiteLLM backend with OpenAI compatibility."""
26
24
 
27
- # NOTE: Currently "stream": True is not supported with LiteLLM due to the
28
- # limitation of the current camel design.
25
+ # NOTE: Currently stream mode is not supported.
29
26
 
30
27
  def __init__(
31
- self, model_type: str, model_config_dict: Dict[str, Any]
28
+ self,
29
+ model_type: str,
30
+ model_config_dict: Dict[str, Any],
31
+ api_key: Optional[str] = None,
32
+ url: Optional[str] = None,
32
33
  ) -> None:
33
34
  r"""Constructor for LiteLLM backend.
34
35
 
@@ -37,12 +38,48 @@ class LiteLLMModel:
37
38
  such as GPT-3.5-turbo, Claude-2, etc.
38
39
  model_config_dict (Dict[str, Any]): A dictionary of parameters for
39
40
  the model configuration.
41
+ api_key (Optional[str]): The API key for authenticating with the
42
+ model service. (default: :obj:`None`)
43
+ url (Optional[str]): The url to the model service. (default:
44
+ :obj:`None`)
40
45
  """
41
46
  self.model_type = model_type
42
47
  self.model_config_dict = model_config_dict
43
48
  self._client = None
44
49
  self._token_counter: Optional[LiteLLMTokenCounter] = None
45
50
  self.check_model_config()
51
+ self._url = url
52
+ self._api_key = api_key
53
+
54
+ def _convert_response_from_litellm_to_openai(
55
+ self, response
56
+ ) -> ChatCompletion:
57
+ r"""Converts a response from the LiteLLM format to the OpenAI format.
58
+
59
+ Parameters:
60
+ response (LiteLLMResponse): The response object from LiteLLM.
61
+
62
+ Returns:
63
+ ChatCompletion: The response object in OpenAI's format.
64
+ """
65
+ return ChatCompletion.construct(
66
+ id=response.id,
67
+ choices=[
68
+ {
69
+ "index": response.choices[0].index,
70
+ "message": {
71
+ "role": response.choices[0].message.role,
72
+ "content": response.choices[0].message.content,
73
+ },
74
+ "finish_reason": response.choices[0].finish_reason,
75
+ }
76
+ ],
77
+ created=response.created,
78
+ model=response.model,
79
+ object=response.object,
80
+ system_fingerprint=response.system_fingerprint,
81
+ usage=response.usage,
82
+ )
46
83
 
47
84
  @property
48
85
  def client(self):
@@ -67,7 +104,7 @@ class LiteLLMModel:
67
104
  def run(
68
105
  self,
69
106
  messages: List[OpenAIMessage],
70
- ) -> Union['ModelResponse', 'CustomStreamWrapper']:
107
+ ) -> ChatCompletion:
71
108
  r"""Runs inference of LiteLLM chat completion.
72
109
 
73
110
  Args:
@@ -75,15 +112,16 @@ class LiteLLMModel:
75
112
  in OpenAI format.
76
113
 
77
114
  Returns:
78
- Union[ModelResponse, CustomStreamWrapper]:
79
- `ModelResponse` in the non-stream mode, or
80
- `CustomStreamWrapper` in the stream mode.
115
+ ChatCompletion
81
116
  """
82
117
  response = self.client(
118
+ api_key=self._api_key,
119
+ base_url=self._url,
83
120
  model=self.model_type,
84
121
  messages=messages,
85
122
  **self.model_config_dict,
86
123
  )
124
+ response = self._convert_response_from_litellm_to_openai(response)
87
125
  return response
88
126
 
89
127
  def check_model_config(self):
@@ -100,13 +138,3 @@ class LiteLLMModel:
100
138
  f"Unexpected argument `{param}` is "
101
139
  "input into LiteLLM model backend."
102
140
  )
103
-
104
- @property
105
- def stream(self) -> bool:
106
- r"""Returns whether the model is in stream mode, which sends partial
107
- results each time.
108
-
109
- Returns:
110
- bool: Whether the model is in stream mode.
111
- """
112
- return self.model_config_dict.get('stream', False)
@@ -15,6 +15,7 @@ from typing import Any, Dict, Optional, Union
15
15
 
16
16
  from camel.models.anthropic_model import AnthropicModel
17
17
  from camel.models.base_model import BaseModelBackend
18
+ from camel.models.gemini_model import GeminiModel
18
19
  from camel.models.litellm_model import LiteLLMModel
19
20
  from camel.models.ollama_model import OllamaModel
20
21
  from camel.models.open_source_model import OpenSourceModel
@@ -59,7 +60,6 @@ class ModelFactory:
59
60
  BaseModelBackend: The initialized backend.
60
61
  """
61
62
  model_class: Any
62
-
63
63
  if isinstance(model_type, ModelType):
64
64
  if model_platform.is_open_source and model_type.is_open_source:
65
65
  model_class = OpenSourceModel
@@ -70,6 +70,8 @@ class ModelFactory:
70
70
  model_class = AnthropicModel
71
71
  elif model_platform.is_zhipuai and model_type.is_zhipuai:
72
72
  model_class = ZhipuAIModel
73
+ elif model_platform.is_gemini and model_type.is_gemini:
74
+ model_class = GeminiModel
73
75
  elif model_type == ModelType.STUB:
74
76
  model_class = StubModel
75
77
  else:
@@ -80,6 +82,7 @@ class ModelFactory:
80
82
  elif isinstance(model_type, str):
81
83
  if model_platform.is_ollama:
82
84
  model_class = OllamaModel
85
+ return model_class(model_type, model_config_dict, url)
83
86
  elif model_platform.is_litellm:
84
87
  model_class = LiteLLMModel
85
88
  else:
@@ -89,5 +92,4 @@ class ModelFactory:
89
92
  )
90
93
  else:
91
94
  raise ValueError(f"Invalid model type `{model_type}` provided.")
92
-
93
95
  return model_class(model_type, model_config_dict, api_key, url)
@@ -20,7 +20,7 @@ from camel.messages import OpenAIMessage
20
20
  from camel.types import ChatCompletion, ModelType
21
21
  from camel.utils import (
22
22
  BaseTokenCounter,
23
- model_api_key_required,
23
+ api_keys_required,
24
24
  )
25
25
 
26
26
 
@@ -33,6 +33,7 @@ class NemotronModel:
33
33
  self,
34
34
  model_type: ModelType,
35
35
  api_key: Optional[str] = None,
36
+ url: Optional[str] = None,
36
37
  ) -> None:
37
38
  r"""Constructor for Nvidia backend.
38
39
 
@@ -40,18 +41,25 @@ class NemotronModel:
40
41
  model_type (ModelType): Model for which a backend is created.
41
42
  api_key (Optional[str]): The API key for authenticating with the
42
43
  Nvidia service. (default: :obj:`None`)
44
+ url (Optional[str]): The url to the Nvidia service. (default:
45
+ :obj:`None`)
43
46
  """
44
47
  self.model_type = model_type
45
- url = os.environ.get('NVIDIA_API_BASE_URL', None)
48
+ self._url = url or os.environ.get("NVIDIA_API_BASE_URL")
46
49
  self._api_key = api_key or os.environ.get("NVIDIA_API_KEY")
47
- if not url or not self._api_key:
48
- raise ValueError("The NVIDIA API base url and key should be set.")
50
+ if not self._url or not self._api_key:
51
+ raise ValueError(
52
+ "NVIDIA_API_BASE_URL and NVIDIA_API_KEY should be set."
53
+ )
49
54
  self._client = OpenAI(
50
- timeout=60, max_retries=3, base_url=url, api_key=self._api_key
55
+ timeout=60,
56
+ max_retries=3,
57
+ base_url=self._url,
58
+ api_key=self._api_key,
51
59
  )
52
60
  self._token_counter: Optional[BaseTokenCounter] = None
53
61
 
54
- @model_api_key_required
62
+ @api_keys_required("NVIDIA_API_KEY")
55
63
  def run(
56
64
  self,
57
65
  messages: List[OpenAIMessage],
@@ -11,12 +11,11 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
- import os
15
14
  from typing import Any, Dict, List, Optional, Union
16
15
 
17
16
  from openai import OpenAI, Stream
18
17
 
19
- from camel.configs import OPENAI_API_PARAMS
18
+ from camel.configs import OLLAMA_API_PARAMS
20
19
  from camel.messages import OpenAIMessage
21
20
  from camel.types import ChatCompletion, ChatCompletionChunk, ModelType
22
21
  from camel.utils import BaseTokenCounter, OpenAITokenCounter
@@ -25,39 +24,34 @@ from camel.utils import BaseTokenCounter, OpenAITokenCounter
25
24
  class OllamaModel:
26
25
  r"""Ollama service interface."""
27
26
 
28
- # NOTE: Current `ModelType and `TokenCounter` desigen is not suitable,
29
- # stream mode is not supported
30
-
31
27
  def __init__(
32
28
  self,
33
29
  model_type: str,
34
30
  model_config_dict: Dict[str, Any],
35
- api_key: Optional[str] = None,
36
31
  url: Optional[str] = None,
37
32
  ) -> None:
38
33
  r"""Constructor for Ollama backend with OpenAI compatibility.
39
34
 
35
+ # Reference: https://github.com/ollama/ollama/blob/main/docs/openai.md
36
+
40
37
  Args:
41
38
  model_type (str): Model for which a backend is created.
42
39
  model_config_dict (Dict[str, Any]): A dictionary that will
43
40
  be fed into openai.ChatCompletion.create().
44
- api_key (Optional[str]): The API key for authenticating with the
45
- model service. (default: :obj:`None`)
46
- url (Optional[str]): The url to the model service.
41
+ url (Optional[str]): The url to the model service. (default:
42
+ :obj:`None`)
47
43
  """
48
44
  self.model_type = model_type
49
45
  self.model_config_dict = model_config_dict
50
- self._url = url or os.environ.get('OPENAI_API_BASE_URL')
51
- self._api_key = api_key or os.environ.get("OPENAI_API_KEY")
52
46
  # Use OpenAI cilent as interface call Ollama
53
- # Reference: https://github.com/ollama/ollama/blob/main/docs/openai.md
54
47
  self._client = OpenAI(
55
48
  timeout=60,
56
49
  max_retries=3,
57
- base_url=self._url,
58
- api_key=self._api_key,
50
+ base_url=url,
51
+ api_key="ollama", # required but ignored
59
52
  )
60
53
  self._token_counter: Optional[BaseTokenCounter] = None
54
+ self.check_model_config()
61
55
 
62
56
  @property
63
57
  def token_counter(self) -> BaseTokenCounter:
@@ -74,17 +68,17 @@ class OllamaModel:
74
68
 
75
69
  def check_model_config(self):
76
70
  r"""Check whether the model configuration contains any
77
- unexpected arguments to OpenAI API.
71
+ unexpected arguments to Ollama API.
78
72
 
79
73
  Raises:
80
74
  ValueError: If the model configuration dictionary contains any
81
75
  unexpected arguments to OpenAI API.
82
76
  """
83
77
  for param in self.model_config_dict:
84
- if param not in OPENAI_API_PARAMS:
78
+ if param not in OLLAMA_API_PARAMS:
85
79
  raise ValueError(
86
80
  f"Unexpected argument `{param}` is "
87
- "input into OpenAI model backend."
81
+ "input into Ollama model backend."
88
82
  )
89
83
 
90
84
  def run(
@@ -25,10 +25,18 @@ class OpenAIAudioModels:
25
25
 
26
26
  def __init__(
27
27
  self,
28
+ api_key: Optional[str] = None,
29
+ url: Optional[str] = None,
28
30
  ) -> None:
29
31
  r"""Initialize an instance of OpenAI."""
30
- url = os.environ.get('OPENAI_API_BASE_URL')
31
- self._client = OpenAI(timeout=120, max_retries=3, base_url=url)
32
+ self._url = url or os.environ.get("OPENAI_API_BASE_URL")
33
+ self._api_key = api_key or os.environ.get("OPENAI_API_KEY")
34
+ self._client = OpenAI(
35
+ timeout=120,
36
+ max_retries=3,
37
+ base_url=self._url,
38
+ api_key=self._api_key,
39
+ )
32
40
 
33
41
  def text_to_speech(
34
42
  self,
@@ -23,7 +23,7 @@ from camel.types import ChatCompletion, ChatCompletionChunk, ModelType
23
23
  from camel.utils import (
24
24
  BaseTokenCounter,
25
25
  OpenAITokenCounter,
26
- model_api_key_required,
26
+ api_keys_required,
27
27
  )
28
28
 
29
29
 
@@ -46,7 +46,8 @@ class OpenAIModel(BaseModelBackend):
46
46
  be fed into openai.ChatCompletion.create().
47
47
  api_key (Optional[str]): The API key for authenticating with the
48
48
  OpenAI service. (default: :obj:`None`)
49
- url (Optional[str]): The url to the OpenAI service.
49
+ url (Optional[str]): The url to the OpenAI service. (default:
50
+ :obj:`None`)
50
51
  """
51
52
  super().__init__(model_type, model_config_dict, api_key, url)
52
53
  self._url = url or os.environ.get("OPENAI_API_BASE_URL")
@@ -71,7 +72,7 @@ class OpenAIModel(BaseModelBackend):
71
72
  self._token_counter = OpenAITokenCounter(self.model_type)
72
73
  return self._token_counter
73
74
 
74
- @model_api_key_required
75
+ @api_keys_required("OPENAI_API_KEY")
75
76
  def run(
76
77
  self,
77
78
  messages: List[OpenAIMessage],
@@ -17,14 +17,14 @@ from typing import Any, Dict, List, Optional, Union
17
17
 
18
18
  from openai import OpenAI, Stream
19
19
 
20
- from camel.configs import OPENAI_API_PARAMS
20
+ from camel.configs import ZHIPUAI_API_PARAMS
21
21
  from camel.messages import OpenAIMessage
22
22
  from camel.models import BaseModelBackend
23
23
  from camel.types import ChatCompletion, ChatCompletionChunk, ModelType
24
24
  from camel.utils import (
25
25
  BaseTokenCounter,
26
26
  OpenAITokenCounter,
27
- model_api_key_required,
27
+ api_keys_required,
28
28
  )
29
29
 
30
30
 
@@ -47,10 +47,16 @@ class ZhipuAIModel(BaseModelBackend):
47
47
  be fed into openai.ChatCompletion.create().
48
48
  api_key (Optional[str]): The API key for authenticating with the
49
49
  ZhipuAI service. (default: :obj:`None`)
50
+ url (Optional[str]): The url to the ZhipuAI service. (default:
51
+ :obj:`None`)
50
52
  """
51
53
  super().__init__(model_type, model_config_dict)
52
54
  self._url = url or os.environ.get("ZHIPUAI_API_BASE_URL")
53
55
  self._api_key = api_key or os.environ.get("ZHIPUAI_API_KEY")
56
+ if not self._url or not self._api_key:
57
+ raise ValueError(
58
+ "ZHIPUAI_API_BASE_URL and ZHIPUAI_API_KEY should be set."
59
+ )
54
60
  self._client = OpenAI(
55
61
  timeout=60,
56
62
  max_retries=3,
@@ -59,7 +65,7 @@ class ZhipuAIModel(BaseModelBackend):
59
65
  )
60
66
  self._token_counter: Optional[BaseTokenCounter] = None
61
67
 
62
- @model_api_key_required
68
+ @api_keys_required("ZHIPUAI_API_KEY")
63
69
  def run(
64
70
  self,
65
71
  messages: List[OpenAIMessage],
@@ -104,13 +110,13 @@ class ZhipuAIModel(BaseModelBackend):
104
110
 
105
111
  Raises:
106
112
  ValueError: If the model configuration dictionary contains any
107
- unexpected arguments to OpenAI API.
113
+ unexpected arguments to ZhipuAI API.
108
114
  """
109
115
  for param in self.model_config_dict:
110
- if param not in OPENAI_API_PARAMS:
116
+ if param not in ZHIPUAI_API_PARAMS:
111
117
  raise ValueError(
112
118
  f"Unexpected argument `{param}` is "
113
- "input into OpenAI model backend."
119
+ "input into ZhipuAI model backend."
114
120
  )
115
121
  pass
116
122
 
@@ -17,6 +17,7 @@ import numpy as np
17
17
 
18
18
  from camel.loaders import UnstructuredIO
19
19
  from camel.retrievers import BaseRetriever
20
+ from camel.utils import dependencies_required
20
21
 
21
22
  DEFAULT_TOP_K_RESULTS = 1
22
23
 
@@ -40,16 +41,10 @@ class BM25Retriever(BaseRetriever):
40
41
  https://github.com/dorianbrown/rank_bm25
41
42
  """
42
43
 
44
+ @dependencies_required('rank_bm25')
43
45
  def __init__(self) -> None:
44
46
  r"""Initializes the BM25Retriever."""
45
-
46
- try:
47
- from rank_bm25 import BM25Okapi
48
- except ImportError as e:
49
- raise ImportError(
50
- "Package `rank_bm25` not installed, install by running 'pip "
51
- "install rank_bm25'"
52
- ) from e
47
+ from rank_bm25 import BM25Okapi
53
48
 
54
49
  self.bm25: BM25Okapi = None
55
50
  self.content_input_path: str = ""
@@ -15,6 +15,7 @@ import os
15
15
  from typing import Any, Dict, List, Optional
16
16
 
17
17
  from camel.retrievers import BaseRetriever
18
+ from camel.utils import dependencies_required
18
19
 
19
20
  DEFAULT_TOP_K_RESULTS = 1
20
21
 
@@ -32,6 +33,7 @@ class CohereRerankRetriever(BaseRetriever):
32
33
  https://txt.cohere.com/rerank/
33
34
  """
34
35
 
36
+ @dependencies_required('cohere')
35
37
  def __init__(
36
38
  self,
37
39
  model_name: str = "rerank-multilingual-v2.0",
@@ -56,11 +58,7 @@ class CohereRerankRetriever(BaseRetriever):
56
58
  ValueError: If the API key is neither passed as an argument nor
57
59
  set in the environment variable.
58
60
  """
59
-
60
- try:
61
- import cohere
62
- except ImportError as e:
63
- raise ImportError("Package 'cohere' is not installed") from e
61
+ import cohere
64
62
 
65
63
  try:
66
64
  self.api_key = api_key or os.environ["COHERE_API_KEY"]
@@ -17,6 +17,7 @@ from .graph_storages.neo4j_graph import Neo4jGraph
17
17
  from .key_value_storages.base import BaseKeyValueStorage
18
18
  from .key_value_storages.in_memory import InMemoryKeyValueStorage
19
19
  from .key_value_storages.json import JsonStorage
20
+ from .key_value_storages.redis import RedisStorage
20
21
  from .vectordb_storages.base import (
21
22
  BaseVectorStorage,
22
23
  VectorDBQuery,
@@ -30,6 +31,7 @@ __all__ = [
30
31
  'BaseKeyValueStorage',
31
32
  'InMemoryKeyValueStorage',
32
33
  'JsonStorage',
34
+ 'RedisStorage',
33
35
  'VectorRecord',
34
36
  'BaseVectorStorage',
35
37
  'VectorDBQuery',
@@ -16,6 +16,7 @@ from hashlib import md5
16
16
  from typing import Any, Dict, List, Optional
17
17
 
18
18
  from camel.storages.graph_storages import BaseGraphStorage, GraphElement
19
+ from camel.utils import dependencies_required
19
20
 
20
21
  logger = logging.getLogger(__name__)
21
22
 
@@ -81,6 +82,7 @@ class Neo4jGraph(BaseGraphStorage):
81
82
  than `LIST_LIMIT` elements from results. Defaults to `False`.
82
83
  """
83
84
 
85
+ @dependencies_required('neo4j')
84
86
  def __init__(
85
87
  self,
86
88
  url: str,
@@ -91,13 +93,7 @@ class Neo4jGraph(BaseGraphStorage):
91
93
  truncate: bool = False,
92
94
  ) -> None:
93
95
  r"""Create a new Neo4j graph instance."""
94
- try:
95
- import neo4j
96
- except ImportError:
97
- raise ValueError(
98
- "Could not import neo4j python package. "
99
- "Please install it with `pip install neo4j`."
100
- )
96
+ import neo4j
101
97
 
102
98
  self.driver = neo4j.GraphDatabase.driver(
103
99
  url, auth=(username, password)
@@ -15,9 +15,11 @@
15
15
  from .base import BaseKeyValueStorage
16
16
  from .in_memory import InMemoryKeyValueStorage
17
17
  from .json import JsonStorage
18
+ from .redis import RedisStorage
18
19
 
19
20
  __all__ = [
20
21
  'BaseKeyValueStorage',
21
22
  'InMemoryKeyValueStorage',
22
23
  'JsonStorage',
24
+ 'RedisStorage',
23
25
  ]
@@ -0,0 +1,169 @@
1
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2
+ # Licensed under the Apache License, Version 2.0 (the “License”);
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an “AS IS” BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+
15
+ import asyncio
16
+ import json
17
+ import logging
18
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
19
+
20
+ from camel.storages.key_value_storages import BaseKeyValueStorage
21
+
22
+ if TYPE_CHECKING:
23
+ from redis.asyncio import Redis
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class RedisStorage(BaseKeyValueStorage):
29
+ r"""A concrete implementation of the :obj:`BaseCacheStorage` using Redis as
30
+ the backend. This is suitable for distributed cache systems that require
31
+ persistence and high availability.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ sid: str,
37
+ url: str = "redis://localhost:6379",
38
+ loop: Optional[asyncio.AbstractEventLoop] = None,
39
+ **kwargs,
40
+ ) -> None:
41
+ r"""Initializes the RedisStorage instance with the provided URL and
42
+ options.
43
+
44
+ Args:
45
+ sid (str): The ID for the storage instance to identify the
46
+ record space.
47
+ url (str): The URL for connecting to the Redis server.
48
+ **kwargs: Additional keyword arguments for Redis client
49
+ configuration.
50
+
51
+ Raises:
52
+ ImportError: If the `redis.asyncio` module is not installed.
53
+ """
54
+ try:
55
+ import redis.asyncio as aredis
56
+ except ImportError as exc:
57
+ logger.error(
58
+ "Please install `redis` first. You can install it by "
59
+ "running `pip install redis`."
60
+ )
61
+ raise exc
62
+
63
+ self._client: Optional[aredis.Redis] = None
64
+ self._url = url
65
+ self._sid = sid
66
+ self._loop = loop or asyncio.get_event_loop()
67
+
68
+ self._create_client(**kwargs)
69
+
70
+ def __enter__(self):
71
+ return self
72
+
73
+ def __exit__(self, exc_type, exc, tb):
74
+ self._run_async(self.close())
75
+
76
+ async def close(self) -> None:
77
+ r"""Closes the Redis client asynchronously."""
78
+ if self._client:
79
+ await self._client.close()
80
+
81
+ def _create_client(self, **kwargs) -> None:
82
+ r"""Creates the Redis client with the provided URL and options.
83
+
84
+ Args:
85
+ **kwargs: Additional keyword arguments for Redis client
86
+ configuration.
87
+ """
88
+ import redis.asyncio as aredis
89
+
90
+ self._client = aredis.from_url(self._url, **kwargs)
91
+
92
+ @property
93
+ def client(self) -> Optional["Redis"]:
94
+ r"""Returns the Redis client instance.
95
+
96
+ Returns:
97
+ redis.asyncio.Redis: The Redis client instance.
98
+ """
99
+ return self._client
100
+
101
+ def save(
102
+ self, records: List[Dict[str, Any]], expire: Optional[int] = None
103
+ ) -> None:
104
+ r"""Saves a batch of records to the key-value storage system."""
105
+ try:
106
+ self._run_async(self._async_save(records, expire))
107
+ except Exception as e:
108
+ logger.error(f"Error in save: {e}")
109
+
110
+ def load(self) -> List[Dict[str, Any]]:
111
+ r"""Loads all stored records from the key-value storage system.
112
+
113
+ Returns:
114
+ List[Dict[str, Any]]: A list of dictionaries, where each dictionary
115
+ represents a stored record.
116
+ """
117
+ try:
118
+ return self._run_async(self._async_load())
119
+ except Exception as e:
120
+ logger.error(f"Error in load: {e}")
121
+ return []
122
+
123
+ def clear(self) -> None:
124
+ r"""Removes all records from the key-value storage system."""
125
+ try:
126
+ self._run_async(self._async_clear())
127
+ except Exception as e:
128
+ logger.error(f"Error in clear: {e}")
129
+
130
+ async def _async_save(
131
+ self, records: List[Dict[str, Any]], expire: Optional[int] = None
132
+ ) -> None:
133
+ if self._client is None:
134
+ raise ValueError("Redis client is not initialized")
135
+ try:
136
+ value = json.dumps(records)
137
+ if expire:
138
+ await self._client.setex(self._sid, expire, value)
139
+ else:
140
+ await self._client.set(self._sid, value)
141
+ except Exception as e:
142
+ logger.error(f"Error saving records: {e}")
143
+
144
+ async def _async_load(self) -> List[Dict[str, Any]]:
145
+ if self._client is None:
146
+ raise ValueError("Redis client is not initialized")
147
+ try:
148
+ value = await self._client.get(self._sid)
149
+ if value:
150
+ return json.loads(value)
151
+ return []
152
+ except Exception as e:
153
+ logger.error(f"Error loading records: {e}")
154
+ return []
155
+
156
+ async def _async_clear(self) -> None:
157
+ if self._client is None:
158
+ raise ValueError("Redis client is not initialized")
159
+ try:
160
+ await self._client.delete(self._sid)
161
+ except Exception as e:
162
+ logger.error(f"Error clearing records: {e}")
163
+
164
+ def _run_async(self, coro):
165
+ if not self._loop.is_running():
166
+ return self._loop.run_until_complete(coro)
167
+ else:
168
+ future = asyncio.run_coroutine_threadsafe(coro, self._loop)
169
+ return future.result()