camel-ai 0.2.79a0__py3-none-any.whl → 0.2.79a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

@@ -13,6 +13,7 @@
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
  import copy
15
15
  import os
16
+ import warnings
16
17
  from typing import Any, Callable, Dict, List, Optional, Type, Union
17
18
 
18
19
  from openai import AsyncAzureOpenAI, AsyncStream, AzureOpenAI, Stream
@@ -60,7 +61,8 @@ class AzureOpenAIModel(BaseModelBackend):
60
61
 
61
62
  Args:
62
63
  model_type (Union[ModelType, str]): Model for which a backend is
63
- created, one of GPT_* series.
64
+ created, Should be the deployment name you chose when you deployed
65
+ an azure model.
64
66
  model_config_dict (Optional[Dict[str, Any]], optional): A dictionary
65
67
  that will be fed into:obj:`openai.ChatCompletion.create()`. If
66
68
  :obj:`None`, :obj:`ChatGPTConfig().as_dict()` will be used.
@@ -71,8 +73,6 @@ class AzureOpenAIModel(BaseModelBackend):
71
73
  (default: :obj:`None`)
72
74
  api_version (Optional[str], optional): The api version for the model.
73
75
  (default: :obj:`None`)
74
- azure_deployment_name (Optional[str], optional): The deployment name
75
- you chose when you deployed an azure model. (default: :obj:`None`)
76
76
  azure_ad_token (Optional[str], optional): Your Azure Active Directory
77
77
  token, https://www.microsoft.com/en-us/security/business/
78
78
  identity-access/microsoft-entra-id. (default: :obj:`None`)
@@ -88,8 +88,23 @@ class AzureOpenAIModel(BaseModelBackend):
88
88
  (default: :obj:`None`)
89
89
  max_retries (int, optional): Maximum number of retries for API calls.
90
90
  (default: :obj:`3`)
91
+ client (Optional[Any], optional): A custom synchronous AzureOpenAI
92
+ client instance. If provided, this client will be used instead of
93
+ creating a new one. Useful for RL frameworks like AReaL or rLLM
94
+ that provide Azure OpenAI-compatible clients. The client should
95
+ implement the AzureOpenAI client interface with
96
+ `.chat.completions.create()` and `.beta.chat.completions.parse()`
97
+ methods. (default: :obj:`None`)
98
+ async_client (Optional[Any], optional): A custom asynchronous
99
+ AzureOpenAI client instance. If provided, this client will be
100
+ used instead of creating a new one. The client should implement
101
+ the AsyncAzureOpenAI client interface. (default: :obj:`None`)
102
+ azure_deployment_name (Optional[str], optional): **Deprecated**.
103
+ Use `model_type` parameter instead. This parameter is kept for
104
+ backward compatibility and will be removed in a future version.
105
+ (default: :obj:`None`)
91
106
  **kwargs (Any): Additional arguments to pass to the client
92
- initialization.
107
+ initialization. Ignored if custom clients are provided.
93
108
 
94
109
  References:
95
110
  https://learn.microsoft.com/en-us/azure/ai-services/openai/
@@ -104,12 +119,35 @@ class AzureOpenAIModel(BaseModelBackend):
104
119
  timeout: Optional[float] = None,
105
120
  token_counter: Optional[BaseTokenCounter] = None,
106
121
  api_version: Optional[str] = None,
107
- azure_deployment_name: Optional[str] = None,
108
122
  azure_ad_token_provider: Optional["AzureADTokenProvider"] = None,
109
123
  azure_ad_token: Optional[str] = None,
110
124
  max_retries: int = 3,
125
+ client: Optional[Any] = None,
126
+ async_client: Optional[Any] = None,
127
+ azure_deployment_name: Optional[str] = None,
111
128
  **kwargs: Any,
112
129
  ) -> None:
130
+ # Handle deprecated azure_deployment_name parameter
131
+ if azure_deployment_name is not None:
132
+ warnings.warn(
133
+ "The 'azure_deployment_name' parameter is deprecated. "
134
+ "Please use 'model_type' parameter instead. "
135
+ "The 'azure_deployment_name' parameter is being ignored.",
136
+ DeprecationWarning,
137
+ stacklevel=2,
138
+ )
139
+
140
+ # Handle deprecated AZURE_DEPLOYMENT_NAME environment variable
141
+ if os.environ.get("AZURE_DEPLOYMENT_NAME") is not None:
142
+ warnings.warn(
143
+ "The 'AZURE_DEPLOYMENT_NAME' environment variable is "
144
+ "deprecated. Please use the 'model_type' parameter "
145
+ "instead. The 'AZURE_DEPLOYMENT_NAME' environment "
146
+ "variable is being ignored.",
147
+ DeprecationWarning,
148
+ stacklevel=2,
149
+ )
150
+
113
151
  if model_config_dict is None:
114
152
  model_config_dict = ChatGPTConfig().as_dict()
115
153
  api_key = api_key or os.environ.get("AZURE_OPENAI_API_KEY")
@@ -120,9 +158,6 @@ class AzureOpenAIModel(BaseModelBackend):
120
158
  )
121
159
 
122
160
  self.api_version = api_version or os.environ.get("AZURE_API_VERSION")
123
- self._azure_deployment_name = azure_deployment_name or os.environ.get(
124
- "AZURE_DEPLOYMENT_NAME"
125
- )
126
161
  self._azure_ad_token = azure_ad_token or os.environ.get(
127
162
  "AZURE_AD_TOKEN"
128
163
  )
@@ -132,62 +167,73 @@ class AzureOpenAIModel(BaseModelBackend):
132
167
  "Must provide either the `api_version` argument "
133
168
  "or `AZURE_API_VERSION` environment variable."
134
169
  )
135
- if self._azure_deployment_name is None:
136
- raise ValueError(
137
- "Must provide either the `azure_deployment_name` argument "
138
- "or `AZURE_DEPLOYMENT_NAME` environment variable."
139
- )
140
170
 
141
- if is_langfuse_available():
142
- from langfuse.openai import AsyncAzureOpenAI as LangfuseAsyncOpenAI
143
- from langfuse.openai import AzureOpenAI as LangfuseOpenAI
144
-
145
- self._client = LangfuseOpenAI(
146
- azure_endpoint=str(self._url),
147
- azure_deployment=self._azure_deployment_name,
148
- api_version=self.api_version,
149
- api_key=self._api_key,
150
- azure_ad_token=self._azure_ad_token,
151
- azure_ad_token_provider=self.azure_ad_token_provider,
152
- timeout=self._timeout,
153
- max_retries=max_retries,
154
- **kwargs,
155
- )
156
- self._async_client = LangfuseAsyncOpenAI(
157
- azure_endpoint=str(self._url),
158
- azure_deployment=self._azure_deployment_name,
159
- api_version=self.api_version,
160
- api_key=self._api_key,
161
- azure_ad_token=self._azure_ad_token,
162
- azure_ad_token_provider=self.azure_ad_token_provider,
163
- timeout=self._timeout,
164
- max_retries=max_retries,
165
- **kwargs,
166
- )
171
+ # Use custom clients if provided, otherwise create new ones
172
+ if client is not None:
173
+ # Use the provided custom sync client
174
+ self._client = client
167
175
  else:
168
- self._client = AzureOpenAI(
169
- azure_endpoint=str(self._url),
170
- azure_deployment=self._azure_deployment_name,
171
- api_version=self.api_version,
172
- api_key=self._api_key,
173
- azure_ad_token=self._azure_ad_token,
174
- azure_ad_token_provider=self.azure_ad_token_provider,
175
- timeout=self._timeout,
176
- max_retries=max_retries,
177
- **kwargs,
178
- )
176
+ # Create default sync client
177
+ if is_langfuse_available():
178
+ from langfuse.openai import AzureOpenAI as LangfuseOpenAI
179
+
180
+ self._client = LangfuseOpenAI(
181
+ azure_endpoint=str(self._url),
182
+ azure_deployment=str(self.model_type),
183
+ api_version=self.api_version,
184
+ api_key=self._api_key,
185
+ azure_ad_token=self._azure_ad_token,
186
+ azure_ad_token_provider=self.azure_ad_token_provider,
187
+ timeout=self._timeout,
188
+ max_retries=max_retries,
189
+ **kwargs,
190
+ )
191
+ else:
192
+ self._client = AzureOpenAI(
193
+ azure_endpoint=str(self._url),
194
+ azure_deployment=str(self.model_type),
195
+ api_version=self.api_version,
196
+ api_key=self._api_key,
197
+ azure_ad_token=self._azure_ad_token,
198
+ azure_ad_token_provider=self.azure_ad_token_provider,
199
+ timeout=self._timeout,
200
+ max_retries=max_retries,
201
+ **kwargs,
202
+ )
179
203
 
180
- self._async_client = AsyncAzureOpenAI(
181
- azure_endpoint=str(self._url),
182
- azure_deployment=self._azure_deployment_name,
183
- api_version=self.api_version,
184
- api_key=self._api_key,
185
- azure_ad_token=self._azure_ad_token,
186
- azure_ad_token_provider=self.azure_ad_token_provider,
187
- timeout=self._timeout,
188
- max_retries=max_retries,
189
- **kwargs,
190
- )
204
+ if async_client is not None:
205
+ # Use the provided custom async client
206
+ self._async_client = async_client
207
+ else:
208
+ # Create default async client
209
+ if is_langfuse_available():
210
+ from langfuse.openai import (
211
+ AsyncAzureOpenAI as LangfuseAsyncOpenAI,
212
+ )
213
+
214
+ self._async_client = LangfuseAsyncOpenAI(
215
+ azure_endpoint=str(self._url),
216
+ azure_deployment=str(self.model_type),
217
+ api_version=self.api_version,
218
+ api_key=self._api_key,
219
+ azure_ad_token=self._azure_ad_token,
220
+ azure_ad_token_provider=self.azure_ad_token_provider,
221
+ timeout=self._timeout,
222
+ max_retries=max_retries,
223
+ **kwargs,
224
+ )
225
+ else:
226
+ self._async_client = AsyncAzureOpenAI(
227
+ azure_endpoint=str(self._url),
228
+ azure_deployment=str(self.model_type),
229
+ api_version=self.api_version,
230
+ api_key=self._api_key,
231
+ azure_ad_token=self._azure_ad_token,
232
+ azure_ad_token_provider=self.azure_ad_token_provider,
233
+ timeout=self._timeout,
234
+ max_retries=max_retries,
235
+ **kwargs,
236
+ )
191
237
 
192
238
  @property
193
239
  def token_counter(self) -> BaseTokenCounter:
@@ -330,7 +376,7 @@ class AzureOpenAIModel(BaseModelBackend):
330
376
 
331
377
  return self._client.chat.completions.create(
332
378
  messages=messages,
333
- model=self._azure_deployment_name, # type:ignore[arg-type]
379
+ model=str(self.model_type),
334
380
  **request_config,
335
381
  )
336
382
 
@@ -346,7 +392,7 @@ class AzureOpenAIModel(BaseModelBackend):
346
392
 
347
393
  return await self._async_client.chat.completions.create(
348
394
  messages=messages,
349
- model=self._azure_deployment_name, # type:ignore[arg-type]
395
+ model=str(self.model_type),
350
396
  **request_config,
351
397
  )
352
398
 
@@ -367,7 +413,7 @@ class AzureOpenAIModel(BaseModelBackend):
367
413
 
368
414
  return self._client.beta.chat.completions.parse(
369
415
  messages=messages,
370
- model=self._azure_deployment_name, # type:ignore[arg-type]
416
+ model=str(self.model_type),
371
417
  **request_config,
372
418
  )
373
419
 
@@ -388,7 +434,7 @@ class AzureOpenAIModel(BaseModelBackend):
388
434
 
389
435
  return await self._async_client.beta.chat.completions.parse(
390
436
  messages=messages,
391
- model=self._azure_deployment_name, # type:ignore[arg-type]
437
+ model=str(self.model_type),
392
438
  **request_config,
393
439
  )
394
440
 
@@ -414,7 +460,7 @@ class AzureOpenAIModel(BaseModelBackend):
414
460
  # Use the beta streaming API for structured outputs
415
461
  return self._client.beta.chat.completions.stream(
416
462
  messages=messages,
417
- model=self.model_type,
463
+ model=str(self.model_type),
418
464
  response_format=response_format,
419
465
  **request_config,
420
466
  )
@@ -441,7 +487,7 @@ class AzureOpenAIModel(BaseModelBackend):
441
487
  # Use the beta streaming API for structured outputs
442
488
  return self._async_client.beta.chat.completions.stream(
443
489
  messages=messages,
444
- model=self.model_type,
490
+ model=str(self.model_type),
445
491
  response_format=response_format,
446
492
  **request_config,
447
493
  )
@@ -13,7 +13,7 @@
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
  import json
15
15
  import os
16
- from typing import ClassVar, Dict, Optional, Type, Union
16
+ from typing import Any, ClassVar, Dict, Optional, Type, Union
17
17
 
18
18
  from camel.models.aiml_model import AIMLModel
19
19
  from camel.models.amd_model import AMDModel
@@ -119,6 +119,8 @@ class ModelFactory:
119
119
  url: Optional[str] = None,
120
120
  timeout: Optional[float] = None,
121
121
  max_retries: int = 3,
122
+ client: Optional[Any] = None,
123
+ async_client: Optional[Any] = None,
122
124
  **kwargs,
123
125
  ) -> BaseModelBackend:
124
126
  r"""Creates an instance of `BaseModelBackend` of the specified type.
@@ -145,6 +147,14 @@ class ModelFactory:
145
147
  for API calls. (default: :obj:`None`)
146
148
  max_retries (int, optional): Maximum number of retries
147
149
  for API calls. (default: :obj:`3`)
150
+ client (Optional[Any], optional): A custom synchronous client
151
+ instance. Supported by models that use OpenAI-compatible APIs
152
+ . The client should implement the appropriate client interface
153
+ for the platform. (default: :obj:`None`)
154
+ async_client (Optional[Any], optional): A custom asynchronous
155
+ client instance. Supported by models that use OpenAI-compatible
156
+ APIs. The client should implement the appropriate async client
157
+ interface for the platform. (default: :obj:`None`)
148
158
  **kwargs: Additional model-specific parameters that will be passed
149
159
  to the model constructor. For example, Azure OpenAI models may
150
160
  require `api_version`, `azure_deployment_name`,
@@ -190,6 +200,12 @@ class ModelFactory:
190
200
  if model_class is None:
191
201
  raise ValueError(f"Unknown model platform `{model_platform}`")
192
202
 
203
+ # Pass client and async_client via kwargs if provided
204
+ if client is not None:
205
+ kwargs['client'] = client
206
+ if async_client is not None:
207
+ kwargs['async_client'] = async_client
208
+
193
209
  return model_class(
194
210
  model_type=model_type,
195
211
  model_config_dict=model_config_dict,
@@ -78,9 +78,21 @@ class OpenAICompatibleModel(BaseModelBackend):
78
78
  (default: :obj:`None`)
79
79
  max_retries (int, optional): Maximum number of retries for API calls.
80
80
  (default: :obj:`3`)
81
+ client (Optional[Any], optional): A custom synchronous
82
+ OpenAI-compatible client instance. If provided, this client will
83
+ be used instead of creating a new one. Useful for RL frameworks
84
+ like AReaL or rLLM that provide OpenAI-compatible clients (e.g.,
85
+ ArealOpenAI). The client should implement the OpenAI client
86
+ interface with `.chat.completions.create()` and `.beta.chat.
87
+ completions.parse()` methods. (default: :obj:`None`)
88
+ async_client (Optional[Any], optional): A custom asynchronous
89
+ OpenAI-compatible client instance. If provided, this client will
90
+ be used instead of creating a new one. The client should implement
91
+ the AsyncOpenAI client interface. (default: :obj:`None`)
81
92
  **kwargs (Any): Additional arguments to pass to the
82
93
  OpenAI client initialization. These can include parameters like
83
94
  'organization', 'default_headers', 'http_client', etc.
95
+ Ignored if custom clients are provided.
84
96
  """
85
97
 
86
98
  def __init__(
@@ -92,6 +104,8 @@ class OpenAICompatibleModel(BaseModelBackend):
92
104
  token_counter: Optional[BaseTokenCounter] = None,
93
105
  timeout: Optional[float] = None,
94
106
  max_retries: int = 3,
107
+ client: Optional[Any] = None,
108
+ async_client: Optional[Any] = None,
95
109
  **kwargs: Any,
96
110
  ) -> None:
97
111
  api_key = api_key or os.environ.get("OPENAI_COMPATIBILITY_API_KEY")
@@ -107,39 +121,55 @@ class OpenAICompatibleModel(BaseModelBackend):
107
121
  timeout,
108
122
  max_retries,
109
123
  )
110
- if is_langfuse_available():
111
- from langfuse.openai import AsyncOpenAI as LangfuseAsyncOpenAI
112
- from langfuse.openai import OpenAI as LangfuseOpenAI
113
-
114
- self._client = LangfuseOpenAI(
115
- timeout=self._timeout,
116
- max_retries=max_retries,
117
- base_url=self._url,
118
- api_key=self._api_key,
119
- **kwargs,
120
- )
121
- self._async_client = LangfuseAsyncOpenAI(
122
- timeout=self._timeout,
123
- max_retries=max_retries,
124
- base_url=self._url,
125
- api_key=self._api_key,
126
- **kwargs,
127
- )
124
+
125
+ # Use custom clients if provided, otherwise create new ones
126
+ if client is not None:
127
+ # Use the provided custom sync client
128
+ self._client = client
128
129
  else:
129
- self._client = OpenAI(
130
- timeout=self._timeout,
131
- max_retries=max_retries,
132
- base_url=self._url,
133
- api_key=self._api_key,
134
- **kwargs,
135
- )
136
- self._async_client = AsyncOpenAI(
137
- timeout=self._timeout,
138
- max_retries=max_retries,
139
- base_url=self._url,
140
- api_key=self._api_key,
141
- **kwargs,
142
- )
130
+ # Create default sync client
131
+ if is_langfuse_available():
132
+ from langfuse.openai import OpenAI as LangfuseOpenAI
133
+
134
+ self._client = LangfuseOpenAI(
135
+ timeout=self._timeout,
136
+ max_retries=max_retries,
137
+ base_url=self._url,
138
+ api_key=self._api_key,
139
+ **kwargs,
140
+ )
141
+ else:
142
+ self._client = OpenAI(
143
+ timeout=self._timeout,
144
+ max_retries=max_retries,
145
+ base_url=self._url,
146
+ api_key=self._api_key,
147
+ **kwargs,
148
+ )
149
+
150
+ if async_client is not None:
151
+ # Use the provided custom async client
152
+ self._async_client = async_client
153
+ else:
154
+ # Create default async client
155
+ if is_langfuse_available():
156
+ from langfuse.openai import AsyncOpenAI as LangfuseAsyncOpenAI
157
+
158
+ self._async_client = LangfuseAsyncOpenAI(
159
+ timeout=self._timeout,
160
+ max_retries=max_retries,
161
+ base_url=self._url,
162
+ api_key=self._api_key,
163
+ **kwargs,
164
+ )
165
+ else:
166
+ self._async_client = AsyncOpenAI(
167
+ timeout=self._timeout,
168
+ max_retries=max_retries,
169
+ base_url=self._url,
170
+ api_key=self._api_key,
171
+ **kwargs,
172
+ )
143
173
 
144
174
  @observe()
145
175
  def _run(
@@ -90,9 +90,21 @@ class OpenAIModel(BaseModelBackend):
90
90
  (default: :obj:`None`)
91
91
  max_retries (int, optional): Maximum number of retries for API calls.
92
92
  (default: :obj:`3`)
93
+ client (Optional[Any], optional): A custom synchronous OpenAI client
94
+ instance. If provided, this client will be used instead of
95
+ creating a new one. Useful for RL frameworks like AReaL or rLLM
96
+ that provide OpenAI-compatible clients. The client should
97
+ implement the OpenAI client interface with
98
+ `.chat.completions.create()` and `.beta.chat.completions.parse()`
99
+ methods. (default: :obj:`None`)
100
+ async_client (Optional[Any], optional): A custom asynchronous OpenAI
101
+ client instance. If provided, this client will be used instead of
102
+ creating a new one. The client should implement the AsyncOpenAI
103
+ client interface. (default: :obj:`None`)
93
104
  **kwargs (Any): Additional arguments to pass to the
94
105
  OpenAI client initialization. These can include parameters like
95
106
  'organization', 'default_headers', 'http_client', etc.
107
+ Ignored if custom clients are provided.
96
108
  """
97
109
 
98
110
  @api_keys_required(
@@ -109,6 +121,8 @@ class OpenAIModel(BaseModelBackend):
109
121
  token_counter: Optional[BaseTokenCounter] = None,
110
122
  timeout: Optional[float] = None,
111
123
  max_retries: int = 3,
124
+ client: Optional[Any] = None,
125
+ async_client: Optional[Any] = None,
112
126
  **kwargs: Any,
113
127
  ) -> None:
114
128
  if model_config_dict is None:
@@ -124,42 +138,54 @@ class OpenAIModel(BaseModelBackend):
124
138
  model_type, model_config_dict, api_key, url, token_counter, timeout
125
139
  )
126
140
 
127
- if is_langfuse_available():
128
- from langfuse.openai import AsyncOpenAI as LangfuseAsyncOpenAI
129
- from langfuse.openai import OpenAI as LangfuseOpenAI
130
-
131
- # Create Langfuse client with base parameters and additional
132
- # arguments
133
- self._client = LangfuseOpenAI(
134
- timeout=self._timeout,
135
- max_retries=self._max_retries,
136
- base_url=self._url,
137
- api_key=self._api_key,
138
- **kwargs,
139
- )
140
- self._async_client = LangfuseAsyncOpenAI(
141
- timeout=self._timeout,
142
- max_retries=self._max_retries,
143
- base_url=self._url,
144
- api_key=self._api_key,
145
- **kwargs,
146
- )
141
+ # Use custom clients if provided, otherwise create new ones
142
+ if client is not None:
143
+ # Use the provided custom sync client
144
+ self._client = client
147
145
  else:
148
- # Create client with base parameters and additional arguments
149
- self._client = OpenAI(
150
- timeout=self._timeout,
151
- max_retries=self._max_retries,
152
- base_url=self._url,
153
- api_key=self._api_key,
154
- **kwargs,
155
- )
156
- self._async_client = AsyncOpenAI(
157
- timeout=self._timeout,
158
- max_retries=self._max_retries,
159
- base_url=self._url,
160
- api_key=self._api_key,
161
- **kwargs,
162
- )
146
+ # Create default sync client
147
+ if is_langfuse_available():
148
+ from langfuse.openai import OpenAI as LangfuseOpenAI
149
+
150
+ self._client = LangfuseOpenAI(
151
+ timeout=self._timeout,
152
+ max_retries=self._max_retries,
153
+ base_url=self._url,
154
+ api_key=self._api_key,
155
+ **kwargs,
156
+ )
157
+ else:
158
+ self._client = OpenAI(
159
+ timeout=self._timeout,
160
+ max_retries=self._max_retries,
161
+ base_url=self._url,
162
+ api_key=self._api_key,
163
+ **kwargs,
164
+ )
165
+
166
+ if async_client is not None:
167
+ # Use the provided custom async client
168
+ self._async_client = async_client
169
+ else:
170
+ # Create default async client
171
+ if is_langfuse_available():
172
+ from langfuse.openai import AsyncOpenAI as LangfuseAsyncOpenAI
173
+
174
+ self._async_client = LangfuseAsyncOpenAI(
175
+ timeout=self._timeout,
176
+ max_retries=self._max_retries,
177
+ base_url=self._url,
178
+ api_key=self._api_key,
179
+ **kwargs,
180
+ )
181
+ else:
182
+ self._async_client = AsyncOpenAI(
183
+ timeout=self._timeout,
184
+ max_retries=self._max_retries,
185
+ base_url=self._url,
186
+ api_key=self._api_key,
187
+ **kwargs,
188
+ )
163
189
 
164
190
  def _sanitize_config(self, config_dict: Dict[str, Any]) -> Dict[str, Any]:
165
191
  r"""Sanitize the model configuration for O1 models."""
@@ -88,8 +88,16 @@ class SambaModel(BaseModelBackend):
88
88
  (default: :obj:`None`)
89
89
  max_retries (int, optional): Maximum number of retries for API calls.
90
90
  (default: :obj:`3`)
91
+ client (Optional[Any], optional): A custom synchronous
92
+ OpenAI-compatible client instance. If provided, this client will
93
+ be used instead of creating a new one. Only applicable when using
94
+ SambaNova Cloud API. (default: :obj:`None`)
95
+ async_client (Optional[Any], optional): A custom asynchronous
96
+ OpenAI-compatible client instance. If provided, this client will
97
+ be used instead of creating a new one. Only applicable when using
98
+ SambaNova Cloud API. (default: :obj:`None`)
91
99
  **kwargs (Any): Additional arguments to pass to the client
92
- initialization.
100
+ initialization. Ignored if custom clients are provided.
93
101
  """
94
102
 
95
103
  @api_keys_required(
@@ -106,6 +114,8 @@ class SambaModel(BaseModelBackend):
106
114
  token_counter: Optional[BaseTokenCounter] = None,
107
115
  timeout: Optional[float] = None,
108
116
  max_retries: int = 3,
117
+ client: Optional[Any] = None,
118
+ async_client: Optional[Any] = None,
109
119
  **kwargs: Any,
110
120
  ) -> None:
111
121
  if model_config_dict is None:
@@ -126,21 +136,30 @@ class SambaModel(BaseModelBackend):
126
136
  max_retries,
127
137
  )
128
138
 
139
+ # Only create clients for Cloud API mode
129
140
  if self._url == "https://api.sambanova.ai/v1":
130
- self._client = OpenAI(
131
- timeout=self._timeout,
132
- max_retries=self._max_retries,
133
- base_url=self._url,
134
- api_key=self._api_key,
135
- **kwargs,
136
- )
137
- self._async_client = AsyncOpenAI(
138
- timeout=self._timeout,
139
- max_retries=self._max_retries,
140
- base_url=self._url,
141
- api_key=self._api_key,
142
- **kwargs,
143
- )
141
+ # Use custom clients if provided, otherwise create new ones
142
+ if client is not None:
143
+ self._client = client
144
+ else:
145
+ self._client = OpenAI(
146
+ timeout=self._timeout,
147
+ max_retries=self._max_retries,
148
+ base_url=self._url,
149
+ api_key=self._api_key,
150
+ **kwargs,
151
+ )
152
+
153
+ if async_client is not None:
154
+ self._async_client = async_client
155
+ else:
156
+ self._async_client = AsyncOpenAI(
157
+ timeout=self._timeout,
158
+ max_retries=self._max_retries,
159
+ base_url=self._url,
160
+ api_key=self._api_key,
161
+ **kwargs,
162
+ )
144
163
 
145
164
  @property
146
165
  def token_counter(self) -> BaseTokenCounter: