langchain 1.0.5__py3-none-any.whl → 1.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/__init__.py +1 -1
- langchain/agents/__init__.py +1 -7
- langchain/agents/factory.py +153 -79
- langchain/agents/middleware/__init__.py +18 -23
- langchain/agents/middleware/_execution.py +29 -32
- langchain/agents/middleware/_redaction.py +108 -22
- langchain/agents/middleware/_retry.py +123 -0
- langchain/agents/middleware/context_editing.py +47 -25
- langchain/agents/middleware/file_search.py +19 -14
- langchain/agents/middleware/human_in_the_loop.py +87 -57
- langchain/agents/middleware/model_call_limit.py +64 -18
- langchain/agents/middleware/model_fallback.py +7 -9
- langchain/agents/middleware/model_retry.py +307 -0
- langchain/agents/middleware/pii.py +82 -29
- langchain/agents/middleware/shell_tool.py +254 -107
- langchain/agents/middleware/summarization.py +469 -95
- langchain/agents/middleware/todo.py +129 -31
- langchain/agents/middleware/tool_call_limit.py +105 -71
- langchain/agents/middleware/tool_emulator.py +47 -38
- langchain/agents/middleware/tool_retry.py +183 -164
- langchain/agents/middleware/tool_selection.py +81 -37
- langchain/agents/middleware/types.py +856 -427
- langchain/agents/structured_output.py +65 -42
- langchain/chat_models/__init__.py +1 -7
- langchain/chat_models/base.py +253 -196
- langchain/embeddings/__init__.py +0 -5
- langchain/embeddings/base.py +79 -65
- langchain/messages/__init__.py +0 -5
- langchain/tools/__init__.py +1 -7
- {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/METADATA +5 -7
- langchain-1.2.4.dist-info/RECORD +36 -0
- {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/WHEEL +1 -1
- langchain-1.0.5.dist-info/RECORD +0 -34
- {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/licenses/LICENSE +0 -0
langchain/chat_models/base.py
CHANGED
|
@@ -2,17 +2,27 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import functools
|
|
6
|
+
import importlib
|
|
5
7
|
import warnings
|
|
6
|
-
from
|
|
7
|
-
|
|
8
|
+
from typing import (
|
|
9
|
+
TYPE_CHECKING,
|
|
10
|
+
Any,
|
|
11
|
+
Literal,
|
|
12
|
+
TypeAlias,
|
|
13
|
+
cast,
|
|
14
|
+
overload,
|
|
15
|
+
)
|
|
8
16
|
|
|
9
17
|
from langchain_core.language_models import BaseChatModel, LanguageModelInput
|
|
10
18
|
from langchain_core.messages import AIMessage, AnyMessage
|
|
19
|
+
from langchain_core.prompt_values import ChatPromptValueConcrete, StringPromptValue
|
|
11
20
|
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
|
|
12
21
|
from typing_extensions import override
|
|
13
22
|
|
|
14
23
|
if TYPE_CHECKING:
|
|
15
24
|
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
|
|
25
|
+
from types import ModuleType
|
|
16
26
|
|
|
17
27
|
from langchain_core.runnables.schema import StreamEvent
|
|
18
28
|
from langchain_core.tools import BaseTool
|
|
@@ -20,6 +30,131 @@ if TYPE_CHECKING:
|
|
|
20
30
|
from pydantic import BaseModel
|
|
21
31
|
|
|
22
32
|
|
|
33
|
+
def _call(cls: type[BaseChatModel], **kwargs: Any) -> BaseChatModel:
|
|
34
|
+
# TODO: replace with operator.call when lower bounding to Python 3.11
|
|
35
|
+
return cls(**kwargs)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
_SUPPORTED_PROVIDERS: dict[str, tuple[str, str, Callable[..., BaseChatModel]]] = {
|
|
39
|
+
"anthropic": ("langchain_anthropic", "ChatAnthropic", _call),
|
|
40
|
+
"azure_ai": ("langchain_azure_ai.chat_models", "AzureAIChatCompletionsModel", _call),
|
|
41
|
+
"azure_openai": ("langchain_openai", "AzureChatOpenAI", _call),
|
|
42
|
+
"bedrock": ("langchain_aws", "ChatBedrock", _call),
|
|
43
|
+
"bedrock_converse": ("langchain_aws", "ChatBedrockConverse", _call),
|
|
44
|
+
"cohere": ("langchain_cohere", "ChatCohere", _call),
|
|
45
|
+
"deepseek": ("langchain_deepseek", "ChatDeepSeek", _call),
|
|
46
|
+
"fireworks": ("langchain_fireworks", "ChatFireworks", _call),
|
|
47
|
+
"google_anthropic_vertex": (
|
|
48
|
+
"langchain_google_vertexai.model_garden",
|
|
49
|
+
"ChatAnthropicVertex",
|
|
50
|
+
_call,
|
|
51
|
+
),
|
|
52
|
+
"google_genai": ("langchain_google_genai", "ChatGoogleGenerativeAI", _call),
|
|
53
|
+
"google_vertexai": ("langchain_google_vertexai", "ChatVertexAI", _call),
|
|
54
|
+
"groq": ("langchain_groq", "ChatGroq", _call),
|
|
55
|
+
"huggingface": (
|
|
56
|
+
"langchain_huggingface",
|
|
57
|
+
"ChatHuggingFace",
|
|
58
|
+
lambda cls, model, **kwargs: cls.from_model_id(model_id=model, **kwargs),
|
|
59
|
+
),
|
|
60
|
+
"ibm": (
|
|
61
|
+
"langchain_ibm",
|
|
62
|
+
"ChatWatsonx",
|
|
63
|
+
lambda cls, model, **kwargs: cls(model_id=model, **kwargs),
|
|
64
|
+
),
|
|
65
|
+
"mistralai": ("langchain_mistralai", "ChatMistralAI", _call),
|
|
66
|
+
"nvidia": ("langchain_nvidia_ai_endpoints", "ChatNVIDIA", _call),
|
|
67
|
+
"ollama": ("langchain_ollama", "ChatOllama", _call),
|
|
68
|
+
"openai": ("langchain_openai", "ChatOpenAI", _call),
|
|
69
|
+
"perplexity": ("langchain_perplexity", "ChatPerplexity", _call),
|
|
70
|
+
"together": ("langchain_together", "ChatTogether", _call),
|
|
71
|
+
"upstage": ("langchain_upstage", "ChatUpstage", _call),
|
|
72
|
+
"xai": ("langchain_xai", "ChatXAI", _call),
|
|
73
|
+
}
|
|
74
|
+
"""Registry mapping provider names to their import configuration.
|
|
75
|
+
|
|
76
|
+
Each entry maps a provider key to a tuple of:
|
|
77
|
+
|
|
78
|
+
- `module_path`: The Python module path containing the chat model class.
|
|
79
|
+
|
|
80
|
+
This may be a submodule (e.g., `'langchain_azure_ai.chat_models'`) if the class is
|
|
81
|
+
not exported from the package root.
|
|
82
|
+
- `class_name`: The name of the chat model class to import.
|
|
83
|
+
- `creator_func`: A callable that instantiates the class with provided kwargs.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _import_module(module: str, class_name: str) -> ModuleType:
|
|
88
|
+
"""Import a module by name.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
module: The fully qualified module name to import (e.g., `'langchain_openai'`).
|
|
92
|
+
class_name: The name of the class being imported, used for error messages.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
The imported module.
|
|
96
|
+
|
|
97
|
+
Raises:
|
|
98
|
+
ImportError: If the module cannot be imported, with a message suggesting
|
|
99
|
+
the pip package to install.
|
|
100
|
+
"""
|
|
101
|
+
try:
|
|
102
|
+
return importlib.import_module(module)
|
|
103
|
+
except ImportError as e:
|
|
104
|
+
# Extract package name from module path (e.g., "langchain_azure_ai.chat_models"
|
|
105
|
+
# becomes "langchain-azure-ai")
|
|
106
|
+
pkg = module.split(".", maxsplit=1)[0].replace("_", "-")
|
|
107
|
+
msg = (
|
|
108
|
+
f"Initializing {class_name} requires the {pkg} package. Please install it "
|
|
109
|
+
f"with `pip install {pkg}`"
|
|
110
|
+
)
|
|
111
|
+
raise ImportError(msg) from e
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS))
|
|
115
|
+
def _get_chat_model_creator(
|
|
116
|
+
provider: str,
|
|
117
|
+
) -> Callable[..., BaseChatModel]:
|
|
118
|
+
"""Return a factory function that creates a chat model for the given provider.
|
|
119
|
+
|
|
120
|
+
This function is cached to avoid repeated module imports.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
provider: The name of the model provider (e.g., `'openai'`, `'anthropic'`).
|
|
124
|
+
|
|
125
|
+
Must be a key in `_SUPPORTED_PROVIDERS`.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
A callable that accepts model kwargs and returns a `BaseChatModel` instance for
|
|
129
|
+
the specified provider.
|
|
130
|
+
|
|
131
|
+
Raises:
|
|
132
|
+
ValueError: If the provider is not in `_SUPPORTED_PROVIDERS`.
|
|
133
|
+
ImportError: If the provider's integration package is not installed.
|
|
134
|
+
"""
|
|
135
|
+
if provider not in _SUPPORTED_PROVIDERS:
|
|
136
|
+
supported = ", ".join(_SUPPORTED_PROVIDERS.keys())
|
|
137
|
+
msg = f"Unsupported {provider=}.\n\nSupported model providers are: {supported}"
|
|
138
|
+
raise ValueError(msg)
|
|
139
|
+
|
|
140
|
+
pkg, class_name, creator_func = _SUPPORTED_PROVIDERS[provider]
|
|
141
|
+
try:
|
|
142
|
+
module = _import_module(pkg, class_name)
|
|
143
|
+
except ImportError as e:
|
|
144
|
+
if provider != "ollama":
|
|
145
|
+
raise
|
|
146
|
+
# For backwards compatibility
|
|
147
|
+
try:
|
|
148
|
+
module = _import_module("langchain_community.chat_models", class_name)
|
|
149
|
+
except ImportError:
|
|
150
|
+
# If both langchain-ollama and langchain-community aren't available,
|
|
151
|
+
# raise an error related to langchain-ollama
|
|
152
|
+
raise e from None
|
|
153
|
+
|
|
154
|
+
cls = getattr(module, class_name)
|
|
155
|
+
return functools.partial(creator_func, cls=cls)
|
|
156
|
+
|
|
157
|
+
|
|
23
158
|
@overload
|
|
24
159
|
def init_chat_model(
|
|
25
160
|
model: str,
|
|
@@ -73,7 +208,7 @@ def init_chat_model(
|
|
|
73
208
|
runtime via `config`. Makes it easy to switch between models/providers without
|
|
74
209
|
changing your code
|
|
75
210
|
|
|
76
|
-
!!! note
|
|
211
|
+
!!! note "Installation requirements"
|
|
77
212
|
Requires the integration package for the chosen model provider to be installed.
|
|
78
213
|
|
|
79
214
|
See the `model_provider` parameter below for specific package names
|
|
@@ -83,10 +218,26 @@ def init_chat_model(
|
|
|
83
218
|
for supported model parameters to use as `**kwargs`.
|
|
84
219
|
|
|
85
220
|
Args:
|
|
86
|
-
model: The name
|
|
221
|
+
model: The model name, optionally prefixed with provider (e.g., `'openai:gpt-4o'`).
|
|
222
|
+
|
|
223
|
+
Prefer exact model IDs from provider docs over aliases for reliable behavior
|
|
224
|
+
(e.g., dated versions like `'...-20250514'` instead of `'...-latest'`).
|
|
225
|
+
|
|
226
|
+
Will attempt to infer `model_provider` from model if not specified.
|
|
87
227
|
|
|
88
|
-
|
|
89
|
-
|
|
228
|
+
The following providers will be inferred based on these model prefixes:
|
|
229
|
+
|
|
230
|
+
- `gpt-...` | `o1...` | `o3...` -> `openai`
|
|
231
|
+
- `claude...` -> `anthropic`
|
|
232
|
+
- `amazon...` -> `bedrock`
|
|
233
|
+
- `gemini...` -> `google_vertexai`
|
|
234
|
+
- `command...` -> `cohere`
|
|
235
|
+
- `accounts/fireworks...` -> `fireworks`
|
|
236
|
+
- `mistral...` -> `mistralai`
|
|
237
|
+
- `deepseek...` -> `deepseek`
|
|
238
|
+
- `grok...` -> `xai`
|
|
239
|
+
- `sonar...` -> `perplexity`
|
|
240
|
+
- `solar...` -> `upstage`
|
|
90
241
|
model_provider: The model provider if not specified as part of the model arg
|
|
91
242
|
(see above).
|
|
92
243
|
|
|
@@ -110,24 +261,12 @@ def init_chat_model(
|
|
|
110
261
|
- `ollama` -> [`langchain-ollama`](https://docs.langchain.com/oss/python/integrations/providers/ollama)
|
|
111
262
|
- `google_anthropic_vertex` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
|
|
112
263
|
- `deepseek` -> [`langchain-deepseek`](https://docs.langchain.com/oss/python/integrations/providers/deepseek)
|
|
113
|
-
- `ibm` -> [`langchain-ibm`](https://docs.langchain.com/oss/python/integrations/providers/
|
|
264
|
+
- `ibm` -> [`langchain-ibm`](https://docs.langchain.com/oss/python/integrations/providers/ibm)
|
|
114
265
|
- `nvidia` -> [`langchain-nvidia-ai-endpoints`](https://docs.langchain.com/oss/python/integrations/providers/nvidia)
|
|
115
266
|
- `xai` -> [`langchain-xai`](https://docs.langchain.com/oss/python/integrations/providers/xai)
|
|
116
267
|
- `perplexity` -> [`langchain-perplexity`](https://docs.langchain.com/oss/python/integrations/providers/perplexity)
|
|
268
|
+
- `upstage` -> [`langchain-upstage`](https://docs.langchain.com/oss/python/integrations/providers/upstage)
|
|
117
269
|
|
|
118
|
-
Will attempt to infer `model_provider` from model if not specified. The
|
|
119
|
-
following providers will be inferred based on these model prefixes:
|
|
120
|
-
|
|
121
|
-
- `gpt-...` | `o1...` | `o3...` -> `openai`
|
|
122
|
-
- `claude...` -> `anthropic`
|
|
123
|
-
- `amazon...` -> `bedrock`
|
|
124
|
-
- `gemini...` -> `google_vertexai`
|
|
125
|
-
- `command...` -> `cohere`
|
|
126
|
-
- `accounts/fireworks...` -> `fireworks`
|
|
127
|
-
- `mistral...` -> `mistralai`
|
|
128
|
-
- `deepseek...` -> `deepseek`
|
|
129
|
-
- `grok...` -> `xai`
|
|
130
|
-
- `sonar...` -> `perplexity`
|
|
131
270
|
configurable_fields: Which model parameters are configurable at runtime:
|
|
132
271
|
|
|
133
272
|
- `None`: No configurable fields (i.e., a fixed model).
|
|
@@ -142,6 +281,7 @@ def init_chat_model(
|
|
|
142
281
|
If `model` is not specified, then defaults to `("model", "model_provider")`.
|
|
143
282
|
|
|
144
283
|
!!! warning "Security note"
|
|
284
|
+
|
|
145
285
|
Setting `configurable_fields="any"` means fields like `api_key`,
|
|
146
286
|
`base_url`, etc., can be altered at runtime, potentially redirecting
|
|
147
287
|
model requests to a different service/user.
|
|
@@ -331,195 +471,108 @@ def _init_chat_model_helper(
|
|
|
331
471
|
**kwargs: Any,
|
|
332
472
|
) -> BaseChatModel:
|
|
333
473
|
model, model_provider = _parse_model(model, model_provider)
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
from langchain_openai import ChatOpenAI
|
|
337
|
-
|
|
338
|
-
return ChatOpenAI(model=model, **kwargs)
|
|
339
|
-
if model_provider == "anthropic":
|
|
340
|
-
_check_pkg("langchain_anthropic")
|
|
341
|
-
from langchain_anthropic import ChatAnthropic
|
|
342
|
-
|
|
343
|
-
return ChatAnthropic(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
|
|
344
|
-
if model_provider == "azure_openai":
|
|
345
|
-
_check_pkg("langchain_openai")
|
|
346
|
-
from langchain_openai import AzureChatOpenAI
|
|
347
|
-
|
|
348
|
-
return AzureChatOpenAI(model=model, **kwargs)
|
|
349
|
-
if model_provider == "azure_ai":
|
|
350
|
-
_check_pkg("langchain_azure_ai")
|
|
351
|
-
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
|
|
352
|
-
|
|
353
|
-
return AzureAIChatCompletionsModel(model=model, **kwargs)
|
|
354
|
-
if model_provider == "cohere":
|
|
355
|
-
_check_pkg("langchain_cohere")
|
|
356
|
-
from langchain_cohere import ChatCohere
|
|
357
|
-
|
|
358
|
-
return ChatCohere(model=model, **kwargs)
|
|
359
|
-
if model_provider == "google_vertexai":
|
|
360
|
-
_check_pkg("langchain_google_vertexai")
|
|
361
|
-
from langchain_google_vertexai import ChatVertexAI
|
|
362
|
-
|
|
363
|
-
return ChatVertexAI(model=model, **kwargs)
|
|
364
|
-
if model_provider == "google_genai":
|
|
365
|
-
_check_pkg("langchain_google_genai")
|
|
366
|
-
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
367
|
-
|
|
368
|
-
return ChatGoogleGenerativeAI(model=model, **kwargs)
|
|
369
|
-
if model_provider == "fireworks":
|
|
370
|
-
_check_pkg("langchain_fireworks")
|
|
371
|
-
from langchain_fireworks import ChatFireworks
|
|
372
|
-
|
|
373
|
-
return ChatFireworks(model=model, **kwargs)
|
|
374
|
-
if model_provider == "ollama":
|
|
375
|
-
try:
|
|
376
|
-
_check_pkg("langchain_ollama")
|
|
377
|
-
from langchain_ollama import ChatOllama
|
|
378
|
-
except ImportError:
|
|
379
|
-
# For backwards compatibility
|
|
380
|
-
try:
|
|
381
|
-
_check_pkg("langchain_community")
|
|
382
|
-
from langchain_community.chat_models import ChatOllama
|
|
383
|
-
except ImportError:
|
|
384
|
-
# If both langchain-ollama and langchain-community aren't available,
|
|
385
|
-
# raise an error related to langchain-ollama
|
|
386
|
-
_check_pkg("langchain_ollama")
|
|
387
|
-
|
|
388
|
-
return ChatOllama(model=model, **kwargs)
|
|
389
|
-
if model_provider == "together":
|
|
390
|
-
_check_pkg("langchain_together")
|
|
391
|
-
from langchain_together import ChatTogether
|
|
392
|
-
|
|
393
|
-
return ChatTogether(model=model, **kwargs)
|
|
394
|
-
if model_provider == "mistralai":
|
|
395
|
-
_check_pkg("langchain_mistralai")
|
|
396
|
-
from langchain_mistralai import ChatMistralAI
|
|
397
|
-
|
|
398
|
-
return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
|
|
399
|
-
if model_provider == "huggingface":
|
|
400
|
-
_check_pkg("langchain_huggingface")
|
|
401
|
-
from langchain_huggingface import ChatHuggingFace
|
|
402
|
-
|
|
403
|
-
return ChatHuggingFace(model_id=model, **kwargs)
|
|
404
|
-
if model_provider == "groq":
|
|
405
|
-
_check_pkg("langchain_groq")
|
|
406
|
-
from langchain_groq import ChatGroq
|
|
407
|
-
|
|
408
|
-
return ChatGroq(model=model, **kwargs)
|
|
409
|
-
if model_provider == "bedrock":
|
|
410
|
-
_check_pkg("langchain_aws")
|
|
411
|
-
from langchain_aws import ChatBedrock
|
|
412
|
-
|
|
413
|
-
return ChatBedrock(model_id=model, **kwargs)
|
|
414
|
-
if model_provider == "bedrock_converse":
|
|
415
|
-
_check_pkg("langchain_aws")
|
|
416
|
-
from langchain_aws import ChatBedrockConverse
|
|
417
|
-
|
|
418
|
-
return ChatBedrockConverse(model=model, **kwargs)
|
|
419
|
-
if model_provider == "google_anthropic_vertex":
|
|
420
|
-
_check_pkg("langchain_google_vertexai")
|
|
421
|
-
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
|
|
422
|
-
|
|
423
|
-
return ChatAnthropicVertex(model=model, **kwargs)
|
|
424
|
-
if model_provider == "deepseek":
|
|
425
|
-
_check_pkg("langchain_deepseek", pkg_kebab="langchain-deepseek")
|
|
426
|
-
from langchain_deepseek import ChatDeepSeek
|
|
427
|
-
|
|
428
|
-
return ChatDeepSeek(model=model, **kwargs)
|
|
429
|
-
if model_provider == "nvidia":
|
|
430
|
-
_check_pkg("langchain_nvidia_ai_endpoints")
|
|
431
|
-
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
|
432
|
-
|
|
433
|
-
return ChatNVIDIA(model=model, **kwargs)
|
|
434
|
-
if model_provider == "ibm":
|
|
435
|
-
_check_pkg("langchain_ibm")
|
|
436
|
-
from langchain_ibm import ChatWatsonx
|
|
437
|
-
|
|
438
|
-
return ChatWatsonx(model_id=model, **kwargs)
|
|
439
|
-
if model_provider == "xai":
|
|
440
|
-
_check_pkg("langchain_xai")
|
|
441
|
-
from langchain_xai import ChatXAI
|
|
442
|
-
|
|
443
|
-
return ChatXAI(model=model, **kwargs)
|
|
444
|
-
if model_provider == "perplexity":
|
|
445
|
-
_check_pkg("langchain_perplexity")
|
|
446
|
-
from langchain_perplexity import ChatPerplexity
|
|
447
|
-
|
|
448
|
-
return ChatPerplexity(model=model, **kwargs)
|
|
449
|
-
supported = ", ".join(_SUPPORTED_PROVIDERS)
|
|
450
|
-
msg = f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
|
|
451
|
-
raise ValueError(msg)
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
_SUPPORTED_PROVIDERS = {
|
|
455
|
-
"openai",
|
|
456
|
-
"anthropic",
|
|
457
|
-
"azure_openai",
|
|
458
|
-
"azure_ai",
|
|
459
|
-
"cohere",
|
|
460
|
-
"google_vertexai",
|
|
461
|
-
"google_genai",
|
|
462
|
-
"fireworks",
|
|
463
|
-
"ollama",
|
|
464
|
-
"together",
|
|
465
|
-
"mistralai",
|
|
466
|
-
"huggingface",
|
|
467
|
-
"groq",
|
|
468
|
-
"bedrock",
|
|
469
|
-
"bedrock_converse",
|
|
470
|
-
"google_anthropic_vertex",
|
|
471
|
-
"deepseek",
|
|
472
|
-
"ibm",
|
|
473
|
-
"xai",
|
|
474
|
-
"perplexity",
|
|
475
|
-
}
|
|
474
|
+
creator_func = _get_chat_model_creator(model_provider)
|
|
475
|
+
return creator_func(model=model, **kwargs)
|
|
476
476
|
|
|
477
477
|
|
|
478
478
|
def _attempt_infer_model_provider(model_name: str) -> str | None:
|
|
479
|
-
|
|
479
|
+
"""Attempt to infer model provider from model name.
|
|
480
|
+
|
|
481
|
+
Args:
|
|
482
|
+
model_name: The name of the model to infer provider for.
|
|
483
|
+
|
|
484
|
+
Returns:
|
|
485
|
+
The inferred provider name, or `None` if no provider could be inferred.
|
|
486
|
+
"""
|
|
487
|
+
model_lower = model_name.lower()
|
|
488
|
+
|
|
489
|
+
# OpenAI models (including newer models and aliases)
|
|
490
|
+
if any(
|
|
491
|
+
model_lower.startswith(pre)
|
|
492
|
+
for pre in (
|
|
493
|
+
"gpt-",
|
|
494
|
+
"o1",
|
|
495
|
+
"o3",
|
|
496
|
+
"chatgpt",
|
|
497
|
+
"text-davinci",
|
|
498
|
+
)
|
|
499
|
+
):
|
|
480
500
|
return "openai"
|
|
481
|
-
|
|
501
|
+
|
|
502
|
+
# Anthropic models
|
|
503
|
+
if model_lower.startswith("claude"):
|
|
482
504
|
return "anthropic"
|
|
483
|
-
|
|
505
|
+
|
|
506
|
+
# Cohere models
|
|
507
|
+
if model_lower.startswith("command"):
|
|
484
508
|
return "cohere"
|
|
485
|
-
|
|
509
|
+
|
|
510
|
+
# Fireworks models
|
|
511
|
+
if model_lower.startswith("accounts/fireworks"):
|
|
486
512
|
return "fireworks"
|
|
487
|
-
|
|
513
|
+
|
|
514
|
+
# Google models
|
|
515
|
+
if model_lower.startswith("gemini"):
|
|
488
516
|
return "google_vertexai"
|
|
489
|
-
|
|
517
|
+
|
|
518
|
+
# AWS Bedrock models
|
|
519
|
+
if model_lower.startswith(("amazon.", "anthropic.", "meta.")):
|
|
490
520
|
return "bedrock"
|
|
491
|
-
|
|
521
|
+
|
|
522
|
+
# Mistral models
|
|
523
|
+
if model_lower.startswith(("mistral", "mixtral")):
|
|
492
524
|
return "mistralai"
|
|
493
|
-
|
|
525
|
+
|
|
526
|
+
# DeepSeek models
|
|
527
|
+
if model_lower.startswith("deepseek"):
|
|
494
528
|
return "deepseek"
|
|
495
|
-
|
|
529
|
+
|
|
530
|
+
# xAI models
|
|
531
|
+
if model_lower.startswith("grok"):
|
|
496
532
|
return "xai"
|
|
497
|
-
|
|
533
|
+
|
|
534
|
+
# Perplexity models
|
|
535
|
+
if model_lower.startswith("sonar"):
|
|
498
536
|
return "perplexity"
|
|
537
|
+
|
|
538
|
+
# Upstage models
|
|
539
|
+
if model_lower.startswith("solar"):
|
|
540
|
+
return "upstage"
|
|
541
|
+
|
|
499
542
|
return None
|
|
500
543
|
|
|
501
544
|
|
|
502
545
|
def _parse_model(model: str, model_provider: str | None) -> tuple[str, str]:
|
|
503
|
-
|
|
504
|
-
|
|
546
|
+
"""Parse model name and provider, inferring provider if necessary."""
|
|
547
|
+
# Handle provider:model format
|
|
548
|
+
if (
|
|
549
|
+
not model_provider
|
|
550
|
+
and ":" in model
|
|
551
|
+
and model.split(":", maxsplit=1)[0] in _SUPPORTED_PROVIDERS
|
|
552
|
+
):
|
|
553
|
+
model_provider = model.split(":", maxsplit=1)[0]
|
|
505
554
|
model = ":".join(model.split(":")[1:])
|
|
555
|
+
|
|
556
|
+
# Attempt to infer provider if not specified
|
|
506
557
|
model_provider = model_provider or _attempt_infer_model_provider(model)
|
|
558
|
+
|
|
507
559
|
if not model_provider:
|
|
560
|
+
# Enhanced error message with suggestions
|
|
561
|
+
supported_list = ", ".join(sorted(_SUPPORTED_PROVIDERS))
|
|
508
562
|
msg = (
|
|
509
|
-
f"Unable to infer model provider for {model=}
|
|
563
|
+
f"Unable to infer model provider for {model=}. "
|
|
564
|
+
f"Please specify 'model_provider' directly.\n\n"
|
|
565
|
+
f"Supported providers: {supported_list}\n\n"
|
|
566
|
+
f"For help with specific providers, see: "
|
|
567
|
+
f"https://docs.langchain.com/oss/python/integrations/providers"
|
|
510
568
|
)
|
|
511
569
|
raise ValueError(msg)
|
|
570
|
+
|
|
571
|
+
# Normalize provider name
|
|
512
572
|
model_provider = model_provider.replace("-", "_").lower()
|
|
513
573
|
return model, model_provider
|
|
514
574
|
|
|
515
575
|
|
|
516
|
-
def _check_pkg(pkg: str, *, pkg_kebab: str | None = None) -> None:
|
|
517
|
-
if not util.find_spec(pkg):
|
|
518
|
-
pkg_kebab = pkg_kebab if pkg_kebab is not None else pkg.replace("_", "-")
|
|
519
|
-
msg = f"Unable to import {pkg}. Please install with `pip install -U {pkg_kebab}`"
|
|
520
|
-
raise ImportError(msg)
|
|
521
|
-
|
|
522
|
-
|
|
523
576
|
def _remove_prefix(s: str, prefix: str) -> str:
|
|
524
577
|
return s.removeprefix(prefix)
|
|
525
578
|
|
|
@@ -531,22 +584,24 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
531
584
|
def __init__(
|
|
532
585
|
self,
|
|
533
586
|
*,
|
|
534
|
-
default_config: dict | None = None,
|
|
587
|
+
default_config: dict[str, Any] | None = None,
|
|
535
588
|
configurable_fields: Literal["any"] | list[str] | tuple[str, ...] = "any",
|
|
536
589
|
config_prefix: str = "",
|
|
537
|
-
queued_declarative_operations: Sequence[tuple[str, tuple, dict]] = (),
|
|
590
|
+
queued_declarative_operations: Sequence[tuple[str, tuple[Any, ...], dict[str, Any]]] = (),
|
|
538
591
|
) -> None:
|
|
539
|
-
self._default_config: dict = default_config or {}
|
|
592
|
+
self._default_config: dict[str, Any] = default_config or {}
|
|
540
593
|
self._configurable_fields: Literal["any"] | list[str] = (
|
|
541
|
-
|
|
594
|
+
"any" if configurable_fields == "any" else list(configurable_fields)
|
|
542
595
|
)
|
|
543
596
|
self._config_prefix = (
|
|
544
597
|
config_prefix + "_"
|
|
545
598
|
if config_prefix and not config_prefix.endswith("_")
|
|
546
599
|
else config_prefix
|
|
547
600
|
)
|
|
548
|
-
self._queued_declarative_operations: list[tuple[str, tuple, dict]] =
|
|
549
|
-
|
|
601
|
+
self._queued_declarative_operations: list[tuple[str, tuple[Any, ...], dict[str, Any]]] = (
|
|
602
|
+
list(
|
|
603
|
+
queued_declarative_operations,
|
|
604
|
+
)
|
|
550
605
|
)
|
|
551
606
|
|
|
552
607
|
def __getattr__(self, name: str) -> Any:
|
|
@@ -579,14 +634,14 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
579
634
|
msg += "."
|
|
580
635
|
raise AttributeError(msg)
|
|
581
636
|
|
|
582
|
-
def _model(self, config: RunnableConfig | None = None) -> Runnable:
|
|
637
|
+
def _model(self, config: RunnableConfig | None = None) -> Runnable[Any, Any]:
|
|
583
638
|
params = {**self._default_config, **self._model_params(config)}
|
|
584
639
|
model = _init_chat_model_helper(**params)
|
|
585
640
|
for name, args, kwargs in self._queued_declarative_operations:
|
|
586
641
|
model = getattr(model, name)(*args, **kwargs)
|
|
587
642
|
return model
|
|
588
643
|
|
|
589
|
-
def _model_params(self, config: RunnableConfig | None) -> dict:
|
|
644
|
+
def _model_params(self, config: RunnableConfig | None) -> dict[str, Any]:
|
|
590
645
|
config = ensure_config(config)
|
|
591
646
|
model_params = {
|
|
592
647
|
_remove_prefix(k, self._config_prefix): v
|
|
@@ -602,8 +657,9 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
602
657
|
config: RunnableConfig | None = None,
|
|
603
658
|
**kwargs: Any,
|
|
604
659
|
) -> _ConfigurableModel:
|
|
605
|
-
"""Bind config to a `Runnable`, returning a new `Runnable`."""
|
|
606
660
|
config = RunnableConfig(**(config or {}), **cast("RunnableConfig", kwargs))
|
|
661
|
+
# Ensure config is not None after creation
|
|
662
|
+
config = ensure_config(config)
|
|
607
663
|
model_params = self._model_params(config)
|
|
608
664
|
remaining_config = {k: v for k, v in config.items() if k != "configurable"}
|
|
609
665
|
remaining_config["configurable"] = {
|
|
@@ -630,10 +686,9 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
630
686
|
)
|
|
631
687
|
|
|
632
688
|
@property
|
|
689
|
+
@override
|
|
633
690
|
def InputType(self) -> TypeAlias:
|
|
634
691
|
"""Get the input type for this `Runnable`."""
|
|
635
|
-
from langchain_core.prompt_values import ChatPromptValueConcrete, StringPromptValue
|
|
636
|
-
|
|
637
692
|
# This is a version of LanguageModelInput which replaces the abstract
|
|
638
693
|
# base class BaseMessage with a union of its subclasses, which makes
|
|
639
694
|
# for a much better schema.
|
|
@@ -814,6 +869,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
814
869
|
yield x
|
|
815
870
|
|
|
816
871
|
@overload
|
|
872
|
+
@override
|
|
817
873
|
def astream_log(
|
|
818
874
|
self,
|
|
819
875
|
input: Any,
|
|
@@ -831,6 +887,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
831
887
|
) -> AsyncIterator[RunLogPatch]: ...
|
|
832
888
|
|
|
833
889
|
@overload
|
|
890
|
+
@override
|
|
834
891
|
def astream_log(
|
|
835
892
|
self,
|
|
836
893
|
input: Any,
|
|
@@ -910,7 +967,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
910
967
|
# Explicitly added to satisfy downstream linters.
|
|
911
968
|
def bind_tools(
|
|
912
969
|
self,
|
|
913
|
-
tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool],
|
|
970
|
+
tools: Sequence[dict[str, Any] | type[BaseModel] | Callable[..., Any] | BaseTool],
|
|
914
971
|
**kwargs: Any,
|
|
915
972
|
) -> Runnable[LanguageModelInput, AIMessage]:
|
|
916
973
|
return self.__getattr__("bind_tools")(tools, **kwargs)
|
|
@@ -918,7 +975,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
918
975
|
# Explicitly added to satisfy downstream linters.
|
|
919
976
|
def with_structured_output(
|
|
920
977
|
self,
|
|
921
|
-
schema: dict | type[BaseModel],
|
|
978
|
+
schema: dict[str, Any] | type[BaseModel],
|
|
922
979
|
**kwargs: Any,
|
|
923
|
-
) -> Runnable[LanguageModelInput, dict | BaseModel]:
|
|
980
|
+
) -> Runnable[LanguageModelInput, dict[str, Any] | BaseModel]:
|
|
924
981
|
return self.__getattr__("with_structured_output")(schema, **kwargs)
|
langchain/embeddings/__init__.py
CHANGED
|
@@ -1,10 +1,5 @@
|
|
|
1
1
|
"""Embeddings models.
|
|
2
2
|
|
|
3
|
-
!!! warning "Reference docs"
|
|
4
|
-
This page contains **reference documentation** for Embeddings. See
|
|
5
|
-
[the docs](https://docs.langchain.com/oss/python/langchain/retrieval#embedding-models)
|
|
6
|
-
for conceptual guides, tutorials, and examples on using Embeddings.
|
|
7
|
-
|
|
8
3
|
!!! warning "Modules moved"
|
|
9
4
|
With the release of `langchain 1.0.0`, several embeddings modules were moved to
|
|
10
5
|
`langchain-classic`, such as `CacheBackedEmbeddings` and all community
|