camel-ai 0.2.65__py3-none-any.whl → 0.2.67__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. camel/__init__.py +1 -1
  2. camel/agents/mcp_agent.py +1 -5
  3. camel/configs/__init__.py +3 -0
  4. camel/configs/qianfan_config.py +85 -0
  5. camel/models/__init__.py +2 -0
  6. camel/models/aiml_model.py +8 -0
  7. camel/models/anthropic_model.py +8 -0
  8. camel/models/aws_bedrock_model.py +8 -0
  9. camel/models/azure_openai_model.py +14 -5
  10. camel/models/base_model.py +4 -0
  11. camel/models/cohere_model.py +9 -2
  12. camel/models/crynux_model.py +8 -0
  13. camel/models/deepseek_model.py +8 -0
  14. camel/models/gemini_model.py +8 -0
  15. camel/models/groq_model.py +8 -0
  16. camel/models/internlm_model.py +8 -0
  17. camel/models/litellm_model.py +5 -0
  18. camel/models/lmstudio_model.py +14 -1
  19. camel/models/mistral_model.py +15 -1
  20. camel/models/model_factory.py +6 -0
  21. camel/models/modelscope_model.py +8 -0
  22. camel/models/moonshot_model.py +8 -0
  23. camel/models/nemotron_model.py +17 -2
  24. camel/models/netmind_model.py +8 -0
  25. camel/models/novita_model.py +8 -0
  26. camel/models/nvidia_model.py +8 -0
  27. camel/models/ollama_model.py +8 -0
  28. camel/models/openai_compatible_model.py +23 -5
  29. camel/models/openai_model.py +21 -4
  30. camel/models/openrouter_model.py +8 -0
  31. camel/models/ppio_model.py +8 -0
  32. camel/models/qianfan_model.py +104 -0
  33. camel/models/qwen_model.py +8 -0
  34. camel/models/reka_model.py +18 -3
  35. camel/models/samba_model.py +17 -3
  36. camel/models/sglang_model.py +20 -5
  37. camel/models/siliconflow_model.py +8 -0
  38. camel/models/stub_model.py +8 -1
  39. camel/models/togetherai_model.py +8 -0
  40. camel/models/vllm_model.py +7 -0
  41. camel/models/volcano_model.py +14 -1
  42. camel/models/watsonx_model.py +4 -1
  43. camel/models/yi_model.py +8 -0
  44. camel/models/zhipuai_model.py +8 -0
  45. camel/societies/workforce/prompts.py +33 -17
  46. camel/societies/workforce/role_playing_worker.py +5 -10
  47. camel/societies/workforce/single_agent_worker.py +3 -5
  48. camel/societies/workforce/task_channel.py +16 -18
  49. camel/societies/workforce/utils.py +104 -65
  50. camel/societies/workforce/workforce.py +1263 -100
  51. camel/societies/workforce/workforce_logger.py +613 -0
  52. camel/tasks/task.py +77 -6
  53. camel/toolkits/__init__.py +2 -0
  54. camel/toolkits/code_execution.py +1 -1
  55. camel/toolkits/function_tool.py +79 -7
  56. camel/toolkits/mcp_toolkit.py +70 -19
  57. camel/toolkits/playwright_mcp_toolkit.py +2 -1
  58. camel/toolkits/pptx_toolkit.py +4 -4
  59. camel/types/enums.py +32 -0
  60. camel/types/unified_model_type.py +5 -0
  61. camel/utils/mcp_client.py +1 -35
  62. {camel_ai-0.2.65.dist-info → camel_ai-0.2.67.dist-info}/METADATA +3 -3
  63. {camel_ai-0.2.65.dist-info → camel_ai-0.2.67.dist-info}/RECORD +65 -62
  64. {camel_ai-0.2.65.dist-info → camel_ai-0.2.67.dist-info}/WHEEL +0 -0
  65. {camel_ai-0.2.65.dist-info → camel_ai-0.2.67.dist-info}/licenses/LICENSE +0 -0
@@ -56,6 +56,10 @@ class ModelScopeModel(OpenAICompatibleModel):
56
56
  API calls. If not provided, will fall back to the MODEL_TIMEOUT
57
57
  environment variable or default to 180 seconds.
58
58
  (default: :obj:`None`)
59
+ max_retries (int, optional): Maximum number of retries for API calls.
60
+ (default: :obj:`3`)
61
+ **kwargs (Any): Additional arguments to pass to the client
62
+ initialization.
59
63
  """
60
64
 
61
65
  @api_keys_required(
@@ -71,6 +75,8 @@ class ModelScopeModel(OpenAICompatibleModel):
71
75
  url: Optional[str] = None,
72
76
  token_counter: Optional[BaseTokenCounter] = None,
73
77
  timeout: Optional[float] = None,
78
+ max_retries: int = 3,
79
+ **kwargs: Any,
74
80
  ) -> None:
75
81
  if model_config_dict is None:
76
82
  model_config_dict = ModelScopeConfig().as_dict()
@@ -87,6 +93,8 @@ class ModelScopeModel(OpenAICompatibleModel):
87
93
  url=url,
88
94
  token_counter=token_counter,
89
95
  timeout=timeout,
96
+ max_retries=max_retries,
97
+ **kwargs,
90
98
  )
91
99
 
92
100
  def _post_handle_response(
@@ -54,6 +54,10 @@ class MoonshotModel(OpenAICompatibleModel):
54
54
  API calls. If not provided, will fall back to the MODEL_TIMEOUT
55
55
  environment variable or default to 180 seconds.
56
56
  (default: :obj:`None`)
57
+ max_retries (int, optional): Maximum number of retries for API calls.
58
+ (default: :obj:`3`)
59
+ **kwargs (Any): Additional arguments to pass to the client
60
+ initialization.
57
61
  """
58
62
 
59
63
  @api_keys_required([("api_key", "MOONSHOT_API_KEY")])
@@ -65,6 +69,8 @@ class MoonshotModel(OpenAICompatibleModel):
65
69
  url: Optional[str] = None,
66
70
  token_counter: Optional[BaseTokenCounter] = None,
67
71
  timeout: Optional[float] = None,
72
+ max_retries: int = 3,
73
+ **kwargs: Any,
68
74
  ) -> None:
69
75
  if model_config_dict is None:
70
76
  model_config_dict = MoonshotConfig().as_dict()
@@ -81,6 +87,8 @@ class MoonshotModel(OpenAICompatibleModel):
81
87
  url=url,
82
88
  token_counter=token_counter,
83
89
  timeout=timeout,
90
+ max_retries=max_retries,
91
+ **kwargs,
84
92
  )
85
93
 
86
94
  async def _arun(
@@ -12,7 +12,7 @@
12
12
  # limitations under the License.
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
  import os
15
- from typing import Optional, Union
15
+ from typing import Any, Optional, Union
16
16
 
17
17
  from camel.models.openai_compatible_model import OpenAICompatibleModel
18
18
  from camel.types import ModelType
@@ -36,6 +36,10 @@ class NemotronModel(OpenAICompatibleModel):
36
36
  API calls. If not provided, will fall back to the MODEL_TIMEOUT
37
37
  environment variable or default to 180 seconds.
38
38
  (default: :obj:`None`)
39
+ max_retries (int, optional): Maximum number of retries for API calls.
40
+ (default: :obj:`3`)
41
+ **kwargs (Any): Additional arguments to pass to the client
42
+ initialization.
39
43
 
40
44
  Notes:
41
45
  Nemotron model doesn't support additional model config like OpenAI.
@@ -52,13 +56,24 @@ class NemotronModel(OpenAICompatibleModel):
52
56
  api_key: Optional[str] = None,
53
57
  url: Optional[str] = None,
54
58
  timeout: Optional[float] = None,
59
+ max_retries: int = 3,
60
+ **kwargs: Any,
55
61
  ) -> None:
56
62
  url = url or os.environ.get(
57
63
  "NVIDIA_API_BASE_URL", "https://integrate.api.nvidia.com/v1"
58
64
  )
59
65
  api_key = api_key or os.environ.get("NVIDIA_API_KEY")
60
66
  timeout = timeout or float(os.environ.get("MODEL_TIMEOUT", 180))
61
- super().__init__(model_type, {}, api_key, url, None, timeout)
67
+ super().__init__(
68
+ model_type,
69
+ {},
70
+ api_key,
71
+ url,
72
+ None,
73
+ timeout,
74
+ max_retries=max_retries,
75
+ **kwargs,
76
+ )
62
77
 
63
78
  @property
64
79
  def token_counter(self) -> BaseTokenCounter:
@@ -47,6 +47,10 @@ class NetmindModel(OpenAICompatibleModel):
47
47
  API calls. If not provided, will fall back to the MODEL_TIMEOUT
48
48
  environment variable or default to 180 seconds.
49
49
  (default: :obj:`None`)
50
+ max_retries (int, optional): Maximum number of retries for API calls.
51
+ (default: :obj:`3`)
52
+ **kwargs (Any): Additional arguments to pass to the client
53
+ initialization.
50
54
  """
51
55
 
52
56
  @api_keys_required(
@@ -62,6 +66,8 @@ class NetmindModel(OpenAICompatibleModel):
62
66
  url: Optional[str] = None,
63
67
  token_counter: Optional[BaseTokenCounter] = None,
64
68
  timeout: Optional[float] = None,
69
+ max_retries: int = 3,
70
+ **kwargs: Any,
65
71
  ) -> None:
66
72
  if model_config_dict is None:
67
73
  model_config_dict = NetmindConfig().as_dict()
@@ -78,6 +84,8 @@ class NetmindModel(OpenAICompatibleModel):
78
84
  url=url,
79
85
  token_counter=token_counter,
80
86
  timeout=timeout,
87
+ max_retries=max_retries,
88
+ **kwargs,
81
89
  )
82
90
 
83
91
  def check_model_config(self):
@@ -47,6 +47,10 @@ class NovitaModel(OpenAICompatibleModel):
47
47
  API calls. If not provided, will fall back to the MODEL_TIMEOUT
48
48
  environment variable or default to 180 seconds.
49
49
  (default: :obj:`None`)
50
+ max_retries (int, optional): Maximum number of retries for API calls.
51
+ (default: :obj:`3`)
52
+ **kwargs (Any): Additional arguments to pass to the client
53
+ initialization.
50
54
  """
51
55
 
52
56
  @api_keys_required(
@@ -62,6 +66,8 @@ class NovitaModel(OpenAICompatibleModel):
62
66
  url: Optional[str] = None,
63
67
  token_counter: Optional[BaseTokenCounter] = None,
64
68
  timeout: Optional[float] = None,
69
+ max_retries: int = 3,
70
+ **kwargs: Any,
65
71
  ) -> None:
66
72
  if model_config_dict is None:
67
73
  model_config_dict = NovitaConfig().as_dict()
@@ -77,6 +83,8 @@ class NovitaModel(OpenAICompatibleModel):
77
83
  url=url,
78
84
  token_counter=token_counter,
79
85
  timeout=timeout,
86
+ max_retries=max_retries,
87
+ **kwargs,
80
88
  )
81
89
 
82
90
  def check_model_config(self):
@@ -43,6 +43,10 @@ class NvidiaModel(OpenAICompatibleModel):
43
43
  API calls. If not provided, will fall back to the MODEL_TIMEOUT
44
44
  environment variable or default to 180 seconds.
45
45
  (default: :obj:`None`)
46
+ max_retries (int, optional): Maximum number of retries for API calls.
47
+ (default: :obj:`3`)
48
+ **kwargs (Any): Additional arguments to pass to the client
49
+ initialization.
46
50
  """
47
51
 
48
52
  @api_keys_required(
@@ -58,6 +62,8 @@ class NvidiaModel(OpenAICompatibleModel):
58
62
  url: Optional[str] = None,
59
63
  token_counter: Optional[BaseTokenCounter] = None,
60
64
  timeout: Optional[float] = None,
65
+ max_retries: int = 3,
66
+ **kwargs: Any,
61
67
  ) -> None:
62
68
  if model_config_dict is None:
63
69
  model_config_dict = NvidiaConfig().as_dict()
@@ -73,6 +79,8 @@ class NvidiaModel(OpenAICompatibleModel):
73
79
  url=url,
74
80
  token_counter=token_counter,
75
81
  timeout=timeout,
82
+ max_retries=max_retries,
83
+ **kwargs,
76
84
  )
77
85
 
78
86
  def check_model_config(self):
@@ -47,6 +47,10 @@ class OllamaModel(OpenAICompatibleModel):
47
47
  API calls. If not provided, will fall back to the MODEL_TIMEOUT
48
48
  environment variable or default to 180 seconds.
49
49
  (default: :obj:`None`)
50
+ max_retries (int, optional): Maximum number of retries for API calls.
51
+ (default: :obj:`3`)
52
+ **kwargs (Any): Additional arguments to pass to the client
53
+ initialization.
50
54
 
51
55
  References:
52
56
  https://github.com/ollama/ollama/blob/main/docs/openai.md
@@ -60,6 +64,8 @@ class OllamaModel(OpenAICompatibleModel):
60
64
  url: Optional[str] = None,
61
65
  token_counter: Optional[BaseTokenCounter] = None,
62
66
  timeout: Optional[float] = None,
67
+ max_retries: int = 3,
68
+ **kwargs: Any,
63
69
  ) -> None:
64
70
  if model_config_dict is None:
65
71
  model_config_dict = OllamaConfig().as_dict()
@@ -77,6 +83,8 @@ class OllamaModel(OpenAICompatibleModel):
77
83
  url=self._url,
78
84
  token_counter=token_counter,
79
85
  timeout=timeout,
86
+ max_retries=max_retries,
87
+ **kwargs,
80
88
  )
81
89
 
82
90
  def _start_server(self) -> None:
@@ -67,6 +67,11 @@ class OpenAICompatibleModel(BaseModelBackend):
67
67
  API calls. If not provided, will fall back to the MODEL_TIMEOUT
68
68
  environment variable or default to 180 seconds.
69
69
  (default: :obj:`None`)
70
+ max_retries (int, optional): Maximum number of retries for API calls.
71
+ (default: :obj:`3`)
72
+ **kwargs (Any): Additional arguments to pass to the
73
+ OpenAI client initialization. These can include parameters like
74
+ 'organization', 'default_headers', 'http_client', etc.
70
75
  """
71
76
 
72
77
  def __init__(
@@ -77,12 +82,21 @@ class OpenAICompatibleModel(BaseModelBackend):
77
82
  url: Optional[str] = None,
78
83
  token_counter: Optional[BaseTokenCounter] = None,
79
84
  timeout: Optional[float] = None,
85
+ max_retries: int = 3,
86
+ **kwargs: Any,
80
87
  ) -> None:
81
88
  api_key = api_key or os.environ.get("OPENAI_COMPATIBILITY_API_KEY")
82
89
  url = url or os.environ.get("OPENAI_COMPATIBILITY_API_BASE_URL")
83
90
  timeout = timeout or float(os.environ.get("MODEL_TIMEOUT", 180))
91
+
84
92
  super().__init__(
85
- model_type, model_config_dict, api_key, url, token_counter, timeout
93
+ model_type,
94
+ model_config_dict,
95
+ api_key,
96
+ url,
97
+ token_counter,
98
+ timeout,
99
+ max_retries,
86
100
  )
87
101
  if is_langfuse_available():
88
102
  from langfuse.openai import AsyncOpenAI as LangfuseAsyncOpenAI
@@ -90,28 +104,32 @@ class OpenAICompatibleModel(BaseModelBackend):
90
104
 
91
105
  self._client = LangfuseOpenAI(
92
106
  timeout=self._timeout,
93
- max_retries=3,
107
+ max_retries=max_retries,
94
108
  base_url=self._url,
95
109
  api_key=self._api_key,
110
+ **kwargs,
96
111
  )
97
112
  self._async_client = LangfuseAsyncOpenAI(
98
113
  timeout=self._timeout,
99
- max_retries=3,
114
+ max_retries=max_retries,
100
115
  base_url=self._url,
101
116
  api_key=self._api_key,
117
+ **kwargs,
102
118
  )
103
119
  else:
104
120
  self._client = OpenAI(
105
121
  timeout=self._timeout,
106
- max_retries=3,
122
+ max_retries=max_retries,
107
123
  base_url=self._url,
108
124
  api_key=self._api_key,
125
+ **kwargs,
109
126
  )
110
127
  self._async_client = AsyncOpenAI(
111
128
  timeout=self._timeout,
112
- max_retries=3,
129
+ max_retries=max_retries,
113
130
  base_url=self._url,
114
131
  api_key=self._api_key,
132
+ **kwargs,
115
133
  )
116
134
 
117
135
  @observe()
@@ -76,6 +76,11 @@ class OpenAIModel(BaseModelBackend):
76
76
  API calls. If not provided, will fall back to the MODEL_TIMEOUT
77
77
  environment variable or default to 180 seconds.
78
78
  (default: :obj:`None`)
79
+ max_retries (int, optional): Maximum number of retries for API calls.
80
+ (default: :obj:`3`)
81
+ **kwargs (Any): Additional arguments to pass to the
82
+ OpenAI client initialization. These can include parameters like
83
+ 'organization', 'default_headers', 'http_client', etc.
79
84
  """
80
85
 
81
86
  @api_keys_required(
@@ -91,6 +96,8 @@ class OpenAIModel(BaseModelBackend):
91
96
  url: Optional[str] = None,
92
97
  token_counter: Optional[BaseTokenCounter] = None,
93
98
  timeout: Optional[float] = None,
99
+ max_retries: int = 3,
100
+ **kwargs: Any,
94
101
  ) -> None:
95
102
  if model_config_dict is None:
96
103
  model_config_dict = ChatGPTConfig().as_dict()
@@ -98,6 +105,9 @@ class OpenAIModel(BaseModelBackend):
98
105
  url = url or os.environ.get("OPENAI_API_BASE_URL")
99
106
  timeout = timeout or float(os.environ.get("MODEL_TIMEOUT", 180))
100
107
 
108
+ # Store additional client args for later use
109
+ self._max_retries = max_retries
110
+
101
111
  super().__init__(
102
112
  model_type, model_config_dict, api_key, url, token_counter, timeout
103
113
  )
@@ -106,30 +116,37 @@ class OpenAIModel(BaseModelBackend):
106
116
  from langfuse.openai import AsyncOpenAI as LangfuseAsyncOpenAI
107
117
  from langfuse.openai import OpenAI as LangfuseOpenAI
108
118
 
119
+ # Create Langfuse client with base parameters and additional
120
+ # arguments
109
121
  self._client = LangfuseOpenAI(
110
122
  timeout=self._timeout,
111
- max_retries=3,
123
+ max_retries=self._max_retries,
112
124
  base_url=self._url,
113
125
  api_key=self._api_key,
126
+ **kwargs,
114
127
  )
115
128
  self._async_client = LangfuseAsyncOpenAI(
116
129
  timeout=self._timeout,
117
- max_retries=3,
130
+ max_retries=self._max_retries,
118
131
  base_url=self._url,
119
132
  api_key=self._api_key,
133
+ **kwargs,
120
134
  )
121
135
  else:
136
+ # Create client with base parameters and additional arguments
122
137
  self._client = OpenAI(
123
138
  timeout=self._timeout,
124
- max_retries=3,
139
+ max_retries=self._max_retries,
125
140
  base_url=self._url,
126
141
  api_key=self._api_key,
142
+ **kwargs,
127
143
  )
128
144
  self._async_client = AsyncOpenAI(
129
145
  timeout=self._timeout,
130
- max_retries=3,
146
+ max_retries=self._max_retries,
131
147
  base_url=self._url,
132
148
  api_key=self._api_key,
149
+ **kwargs,
133
150
  )
134
151
 
135
152
  def _sanitize_config(self, config_dict: Dict[str, Any]) -> Dict[str, Any]:
@@ -46,6 +46,10 @@ class OpenRouterModel(OpenAICompatibleModel):
46
46
  API calls. If not provided, will fall back to the MODEL_TIMEOUT
47
47
  environment variable or default to 180 seconds.
48
48
  (default: :obj:`None`)
49
+ max_retries (int, optional): Maximum number of retries for API calls.
50
+ (default: :obj:`3`)
51
+ **kwargs (Any): Additional arguments to pass to the client
52
+ initialization.
49
53
  """
50
54
 
51
55
  @api_keys_required([("api_key", "OPENROUTER_API_KEY")])
@@ -57,6 +61,8 @@ class OpenRouterModel(OpenAICompatibleModel):
57
61
  url: Optional[str] = None,
58
62
  token_counter: Optional[BaseTokenCounter] = None,
59
63
  timeout: Optional[float] = None,
64
+ max_retries: int = 3,
65
+ **kwargs: Any,
60
66
  ) -> None:
61
67
  if model_config_dict is None:
62
68
  model_config_dict = OpenRouterConfig().as_dict()
@@ -72,6 +78,8 @@ class OpenRouterModel(OpenAICompatibleModel):
72
78
  url=url,
73
79
  token_counter=token_counter,
74
80
  timeout=timeout,
81
+ max_retries=max_retries,
82
+ **kwargs,
75
83
  )
76
84
 
77
85
  def check_model_config(self):
@@ -47,6 +47,10 @@ class PPIOModel(OpenAICompatibleModel):
47
47
  API calls. If not provided, will fall back to the MODEL_TIMEOUT
48
48
  environment variable or default to 180 seconds.
49
49
  (default: :obj:`None`)
50
+ max_retries (int, optional): Maximum number of retries for API calls.
51
+ (default: :obj:`3`)
52
+ **kwargs (Any): Additional arguments to pass to the client
53
+ initialization.
50
54
  """
51
55
 
52
56
  @api_keys_required(
@@ -62,6 +66,8 @@ class PPIOModel(OpenAICompatibleModel):
62
66
  url: Optional[str] = None,
63
67
  token_counter: Optional[BaseTokenCounter] = None,
64
68
  timeout: Optional[float] = None,
69
+ max_retries: int = 3,
70
+ **kwargs: Any,
65
71
  ) -> None:
66
72
  if model_config_dict is None:
67
73
  model_config_dict = PPIOConfig().as_dict()
@@ -77,6 +83,8 @@ class PPIOModel(OpenAICompatibleModel):
77
83
  url=url,
78
84
  token_counter=token_counter,
79
85
  timeout=timeout,
86
+ max_retries=max_retries,
87
+ **kwargs,
80
88
  )
81
89
 
82
90
  def check_model_config(self):
@@ -0,0 +1,104 @@
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ import os
16
+ from typing import Any, Dict, Optional, Union
17
+
18
+ from camel.configs import QIANFAN_API_PARAMS, QianfanConfig
19
+ from camel.models.openai_compatible_model import OpenAICompatibleModel
20
+ from camel.types import ModelType
21
+ from camel.utils import (
22
+ BaseTokenCounter,
23
+ api_keys_required,
24
+ )
25
+
26
+
27
+ class QianfanModel(OpenAICompatibleModel):
28
+ r"""Constructor for Qianfan backend with OpenAI compatibility.
29
+
30
+ Args:
31
+ model_type (Union[ModelType, str]): Model for which a backend is
32
+ created, supported model can be found here:
33
+ https://cloud.baidu.com/doc/QIANFANWORKSHOP/s/Wm9cvy6rl
34
+ model_config_dict (Optional[Dict[str, Any]], optional): A dictionary
35
+ that will be fed into:obj:`openai.ChatCompletion.create()`. If
36
+ :obj:`None`, :obj:`QianfanConfig().as_dict()` will be used.
37
+ (default: :obj:`None`)
38
+ api_key (Optional[str], optional): The API key for authenticating with
39
+ the Qianfan service. (default: :obj:`None`)
40
+ url (Optional[str], optional): The url to the Qianfan service.
41
+ If not provided, "https://qianfan.baidubce.com/v2/chat/completions"
42
+ will be used.(default: :obj:`None`)
43
+ token_counter (Optional[BaseTokenCounter], optional): Token counter to
44
+ use for the model. If not provided, :obj:`OpenAITokenCounter(
45
+ ModelType.GPT_4O_MINI)` will be used.
46
+ timeout (Optional[float], optional): The timeout value in seconds for
47
+ API calls. If not provided, will fall back to the MODEL_TIMEOUT
48
+ environment variable or default to 180 seconds.
49
+ (default: :obj:`None`)
50
+ max_retries (Optional[int], optional): Maximum number of retries
51
+ for API calls. (default: :obj:`None`)
52
+ **kwargs: Additional model-specific parameters that will be passed
53
+ to the model constructor.
54
+ """
55
+
56
+ @api_keys_required(
57
+ [
58
+ ("api_key", 'QIANFAN_API_KEY'),
59
+ ]
60
+ )
61
+ def __init__(
62
+ self,
63
+ model_type: Union[ModelType, str],
64
+ model_config_dict: Optional[Dict[str, Any]] = None,
65
+ api_key: Optional[str] = None,
66
+ url: Optional[str] = None,
67
+ token_counter: Optional[BaseTokenCounter] = None,
68
+ timeout: Optional[float] = None,
69
+ max_retries: int = 3,
70
+ **kwargs,
71
+ ) -> None:
72
+ if model_config_dict is None:
73
+ model_config_dict = QianfanConfig().as_dict()
74
+ api_key = api_key or os.environ.get("QIANFAN_API_KEY")
75
+ url = url or os.environ.get(
76
+ "QIANFAN_API_BASE_URL",
77
+ "https://qianfan.baidubce.com/v2",
78
+ )
79
+ timeout = timeout or float(os.environ.get("MODEL_TIMEOUT", 180))
80
+ super().__init__(
81
+ model_type=model_type,
82
+ model_config_dict=model_config_dict,
83
+ api_key=api_key,
84
+ url=url,
85
+ token_counter=token_counter,
86
+ timeout=timeout,
87
+ max_retries=max_retries,
88
+ **kwargs,
89
+ )
90
+
91
+ def check_model_config(self):
92
+ r"""Check whether the model configuration contains any
93
+ unexpected arguments to Qianfan API.
94
+
95
+ Raises:
96
+ ValueError: If the model configuration dictionary contains any
97
+ unexpected arguments to Qianfan API.
98
+ """
99
+ for param in self.model_config_dict:
100
+ if param not in QIANFAN_API_PARAMS:
101
+ raise ValueError(
102
+ f"Unexpected argument `{param}` is "
103
+ "input into QIANFAN model backend."
104
+ )
@@ -54,6 +54,10 @@ class QwenModel(OpenAICompatibleModel):
54
54
  API calls. If not provided, will fall back to the MODEL_TIMEOUT
55
55
  environment variable or default to 180 seconds.
56
56
  (default: :obj:`None`)
57
+ max_retries (int, optional): Maximum number of retries for API calls.
58
+ (default: :obj:`3`)
59
+ **kwargs (Any): Additional arguments to pass to the client
60
+ initialization.
57
61
  """
58
62
 
59
63
  @api_keys_required(
@@ -69,6 +73,8 @@ class QwenModel(OpenAICompatibleModel):
69
73
  url: Optional[str] = None,
70
74
  token_counter: Optional[BaseTokenCounter] = None,
71
75
  timeout: Optional[float] = None,
76
+ max_retries: int = 3,
77
+ **kwargs: Any,
72
78
  ) -> None:
73
79
  if model_config_dict is None:
74
80
  model_config_dict = QwenConfig().as_dict()
@@ -85,6 +91,8 @@ class QwenModel(OpenAICompatibleModel):
85
91
  url=url,
86
92
  token_counter=token_counter,
87
93
  timeout=timeout,
94
+ max_retries=max_retries,
95
+ **kwargs,
88
96
  )
89
97
 
90
98
  def _post_handle_response(
@@ -72,6 +72,8 @@ class RekaModel(BaseModelBackend):
72
72
  API calls. If not provided, will fall back to the MODEL_TIMEOUT
73
73
  environment variable or default to 180 seconds.
74
74
  (default: :obj:`None`)
75
+ **kwargs (Any): Additional arguments to pass to the client
76
+ initialization.
75
77
  """
76
78
 
77
79
  @api_keys_required(
@@ -88,6 +90,7 @@ class RekaModel(BaseModelBackend):
88
90
  url: Optional[str] = None,
89
91
  token_counter: Optional[BaseTokenCounter] = None,
90
92
  timeout: Optional[float] = None,
93
+ **kwargs: Any,
91
94
  ) -> None:
92
95
  from reka.client import AsyncReka, Reka
93
96
 
@@ -97,13 +100,25 @@ class RekaModel(BaseModelBackend):
97
100
  url = url or os.environ.get("REKA_API_BASE_URL")
98
101
  timeout = timeout or float(os.environ.get("MODEL_TIMEOUT", 180))
99
102
  super().__init__(
100
- model_type, model_config_dict, api_key, url, token_counter, timeout
103
+ model_type,
104
+ model_config_dict,
105
+ api_key,
106
+ url,
107
+ token_counter,
108
+ timeout,
109
+ **kwargs,
101
110
  )
102
111
  self._client = Reka(
103
- api_key=self._api_key, base_url=self._url, timeout=self._timeout
112
+ api_key=self._api_key,
113
+ base_url=self._url,
114
+ timeout=self._timeout,
115
+ **kwargs,
104
116
  )
105
117
  self._async_client = AsyncReka(
106
- api_key=self._api_key, base_url=self._url, timeout=self._timeout
118
+ api_key=self._api_key,
119
+ base_url=self._url,
120
+ timeout=self._timeout,
121
+ **kwargs,
107
122
  )
108
123
 
109
124
  def _convert_reka_to_openai_response(
@@ -88,6 +88,10 @@ class SambaModel(BaseModelBackend):
88
88
  API calls. If not provided, will fall back to the MODEL_TIMEOUT
89
89
  environment variable or default to 180 seconds.
90
90
  (default: :obj:`None`)
91
+ max_retries (int, optional): Maximum number of retries for API calls.
92
+ (default: :obj:`3`)
93
+ **kwargs (Any): Additional arguments to pass to the client
94
+ initialization.
91
95
  """
92
96
 
93
97
  @api_keys_required(
@@ -103,6 +107,8 @@ class SambaModel(BaseModelBackend):
103
107
  url: Optional[str] = None,
104
108
  token_counter: Optional[BaseTokenCounter] = None,
105
109
  timeout: Optional[float] = None,
110
+ max_retries: int = 3,
111
+ **kwargs: Any,
106
112
  ) -> None:
107
113
  if model_config_dict is None:
108
114
  model_config_dict = SambaCloudAPIConfig().as_dict()
@@ -113,21 +119,29 @@ class SambaModel(BaseModelBackend):
113
119
  )
114
120
  timeout = timeout or float(os.environ.get("MODEL_TIMEOUT", 180))
115
121
  super().__init__(
116
- model_type, model_config_dict, api_key, url, token_counter, timeout
122
+ model_type,
123
+ model_config_dict,
124
+ api_key,
125
+ url,
126
+ token_counter,
127
+ timeout,
128
+ max_retries,
117
129
  )
118
130
 
119
131
  if self._url == "https://api.sambanova.ai/v1":
120
132
  self._client = OpenAI(
121
133
  timeout=self._timeout,
122
- max_retries=3,
134
+ max_retries=self._max_retries,
123
135
  base_url=self._url,
124
136
  api_key=self._api_key,
137
+ **kwargs,
125
138
  )
126
139
  self._async_client = AsyncOpenAI(
127
140
  timeout=self._timeout,
128
- max_retries=3,
141
+ max_retries=self._max_retries,
129
142
  base_url=self._url,
130
143
  api_key=self._api_key,
144
+ **kwargs,
131
145
  )
132
146
 
133
147
  @property