camel-ai 0.1.5.7__py3-none-any.whl → 0.1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (44) hide show
  1. camel/__init__.py +1 -1
  2. camel/agents/chat_agent.py +2 -2
  3. camel/agents/critic_agent.py +1 -1
  4. camel/agents/deductive_reasoner_agent.py +4 -4
  5. camel/agents/embodied_agent.py +1 -1
  6. camel/agents/knowledge_graph_agent.py +2 -2
  7. camel/agents/role_assignment_agent.py +1 -1
  8. camel/agents/search_agent.py +4 -5
  9. camel/agents/task_agent.py +5 -5
  10. camel/configs/__init__.py +9 -0
  11. camel/configs/groq_config.py +119 -0
  12. camel/configs/mistral_config.py +81 -0
  13. camel/configs/ollama_config.py +1 -1
  14. camel/configs/vllm_config.py +103 -0
  15. camel/embeddings/__init__.py +2 -0
  16. camel/embeddings/mistral_embedding.py +89 -0
  17. camel/interpreters/__init__.py +2 -0
  18. camel/interpreters/ipython_interpreter.py +167 -0
  19. camel/models/__init__.py +8 -0
  20. camel/models/anthropic_model.py +7 -2
  21. camel/models/azure_openai_model.py +152 -0
  22. camel/models/base_model.py +5 -1
  23. camel/models/gemini_model.py +14 -2
  24. camel/models/groq_model.py +131 -0
  25. camel/models/litellm_model.py +10 -4
  26. camel/models/mistral_model.py +169 -0
  27. camel/models/model_factory.py +30 -3
  28. camel/models/ollama_model.py +5 -2
  29. camel/models/open_source_model.py +11 -3
  30. camel/models/openai_model.py +7 -2
  31. camel/models/stub_model.py +4 -4
  32. camel/models/vllm_model.py +138 -0
  33. camel/models/zhipuai_model.py +7 -3
  34. camel/prompts/__init__.py +2 -2
  35. camel/prompts/task_prompt_template.py +4 -4
  36. camel/prompts/{descripte_video_prompt.py → video_description_prompt.py} +1 -1
  37. camel/retrievers/auto_retriever.py +2 -2
  38. camel/storages/graph_storages/neo4j_graph.py +5 -0
  39. camel/types/enums.py +152 -35
  40. camel/utils/__init__.py +2 -0
  41. camel/utils/token_counting.py +148 -40
  42. {camel_ai-0.1.5.7.dist-info → camel_ai-0.1.6.0.dist-info}/METADATA +42 -3
  43. {camel_ai-0.1.5.7.dist-info → camel_ai-0.1.6.0.dist-info}/RECORD +44 -35
  44. {camel_ai-0.1.5.7.dist-info → camel_ai-0.1.6.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,167 @@
1
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2
+ # Licensed under the Apache License, Version 2.0 (the “License”);
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an “AS IS” BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+
15
+ import queue
16
+ import re
17
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
18
+
19
+ from camel.interpreters.base import BaseInterpreter
20
+ from camel.interpreters.interpreter_error import InterpreterError
21
+
22
+ if TYPE_CHECKING:
23
+ from jupyter_client import BlockingKernelClient, KernelManager
24
+
25
+ TIMEOUT = 30
26
+
27
+
28
+ class JupyterKernelInterpreter(BaseInterpreter):
29
+ r"""A class for executing code strings in a Jupyter Kernel.
30
+
31
+ Args:
32
+ require_confirm (bool, optional): If `True`, prompt user before
33
+ running code strings for security. Defaults to `True`.
34
+ print_stdout (bool, optional): If `True`, print the standard
35
+ output of the executed code. Defaults to `False`.
36
+ print_stderr (bool, optional): If `True`, print the standard error
37
+ of the executed code. Defaults to `True`.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ require_confirm: bool = True,
43
+ print_stdout: bool = False,
44
+ print_stderr: bool = True,
45
+ ) -> None:
46
+ self.require_confirm = require_confirm
47
+ self.print_stdout = print_stdout
48
+ self.print_stderr = print_stderr
49
+
50
+ self.kernel_manager: Optional[KernelManager] = None
51
+ self.client: Optional[BlockingKernelClient] = None
52
+
53
+ def __del__(self) -> None:
54
+ r"""Clean up the kernel and client."""
55
+
56
+ if self.kernel_manager:
57
+ self.kernel_manager.shutdown_kernel()
58
+ if self.client:
59
+ self.client.stop_channels()
60
+
61
+ def _initialize_if_needed(self) -> None:
62
+ r"""Initialize the kernel manager and client if they are not already
63
+ initialized.
64
+ """
65
+
66
+ if self.kernel_manager is not None:
67
+ return
68
+
69
+ from jupyter_client.manager import start_new_kernel
70
+
71
+ self.kernel_manager, self.client = start_new_kernel()
72
+
73
+ @staticmethod
74
+ def _clean_ipython_output(output: str) -> str:
75
+ r"""Remove ANSI escape sequences from the output."""
76
+
77
+ ansi_escape = re.compile(r'\x1B[@-_][0-?]*[ -/]*[@-~]')
78
+ return ansi_escape.sub('', output)
79
+
80
+ def _execute(self, code: str, timeout: float) -> str:
81
+ r"""Execute the code in the Jupyter kernel and return the result."""
82
+
83
+ if not self.kernel_manager or not self.client:
84
+ raise InterpreterError("Jupyter client is not initialized.")
85
+
86
+ self.client.execute(code)
87
+ outputs = []
88
+ while True:
89
+ try:
90
+ msg = self.client.get_iopub_msg(timeout=timeout)
91
+ msg_content = msg["content"]
92
+ msg_type = msg.get("msg_type", None)
93
+
94
+ if msg_content.get("execution_state", None) == "idle":
95
+ break
96
+
97
+ if msg_type == "error":
98
+ print(msg_content.keys())
99
+ print(msg_content)
100
+ traceback = "\n".join(msg_content["traceback"])
101
+ outputs.append(traceback)
102
+ elif msg_type == "stream":
103
+ outputs.append(msg_content["text"])
104
+ elif msg_type in ["execute_result", "display_data"]:
105
+ outputs.append(msg_content["data"]["text/plain"])
106
+ if "image/png" in msg_content["data"]:
107
+ outputs.append(
108
+ f"\n![image](data:image/png;base64,{msg_content['data']['image/png']})\n"
109
+ )
110
+ except queue.Empty:
111
+ outputs.append("Time out")
112
+ break
113
+ except Exception as e:
114
+ outputs.append(f"Exception occurred: {e!s}")
115
+ break
116
+
117
+ exec_result = "\n".join(outputs)
118
+ return self._clean_ipython_output(exec_result)
119
+
120
+ def run(self, code: str, code_type: str) -> str:
121
+ r"""Executes the given code in the Jupyter kernel.
122
+
123
+ Args:
124
+ code (str): The code string to execute.
125
+ code_type (str): The type of code to execute (e.g., 'python',
126
+ 'bash').
127
+
128
+ Returns:
129
+ str: A string containing the captured result of the
130
+ executed code.
131
+
132
+ Raises:
133
+ InterpreterError: If there is an error when doing code execution.
134
+ """
135
+ self._initialize_if_needed()
136
+
137
+ if code_type == "bash":
138
+ code = f"%%bash\n({code})"
139
+ try:
140
+ result = self._execute(code, timeout=TIMEOUT)
141
+ except Exception as e:
142
+ raise InterpreterError(f"Execution failed: {e!s}")
143
+
144
+ return result
145
+
146
+ def supported_code_types(self) -> List[str]:
147
+ r"""Provides supported code types by the interpreter.
148
+
149
+ Returns:
150
+ List[str]: Supported code types.
151
+ """
152
+ return ["python", "bash"]
153
+
154
+ def update_action_space(self, action_space: Dict[str, Any]) -> None:
155
+ r"""Updates the action space for the interpreter.
156
+
157
+ Args:
158
+ action_space (Dict[str, Any]): A dictionary representing the
159
+ new or updated action space.
160
+
161
+ Raises:
162
+ RuntimeError: Always raised because `JupyterKernelInterpreter`
163
+ does not support updating the action space.
164
+ """
165
+ raise RuntimeError(
166
+ "SubprocessInterpreter doesn't support " "`action_space`."
167
+ )
camel/models/__init__.py CHANGED
@@ -12,9 +12,12 @@
12
12
  # limitations under the License.
13
13
  # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
14
  from .anthropic_model import AnthropicModel
15
+ from .azure_openai_model import AzureOpenAIModel
15
16
  from .base_model import BaseModelBackend
16
17
  from .gemini_model import GeminiModel
18
+ from .groq_model import GroqModel
17
19
  from .litellm_model import LiteLLMModel
20
+ from .mistral_model import MistralModel
18
21
  from .model_factory import ModelFactory
19
22
  from .nemotron_model import NemotronModel
20
23
  from .ollama_model import OllamaModel
@@ -22,12 +25,16 @@ from .open_source_model import OpenSourceModel
22
25
  from .openai_audio_models import OpenAIAudioModels
23
26
  from .openai_model import OpenAIModel
24
27
  from .stub_model import StubModel
28
+ from .vllm_model import VLLMModel
25
29
  from .zhipuai_model import ZhipuAIModel
26
30
 
27
31
  __all__ = [
28
32
  'BaseModelBackend',
29
33
  'OpenAIModel',
34
+ 'AzureOpenAIModel',
30
35
  'AnthropicModel',
36
+ 'MistralModel',
37
+ 'GroqModel',
31
38
  'StubModel',
32
39
  'ZhipuAIModel',
33
40
  'OpenSourceModel',
@@ -36,5 +43,6 @@ __all__ = [
36
43
  'OpenAIAudioModels',
37
44
  'NemotronModel',
38
45
  'OllamaModel',
46
+ 'VLLMModel',
39
47
  'GeminiModel',
40
48
  ]
@@ -36,6 +36,7 @@ class AnthropicModel(BaseModelBackend):
36
36
  model_config_dict: Dict[str, Any],
37
37
  api_key: Optional[str] = None,
38
38
  url: Optional[str] = None,
39
+ token_counter: Optional[BaseTokenCounter] = None,
39
40
  ) -> None:
40
41
  r"""Constructor for Anthropic backend.
41
42
 
@@ -48,12 +49,16 @@ class AnthropicModel(BaseModelBackend):
48
49
  Anthropic service. (default: :obj:`None`)
49
50
  url (Optional[str]): The url to the Anthropic service. (default:
50
51
  :obj:`None`)
52
+ token_counter (Optional[BaseTokenCounter]): Token counter to use
53
+ for the model. If not provided, `AnthropicTokenCounter` will
54
+ be used.
51
55
  """
52
- super().__init__(model_type, model_config_dict, api_key, url)
56
+ super().__init__(
57
+ model_type, model_config_dict, api_key, url, token_counter
58
+ )
53
59
  self._api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
54
60
  self._url = url or os.environ.get("ANTHROPIC_API_BASE_URL")
55
61
  self.client = Anthropic(api_key=self._api_key, base_url=self._url)
56
- self._token_counter: Optional[BaseTokenCounter] = None
57
62
 
58
63
  def _convert_response_from_anthropic_to_openai(self, response):
59
64
  # openai ^1.0.0 format, reference openai/types/chat/chat_completion.py
@@ -0,0 +1,152 @@
1
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2
+ # Licensed under the Apache License, Version 2.0 (the “License”);
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an “AS IS” BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+ import os
15
+ from typing import Any, Dict, List, Optional, Union
16
+
17
+ from openai import AzureOpenAI, Stream
18
+
19
+ from camel.configs import OPENAI_API_PARAMS
20
+ from camel.messages import OpenAIMessage
21
+ from camel.models.base_model import BaseModelBackend
22
+ from camel.types import ChatCompletion, ChatCompletionChunk, ModelType
23
+ from camel.utils import BaseTokenCounter, OpenAITokenCounter, api_keys_required
24
+
25
+
26
+ class AzureOpenAIModel(BaseModelBackend):
27
+ r"""Azure OpenAI API in a unified BaseModelBackend interface.
28
+ Doc: https://learn.microsoft.com/en-us/azure/ai-services/openai/
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ model_type: ModelType,
34
+ model_config_dict: Dict[str, Any],
35
+ api_key: Optional[str] = None,
36
+ url: Optional[str] = None,
37
+ api_version: Optional[str] = None,
38
+ azure_deployment_name: Optional[str] = None,
39
+ ) -> None:
40
+ r"""Constructor for OpenAI backend.
41
+
42
+ Args:
43
+ model_type (ModelType): Model for which a backend is created,
44
+ one of GPT_* series.
45
+ model_config_dict (Dict[str, Any]): A dictionary that will
46
+ be fed into openai.ChatCompletion.create().
47
+ api_key (Optional[str]): The API key for authenticating with the
48
+ OpenAI service. (default: :obj:`None`)
49
+ url (Optional[str]): The url to the OpenAI service. (default:
50
+ :obj:`None`)
51
+ api_version (Optional[str]): The api version for the model.
52
+ azure_deployment_name (Optional[str]): The deployment name you
53
+ chose when you deployed an azure model. (default: :obj:`None`)
54
+ """
55
+ super().__init__(model_type, model_config_dict, api_key, url)
56
+ self._url = url or os.environ.get("AZURE_OPENAI_ENDPOINT")
57
+ self._api_key = api_key or os.environ.get("AZURE_OPENAI_API_KEY")
58
+ self.api_version = api_version or os.environ.get("AZURE_API_VERSION")
59
+ self.azure_deployment_name = azure_deployment_name or os.environ.get(
60
+ "AZURE_DEPLOYMENT_NAME"
61
+ )
62
+
63
+ if self._url is None:
64
+ raise ValueError(
65
+ "Must provide either the `url` argument "
66
+ "or `AZURE_OPENAI_ENDPOINT` environment variable."
67
+ )
68
+ if self._api_key is None:
69
+ raise ValueError(
70
+ "Must provide either the `api_key` argument "
71
+ "or `AZURE_OPENAI_API_KEY` environment variable."
72
+ )
73
+ if self.api_version is None:
74
+ raise ValueError(
75
+ "Must provide either the `api_version` argument "
76
+ "or `AZURE_API_VERSION` environment variable."
77
+ )
78
+ if self.azure_deployment_name is None:
79
+ raise ValueError(
80
+ "Must provide either the `azure_deployment_name` argument "
81
+ "or `AZURE_DEPLOYMENT_NAME` environment variable."
82
+ )
83
+ self.model = str(self.azure_deployment_name)
84
+
85
+ self._client = AzureOpenAI(
86
+ azure_endpoint=str(self._url),
87
+ azure_deployment=self.azure_deployment_name,
88
+ api_version=self.api_version,
89
+ api_key=self._api_key,
90
+ timeout=60,
91
+ max_retries=3,
92
+ )
93
+ self._token_counter: Optional[BaseTokenCounter] = None
94
+
95
+ @property
96
+ def token_counter(self) -> BaseTokenCounter:
97
+ r"""Initialize the token counter for the model backend.
98
+
99
+ Returns:
100
+ BaseTokenCounter: The token counter following the model's
101
+ tokenization style.
102
+ """
103
+ if not self._token_counter:
104
+ self._token_counter = OpenAITokenCounter(self.model_type)
105
+ return self._token_counter
106
+
107
+ @api_keys_required("AZURE_OPENAI_API_KEY", "AZURE_API_VERSION")
108
+ def run(
109
+ self,
110
+ messages: List[OpenAIMessage],
111
+ ) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
112
+ r"""Runs inference of Azure OpenAI chat completion.
113
+
114
+ Args:
115
+ messages (List[OpenAIMessage]): Message list with the chat history
116
+ in OpenAI API format.
117
+
118
+ Returns:
119
+ Union[ChatCompletion, Stream[ChatCompletionChunk]]:
120
+ `ChatCompletion` in the non-stream mode, or
121
+ `Stream[ChatCompletionChunk]` in the stream mode.
122
+ """
123
+ response = self._client.chat.completions.create(
124
+ messages=messages,
125
+ model=self.model,
126
+ **self.model_config_dict,
127
+ )
128
+ return response
129
+
130
+ def check_model_config(self):
131
+ r"""Check whether the model configuration contains any
132
+ unexpected arguments to Azure OpenAI API.
133
+
134
+ Raises:
135
+ ValueError: If the model configuration dictionary contains any
136
+ unexpected arguments to Azure OpenAI API.
137
+ """
138
+ for param in self.model_config_dict:
139
+ if param not in OPENAI_API_PARAMS:
140
+ raise ValueError(
141
+ f"Unexpected argument `{param}` is "
142
+ "input into Azure OpenAI model backend."
143
+ )
144
+
145
+ @property
146
+ def stream(self) -> bool:
147
+ r"""Returns whether the model is in stream mode,
148
+ which sends partial results each time.
149
+ Returns:
150
+ bool: Whether the model is in stream mode.
151
+ """
152
+ return self.model_config_dict.get("stream", False)
@@ -32,6 +32,7 @@ class BaseModelBackend(ABC):
32
32
  model_config_dict: Dict[str, Any],
33
33
  api_key: Optional[str] = None,
34
34
  url: Optional[str] = None,
35
+ token_counter: Optional[BaseTokenCounter] = None,
35
36
  ) -> None:
36
37
  r"""Constructor for the model backend.
37
38
 
@@ -41,13 +42,16 @@ class BaseModelBackend(ABC):
41
42
  api_key (Optional[str]): The API key for authenticating with the
42
43
  model service.
43
44
  url (Optional[str]): The url to the model service.
45
+ token_counter (Optional[BaseTokenCounter]): Token counter to use
46
+ for the model. If not provided, `OpenAITokenCounter` will
47
+ be used.
44
48
  """
45
49
  self.model_type = model_type
46
-
47
50
  self.model_config_dict = model_config_dict
48
51
  self._api_key = api_key
49
52
  self._url = url
50
53
  self.check_model_config()
54
+ self._token_counter = token_counter
51
55
 
52
56
  @property
53
57
  @abstractmethod
@@ -44,6 +44,7 @@ class GeminiModel(BaseModelBackend):
44
44
  model_config_dict: Dict[str, Any],
45
45
  api_key: Optional[str] = None,
46
46
  url: Optional[str] = None,
47
+ token_counter: Optional[BaseTokenCounter] = None,
47
48
  ) -> None:
48
49
  r"""Constructor for Gemini backend.
49
50
 
@@ -54,17 +55,22 @@ class GeminiModel(BaseModelBackend):
54
55
  api_key (Optional[str]): The API key for authenticating with the
55
56
  gemini service. (default: :obj:`None`)
56
57
  url (Optional[str]): The url to the gemini service.
58
+ token_counter (Optional[BaseTokenCounter]): Token counter to use
59
+ for the model. If not provided, `GeminiTokenCounter` will be
60
+ used.
57
61
  """
58
62
  import os
59
63
 
60
64
  import google.generativeai as genai
61
65
  from google.generativeai.types.generation_types import GenerationConfig
62
66
 
63
- super().__init__(model_type, model_config_dict, api_key, url)
67
+ super().__init__(
68
+ model_type, model_config_dict, api_key, url, token_counter
69
+ )
64
70
  self._api_key = api_key or os.environ.get("GOOGLE_API_KEY")
65
71
  genai.configure(api_key=self._api_key)
66
72
  self._client = genai.GenerativeModel(self.model_type.value)
67
- self._token_counter: Optional[BaseTokenCounter] = None
73
+
68
74
  keys = list(self.model_config_dict.keys())
69
75
  generation_config_dict = {
70
76
  k: self.model_config_dict.pop(k)
@@ -78,6 +84,12 @@ class GeminiModel(BaseModelBackend):
78
84
 
79
85
  @property
80
86
  def token_counter(self) -> BaseTokenCounter:
87
+ r"""Initialize the token counter for the model backend.
88
+
89
+ Returns:
90
+ BaseTokenCounter: The token counter following the model's
91
+ tokenization style.
92
+ """
81
93
  if not self._token_counter:
82
94
  self._token_counter = GeminiTokenCounter(self.model_type)
83
95
  return self._token_counter
@@ -0,0 +1,131 @@
1
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2
+ # Licensed under the Apache License, Version 2.0 (the “License”);
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an “AS IS” BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+ import os
15
+ from typing import Any, Dict, List, Optional, Union
16
+
17
+ from openai import OpenAI, Stream
18
+
19
+ from camel.configs import GROQ_API_PARAMS
20
+ from camel.messages import OpenAIMessage
21
+ from camel.models import BaseModelBackend
22
+ from camel.types import (
23
+ ChatCompletion,
24
+ ChatCompletionChunk,
25
+ ModelType,
26
+ )
27
+ from camel.utils import (
28
+ BaseTokenCounter,
29
+ OpenAITokenCounter,
30
+ api_keys_required,
31
+ )
32
+
33
+
34
+ class GroqModel(BaseModelBackend):
35
+ r"""LLM API served by Groq in a unified BaseModelBackend interface."""
36
+
37
+ def __init__(
38
+ self,
39
+ model_type: ModelType,
40
+ model_config_dict: Dict[str, Any],
41
+ api_key: Optional[str] = None,
42
+ url: Optional[str] = None,
43
+ token_counter: Optional[BaseTokenCounter] = None,
44
+ ) -> None:
45
+ r"""Constructor for Groq backend.
46
+
47
+ Args:
48
+ model_type (str): Model for which a backend is created.
49
+ model_config_dict (Dict[str, Any]): A dictionary of parameters for
50
+ the model configuration.
51
+ api_key (Optional[str]): The API key for authenticating with the
52
+ Groq service. (default: :obj:`None`).
53
+ url (Optional[str]): The url to the Groq service. (default:
54
+ :obj:`None`)
55
+ token_counter (Optional[BaseTokenCounter]): Token counter to use
56
+ for the model. If not provided, `OpenAITokenCounter(ModelType.
57
+ GPT_3_5_TURBO)` will be used.
58
+ """
59
+ super().__init__(
60
+ model_type, model_config_dict, api_key, url, token_counter
61
+ )
62
+ self._url = url or "https://api.groq.com/openai/v1"
63
+ self._api_key = api_key or os.environ.get("GROQ_API_KEY")
64
+ self._client = OpenAI(
65
+ timeout=60,
66
+ max_retries=3,
67
+ api_key=self._api_key,
68
+ base_url=self._url,
69
+ )
70
+ self._token_counter = token_counter
71
+
72
+ @property
73
+ def token_counter(self) -> BaseTokenCounter:
74
+ r"""Initialize the token counter for the model backend.
75
+
76
+ Returns:
77
+ BaseTokenCounter: The token counter following the model's
78
+ tokenization style.
79
+ """
80
+ # Make sure you have the access to these open-source model in
81
+ # HuggingFace
82
+ if not self._token_counter:
83
+ self._token_counter = OpenAITokenCounter(ModelType.GPT_3_5_TURBO)
84
+ return self._token_counter
85
+
86
+ @api_keys_required("GROQ_API_KEY")
87
+ def run(
88
+ self,
89
+ messages: List[OpenAIMessage],
90
+ ) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
91
+ r"""Runs inference of OpenAI chat completion.
92
+
93
+ Args:
94
+ messages (List[OpenAIMessage]): Message list with the chat history
95
+ in OpenAI API format.
96
+
97
+ Returns:
98
+ Union[ChatCompletion, Stream[ChatCompletionChunk]]:
99
+ `ChatCompletion` in the non-stream mode, or
100
+ `Stream[ChatCompletionChunk]` in the stream mode.
101
+ """
102
+ response = self._client.chat.completions.create(
103
+ messages=messages,
104
+ model=self.model_type.value,
105
+ **self.model_config_dict,
106
+ )
107
+
108
+ return response
109
+
110
+ def check_model_config(self):
111
+ r"""Check whether the model configuration contains any unexpected
112
+ arguments to Groq API. But Groq API does not have any additional
113
+ arguments to check.
114
+
115
+ Raises:
116
+ ValueError: If the model configuration dictionary contains any
117
+ unexpected arguments to Groq API.
118
+ """
119
+ for param in self.model_config_dict:
120
+ if param not in GROQ_API_PARAMS:
121
+ raise ValueError(
122
+ f"Unexpected argument `{param}` is "
123
+ "input into Groq model backend."
124
+ )
125
+
126
+ @property
127
+ def stream(self) -> bool:
128
+ r"""Returns whether the model supports streaming. But Groq API does
129
+ not support streaming.
130
+ """
131
+ return False
@@ -16,7 +16,7 @@ from typing import Any, Dict, List, Optional
16
16
  from camel.configs import LITELLM_API_PARAMS
17
17
  from camel.messages import OpenAIMessage
18
18
  from camel.types import ChatCompletion
19
- from camel.utils import LiteLLMTokenCounter
19
+ from camel.utils import BaseTokenCounter, LiteLLMTokenCounter
20
20
 
21
21
 
22
22
  class LiteLLMModel:
@@ -30,6 +30,7 @@ class LiteLLMModel:
30
30
  model_config_dict: Dict[str, Any],
31
31
  api_key: Optional[str] = None,
32
32
  url: Optional[str] = None,
33
+ token_counter: Optional[BaseTokenCounter] = None,
33
34
  ) -> None:
34
35
  r"""Constructor for LiteLLM backend.
35
36
 
@@ -42,11 +43,14 @@ class LiteLLMModel:
42
43
  model service. (default: :obj:`None`)
43
44
  url (Optional[str]): The url to the model service. (default:
44
45
  :obj:`None`)
46
+ token_counter (Optional[BaseTokenCounter]): Token counter to use
47
+ for the model. If not provided, `LiteLLMTokenCounter` will
48
+ be used.
45
49
  """
46
50
  self.model_type = model_type
47
51
  self.model_config_dict = model_config_dict
48
52
  self._client = None
49
- self._token_counter: Optional[LiteLLMTokenCounter] = None
53
+ self._token_counter = token_counter
50
54
  self.check_model_config()
51
55
  self._url = url
52
56
  self._api_key = api_key
@@ -98,8 +102,10 @@ class LiteLLMModel:
98
102
  tokenization style.
99
103
  """
100
104
  if not self._token_counter:
101
- self._token_counter = LiteLLMTokenCounter(self.model_type)
102
- return self._token_counter
105
+ self._token_counter = LiteLLMTokenCounter( # type: ignore[assignment]
106
+ self.model_type
107
+ )
108
+ return self._token_counter # type: ignore[return-value]
103
109
 
104
110
  def run(
105
111
  self,