langchain-dev-utils 1.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain_dev_utils/__init__.py +1 -0
- langchain_dev_utils/_utils.py +131 -0
- langchain_dev_utils/agents/__init__.py +4 -0
- langchain_dev_utils/agents/factory.py +99 -0
- langchain_dev_utils/agents/file_system.py +252 -0
- langchain_dev_utils/agents/middleware/__init__.py +21 -0
- langchain_dev_utils/agents/middleware/format_prompt.py +66 -0
- langchain_dev_utils/agents/middleware/handoffs.py +214 -0
- langchain_dev_utils/agents/middleware/model_fallback.py +49 -0
- langchain_dev_utils/agents/middleware/model_router.py +200 -0
- langchain_dev_utils/agents/middleware/plan.py +367 -0
- langchain_dev_utils/agents/middleware/summarization.py +85 -0
- langchain_dev_utils/agents/middleware/tool_call_repair.py +96 -0
- langchain_dev_utils/agents/middleware/tool_emulator.py +60 -0
- langchain_dev_utils/agents/middleware/tool_selection.py +82 -0
- langchain_dev_utils/agents/plan.py +188 -0
- langchain_dev_utils/agents/wrap.py +324 -0
- langchain_dev_utils/chat_models/__init__.py +11 -0
- langchain_dev_utils/chat_models/adapters/__init__.py +3 -0
- langchain_dev_utils/chat_models/adapters/create_utils.py +53 -0
- langchain_dev_utils/chat_models/adapters/openai_compatible.py +715 -0
- langchain_dev_utils/chat_models/adapters/register_profiles.py +15 -0
- langchain_dev_utils/chat_models/base.py +282 -0
- langchain_dev_utils/chat_models/types.py +27 -0
- langchain_dev_utils/embeddings/__init__.py +11 -0
- langchain_dev_utils/embeddings/adapters/__init__.py +3 -0
- langchain_dev_utils/embeddings/adapters/create_utils.py +45 -0
- langchain_dev_utils/embeddings/adapters/openai_compatible.py +91 -0
- langchain_dev_utils/embeddings/base.py +234 -0
- langchain_dev_utils/message_convert/__init__.py +15 -0
- langchain_dev_utils/message_convert/content.py +201 -0
- langchain_dev_utils/message_convert/format.py +69 -0
- langchain_dev_utils/pipeline/__init__.py +7 -0
- langchain_dev_utils/pipeline/parallel.py +135 -0
- langchain_dev_utils/pipeline/sequential.py +101 -0
- langchain_dev_utils/pipeline/types.py +3 -0
- langchain_dev_utils/py.typed +0 -0
- langchain_dev_utils/tool_calling/__init__.py +14 -0
- langchain_dev_utils/tool_calling/human_in_the_loop.py +284 -0
- langchain_dev_utils/tool_calling/utils.py +81 -0
- langchain_dev_utils-1.3.7.dist-info/METADATA +103 -0
- langchain_dev_utils-1.3.7.dist-info/RECORD +44 -0
- langchain_dev_utils-1.3.7.dist-info/WHEEL +4 -0
- langchain_dev_utils-1.3.7.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
_PROFILES = {}
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _register_profile_with_provider(
|
|
7
|
+
provider_name: str, profile: dict[str, Any]
|
|
8
|
+
) -> None:
|
|
9
|
+
_PROFILES.update({provider_name: profile})
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _get_profile_by_provider_and_model(
|
|
13
|
+
provider_name: str, model_name: str
|
|
14
|
+
) -> dict[str, Any]:
|
|
15
|
+
return _PROFILES.get(provider_name, {}).get(model_name, {})
|
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
from typing import Any, Optional, cast
|
|
2
|
+
|
|
3
|
+
from langchain.chat_models.base import _SUPPORTED_PROVIDERS, _init_chat_model_helper
|
|
4
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
|
5
|
+
from langchain_core.utils import from_env
|
|
6
|
+
|
|
7
|
+
from langchain_dev_utils._utils import (
|
|
8
|
+
_check_pkg_install,
|
|
9
|
+
_get_base_url_field_name,
|
|
10
|
+
_validate_provider_name,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from .types import ChatModelProvider, ChatModelType, CompatibilityOptions
|
|
14
|
+
|
|
15
|
+
_MODEL_PROVIDERS_DICT = {}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _parse_model(model: str, model_provider: Optional[str]) -> tuple[str, str]:
|
|
19
|
+
"""Parse model string and provider.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
model: Model name string, potentially including provider prefix
|
|
23
|
+
model_provider: Optional provider name
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Tuple of (model_name, provider_name)
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
ValueError: If unable to infer model provider
|
|
30
|
+
"""
|
|
31
|
+
support_providers = list(_MODEL_PROVIDERS_DICT.keys()) + list(_SUPPORTED_PROVIDERS)
|
|
32
|
+
if not model_provider and ":" in model and model.split(":")[0] in support_providers:
|
|
33
|
+
model_provider = model.split(":")[0]
|
|
34
|
+
model = ":".join(model.split(":")[1:])
|
|
35
|
+
if not model_provider:
|
|
36
|
+
msg = (
|
|
37
|
+
f"Unable to infer model provider for {model=}, please specify "
|
|
38
|
+
f"model_provider directly."
|
|
39
|
+
)
|
|
40
|
+
raise ValueError(msg)
|
|
41
|
+
model_provider = model_provider.replace("-", "_").lower()
|
|
42
|
+
return model, model_provider
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _load_chat_model_helper(
|
|
46
|
+
model: str,
|
|
47
|
+
model_provider: Optional[str] = None,
|
|
48
|
+
**kwargs: Any,
|
|
49
|
+
) -> BaseChatModel:
|
|
50
|
+
"""Helper function to load chat model.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
model: Model name
|
|
54
|
+
model_provider: Optional provider name
|
|
55
|
+
**kwargs: Additional arguments for model initialization
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
BaseChatModel: Initialized chat model instance
|
|
59
|
+
"""
|
|
60
|
+
model, model_provider = _parse_model(model, model_provider)
|
|
61
|
+
if model_provider in _MODEL_PROVIDERS_DICT:
|
|
62
|
+
chat_model = _MODEL_PROVIDERS_DICT[model_provider]["chat_model"]
|
|
63
|
+
if base_url := _MODEL_PROVIDERS_DICT[model_provider].get("base_url"):
|
|
64
|
+
url_key = _get_base_url_field_name(chat_model)
|
|
65
|
+
if url_key:
|
|
66
|
+
kwargs.update({url_key: base_url})
|
|
67
|
+
if model_profiles := _MODEL_PROVIDERS_DICT[model_provider].get(
|
|
68
|
+
"model_profiles"
|
|
69
|
+
):
|
|
70
|
+
if model in model_profiles and "profile" not in kwargs:
|
|
71
|
+
kwargs.update({"profile": model_profiles[model]})
|
|
72
|
+
return chat_model(model=model, **kwargs)
|
|
73
|
+
|
|
74
|
+
return _init_chat_model_helper(model, model_provider=model_provider, **kwargs)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def register_model_provider(
|
|
78
|
+
provider_name: str,
|
|
79
|
+
chat_model: ChatModelType,
|
|
80
|
+
base_url: Optional[str] = None,
|
|
81
|
+
model_profiles: Optional[dict[str, dict[str, Any]]] = None,
|
|
82
|
+
compatibility_options: Optional[CompatibilityOptions] = None,
|
|
83
|
+
):
|
|
84
|
+
"""Register a new model provider.
|
|
85
|
+
|
|
86
|
+
This function allows you to register custom chat model providers that can be used
|
|
87
|
+
with the load_chat_model function. It supports both custom model classes and
|
|
88
|
+
string identifiers for supported providers.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
provider_name: The name of the model provider, used as an identifier for
|
|
92
|
+
loading models later.
|
|
93
|
+
chat_model: The chat model, which can be either a `ChatModel` instance or
|
|
94
|
+
a string (currently only `"openai-compatible"` is supported).
|
|
95
|
+
base_url: The API endpoint URL of the model provider (optional; applicable
|
|
96
|
+
to both `chat_model` types, but primarily used when `chat_model` is a
|
|
97
|
+
string with value `"openai-compatible"`).
|
|
98
|
+
model_profiles: Declares the capabilities and parameters supported by each
|
|
99
|
+
model provided by this provider (optional; applicable to both `chat_model`
|
|
100
|
+
types). The configuration corresponding to the `model_name` will be loaded
|
|
101
|
+
and assigned to `model.profile` (e.g., fields such as `max_input_tokens`,
|
|
102
|
+
`tool_calling`etc.).
|
|
103
|
+
compatibility_options: Compatibility options for the model provider (optional;
|
|
104
|
+
only effective when `chat_model` is a string with value `"openai-compatible"`).
|
|
105
|
+
Used to declare support for OpenAI-compatible features (e.g., `tool_choice`
|
|
106
|
+
strategies, JSON mode, etc.) to ensure correct functional adaptation.
|
|
107
|
+
Raises:
|
|
108
|
+
ValueError: If base_url is not provided when chat_model is a string,
|
|
109
|
+
or if chat_model string is not in supported providers
|
|
110
|
+
|
|
111
|
+
Example:
|
|
112
|
+
Basic usage with custom model class:
|
|
113
|
+
>>> from langchain_dev_utils.chat_models import register_model_provider, load_chat_model
|
|
114
|
+
>>> from langchain_core.language_models.fake_chat_models import FakeChatModel
|
|
115
|
+
>>>
|
|
116
|
+
# Register custom model provider
|
|
117
|
+
>>> register_model_provider("fakechat", FakeChatModel)
|
|
118
|
+
>>> model = load_chat_model(model="fakechat:fake-model")
|
|
119
|
+
>>> model.invoke("Hello")
|
|
120
|
+
>>>
|
|
121
|
+
# Using with OpenAI-compatible API:
|
|
122
|
+
>>> register_model_provider(
|
|
123
|
+
... provider_name="vllm",
|
|
124
|
+
... chat_model="openai-compatible",
|
|
125
|
+
... base_url="http://localhost:8000/v1",
|
|
126
|
+
... )
|
|
127
|
+
>>> model = load_chat_model(model="vllm:qwen3-4b")
|
|
128
|
+
>>> model.invoke("Hello")
|
|
129
|
+
"""
|
|
130
|
+
_validate_provider_name(provider_name)
|
|
131
|
+
base_url = base_url or from_env(f"{provider_name.upper()}_API_BASE", default=None)()
|
|
132
|
+
if isinstance(chat_model, str):
|
|
133
|
+
_check_pkg_install("langchain_openai")
|
|
134
|
+
from .adapters.openai_compatible import _create_openai_compatible_model
|
|
135
|
+
|
|
136
|
+
if base_url is None:
|
|
137
|
+
raise ValueError(
|
|
138
|
+
f"base_url must be provided or set {provider_name.upper()}_API_BASE environment variable when chat_model is a string"
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
if chat_model != "openai-compatible":
|
|
142
|
+
raise ValueError(
|
|
143
|
+
"when chat_model is a string, the value must be 'openai-compatible'"
|
|
144
|
+
)
|
|
145
|
+
chat_model = _create_openai_compatible_model(
|
|
146
|
+
provider=provider_name,
|
|
147
|
+
base_url=base_url,
|
|
148
|
+
compatibility_options=compatibility_options,
|
|
149
|
+
profiles=model_profiles,
|
|
150
|
+
)
|
|
151
|
+
_MODEL_PROVIDERS_DICT.update({provider_name: {"chat_model": chat_model}})
|
|
152
|
+
else:
|
|
153
|
+
if base_url is not None:
|
|
154
|
+
_MODEL_PROVIDERS_DICT.update(
|
|
155
|
+
{
|
|
156
|
+
provider_name: {
|
|
157
|
+
"chat_model": chat_model,
|
|
158
|
+
"base_url": base_url,
|
|
159
|
+
"model_profiles": model_profiles,
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
_MODEL_PROVIDERS_DICT.update(
|
|
165
|
+
{
|
|
166
|
+
provider_name: {
|
|
167
|
+
"chat_model": chat_model,
|
|
168
|
+
"model_profiles": model_profiles,
|
|
169
|
+
}
|
|
170
|
+
}
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def batch_register_model_provider(
|
|
175
|
+
providers: list[ChatModelProvider],
|
|
176
|
+
):
|
|
177
|
+
"""Batch register model providers.
|
|
178
|
+
|
|
179
|
+
This function allows you to register multiple model providers at once, which is
|
|
180
|
+
useful when setting up applications that need to work with multiple model services.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
providers: List of ChatModelProvider dictionaries, each containing:
|
|
184
|
+
- provider_name (str): The name of the model provider, used as an
|
|
185
|
+
identifier for loading models later.
|
|
186
|
+
- chat_model (ChatModel | str): The chat model, which can be either
|
|
187
|
+
a `ChatModel` instance or a string (currently only `"openai-compatible"`
|
|
188
|
+
is supported).
|
|
189
|
+
- base_url (str, optional): The API endpoint URL of the model provider.
|
|
190
|
+
Applicable to both `chat_model` types, but primarily used when `chat_model`
|
|
191
|
+
is `"openai-compatible"`.
|
|
192
|
+
- model_profiles (dict, optional): Declares the capabilities and parameters
|
|
193
|
+
supported by each model. The configuration will be loaded and assigned to
|
|
194
|
+
`model.profile` (e.g., `max_input_tokens`, `tool_calling`, etc.).
|
|
195
|
+
- compatibility_options (CompatibilityOptions, optional): Compatibility
|
|
196
|
+
options for the model provider. Only effective when `chat_model` is
|
|
197
|
+
`"openai-compatible"`. Used to declare support for OpenAI-compatible features
|
|
198
|
+
(e.g., `tool_choice` strategies, JSON mode, etc.).
|
|
199
|
+
|
|
200
|
+
Raises:
|
|
201
|
+
ValueError: If any of the providers are invalid
|
|
202
|
+
|
|
203
|
+
Example:
|
|
204
|
+
Register multiple providers at once::
|
|
205
|
+
|
|
206
|
+
>>> from langchain_dev_utils.chat_models import batch_register_model_provider, load_chat_model
|
|
207
|
+
>>> from langchain_core.language_models.fake_chat_models import FakeChatModel
|
|
208
|
+
>>>
|
|
209
|
+
# Register multiple providers
|
|
210
|
+
>>> batch_register_model_provider([
|
|
211
|
+
... {
|
|
212
|
+
... "provider_name": "fakechat",
|
|
213
|
+
... "chat_model": FakeChatModel,
|
|
214
|
+
... },
|
|
215
|
+
... {
|
|
216
|
+
... "provider_name": "vllm",
|
|
217
|
+
... "chat_model": "openai-compatible",
|
|
218
|
+
... "base_url": "http://localhost:8000/v1",
|
|
219
|
+
... },
|
|
220
|
+
... ])
|
|
221
|
+
>>>
|
|
222
|
+
# Use registered providers
|
|
223
|
+
>>> model = load_chat_model("fakechat:fake-model")
|
|
224
|
+
>>> model.invoke("Hello")
|
|
225
|
+
>>>
|
|
226
|
+
>>> model = load_chat_model("vllm:qwen3-4b")
|
|
227
|
+
>>> model.invoke("Hello")
|
|
228
|
+
"""
|
|
229
|
+
|
|
230
|
+
for provider in providers:
|
|
231
|
+
register_model_provider(
|
|
232
|
+
provider["provider_name"],
|
|
233
|
+
provider["chat_model"],
|
|
234
|
+
provider.get("base_url"),
|
|
235
|
+
model_profiles=provider.get("model_profiles"),
|
|
236
|
+
compatibility_options=provider.get("compatibility_options"),
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def load_chat_model(
|
|
241
|
+
model: str,
|
|
242
|
+
*,
|
|
243
|
+
model_provider: Optional[str] = None,
|
|
244
|
+
**kwargs: Any,
|
|
245
|
+
) -> BaseChatModel:
|
|
246
|
+
"""Load a chat model.
|
|
247
|
+
|
|
248
|
+
This function loads a chat model from the registered providers. The model parameter
|
|
249
|
+
can be specified in two ways:
|
|
250
|
+
1. "provider:model-name" - When model_provider is not specified
|
|
251
|
+
2. "model-name" - When model_provider is specified separately
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
model: Model name, either as "provider:model-name" or just "model-name"
|
|
255
|
+
model_provider: Optional provider name (if not included in model parameter)
|
|
256
|
+
**kwargs: Additional arguments for model initialization (e.g., temperature, api_key)
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
BaseChatModel: Initialized chat model instance
|
|
260
|
+
|
|
261
|
+
Example:
|
|
262
|
+
# Load model with provider prefix:
|
|
263
|
+
>>> from langchain_dev_utils.chat_models import load_chat_model
|
|
264
|
+
>>> model = load_chat_model("vllm:qwen3-4b")
|
|
265
|
+
>>> model.invoke("hello")
|
|
266
|
+
|
|
267
|
+
# Load model with separate provider parameter:
|
|
268
|
+
>>> model = load_chat_model("qwen3-4b", model_provider="vllm")
|
|
269
|
+
>>> model.invoke("hello")
|
|
270
|
+
|
|
271
|
+
# Load model with additional parameters:
|
|
272
|
+
>>> model = load_chat_model(
|
|
273
|
+
... "vllm:qwen3-4b",
|
|
274
|
+
... temperature=0.7
|
|
275
|
+
... )
|
|
276
|
+
>>> model.invoke("Hello, how are you?")
|
|
277
|
+
"""
|
|
278
|
+
return _load_chat_model_helper(
|
|
279
|
+
cast(str, model),
|
|
280
|
+
model_provider=model_provider,
|
|
281
|
+
**kwargs,
|
|
282
|
+
)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from typing import Any, Literal, NotRequired, TypedDict, Union
|
|
2
|
+
|
|
3
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
|
4
|
+
|
|
5
|
+
ChatModelType = Union[type[BaseChatModel], Literal["openai-compatible"]]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
ToolChoiceType = list[Literal["auto", "none", "required", "specific"]]
|
|
9
|
+
|
|
10
|
+
ResponseFormatType = list[Literal["json_schema", "json_mode"]]
|
|
11
|
+
|
|
12
|
+
ReasoningKeepPolicy = Literal["never", "current", "all"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CompatibilityOptions(TypedDict):
|
|
16
|
+
supported_tool_choice: NotRequired[ToolChoiceType]
|
|
17
|
+
supported_response_format: NotRequired[ResponseFormatType]
|
|
18
|
+
reasoning_keep_policy: NotRequired[ReasoningKeepPolicy]
|
|
19
|
+
include_usage: NotRequired[bool]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ChatModelProvider(TypedDict):
|
|
23
|
+
provider_name: str
|
|
24
|
+
chat_model: ChatModelType
|
|
25
|
+
base_url: NotRequired[str]
|
|
26
|
+
model_profiles: NotRequired[dict[str, dict[str, Any]]]
|
|
27
|
+
compatibility_options: NotRequired[CompatibilityOptions]
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from typing import Optional, cast
|
|
2
|
+
|
|
3
|
+
from langchain_core.utils import from_env
|
|
4
|
+
|
|
5
|
+
from langchain_dev_utils._utils import _check_pkg_install
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def create_openai_compatible_embedding(
|
|
9
|
+
embedding_provider: str,
|
|
10
|
+
base_url: Optional[str] = None,
|
|
11
|
+
embedding_model_cls_name: Optional[str] = None,
|
|
12
|
+
):
|
|
13
|
+
"""Factory function for creating provider-specific OpenAI-compatible embedding classes.
|
|
14
|
+
|
|
15
|
+
Dynamically generates embedding classes for different OpenAI-compatible providers,
|
|
16
|
+
configuring environment variable mappings and default base URLs specific to each provider.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
embedding_provider (str): Identifier for the OpenAI-compatible provider (e.g. `vllm`, `moonshot`)
|
|
20
|
+
base_url (Optional[str], optional): Default API base URL for the provider. Defaults to None. If not provided, will try to use the environment variable.
|
|
21
|
+
embedding_model_cls_name (Optional[str], optional): Optional custom class name for the generated embedding. Defaults to None.
|
|
22
|
+
Returns:
|
|
23
|
+
Type[_BaseEmbeddingOpenAICompatible]: Configured embedding class ready for instantiation with provider-specific settings
|
|
24
|
+
|
|
25
|
+
Examples:
|
|
26
|
+
>>> from langchain_dev_utils.embeddings.adapters import create_openai_compatible_embedding
|
|
27
|
+
>>> VLLMEmbedding = create_openai_compatible_embedding(
|
|
28
|
+
... "vllm",
|
|
29
|
+
... base_url="http://localhost:8000",
|
|
30
|
+
... embedding_model_cls_name="VLLMEmbedding",
|
|
31
|
+
... )
|
|
32
|
+
>>> model = VLLMEmbedding(model="qwen3-embedding-8b")
|
|
33
|
+
>>> model.embed_query("hello")
|
|
34
|
+
"""
|
|
35
|
+
_check_pkg_install("langchain_openai")
|
|
36
|
+
from .openai_compatible import _create_openai_compatible_embedding
|
|
37
|
+
|
|
38
|
+
base_url = (
|
|
39
|
+
base_url or from_env(f"{embedding_provider.upper()}_API_BASE", default=None)()
|
|
40
|
+
)
|
|
41
|
+
return _create_openai_compatible_embedding(
|
|
42
|
+
provider=embedding_provider,
|
|
43
|
+
base_url=cast(str, base_url),
|
|
44
|
+
embeddings_cls_name=embedding_model_cls_name,
|
|
45
|
+
)
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from typing import Optional, Type
|
|
2
|
+
|
|
3
|
+
from langchain_core.utils import from_env, secret_from_env
|
|
4
|
+
from langchain_openai.embeddings import OpenAIEmbeddings
|
|
5
|
+
from pydantic import Field, SecretStr, create_model
|
|
6
|
+
|
|
7
|
+
from ..._utils import (
|
|
8
|
+
_validate_base_url,
|
|
9
|
+
_validate_model_cls_name,
|
|
10
|
+
_validate_provider_name,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class _BaseEmbeddingOpenAICompatible(OpenAIEmbeddings):
|
|
15
|
+
"""Base class for OpenAI-Compatible embeddings.
|
|
16
|
+
|
|
17
|
+
This class extends the OpenAIEmbeddings class to support
|
|
18
|
+
custom API keys and base URLs for OpenAI-Compatible models.
|
|
19
|
+
|
|
20
|
+
Note: This is a template class and should not be exported or instantiated
|
|
21
|
+
directly. Instead, use it as a base class and provide the specific provider
|
|
22
|
+
name through inheritance or the factory function
|
|
23
|
+
`create_openai_compatible_embedding()`.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
openai_api_key: Optional[SecretStr] = Field(
|
|
27
|
+
default_factory=secret_from_env("OPENAI_COMPATIBLE_API_KEY", default=None),
|
|
28
|
+
alias="api_key",
|
|
29
|
+
)
|
|
30
|
+
"""OpenAI Compatible API key"""
|
|
31
|
+
openai_api_base: str = Field(
|
|
32
|
+
default_factory=from_env("OPENAI_COMPATIBLE_API_BASE", default=""),
|
|
33
|
+
alias="base_url",
|
|
34
|
+
)
|
|
35
|
+
"""OpenAI Compatible API base URL"""
|
|
36
|
+
|
|
37
|
+
check_embedding_ctx_length: bool = False
|
|
38
|
+
"""Whether to check the token length of inputs and automatically split inputs
|
|
39
|
+
longer than embedding_ctx_length. Defaults to False. """
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _create_openai_compatible_embedding(
|
|
43
|
+
provider: str,
|
|
44
|
+
base_url: str,
|
|
45
|
+
embeddings_cls_name: Optional[str] = None,
|
|
46
|
+
) -> Type[_BaseEmbeddingOpenAICompatible]:
|
|
47
|
+
"""Factory function for creating provider-specific OpenAI-compatible embeddings classes.
|
|
48
|
+
|
|
49
|
+
Dynamically generates embeddings classes for different OpenAI-compatible providers,
|
|
50
|
+
configuring environment variable mappings and default base URLs specific to each provider.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
provider: Provider identifier (e.g.`vllm`)
|
|
54
|
+
base_url: Default API base URL for the provider
|
|
55
|
+
embeddings_cls_name: Optional custom class name for the generated embeddings. Defaults to None.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Configured embeddings class ready for instantiation with provider-specific settings
|
|
59
|
+
"""
|
|
60
|
+
embeddings_cls_name = embeddings_cls_name or f"{provider.title()}Embeddings"
|
|
61
|
+
|
|
62
|
+
if len(provider) >= 20:
|
|
63
|
+
raise ValueError(
|
|
64
|
+
f"provider must be less than 50 characters. Received: {provider}"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
_validate_model_cls_name(embeddings_cls_name)
|
|
68
|
+
_validate_provider_name(provider)
|
|
69
|
+
|
|
70
|
+
_validate_base_url(base_url)
|
|
71
|
+
|
|
72
|
+
return create_model(
|
|
73
|
+
embeddings_cls_name,
|
|
74
|
+
__base__=_BaseEmbeddingOpenAICompatible,
|
|
75
|
+
openai_api_base=(
|
|
76
|
+
str,
|
|
77
|
+
Field(
|
|
78
|
+
default_factory=from_env(
|
|
79
|
+
f"{provider.upper()}_API_BASE", default=base_url
|
|
80
|
+
),
|
|
81
|
+
),
|
|
82
|
+
),
|
|
83
|
+
openai_api_key=(
|
|
84
|
+
str,
|
|
85
|
+
Field(
|
|
86
|
+
default_factory=secret_from_env(
|
|
87
|
+
f"{provider.upper()}_API_KEY", default=None
|
|
88
|
+
),
|
|
89
|
+
),
|
|
90
|
+
),
|
|
91
|
+
)
|