camel-ai 0.2.78__py3-none-any.whl → 0.2.79a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (39) hide show
  1. camel/__init__.py +1 -1
  2. camel/agents/_utils.py +38 -0
  3. camel/agents/chat_agent.py +1112 -287
  4. camel/datasets/base_generator.py +39 -10
  5. camel/environments/single_step.py +28 -3
  6. camel/memories/__init__.py +1 -2
  7. camel/memories/agent_memories.py +34 -0
  8. camel/memories/base.py +26 -0
  9. camel/memories/blocks/chat_history_block.py +117 -17
  10. camel/memories/context_creators/score_based.py +25 -384
  11. camel/messages/base.py +26 -0
  12. camel/models/aws_bedrock_model.py +1 -17
  13. camel/models/azure_openai_model.py +113 -67
  14. camel/models/model_factory.py +17 -1
  15. camel/models/moonshot_model.py +102 -5
  16. camel/models/openai_compatible_model.py +62 -32
  17. camel/models/openai_model.py +61 -35
  18. camel/models/samba_model.py +34 -15
  19. camel/models/sglang_model.py +41 -11
  20. camel/societies/workforce/__init__.py +2 -0
  21. camel/societies/workforce/events.py +122 -0
  22. camel/societies/workforce/role_playing_worker.py +15 -11
  23. camel/societies/workforce/single_agent_worker.py +143 -291
  24. camel/societies/workforce/utils.py +2 -1
  25. camel/societies/workforce/workflow_memory_manager.py +772 -0
  26. camel/societies/workforce/workforce.py +513 -188
  27. camel/societies/workforce/workforce_callback.py +74 -0
  28. camel/societies/workforce/workforce_logger.py +144 -140
  29. camel/societies/workforce/workforce_metrics.py +33 -0
  30. camel/storages/vectordb_storages/oceanbase.py +5 -4
  31. camel/toolkits/file_toolkit.py +166 -0
  32. camel/toolkits/message_integration.py +15 -13
  33. camel/toolkits/terminal_toolkit/terminal_toolkit.py +112 -79
  34. camel/types/enums.py +1 -0
  35. camel/utils/context_utils.py +201 -2
  36. {camel_ai-0.2.78.dist-info → camel_ai-0.2.79a1.dist-info}/METADATA +14 -13
  37. {camel_ai-0.2.78.dist-info → camel_ai-0.2.79a1.dist-info}/RECORD +39 -35
  38. {camel_ai-0.2.78.dist-info → camel_ai-0.2.79a1.dist-info}/WHEEL +0 -0
  39. {camel_ai-0.2.78.dist-info → camel_ai-0.2.79a1.dist-info}/licenses/LICENSE +0 -0
@@ -13,6 +13,7 @@
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
  import copy
15
15
  import os
16
+ import warnings
16
17
  from typing import Any, Callable, Dict, List, Optional, Type, Union
17
18
 
18
19
  from openai import AsyncAzureOpenAI, AsyncStream, AzureOpenAI, Stream
@@ -60,7 +61,8 @@ class AzureOpenAIModel(BaseModelBackend):
60
61
 
61
62
  Args:
62
63
  model_type (Union[ModelType, str]): Model for which a backend is
63
- created, one of GPT_* series.
64
+ created, Should be the deployment name you chose when you deployed
65
+ an azure model.
64
66
  model_config_dict (Optional[Dict[str, Any]], optional): A dictionary
65
67
  that will be fed into:obj:`openai.ChatCompletion.create()`. If
66
68
  :obj:`None`, :obj:`ChatGPTConfig().as_dict()` will be used.
@@ -71,8 +73,6 @@ class AzureOpenAIModel(BaseModelBackend):
71
73
  (default: :obj:`None`)
72
74
  api_version (Optional[str], optional): The api version for the model.
73
75
  (default: :obj:`None`)
74
- azure_deployment_name (Optional[str], optional): The deployment name
75
- you chose when you deployed an azure model. (default: :obj:`None`)
76
76
  azure_ad_token (Optional[str], optional): Your Azure Active Directory
77
77
  token, https://www.microsoft.com/en-us/security/business/
78
78
  identity-access/microsoft-entra-id. (default: :obj:`None`)
@@ -88,8 +88,23 @@ class AzureOpenAIModel(BaseModelBackend):
88
88
  (default: :obj:`None`)
89
89
  max_retries (int, optional): Maximum number of retries for API calls.
90
90
  (default: :obj:`3`)
91
+ client (Optional[Any], optional): A custom synchronous AzureOpenAI
92
+ client instance. If provided, this client will be used instead of
93
+ creating a new one. Useful for RL frameworks like AReaL or rLLM
94
+ that provide Azure OpenAI-compatible clients. The client should
95
+ implement the AzureOpenAI client interface with
96
+ `.chat.completions.create()` and `.beta.chat.completions.parse()`
97
+ methods. (default: :obj:`None`)
98
+ async_client (Optional[Any], optional): A custom asynchronous
99
+ AzureOpenAI client instance. If provided, this client will be
100
+ used instead of creating a new one. The client should implement
101
+ the AsyncAzureOpenAI client interface. (default: :obj:`None`)
102
+ azure_deployment_name (Optional[str], optional): **Deprecated**.
103
+ Use `model_type` parameter instead. This parameter is kept for
104
+ backward compatibility and will be removed in a future version.
105
+ (default: :obj:`None`)
91
106
  **kwargs (Any): Additional arguments to pass to the client
92
- initialization.
107
+ initialization. Ignored if custom clients are provided.
93
108
 
94
109
  References:
95
110
  https://learn.microsoft.com/en-us/azure/ai-services/openai/
@@ -104,12 +119,35 @@ class AzureOpenAIModel(BaseModelBackend):
104
119
  timeout: Optional[float] = None,
105
120
  token_counter: Optional[BaseTokenCounter] = None,
106
121
  api_version: Optional[str] = None,
107
- azure_deployment_name: Optional[str] = None,
108
122
  azure_ad_token_provider: Optional["AzureADTokenProvider"] = None,
109
123
  azure_ad_token: Optional[str] = None,
110
124
  max_retries: int = 3,
125
+ client: Optional[Any] = None,
126
+ async_client: Optional[Any] = None,
127
+ azure_deployment_name: Optional[str] = None,
111
128
  **kwargs: Any,
112
129
  ) -> None:
130
+ # Handle deprecated azure_deployment_name parameter
131
+ if azure_deployment_name is not None:
132
+ warnings.warn(
133
+ "The 'azure_deployment_name' parameter is deprecated. "
134
+ "Please use 'model_type' parameter instead. "
135
+ "The 'azure_deployment_name' parameter is being ignored.",
136
+ DeprecationWarning,
137
+ stacklevel=2,
138
+ )
139
+
140
+ # Handle deprecated AZURE_DEPLOYMENT_NAME environment variable
141
+ if os.environ.get("AZURE_DEPLOYMENT_NAME") is not None:
142
+ warnings.warn(
143
+ "The 'AZURE_DEPLOYMENT_NAME' environment variable is "
144
+ "deprecated. Please use the 'model_type' parameter "
145
+ "instead. The 'AZURE_DEPLOYMENT_NAME' environment "
146
+ "variable is being ignored.",
147
+ DeprecationWarning,
148
+ stacklevel=2,
149
+ )
150
+
113
151
  if model_config_dict is None:
114
152
  model_config_dict = ChatGPTConfig().as_dict()
115
153
  api_key = api_key or os.environ.get("AZURE_OPENAI_API_KEY")
@@ -120,9 +158,6 @@ class AzureOpenAIModel(BaseModelBackend):
120
158
  )
121
159
 
122
160
  self.api_version = api_version or os.environ.get("AZURE_API_VERSION")
123
- self._azure_deployment_name = azure_deployment_name or os.environ.get(
124
- "AZURE_DEPLOYMENT_NAME"
125
- )
126
161
  self._azure_ad_token = azure_ad_token or os.environ.get(
127
162
  "AZURE_AD_TOKEN"
128
163
  )
@@ -132,62 +167,73 @@ class AzureOpenAIModel(BaseModelBackend):
132
167
  "Must provide either the `api_version` argument "
133
168
  "or `AZURE_API_VERSION` environment variable."
134
169
  )
135
- if self._azure_deployment_name is None:
136
- raise ValueError(
137
- "Must provide either the `azure_deployment_name` argument "
138
- "or `AZURE_DEPLOYMENT_NAME` environment variable."
139
- )
140
170
 
141
- if is_langfuse_available():
142
- from langfuse.openai import AsyncAzureOpenAI as LangfuseAsyncOpenAI
143
- from langfuse.openai import AzureOpenAI as LangfuseOpenAI
144
-
145
- self._client = LangfuseOpenAI(
146
- azure_endpoint=str(self._url),
147
- azure_deployment=self._azure_deployment_name,
148
- api_version=self.api_version,
149
- api_key=self._api_key,
150
- azure_ad_token=self._azure_ad_token,
151
- azure_ad_token_provider=self.azure_ad_token_provider,
152
- timeout=self._timeout,
153
- max_retries=max_retries,
154
- **kwargs,
155
- )
156
- self._async_client = LangfuseAsyncOpenAI(
157
- azure_endpoint=str(self._url),
158
- azure_deployment=self._azure_deployment_name,
159
- api_version=self.api_version,
160
- api_key=self._api_key,
161
- azure_ad_token=self._azure_ad_token,
162
- azure_ad_token_provider=self.azure_ad_token_provider,
163
- timeout=self._timeout,
164
- max_retries=max_retries,
165
- **kwargs,
166
- )
171
+ # Use custom clients if provided, otherwise create new ones
172
+ if client is not None:
173
+ # Use the provided custom sync client
174
+ self._client = client
167
175
  else:
168
- self._client = AzureOpenAI(
169
- azure_endpoint=str(self._url),
170
- azure_deployment=self._azure_deployment_name,
171
- api_version=self.api_version,
172
- api_key=self._api_key,
173
- azure_ad_token=self._azure_ad_token,
174
- azure_ad_token_provider=self.azure_ad_token_provider,
175
- timeout=self._timeout,
176
- max_retries=max_retries,
177
- **kwargs,
178
- )
176
+ # Create default sync client
177
+ if is_langfuse_available():
178
+ from langfuse.openai import AzureOpenAI as LangfuseOpenAI
179
+
180
+ self._client = LangfuseOpenAI(
181
+ azure_endpoint=str(self._url),
182
+ azure_deployment=str(self.model_type),
183
+ api_version=self.api_version,
184
+ api_key=self._api_key,
185
+ azure_ad_token=self._azure_ad_token,
186
+ azure_ad_token_provider=self.azure_ad_token_provider,
187
+ timeout=self._timeout,
188
+ max_retries=max_retries,
189
+ **kwargs,
190
+ )
191
+ else:
192
+ self._client = AzureOpenAI(
193
+ azure_endpoint=str(self._url),
194
+ azure_deployment=str(self.model_type),
195
+ api_version=self.api_version,
196
+ api_key=self._api_key,
197
+ azure_ad_token=self._azure_ad_token,
198
+ azure_ad_token_provider=self.azure_ad_token_provider,
199
+ timeout=self._timeout,
200
+ max_retries=max_retries,
201
+ **kwargs,
202
+ )
179
203
 
180
- self._async_client = AsyncAzureOpenAI(
181
- azure_endpoint=str(self._url),
182
- azure_deployment=self._azure_deployment_name,
183
- api_version=self.api_version,
184
- api_key=self._api_key,
185
- azure_ad_token=self._azure_ad_token,
186
- azure_ad_token_provider=self.azure_ad_token_provider,
187
- timeout=self._timeout,
188
- max_retries=max_retries,
189
- **kwargs,
190
- )
204
+ if async_client is not None:
205
+ # Use the provided custom async client
206
+ self._async_client = async_client
207
+ else:
208
+ # Create default async client
209
+ if is_langfuse_available():
210
+ from langfuse.openai import (
211
+ AsyncAzureOpenAI as LangfuseAsyncOpenAI,
212
+ )
213
+
214
+ self._async_client = LangfuseAsyncOpenAI(
215
+ azure_endpoint=str(self._url),
216
+ azure_deployment=str(self.model_type),
217
+ api_version=self.api_version,
218
+ api_key=self._api_key,
219
+ azure_ad_token=self._azure_ad_token,
220
+ azure_ad_token_provider=self.azure_ad_token_provider,
221
+ timeout=self._timeout,
222
+ max_retries=max_retries,
223
+ **kwargs,
224
+ )
225
+ else:
226
+ self._async_client = AsyncAzureOpenAI(
227
+ azure_endpoint=str(self._url),
228
+ azure_deployment=str(self.model_type),
229
+ api_version=self.api_version,
230
+ api_key=self._api_key,
231
+ azure_ad_token=self._azure_ad_token,
232
+ azure_ad_token_provider=self.azure_ad_token_provider,
233
+ timeout=self._timeout,
234
+ max_retries=max_retries,
235
+ **kwargs,
236
+ )
191
237
 
192
238
  @property
193
239
  def token_counter(self) -> BaseTokenCounter:
@@ -330,7 +376,7 @@ class AzureOpenAIModel(BaseModelBackend):
330
376
 
331
377
  return self._client.chat.completions.create(
332
378
  messages=messages,
333
- model=self._azure_deployment_name, # type:ignore[arg-type]
379
+ model=str(self.model_type),
334
380
  **request_config,
335
381
  )
336
382
 
@@ -346,7 +392,7 @@ class AzureOpenAIModel(BaseModelBackend):
346
392
 
347
393
  return await self._async_client.chat.completions.create(
348
394
  messages=messages,
349
- model=self._azure_deployment_name, # type:ignore[arg-type]
395
+ model=str(self.model_type),
350
396
  **request_config,
351
397
  )
352
398
 
@@ -367,7 +413,7 @@ class AzureOpenAIModel(BaseModelBackend):
367
413
 
368
414
  return self._client.beta.chat.completions.parse(
369
415
  messages=messages,
370
- model=self._azure_deployment_name, # type:ignore[arg-type]
416
+ model=str(self.model_type),
371
417
  **request_config,
372
418
  )
373
419
 
@@ -388,7 +434,7 @@ class AzureOpenAIModel(BaseModelBackend):
388
434
 
389
435
  return await self._async_client.beta.chat.completions.parse(
390
436
  messages=messages,
391
- model=self._azure_deployment_name, # type:ignore[arg-type]
437
+ model=str(self.model_type),
392
438
  **request_config,
393
439
  )
394
440
 
@@ -414,7 +460,7 @@ class AzureOpenAIModel(BaseModelBackend):
414
460
  # Use the beta streaming API for structured outputs
415
461
  return self._client.beta.chat.completions.stream(
416
462
  messages=messages,
417
- model=self.model_type,
463
+ model=str(self.model_type),
418
464
  response_format=response_format,
419
465
  **request_config,
420
466
  )
@@ -441,7 +487,7 @@ class AzureOpenAIModel(BaseModelBackend):
441
487
  # Use the beta streaming API for structured outputs
442
488
  return self._async_client.beta.chat.completions.stream(
443
489
  messages=messages,
444
- model=self.model_type,
490
+ model=str(self.model_type),
445
491
  response_format=response_format,
446
492
  **request_config,
447
493
  )
@@ -13,7 +13,7 @@
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
  import json
15
15
  import os
16
- from typing import ClassVar, Dict, Optional, Type, Union
16
+ from typing import Any, ClassVar, Dict, Optional, Type, Union
17
17
 
18
18
  from camel.models.aiml_model import AIMLModel
19
19
  from camel.models.amd_model import AMDModel
@@ -119,6 +119,8 @@ class ModelFactory:
119
119
  url: Optional[str] = None,
120
120
  timeout: Optional[float] = None,
121
121
  max_retries: int = 3,
122
+ client: Optional[Any] = None,
123
+ async_client: Optional[Any] = None,
122
124
  **kwargs,
123
125
  ) -> BaseModelBackend:
124
126
  r"""Creates an instance of `BaseModelBackend` of the specified type.
@@ -145,6 +147,14 @@ class ModelFactory:
145
147
  for API calls. (default: :obj:`None`)
146
148
  max_retries (int, optional): Maximum number of retries
147
149
  for API calls. (default: :obj:`3`)
150
+ client (Optional[Any], optional): A custom synchronous client
151
+ instance. Supported by models that use OpenAI-compatible APIs
152
+ . The client should implement the appropriate client interface
153
+ for the platform. (default: :obj:`None`)
154
+ async_client (Optional[Any], optional): A custom asynchronous
155
+ client instance. Supported by models that use OpenAI-compatible
156
+ APIs. The client should implement the appropriate async client
157
+ interface for the platform. (default: :obj:`None`)
148
158
  **kwargs: Additional model-specific parameters that will be passed
149
159
  to the model constructor. For example, Azure OpenAI models may
150
160
  require `api_version`, `azure_deployment_name`,
@@ -190,6 +200,12 @@ class ModelFactory:
190
200
  if model_class is None:
191
201
  raise ValueError(f"Unknown model platform `{model_platform}`")
192
202
 
203
+ # Pass client and async_client via kwargs if provided
204
+ if client is not None:
205
+ kwargs['client'] = client
206
+ if async_client is not None:
207
+ kwargs['async_client'] = async_client
208
+
193
209
  return model_class(
194
210
  model_type=model_type,
195
211
  model_config_dict=model_config_dict,
@@ -12,6 +12,7 @@
12
12
  # limitations under the License.
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
 
15
+ import copy
15
16
  import os
16
17
  from typing import Any, Dict, List, Optional, Type, Union
17
18
 
@@ -19,6 +20,7 @@ from openai import AsyncStream
19
20
  from pydantic import BaseModel
20
21
 
21
22
  from camel.configs import MoonshotConfig
23
+ from camel.logger import get_logger
22
24
  from camel.messages import OpenAIMessage
23
25
  from camel.models._utils import try_modify_message_with_format
24
26
  from camel.models.openai_compatible_model import OpenAICompatibleModel
@@ -34,6 +36,8 @@ from camel.utils import (
34
36
  update_langfuse_trace,
35
37
  )
36
38
 
39
+ logger = get_logger(__name__)
40
+
37
41
  if os.environ.get("LANGFUSE_ENABLED", "False").lower() == "true":
38
42
  try:
39
43
  from langfuse.decorators import observe
@@ -84,7 +88,7 @@ class MoonshotModel(OpenAICompatibleModel):
84
88
  model_type: Union[ModelType, str],
85
89
  model_config_dict: Optional[Dict[str, Any]] = None,
86
90
  api_key: Optional[str] = None,
87
- url: Optional[str] = "https://api.moonshot.ai/v1",
91
+ url: Optional[str] = None,
88
92
  token_counter: Optional[BaseTokenCounter] = None,
89
93
  timeout: Optional[float] = None,
90
94
  max_retries: int = 3,
@@ -93,7 +97,12 @@ class MoonshotModel(OpenAICompatibleModel):
93
97
  if model_config_dict is None:
94
98
  model_config_dict = MoonshotConfig().as_dict()
95
99
  api_key = api_key or os.environ.get("MOONSHOT_API_KEY")
96
- url = url or os.environ.get("MOONSHOT_API_BASE_URL")
100
+ # Preserve default URL if not provided
101
+ if url is None:
102
+ url = (
103
+ os.environ.get("MOONSHOT_API_BASE_URL")
104
+ or "https://api.moonshot.ai/v1"
105
+ )
97
106
  timeout = timeout or float(os.environ.get("MODEL_TIMEOUT", 180))
98
107
  super().__init__(
99
108
  model_type=model_type,
@@ -125,12 +134,12 @@ class MoonshotModel(OpenAICompatibleModel):
125
134
  Returns:
126
135
  Dict[str, Any]: The prepared request configuration.
127
136
  """
128
- import copy
129
-
130
137
  request_config = copy.deepcopy(self.model_config_dict)
131
138
 
132
139
  if tools:
133
- request_config["tools"] = tools
140
+ # Clean tools to remove null types (Moonshot API incompatibility)
141
+ cleaned_tools = self._clean_tool_schemas(tools)
142
+ request_config["tools"] = cleaned_tools
134
143
  elif response_format:
135
144
  # Use the same approach as DeepSeek for structured output
136
145
  try_modify_message_with_format(messages[-1], response_format)
@@ -138,6 +147,94 @@ class MoonshotModel(OpenAICompatibleModel):
138
147
 
139
148
  return request_config
140
149
 
150
+ def _clean_tool_schemas(
151
+ self, tools: List[Dict[str, Any]]
152
+ ) -> List[Dict[str, Any]]:
153
+ r"""Clean tool schemas to remove null types for Moonshot compatibility.
154
+
155
+ Moonshot API doesn't accept {"type": "null"} in anyOf schemas.
156
+ This method removes null type definitions from parameters.
157
+
158
+ Args:
159
+ tools (List[Dict[str, Any]]): Original tool schemas.
160
+
161
+ Returns:
162
+ List[Dict[str, Any]]: Cleaned tool schemas.
163
+ """
164
+
165
+ def remove_null_from_schema(schema: Any) -> Any:
166
+ """Recursively remove null types from schema."""
167
+ if isinstance(schema, dict):
168
+ # Create a copy to avoid modifying the original
169
+ result = {}
170
+
171
+ for key, value in schema.items():
172
+ if key == 'type' and isinstance(value, list):
173
+ # Handle type arrays like ["string", "null"]
174
+ filtered_types = [t for t in value if t != 'null']
175
+ if len(filtered_types) == 1:
176
+ # Single type remains, convert to string
177
+ result[key] = filtered_types[0]
178
+ elif len(filtered_types) > 1:
179
+ # Multiple types remain, keep as array
180
+ result[key] = filtered_types
181
+ else:
182
+ # All were null, use string as fallback
183
+ logger.warning(
184
+ "All types in tool schema type array "
185
+ "were null, falling back to 'string' "
186
+ "type for Moonshot API compatibility. "
187
+ "Original tool schema may need review."
188
+ )
189
+ result[key] = 'string'
190
+ elif key == 'anyOf':
191
+ # Handle anyOf with null types
192
+ filtered = [
193
+ item
194
+ for item in value
195
+ if not (
196
+ isinstance(item, dict)
197
+ and item.get('type') == 'null'
198
+ )
199
+ ]
200
+ if len(filtered) == 1:
201
+ # If only one type remains, flatten it
202
+ return remove_null_from_schema(filtered[0])
203
+ elif len(filtered) > 1:
204
+ result[key] = [
205
+ remove_null_from_schema(item)
206
+ for item in filtered
207
+ ]
208
+ else:
209
+ # All were null, return string type as fallback
210
+ logger.warning(
211
+ "All types in tool schema anyOf were null, "
212
+ "falling back to 'string' type for "
213
+ "Moonshot API compatibility. Original "
214
+ "tool schema may need review."
215
+ )
216
+ return {"type": "string"}
217
+ else:
218
+ # Recursively process other values
219
+ result[key] = remove_null_from_schema(value)
220
+
221
+ return result
222
+ elif isinstance(schema, list):
223
+ return [remove_null_from_schema(item) for item in schema]
224
+ else:
225
+ return schema
226
+
227
+ cleaned_tools = copy.deepcopy(tools)
228
+ for tool in cleaned_tools:
229
+ if 'function' in tool and 'parameters' in tool['function']:
230
+ params = tool['function']['parameters']
231
+ if 'properties' in params:
232
+ params['properties'] = remove_null_from_schema(
233
+ params['properties']
234
+ )
235
+
236
+ return cleaned_tools
237
+
141
238
  @observe()
142
239
  async def _arun(
143
240
  self,
@@ -78,9 +78,21 @@ class OpenAICompatibleModel(BaseModelBackend):
78
78
  (default: :obj:`None`)
79
79
  max_retries (int, optional): Maximum number of retries for API calls.
80
80
  (default: :obj:`3`)
81
+ client (Optional[Any], optional): A custom synchronous
82
+ OpenAI-compatible client instance. If provided, this client will
83
+ be used instead of creating a new one. Useful for RL frameworks
84
+ like AReaL or rLLM that provide OpenAI-compatible clients (e.g.,
85
+ ArealOpenAI). The client should implement the OpenAI client
86
+ interface with `.chat.completions.create()` and `.beta.chat.
87
+ completions.parse()` methods. (default: :obj:`None`)
88
+ async_client (Optional[Any], optional): A custom asynchronous
89
+ OpenAI-compatible client instance. If provided, this client will
90
+ be used instead of creating a new one. The client should implement
91
+ the AsyncOpenAI client interface. (default: :obj:`None`)
81
92
  **kwargs (Any): Additional arguments to pass to the
82
93
  OpenAI client initialization. These can include parameters like
83
94
  'organization', 'default_headers', 'http_client', etc.
95
+ Ignored if custom clients are provided.
84
96
  """
85
97
 
86
98
  def __init__(
@@ -92,6 +104,8 @@ class OpenAICompatibleModel(BaseModelBackend):
92
104
  token_counter: Optional[BaseTokenCounter] = None,
93
105
  timeout: Optional[float] = None,
94
106
  max_retries: int = 3,
107
+ client: Optional[Any] = None,
108
+ async_client: Optional[Any] = None,
95
109
  **kwargs: Any,
96
110
  ) -> None:
97
111
  api_key = api_key or os.environ.get("OPENAI_COMPATIBILITY_API_KEY")
@@ -107,39 +121,55 @@ class OpenAICompatibleModel(BaseModelBackend):
107
121
  timeout,
108
122
  max_retries,
109
123
  )
110
- if is_langfuse_available():
111
- from langfuse.openai import AsyncOpenAI as LangfuseAsyncOpenAI
112
- from langfuse.openai import OpenAI as LangfuseOpenAI
113
-
114
- self._client = LangfuseOpenAI(
115
- timeout=self._timeout,
116
- max_retries=max_retries,
117
- base_url=self._url,
118
- api_key=self._api_key,
119
- **kwargs,
120
- )
121
- self._async_client = LangfuseAsyncOpenAI(
122
- timeout=self._timeout,
123
- max_retries=max_retries,
124
- base_url=self._url,
125
- api_key=self._api_key,
126
- **kwargs,
127
- )
124
+
125
+ # Use custom clients if provided, otherwise create new ones
126
+ if client is not None:
127
+ # Use the provided custom sync client
128
+ self._client = client
128
129
  else:
129
- self._client = OpenAI(
130
- timeout=self._timeout,
131
- max_retries=max_retries,
132
- base_url=self._url,
133
- api_key=self._api_key,
134
- **kwargs,
135
- )
136
- self._async_client = AsyncOpenAI(
137
- timeout=self._timeout,
138
- max_retries=max_retries,
139
- base_url=self._url,
140
- api_key=self._api_key,
141
- **kwargs,
142
- )
130
+ # Create default sync client
131
+ if is_langfuse_available():
132
+ from langfuse.openai import OpenAI as LangfuseOpenAI
133
+
134
+ self._client = LangfuseOpenAI(
135
+ timeout=self._timeout,
136
+ max_retries=max_retries,
137
+ base_url=self._url,
138
+ api_key=self._api_key,
139
+ **kwargs,
140
+ )
141
+ else:
142
+ self._client = OpenAI(
143
+ timeout=self._timeout,
144
+ max_retries=max_retries,
145
+ base_url=self._url,
146
+ api_key=self._api_key,
147
+ **kwargs,
148
+ )
149
+
150
+ if async_client is not None:
151
+ # Use the provided custom async client
152
+ self._async_client = async_client
153
+ else:
154
+ # Create default async client
155
+ if is_langfuse_available():
156
+ from langfuse.openai import AsyncOpenAI as LangfuseAsyncOpenAI
157
+
158
+ self._async_client = LangfuseAsyncOpenAI(
159
+ timeout=self._timeout,
160
+ max_retries=max_retries,
161
+ base_url=self._url,
162
+ api_key=self._api_key,
163
+ **kwargs,
164
+ )
165
+ else:
166
+ self._async_client = AsyncOpenAI(
167
+ timeout=self._timeout,
168
+ max_retries=max_retries,
169
+ base_url=self._url,
170
+ api_key=self._api_key,
171
+ **kwargs,
172
+ )
143
173
 
144
174
  @observe()
145
175
  def _run(