langchain 1.0.4__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/__init__.py +1 -1
- langchain/agents/__init__.py +1 -7
- langchain/agents/factory.py +100 -41
- langchain/agents/middleware/__init__.py +5 -7
- langchain/agents/middleware/_execution.py +21 -20
- langchain/agents/middleware/_redaction.py +27 -12
- langchain/agents/middleware/_retry.py +123 -0
- langchain/agents/middleware/context_editing.py +26 -22
- langchain/agents/middleware/file_search.py +18 -13
- langchain/agents/middleware/human_in_the_loop.py +60 -54
- langchain/agents/middleware/model_call_limit.py +63 -17
- langchain/agents/middleware/model_fallback.py +7 -9
- langchain/agents/middleware/model_retry.py +300 -0
- langchain/agents/middleware/pii.py +80 -27
- langchain/agents/middleware/shell_tool.py +230 -103
- langchain/agents/middleware/summarization.py +439 -90
- langchain/agents/middleware/todo.py +111 -27
- langchain/agents/middleware/tool_call_limit.py +105 -71
- langchain/agents/middleware/tool_emulator.py +42 -33
- langchain/agents/middleware/tool_retry.py +171 -159
- langchain/agents/middleware/tool_selection.py +37 -27
- langchain/agents/middleware/types.py +754 -392
- langchain/agents/structured_output.py +22 -12
- langchain/chat_models/__init__.py +1 -7
- langchain/chat_models/base.py +234 -185
- langchain/embeddings/__init__.py +0 -5
- langchain/embeddings/base.py +80 -66
- langchain/messages/__init__.py +0 -5
- langchain/tools/__init__.py +1 -7
- {langchain-1.0.4.dist-info → langchain-1.2.3.dist-info}/METADATA +3 -5
- langchain-1.2.3.dist-info/RECORD +36 -0
- {langchain-1.0.4.dist-info → langchain-1.2.3.dist-info}/WHEEL +1 -1
- langchain-1.0.4.dist-info/RECORD +0 -34
- {langchain-1.0.4.dist-info → langchain-1.2.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import json
|
|
5
6
|
import uuid
|
|
6
7
|
from dataclasses import dataclass, is_dataclass
|
|
7
8
|
from types import UnionType
|
|
@@ -128,7 +129,7 @@ class _SchemaSpec(Generic[SchemaT]):
|
|
|
128
129
|
json_schema: dict[str, Any]
|
|
129
130
|
"""JSON schema associated with the schema."""
|
|
130
131
|
|
|
131
|
-
strict: bool =
|
|
132
|
+
strict: bool | None = None
|
|
132
133
|
"""Whether to enforce strict validation of the schema."""
|
|
133
134
|
|
|
134
135
|
def __init__(
|
|
@@ -137,7 +138,7 @@ class _SchemaSpec(Generic[SchemaT]):
|
|
|
137
138
|
*,
|
|
138
139
|
name: str | None = None,
|
|
139
140
|
description: str | None = None,
|
|
140
|
-
strict: bool =
|
|
141
|
+
strict: bool | None = None,
|
|
141
142
|
) -> None:
|
|
142
143
|
"""Initialize SchemaSpec with schema and optional parameters."""
|
|
143
144
|
self.schema = schema
|
|
@@ -227,7 +228,7 @@ class ToolStrategy(Generic[SchemaT]):
|
|
|
227
228
|
|
|
228
229
|
def _iter_variants(schema: Any) -> Iterable[Any]:
|
|
229
230
|
"""Yield leaf variants from Union and JSON Schema oneOf."""
|
|
230
|
-
if get_origin(schema) in
|
|
231
|
+
if get_origin(schema) in {UnionType, Union}:
|
|
231
232
|
for arg in get_args(schema):
|
|
232
233
|
yield from _iter_variants(arg)
|
|
233
234
|
return
|
|
@@ -255,21 +256,32 @@ class ProviderStrategy(Generic[SchemaT]):
|
|
|
255
256
|
def __init__(
|
|
256
257
|
self,
|
|
257
258
|
schema: type[SchemaT],
|
|
259
|
+
*,
|
|
260
|
+
strict: bool | None = None,
|
|
258
261
|
) -> None:
|
|
259
|
-
"""Initialize ProviderStrategy with schema.
|
|
262
|
+
"""Initialize ProviderStrategy with schema.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
schema: Schema to enforce via the provider's native structured output.
|
|
266
|
+
strict: Whether to request strict provider-side schema enforcement.
|
|
267
|
+
"""
|
|
260
268
|
self.schema = schema
|
|
261
|
-
self.schema_spec = _SchemaSpec(schema)
|
|
269
|
+
self.schema_spec = _SchemaSpec(schema, strict=strict)
|
|
262
270
|
|
|
263
271
|
def to_model_kwargs(self) -> dict[str, Any]:
|
|
264
272
|
"""Convert to kwargs to bind to a model to force structured output."""
|
|
265
273
|
# OpenAI:
|
|
266
274
|
# - see https://platform.openai.com/docs/guides/structured-outputs
|
|
267
|
-
|
|
275
|
+
json_schema: dict[str, Any] = {
|
|
276
|
+
"name": self.schema_spec.name,
|
|
277
|
+
"schema": self.schema_spec.json_schema,
|
|
278
|
+
}
|
|
279
|
+
if self.schema_spec.strict:
|
|
280
|
+
json_schema["strict"] = True
|
|
281
|
+
|
|
282
|
+
response_format: dict[str, Any] = {
|
|
268
283
|
"type": "json_schema",
|
|
269
|
-
"json_schema":
|
|
270
|
-
"name": self.schema_spec.name,
|
|
271
|
-
"schema": self.schema_spec.json_schema,
|
|
272
|
-
},
|
|
284
|
+
"json_schema": json_schema,
|
|
273
285
|
}
|
|
274
286
|
return {"response_format": response_format}
|
|
275
287
|
|
|
@@ -374,8 +386,6 @@ class ProviderStrategyBinding(Generic[SchemaT]):
|
|
|
374
386
|
# Extract text content from AIMessage and parse as JSON
|
|
375
387
|
raw_text = self._extract_text_content_from_message(response)
|
|
376
388
|
|
|
377
|
-
import json
|
|
378
|
-
|
|
379
389
|
try:
|
|
380
390
|
data = json.loads(raw_text)
|
|
381
391
|
except Exception as e:
|
|
@@ -1,10 +1,4 @@
|
|
|
1
|
-
"""Entrypoint to using [chat models](https://docs.langchain.com/oss/python/langchain/models) in LangChain.
|
|
2
|
-
|
|
3
|
-
!!! warning "Reference docs"
|
|
4
|
-
This page contains **reference documentation** for chat models. See
|
|
5
|
-
[the docs](https://docs.langchain.com/oss/python/langchain/models) for conceptual
|
|
6
|
-
guides, tutorials, and examples on using chat models.
|
|
7
|
-
""" # noqa: E501
|
|
1
|
+
"""Entrypoint to using [chat models](https://docs.langchain.com/oss/python/langchain/models) in LangChain.""" # noqa: E501
|
|
8
2
|
|
|
9
3
|
from langchain_core.language_models import BaseChatModel
|
|
10
4
|
|
langchain/chat_models/base.py
CHANGED
|
@@ -2,17 +2,27 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import functools
|
|
6
|
+
import importlib
|
|
5
7
|
import warnings
|
|
6
|
-
from
|
|
7
|
-
|
|
8
|
+
from typing import (
|
|
9
|
+
TYPE_CHECKING,
|
|
10
|
+
Any,
|
|
11
|
+
Literal,
|
|
12
|
+
TypeAlias,
|
|
13
|
+
cast,
|
|
14
|
+
overload,
|
|
15
|
+
)
|
|
8
16
|
|
|
9
17
|
from langchain_core.language_models import BaseChatModel, LanguageModelInput
|
|
10
18
|
from langchain_core.messages import AIMessage, AnyMessage
|
|
19
|
+
from langchain_core.prompt_values import ChatPromptValueConcrete, StringPromptValue
|
|
11
20
|
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
|
|
12
21
|
from typing_extensions import override
|
|
13
22
|
|
|
14
23
|
if TYPE_CHECKING:
|
|
15
24
|
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
|
|
25
|
+
from types import ModuleType
|
|
16
26
|
|
|
17
27
|
from langchain_core.runnables.schema import StreamEvent
|
|
18
28
|
from langchain_core.tools import BaseTool
|
|
@@ -20,6 +30,127 @@ if TYPE_CHECKING:
|
|
|
20
30
|
from pydantic import BaseModel
|
|
21
31
|
|
|
22
32
|
|
|
33
|
+
def _call(cls: type[BaseChatModel], **kwargs: Any) -> BaseChatModel:
|
|
34
|
+
# TODO: replace with operator.call when lower bounding to Python 3.11
|
|
35
|
+
return cls(**kwargs)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
_SUPPORTED_PROVIDERS: dict[str, tuple[str, str, Callable[..., BaseChatModel]]] = {
|
|
39
|
+
"anthropic": ("langchain_anthropic", "ChatAnthropic", _call),
|
|
40
|
+
"azure_ai": ("langchain_azure_ai.chat_models", "AzureAIChatCompletionsModel", _call),
|
|
41
|
+
"azure_openai": ("langchain_openai", "AzureChatOpenAI", _call),
|
|
42
|
+
"bedrock": ("langchain_aws", "ChatBedrock", _call),
|
|
43
|
+
"bedrock_converse": ("langchain_aws", "ChatBedrockConverse", _call),
|
|
44
|
+
"cohere": ("langchain_cohere", "ChatCohere", _call),
|
|
45
|
+
"deepseek": ("langchain_deepseek", "ChatDeepSeek", _call),
|
|
46
|
+
"fireworks": ("langchain_fireworks", "ChatFireworks", _call),
|
|
47
|
+
"google_anthropic_vertex": (
|
|
48
|
+
"langchain_google_vertexai.model_garden",
|
|
49
|
+
"ChatAnthropicVertex",
|
|
50
|
+
_call,
|
|
51
|
+
),
|
|
52
|
+
"google_genai": ("langchain_google_genai", "ChatGoogleGenerativeAI", _call),
|
|
53
|
+
"google_vertexai": ("langchain_google_vertexai", "ChatVertexAI", _call),
|
|
54
|
+
"groq": ("langchain_groq", "ChatGroq", _call),
|
|
55
|
+
"huggingface": (
|
|
56
|
+
"langchain_huggingface",
|
|
57
|
+
"ChatHuggingFace",
|
|
58
|
+
lambda cls, model, **kwargs: cls.from_model_id(model_id=model, **kwargs),
|
|
59
|
+
),
|
|
60
|
+
"ibm": (
|
|
61
|
+
"langchain_ibm",
|
|
62
|
+
"ChatWatsonx",
|
|
63
|
+
lambda cls, model, **kwargs: cls(model_id=model, **kwargs),
|
|
64
|
+
),
|
|
65
|
+
"mistralai": ("langchain_mistralai", "ChatMistralAI", _call),
|
|
66
|
+
"nvidia": ("langchain_nvidia_ai_endpoints", "ChatNVIDIA", _call),
|
|
67
|
+
"ollama": ("langchain_ollama", "ChatOllama", _call),
|
|
68
|
+
"openai": ("langchain_openai", "ChatOpenAI", _call),
|
|
69
|
+
"perplexity": ("langchain_perplexity", "ChatPerplexity", _call),
|
|
70
|
+
"together": ("langchain_together", "ChatTogether", _call),
|
|
71
|
+
"upstage": ("langchain_upstage", "ChatUpstage", _call),
|
|
72
|
+
"xai": ("langchain_xai", "ChatXAI", _call),
|
|
73
|
+
}
|
|
74
|
+
"""Registry mapping provider names to their import configuration.
|
|
75
|
+
|
|
76
|
+
Each entry maps a provider key to a tuple of:
|
|
77
|
+
|
|
78
|
+
- `module_path`: The Python module path containing the chat model class.
|
|
79
|
+
|
|
80
|
+
This may be a submodule (e.g., `'langchain_azure_ai.chat_models'`) if the class is
|
|
81
|
+
not exported from the package root.
|
|
82
|
+
- `class_name`: The name of the chat model class to import.
|
|
83
|
+
- `creator_func`: A callable that instantiates the class with provided kwargs.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _import_module(module: str) -> ModuleType:
|
|
88
|
+
"""Import a module by name.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
module: The fully qualified module name to import (e.g., `'langchain_openai'`).
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
The imported module.
|
|
95
|
+
|
|
96
|
+
Raises:
|
|
97
|
+
ImportError: If the module cannot be imported, with a message suggesting
|
|
98
|
+
the pip package to install.
|
|
99
|
+
"""
|
|
100
|
+
try:
|
|
101
|
+
return importlib.import_module(module)
|
|
102
|
+
except ImportError as e:
|
|
103
|
+
# Extract package name from module path (e.g., "langchain_azure_ai.chat_models"
|
|
104
|
+
# becomes "langchain-azure-ai")
|
|
105
|
+
pkg = module.split(".", maxsplit=1)[0].replace("_", "-")
|
|
106
|
+
msg = f"Could not import {pkg} python package. Please install it with `pip install {pkg}`"
|
|
107
|
+
raise ImportError(msg) from e
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS))
|
|
111
|
+
def _get_chat_model_creator(
|
|
112
|
+
provider: str,
|
|
113
|
+
) -> Callable[..., BaseChatModel]:
|
|
114
|
+
"""Return a factory function that creates a chat model for the given provider.
|
|
115
|
+
|
|
116
|
+
This function is cached to avoid repeated module imports.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
provider: The name of the model provider (e.g., `'openai'`, `'anthropic'`).
|
|
120
|
+
|
|
121
|
+
Must be a key in `_SUPPORTED_PROVIDERS`.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
A callable that accepts model kwargs and returns a `BaseChatModel` instance for
|
|
125
|
+
the specified provider.
|
|
126
|
+
|
|
127
|
+
Raises:
|
|
128
|
+
ValueError: If the provider is not in `_SUPPORTED_PROVIDERS`.
|
|
129
|
+
ImportError: If the provider's integration package is not installed.
|
|
130
|
+
"""
|
|
131
|
+
if provider not in _SUPPORTED_PROVIDERS:
|
|
132
|
+
supported = ", ".join(_SUPPORTED_PROVIDERS.keys())
|
|
133
|
+
msg = f"Unsupported {provider=}.\n\nSupported model providers are: {supported}"
|
|
134
|
+
raise ValueError(msg)
|
|
135
|
+
|
|
136
|
+
pkg, class_name, creator_func = _SUPPORTED_PROVIDERS[provider]
|
|
137
|
+
try:
|
|
138
|
+
module = _import_module(pkg)
|
|
139
|
+
except ImportError as e:
|
|
140
|
+
if provider != "ollama":
|
|
141
|
+
raise
|
|
142
|
+
# For backwards compatibility
|
|
143
|
+
try:
|
|
144
|
+
module = _import_module("langchain_community.chat_models")
|
|
145
|
+
except ImportError:
|
|
146
|
+
# If both langchain-ollama and langchain-community aren't available,
|
|
147
|
+
# raise an error related to langchain-ollama
|
|
148
|
+
raise e from None
|
|
149
|
+
|
|
150
|
+
cls = getattr(module, class_name)
|
|
151
|
+
return functools.partial(creator_func, cls=cls)
|
|
152
|
+
|
|
153
|
+
|
|
23
154
|
@overload
|
|
24
155
|
def init_chat_model(
|
|
25
156
|
model: str,
|
|
@@ -73,7 +204,7 @@ def init_chat_model(
|
|
|
73
204
|
runtime via `config`. Makes it easy to switch between models/providers without
|
|
74
205
|
changing your code
|
|
75
206
|
|
|
76
|
-
!!! note
|
|
207
|
+
!!! note "Installation requirements"
|
|
77
208
|
Requires the integration package for the chosen model provider to be installed.
|
|
78
209
|
|
|
79
210
|
See the `model_provider` parameter below for specific package names
|
|
@@ -83,10 +214,23 @@ def init_chat_model(
|
|
|
83
214
|
for supported model parameters to use as `**kwargs`.
|
|
84
215
|
|
|
85
216
|
Args:
|
|
86
|
-
model: The name
|
|
217
|
+
model: The model name, optionally prefixed with provider (e.g., `'openai:gpt-4o'`).
|
|
218
|
+
|
|
219
|
+
Will attempt to infer `model_provider` from model if not specified.
|
|
220
|
+
|
|
221
|
+
The following providers will be inferred based on these model prefixes:
|
|
87
222
|
|
|
88
|
-
|
|
89
|
-
`
|
|
223
|
+
- `gpt-...` | `o1...` | `o3...` -> `openai`
|
|
224
|
+
- `claude...` -> `anthropic`
|
|
225
|
+
- `amazon...` -> `bedrock`
|
|
226
|
+
- `gemini...` -> `google_vertexai`
|
|
227
|
+
- `command...` -> `cohere`
|
|
228
|
+
- `accounts/fireworks...` -> `fireworks`
|
|
229
|
+
- `mistral...` -> `mistralai`
|
|
230
|
+
- `deepseek...` -> `deepseek`
|
|
231
|
+
- `grok...` -> `xai`
|
|
232
|
+
- `sonar...` -> `perplexity`
|
|
233
|
+
- `solar...` -> `upstage`
|
|
90
234
|
model_provider: The model provider if not specified as part of the model arg
|
|
91
235
|
(see above).
|
|
92
236
|
|
|
@@ -110,24 +254,12 @@ def init_chat_model(
|
|
|
110
254
|
- `ollama` -> [`langchain-ollama`](https://docs.langchain.com/oss/python/integrations/providers/ollama)
|
|
111
255
|
- `google_anthropic_vertex` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
|
|
112
256
|
- `deepseek` -> [`langchain-deepseek`](https://docs.langchain.com/oss/python/integrations/providers/deepseek)
|
|
113
|
-
- `ibm` -> [`langchain-ibm`](https://docs.langchain.com/oss/python/integrations/providers/
|
|
257
|
+
- `ibm` -> [`langchain-ibm`](https://docs.langchain.com/oss/python/integrations/providers/ibm)
|
|
114
258
|
- `nvidia` -> [`langchain-nvidia-ai-endpoints`](https://docs.langchain.com/oss/python/integrations/providers/nvidia)
|
|
115
259
|
- `xai` -> [`langchain-xai`](https://docs.langchain.com/oss/python/integrations/providers/xai)
|
|
116
260
|
- `perplexity` -> [`langchain-perplexity`](https://docs.langchain.com/oss/python/integrations/providers/perplexity)
|
|
261
|
+
- `upstage` -> [`langchain-upstage`](https://docs.langchain.com/oss/python/integrations/providers/upstage)
|
|
117
262
|
|
|
118
|
-
Will attempt to infer `model_provider` from model if not specified. The
|
|
119
|
-
following providers will be inferred based on these model prefixes:
|
|
120
|
-
|
|
121
|
-
- `gpt-...` | `o1...` | `o3...` -> `openai`
|
|
122
|
-
- `claude...` -> `anthropic`
|
|
123
|
-
- `amazon...` -> `bedrock`
|
|
124
|
-
- `gemini...` -> `google_vertexai`
|
|
125
|
-
- `command...` -> `cohere`
|
|
126
|
-
- `accounts/fireworks...` -> `fireworks`
|
|
127
|
-
- `mistral...` -> `mistralai`
|
|
128
|
-
- `deepseek...` -> `deepseek`
|
|
129
|
-
- `grok...` -> `xai`
|
|
130
|
-
- `sonar...` -> `perplexity`
|
|
131
263
|
configurable_fields: Which model parameters are configurable at runtime:
|
|
132
264
|
|
|
133
265
|
- `None`: No configurable fields (i.e., a fixed model).
|
|
@@ -142,6 +274,7 @@ def init_chat_model(
|
|
|
142
274
|
If `model` is not specified, then defaults to `("model", "model_provider")`.
|
|
143
275
|
|
|
144
276
|
!!! warning "Security note"
|
|
277
|
+
|
|
145
278
|
Setting `configurable_fields="any"` means fields like `api_key`,
|
|
146
279
|
`base_url`, etc., can be altered at runtime, potentially redirecting
|
|
147
280
|
model requests to a different service/user.
|
|
@@ -331,195 +464,108 @@ def _init_chat_model_helper(
|
|
|
331
464
|
**kwargs: Any,
|
|
332
465
|
) -> BaseChatModel:
|
|
333
466
|
model, model_provider = _parse_model(model, model_provider)
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
from langchain_openai import ChatOpenAI
|
|
337
|
-
|
|
338
|
-
return ChatOpenAI(model=model, **kwargs)
|
|
339
|
-
if model_provider == "anthropic":
|
|
340
|
-
_check_pkg("langchain_anthropic")
|
|
341
|
-
from langchain_anthropic import ChatAnthropic
|
|
342
|
-
|
|
343
|
-
return ChatAnthropic(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
|
|
344
|
-
if model_provider == "azure_openai":
|
|
345
|
-
_check_pkg("langchain_openai")
|
|
346
|
-
from langchain_openai import AzureChatOpenAI
|
|
347
|
-
|
|
348
|
-
return AzureChatOpenAI(model=model, **kwargs)
|
|
349
|
-
if model_provider == "azure_ai":
|
|
350
|
-
_check_pkg("langchain_azure_ai")
|
|
351
|
-
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
|
|
352
|
-
|
|
353
|
-
return AzureAIChatCompletionsModel(model=model, **kwargs)
|
|
354
|
-
if model_provider == "cohere":
|
|
355
|
-
_check_pkg("langchain_cohere")
|
|
356
|
-
from langchain_cohere import ChatCohere
|
|
357
|
-
|
|
358
|
-
return ChatCohere(model=model, **kwargs)
|
|
359
|
-
if model_provider == "google_vertexai":
|
|
360
|
-
_check_pkg("langchain_google_vertexai")
|
|
361
|
-
from langchain_google_vertexai import ChatVertexAI
|
|
362
|
-
|
|
363
|
-
return ChatVertexAI(model=model, **kwargs)
|
|
364
|
-
if model_provider == "google_genai":
|
|
365
|
-
_check_pkg("langchain_google_genai")
|
|
366
|
-
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
367
|
-
|
|
368
|
-
return ChatGoogleGenerativeAI(model=model, **kwargs)
|
|
369
|
-
if model_provider == "fireworks":
|
|
370
|
-
_check_pkg("langchain_fireworks")
|
|
371
|
-
from langchain_fireworks import ChatFireworks
|
|
372
|
-
|
|
373
|
-
return ChatFireworks(model=model, **kwargs)
|
|
374
|
-
if model_provider == "ollama":
|
|
375
|
-
try:
|
|
376
|
-
_check_pkg("langchain_ollama")
|
|
377
|
-
from langchain_ollama import ChatOllama
|
|
378
|
-
except ImportError:
|
|
379
|
-
# For backwards compatibility
|
|
380
|
-
try:
|
|
381
|
-
_check_pkg("langchain_community")
|
|
382
|
-
from langchain_community.chat_models import ChatOllama
|
|
383
|
-
except ImportError:
|
|
384
|
-
# If both langchain-ollama and langchain-community aren't available,
|
|
385
|
-
# raise an error related to langchain-ollama
|
|
386
|
-
_check_pkg("langchain_ollama")
|
|
387
|
-
|
|
388
|
-
return ChatOllama(model=model, **kwargs)
|
|
389
|
-
if model_provider == "together":
|
|
390
|
-
_check_pkg("langchain_together")
|
|
391
|
-
from langchain_together import ChatTogether
|
|
392
|
-
|
|
393
|
-
return ChatTogether(model=model, **kwargs)
|
|
394
|
-
if model_provider == "mistralai":
|
|
395
|
-
_check_pkg("langchain_mistralai")
|
|
396
|
-
from langchain_mistralai import ChatMistralAI
|
|
397
|
-
|
|
398
|
-
return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
|
|
399
|
-
if model_provider == "huggingface":
|
|
400
|
-
_check_pkg("langchain_huggingface")
|
|
401
|
-
from langchain_huggingface import ChatHuggingFace
|
|
402
|
-
|
|
403
|
-
return ChatHuggingFace(model_id=model, **kwargs)
|
|
404
|
-
if model_provider == "groq":
|
|
405
|
-
_check_pkg("langchain_groq")
|
|
406
|
-
from langchain_groq import ChatGroq
|
|
407
|
-
|
|
408
|
-
return ChatGroq(model=model, **kwargs)
|
|
409
|
-
if model_provider == "bedrock":
|
|
410
|
-
_check_pkg("langchain_aws")
|
|
411
|
-
from langchain_aws import ChatBedrock
|
|
412
|
-
|
|
413
|
-
return ChatBedrock(model_id=model, **kwargs)
|
|
414
|
-
if model_provider == "bedrock_converse":
|
|
415
|
-
_check_pkg("langchain_aws")
|
|
416
|
-
from langchain_aws import ChatBedrockConverse
|
|
417
|
-
|
|
418
|
-
return ChatBedrockConverse(model=model, **kwargs)
|
|
419
|
-
if model_provider == "google_anthropic_vertex":
|
|
420
|
-
_check_pkg("langchain_google_vertexai")
|
|
421
|
-
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
|
|
422
|
-
|
|
423
|
-
return ChatAnthropicVertex(model=model, **kwargs)
|
|
424
|
-
if model_provider == "deepseek":
|
|
425
|
-
_check_pkg("langchain_deepseek", pkg_kebab="langchain-deepseek")
|
|
426
|
-
from langchain_deepseek import ChatDeepSeek
|
|
427
|
-
|
|
428
|
-
return ChatDeepSeek(model=model, **kwargs)
|
|
429
|
-
if model_provider == "nvidia":
|
|
430
|
-
_check_pkg("langchain_nvidia_ai_endpoints")
|
|
431
|
-
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
|
432
|
-
|
|
433
|
-
return ChatNVIDIA(model=model, **kwargs)
|
|
434
|
-
if model_provider == "ibm":
|
|
435
|
-
_check_pkg("langchain_ibm")
|
|
436
|
-
from langchain_ibm import ChatWatsonx
|
|
437
|
-
|
|
438
|
-
return ChatWatsonx(model_id=model, **kwargs)
|
|
439
|
-
if model_provider == "xai":
|
|
440
|
-
_check_pkg("langchain_xai")
|
|
441
|
-
from langchain_xai import ChatXAI
|
|
442
|
-
|
|
443
|
-
return ChatXAI(model=model, **kwargs)
|
|
444
|
-
if model_provider == "perplexity":
|
|
445
|
-
_check_pkg("langchain_perplexity")
|
|
446
|
-
from langchain_perplexity import ChatPerplexity
|
|
447
|
-
|
|
448
|
-
return ChatPerplexity(model=model, **kwargs)
|
|
449
|
-
supported = ", ".join(_SUPPORTED_PROVIDERS)
|
|
450
|
-
msg = f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
|
|
451
|
-
raise ValueError(msg)
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
_SUPPORTED_PROVIDERS = {
|
|
455
|
-
"openai",
|
|
456
|
-
"anthropic",
|
|
457
|
-
"azure_openai",
|
|
458
|
-
"azure_ai",
|
|
459
|
-
"cohere",
|
|
460
|
-
"google_vertexai",
|
|
461
|
-
"google_genai",
|
|
462
|
-
"fireworks",
|
|
463
|
-
"ollama",
|
|
464
|
-
"together",
|
|
465
|
-
"mistralai",
|
|
466
|
-
"huggingface",
|
|
467
|
-
"groq",
|
|
468
|
-
"bedrock",
|
|
469
|
-
"bedrock_converse",
|
|
470
|
-
"google_anthropic_vertex",
|
|
471
|
-
"deepseek",
|
|
472
|
-
"ibm",
|
|
473
|
-
"xai",
|
|
474
|
-
"perplexity",
|
|
475
|
-
}
|
|
467
|
+
creator_func = _get_chat_model_creator(model_provider)
|
|
468
|
+
return creator_func(model=model, **kwargs)
|
|
476
469
|
|
|
477
470
|
|
|
478
471
|
def _attempt_infer_model_provider(model_name: str) -> str | None:
|
|
479
|
-
|
|
472
|
+
"""Attempt to infer model provider from model name.
|
|
473
|
+
|
|
474
|
+
Args:
|
|
475
|
+
model_name: The name of the model to infer provider for.
|
|
476
|
+
|
|
477
|
+
Returns:
|
|
478
|
+
The inferred provider name, or `None` if no provider could be inferred.
|
|
479
|
+
"""
|
|
480
|
+
model_lower = model_name.lower()
|
|
481
|
+
|
|
482
|
+
# OpenAI models (including newer models and aliases)
|
|
483
|
+
if any(
|
|
484
|
+
model_lower.startswith(pre)
|
|
485
|
+
for pre in (
|
|
486
|
+
"gpt-",
|
|
487
|
+
"o1",
|
|
488
|
+
"o3",
|
|
489
|
+
"chatgpt",
|
|
490
|
+
"text-davinci",
|
|
491
|
+
)
|
|
492
|
+
):
|
|
480
493
|
return "openai"
|
|
481
|
-
|
|
494
|
+
|
|
495
|
+
# Anthropic models
|
|
496
|
+
if model_lower.startswith("claude"):
|
|
482
497
|
return "anthropic"
|
|
483
|
-
|
|
498
|
+
|
|
499
|
+
# Cohere models
|
|
500
|
+
if model_lower.startswith("command"):
|
|
484
501
|
return "cohere"
|
|
502
|
+
|
|
503
|
+
# Fireworks models
|
|
485
504
|
if model_name.startswith("accounts/fireworks"):
|
|
486
505
|
return "fireworks"
|
|
487
|
-
|
|
506
|
+
|
|
507
|
+
# Google models
|
|
508
|
+
if model_lower.startswith("gemini"):
|
|
488
509
|
return "google_vertexai"
|
|
489
|
-
|
|
510
|
+
|
|
511
|
+
# AWS Bedrock models
|
|
512
|
+
if model_name.startswith("amazon.") or model_lower.startswith(("anthropic.", "meta.")):
|
|
490
513
|
return "bedrock"
|
|
491
|
-
|
|
514
|
+
|
|
515
|
+
# Mistral models
|
|
516
|
+
if model_lower.startswith(("mistral", "mixtral")):
|
|
492
517
|
return "mistralai"
|
|
493
|
-
|
|
518
|
+
|
|
519
|
+
# DeepSeek models
|
|
520
|
+
if model_lower.startswith("deepseek"):
|
|
494
521
|
return "deepseek"
|
|
495
|
-
|
|
522
|
+
|
|
523
|
+
# xAI models
|
|
524
|
+
if model_lower.startswith("grok"):
|
|
496
525
|
return "xai"
|
|
497
|
-
|
|
526
|
+
|
|
527
|
+
# Perplexity models
|
|
528
|
+
if model_lower.startswith("sonar"):
|
|
498
529
|
return "perplexity"
|
|
530
|
+
|
|
531
|
+
# Upstage models
|
|
532
|
+
if model_lower.startswith("solar"):
|
|
533
|
+
return "upstage"
|
|
534
|
+
|
|
499
535
|
return None
|
|
500
536
|
|
|
501
537
|
|
|
502
538
|
def _parse_model(model: str, model_provider: str | None) -> tuple[str, str]:
|
|
503
|
-
|
|
504
|
-
|
|
539
|
+
"""Parse model name and provider, inferring provider if necessary."""
|
|
540
|
+
# Handle provider:model format
|
|
541
|
+
if (
|
|
542
|
+
not model_provider
|
|
543
|
+
and ":" in model
|
|
544
|
+
and model.split(":", maxsplit=1)[0] in _SUPPORTED_PROVIDERS
|
|
545
|
+
):
|
|
546
|
+
model_provider = model.split(":", maxsplit=1)[0]
|
|
505
547
|
model = ":".join(model.split(":")[1:])
|
|
548
|
+
|
|
549
|
+
# Attempt to infer provider if not specified
|
|
506
550
|
model_provider = model_provider or _attempt_infer_model_provider(model)
|
|
551
|
+
|
|
507
552
|
if not model_provider:
|
|
553
|
+
# Enhanced error message with suggestions
|
|
554
|
+
supported_list = ", ".join(sorted(_SUPPORTED_PROVIDERS))
|
|
508
555
|
msg = (
|
|
509
|
-
f"Unable to infer model provider for {model=}
|
|
556
|
+
f"Unable to infer model provider for {model=}. "
|
|
557
|
+
f"Please specify 'model_provider' directly.\n\n"
|
|
558
|
+
f"Supported providers: {supported_list}\n\n"
|
|
559
|
+
f"For help with specific providers, see: "
|
|
560
|
+
f"https://docs.langchain.com/oss/python/integrations/providers"
|
|
510
561
|
)
|
|
511
562
|
raise ValueError(msg)
|
|
563
|
+
|
|
564
|
+
# Normalize provider name
|
|
512
565
|
model_provider = model_provider.replace("-", "_").lower()
|
|
513
566
|
return model, model_provider
|
|
514
567
|
|
|
515
568
|
|
|
516
|
-
def _check_pkg(pkg: str, *, pkg_kebab: str | None = None) -> None:
|
|
517
|
-
if not util.find_spec(pkg):
|
|
518
|
-
pkg_kebab = pkg_kebab if pkg_kebab is not None else pkg.replace("_", "-")
|
|
519
|
-
msg = f"Unable to import {pkg}. Please install with `pip install -U {pkg_kebab}`"
|
|
520
|
-
raise ImportError(msg)
|
|
521
|
-
|
|
522
|
-
|
|
523
569
|
def _remove_prefix(s: str, prefix: str) -> str:
|
|
524
570
|
return s.removeprefix(prefix)
|
|
525
571
|
|
|
@@ -538,7 +584,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
538
584
|
) -> None:
|
|
539
585
|
self._default_config: dict = default_config or {}
|
|
540
586
|
self._configurable_fields: Literal["any"] | list[str] = (
|
|
541
|
-
|
|
587
|
+
"any" if configurable_fields == "any" else list(configurable_fields)
|
|
542
588
|
)
|
|
543
589
|
self._config_prefix = (
|
|
544
590
|
config_prefix + "_"
|
|
@@ -602,8 +648,10 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
602
648
|
config: RunnableConfig | None = None,
|
|
603
649
|
**kwargs: Any,
|
|
604
650
|
) -> _ConfigurableModel:
|
|
605
|
-
"""Bind config to a Runnable
|
|
651
|
+
"""Bind config to a `Runnable`, returning a new `Runnable`."""
|
|
606
652
|
config = RunnableConfig(**(config or {}), **cast("RunnableConfig", kwargs))
|
|
653
|
+
# Ensure config is not None after creation
|
|
654
|
+
config = ensure_config(config)
|
|
607
655
|
model_params = self._model_params(config)
|
|
608
656
|
remaining_config = {k: v for k, v in config.items() if k != "configurable"}
|
|
609
657
|
remaining_config["configurable"] = {
|
|
@@ -630,10 +678,9 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
630
678
|
)
|
|
631
679
|
|
|
632
680
|
@property
|
|
681
|
+
@override
|
|
633
682
|
def InputType(self) -> TypeAlias:
|
|
634
683
|
"""Get the input type for this `Runnable`."""
|
|
635
|
-
from langchain_core.prompt_values import ChatPromptValueConcrete, StringPromptValue
|
|
636
|
-
|
|
637
684
|
# This is a version of LanguageModelInput which replaces the abstract
|
|
638
685
|
# base class BaseMessage with a union of its subclasses, which makes
|
|
639
686
|
# for a much better schema.
|
|
@@ -814,6 +861,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
814
861
|
yield x
|
|
815
862
|
|
|
816
863
|
@overload
|
|
864
|
+
@override
|
|
817
865
|
def astream_log(
|
|
818
866
|
self,
|
|
819
867
|
input: Any,
|
|
@@ -831,6 +879,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
831
879
|
) -> AsyncIterator[RunLogPatch]: ...
|
|
832
880
|
|
|
833
881
|
@overload
|
|
882
|
+
@override
|
|
834
883
|
def astream_log(
|
|
835
884
|
self,
|
|
836
885
|
input: Any,
|
langchain/embeddings/__init__.py
CHANGED
|
@@ -1,10 +1,5 @@
|
|
|
1
1
|
"""Embeddings models.
|
|
2
2
|
|
|
3
|
-
!!! warning "Reference docs"
|
|
4
|
-
This page contains **reference documentation** for Embeddings. See
|
|
5
|
-
[the docs](https://docs.langchain.com/oss/python/langchain/retrieval#embedding-models)
|
|
6
|
-
for conceptual guides, tutorials, and examples on using Embeddings.
|
|
7
|
-
|
|
8
3
|
!!! warning "Modules moved"
|
|
9
4
|
With the release of `langchain 1.0.0`, several embeddings modules were moved to
|
|
10
5
|
`langchain-classic`, such as `CacheBackedEmbeddings` and all community
|