camel-ai 0.1.5.7__py3-none-any.whl → 0.1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (44) hide show
  1. camel/__init__.py +1 -1
  2. camel/agents/chat_agent.py +2 -2
  3. camel/agents/critic_agent.py +1 -1
  4. camel/agents/deductive_reasoner_agent.py +4 -4
  5. camel/agents/embodied_agent.py +1 -1
  6. camel/agents/knowledge_graph_agent.py +2 -2
  7. camel/agents/role_assignment_agent.py +1 -1
  8. camel/agents/search_agent.py +4 -5
  9. camel/agents/task_agent.py +5 -5
  10. camel/configs/__init__.py +9 -0
  11. camel/configs/groq_config.py +119 -0
  12. camel/configs/mistral_config.py +81 -0
  13. camel/configs/ollama_config.py +1 -1
  14. camel/configs/vllm_config.py +103 -0
  15. camel/embeddings/__init__.py +2 -0
  16. camel/embeddings/mistral_embedding.py +89 -0
  17. camel/interpreters/__init__.py +2 -0
  18. camel/interpreters/ipython_interpreter.py +167 -0
  19. camel/models/__init__.py +8 -0
  20. camel/models/anthropic_model.py +7 -2
  21. camel/models/azure_openai_model.py +152 -0
  22. camel/models/base_model.py +5 -1
  23. camel/models/gemini_model.py +14 -2
  24. camel/models/groq_model.py +131 -0
  25. camel/models/litellm_model.py +10 -4
  26. camel/models/mistral_model.py +169 -0
  27. camel/models/model_factory.py +30 -3
  28. camel/models/ollama_model.py +5 -2
  29. camel/models/open_source_model.py +11 -3
  30. camel/models/openai_model.py +7 -2
  31. camel/models/stub_model.py +4 -4
  32. camel/models/vllm_model.py +138 -0
  33. camel/models/zhipuai_model.py +7 -3
  34. camel/prompts/__init__.py +2 -2
  35. camel/prompts/task_prompt_template.py +4 -4
  36. camel/prompts/{descripte_video_prompt.py → video_description_prompt.py} +1 -1
  37. camel/retrievers/auto_retriever.py +2 -2
  38. camel/storages/graph_storages/neo4j_graph.py +5 -0
  39. camel/types/enums.py +152 -35
  40. camel/utils/__init__.py +2 -0
  41. camel/utils/token_counting.py +148 -40
  42. {camel_ai-0.1.5.7.dist-info → camel_ai-0.1.6.0.dist-info}/METADATA +42 -3
  43. {camel_ai-0.1.5.7.dist-info → camel_ai-0.1.6.0.dist-info}/RECORD +44 -35
  44. {camel_ai-0.1.5.7.dist-info → camel_ai-0.1.6.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,169 @@
1
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2
+ # Licensed under the Apache License, Version 2.0 (the “License”);
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an “AS IS” BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+ import os
15
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
16
+
17
+ if TYPE_CHECKING:
18
+ from mistralai.models.chat_completion import ChatCompletionResponse
19
+
20
+ from camel.configs import MISTRAL_API_PARAMS
21
+ from camel.messages import OpenAIMessage
22
+ from camel.models import BaseModelBackend
23
+ from camel.types import ChatCompletion, ModelType
24
+ from camel.utils import (
25
+ BaseTokenCounter,
26
+ MistralTokenCounter,
27
+ api_keys_required,
28
+ )
29
+
30
+
31
+ class MistralModel(BaseModelBackend):
32
+ r"""Mistral API in a unified BaseModelBackend interface."""
33
+
34
+ # TODO: Support tool calling.
35
+
36
+ def __init__(
37
+ self,
38
+ model_type: ModelType,
39
+ model_config_dict: Dict[str, Any],
40
+ api_key: Optional[str] = None,
41
+ url: Optional[str] = None,
42
+ token_counter: Optional[BaseTokenCounter] = None,
43
+ ) -> None:
44
+ r"""Constructor for Mistral backend.
45
+
46
+ Args:
47
+ model_type (ModelType): Model for which a backend is created,
48
+ one of MISTRAL_* series.
49
+ model_config_dict (Dict[str, Any]): A dictionary that will
50
+ be fed into `MistralClient.chat`.
51
+ api_key (Optional[str]): The API key for authenticating with the
52
+ mistral service. (default: :obj:`None`)
53
+ url (Optional[str]): The url to the mistral service.
54
+ token_counter (Optional[BaseTokenCounter]): Token counter to use
55
+ for the model. If not provided, `MistralTokenCounter` will be
56
+ used.
57
+ """
58
+ super().__init__(
59
+ model_type, model_config_dict, api_key, url, token_counter
60
+ )
61
+ self._api_key = api_key or os.environ.get("MISTRAL_API_KEY")
62
+
63
+ from mistralai.client import MistralClient
64
+
65
+ self._client = MistralClient(api_key=self._api_key)
66
+ self._token_counter: Optional[BaseTokenCounter] = None
67
+
68
+ def _convert_response_from_mistral_to_openai(
69
+ self, response: 'ChatCompletionResponse'
70
+ ) -> ChatCompletion:
71
+ tool_calls = None
72
+ if response.choices[0].message.tool_calls is not None:
73
+ tool_calls = [
74
+ dict(
75
+ id=tool_call.id,
76
+ function={
77
+ "name": tool_call.function.name,
78
+ "arguments": tool_call.function.arguments,
79
+ },
80
+ type=tool_call.type.value,
81
+ )
82
+ for tool_call in response.choices[0].message.tool_calls
83
+ ]
84
+
85
+ obj = ChatCompletion.construct(
86
+ id=response.id,
87
+ choices=[
88
+ dict(
89
+ index=response.choices[0].index,
90
+ message={
91
+ "role": response.choices[0].message.role,
92
+ "content": response.choices[0].message.content,
93
+ "tool_calls": tool_calls,
94
+ },
95
+ finish_reason=response.choices[0].finish_reason.value
96
+ if response.choices[0].finish_reason
97
+ else None,
98
+ )
99
+ ],
100
+ created=response.created,
101
+ model=response.model,
102
+ object="chat.completion",
103
+ usage=response.usage,
104
+ )
105
+
106
+ return obj
107
+
108
+ @property
109
+ def token_counter(self) -> BaseTokenCounter:
110
+ r"""Initialize the token counter for the model backend.
111
+
112
+ Returns:
113
+ BaseTokenCounter: The token counter following the model's
114
+ tokenization style.
115
+ """
116
+ if not self._token_counter:
117
+ self._token_counter = MistralTokenCounter(
118
+ model_type=self.model_type
119
+ )
120
+ return self._token_counter
121
+
122
+ @api_keys_required("MISTRAL_API_KEY")
123
+ def run(
124
+ self,
125
+ messages: List[OpenAIMessage],
126
+ ) -> ChatCompletion:
127
+ r"""Runs inference of Mistral chat completion.
128
+
129
+ Args:
130
+ messages (List[OpenAIMessage]): Message list with the chat history
131
+ in OpenAI API format.
132
+
133
+ Returns:
134
+ ChatCompletion
135
+ """
136
+ response = self._client.chat(
137
+ messages=messages,
138
+ model=self.model_type.value,
139
+ **self.model_config_dict,
140
+ )
141
+
142
+ response = self._convert_response_from_mistral_to_openai(response) # type:ignore[assignment]
143
+
144
+ return response # type:ignore[return-value]
145
+
146
+ def check_model_config(self):
147
+ r"""Check whether the model configuration contains any
148
+ unexpected arguments to Mistral API.
149
+
150
+ Raises:
151
+ ValueError: If the model configuration dictionary contains any
152
+ unexpected arguments to Mistral API.
153
+ """
154
+ for param in self.model_config_dict:
155
+ if param not in MISTRAL_API_PARAMS:
156
+ raise ValueError(
157
+ f"Unexpected argument `{param}` is "
158
+ "input into Mistral model backend."
159
+ )
160
+
161
+ @property
162
+ def stream(self) -> bool:
163
+ r"""Returns whether the model is in stream mode, which sends partial
164
+ results each time. Mistral doesn't support stream mode.
165
+
166
+ Returns:
167
+ bool: Whether the model is in stream mode.
168
+ """
169
+ return False
@@ -14,15 +14,20 @@
14
14
  from typing import Any, Dict, Optional, Union
15
15
 
16
16
  from camel.models.anthropic_model import AnthropicModel
17
+ from camel.models.azure_openai_model import AzureOpenAIModel
17
18
  from camel.models.base_model import BaseModelBackend
18
19
  from camel.models.gemini_model import GeminiModel
20
+ from camel.models.groq_model import GroqModel
19
21
  from camel.models.litellm_model import LiteLLMModel
22
+ from camel.models.mistral_model import MistralModel
20
23
  from camel.models.ollama_model import OllamaModel
21
24
  from camel.models.open_source_model import OpenSourceModel
22
25
  from camel.models.openai_model import OpenAIModel
23
26
  from camel.models.stub_model import StubModel
27
+ from camel.models.vllm_model import VLLMModel
24
28
  from camel.models.zhipuai_model import ZhipuAIModel
25
29
  from camel.types import ModelPlatformType, ModelType
30
+ from camel.utils import BaseTokenCounter
26
31
 
27
32
 
28
33
  class ModelFactory:
@@ -37,6 +42,7 @@ class ModelFactory:
37
42
  model_platform: ModelPlatformType,
38
43
  model_type: Union[ModelType, str],
39
44
  model_config_dict: Dict,
45
+ token_counter: Optional[BaseTokenCounter] = None,
40
46
  api_key: Optional[str] = None,
41
47
  url: Optional[str] = None,
42
48
  ) -> BaseModelBackend:
@@ -49,6 +55,10 @@ class ModelFactory:
49
55
  created can be a `str` for open source platforms.
50
56
  model_config_dict (Dict): A dictionary that will be fed into
51
57
  the backend constructor.
58
+ token_counter (Optional[BaseTokenCounter]): Token counter to use
59
+ for the model. If not provided, OpenAITokenCounter(ModelType.
60
+ GPT_3_5_TURBO) will be used if the model platform didn't
61
+ provide official token counter.
52
62
  api_key (Optional[str]): The API key for authenticating with the
53
63
  model service.
54
64
  url (Optional[str]): The url to the model service.
@@ -63,15 +73,23 @@ class ModelFactory:
63
73
  if isinstance(model_type, ModelType):
64
74
  if model_platform.is_open_source and model_type.is_open_source:
65
75
  model_class = OpenSourceModel
66
- return model_class(model_type, model_config_dict, url)
76
+ return model_class(
77
+ model_type, model_config_dict, url, token_counter
78
+ )
67
79
  if model_platform.is_openai and model_type.is_openai:
68
80
  model_class = OpenAIModel
81
+ elif model_platform.is_azure and model_type.is_azure_openai:
82
+ model_class = AzureOpenAIModel
69
83
  elif model_platform.is_anthropic and model_type.is_anthropic:
70
84
  model_class = AnthropicModel
85
+ elif model_type.is_groq:
86
+ model_class = GroqModel
71
87
  elif model_platform.is_zhipuai and model_type.is_zhipuai:
72
88
  model_class = ZhipuAIModel
73
89
  elif model_platform.is_gemini and model_type.is_gemini:
74
90
  model_class = GeminiModel
91
+ elif model_platform.is_mistral and model_type.is_mistral:
92
+ model_class = MistralModel
75
93
  elif model_type == ModelType.STUB:
76
94
  model_class = StubModel
77
95
  else:
@@ -82,7 +100,14 @@ class ModelFactory:
82
100
  elif isinstance(model_type, str):
83
101
  if model_platform.is_ollama:
84
102
  model_class = OllamaModel
85
- return model_class(model_type, model_config_dict, url)
103
+ return model_class(
104
+ model_type, model_config_dict, url, token_counter
105
+ )
106
+ elif model_platform.is_vllm:
107
+ model_class = VLLMModel
108
+ return model_class(
109
+ model_type, model_config_dict, url, api_key, token_counter
110
+ )
86
111
  elif model_platform.is_litellm:
87
112
  model_class = LiteLLMModel
88
113
  else:
@@ -92,4 +117,6 @@ class ModelFactory:
92
117
  )
93
118
  else:
94
119
  raise ValueError(f"Invalid model type `{model_type}` provided.")
95
- return model_class(model_type, model_config_dict, api_key, url)
120
+ return model_class(
121
+ model_type, model_config_dict, api_key, url, token_counter
122
+ )
@@ -29,6 +29,7 @@ class OllamaModel:
29
29
  model_type: str,
30
30
  model_config_dict: Dict[str, Any],
31
31
  url: Optional[str] = None,
32
+ token_counter: Optional[BaseTokenCounter] = None,
32
33
  ) -> None:
33
34
  r"""Constructor for Ollama backend with OpenAI compatibility.
34
35
 
@@ -40,6 +41,9 @@ class OllamaModel:
40
41
  be fed into openai.ChatCompletion.create().
41
42
  url (Optional[str]): The url to the model service. (default:
42
43
  :obj:`None`)
44
+ token_counter (Optional[BaseTokenCounter]): Token counter to use
45
+ for the model. If not provided, `OpenAITokenCounter(ModelType.
46
+ GPT_3_5_TURBO)` will be used.
43
47
  """
44
48
  self.model_type = model_type
45
49
  self.model_config_dict = model_config_dict
@@ -50,7 +54,7 @@ class OllamaModel:
50
54
  base_url=url,
51
55
  api_key="ollama", # required but ignored
52
56
  )
53
- self._token_counter: Optional[BaseTokenCounter] = None
57
+ self._token_counter = token_counter
54
58
  self.check_model_config()
55
59
 
56
60
  @property
@@ -61,7 +65,6 @@ class OllamaModel:
61
65
  BaseTokenCounter: The token counter following the model's
62
66
  tokenization style.
63
67
  """
64
- # NOTE: Use OpenAITokenCounter temporarily
65
68
  if not self._token_counter:
66
69
  self._token_counter = OpenAITokenCounter(ModelType.GPT_3_5_TURBO)
67
70
  return self._token_counter
@@ -19,7 +19,10 @@ from camel.configs import OPENAI_API_PARAMS
19
19
  from camel.messages import OpenAIMessage
20
20
  from camel.models import BaseModelBackend
21
21
  from camel.types import ChatCompletion, ChatCompletionChunk, ModelType
22
- from camel.utils import BaseTokenCounter, OpenSourceTokenCounter
22
+ from camel.utils import (
23
+ BaseTokenCounter,
24
+ OpenSourceTokenCounter,
25
+ )
23
26
 
24
27
 
25
28
  class OpenSourceModel(BaseModelBackend):
@@ -33,6 +36,7 @@ class OpenSourceModel(BaseModelBackend):
33
36
  model_config_dict: Dict[str, Any],
34
37
  api_key: Optional[str] = None,
35
38
  url: Optional[str] = None,
39
+ token_counter: Optional[BaseTokenCounter] = None,
36
40
  ) -> None:
37
41
  r"""Constructor for model backends of Open-source models.
38
42
 
@@ -43,9 +47,13 @@ class OpenSourceModel(BaseModelBackend):
43
47
  api_key (Optional[str]): The API key for authenticating with the
44
48
  model service. (ignored for open-source models)
45
49
  url (Optional[str]): The url to the model service.
50
+ token_counter (Optional[BaseTokenCounter]): Token counter to use
51
+ for the model. If not provided, `OpenSourceTokenCounter` will
52
+ be used.
46
53
  """
47
- super().__init__(model_type, model_config_dict, api_key, url)
48
- self._token_counter: Optional[BaseTokenCounter] = None
54
+ super().__init__(
55
+ model_type, model_config_dict, api_key, url, token_counter
56
+ )
49
57
 
50
58
  # Check whether the input model type is open-source
51
59
  if not model_type.is_open_source:
@@ -36,6 +36,7 @@ class OpenAIModel(BaseModelBackend):
36
36
  model_config_dict: Dict[str, Any],
37
37
  api_key: Optional[str] = None,
38
38
  url: Optional[str] = None,
39
+ token_counter: Optional[BaseTokenCounter] = None,
39
40
  ) -> None:
40
41
  r"""Constructor for OpenAI backend.
41
42
 
@@ -48,8 +49,13 @@ class OpenAIModel(BaseModelBackend):
48
49
  OpenAI service. (default: :obj:`None`)
49
50
  url (Optional[str]): The url to the OpenAI service. (default:
50
51
  :obj:`None`)
52
+ token_counter (Optional[BaseTokenCounter]): Token counter to use
53
+ for the model. If not provided, `OpenAITokenCounter` will
54
+ be used.
51
55
  """
52
- super().__init__(model_type, model_config_dict, api_key, url)
56
+ super().__init__(
57
+ model_type, model_config_dict, api_key, url, token_counter
58
+ )
53
59
  self._url = url or os.environ.get("OPENAI_API_BASE_URL")
54
60
  self._api_key = api_key or os.environ.get("OPENAI_API_KEY")
55
61
  self._client = OpenAI(
@@ -58,7 +64,6 @@ class OpenAIModel(BaseModelBackend):
58
64
  base_url=self._url,
59
65
  api_key=self._api_key,
60
66
  )
61
- self._token_counter: Optional[BaseTokenCounter] = None
62
67
 
63
68
  @property
64
69
  def token_counter(self) -> BaseTokenCounter:
@@ -55,12 +55,12 @@ class StubModel(BaseModelBackend):
55
55
  model_config_dict: Dict[str, Any],
56
56
  api_key: Optional[str] = None,
57
57
  url: Optional[str] = None,
58
+ token_counter: Optional[BaseTokenCounter] = None,
58
59
  ) -> None:
59
60
  r"""All arguments are unused for the dummy model."""
60
- super().__init__(model_type, model_config_dict, api_key, url)
61
- self._token_counter: Optional[BaseTokenCounter] = None
62
- self._api_key = api_key
63
- self._url = url
61
+ super().__init__(
62
+ model_type, model_config_dict, api_key, url, token_counter
63
+ )
64
64
 
65
65
  @property
66
66
  def token_counter(self) -> BaseTokenCounter:
@@ -0,0 +1,138 @@
1
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2
+ # Licensed under the Apache License, Version 2.0 (the “License”);
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an “AS IS” BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+ from typing import Any, Dict, List, Optional, Union
15
+
16
+ from openai import OpenAI, Stream
17
+
18
+ from camel.configs import VLLM_API_PARAMS
19
+ from camel.messages import OpenAIMessage
20
+ from camel.types import ChatCompletion, ChatCompletionChunk, ModelType
21
+ from camel.utils import BaseTokenCounter, OpenAITokenCounter
22
+
23
+
24
+ # flake8: noqa: E501
25
+ class VLLMModel:
26
+ r"""vLLM service interface."""
27
+
28
+ def __init__(
29
+ self,
30
+ model_type: str,
31
+ model_config_dict: Dict[str, Any],
32
+ url: Optional[str] = None,
33
+ api_key: Optional[str] = None,
34
+ token_counter: Optional[BaseTokenCounter] = None,
35
+ ) -> None:
36
+ r"""Constructor for vLLM backend with OpenAI compatibility.
37
+
38
+ # Reference: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
39
+
40
+ Args:
41
+ model_type (str): Model for which a backend is created.
42
+ model_config_dict (Dict[str, Any]): A dictionary that will
43
+ be fed into openai.ChatCompletion.create().
44
+ url (Optional[str]): The url to the model service. (default:
45
+ :obj:`None`)
46
+ api_key (Optional[str]): The API key for authenticating with the
47
+ model service.
48
+ token_counter (Optional[BaseTokenCounter]): Token counter to use
49
+ for the model. If not provided, `OpenAITokenCounter(ModelType.
50
+ GPT_3_5_TURBO)` will be used.
51
+ """
52
+ self.model_type = model_type
53
+ self.model_config_dict = model_config_dict
54
+ # Use OpenAI cilent as interface call vLLM
55
+ self._client = OpenAI(
56
+ timeout=60,
57
+ max_retries=3,
58
+ base_url=url,
59
+ api_key=api_key,
60
+ )
61
+ self._token_counter = token_counter
62
+ self.check_model_config()
63
+
64
+ @property
65
+ def token_counter(self) -> BaseTokenCounter:
66
+ r"""Initialize the token counter for the model backend.
67
+
68
+ Returns:
69
+ BaseTokenCounter: The token counter following the model's
70
+ tokenization style.
71
+ """
72
+ if not self._token_counter:
73
+ self._token_counter = OpenAITokenCounter(ModelType.GPT_3_5_TURBO)
74
+ return self._token_counter
75
+
76
+ def check_model_config(self):
77
+ r"""Check whether the model configuration contains any
78
+ unexpected arguments to vLLM API.
79
+
80
+ Raises:
81
+ ValueError: If the model configuration dictionary contains any
82
+ unexpected arguments to OpenAI API.
83
+ """
84
+ for param in self.model_config_dict:
85
+ if param not in VLLM_API_PARAMS:
86
+ raise ValueError(
87
+ f"Unexpected argument `{param}` is "
88
+ "input into vLLM model backend."
89
+ )
90
+
91
+ def run(
92
+ self,
93
+ messages: List[OpenAIMessage],
94
+ ) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
95
+ r"""Runs inference of OpenAI chat completion.
96
+
97
+ Args:
98
+ messages (List[OpenAIMessage]): Message list with the chat history
99
+ in OpenAI API format.
100
+
101
+ Returns:
102
+ Union[ChatCompletion, Stream[ChatCompletionChunk]]:
103
+ `ChatCompletion` in the non-stream mode, or
104
+ `Stream[ChatCompletionChunk]` in the stream mode.
105
+ """
106
+
107
+ response = self._client.chat.completions.create(
108
+ messages=messages,
109
+ model=self.model_type,
110
+ **self.model_config_dict,
111
+ )
112
+ return response
113
+
114
+ @property
115
+ def token_limit(self) -> int:
116
+ """Returns the maximum token limit for the given model.
117
+
118
+ Returns:
119
+ int: The maximum token limit for the given model.
120
+ """
121
+ max_tokens = self.model_config_dict.get("max_tokens")
122
+ if isinstance(max_tokens, int):
123
+ return max_tokens
124
+ print(
125
+ "Must set `max_tokens` as an integer in `model_config_dict` when"
126
+ " setting up the model. Using 4096 as default value."
127
+ )
128
+ return 4096
129
+
130
+ @property
131
+ def stream(self) -> bool:
132
+ r"""Returns whether the model is in stream mode, which sends partial
133
+ results each time.
134
+
135
+ Returns:
136
+ bool: Whether the model is in stream mode.
137
+ """
138
+ return self.model_config_dict.get('stream', False)
@@ -37,6 +37,7 @@ class ZhipuAIModel(BaseModelBackend):
37
37
  model_config_dict: Dict[str, Any],
38
38
  api_key: Optional[str] = None,
39
39
  url: Optional[str] = None,
40
+ token_counter: Optional[BaseTokenCounter] = None,
40
41
  ) -> None:
41
42
  r"""Constructor for ZhipuAI backend.
42
43
 
@@ -49,8 +50,13 @@ class ZhipuAIModel(BaseModelBackend):
49
50
  ZhipuAI service. (default: :obj:`None`)
50
51
  url (Optional[str]): The url to the ZhipuAI service. (default:
51
52
  :obj:`None`)
53
+ token_counter (Optional[BaseTokenCounter]): Token counter to use
54
+ for the model. If not provided, `OpenAITokenCounter(ModelType.
55
+ GPT_3_5_TURBO)` will be used.
52
56
  """
53
- super().__init__(model_type, model_config_dict)
57
+ super().__init__(
58
+ model_type, model_config_dict, api_key, url, token_counter
59
+ )
54
60
  self._url = url or os.environ.get("ZHIPUAI_API_BASE_URL")
55
61
  self._api_key = api_key or os.environ.get("ZHIPUAI_API_KEY")
56
62
  if not self._url or not self._api_key:
@@ -63,7 +69,6 @@ class ZhipuAIModel(BaseModelBackend):
63
69
  api_key=self._api_key,
64
70
  base_url=self._url,
65
71
  )
66
- self._token_counter: Optional[BaseTokenCounter] = None
67
72
 
68
73
  @api_keys_required("ZHIPUAI_API_KEY")
69
74
  def run(
@@ -100,7 +105,6 @@ class ZhipuAIModel(BaseModelBackend):
100
105
  """
101
106
 
102
107
  if not self._token_counter:
103
- # It's a temporary setting for token counter.
104
108
  self._token_counter = OpenAITokenCounter(ModelType.GPT_3_5_TURBO)
105
109
  return self._token_counter
106
110
 
camel/prompts/__init__.py CHANGED
@@ -14,7 +14,6 @@
14
14
  from .ai_society import AISocietyPromptTemplateDict
15
15
  from .base import CodePrompt, TextPrompt, TextPromptDict
16
16
  from .code import CodePromptTemplateDict
17
- from .descripte_video_prompt import DescriptionVideoPromptTemplateDict
18
17
  from .evaluation import EvaluationPromptTemplateDict
19
18
  from .generate_text_embedding_data import (
20
19
  GenerateTextEmbeddingDataPromptTemplateDict,
@@ -26,6 +25,7 @@ from .role_description_prompt_template import RoleDescriptionPromptTemplateDict
26
25
  from .solution_extraction import SolutionExtractionPromptTemplateDict
27
26
  from .task_prompt_template import TaskPromptTemplateDict
28
27
  from .translation import TranslationPromptTemplateDict
28
+ from .video_description_prompt import VideoDescriptionPromptTemplateDict
29
29
 
30
30
  __all__ = [
31
31
  'TextPrompt',
@@ -42,5 +42,5 @@ __all__ = [
42
42
  'SolutionExtractionPromptTemplateDict',
43
43
  'GenerateTextEmbeddingDataPromptTemplateDict',
44
44
  'ObjectRecognitionPromptTemplateDict',
45
- 'DescriptionVideoPromptTemplateDict',
45
+ 'VideoDescriptionPromptTemplateDict',
46
46
  ]
@@ -18,9 +18,6 @@ from camel.prompts.ai_society import (
18
18
  TextPromptDict,
19
19
  )
20
20
  from camel.prompts.code import CodePromptTemplateDict
21
- from camel.prompts.descripte_video_prompt import (
22
- DescriptionVideoPromptTemplateDict,
23
- )
24
21
  from camel.prompts.evaluation import (
25
22
  EvaluationPromptTemplateDict,
26
23
  )
@@ -38,6 +35,9 @@ from camel.prompts.solution_extraction import (
38
35
  SolutionExtractionPromptTemplateDict,
39
36
  )
40
37
  from camel.prompts.translation import TranslationPromptTemplateDict
38
+ from camel.prompts.video_description_prompt import (
39
+ VideoDescriptionPromptTemplateDict,
40
+ )
41
41
  from camel.types import TaskType
42
42
 
43
43
 
@@ -64,6 +64,6 @@ class TaskPromptTemplateDict(Dict[Any, TextPromptDict]):
64
64
  TaskType.ROLE_DESCRIPTION: RoleDescriptionPromptTemplateDict(),
65
65
  TaskType.OBJECT_RECOGNITION: ObjectRecognitionPromptTemplateDict(), # noqa: E501
66
66
  TaskType.GENERATE_TEXT_EMBEDDING_DATA: GenerateTextEmbeddingDataPromptTemplateDict(), # noqa: E501
67
- TaskType.VIDEO_DESCRIPTION: DescriptionVideoPromptTemplateDict(), # noqa: E501
67
+ TaskType.VIDEO_DESCRIPTION: VideoDescriptionPromptTemplateDict(), # noqa: E501
68
68
  }
69
69
  )
@@ -18,7 +18,7 @@ from camel.types import RoleType
18
18
 
19
19
 
20
20
  # flake8: noqa :E501
21
- class DescriptionVideoPromptTemplateDict(TextPromptDict):
21
+ class VideoDescriptionPromptTemplateDict(TextPromptDict):
22
22
  ASSISTANT_PROMPT = TextPrompt(
23
23
  """You are a master of video analysis.
24
24
  Please provide a shot description of the content of the current video."""
@@ -235,8 +235,6 @@ class AutoRetriever:
235
235
  else content_input_paths
236
236
  )
237
237
 
238
- vr = VectorRetriever()
239
-
240
238
  all_retrieved_info = []
241
239
  for content_input_path in content_input_paths:
242
240
  # Generate a valid collection name
@@ -283,12 +281,14 @@ class AutoRetriever:
283
281
  vr = VectorRetriever(
284
282
  storage=vector_storage_instance,
285
283
  similarity_threshold=similarity_threshold,
284
+ embedding_model=self.embedding_model,
286
285
  )
287
286
  vr.process(content_input_path)
288
287
  else:
289
288
  vr = VectorRetriever(
290
289
  storage=vector_storage_instance,
291
290
  similarity_threshold=similarity_threshold,
291
+ embedding_model=self.embedding_model,
292
292
  )
293
293
  # Retrieve info by given query from the vector storage
294
294
  retrieved_info = vr.query(query, top_k)
@@ -12,6 +12,7 @@
12
12
  # limitations under the License.
13
13
  # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
14
  import logging
15
+ import os
15
16
  from hashlib import md5
16
17
  from typing import Any, Dict, List, Optional
17
18
 
@@ -95,6 +96,10 @@ class Neo4jGraph(BaseGraphStorage):
95
96
  r"""Create a new Neo4j graph instance."""
96
97
  import neo4j
97
98
 
99
+ url = os.environ.get("NEO4J_URI") or url
100
+ username = os.environ.get("NEO4J_USERNAME") or username
101
+ password = os.environ.get("NEO4J_PASSWORD") or password
102
+
98
103
  self.driver = neo4j.GraphDatabase.driver(
99
104
  url, auth=(username, password)
100
105
  )