camel-ai 0.2.66__py3-none-any.whl → 0.2.68__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (68) hide show
  1. camel/__init__.py +1 -1
  2. camel/configs/__init__.py +3 -0
  3. camel/configs/qianfan_config.py +85 -0
  4. camel/environments/__init__.py +12 -0
  5. camel/environments/rlcards_env.py +860 -0
  6. camel/interpreters/docker/Dockerfile +2 -5
  7. camel/loaders/firecrawl_reader.py +4 -4
  8. camel/memories/blocks/vectordb_block.py +8 -1
  9. camel/memories/context_creators/score_based.py +123 -19
  10. camel/models/__init__.py +2 -0
  11. camel/models/aiml_model.py +8 -0
  12. camel/models/anthropic_model.py +122 -2
  13. camel/models/aws_bedrock_model.py +8 -0
  14. camel/models/azure_openai_model.py +14 -5
  15. camel/models/base_model.py +4 -0
  16. camel/models/cohere_model.py +9 -2
  17. camel/models/crynux_model.py +8 -0
  18. camel/models/deepseek_model.py +8 -0
  19. camel/models/gemini_model.py +8 -0
  20. camel/models/groq_model.py +8 -0
  21. camel/models/internlm_model.py +8 -0
  22. camel/models/litellm_model.py +5 -0
  23. camel/models/lmstudio_model.py +14 -1
  24. camel/models/mistral_model.py +15 -1
  25. camel/models/model_factory.py +6 -0
  26. camel/models/modelscope_model.py +8 -0
  27. camel/models/moonshot_model.py +8 -0
  28. camel/models/nemotron_model.py +17 -2
  29. camel/models/netmind_model.py +8 -0
  30. camel/models/novita_model.py +8 -0
  31. camel/models/nvidia_model.py +8 -0
  32. camel/models/ollama_model.py +8 -0
  33. camel/models/openai_compatible_model.py +23 -5
  34. camel/models/openai_model.py +21 -4
  35. camel/models/openrouter_model.py +8 -0
  36. camel/models/ppio_model.py +8 -0
  37. camel/models/qianfan_model.py +104 -0
  38. camel/models/qwen_model.py +8 -0
  39. camel/models/reka_model.py +18 -3
  40. camel/models/samba_model.py +17 -3
  41. camel/models/sglang_model.py +20 -5
  42. camel/models/siliconflow_model.py +8 -0
  43. camel/models/stub_model.py +8 -1
  44. camel/models/togetherai_model.py +8 -0
  45. camel/models/vllm_model.py +7 -0
  46. camel/models/volcano_model.py +14 -1
  47. camel/models/watsonx_model.py +4 -1
  48. camel/models/yi_model.py +8 -0
  49. camel/models/zhipuai_model.py +8 -0
  50. camel/societies/workforce/prompts.py +71 -22
  51. camel/societies/workforce/role_playing_worker.py +3 -8
  52. camel/societies/workforce/single_agent_worker.py +37 -9
  53. camel/societies/workforce/task_channel.py +25 -20
  54. camel/societies/workforce/utils.py +104 -14
  55. camel/societies/workforce/worker.py +98 -16
  56. camel/societies/workforce/workforce.py +1289 -101
  57. camel/societies/workforce/workforce_logger.py +613 -0
  58. camel/tasks/task.py +16 -5
  59. camel/toolkits/__init__.py +2 -0
  60. camel/toolkits/code_execution.py +1 -1
  61. camel/toolkits/playwright_mcp_toolkit.py +2 -1
  62. camel/toolkits/pptx_toolkit.py +4 -4
  63. camel/types/enums.py +32 -0
  64. camel/types/unified_model_type.py +5 -0
  65. {camel_ai-0.2.66.dist-info → camel_ai-0.2.68.dist-info}/METADATA +4 -3
  66. {camel_ai-0.2.66.dist-info → camel_ai-0.2.68.dist-info}/RECORD +68 -64
  67. {camel_ai-0.2.66.dist-info → camel_ai-0.2.68.dist-info}/WHEEL +0 -0
  68. {camel_ai-0.2.66.dist-info → camel_ai-0.2.68.dist-info}/licenses/LICENSE +0 -0
@@ -55,11 +55,8 @@ RUN curl -fsSL https://install.python-poetry.org | python3.10 - && \
55
55
  # Upgrade pip and install base Python packages
56
56
  RUN python3.10 -m pip install --upgrade pip setuptools wheel
57
57
 
58
- # Install uv
59
- RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
60
- mv /root/.local/bin/uv /usr/local/bin/uv && \
61
- mv /root/.local/bin/uvx /usr/local/bin/uvx && \
62
- chmod +x /usr/local/bin/uv /usr/local/bin/uvx
58
+ # Install uv using pip instead of the shell script
59
+ RUN pip install uv
63
60
 
64
61
  # Setup working directory
65
62
  WORKDIR /workspace
@@ -98,8 +98,8 @@ class Firecrawl:
98
98
  def scrape(
99
99
  self,
100
100
  url: str,
101
- params: Optional[Dict[str, Any]] = None,
102
- ) -> Dict:
101
+ params: Optional[Dict[str, str]] = None,
102
+ ) -> Dict[str, str]:
103
103
  r"""To scrape a single URL. This function supports advanced scraping
104
104
  by setting different parameters and returns the full scraped data as a
105
105
  dictionary.
@@ -108,11 +108,11 @@ class Firecrawl:
108
108
 
109
109
  Args:
110
110
  url (str): The URL to read.
111
- params (Optional[Dict[str, Any]]): Additional parameters for the
111
+ params (Optional[Dict[str, str]]): Additional parameters for the
112
112
  scrape request.
113
113
 
114
114
  Returns:
115
- Dict: The scraped data.
115
+ Dict[str, str]: The scraped data.
116
116
 
117
117
  Raises:
118
118
  RuntimeError: If the scrape process fails.
@@ -89,13 +89,20 @@ class VectorDBBlock(MemoryBlock):
89
89
  records (List[MemoryRecord]): Memory records to be added to the
90
90
  memory.
91
91
  """
92
+ # Filter out records with empty message content
93
+ valid_records = [
94
+ record
95
+ for record in records
96
+ if record.message.content and record.message.content.strip()
97
+ ]
98
+
92
99
  v_records = [
93
100
  VectorRecord(
94
101
  vector=self.embedding.embed(record.message.content),
95
102
  payload=record.to_dict(),
96
103
  id=str(record.uuid),
97
104
  )
98
- for record in records
105
+ for record in valid_records
99
106
  ]
100
107
  self.storage.add(v_records)
101
108
 
@@ -11,7 +11,7 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
- from typing import List, Optional, Tuple
14
+ from typing import Dict, List, Optional, Tuple
15
15
 
16
16
  from pydantic import BaseModel
17
17
 
@@ -74,6 +74,8 @@ class ScoreBasedContextCreator(BaseContextCreator):
74
74
  3. Final output maintains chronological order and in history memory,
75
75
  the score of each message decreases according to keep_rate. The
76
76
  newer the message, the higher the score.
77
+ 4. Tool calls and their responses are kept together to maintain
78
+ API compatibility
77
79
 
78
80
  Args:
79
81
  records (List[ContextRecord]): List of context records with scores
@@ -126,12 +128,17 @@ class ScoreBasedContextCreator(BaseContextCreator):
126
128
  )
127
129
 
128
130
  # ======================
129
- # 3. Token Calculation
131
+ # 3. Tool Call Relationship Mapping
132
+ # ======================
133
+ tool_call_groups = self._group_tool_calls_and_responses(regular_units)
134
+
135
+ # ======================
136
+ # 4. Token Calculation
130
137
  # ======================
131
138
  total_tokens = system_tokens + sum(u.num_tokens for u in regular_units)
132
139
 
133
140
  # ======================
134
- # 4. Early Return if Within Limit
141
+ # 5. Early Return if Within Limit
135
142
  # ======================
136
143
  if total_tokens <= self.token_limit:
137
144
  sorted_units = sorted(
@@ -140,7 +147,7 @@ class ScoreBasedContextCreator(BaseContextCreator):
140
147
  return self._assemble_output(sorted_units, system_unit)
141
148
 
142
149
  # ======================
143
- # 5. Truncation Logic
150
+ # 6. Truncation Logic with Tool Call Awareness
144
151
  # ======================
145
152
  logger.warning(
146
153
  f"Context truncation required "
@@ -148,24 +155,12 @@ class ScoreBasedContextCreator(BaseContextCreator):
148
155
  f"pruning low-score messages."
149
156
  )
150
157
 
151
- # Sort for truncation: high scores first, older messages first at same
152
- # score
153
- sorted_for_truncation = sorted(
154
- regular_units, key=self._truncation_sort_key
158
+ remaining_units = self._truncate_with_tool_call_awareness(
159
+ regular_units, tool_call_groups, system_tokens
155
160
  )
156
161
 
157
- # Reverse to process from lowest score (end of sorted list)
158
- remaining_units = []
159
- current_total = system_tokens
160
-
161
- for unit in sorted_for_truncation:
162
- potential_total = current_total + unit.num_tokens
163
- if potential_total <= self.token_limit:
164
- remaining_units.append(unit)
165
- current_total = potential_total
166
-
167
162
  # ======================
168
- # 6. Output Assembly
163
+ # 7. Output Assembly
169
164
  # ======================
170
165
 
171
166
  # In case system message is the only message in memory when sorted
@@ -180,6 +175,115 @@ class ScoreBasedContextCreator(BaseContextCreator):
180
175
  final_units = sorted(remaining_units, key=self._conversation_sort_key)
181
176
  return self._assemble_output(final_units, system_unit)
182
177
 
178
+ def _group_tool_calls_and_responses(
179
+ self, units: List[_ContextUnit]
180
+ ) -> Dict[str, List[_ContextUnit]]:
181
+ r"""Groups tool calls with their corresponding responses.
182
+
183
+ Args:
184
+ units (List[_ContextUnit]): List of context units to analyze
185
+
186
+ Returns:
187
+ Dict[str, List[_ContextUnit]]: Mapping from tool_call_id to list of
188
+ related units (tool call + responses)
189
+ """
190
+ tool_call_groups: Dict[str, List[_ContextUnit]] = {}
191
+
192
+ for unit in units:
193
+ message = unit.record.memory_record.message
194
+ backend_role = unit.record.memory_record.role_at_backend
195
+
196
+ # Check if this is a tool call message
197
+ if hasattr(message, 'func_name') and hasattr(
198
+ message, 'tool_call_id'
199
+ ):
200
+ tool_call_id = getattr(message, 'tool_call_id', None)
201
+ if tool_call_id:
202
+ if tool_call_id not in tool_call_groups:
203
+ tool_call_groups[tool_call_id] = []
204
+ tool_call_groups[tool_call_id].append(unit)
205
+
206
+ # Check if this is a tool response message
207
+ elif backend_role == OpenAIBackendRole.FUNCTION:
208
+ tool_call_id = None
209
+ if hasattr(message, 'tool_call_id'):
210
+ tool_call_id = getattr(message, 'tool_call_id', None)
211
+ elif hasattr(message, 'result') and hasattr(
212
+ message, 'tool_call_id'
213
+ ):
214
+ tool_call_id = getattr(message, 'tool_call_id', None)
215
+
216
+ if tool_call_id:
217
+ if tool_call_id not in tool_call_groups:
218
+ tool_call_groups[tool_call_id] = []
219
+ tool_call_groups[tool_call_id].append(unit)
220
+
221
+ return tool_call_groups
222
+
223
+ def _truncate_with_tool_call_awareness(
224
+ self,
225
+ regular_units: List[_ContextUnit],
226
+ tool_call_groups: Dict[str, List[_ContextUnit]],
227
+ system_tokens: int,
228
+ ) -> List[_ContextUnit]:
229
+ r"""Truncates messages while preserving tool call-response pairs.
230
+
231
+ Args:
232
+ regular_units (List[_ContextUnit]): All regular message units
233
+ tool_call_groups (Dict[str, List[_ContextUnit]]): Grouped tool
234
+ calls
235
+ system_tokens (int): Tokens used by system message
236
+
237
+ Returns:
238
+ List[_ContextUnit]: Units that fit within token limit
239
+ """
240
+ # Create sets for quick lookup of tool call related units
241
+ tool_call_unit_ids = set()
242
+ for group in tool_call_groups.values():
243
+ for unit in group:
244
+ tool_call_unit_ids.add(unit.record.memory_record.uuid)
245
+
246
+ # Separate tool call groups and standalone units
247
+ standalone_units = [
248
+ u
249
+ for u in regular_units
250
+ if u.record.memory_record.uuid not in tool_call_unit_ids
251
+ ]
252
+
253
+ # Sort standalone units for truncation (high scores first)
254
+ standalone_units.sort(key=self._truncation_sort_key)
255
+
256
+ # Sort tool call groups by their best (highest) score
257
+ sorted_tool_groups = []
258
+ for _tool_call_id, group in tool_call_groups.items():
259
+ # Use the highest score in the group as the group's score
260
+ best_score = max(unit.record.score for unit in group)
261
+ latest_timestamp = max(unit.record.timestamp for unit in group)
262
+ group_tokens = sum(unit.num_tokens for unit in group)
263
+ sorted_tool_groups.append(
264
+ ((-best_score, -latest_timestamp), group, group_tokens)
265
+ )
266
+
267
+ sorted_tool_groups.sort(key=lambda x: x[0])
268
+
269
+ # Greedy selection to fit within token limit
270
+ remaining_units = []
271
+ current_tokens = system_tokens
272
+
273
+ # First, try to include complete tool call groups
274
+ for _, group, group_tokens in sorted_tool_groups:
275
+ if current_tokens + group_tokens <= self.token_limit:
276
+ remaining_units.extend(group)
277
+ current_tokens += group_tokens
278
+
279
+ # Then, include standalone units
280
+ for unit in standalone_units:
281
+ if current_tokens + unit.num_tokens <= self.token_limit:
282
+ remaining_units.append(unit)
283
+ current_tokens += unit.num_tokens
284
+
285
+ return remaining_units
286
+
183
287
  def _extract_system_message(
184
288
  self, records: List[ContextRecord]
185
289
  ) -> Tuple[Optional[_ContextUnit], List[_ContextUnit]]:
camel/models/__init__.py CHANGED
@@ -41,6 +41,7 @@ from .openai_compatible_model import OpenAICompatibleModel
41
41
  from .openai_model import OpenAIModel
42
42
  from .openrouter_model import OpenRouterModel
43
43
  from .ppio_model import PPIOModel
44
+ from .qianfan_model import QianfanModel
44
45
  from .qwen_model import QwenModel
45
46
  from .reka_model import RekaModel
46
47
  from .samba_model import SambaModel
@@ -97,5 +98,6 @@ __all__ = [
97
98
  'VolcanoModel',
98
99
  'LMStudioModel',
99
100
  'WatsonXModel',
101
+ 'QianfanModel',
100
102
  'CrynuxModel',
101
103
  ]
@@ -46,6 +46,10 @@ class AIMLModel(OpenAICompatibleModel):
46
46
  API calls. If not provided, will fall back to the MODEL_TIMEOUT
47
47
  environment variable or default to 180 seconds.
48
48
  (default: :obj:`None`)
49
+ max_retries (int, optional): Maximum number of retries for API calls.
50
+ (default: :obj:`3`)
51
+ **kwargs (Any): Additional arguments to pass to the client
52
+ initialization.
49
53
  """
50
54
 
51
55
  @api_keys_required([("api_key", "AIML_API_KEY")])
@@ -57,6 +61,8 @@ class AIMLModel(OpenAICompatibleModel):
57
61
  url: Optional[str] = None,
58
62
  token_counter: Optional[BaseTokenCounter] = None,
59
63
  timeout: Optional[float] = None,
64
+ max_retries: int = 3,
65
+ **kwargs: Any,
60
66
  ) -> None:
61
67
  if model_config_dict is None:
62
68
  model_config_dict = AIMLConfig().as_dict()
@@ -73,6 +79,8 @@ class AIMLModel(OpenAICompatibleModel):
73
79
  url=url,
74
80
  token_counter=token_counter,
75
81
  timeout=timeout,
82
+ max_retries=max_retries,
83
+ **kwargs,
76
84
  )
77
85
 
78
86
  def check_model_config(self):
@@ -12,11 +12,14 @@
12
12
  # limitations under the License.
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
  import os
15
- from typing import Any, Dict, Optional, Union
15
+ from typing import Any, Dict, List, Optional, Union
16
+
17
+ from openai import AsyncStream, Stream
16
18
 
17
19
  from camel.configs import ANTHROPIC_API_PARAMS, AnthropicConfig
20
+ from camel.messages import OpenAIMessage
18
21
  from camel.models.openai_compatible_model import OpenAICompatibleModel
19
- from camel.types import ModelType
22
+ from camel.types import ChatCompletion, ChatCompletionChunk, ModelType
20
23
  from camel.utils import (
21
24
  AnthropicTokenCounter,
22
25
  BaseTokenCounter,
@@ -25,6 +28,47 @@ from camel.utils import (
25
28
  )
26
29
 
27
30
 
31
+ def strip_trailing_whitespace_from_messages(
32
+ messages: List[OpenAIMessage],
33
+ ) -> List[OpenAIMessage]:
34
+ r"""Strip trailing whitespace from all message contents in a list of
35
+ messages. This is necessary because the Anthropic API doesn't allow
36
+ trailing whitespace in message content.
37
+
38
+ Args:
39
+ messages (List[OpenAIMessage]): List of messages to process
40
+
41
+ Returns:
42
+ List[OpenAIMessage]: The processed messages with trailing whitespace
43
+ removed
44
+ """
45
+ if not messages:
46
+ return messages
47
+
48
+ # Create a deep copy to avoid modifying the original messages
49
+ processed_messages = [dict(msg) for msg in messages]
50
+
51
+ # Process each message
52
+ for msg in processed_messages:
53
+ if "content" in msg and msg["content"] is not None:
54
+ if isinstance(msg["content"], str):
55
+ msg["content"] = msg["content"].rstrip()
56
+ elif isinstance(msg["content"], list):
57
+ # Handle content that's a list of content parts (e.g., for
58
+ # multimodal content)
59
+ for i, part in enumerate(msg["content"]):
60
+ if (
61
+ isinstance(part, dict)
62
+ and "text" in part
63
+ and isinstance(part["text"], str)
64
+ ):
65
+ part["text"] = part["text"].rstrip()
66
+ elif isinstance(part, str):
67
+ msg["content"][i] = part.rstrip()
68
+
69
+ return processed_messages # type: ignore[return-value]
70
+
71
+
28
72
  class AnthropicModel(OpenAICompatibleModel):
29
73
  r"""Anthropic API in a unified OpenAICompatibleModel interface.
30
74
 
@@ -46,6 +90,10 @@ class AnthropicModel(OpenAICompatibleModel):
46
90
  API calls. If not provided, will fall back to the MODEL_TIMEOUT
47
91
  environment variable or default to 180 seconds.
48
92
  (default: :obj:`None`)
93
+ max_retries (int, optional): Maximum number of retries for API calls.
94
+ (default: :obj:`3`)
95
+ **kwargs (Any): Additional arguments to pass to the client
96
+ initialization.
49
97
  """
50
98
 
51
99
  @api_keys_required(
@@ -62,6 +110,8 @@ class AnthropicModel(OpenAICompatibleModel):
62
110
  url: Optional[str] = None,
63
111
  token_counter: Optional[BaseTokenCounter] = None,
64
112
  timeout: Optional[float] = None,
113
+ max_retries: int = 3,
114
+ **kwargs: Any,
65
115
  ) -> None:
66
116
  if model_config_dict is None:
67
117
  model_config_dict = AnthropicConfig().as_dict()
@@ -79,8 +129,13 @@ class AnthropicModel(OpenAICompatibleModel):
79
129
  url=url,
80
130
  token_counter=token_counter,
81
131
  timeout=timeout,
132
+ max_retries=max_retries,
133
+ **kwargs,
82
134
  )
83
135
 
136
+ # Monkey patch the AnthropicTokenCounter to handle trailing whitespace
137
+ self._patch_anthropic_token_counter()
138
+
84
139
  @property
85
140
  def token_counter(self) -> BaseTokenCounter:
86
141
  r"""Initialize the token counter for the model backend.
@@ -107,3 +162,68 @@ class AnthropicModel(OpenAICompatibleModel):
107
162
  f"Unexpected argument `{param}` is "
108
163
  "input into Anthropic model backend."
109
164
  )
165
+
166
+ def _request_chat_completion(
167
+ self,
168
+ messages: List[OpenAIMessage],
169
+ tools: Optional[List[Dict[str, Any]]] = None,
170
+ ) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
171
+ # Strip trailing whitespace from all message contents to prevent
172
+ # Anthropic API errors
173
+ processed_messages = strip_trailing_whitespace_from_messages(messages)
174
+
175
+ # Call the parent class method
176
+ return super()._request_chat_completion(processed_messages, tools)
177
+
178
+ async def _arequest_chat_completion(
179
+ self,
180
+ messages: List[OpenAIMessage],
181
+ tools: Optional[List[Dict[str, Any]]] = None,
182
+ ) -> Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
183
+ # Strip trailing whitespace from all message contents to prevent
184
+ # Anthropic API errors
185
+ processed_messages = strip_trailing_whitespace_from_messages(messages)
186
+
187
+ # Call the parent class method
188
+ return await super()._arequest_chat_completion(
189
+ processed_messages, tools
190
+ )
191
+
192
+ def _patch_anthropic_token_counter(self):
193
+ r"""Monkey patch the AnthropicTokenCounter class to handle trailing
194
+ whitespace.
195
+
196
+ This patches the count_tokens_from_messages method to strip trailing
197
+ whitespace from message content before sending to the Anthropic API.
198
+ """
199
+ import functools
200
+
201
+ from anthropic.types import MessageParam
202
+
203
+ from camel.utils import AnthropicTokenCounter
204
+
205
+ original_count_tokens = (
206
+ AnthropicTokenCounter.count_tokens_from_messages
207
+ )
208
+
209
+ @functools.wraps(original_count_tokens)
210
+ def patched_count_tokens(self, messages):
211
+ # Process messages to remove trailing whitespace
212
+ processed_messages = strip_trailing_whitespace_from_messages(
213
+ messages
214
+ )
215
+
216
+ # Use the processed messages with the original method
217
+ return self.client.messages.count_tokens(
218
+ messages=[
219
+ MessageParam(
220
+ content=str(msg["content"]),
221
+ role="user" if msg["role"] == "user" else "assistant",
222
+ )
223
+ for msg in processed_messages
224
+ ],
225
+ model=self.model,
226
+ ).input_tokens
227
+
228
+ # Apply the monkey patch
229
+ AnthropicTokenCounter.count_tokens_from_messages = patched_count_tokens
@@ -50,6 +50,10 @@ class AWSBedrockModel(OpenAICompatibleModel):
50
50
  API calls. If not provided, will fall back to the MODEL_TIMEOUT
51
51
  environment variable or default to 180 seconds.
52
52
  (default: :obj:`None`)
53
+ max_retries (int, optional): Maximum number of retries for API calls.
54
+ (default: :obj:`3`)
55
+ **kwargs (Any): Additional arguments to pass to the client
56
+ initialization.
53
57
 
54
58
  References:
55
59
  https://docs.aws.amazon.com/bedrock/latest/APIReference/welcome.html
@@ -69,6 +73,8 @@ class AWSBedrockModel(OpenAICompatibleModel):
69
73
  url: Optional[str] = None,
70
74
  token_counter: Optional[BaseTokenCounter] = None,
71
75
  timeout: Optional[float] = None,
76
+ max_retries: int = 3,
77
+ **kwargs: Any,
72
78
  ) -> None:
73
79
  if model_config_dict is None:
74
80
  model_config_dict = BedrockConfig().as_dict()
@@ -84,6 +90,8 @@ class AWSBedrockModel(OpenAICompatibleModel):
84
90
  url=url,
85
91
  token_counter=token_counter,
86
92
  timeout=timeout,
93
+ max_retries=max_retries,
94
+ **kwargs,
87
95
  )
88
96
 
89
97
  async def _arun(
@@ -76,7 +76,10 @@ class AzureOpenAIModel(BaseModelBackend):
76
76
  API calls. If not provided, will fall back to the MODEL_TIMEOUT
77
77
  environment variable or default to 180 seconds.
78
78
  (default: :obj:`None`)
79
-
79
+ max_retries (int, optional): Maximum number of retries for API calls.
80
+ (default: :obj:`3`)
81
+ **kwargs (Any): Additional arguments to pass to the client
82
+ initialization.
80
83
 
81
84
  References:
82
85
  https://learn.microsoft.com/en-us/azure/ai-services/openai/
@@ -94,6 +97,8 @@ class AzureOpenAIModel(BaseModelBackend):
94
97
  azure_deployment_name: Optional[str] = None,
95
98
  azure_ad_token_provider: Optional["AzureADTokenProvider"] = None,
96
99
  azure_ad_token: Optional[str] = None,
100
+ max_retries: int = 3,
101
+ **kwargs: Any,
97
102
  ) -> None:
98
103
  if model_config_dict is None:
99
104
  model_config_dict = ChatGPTConfig().as_dict()
@@ -135,7 +140,8 @@ class AzureOpenAIModel(BaseModelBackend):
135
140
  azure_ad_token=self._azure_ad_token,
136
141
  azure_ad_token_provider=self.azure_ad_token_provider,
137
142
  timeout=self._timeout,
138
- max_retries=3,
143
+ max_retries=max_retries,
144
+ **kwargs,
139
145
  )
140
146
  self._async_client = LangfuseAsyncOpenAI(
141
147
  azure_endpoint=str(self._url),
@@ -145,7 +151,8 @@ class AzureOpenAIModel(BaseModelBackend):
145
151
  azure_ad_token=self._azure_ad_token,
146
152
  azure_ad_token_provider=self.azure_ad_token_provider,
147
153
  timeout=self._timeout,
148
- max_retries=3,
154
+ max_retries=max_retries,
155
+ **kwargs,
149
156
  )
150
157
  else:
151
158
  self._client = AzureOpenAI(
@@ -156,7 +163,8 @@ class AzureOpenAIModel(BaseModelBackend):
156
163
  azure_ad_token=self._azure_ad_token,
157
164
  azure_ad_token_provider=self.azure_ad_token_provider,
158
165
  timeout=self._timeout,
159
- max_retries=3,
166
+ max_retries=max_retries,
167
+ **kwargs,
160
168
  )
161
169
 
162
170
  self._async_client = AsyncAzureOpenAI(
@@ -167,7 +175,8 @@ class AzureOpenAIModel(BaseModelBackend):
167
175
  azure_ad_token=self._azure_ad_token,
168
176
  azure_ad_token_provider=self.azure_ad_token_provider,
169
177
  timeout=self._timeout,
170
- max_retries=3,
178
+ max_retries=max_retries,
179
+ **kwargs,
171
180
  )
172
181
 
173
182
  @property
@@ -71,6 +71,8 @@ class BaseModelBackend(ABC, metaclass=ModelBackendMeta):
71
71
  :obj:`OpenAITokenCounter` will be used. (default: :obj:`None`)
72
72
  timeout (Optional[float], optional): The timeout value in seconds for
73
73
  API calls. (default: :obj:`None`)
74
+ max_retries (int, optional): Maximum number of retries
75
+ for API calls. (default: :obj:`3`)
74
76
  """
75
77
 
76
78
  def __init__(
@@ -81,6 +83,7 @@ class BaseModelBackend(ABC, metaclass=ModelBackendMeta):
81
83
  url: Optional[str] = None,
82
84
  token_counter: Optional[BaseTokenCounter] = None,
83
85
  timeout: Optional[float] = None,
86
+ max_retries: int = 3,
84
87
  ) -> None:
85
88
  self.model_type: UnifiedModelType = UnifiedModelType(model_type)
86
89
  if model_config_dict is None:
@@ -90,6 +93,7 @@ class BaseModelBackend(ABC, metaclass=ModelBackendMeta):
90
93
  self._url = url
91
94
  self._token_counter = token_counter
92
95
  self._timeout = timeout
96
+ self._max_retries = max_retries
93
97
  self.check_model_config()
94
98
 
95
99
  @property
@@ -76,6 +76,8 @@ class CohereModel(BaseModelBackend):
76
76
  API calls. If not provided, will fall back to the MODEL_TIMEOUT
77
77
  environment variable or default to 180 seconds.
78
78
  (default: :obj:`None`)
79
+ **kwargs (Any): Additional arguments to pass to the client
80
+ initialization.
79
81
  """
80
82
 
81
83
  @api_keys_required(
@@ -91,6 +93,7 @@ class CohereModel(BaseModelBackend):
91
93
  url: Optional[str] = None,
92
94
  token_counter: Optional[BaseTokenCounter] = None,
93
95
  timeout: Optional[float] = None,
96
+ **kwargs: Any,
94
97
  ):
95
98
  import cohere
96
99
 
@@ -105,10 +108,14 @@ class CohereModel(BaseModelBackend):
105
108
  model_type, model_config_dict, api_key, url, token_counter, timeout
106
109
  )
107
110
  self._client = cohere.ClientV2(
108
- timeout=self._timeout, api_key=self._api_key
111
+ timeout=self._timeout,
112
+ api_key=self._api_key,
113
+ **kwargs,
109
114
  )
110
115
  self._async_client = cohere.AsyncClientV2(
111
- timeout=self._timeout, api_key=self._api_key
116
+ timeout=self._timeout,
117
+ api_key=self._api_key,
118
+ **kwargs,
112
119
  )
113
120
 
114
121
  def _to_openai_response(self, response: 'ChatResponse') -> ChatCompletion:
@@ -46,6 +46,10 @@ class CrynuxModel(OpenAICompatibleModel):
46
46
  API calls. If not provided, will fall back to the MODEL_TIMEOUT
47
47
  environment variable or default to 180 seconds.
48
48
  (default: :obj:`None`)
49
+ max_retries (int, optional): Maximum number of retries for API calls.
50
+ (default: :obj:`3`)
51
+ **kwargs (Any): Additional arguments to pass to the client
52
+ initialization.
49
53
  """
50
54
 
51
55
  @api_keys_required(
@@ -61,6 +65,8 @@ class CrynuxModel(OpenAICompatibleModel):
61
65
  url: Optional[str] = None,
62
66
  token_counter: Optional[BaseTokenCounter] = None,
63
67
  timeout: Optional[float] = None,
68
+ max_retries: int = 3,
69
+ **kwargs: Any,
64
70
  ) -> None:
65
71
  if model_config_dict is None:
66
72
  model_config_dict = CrynuxConfig().as_dict()
@@ -76,6 +82,8 @@ class CrynuxModel(OpenAICompatibleModel):
76
82
  url=url,
77
83
  token_counter=token_counter,
78
84
  timeout=timeout,
85
+ max_retries=max_retries,
86
+ **kwargs,
79
87
  )
80
88
 
81
89
  def check_model_config(self):
@@ -78,6 +78,10 @@ class DeepSeekModel(OpenAICompatibleModel):
78
78
  API calls. If not provided, will fall back to the MODEL_TIMEOUT
79
79
  environment variable or default to 180 seconds.
80
80
  (default: :obj:`None`)
81
+ max_retries (int, optional): Maximum number of retries for API calls.
82
+ (default: :obj:`3`)
83
+ **kwargs (Any): Additional arguments to pass to the client
84
+ initialization.
81
85
 
82
86
  References:
83
87
  https://api-docs.deepseek.com/
@@ -96,6 +100,8 @@ class DeepSeekModel(OpenAICompatibleModel):
96
100
  url: Optional[str] = None,
97
101
  token_counter: Optional[BaseTokenCounter] = None,
98
102
  timeout: Optional[float] = None,
103
+ max_retries: int = 3,
104
+ **kwargs: Any,
99
105
  ) -> None:
100
106
  if model_config_dict is None:
101
107
  model_config_dict = DeepSeekConfig().as_dict()
@@ -112,6 +118,8 @@ class DeepSeekModel(OpenAICompatibleModel):
112
118
  url=url,
113
119
  token_counter=token_counter,
114
120
  timeout=timeout,
121
+ max_retries=max_retries,
122
+ **kwargs,
115
123
  )
116
124
 
117
125
  def _prepare_request(