camel-ai 0.1.5.5__py3-none-any.whl → 0.1.5.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (97) hide show
  1. camel/__init__.py +1 -1
  2. camel/agents/chat_agent.py +3 -3
  3. camel/agents/critic_agent.py +1 -1
  4. camel/agents/deductive_reasoner_agent.py +4 -4
  5. camel/agents/embodied_agent.py +1 -1
  6. camel/agents/knowledge_graph_agent.py +13 -17
  7. camel/agents/role_assignment_agent.py +1 -1
  8. camel/agents/search_agent.py +4 -5
  9. camel/agents/task_agent.py +5 -6
  10. camel/configs/__init__.py +15 -0
  11. camel/configs/gemini_config.py +98 -0
  12. camel/configs/groq_config.py +119 -0
  13. camel/configs/litellm_config.py +1 -1
  14. camel/configs/mistral_config.py +81 -0
  15. camel/configs/ollama_config.py +1 -1
  16. camel/configs/openai_config.py +1 -1
  17. camel/configs/vllm_config.py +103 -0
  18. camel/configs/zhipuai_config.py +1 -1
  19. camel/embeddings/__init__.py +2 -0
  20. camel/embeddings/mistral_embedding.py +89 -0
  21. camel/interpreters/__init__.py +2 -0
  22. camel/interpreters/ipython_interpreter.py +167 -0
  23. camel/models/__init__.py +10 -0
  24. camel/models/anthropic_model.py +7 -2
  25. camel/models/azure_openai_model.py +152 -0
  26. camel/models/base_model.py +9 -2
  27. camel/models/gemini_model.py +215 -0
  28. camel/models/groq_model.py +131 -0
  29. camel/models/litellm_model.py +26 -4
  30. camel/models/mistral_model.py +169 -0
  31. camel/models/model_factory.py +33 -5
  32. camel/models/ollama_model.py +21 -2
  33. camel/models/open_source_model.py +11 -3
  34. camel/models/openai_model.py +7 -2
  35. camel/models/stub_model.py +4 -4
  36. camel/models/vllm_model.py +138 -0
  37. camel/models/zhipuai_model.py +7 -4
  38. camel/prompts/__init__.py +2 -2
  39. camel/prompts/task_prompt_template.py +4 -4
  40. camel/prompts/{descripte_video_prompt.py → video_description_prompt.py} +1 -1
  41. camel/retrievers/auto_retriever.py +2 -0
  42. camel/storages/graph_storages/neo4j_graph.py +5 -0
  43. camel/toolkits/__init__.py +36 -0
  44. camel/toolkits/base.py +1 -1
  45. camel/toolkits/code_execution.py +1 -1
  46. camel/toolkits/github_toolkit.py +3 -2
  47. camel/toolkits/google_maps_toolkit.py +367 -0
  48. camel/toolkits/math_toolkit.py +79 -0
  49. camel/toolkits/open_api_toolkit.py +548 -0
  50. camel/toolkits/retrieval_toolkit.py +76 -0
  51. camel/toolkits/search_toolkit.py +326 -0
  52. camel/toolkits/slack_toolkit.py +308 -0
  53. camel/toolkits/twitter_toolkit.py +522 -0
  54. camel/toolkits/weather_toolkit.py +173 -0
  55. camel/types/enums.py +163 -30
  56. camel/utils/__init__.py +4 -0
  57. camel/utils/async_func.py +1 -1
  58. camel/utils/token_counting.py +182 -40
  59. {camel_ai-0.1.5.5.dist-info → camel_ai-0.1.5.9.dist-info}/METADATA +43 -3
  60. camel_ai-0.1.5.9.dist-info/RECORD +165 -0
  61. camel/functions/__init__.py +0 -51
  62. camel/functions/google_maps_function.py +0 -335
  63. camel/functions/math_functions.py +0 -61
  64. camel/functions/open_api_function.py +0 -508
  65. camel/functions/retrieval_functions.py +0 -61
  66. camel/functions/search_functions.py +0 -298
  67. camel/functions/slack_functions.py +0 -286
  68. camel/functions/twitter_function.py +0 -479
  69. camel/functions/weather_functions.py +0 -144
  70. camel_ai-0.1.5.5.dist-info/RECORD +0 -155
  71. /camel/{functions → toolkits}/open_api_specs/biztoc/__init__.py +0 -0
  72. /camel/{functions → toolkits}/open_api_specs/biztoc/ai-plugin.json +0 -0
  73. /camel/{functions → toolkits}/open_api_specs/biztoc/openapi.yaml +0 -0
  74. /camel/{functions → toolkits}/open_api_specs/coursera/__init__.py +0 -0
  75. /camel/{functions → toolkits}/open_api_specs/coursera/openapi.yaml +0 -0
  76. /camel/{functions → toolkits}/open_api_specs/create_qr_code/__init__.py +0 -0
  77. /camel/{functions → toolkits}/open_api_specs/create_qr_code/openapi.yaml +0 -0
  78. /camel/{functions → toolkits}/open_api_specs/klarna/__init__.py +0 -0
  79. /camel/{functions → toolkits}/open_api_specs/klarna/openapi.yaml +0 -0
  80. /camel/{functions → toolkits}/open_api_specs/nasa_apod/__init__.py +0 -0
  81. /camel/{functions → toolkits}/open_api_specs/nasa_apod/openapi.yaml +0 -0
  82. /camel/{functions → toolkits}/open_api_specs/outschool/__init__.py +0 -0
  83. /camel/{functions → toolkits}/open_api_specs/outschool/ai-plugin.json +0 -0
  84. /camel/{functions → toolkits}/open_api_specs/outschool/openapi.yaml +0 -0
  85. /camel/{functions → toolkits}/open_api_specs/outschool/paths/__init__.py +0 -0
  86. /camel/{functions → toolkits}/open_api_specs/outschool/paths/get_classes.py +0 -0
  87. /camel/{functions → toolkits}/open_api_specs/outschool/paths/search_teachers.py +0 -0
  88. /camel/{functions → toolkits}/open_api_specs/security_config.py +0 -0
  89. /camel/{functions → toolkits}/open_api_specs/speak/__init__.py +0 -0
  90. /camel/{functions → toolkits}/open_api_specs/speak/openapi.yaml +0 -0
  91. /camel/{functions → toolkits}/open_api_specs/web_scraper/__init__.py +0 -0
  92. /camel/{functions → toolkits}/open_api_specs/web_scraper/ai-plugin.json +0 -0
  93. /camel/{functions → toolkits}/open_api_specs/web_scraper/openapi.yaml +0 -0
  94. /camel/{functions → toolkits}/open_api_specs/web_scraper/paths/__init__.py +0 -0
  95. /camel/{functions → toolkits}/open_api_specs/web_scraper/paths/scraper.py +0 -0
  96. /camel/{functions → toolkits}/openai_function.py +0 -0
  97. {camel_ai-0.1.5.5.dist-info → camel_ai-0.1.5.9.dist-info}/WHEEL +0 -0
@@ -0,0 +1,215 @@
1
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2
+ # Licensed under the Apache License, Version 2.0 (the “License”);
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an “AS IS” BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
15
+
16
+ from camel.configs import Gemini_API_PARAMS
17
+ from camel.messages import OpenAIMessage
18
+ from camel.models import BaseModelBackend
19
+ from camel.types import (
20
+ ChatCompletion,
21
+ ChatCompletionMessage,
22
+ Choice,
23
+ ModelType,
24
+ )
25
+ from camel.utils import (
26
+ BaseTokenCounter,
27
+ GeminiTokenCounter,
28
+ api_keys_required,
29
+ )
30
+
31
+ if TYPE_CHECKING:
32
+ from google.generativeai.types import ContentsType, GenerateContentResponse
33
+
34
+
35
+ class GeminiModel(BaseModelBackend):
36
+ r"""Gemini API in a unified BaseModelBackend interface."""
37
+
38
+ # NOTE: Currently "stream": True is not supported with Gemini due to the
39
+ # limitation of the current camel design.
40
+
41
+ def __init__(
42
+ self,
43
+ model_type: ModelType,
44
+ model_config_dict: Dict[str, Any],
45
+ api_key: Optional[str] = None,
46
+ url: Optional[str] = None,
47
+ token_counter: Optional[BaseTokenCounter] = None,
48
+ ) -> None:
49
+ r"""Constructor for Gemini backend.
50
+
51
+ Args:
52
+ model_type (ModelType): Model for which a backend is created.
53
+ model_config_dict (Dict[str, Any]): A dictionary that will
54
+ be fed into generate_content().
55
+ api_key (Optional[str]): The API key for authenticating with the
56
+ gemini service. (default: :obj:`None`)
57
+ url (Optional[str]): The url to the gemini service.
58
+ token_counter (Optional[BaseTokenCounter]): Token counter to use
59
+ for the model. If not provided, `GeminiTokenCounter` will be
60
+ used.
61
+ """
62
+ import os
63
+
64
+ import google.generativeai as genai
65
+ from google.generativeai.types.generation_types import GenerationConfig
66
+
67
+ super().__init__(
68
+ model_type, model_config_dict, api_key, url, token_counter
69
+ )
70
+ self._api_key = api_key or os.environ.get("GOOGLE_API_KEY")
71
+ genai.configure(api_key=self._api_key)
72
+ self._client = genai.GenerativeModel(self.model_type.value)
73
+
74
+ keys = list(self.model_config_dict.keys())
75
+ generation_config_dict = {
76
+ k: self.model_config_dict.pop(k)
77
+ for k in keys
78
+ if hasattr(GenerationConfig, k)
79
+ }
80
+ generation_config = genai.types.GenerationConfig(
81
+ **generation_config_dict
82
+ )
83
+ self.model_config_dict["generation_config"] = generation_config
84
+
85
+ @property
86
+ def token_counter(self) -> BaseTokenCounter:
87
+ r"""Initialize the token counter for the model backend.
88
+
89
+ Returns:
90
+ BaseTokenCounter: The token counter following the model's
91
+ tokenization style.
92
+ """
93
+ if not self._token_counter:
94
+ self._token_counter = GeminiTokenCounter(self.model_type)
95
+ return self._token_counter
96
+
97
+ @api_keys_required("GOOGLE_API_KEY")
98
+ def run(
99
+ self,
100
+ messages: List[OpenAIMessage],
101
+ ) -> ChatCompletion:
102
+ r"""Runs inference of Gemini model.
103
+ This method can handle multimodal input
104
+
105
+ Args:
106
+ messages: Message list or Message with the chat history
107
+ in OpenAi format.
108
+
109
+ Returns:
110
+ response: A ChatCompletion object formatted for the OpenAI API.
111
+ """
112
+ response = self._client.generate_content(
113
+ contents=self.to_gemini_req(messages),
114
+ **self.model_config_dict,
115
+ )
116
+ response.resolve()
117
+ return self.to_openai_response(response)
118
+
119
+ def check_model_config(self):
120
+ r"""Check whether the model configuration contains any
121
+ unexpected arguments to Gemini API.
122
+
123
+ Raises:
124
+ ValueError: If the model configuration dictionary contains any
125
+ unexpected arguments to OpenAI API.
126
+ """
127
+ if self.model_config_dict is not None:
128
+ for param in self.model_config_dict:
129
+ if param not in Gemini_API_PARAMS:
130
+ raise ValueError(
131
+ f"Unexpected argument `{param}` is "
132
+ "input into Gemini model backend."
133
+ )
134
+
135
+ @property
136
+ def stream(self) -> bool:
137
+ r"""Returns whether the model is in stream mode,
138
+ which sends partial results each time.
139
+
140
+ Returns:
141
+ bool: Whether the model is in stream mode.
142
+ """
143
+ return self.model_config_dict.get('stream', False)
144
+
145
+ def to_gemini_req(self, messages: List[OpenAIMessage]) -> 'ContentsType':
146
+ r"""Converts the request from the OpenAI API format to
147
+ the Gemini API request format.
148
+
149
+ Args:
150
+ messages: The request object from the OpenAI API.
151
+
152
+ Returns:
153
+ converted_messages: A list of messages formatted for Gemini API.
154
+ """
155
+ # role reference
156
+ # https://ai.google.dev/api/python/google/generativeai/protos/Content
157
+ converted_messages = []
158
+ for message in messages:
159
+ role = message.get('role')
160
+ if role == 'assistant':
161
+ role_to_gemini = 'model'
162
+ else:
163
+ role_to_gemini = 'user'
164
+ converted_message = {
165
+ "role": role_to_gemini,
166
+ "parts": message.get("content"),
167
+ }
168
+ converted_messages.append(converted_message)
169
+ return converted_messages
170
+
171
+ def to_openai_response(
172
+ self,
173
+ response: 'GenerateContentResponse',
174
+ ) -> ChatCompletion:
175
+ r"""Converts the response from the Gemini API to the OpenAI API
176
+ response format.
177
+
178
+ Args:
179
+ response: The response object returned by the Gemini API
180
+
181
+ Returns:
182
+ openai_response: A ChatCompletion object formatted for
183
+ the OpenAI API.
184
+ """
185
+ import time
186
+ import uuid
187
+
188
+ openai_response = ChatCompletion(
189
+ id=f"chatcmpl-{uuid.uuid4().hex!s}",
190
+ object="chat.completion",
191
+ created=int(time.time()),
192
+ model=self.model_type.value,
193
+ choices=[],
194
+ )
195
+ for i, candidate in enumerate(response.candidates):
196
+ content = ""
197
+ if candidate.content and len(candidate.content.parts) > 0:
198
+ content = candidate.content.parts[0].text
199
+ finish_reason = candidate.finish_reason
200
+ finish_reason_mapping = {
201
+ "FinishReason.STOP": "stop",
202
+ "FinishReason.SAFETY": "content_filter",
203
+ "FinishReason.RECITATION": "content_filter",
204
+ "FinishReason.MAX_TOKENS": "length",
205
+ }
206
+ finish_reason = finish_reason_mapping.get(finish_reason, "stop")
207
+ choice = Choice(
208
+ index=i,
209
+ message=ChatCompletionMessage(
210
+ role="assistant", content=content
211
+ ),
212
+ finish_reason=finish_reason,
213
+ )
214
+ openai_response.choices.append(choice)
215
+ return openai_response
@@ -0,0 +1,131 @@
1
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2
+ # Licensed under the Apache License, Version 2.0 (the “License”);
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an “AS IS” BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+ import os
15
+ from typing import Any, Dict, List, Optional, Union
16
+
17
+ from openai import OpenAI, Stream
18
+
19
+ from camel.configs import GROQ_API_PARAMS
20
+ from camel.messages import OpenAIMessage
21
+ from camel.models import BaseModelBackend
22
+ from camel.types import (
23
+ ChatCompletion,
24
+ ChatCompletionChunk,
25
+ ModelType,
26
+ )
27
+ from camel.utils import (
28
+ BaseTokenCounter,
29
+ OpenAITokenCounter,
30
+ api_keys_required,
31
+ )
32
+
33
+
34
+ class GroqModel(BaseModelBackend):
35
+ r"""LLM API served by Groq in a unified BaseModelBackend interface."""
36
+
37
+ def __init__(
38
+ self,
39
+ model_type: ModelType,
40
+ model_config_dict: Dict[str, Any],
41
+ api_key: Optional[str] = None,
42
+ url: Optional[str] = None,
43
+ token_counter: Optional[BaseTokenCounter] = None,
44
+ ) -> None:
45
+ r"""Constructor for Groq backend.
46
+
47
+ Args:
48
+ model_type (str): Model for which a backend is created.
49
+ model_config_dict (Dict[str, Any]): A dictionary of parameters for
50
+ the model configuration.
51
+ api_key (Optional[str]): The API key for authenticating with the
52
+ Groq service. (default: :obj:`None`).
53
+ url (Optional[str]): The url to the Groq service. (default:
54
+ :obj:`None`)
55
+ token_counter (Optional[BaseTokenCounter]): Token counter to use
56
+ for the model. If not provided, `OpenAITokenCounter(ModelType.
57
+ GPT_3_5_TURBO)` will be used.
58
+ """
59
+ super().__init__(
60
+ model_type, model_config_dict, api_key, url, token_counter
61
+ )
62
+ self._url = url or "https://api.groq.com/openai/v1"
63
+ self._api_key = api_key or os.environ.get("GROQ_API_KEY")
64
+ self._client = OpenAI(
65
+ timeout=60,
66
+ max_retries=3,
67
+ api_key=self._api_key,
68
+ base_url=self._url,
69
+ )
70
+ self._token_counter = token_counter
71
+
72
+ @property
73
+ def token_counter(self) -> BaseTokenCounter:
74
+ r"""Initialize the token counter for the model backend.
75
+
76
+ Returns:
77
+ BaseTokenCounter: The token counter following the model's
78
+ tokenization style.
79
+ """
80
+ # Make sure you have the access to these open-source model in
81
+ # HuggingFace
82
+ if not self._token_counter:
83
+ self._token_counter = OpenAITokenCounter(ModelType.GPT_3_5_TURBO)
84
+ return self._token_counter
85
+
86
+ @api_keys_required("GROQ_API_KEY")
87
+ def run(
88
+ self,
89
+ messages: List[OpenAIMessage],
90
+ ) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
91
+ r"""Runs inference of OpenAI chat completion.
92
+
93
+ Args:
94
+ messages (List[OpenAIMessage]): Message list with the chat history
95
+ in OpenAI API format.
96
+
97
+ Returns:
98
+ Union[ChatCompletion, Stream[ChatCompletionChunk]]:
99
+ `ChatCompletion` in the non-stream mode, or
100
+ `Stream[ChatCompletionChunk]` in the stream mode.
101
+ """
102
+ response = self._client.chat.completions.create(
103
+ messages=messages,
104
+ model=self.model_type.value,
105
+ **self.model_config_dict,
106
+ )
107
+
108
+ return response
109
+
110
+ def check_model_config(self):
111
+ r"""Check whether the model configuration contains any unexpected
112
+ arguments to Groq API. But Groq API does not have any additional
113
+ arguments to check.
114
+
115
+ Raises:
116
+ ValueError: If the model configuration dictionary contains any
117
+ unexpected arguments to Groq API.
118
+ """
119
+ for param in self.model_config_dict:
120
+ if param not in GROQ_API_PARAMS:
121
+ raise ValueError(
122
+ f"Unexpected argument `{param}` is "
123
+ "input into Groq model backend."
124
+ )
125
+
126
+ @property
127
+ def stream(self) -> bool:
128
+ r"""Returns whether the model supports streaming. But Groq API does
129
+ not support streaming.
130
+ """
131
+ return False
@@ -16,7 +16,7 @@ from typing import Any, Dict, List, Optional
16
16
  from camel.configs import LITELLM_API_PARAMS
17
17
  from camel.messages import OpenAIMessage
18
18
  from camel.types import ChatCompletion
19
- from camel.utils import LiteLLMTokenCounter
19
+ from camel.utils import BaseTokenCounter, LiteLLMTokenCounter
20
20
 
21
21
 
22
22
  class LiteLLMModel:
@@ -30,6 +30,7 @@ class LiteLLMModel:
30
30
  model_config_dict: Dict[str, Any],
31
31
  api_key: Optional[str] = None,
32
32
  url: Optional[str] = None,
33
+ token_counter: Optional[BaseTokenCounter] = None,
33
34
  ) -> None:
34
35
  r"""Constructor for LiteLLM backend.
35
36
 
@@ -42,11 +43,14 @@ class LiteLLMModel:
42
43
  model service. (default: :obj:`None`)
43
44
  url (Optional[str]): The url to the model service. (default:
44
45
  :obj:`None`)
46
+ token_counter (Optional[BaseTokenCounter]): Token counter to use
47
+ for the model. If not provided, `LiteLLMTokenCounter` will
48
+ be used.
45
49
  """
46
50
  self.model_type = model_type
47
51
  self.model_config_dict = model_config_dict
48
52
  self._client = None
49
- self._token_counter: Optional[LiteLLMTokenCounter] = None
53
+ self._token_counter = token_counter
50
54
  self.check_model_config()
51
55
  self._url = url
52
56
  self._api_key = api_key
@@ -98,8 +102,10 @@ class LiteLLMModel:
98
102
  tokenization style.
99
103
  """
100
104
  if not self._token_counter:
101
- self._token_counter = LiteLLMTokenCounter(self.model_type)
102
- return self._token_counter
105
+ self._token_counter = LiteLLMTokenCounter( # type: ignore[assignment]
106
+ self.model_type
107
+ )
108
+ return self._token_counter # type: ignore[return-value]
103
109
 
104
110
  def run(
105
111
  self,
@@ -138,3 +144,19 @@ class LiteLLMModel:
138
144
  f"Unexpected argument `{param}` is "
139
145
  "input into LiteLLM model backend."
140
146
  )
147
+
148
+ @property
149
+ def token_limit(self) -> int:
150
+ """Returns the maximum token limit for the given model.
151
+
152
+ Returns:
153
+ int: The maximum token limit for the given model.
154
+ """
155
+ max_tokens = self.model_config_dict.get("max_tokens")
156
+ if isinstance(max_tokens, int):
157
+ return max_tokens
158
+ print(
159
+ "Must set `max_tokens` as an integer in `model_config_dict` when"
160
+ " setting up the model. Using 4096 as default value."
161
+ )
162
+ return 4096
@@ -0,0 +1,169 @@
1
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2
+ # Licensed under the Apache License, Version 2.0 (the “License”);
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an “AS IS” BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+ import os
15
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
16
+
17
+ if TYPE_CHECKING:
18
+ from mistralai.models.chat_completion import ChatCompletionResponse
19
+
20
+ from camel.configs import MISTRAL_API_PARAMS
21
+ from camel.messages import OpenAIMessage
22
+ from camel.models import BaseModelBackend
23
+ from camel.types import ChatCompletion, ModelType
24
+ from camel.utils import (
25
+ BaseTokenCounter,
26
+ MistralTokenCounter,
27
+ api_keys_required,
28
+ )
29
+
30
+
31
+ class MistralModel(BaseModelBackend):
32
+ r"""Mistral API in a unified BaseModelBackend interface."""
33
+
34
+ # TODO: Support tool calling.
35
+
36
+ def __init__(
37
+ self,
38
+ model_type: ModelType,
39
+ model_config_dict: Dict[str, Any],
40
+ api_key: Optional[str] = None,
41
+ url: Optional[str] = None,
42
+ token_counter: Optional[BaseTokenCounter] = None,
43
+ ) -> None:
44
+ r"""Constructor for Mistral backend.
45
+
46
+ Args:
47
+ model_type (ModelType): Model for which a backend is created,
48
+ one of MISTRAL_* series.
49
+ model_config_dict (Dict[str, Any]): A dictionary that will
50
+ be fed into `MistralClient.chat`.
51
+ api_key (Optional[str]): The API key for authenticating with the
52
+ mistral service. (default: :obj:`None`)
53
+ url (Optional[str]): The url to the mistral service.
54
+ token_counter (Optional[BaseTokenCounter]): Token counter to use
55
+ for the model. If not provided, `MistralTokenCounter` will be
56
+ used.
57
+ """
58
+ super().__init__(
59
+ model_type, model_config_dict, api_key, url, token_counter
60
+ )
61
+ self._api_key = api_key or os.environ.get("MISTRAL_API_KEY")
62
+
63
+ from mistralai.client import MistralClient
64
+
65
+ self._client = MistralClient(api_key=self._api_key)
66
+ self._token_counter: Optional[BaseTokenCounter] = None
67
+
68
+ def _convert_response_from_mistral_to_openai(
69
+ self, response: 'ChatCompletionResponse'
70
+ ) -> ChatCompletion:
71
+ tool_calls = None
72
+ if response.choices[0].message.tool_calls is not None:
73
+ tool_calls = [
74
+ dict(
75
+ id=tool_call.id,
76
+ function={
77
+ "name": tool_call.function.name,
78
+ "arguments": tool_call.function.arguments,
79
+ },
80
+ type=tool_call.type.value,
81
+ )
82
+ for tool_call in response.choices[0].message.tool_calls
83
+ ]
84
+
85
+ obj = ChatCompletion.construct(
86
+ id=response.id,
87
+ choices=[
88
+ dict(
89
+ index=response.choices[0].index,
90
+ message={
91
+ "role": response.choices[0].message.role,
92
+ "content": response.choices[0].message.content,
93
+ "tool_calls": tool_calls,
94
+ },
95
+ finish_reason=response.choices[0].finish_reason.value
96
+ if response.choices[0].finish_reason
97
+ else None,
98
+ )
99
+ ],
100
+ created=response.created,
101
+ model=response.model,
102
+ object="chat.completion",
103
+ usage=response.usage,
104
+ )
105
+
106
+ return obj
107
+
108
+ @property
109
+ def token_counter(self) -> BaseTokenCounter:
110
+ r"""Initialize the token counter for the model backend.
111
+
112
+ Returns:
113
+ BaseTokenCounter: The token counter following the model's
114
+ tokenization style.
115
+ """
116
+ if not self._token_counter:
117
+ self._token_counter = MistralTokenCounter(
118
+ model_type=self.model_type
119
+ )
120
+ return self._token_counter
121
+
122
+ @api_keys_required("MISTRAL_API_KEY")
123
+ def run(
124
+ self,
125
+ messages: List[OpenAIMessage],
126
+ ) -> ChatCompletion:
127
+ r"""Runs inference of Mistral chat completion.
128
+
129
+ Args:
130
+ messages (List[OpenAIMessage]): Message list with the chat history
131
+ in OpenAI API format.
132
+
133
+ Returns:
134
+ ChatCompletion
135
+ """
136
+ response = self._client.chat(
137
+ messages=messages,
138
+ model=self.model_type.value,
139
+ **self.model_config_dict,
140
+ )
141
+
142
+ response = self._convert_response_from_mistral_to_openai(response) # type:ignore[assignment]
143
+
144
+ return response # type:ignore[return-value]
145
+
146
+ def check_model_config(self):
147
+ r"""Check whether the model configuration contains any
148
+ unexpected arguments to Mistral API.
149
+
150
+ Raises:
151
+ ValueError: If the model configuration dictionary contains any
152
+ unexpected arguments to Mistral API.
153
+ """
154
+ for param in self.model_config_dict:
155
+ if param not in MISTRAL_API_PARAMS:
156
+ raise ValueError(
157
+ f"Unexpected argument `{param}` is "
158
+ "input into Mistral model backend."
159
+ )
160
+
161
+ @property
162
+ def stream(self) -> bool:
163
+ r"""Returns whether the model is in stream mode, which sends partial
164
+ results each time. Mistral doesn't support stream mode.
165
+
166
+ Returns:
167
+ bool: Whether the model is in stream mode.
168
+ """
169
+ return False
@@ -14,14 +14,20 @@
14
14
  from typing import Any, Dict, Optional, Union
15
15
 
16
16
  from camel.models.anthropic_model import AnthropicModel
17
+ from camel.models.azure_openai_model import AzureOpenAIModel
17
18
  from camel.models.base_model import BaseModelBackend
19
+ from camel.models.gemini_model import GeminiModel
20
+ from camel.models.groq_model import GroqModel
18
21
  from camel.models.litellm_model import LiteLLMModel
22
+ from camel.models.mistral_model import MistralModel
19
23
  from camel.models.ollama_model import OllamaModel
20
24
  from camel.models.open_source_model import OpenSourceModel
21
25
  from camel.models.openai_model import OpenAIModel
22
26
  from camel.models.stub_model import StubModel
27
+ from camel.models.vllm_model import VLLMModel
23
28
  from camel.models.zhipuai_model import ZhipuAIModel
24
29
  from camel.types import ModelPlatformType, ModelType
30
+ from camel.utils import BaseTokenCounter
25
31
 
26
32
 
27
33
  class ModelFactory:
@@ -36,6 +42,7 @@ class ModelFactory:
36
42
  model_platform: ModelPlatformType,
37
43
  model_type: Union[ModelType, str],
38
44
  model_config_dict: Dict,
45
+ token_counter: Optional[BaseTokenCounter] = None,
39
46
  api_key: Optional[str] = None,
40
47
  url: Optional[str] = None,
41
48
  ) -> BaseModelBackend:
@@ -48,6 +55,10 @@ class ModelFactory:
48
55
  created can be a `str` for open source platforms.
49
56
  model_config_dict (Dict): A dictionary that will be fed into
50
57
  the backend constructor.
58
+ token_counter (Optional[BaseTokenCounter]): Token counter to use
59
+ for the model. If not provided, OpenAITokenCounter(ModelType.
60
+ GPT_3_5_TURBO) will be used if the model platform didn't
61
+ provide official token counter.
51
62
  api_key (Optional[str]): The API key for authenticating with the
52
63
  model service.
53
64
  url (Optional[str]): The url to the model service.
@@ -59,17 +70,26 @@ class ModelFactory:
59
70
  BaseModelBackend: The initialized backend.
60
71
  """
61
72
  model_class: Any
62
-
63
73
  if isinstance(model_type, ModelType):
64
74
  if model_platform.is_open_source and model_type.is_open_source:
65
75
  model_class = OpenSourceModel
66
- return model_class(model_type, model_config_dict, url)
76
+ return model_class(
77
+ model_type, model_config_dict, url, token_counter
78
+ )
67
79
  if model_platform.is_openai and model_type.is_openai:
68
80
  model_class = OpenAIModel
81
+ elif model_platform.is_azure and model_type.is_azure_openai:
82
+ model_class = AzureOpenAIModel
69
83
  elif model_platform.is_anthropic and model_type.is_anthropic:
70
84
  model_class = AnthropicModel
85
+ elif model_type.is_groq:
86
+ model_class = GroqModel
71
87
  elif model_platform.is_zhipuai and model_type.is_zhipuai:
72
88
  model_class = ZhipuAIModel
89
+ elif model_platform.is_gemini and model_type.is_gemini:
90
+ model_class = GeminiModel
91
+ elif model_platform.is_mistral and model_type.is_mistral:
92
+ model_class = MistralModel
73
93
  elif model_type == ModelType.STUB:
74
94
  model_class = StubModel
75
95
  else:
@@ -80,7 +100,14 @@ class ModelFactory:
80
100
  elif isinstance(model_type, str):
81
101
  if model_platform.is_ollama:
82
102
  model_class = OllamaModel
83
- return model_class(model_type, model_config_dict, url)
103
+ return model_class(
104
+ model_type, model_config_dict, url, token_counter
105
+ )
106
+ elif model_platform.is_vllm:
107
+ model_class = VLLMModel
108
+ return model_class(
109
+ model_type, model_config_dict, url, api_key, token_counter
110
+ )
84
111
  elif model_platform.is_litellm:
85
112
  model_class = LiteLLMModel
86
113
  else:
@@ -90,5 +117,6 @@ class ModelFactory:
90
117
  )
91
118
  else:
92
119
  raise ValueError(f"Invalid model type `{model_type}` provided.")
93
-
94
- return model_class(model_type, model_config_dict, api_key, url)
120
+ return model_class(
121
+ model_type, model_config_dict, api_key, url, token_counter
122
+ )