camel-ai 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (76) hide show
  1. camel/__init__.py +6 -1
  2. camel/agents/chat_agent.py +87 -6
  3. camel/agents/deductive_reasoner_agent.py +4 -1
  4. camel/benchmarks/__init__.py +18 -0
  5. camel/benchmarks/base.py +152 -0
  6. camel/benchmarks/gaia.py +478 -0
  7. camel/configs/__init__.py +6 -0
  8. camel/configs/mistral_config.py +0 -3
  9. camel/configs/nvidia_config.py +70 -0
  10. camel/configs/ollama_config.py +4 -2
  11. camel/configs/sglang_config.py +71 -0
  12. camel/configs/vllm_config.py +10 -1
  13. camel/data_collector/__init__.py +19 -0
  14. camel/data_collector/alpaca_collector.py +127 -0
  15. camel/data_collector/base.py +211 -0
  16. camel/data_collector/sharegpt_collector.py +205 -0
  17. camel/datahubs/__init__.py +23 -0
  18. camel/datahubs/base.py +136 -0
  19. camel/datahubs/huggingface.py +433 -0
  20. camel/datahubs/models.py +22 -0
  21. camel/embeddings/vlm_embedding.py +4 -1
  22. camel/interpreters/__init__.py +2 -0
  23. camel/interpreters/docker_interpreter.py +7 -2
  24. camel/interpreters/e2b_interpreter.py +136 -0
  25. camel/interpreters/subprocess_interpreter.py +7 -2
  26. camel/loaders/__init__.py +3 -1
  27. camel/loaders/base_io.py +41 -41
  28. camel/loaders/firecrawl_reader.py +0 -3
  29. camel/logger.py +112 -0
  30. camel/messages/__init__.py +3 -1
  31. camel/messages/base.py +10 -7
  32. camel/messages/conversion/__init__.py +3 -1
  33. camel/messages/conversion/alpaca.py +122 -0
  34. camel/models/__init__.py +7 -0
  35. camel/models/anthropic_model.py +14 -4
  36. camel/models/base_model.py +28 -0
  37. camel/models/groq_model.py +1 -1
  38. camel/models/model_factory.py +6 -0
  39. camel/models/model_manager.py +212 -0
  40. camel/models/nvidia_model.py +141 -0
  41. camel/models/ollama_model.py +12 -0
  42. camel/models/openai_model.py +0 -25
  43. camel/models/reward/__init__.py +22 -0
  44. camel/models/reward/base_reward_model.py +58 -0
  45. camel/models/reward/evaluator.py +63 -0
  46. camel/models/reward/nemotron_model.py +112 -0
  47. camel/models/sglang_model.py +225 -0
  48. camel/models/vllm_model.py +1 -1
  49. camel/personas/persona_hub.py +2 -2
  50. camel/retrievers/vector_retriever.py +22 -5
  51. camel/schemas/openai_converter.py +2 -2
  52. camel/societies/babyagi_playing.py +4 -1
  53. camel/societies/workforce/role_playing_worker.py +2 -2
  54. camel/societies/workforce/single_agent_worker.py +2 -2
  55. camel/societies/workforce/workforce.py +3 -3
  56. camel/storages/object_storages/amazon_s3.py +2 -2
  57. camel/storages/object_storages/azure_blob.py +2 -2
  58. camel/storages/object_storages/google_cloud.py +2 -2
  59. camel/toolkits/__init__.py +5 -0
  60. camel/toolkits/code_execution.py +42 -4
  61. camel/toolkits/function_tool.py +41 -0
  62. camel/toolkits/human_toolkit.py +1 -0
  63. camel/toolkits/math_toolkit.py +47 -16
  64. camel/toolkits/meshy_toolkit.py +185 -0
  65. camel/toolkits/search_toolkit.py +154 -2
  66. camel/toolkits/stripe_toolkit.py +273 -0
  67. camel/toolkits/twitter_toolkit.py +3 -0
  68. camel/types/__init__.py +2 -0
  69. camel/types/enums.py +68 -10
  70. camel/utils/commons.py +22 -5
  71. camel/utils/token_counting.py +26 -11
  72. {camel_ai-0.2.10.dist-info → camel_ai-0.2.12.dist-info}/METADATA +13 -6
  73. {camel_ai-0.2.10.dist-info → camel_ai-0.2.12.dist-info}/RECORD +76 -51
  74. /camel/messages/conversion/{models.py → conversation_models.py} +0 -0
  75. {camel_ai-0.2.10.dist-info → camel_ai-0.2.12.dist-info}/LICENSE +0 -0
  76. {camel_ai-0.2.10.dist-info → camel_ai-0.2.12.dist-info}/WHEEL +0 -0
@@ -12,7 +12,7 @@
12
12
  # limitations under the License.
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
  import os
15
- from typing import Any, Dict, List, Optional, Union
15
+ from typing import Any, Dict, List, Literal, Optional, Union
16
16
 
17
17
  from camel.configs import ANTHROPIC_API_PARAMS, AnthropicConfig
18
18
  from camel.messages import OpenAIMessage
@@ -94,19 +94,29 @@ class AnthropicModel(BaseModelBackend):
94
94
  tokenization style.
95
95
  """
96
96
  if not self._token_counter:
97
- self._token_counter = AnthropicTokenCounter()
97
+ self._token_counter = AnthropicTokenCounter(self.model_type)
98
98
  return self._token_counter
99
99
 
100
- def count_tokens_from_prompt(self, prompt: str) -> int:
100
+ @dependencies_required('anthropic')
101
+ def count_tokens_from_prompt(
102
+ self, prompt: str, role: Literal["user", "assistant"]
103
+ ) -> int:
101
104
  r"""Count the number of tokens from a prompt.
102
105
 
103
106
  Args:
104
107
  prompt (str): The prompt string.
108
+ role (Literal["user", "assistant"]): The role of the message
109
+ sender, either "user" or "assistant".
105
110
 
106
111
  Returns:
107
112
  int: The number of tokens in the prompt.
108
113
  """
109
- return self.client.count_tokens(prompt)
114
+ from anthropic.types.beta import BetaMessageParam
115
+
116
+ return self.client.beta.messages.count_tokens(
117
+ messages=[BetaMessageParam(content=prompt, role=role)],
118
+ model=self.model_type,
119
+ ).input_tokens
110
120
 
111
121
  @api_keys_required("ANTHROPIC_API_KEY")
112
122
  def run(
@@ -21,6 +21,7 @@ from camel.types import (
21
21
  ChatCompletion,
22
22
  ChatCompletionChunk,
23
23
  ModelType,
24
+ ParsedChatCompletion,
24
25
  UnifiedModelType,
25
26
  )
26
27
  from camel.utils import BaseTokenCounter
@@ -114,6 +115,33 @@ class BaseModelBackend(ABC):
114
115
  """
115
116
  return self.token_counter.count_tokens_from_messages(messages)
116
117
 
118
+ def _to_chat_completion(
119
+ self, response: ParsedChatCompletion
120
+ ) -> ChatCompletion:
121
+ if len(response.choices) > 1:
122
+ print("Warning: Multiple response choices detected")
123
+
124
+ choice = dict(
125
+ index=response.choices[0].index,
126
+ message={
127
+ "role": response.choices[0].message.role,
128
+ "content": response.choices[0].message.content,
129
+ "tool_calls": response.choices[0].message.tool_calls,
130
+ "parsed": response.choices[0].message.parsed,
131
+ },
132
+ finish_reason=response.choices[0].finish_reason,
133
+ )
134
+
135
+ obj = ChatCompletion.construct(
136
+ id=response.id,
137
+ choices=[choice],
138
+ created=response.created,
139
+ model=response.model,
140
+ object="chat.completion",
141
+ usage=response.usage,
142
+ )
143
+ return obj
144
+
117
145
  @property
118
146
  def token_limit(self) -> int:
119
147
  r"""Returns the maximum token limit for a given model.
@@ -63,7 +63,7 @@ class GroqModel(BaseModelBackend):
63
63
  model_config_dict = GroqConfig().as_dict()
64
64
  api_key = api_key or os.environ.get("GROQ_API_KEY")
65
65
  url = url or os.environ.get(
66
- "GROQ_API_BASE_URL" or "https://api.groq.com/openai/v1"
66
+ "GROQ_API_BASE_URL", "https://api.groq.com/openai/v1"
67
67
  )
68
68
  super().__init__(
69
69
  model_type, model_config_dict, api_key, url, token_counter
@@ -22,12 +22,14 @@ from camel.models.gemini_model import GeminiModel
22
22
  from camel.models.groq_model import GroqModel
23
23
  from camel.models.litellm_model import LiteLLMModel
24
24
  from camel.models.mistral_model import MistralModel
25
+ from camel.models.nvidia_model import NvidiaModel
25
26
  from camel.models.ollama_model import OllamaModel
26
27
  from camel.models.openai_compatible_model import OpenAICompatibleModel
27
28
  from camel.models.openai_model import OpenAIModel
28
29
  from camel.models.qwen_model import QwenModel
29
30
  from camel.models.reka_model import RekaModel
30
31
  from camel.models.samba_model import SambaModel
32
+ from camel.models.sglang_model import SGLangModel
31
33
  from camel.models.stub_model import StubModel
32
34
  from camel.models.togetherai_model import TogetherAIModel
33
35
  from camel.models.vllm_model import VLLMModel
@@ -85,6 +87,8 @@ class ModelFactory:
85
87
  model_class = OllamaModel
86
88
  elif model_platform.is_vllm:
87
89
  model_class = VLLMModel
90
+ elif model_platform.is_sglang:
91
+ model_class = SGLangModel
88
92
  elif model_platform.is_openai_compatible_model:
89
93
  model_class = OpenAICompatibleModel
90
94
  elif model_platform.is_samba:
@@ -93,6 +97,8 @@ class ModelFactory:
93
97
  model_class = TogetherAIModel
94
98
  elif model_platform.is_litellm:
95
99
  model_class = LiteLLMModel
100
+ elif model_platform.is_nvidia:
101
+ model_class = NvidiaModel
96
102
 
97
103
  elif model_platform.is_openai and model_type.is_openai:
98
104
  model_class = OpenAIModel
@@ -0,0 +1,212 @@
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ import logging
16
+ from itertools import cycle
17
+ from random import choice
18
+ from typing import (
19
+ Any,
20
+ Callable,
21
+ Dict,
22
+ List,
23
+ Union,
24
+ )
25
+
26
+ from openai import Stream
27
+
28
+ from camel.messages import OpenAIMessage
29
+ from camel.models.base_model import BaseModelBackend
30
+ from camel.types import (
31
+ ChatCompletion,
32
+ ChatCompletionChunk,
33
+ UnifiedModelType,
34
+ )
35
+ from camel.utils import BaseTokenCounter
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ class ModelProcessingError(Exception):
41
+ r"""Raised when an error occurs during model processing."""
42
+
43
+ pass
44
+
45
+
46
+ class ModelManager:
47
+ r"""ModelManager choosing a model from provided list.
48
+ Models are picked according to defined strategy.
49
+
50
+ Args:
51
+ models(Union[BaseModelBackend, List[BaseModelBackend]]):
52
+ model backend or list of model backends
53
+ (e.g., model instances, APIs)
54
+ scheduling_strategy (str): name of function that defines how
55
+ to select the next model. (default: :str:`round_robin`)
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ models: Union[BaseModelBackend, List[BaseModelBackend]],
61
+ scheduling_strategy: str = "round_robin",
62
+ ):
63
+ if isinstance(models, list):
64
+ self.models = models
65
+ else:
66
+ self.models = [models]
67
+ self.models_cycle = cycle(self.models)
68
+ self.current_model = self.models[0]
69
+
70
+ # Set the scheduling strategy; default is round-robin
71
+ try:
72
+ self.scheduling_strategy = getattr(self, scheduling_strategy)
73
+ except AttributeError:
74
+ logger.warning(
75
+ f"Provided strategy: {scheduling_strategy} is not implemented."
76
+ f"Using default 'round robin'"
77
+ )
78
+ self.scheduling_strategy = self.round_robin
79
+
80
+ @property
81
+ def model_type(self) -> UnifiedModelType:
82
+ r"""Return type of the current model.
83
+
84
+ Returns:
85
+ Union[ModelType, str]: Current model type.
86
+ """
87
+ return self.current_model.model_type
88
+
89
+ @property
90
+ def model_config_dict(self) -> Dict[str, Any]:
91
+ r"""Return model_config_dict of the current model.
92
+
93
+ Returns:
94
+ Dict[str, Any]: Config dictionary of the current model.
95
+ """
96
+ return self.current_model.model_config_dict
97
+
98
+ @model_config_dict.setter
99
+ def model_config_dict(self, model_config_dict: Dict[str, Any]):
100
+ r"""Set model_config_dict to the current model.
101
+
102
+ Args:
103
+ model_config_dict (Dict[str, Any]): Config dictionary to be set at
104
+ current model.
105
+ """
106
+ self.current_model.model_config_dict = model_config_dict
107
+
108
+ @property
109
+ def current_model_index(self) -> int:
110
+ r"""Return the index of current model in self.models list.
111
+
112
+ Returns:
113
+ int: index of current model in given list of models.
114
+ """
115
+ return self.models.index(self.current_model)
116
+
117
+ @property
118
+ def token_limit(self):
119
+ r"""Returns the maximum token limit for current model.
120
+
121
+ This method retrieves the maximum token limit either from the
122
+ `model_config_dict` or from the model's default token limit.
123
+
124
+ Returns:
125
+ int: The maximum token limit for the given model.
126
+ """
127
+ return self.current_model.token_limit
128
+
129
+ @property
130
+ def token_counter(self) -> BaseTokenCounter:
131
+ r"""Return token_counter of the current model.
132
+
133
+ Returns:
134
+ BaseTokenCounter: The token counter following the model's
135
+ tokenization style.
136
+ """
137
+ return self.current_model.token_counter
138
+
139
+ def add_strategy(self, name: str, strategy_fn: Callable):
140
+ r"""Add a scheduling strategy method provided by user in case when none
141
+ of existent strategies fits.
142
+ When custom strategy is provided, it will be set as
143
+ "self.scheduling_strategy" attribute.
144
+
145
+ Args:
146
+ name (str): The name of the strategy.
147
+ strategy_fn (Callable): The scheduling strategy function.
148
+ """
149
+ if not callable(strategy_fn):
150
+ raise ValueError("strategy_fn must be a callable function.")
151
+ setattr(self, name, strategy_fn.__get__(self))
152
+ self.scheduling_strategy = getattr(self, name)
153
+ logger.info(f"Custom strategy '{name}' added.")
154
+
155
+ # Strategies
156
+ def round_robin(self) -> BaseModelBackend:
157
+ r"""Return models one by one in simple round-robin fashion.
158
+
159
+ Returns:
160
+ BaseModelBackend for processing incoming messages.
161
+ """
162
+ return next(self.models_cycle)
163
+
164
+ def always_first(self) -> BaseModelBackend:
165
+ r"""Always return the first model from self.models.
166
+
167
+ Returns:
168
+ BaseModelBackend for processing incoming messages.
169
+ """
170
+ return self.models[0]
171
+
172
+ def random_model(self) -> BaseModelBackend:
173
+ r"""Return random model from self.models list.
174
+
175
+ Returns:
176
+ BaseModelBackend for processing incoming messages.
177
+ """
178
+ return choice(self.models)
179
+
180
+ def run(
181
+ self, messages: List[OpenAIMessage]
182
+ ) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
183
+ r"""Process a list of messages by selecting a model based on
184
+ the scheduling strategy.
185
+ Sends the entire list of messages to the selected model,
186
+ and returns a single response.
187
+
188
+ Args:
189
+ messages (List[OpenAIMessage]): Message list with the chat
190
+ history in OpenAI API format.
191
+
192
+ Returns:
193
+ Union[ChatCompletion, Stream[ChatCompletionChunk]]:
194
+ `ChatCompletion` in the non-stream mode, or
195
+ `Stream[ChatCompletionChunk]` in the stream mode.
196
+ """
197
+ self.current_model = self.scheduling_strategy()
198
+
199
+ # Pass all messages to the selected model and get the response
200
+ try:
201
+ response = self.current_model.run(messages)
202
+ except Exception as exc:
203
+ logger.error(f"Error processing with model: {self.current_model}")
204
+ if self.scheduling_strategy == self.always_first:
205
+ self.scheduling_strategy = self.round_robin
206
+ logger.warning(
207
+ "The scheduling strategy has been changed to 'round_robin'"
208
+ )
209
+ # Skip already used one
210
+ self.current_model = self.scheduling_strategy()
211
+ raise exc
212
+ return response
@@ -0,0 +1,141 @@
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ import os
16
+ from typing import Any, Dict, List, Optional, Union
17
+
18
+ from openai import OpenAI, Stream
19
+ from openai.types.chat import (
20
+ ChatCompletion,
21
+ ChatCompletionChunk,
22
+ )
23
+
24
+ from camel.configs import NVIDIA_API_PARAMS, NvidiaConfig
25
+ from camel.messages import OpenAIMessage
26
+ from camel.models import BaseModelBackend
27
+ from camel.types import ModelType
28
+ from camel.utils import BaseTokenCounter, OpenAITokenCounter, api_keys_required
29
+
30
+
31
+ class NvidiaModel(BaseModelBackend):
32
+ r"""NVIDIA API in a unified BaseModelBackend interface.
33
+
34
+ Args:
35
+ model_type (Union[ModelType, str]): Model for which a backend is
36
+ created, one of NVIDIA series.
37
+ model_config_dict (Optional[Dict[str, Any]], optional): A dictionary
38
+ that will be fed into:obj:`openai.ChatCompletion.create()`. If
39
+ :obj:`None`, :obj:`NvidiaConfig().as_dict()` will be used.
40
+ (default: :obj:`None`)
41
+ api_key (Optional[str], optional): The API key for authenticating with
42
+ the NVIDIA service. (default: :obj:`None`)
43
+ url (Optional[str], optional): The url to the NVIDIA service.
44
+ (default: :obj:`None`)
45
+ token_counter (Optional[BaseTokenCounter], optional): Token counter to
46
+ use for the model. If not provided, :obj:`OpenAITokenCounter(
47
+ ModelType.GPT_4)` will be used.
48
+ (default: :obj:`None`)
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ model_type: Union[ModelType, str],
54
+ model_config_dict: Optional[Dict[str, Any]] = None,
55
+ api_key: Optional[str] = None,
56
+ url: Optional[str] = None,
57
+ token_counter: Optional[BaseTokenCounter] = None,
58
+ ) -> None:
59
+ if model_config_dict is None:
60
+ model_config_dict = NvidiaConfig().as_dict()
61
+ api_key = api_key or os.environ.get("NVIDIA_API_KEY")
62
+ url = url or os.environ.get(
63
+ "NVIDIA_API_BASE_URL", "https://integrate.api.nvidia.com/v1"
64
+ )
65
+ super().__init__(
66
+ model_type, model_config_dict, api_key, url, token_counter
67
+ )
68
+ self._client = OpenAI(
69
+ timeout=60,
70
+ max_retries=3,
71
+ api_key=self._api_key,
72
+ base_url=self._url,
73
+ )
74
+
75
+ @api_keys_required("NVIDIA_API_KEY")
76
+ def run(
77
+ self,
78
+ messages: List[OpenAIMessage],
79
+ ) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
80
+ r"""Runs inference of NVIDIA chat completion.
81
+
82
+ Args:
83
+ messages (List[OpenAIMessage]): Message list with the chat history
84
+ in OpenAI API format.
85
+
86
+ Returns:
87
+ Union[ChatCompletion, Stream[ChatCompletionChunk]]:
88
+ `ChatCompletion` in the non-stream mode, or
89
+ `Stream[ChatCompletionChunk]` in the stream mode.
90
+ """
91
+
92
+ # Remove tool-related parameters if no tools are specified
93
+ config = dict(self.model_config_dict)
94
+ if not config.get('tools'): # None or empty list
95
+ config.pop('tools', None)
96
+ config.pop('tool_choice', None)
97
+
98
+ response = self._client.chat.completions.create(
99
+ messages=messages,
100
+ model=self.model_type,
101
+ **config,
102
+ )
103
+ return response
104
+
105
+ @property
106
+ def token_counter(self) -> BaseTokenCounter:
107
+ r"""Initialize the token counter for the model backend.
108
+
109
+ Returns:
110
+ OpenAITokenCounter: The token counter following the model's
111
+ tokenization style.
112
+ """
113
+
114
+ if not self._token_counter:
115
+ self._token_counter = OpenAITokenCounter(ModelType.GPT_4O_MINI)
116
+ return self._token_counter
117
+
118
+ def check_model_config(self):
119
+ r"""Check whether the model configuration contains any
120
+ unexpected arguments to NVIDIA API.
121
+
122
+ Raises:
123
+ ValueError: If the model configuration dictionary contains any
124
+ unexpected arguments to NVIDIA API.
125
+ """
126
+ for param in self.model_config_dict:
127
+ if param not in NVIDIA_API_PARAMS:
128
+ raise ValueError(
129
+ f"Unexpected argument `{param}` is "
130
+ "input into NVIDIA model backend."
131
+ )
132
+
133
+ @property
134
+ def stream(self) -> bool:
135
+ r"""Returns whether the model is in stream mode, which sends partial
136
+ results each time.
137
+
138
+ Returns:
139
+ bool: Whether the model is in stream mode.
140
+ """
141
+ return self.model_config_dict.get('stream', False)
@@ -134,6 +134,18 @@ class OllamaModel(BaseModelBackend):
134
134
  `ChatCompletion` in the non-stream mode, or
135
135
  `Stream[ChatCompletionChunk]` in the stream mode.
136
136
  """
137
+ if self.model_config_dict.get("response_format"):
138
+ # stream is not supported in beta.chat.completions.parse
139
+ if "stream" in self.model_config_dict:
140
+ del self.model_config_dict["stream"]
141
+
142
+ response = self._client.beta.chat.completions.parse(
143
+ messages=messages,
144
+ model=self.model_type,
145
+ **self.model_config_dict,
146
+ )
147
+
148
+ return self._to_chat_completion(response)
137
149
 
138
150
  response = self._client.chat.completions.create(
139
151
  messages=messages,
@@ -24,7 +24,6 @@ from camel.types import (
24
24
  ChatCompletion,
25
25
  ChatCompletionChunk,
26
26
  ModelType,
27
- ParsedChatCompletion,
28
27
  )
29
28
  from camel.utils import (
30
29
  BaseTokenCounter,
@@ -148,30 +147,6 @@ class OpenAIModel(BaseModelBackend):
148
147
  )
149
148
  return response
150
149
 
151
- def _to_chat_completion(
152
- self, response: "ParsedChatCompletion"
153
- ) -> ChatCompletion:
154
- # TODO: Handle n > 1 or warn consumers it's not supported
155
- choice = dict(
156
- index=response.choices[0].index,
157
- message={
158
- "role": response.choices[0].message.role,
159
- "content": response.choices[0].message.content,
160
- "tool_calls": response.choices[0].message.tool_calls,
161
- },
162
- finish_reason=response.choices[0].finish_reason,
163
- )
164
-
165
- obj = ChatCompletion.construct(
166
- id=response.id,
167
- choices=[choice],
168
- created=response.created,
169
- model=response.model,
170
- object="chat.completion",
171
- usage=response.usage,
172
- )
173
- return obj
174
-
175
150
  def check_model_config(self):
176
151
  r"""Check whether the model configuration contains any
177
152
  unexpected arguments to OpenAI API.
@@ -0,0 +1,22 @@
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from .base_reward_model import BaseRewardModel
15
+ from .evaluator import Evaluator
16
+ from .nemotron_model import NemotronRewardModel
17
+
18
+ __all__ = [
19
+ 'BaseRewardModel',
20
+ 'NemotronRewardModel',
21
+ 'Evaluator',
22
+ ]
@@ -0,0 +1,58 @@
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from abc import ABC, abstractmethod
15
+ from typing import Dict, List, Optional, Union
16
+
17
+ from camel.types import ModelType
18
+
19
+
20
+ class BaseRewardModel(ABC):
21
+ r"""Abstract base class for reward models. Reward models are used to
22
+ evaluate messages and return scores based on different criteria.
23
+
24
+ Subclasses should implement the 'evaluate' and 'get_scores_types' methods.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ model_type: Union[ModelType, str],
30
+ api_key: Optional[str] = None,
31
+ url: Optional[str] = None,
32
+ ) -> None:
33
+ self.model_type = model_type
34
+ self.api_key = api_key
35
+ self.url = url
36
+
37
+ @abstractmethod
38
+ def evaluate(self, messages: List[Dict[str, str]]) -> Dict[str, float]:
39
+ r"""Evaluate the messages and return scores based on different
40
+ criteria.
41
+
42
+ Args:
43
+ messages (List[Dict[str, str]]): A list of messages where each
44
+ message is a dictionary with 'role' and 'content'.
45
+
46
+ Returns:
47
+ Dict[str, float]: A dictionary mapping score types to their values.
48
+ """
49
+ pass
50
+
51
+ @abstractmethod
52
+ def get_scores_types(self) -> List[str]:
53
+ r"""Get the list of score types that the reward model can return.
54
+
55
+ Returns:
56
+ List[str]: A list of score types that the reward model can return.
57
+ """
58
+ pass
@@ -0,0 +1,63 @@
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from typing import Dict, List
15
+
16
+ from camel.models.reward import BaseRewardModel
17
+
18
+
19
+ class Evaluator:
20
+ r"""Evaluator class to evaluate messages using a reward model and filter
21
+ data based on the scores.
22
+
23
+ Args:
24
+ reward_model (BaseRewardModel): A reward model to evaluate messages.
25
+ """
26
+
27
+ def __init__(self, reward_model: BaseRewardModel):
28
+ self.reward_model = reward_model
29
+
30
+ def evaluate(self, messages: List[Dict[str, str]]) -> Dict[str, float]:
31
+ r"""Evaluate the messages using the reward model.
32
+
33
+ Args:
34
+ messages (List[Dict[str, str]]): A list of messages where each
35
+ message is a dictionary with 'role' and 'content'.
36
+
37
+ Returns:
38
+ Dict[str, float]: A dictionary mapping score types to their values.
39
+ """
40
+ scores = self.reward_model.evaluate(messages)
41
+ return scores
42
+
43
+ def filter_data(
44
+ self, messages: List[Dict[str, str]], thresholds: Dict[str, float]
45
+ ) -> bool:
46
+ r"""Filter messages based on the scores.
47
+
48
+ Args:
49
+ messages (List[Dict[str, str]]): A list of messages where each
50
+ message is a dictionary with 'role' and 'content'.
51
+ thresholds (Dict[str, float]): A dictionary mapping score types to
52
+ their values.
53
+
54
+ Returns:
55
+ bool: A boolean indicating whether the messages pass the filter.
56
+ """
57
+ scores = self.evaluate(messages)
58
+ for score_type, threshold in thresholds.items():
59
+ if score_type not in scores:
60
+ raise ValueError(f"Score type {score_type} not found.")
61
+ if scores.get(score_type, 0) < threshold:
62
+ return False
63
+ return True