camel-ai 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (116) hide show
  1. camel/__init__.py +1 -1
  2. camel/agents/_types.py +41 -0
  3. camel/agents/_utils.py +188 -0
  4. camel/agents/chat_agent.py +570 -965
  5. camel/agents/knowledge_graph_agent.py +7 -1
  6. camel/agents/multi_hop_generator_agent.py +1 -1
  7. camel/configs/base_config.py +10 -13
  8. camel/configs/deepseek_config.py +4 -30
  9. camel/configs/gemini_config.py +5 -31
  10. camel/configs/openai_config.py +14 -32
  11. camel/configs/qwen_config.py +36 -36
  12. camel/datagen/self_improving_cot.py +81 -3
  13. camel/datagen/self_instruct/filter/instruction_filter.py +19 -3
  14. camel/datagen/self_instruct/self_instruct.py +53 -4
  15. camel/datasets/__init__.py +28 -0
  16. camel/datasets/base.py +969 -0
  17. camel/embeddings/openai_embedding.py +10 -1
  18. camel/environments/__init__.py +16 -0
  19. camel/environments/base.py +503 -0
  20. camel/extractors/__init__.py +16 -0
  21. camel/extractors/base.py +263 -0
  22. camel/interpreters/docker/Dockerfile +12 -0
  23. camel/interpreters/docker_interpreter.py +19 -1
  24. camel/interpreters/subprocess_interpreter.py +42 -17
  25. camel/loaders/__init__.py +2 -0
  26. camel/loaders/mineru_extractor.py +250 -0
  27. camel/memories/agent_memories.py +16 -1
  28. camel/memories/blocks/chat_history_block.py +10 -2
  29. camel/memories/blocks/vectordb_block.py +1 -0
  30. camel/memories/context_creators/score_based.py +20 -3
  31. camel/memories/records.py +10 -0
  32. camel/messages/base.py +8 -8
  33. camel/models/__init__.py +2 -0
  34. camel/models/_utils.py +57 -0
  35. camel/models/aiml_model.py +48 -17
  36. camel/models/anthropic_model.py +41 -3
  37. camel/models/azure_openai_model.py +39 -3
  38. camel/models/base_audio_model.py +92 -0
  39. camel/models/base_model.py +132 -4
  40. camel/models/cohere_model.py +88 -11
  41. camel/models/deepseek_model.py +107 -63
  42. camel/models/fish_audio_model.py +18 -8
  43. camel/models/gemini_model.py +133 -15
  44. camel/models/groq_model.py +72 -10
  45. camel/models/internlm_model.py +14 -3
  46. camel/models/litellm_model.py +9 -2
  47. camel/models/mistral_model.py +42 -5
  48. camel/models/model_manager.py +57 -3
  49. camel/models/moonshot_model.py +33 -4
  50. camel/models/nemotron_model.py +32 -3
  51. camel/models/nvidia_model.py +43 -3
  52. camel/models/ollama_model.py +139 -17
  53. camel/models/openai_audio_models.py +87 -2
  54. camel/models/openai_compatible_model.py +37 -3
  55. camel/models/openai_model.py +158 -46
  56. camel/models/qwen_model.py +61 -4
  57. camel/models/reka_model.py +53 -3
  58. camel/models/samba_model.py +209 -4
  59. camel/models/sglang_model.py +153 -14
  60. camel/models/siliconflow_model.py +16 -3
  61. camel/models/stub_model.py +46 -4
  62. camel/models/togetherai_model.py +38 -3
  63. camel/models/vllm_model.py +37 -3
  64. camel/models/yi_model.py +36 -3
  65. camel/models/zhipuai_model.py +38 -3
  66. camel/retrievers/__init__.py +3 -0
  67. camel/retrievers/hybrid_retrival.py +237 -0
  68. camel/toolkits/__init__.py +20 -3
  69. camel/toolkits/arxiv_toolkit.py +2 -1
  70. camel/toolkits/ask_news_toolkit.py +4 -2
  71. camel/toolkits/audio_analysis_toolkit.py +238 -0
  72. camel/toolkits/base.py +22 -3
  73. camel/toolkits/code_execution.py +2 -0
  74. camel/toolkits/dappier_toolkit.py +2 -1
  75. camel/toolkits/data_commons_toolkit.py +38 -12
  76. camel/toolkits/excel_toolkit.py +172 -0
  77. camel/toolkits/function_tool.py +13 -0
  78. camel/toolkits/github_toolkit.py +5 -1
  79. camel/toolkits/google_maps_toolkit.py +2 -1
  80. camel/toolkits/google_scholar_toolkit.py +2 -0
  81. camel/toolkits/human_toolkit.py +0 -3
  82. camel/toolkits/image_analysis_toolkit.py +202 -0
  83. camel/toolkits/linkedin_toolkit.py +3 -2
  84. camel/toolkits/meshy_toolkit.py +3 -2
  85. camel/toolkits/mineru_toolkit.py +178 -0
  86. camel/toolkits/networkx_toolkit.py +240 -0
  87. camel/toolkits/notion_toolkit.py +2 -0
  88. camel/toolkits/openbb_toolkit.py +3 -2
  89. camel/toolkits/page_script.js +376 -0
  90. camel/toolkits/reddit_toolkit.py +11 -3
  91. camel/toolkits/retrieval_toolkit.py +6 -1
  92. camel/toolkits/semantic_scholar_toolkit.py +2 -1
  93. camel/toolkits/stripe_toolkit.py +8 -2
  94. camel/toolkits/sympy_toolkit.py +44 -1
  95. camel/toolkits/video_analysis_toolkit.py +407 -0
  96. camel/toolkits/{video_toolkit.py → video_download_toolkit.py} +21 -25
  97. camel/toolkits/web_toolkit.py +1307 -0
  98. camel/toolkits/whatsapp_toolkit.py +3 -2
  99. camel/toolkits/zapier_toolkit.py +191 -0
  100. camel/types/__init__.py +2 -2
  101. camel/types/agents/__init__.py +16 -0
  102. camel/types/agents/tool_calling_record.py +52 -0
  103. camel/types/enums.py +3 -0
  104. camel/types/openai_types.py +16 -14
  105. camel/utils/__init__.py +2 -1
  106. camel/utils/async_func.py +2 -2
  107. camel/utils/commons.py +114 -1
  108. camel/verifiers/__init__.py +23 -0
  109. camel/verifiers/base.py +340 -0
  110. camel/verifiers/models.py +82 -0
  111. camel/verifiers/python_verifier.py +202 -0
  112. camel_ai-0.2.23.dist-info/METADATA +671 -0
  113. {camel_ai-0.2.21.dist-info → camel_ai-0.2.23.dist-info}/RECORD +127 -99
  114. {camel_ai-0.2.21.dist-info → camel_ai-0.2.23.dist-info}/WHEEL +1 -1
  115. camel_ai-0.2.21.dist-info/METADATA +0 -528
  116. {camel_ai-0.2.21.dist-info → camel_ai-0.2.23.dist-info/licenses}/LICENSE +0 -0
@@ -13,9 +13,10 @@
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
  import os
15
15
  import warnings
16
- from typing import Any, Dict, List, Optional, Union
16
+ from typing import Any, Dict, List, Optional, Type, Union
17
17
 
18
- from openai import OpenAI, Stream
18
+ from openai import AsyncOpenAI, AsyncStream, OpenAI, Stream
19
+ from pydantic import BaseModel
19
20
 
20
21
  from camel.configs import OPENAI_API_PARAMS, ChatGPTConfig
21
22
  from camel.messages import OpenAIMessage
@@ -31,6 +32,16 @@ from camel.utils import (
31
32
  api_keys_required,
32
33
  )
33
34
 
35
+ UNSUPPORTED_PARAMS = {
36
+ "temperature",
37
+ "top_p",
38
+ "presence_penalty",
39
+ "frequency_penalty",
40
+ "logprobs",
41
+ "top_logprobs",
42
+ "logit_bias",
43
+ }
44
+
34
45
 
35
46
  class OpenAIModel(BaseModelBackend):
36
47
  r"""OpenAI API in a unified BaseModelBackend interface.
@@ -68,15 +79,45 @@ class OpenAIModel(BaseModelBackend):
68
79
  model_config_dict = ChatGPTConfig().as_dict()
69
80
  api_key = api_key or os.environ.get("OPENAI_API_KEY")
70
81
  url = url or os.environ.get("OPENAI_API_BASE_URL")
82
+
71
83
  super().__init__(
72
84
  model_type, model_config_dict, api_key, url, token_counter
73
85
  )
86
+
74
87
  self._client = OpenAI(
75
88
  timeout=180,
76
89
  max_retries=3,
77
90
  base_url=self._url,
78
91
  api_key=self._api_key,
79
92
  )
93
+ self._async_client = AsyncOpenAI(
94
+ timeout=180,
95
+ max_retries=3,
96
+ base_url=self._url,
97
+ api_key=self._api_key,
98
+ )
99
+
100
+ def _sanitize_config(self, config_dict: Dict[str, Any]) -> Dict[str, Any]:
101
+ """Sanitize the model configuration for O1 models."""
102
+
103
+ if self.model_type in [
104
+ ModelType.O1,
105
+ ModelType.O1_MINI,
106
+ ModelType.O1_PREVIEW,
107
+ ModelType.O3_MINI,
108
+ ]:
109
+ warnings.warn(
110
+ "Warning: You are using an O1 model (O1_MINI or O1_PREVIEW), "
111
+ "which has certain limitations, reference: "
112
+ "`https://platform.openai.com/docs/guides/reasoning`.",
113
+ UserWarning,
114
+ )
115
+ return {
116
+ k: v
117
+ for k, v in config_dict.items()
118
+ if k not in UNSUPPORTED_PARAMS
119
+ }
120
+ return config_dict
80
121
 
81
122
  @property
82
123
  def token_counter(self) -> BaseTokenCounter:
@@ -90,70 +131,141 @@ class OpenAIModel(BaseModelBackend):
90
131
  self._token_counter = OpenAITokenCounter(self.model_type)
91
132
  return self._token_counter
92
133
 
93
- def run(
134
+ def _run(
94
135
  self,
95
136
  messages: List[OpenAIMessage],
137
+ response_format: Optional[Type[BaseModel]] = None,
138
+ tools: Optional[List[Dict[str, Any]]] = None,
96
139
  ) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
97
140
  r"""Runs inference of OpenAI chat completion.
98
141
 
99
142
  Args:
100
143
  messages (List[OpenAIMessage]): Message list with the chat history
101
144
  in OpenAI API format.
145
+ response_format (Optional[Type[BaseModel]]): The format of the
146
+ response.
147
+ tools (Optional[List[Dict[str, Any]]]): The schema of the tools to
148
+ use for the request.
102
149
 
103
150
  Returns:
104
151
  Union[ChatCompletion, Stream[ChatCompletionChunk]]:
105
152
  `ChatCompletion` in the non-stream mode, or
106
153
  `Stream[ChatCompletionChunk]` in the stream mode.
107
154
  """
108
- # o1-preview and o1-mini have Beta limitations
109
- # reference: https://platform.openai.com/docs/guides/reasoning
110
- if self.model_type in [
111
- ModelType.O1,
112
- ModelType.O1_MINI,
113
- ModelType.O1_PREVIEW,
114
- ModelType.O3_MINI,
115
- ]:
116
- warnings.warn(
117
- "Warning: You are using an O1 model (O1_MINI or O1_PREVIEW), "
118
- "which has certain limitations, reference: "
119
- "`https://platform.openai.com/docs/guides/reasoning`.",
120
- UserWarning,
121
- )
155
+ response_format = response_format or self.model_config_dict.get(
156
+ "response_format", None
157
+ )
158
+ if response_format:
159
+ return self._request_parse(messages, response_format, tools)
160
+ else:
161
+ return self._request_chat_completion(messages, tools)
122
162
 
123
- # Check and remove unsupported parameters and reset the fixed
124
- # parameters
125
- unsupported_keys = [
126
- "temperature",
127
- "top_p",
128
- "presence_penalty",
129
- "frequency_penalty",
130
- "logprobs",
131
- "top_logprobs",
132
- "logit_bias",
133
- ]
134
- for key in unsupported_keys:
135
- if key in self.model_config_dict:
136
- del self.model_config_dict[key]
137
-
138
- if self.model_config_dict.get("response_format"):
139
- # stream is not supported in beta.chat.completions.parse
140
- if "stream" in self.model_config_dict:
141
- del self.model_config_dict["stream"]
142
-
143
- response = self._client.beta.chat.completions.parse(
144
- messages=messages,
145
- model=self.model_type,
146
- **self.model_config_dict,
147
- )
163
+ async def _arun(
164
+ self,
165
+ messages: List[OpenAIMessage],
166
+ response_format: Optional[Type[BaseModel]] = None,
167
+ tools: Optional[List[Dict[str, Any]]] = None,
168
+ ) -> Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
169
+ r"""Runs inference of OpenAI chat completion in async mode.
170
+
171
+ Args:
172
+ messages (List[OpenAIMessage]): Message list with the chat history
173
+ in OpenAI API format.
174
+ response_format (Optional[Type[BaseModel]]): The format of the
175
+ response.
176
+ tools (Optional[List[Dict[str, Any]]]): The schema of the tools to
177
+ use for the request.
178
+
179
+ Returns:
180
+ Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
181
+ `ChatCompletion` in the non-stream mode, or
182
+ `AsyncStream[ChatCompletionChunk]` in the stream mode.
183
+ """
184
+ response_format = response_format or self.model_config_dict.get(
185
+ "response_format", None
186
+ )
187
+ if response_format:
188
+ return await self._arequest_parse(messages, response_format, tools)
189
+ else:
190
+ return await self._arequest_chat_completion(messages, tools)
191
+
192
+ def _request_chat_completion(
193
+ self,
194
+ messages: List[OpenAIMessage],
195
+ tools: Optional[List[Dict[str, Any]]] = None,
196
+ ) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
197
+ request_config = self.model_config_dict.copy()
198
+
199
+ if tools:
200
+ request_config["tools"] = tools
201
+
202
+ request_config = self._sanitize_config(request_config)
203
+
204
+ return self._client.chat.completions.create(
205
+ messages=messages,
206
+ model=self.model_type,
207
+ **request_config,
208
+ )
209
+
210
+ async def _arequest_chat_completion(
211
+ self,
212
+ messages: List[OpenAIMessage],
213
+ tools: Optional[List[Dict[str, Any]]] = None,
214
+ ) -> Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
215
+ request_config = self.model_config_dict.copy()
216
+
217
+ if tools:
218
+ request_config["tools"] = tools
219
+
220
+ request_config = self._sanitize_config(request_config)
221
+
222
+ return await self._async_client.chat.completions.create(
223
+ messages=messages,
224
+ model=self.model_type,
225
+ **request_config,
226
+ )
227
+
228
+ def _request_parse(
229
+ self,
230
+ messages: List[OpenAIMessage],
231
+ response_format: Type[BaseModel],
232
+ tools: Optional[List[Dict[str, Any]]] = None,
233
+ ) -> ChatCompletion:
234
+ request_config = self.model_config_dict.copy()
235
+
236
+ request_config["response_format"] = response_format
237
+ request_config.pop("stream", None)
238
+ if tools is not None:
239
+ request_config["tools"] = tools
240
+
241
+ request_config = self._sanitize_config(request_config)
242
+
243
+ return self._client.beta.chat.completions.parse(
244
+ messages=messages,
245
+ model=self.model_type,
246
+ **request_config,
247
+ )
248
+
249
+ async def _arequest_parse(
250
+ self,
251
+ messages: List[OpenAIMessage],
252
+ response_format: Type[BaseModel],
253
+ tools: Optional[List[Dict[str, Any]]] = None,
254
+ ) -> ChatCompletion:
255
+ request_config = self.model_config_dict.copy()
256
+
257
+ request_config["response_format"] = response_format
258
+ request_config.pop("stream", None)
259
+ if tools is not None:
260
+ request_config["tools"] = tools
148
261
 
149
- return self._to_chat_completion(response)
262
+ request_config = self._sanitize_config(request_config)
150
263
 
151
- response = self._client.chat.completions.create(
264
+ return await self._async_client.beta.chat.completions.parse(
152
265
  messages=messages,
153
266
  model=self.model_type,
154
- **self.model_config_dict,
267
+ **request_config,
155
268
  )
156
- return response
157
269
 
158
270
  def check_model_config(self):
159
271
  r"""Check whether the model configuration contains any
@@ -13,13 +13,15 @@
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
 
15
15
  import os
16
- from typing import Any, Dict, List, Optional, Union
16
+ from typing import Any, Dict, List, Optional, Type, Union
17
17
 
18
- from openai import OpenAI, Stream
18
+ from openai import AsyncOpenAI, AsyncStream, OpenAI, Stream
19
+ from pydantic import BaseModel
19
20
 
20
21
  from camel.configs import QWEN_API_PARAMS, QwenConfig
21
22
  from camel.messages import OpenAIMessage
22
23
  from camel.models import BaseModelBackend
24
+ from camel.models._utils import try_modify_message_with_format
23
25
  from camel.types import (
24
26
  ChatCompletion,
25
27
  ChatCompletionChunk,
@@ -81,10 +83,46 @@ class QwenModel(BaseModelBackend):
81
83
  api_key=self._api_key,
82
84
  base_url=self._url,
83
85
  )
86
+ self._async_client = AsyncOpenAI(
87
+ timeout=180,
88
+ max_retries=3,
89
+ api_key=self._api_key,
90
+ base_url=self._url,
91
+ )
92
+
93
+ async def _arun(
94
+ self,
95
+ messages: List[OpenAIMessage],
96
+ response_format: Optional[Type[BaseModel]] = None,
97
+ tools: Optional[List[Dict[str, Any]]] = None,
98
+ ) -> Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
99
+ r"""Runs inference of Qwen chat completion.
100
+
101
+ Args:
102
+ messages (List[OpenAIMessage]): Message list with the chat history
103
+ in OpenAI API format.
104
+
105
+ Returns:
106
+ Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
107
+ `ChatCompletion` in the non-stream mode, or
108
+ `AsyncStream[ChatCompletionChunk]` in the stream mode.
109
+ """
110
+ request_config = self._prepare_request(
111
+ messages, response_format, tools
112
+ )
113
+
114
+ response = await self._async_client.chat.completions.create(
115
+ messages=messages,
116
+ model=self.model_type,
117
+ **request_config,
118
+ )
119
+ return response
84
120
 
85
- def run(
121
+ def _run(
86
122
  self,
87
123
  messages: List[OpenAIMessage],
124
+ response_format: Optional[Type[BaseModel]] = None,
125
+ tools: Optional[List[Dict[str, Any]]] = None,
88
126
  ) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
89
127
  r"""Runs inference of Qwen chat completion.
90
128
 
@@ -97,13 +135,32 @@ class QwenModel(BaseModelBackend):
97
135
  `ChatCompletion` in the non-stream mode, or
98
136
  `Stream[ChatCompletionChunk]` in the stream mode.
99
137
  """
138
+ request_config = self._prepare_request(
139
+ messages, response_format, tools
140
+ )
141
+
100
142
  response = self._client.chat.completions.create(
101
143
  messages=messages,
102
144
  model=self.model_type,
103
- **self.model_config_dict,
145
+ **request_config,
104
146
  )
105
147
  return response
106
148
 
149
+ def _prepare_request(
150
+ self,
151
+ messages: List[OpenAIMessage],
152
+ response_format: Optional[Type[BaseModel]] = None,
153
+ tools: Optional[List[Dict[str, Any]]] = None,
154
+ ) -> Dict[str, Any]:
155
+ request_config = self.model_config_dict.copy()
156
+ if tools:
157
+ request_config["tools"] = tools
158
+ elif response_format:
159
+ try_modify_message_with_format(messages[-1], response_format)
160
+ request_config["response_format"] = {"type": "json_object"}
161
+
162
+ return request_config
163
+
107
164
  @property
108
165
  def token_counter(self) -> BaseTokenCounter:
109
166
  r"""Initialize the token counter for the model backend.
@@ -11,7 +11,9 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
14
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
15
+
16
+ from pydantic import BaseModel
15
17
 
16
18
  from camel.configs import REKA_API_PARAMS, RekaConfig
17
19
  from camel.messages import OpenAIMessage
@@ -70,7 +72,7 @@ class RekaModel(BaseModelBackend):
70
72
  url: Optional[str] = None,
71
73
  token_counter: Optional[BaseTokenCounter] = None,
72
74
  ) -> None:
73
- from reka.client import Reka
75
+ from reka.client import AsyncReka, Reka
74
76
 
75
77
  if model_config_dict is None:
76
78
  model_config_dict = RekaConfig().as_dict()
@@ -80,6 +82,9 @@ class RekaModel(BaseModelBackend):
80
82
  model_type, model_config_dict, api_key, url, token_counter
81
83
  )
82
84
  self._client = Reka(api_key=self._api_key, base_url=self._url)
85
+ self._async_client = AsyncReka(
86
+ api_key=self._api_key, base_url=self._url
87
+ )
83
88
 
84
89
  def _convert_reka_to_openai_response(
85
90
  self, response: 'ChatResponse'
@@ -117,6 +122,8 @@ class RekaModel(BaseModelBackend):
117
122
  def _convert_openai_to_reka_messages(
118
123
  self,
119
124
  messages: List[OpenAIMessage],
125
+ response_format: Optional[Type[BaseModel]] = None,
126
+ tools: Optional[List[str]] = None,
120
127
  ) -> List["ChatMessage"]:
121
128
  r"""Converts OpenAI API messages to Reka API messages.
122
129
 
@@ -173,9 +180,52 @@ class RekaModel(BaseModelBackend):
173
180
  )
174
181
  return self._token_counter
175
182
 
176
- def run(
183
+ async def _arun(
184
+ self,
185
+ messages: List[OpenAIMessage],
186
+ response_format: Optional[Type[BaseModel]] = None,
187
+ tools: Optional[List[Dict[str, Any]]] = None,
188
+ ) -> ChatCompletion:
189
+ r"""Runs inference of Mistral chat completion.
190
+
191
+ Args:
192
+ messages (List[OpenAIMessage]): Message list with the chat history
193
+ in OpenAI API format.
194
+
195
+ Returns:
196
+ ChatCompletion.
197
+ """
198
+ reka_messages = self._convert_openai_to_reka_messages(messages)
199
+
200
+ response = await self._async_client.chat.create(
201
+ messages=reka_messages,
202
+ model=self.model_type,
203
+ **self.model_config_dict,
204
+ )
205
+
206
+ openai_response = self._convert_reka_to_openai_response(response)
207
+
208
+ # Add AgentOps LLM Event tracking
209
+ if LLMEvent:
210
+ llm_event = LLMEvent(
211
+ thread_id=openai_response.id,
212
+ prompt=" ".join(
213
+ [message.get("content") for message in messages] # type: ignore[misc]
214
+ ),
215
+ prompt_tokens=openai_response.usage.input_tokens, # type: ignore[union-attr]
216
+ completion=openai_response.choices[0].message.content,
217
+ completion_tokens=openai_response.usage.output_tokens, # type: ignore[union-attr]
218
+ model=self.model_type,
219
+ )
220
+ record(llm_event)
221
+
222
+ return openai_response
223
+
224
+ def _run(
177
225
  self,
178
226
  messages: List[OpenAIMessage],
227
+ response_format: Optional[Type[BaseModel]] = None,
228
+ tools: Optional[List[Dict[str, Any]]] = None,
179
229
  ) -> ChatCompletion:
180
230
  r"""Runs inference of Mistral chat completion.
181
231
 
@@ -15,10 +15,11 @@ import json
15
15
  import os
16
16
  import time
17
17
  import uuid
18
- from typing import Any, Dict, List, Optional, Union
18
+ from typing import Any, Dict, List, Optional, Type, Union
19
19
 
20
20
  import httpx
21
- from openai import OpenAI, Stream
21
+ from openai import AsyncOpenAI, AsyncStream, OpenAI, Stream
22
+ from pydantic import BaseModel
22
23
 
23
24
  from camel.configs import (
24
25
  SAMBA_CLOUD_API_PARAMS,
@@ -105,6 +106,12 @@ class SambaModel(BaseModelBackend):
105
106
  base_url=self._url,
106
107
  api_key=self._api_key,
107
108
  )
109
+ self._async_client = AsyncOpenAI(
110
+ timeout=180,
111
+ max_retries=3,
112
+ base_url=self._url,
113
+ api_key=self._api_key,
114
+ )
108
115
 
109
116
  @property
110
117
  def token_counter(self) -> BaseTokenCounter:
@@ -148,8 +155,35 @@ class SambaModel(BaseModelBackend):
148
155
  " SambaNova service"
149
156
  )
150
157
 
151
- def run( # type: ignore[misc]
152
- self, messages: List[OpenAIMessage]
158
+ async def _arun( # type: ignore[misc]
159
+ self,
160
+ messages: List[OpenAIMessage],
161
+ response_format: Optional[Type[BaseModel]] = None,
162
+ tools: Optional[List[Dict[str, Any]]] = None,
163
+ ) -> Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
164
+ r"""Runs SambaNova's service.
165
+
166
+ Args:
167
+ messages (List[OpenAIMessage]): Message list with the chat history
168
+ in OpenAI API format.
169
+
170
+ Returns:
171
+ Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
172
+ `ChatCompletion` in the non-stream mode, or
173
+ `AsyncStream[ChatCompletionChunk]` in the stream mode.
174
+ """
175
+ if "tools" in self.model_config_dict:
176
+ del self.model_config_dict["tools"]
177
+ if self.model_config_dict.get("stream") is True:
178
+ return await self._arun_streaming(messages)
179
+ else:
180
+ return await self._arun_non_streaming(messages)
181
+
182
+ def _run( # type: ignore[misc]
183
+ self,
184
+ messages: List[OpenAIMessage],
185
+ response_format: Optional[Type[BaseModel]] = None,
186
+ tools: Optional[List[Dict[str, Any]]] = None,
153
187
  ) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
154
188
  r"""Runs SambaNova's service.
155
189
 
@@ -398,3 +432,174 @@ class SambaModel(BaseModelBackend):
398
432
  bool: Whether the model is in stream mode.
399
433
  """
400
434
  return self.model_config_dict.get('stream', False)
435
+
436
+ async def _arun_streaming(
437
+ self, messages: List[OpenAIMessage]
438
+ ) -> AsyncStream[ChatCompletionChunk]:
439
+ r"""Handles streaming inference with SambaNova's API.
440
+
441
+ Args:
442
+ messages (List[OpenAIMessage]): A list of messages representing the
443
+ chat history in OpenAI API format.
444
+
445
+ Returns:
446
+ AsyncStream[ChatCompletionChunk]: A generator yielding
447
+ `ChatCompletionChunk` objects as they are received from the
448
+ API.
449
+
450
+ Raises:
451
+ RuntimeError: If the HTTP request fails.
452
+ ValueError: If the API doesn't support stream mode.
453
+ """
454
+ # Handle SambaNova's Cloud API
455
+ if self._url == "https://api.sambanova.ai/v1":
456
+ response = await self._async_client.chat.completions.create(
457
+ messages=messages,
458
+ model=self.model_type,
459
+ **self.model_config_dict,
460
+ )
461
+
462
+ # Add AgentOps LLM Event tracking
463
+ if LLMEvent:
464
+ llm_event = LLMEvent(
465
+ thread_id=response.id,
466
+ prompt=" ".join(
467
+ [message.get("content") for message in messages] # type: ignore[misc]
468
+ ),
469
+ prompt_tokens=response.usage.prompt_tokens, # type: ignore[union-attr]
470
+ completion=response.choices[0].message.content,
471
+ completion_tokens=response.usage.completion_tokens, # type: ignore[union-attr]
472
+ model=self.model_type,
473
+ )
474
+ record(llm_event)
475
+
476
+ return response
477
+
478
+ elif self._url == "https://sambaverse.sambanova.ai/api/predict":
479
+ raise ValueError(
480
+ "https://sambaverse.sambanova.ai/api/predict doesn't support"
481
+ " stream mode"
482
+ )
483
+ raise RuntimeError(f"Unknown URL: {self._url}")
484
+
485
+ async def _arun_non_streaming(
486
+ self, messages: List[OpenAIMessage]
487
+ ) -> ChatCompletion:
488
+ r"""Handles non-streaming inference with SambaNova's API.
489
+
490
+ Args:
491
+ messages (List[OpenAIMessage]): A list of messages representing the
492
+ message in OpenAI API format.
493
+
494
+ Returns:
495
+ ChatCompletion: A `ChatCompletion` object containing the complete
496
+ response from the API.
497
+
498
+ Raises:
499
+ RuntimeError: If the HTTP request fails.
500
+ ValueError: If the JSON response cannot be decoded or is missing
501
+ expected data.
502
+ """
503
+ # Handle SambaNova's Cloud API
504
+ if self._url == "https://api.sambanova.ai/v1":
505
+ response = await self._async_client.chat.completions.create(
506
+ messages=messages,
507
+ model=self.model_type,
508
+ **self.model_config_dict,
509
+ )
510
+
511
+ # Add AgentOps LLM Event tracking
512
+ if LLMEvent:
513
+ llm_event = LLMEvent(
514
+ thread_id=response.id,
515
+ prompt=" ".join(
516
+ [message.get("content") for message in messages] # type: ignore[misc]
517
+ ),
518
+ prompt_tokens=response.usage.prompt_tokens, # type: ignore[union-attr]
519
+ completion=response.choices[0].message.content,
520
+ completion_tokens=response.usage.completion_tokens, # type: ignore[union-attr]
521
+ model=self.model_type,
522
+ )
523
+ record(llm_event)
524
+
525
+ return response
526
+
527
+ # Handle SambaNova's Sambaverse API
528
+ else:
529
+ headers = {
530
+ "Content-Type": "application/json",
531
+ "key": str(self._api_key),
532
+ "modelName": self.model_type,
533
+ }
534
+
535
+ data = {
536
+ "instance": json.dumps(
537
+ {
538
+ "conversation_id": str(uuid.uuid4()),
539
+ "messages": messages,
540
+ }
541
+ ),
542
+ "params": {
543
+ "do_sample": {"type": "bool", "value": "true"},
544
+ "max_tokens_to_generate": {
545
+ "type": "int",
546
+ "value": str(self.model_config_dict.get("max_tokens")),
547
+ },
548
+ "process_prompt": {"type": "bool", "value": "true"},
549
+ "repetition_penalty": {
550
+ "type": "float",
551
+ "value": str(
552
+ self.model_config_dict.get("repetition_penalty")
553
+ ),
554
+ },
555
+ "return_token_count_only": {
556
+ "type": "bool",
557
+ "value": "false",
558
+ },
559
+ "select_expert": {
560
+ "type": "str",
561
+ "value": self.model_type.split("/")[1],
562
+ },
563
+ "stop_sequences": {
564
+ "type": "str",
565
+ "value": self.model_config_dict.get("stop_sequences"),
566
+ },
567
+ "temperature": {
568
+ "type": "float",
569
+ "value": str(
570
+ self.model_config_dict.get("temperature")
571
+ ),
572
+ },
573
+ "top_k": {
574
+ "type": "int",
575
+ "value": str(self.model_config_dict.get("top_k")),
576
+ },
577
+ "top_p": {
578
+ "type": "float",
579
+ "value": str(self.model_config_dict.get("top_p")),
580
+ },
581
+ },
582
+ }
583
+
584
+ try:
585
+ # Send the request and handle the response
586
+ with httpx.Client() as client:
587
+ response = client.post(
588
+ self._url, # type: ignore[arg-type]
589
+ headers=headers,
590
+ json=data,
591
+ )
592
+
593
+ raw_text = response.text
594
+ # Split the string into two dictionaries
595
+ dicts = raw_text.split("}\n{")
596
+
597
+ # Keep only the last dictionary
598
+ last_dict = "{" + dicts[-1]
599
+
600
+ # Parse the dictionary
601
+ last_dict = json.loads(last_dict)
602
+ return self._sambaverse_to_openai_response(last_dict) # type: ignore[arg-type]
603
+
604
+ except httpx.HTTPStatusError:
605
+ raise RuntimeError(f"HTTP request failed: {raw_text}")