langchain 0.2.16__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain might be problematic. Click here for more details.
- langchain/agents/agent.py +23 -19
- langchain/agents/agent_toolkits/vectorstore/toolkit.py +10 -7
- langchain/agents/chat/base.py +1 -1
- langchain/agents/conversational/base.py +1 -1
- langchain/agents/conversational_chat/base.py +1 -1
- langchain/agents/mrkl/base.py +1 -1
- langchain/agents/openai_assistant/base.py +8 -7
- langchain/agents/openai_functions_agent/base.py +6 -5
- langchain/agents/openai_functions_multi_agent/base.py +6 -5
- langchain/agents/react/base.py +1 -1
- langchain/agents/self_ask_with_search/base.py +1 -1
- langchain/agents/structured_chat/base.py +1 -1
- langchain/agents/structured_chat/output_parser.py +1 -1
- langchain/chains/api/base.py +14 -12
- langchain/chains/base.py +17 -9
- langchain/chains/combine_documents/base.py +1 -1
- langchain/chains/combine_documents/map_reduce.py +14 -10
- langchain/chains/combine_documents/map_rerank.py +17 -14
- langchain/chains/combine_documents/reduce.py +5 -3
- langchain/chains/combine_documents/refine.py +11 -8
- langchain/chains/combine_documents/stuff.py +8 -6
- langchain/chains/constitutional_ai/models.py +1 -1
- langchain/chains/conversation/base.py +13 -11
- langchain/chains/conversational_retrieval/base.py +10 -8
- langchain/chains/elasticsearch_database/base.py +11 -9
- langchain/chains/flare/base.py +1 -1
- langchain/chains/hyde/base.py +6 -4
- langchain/chains/llm.py +7 -7
- langchain/chains/llm_checker/base.py +8 -6
- langchain/chains/llm_math/base.py +8 -6
- langchain/chains/llm_summarization_checker/base.py +8 -6
- langchain/chains/mapreduce.py +5 -3
- langchain/chains/moderation.py +6 -5
- langchain/chains/natbot/base.py +8 -6
- langchain/chains/openai_functions/base.py +3 -3
- langchain/chains/openai_functions/citation_fuzzy_match.py +1 -1
- langchain/chains/openai_functions/extraction.py +8 -4
- langchain/chains/openai_functions/qa_with_structure.py +5 -2
- langchain/chains/openai_functions/tagging.py +5 -2
- langchain/chains/openai_tools/extraction.py +2 -2
- langchain/chains/prompt_selector.py +1 -1
- langchain/chains/qa_generation/base.py +1 -1
- langchain/chains/qa_with_sources/base.py +8 -6
- langchain/chains/qa_with_sources/retrieval.py +1 -1
- langchain/chains/qa_with_sources/vector_db.py +4 -3
- langchain/chains/query_constructor/schema.py +5 -4
- langchain/chains/retrieval_qa/base.py +12 -9
- langchain/chains/router/base.py +5 -3
- langchain/chains/router/embedding_router.py +5 -3
- langchain/chains/router/llm_router.py +6 -5
- langchain/chains/sequential.py +17 -13
- langchain/chains/structured_output/base.py +8 -8
- langchain/chains/transform.py +1 -1
- langchain/chat_models/base.py +2 -2
- langchain/evaluation/agents/trajectory_eval_chain.py +4 -3
- langchain/evaluation/comparison/eval_chain.py +4 -3
- langchain/evaluation/criteria/eval_chain.py +4 -3
- langchain/evaluation/embedding_distance/base.py +4 -3
- langchain/evaluation/qa/eval_chain.py +7 -4
- langchain/evaluation/qa/generate_chain.py +1 -1
- langchain/evaluation/scoring/eval_chain.py +4 -3
- langchain/evaluation/string_distance/base.py +1 -1
- langchain/indexes/vectorstore.py +9 -7
- langchain/memory/chat_memory.py +1 -1
- langchain/memory/combined.py +5 -3
- langchain/memory/entity.py +4 -3
- langchain/memory/summary.py +1 -1
- langchain/memory/vectorstore.py +1 -1
- langchain/memory/vectorstore_token_buffer_memory.py +1 -1
- langchain/output_parsers/fix.py +3 -2
- langchain/output_parsers/pandas_dataframe.py +3 -2
- langchain/output_parsers/retry.py +4 -3
- langchain/output_parsers/structured.py +1 -1
- langchain/output_parsers/yaml.py +5 -2
- langchain/pydantic_v1/__init__.py +20 -0
- langchain/pydantic_v1/dataclasses.py +20 -0
- langchain/pydantic_v1/main.py +20 -0
- langchain/retrievers/contextual_compression.py +4 -2
- langchain/retrievers/document_compressors/base.py +4 -2
- langchain/retrievers/document_compressors/chain_extract.py +4 -2
- langchain/retrievers/document_compressors/chain_filter.py +4 -2
- langchain/retrievers/document_compressors/cohere_rerank.py +8 -6
- langchain/retrievers/document_compressors/cross_encoder_rerank.py +5 -3
- langchain/retrievers/document_compressors/embeddings_filter.py +5 -4
- langchain/retrievers/document_compressors/listwise_rerank.py +4 -3
- langchain/retrievers/ensemble.py +18 -14
- langchain/retrievers/multi_vector.py +5 -4
- langchain/retrievers/self_query/base.py +8 -6
- langchain/retrievers/time_weighted_retriever.py +4 -3
- langchain/smith/evaluation/config.py +7 -5
- {langchain-0.2.16.dist-info → langchain-0.3.0.dist-info}/METADATA +5 -5
- {langchain-0.2.16.dist-info → langchain-0.3.0.dist-info}/RECORD +95 -95
- {langchain-0.2.16.dist-info → langchain-0.3.0.dist-info}/LICENSE +0 -0
- {langchain-0.2.16.dist-info → langchain-0.3.0.dist-info}/WHEEL +0 -0
- {langchain-0.2.16.dist-info → langchain-0.3.0.dist-info}/entry_points.txt +0 -0
langchain/agents/agent.py
CHANGED
|
@@ -40,11 +40,12 @@ from langchain_core.output_parsers import BaseOutputParser
|
|
|
40
40
|
from langchain_core.prompts import BasePromptTemplate
|
|
41
41
|
from langchain_core.prompts.few_shot import FewShotPromptTemplate
|
|
42
42
|
from langchain_core.prompts.prompt import PromptTemplate
|
|
43
|
-
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
|
44
43
|
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
|
|
45
44
|
from langchain_core.runnables.utils import AddableDict
|
|
46
45
|
from langchain_core.tools import BaseTool
|
|
47
46
|
from langchain_core.utils.input import get_color_mapping
|
|
47
|
+
from pydantic import BaseModel, ConfigDict, model_validator
|
|
48
|
+
from typing_extensions import Self
|
|
48
49
|
|
|
49
50
|
from langchain.agents.agent_iterator import AgentExecutorIterator
|
|
50
51
|
from langchain.agents.agent_types import AgentType
|
|
@@ -175,7 +176,7 @@ class BaseSingleActionAgent(BaseModel):
|
|
|
175
176
|
Returns:
|
|
176
177
|
Dict: Dictionary representation of agent.
|
|
177
178
|
"""
|
|
178
|
-
_dict = super().
|
|
179
|
+
_dict = super().model_dump()
|
|
179
180
|
try:
|
|
180
181
|
_type = self._agent_type
|
|
181
182
|
except NotImplementedError:
|
|
@@ -323,7 +324,7 @@ class BaseMultiActionAgent(BaseModel):
|
|
|
323
324
|
|
|
324
325
|
def dict(self, **kwargs: Any) -> Dict:
|
|
325
326
|
"""Return dictionary representation of agent."""
|
|
326
|
-
_dict = super().
|
|
327
|
+
_dict = super().model_dump()
|
|
327
328
|
try:
|
|
328
329
|
_dict["_type"] = str(self._agent_type)
|
|
329
330
|
except NotImplementedError:
|
|
@@ -420,8 +421,9 @@ class RunnableAgent(BaseSingleActionAgent):
|
|
|
420
421
|
individual LLM tokens will not be available in stream_log.
|
|
421
422
|
"""
|
|
422
423
|
|
|
423
|
-
|
|
424
|
-
arbitrary_types_allowed
|
|
424
|
+
model_config = ConfigDict(
|
|
425
|
+
arbitrary_types_allowed=True,
|
|
426
|
+
)
|
|
425
427
|
|
|
426
428
|
@property
|
|
427
429
|
def return_values(self) -> List[str]:
|
|
@@ -528,8 +530,9 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
|
|
|
528
530
|
individual LLM tokens will not be available in stream_log.
|
|
529
531
|
"""
|
|
530
532
|
|
|
531
|
-
|
|
532
|
-
arbitrary_types_allowed
|
|
533
|
+
model_config = ConfigDict(
|
|
534
|
+
arbitrary_types_allowed=True,
|
|
535
|
+
)
|
|
533
536
|
|
|
534
537
|
@property
|
|
535
538
|
def return_values(self) -> List[str]:
|
|
@@ -854,8 +857,8 @@ class Agent(BaseSingleActionAgent):
|
|
|
854
857
|
"""
|
|
855
858
|
return list(set(self.llm_chain.input_keys) - {"agent_scratchpad"})
|
|
856
859
|
|
|
857
|
-
@
|
|
858
|
-
def validate_prompt(
|
|
860
|
+
@model_validator(mode="after")
|
|
861
|
+
def validate_prompt(self) -> Self:
|
|
859
862
|
"""Validate that prompt matches format.
|
|
860
863
|
|
|
861
864
|
Args:
|
|
@@ -868,7 +871,7 @@ class Agent(BaseSingleActionAgent):
|
|
|
868
871
|
ValueError: If `agent_scratchpad` is not in prompt.input_variables
|
|
869
872
|
and prompt is not a FewShotPromptTemplate or a PromptTemplate.
|
|
870
873
|
"""
|
|
871
|
-
prompt =
|
|
874
|
+
prompt = self.llm_chain.prompt
|
|
872
875
|
if "agent_scratchpad" not in prompt.input_variables:
|
|
873
876
|
logger.warning(
|
|
874
877
|
"`agent_scratchpad` should be a variable in prompt.input_variables."
|
|
@@ -881,7 +884,7 @@ class Agent(BaseSingleActionAgent):
|
|
|
881
884
|
prompt.suffix += "\n{agent_scratchpad}"
|
|
882
885
|
else:
|
|
883
886
|
raise ValueError(f"Got unexpected prompt type {type(prompt)}")
|
|
884
|
-
return
|
|
887
|
+
return self
|
|
885
888
|
|
|
886
889
|
@property
|
|
887
890
|
@abstractmethod
|
|
@@ -1120,8 +1123,8 @@ class AgentExecutor(Chain):
|
|
|
1120
1123
|
**kwargs,
|
|
1121
1124
|
)
|
|
1122
1125
|
|
|
1123
|
-
@
|
|
1124
|
-
def validate_tools(
|
|
1126
|
+
@model_validator(mode="after")
|
|
1127
|
+
def validate_tools(self) -> Self:
|
|
1125
1128
|
"""Validate that tools are compatible with agent.
|
|
1126
1129
|
|
|
1127
1130
|
Args:
|
|
@@ -1133,19 +1136,20 @@ class AgentExecutor(Chain):
|
|
|
1133
1136
|
Raises:
|
|
1134
1137
|
ValueError: If allowed tools are different than provided tools.
|
|
1135
1138
|
"""
|
|
1136
|
-
agent =
|
|
1137
|
-
tools =
|
|
1138
|
-
allowed_tools = agent.get_allowed_tools()
|
|
1139
|
+
agent = self.agent
|
|
1140
|
+
tools = self.tools
|
|
1141
|
+
allowed_tools = agent.get_allowed_tools() # type: ignore
|
|
1139
1142
|
if allowed_tools is not None:
|
|
1140
1143
|
if set(allowed_tools) != set([tool.name for tool in tools]):
|
|
1141
1144
|
raise ValueError(
|
|
1142
1145
|
f"Allowed tools ({allowed_tools}) different than "
|
|
1143
1146
|
f"provided tools ({[tool.name for tool in tools]})"
|
|
1144
1147
|
)
|
|
1145
|
-
return
|
|
1148
|
+
return self
|
|
1146
1149
|
|
|
1147
|
-
@
|
|
1148
|
-
|
|
1150
|
+
@model_validator(mode="before")
|
|
1151
|
+
@classmethod
|
|
1152
|
+
def validate_runnable_agent(cls, values: Dict) -> Any:
|
|
1149
1153
|
"""Convert runnable to agent if passed in.
|
|
1150
1154
|
|
|
1151
1155
|
Args:
|
|
@@ -3,10 +3,10 @@
|
|
|
3
3
|
from typing import List
|
|
4
4
|
|
|
5
5
|
from langchain_core.language_models import BaseLanguageModel
|
|
6
|
-
from langchain_core.pydantic_v1 import BaseModel, Field
|
|
7
6
|
from langchain_core.tools import BaseTool
|
|
8
7
|
from langchain_core.tools.base import BaseToolkit
|
|
9
8
|
from langchain_core.vectorstores import VectorStore
|
|
9
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class VectorStoreInfo(BaseModel):
|
|
@@ -16,8 +16,9 @@ class VectorStoreInfo(BaseModel):
|
|
|
16
16
|
name: str
|
|
17
17
|
description: str
|
|
18
18
|
|
|
19
|
-
|
|
20
|
-
arbitrary_types_allowed
|
|
19
|
+
model_config = ConfigDict(
|
|
20
|
+
arbitrary_types_allowed=True,
|
|
21
|
+
)
|
|
21
22
|
|
|
22
23
|
|
|
23
24
|
class VectorStoreToolkit(BaseToolkit):
|
|
@@ -26,8 +27,9 @@ class VectorStoreToolkit(BaseToolkit):
|
|
|
26
27
|
vectorstore_info: VectorStoreInfo = Field(exclude=True)
|
|
27
28
|
llm: BaseLanguageModel
|
|
28
29
|
|
|
29
|
-
|
|
30
|
-
arbitrary_types_allowed
|
|
30
|
+
model_config = ConfigDict(
|
|
31
|
+
arbitrary_types_allowed=True,
|
|
32
|
+
)
|
|
31
33
|
|
|
32
34
|
def get_tools(self) -> List[BaseTool]:
|
|
33
35
|
"""Get the tools in the toolkit."""
|
|
@@ -67,8 +69,9 @@ class VectorStoreRouterToolkit(BaseToolkit):
|
|
|
67
69
|
vectorstores: List[VectorStoreInfo] = Field(exclude=True)
|
|
68
70
|
llm: BaseLanguageModel
|
|
69
71
|
|
|
70
|
-
|
|
71
|
-
arbitrary_types_allowed
|
|
72
|
+
model_config = ConfigDict(
|
|
73
|
+
arbitrary_types_allowed=True,
|
|
74
|
+
)
|
|
72
75
|
|
|
73
76
|
def get_tools(self) -> List[BaseTool]:
|
|
74
77
|
"""Get the tools in the toolkit."""
|
langchain/agents/chat/base.py
CHANGED
|
@@ -10,8 +10,8 @@ from langchain_core.prompts.chat import (
|
|
|
10
10
|
HumanMessagePromptTemplate,
|
|
11
11
|
SystemMessagePromptTemplate,
|
|
12
12
|
)
|
|
13
|
-
from langchain_core.pydantic_v1 import Field
|
|
14
13
|
from langchain_core.tools import BaseTool
|
|
14
|
+
from pydantic import Field
|
|
15
15
|
|
|
16
16
|
from langchain.agents.agent import Agent, AgentOutputParser
|
|
17
17
|
from langchain.agents.chat.output_parser import ChatOutputParser
|
|
@@ -8,8 +8,8 @@ from langchain_core._api import deprecated
|
|
|
8
8
|
from langchain_core.callbacks import BaseCallbackManager
|
|
9
9
|
from langchain_core.language_models import BaseLanguageModel
|
|
10
10
|
from langchain_core.prompts import PromptTemplate
|
|
11
|
-
from langchain_core.pydantic_v1 import Field
|
|
12
11
|
from langchain_core.tools import BaseTool
|
|
12
|
+
from pydantic import Field
|
|
13
13
|
|
|
14
14
|
from langchain.agents.agent import Agent, AgentOutputParser
|
|
15
15
|
from langchain.agents.agent_types import AgentType
|
|
@@ -17,8 +17,8 @@ from langchain_core.prompts.chat import (
|
|
|
17
17
|
MessagesPlaceholder,
|
|
18
18
|
SystemMessagePromptTemplate,
|
|
19
19
|
)
|
|
20
|
-
from langchain_core.pydantic_v1 import Field
|
|
21
20
|
from langchain_core.tools import BaseTool
|
|
21
|
+
from pydantic import Field
|
|
22
22
|
|
|
23
23
|
from langchain.agents.agent import Agent, AgentOutputParser
|
|
24
24
|
from langchain.agents.conversational_chat.output_parser import ConvoOutputParser
|
langchain/agents/mrkl/base.py
CHANGED
|
@@ -8,9 +8,9 @@ from langchain_core._api import deprecated
|
|
|
8
8
|
from langchain_core.callbacks import BaseCallbackManager
|
|
9
9
|
from langchain_core.language_models import BaseLanguageModel
|
|
10
10
|
from langchain_core.prompts import PromptTemplate
|
|
11
|
-
from langchain_core.pydantic_v1 import Field
|
|
12
11
|
from langchain_core.tools import BaseTool, Tool
|
|
13
12
|
from langchain_core.tools.render import render_text_description
|
|
13
|
+
from pydantic import Field
|
|
14
14
|
|
|
15
15
|
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
|
|
16
16
|
from langchain.agents.agent_types import AgentType
|
|
@@ -20,10 +20,11 @@ from typing import (
|
|
|
20
20
|
from langchain_core.agents import AgentAction, AgentFinish
|
|
21
21
|
from langchain_core.callbacks import CallbackManager
|
|
22
22
|
from langchain_core.load import dumpd
|
|
23
|
-
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
|
24
23
|
from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config
|
|
25
24
|
from langchain_core.tools import BaseTool
|
|
26
25
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
|
26
|
+
from pydantic import BaseModel, Field, model_validator
|
|
27
|
+
from typing_extensions import Self
|
|
27
28
|
|
|
28
29
|
if TYPE_CHECKING:
|
|
29
30
|
import openai
|
|
@@ -232,14 +233,14 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
|
|
232
233
|
as_agent: bool = False
|
|
233
234
|
"""Use as a LangChain agent, compatible with the AgentExecutor."""
|
|
234
235
|
|
|
235
|
-
@
|
|
236
|
-
def validate_async_client(
|
|
237
|
-
if
|
|
236
|
+
@model_validator(mode="after")
|
|
237
|
+
def validate_async_client(self) -> Self:
|
|
238
|
+
if self.async_client is None:
|
|
238
239
|
import openai
|
|
239
240
|
|
|
240
|
-
api_key =
|
|
241
|
-
|
|
242
|
-
return
|
|
241
|
+
api_key = self.client.api_key
|
|
242
|
+
self.async_client = openai.AsyncOpenAI(api_key=api_key)
|
|
243
|
+
return self
|
|
243
244
|
|
|
244
245
|
@classmethod
|
|
245
246
|
def create_assistant(
|
|
@@ -17,10 +17,11 @@ from langchain_core.prompts.chat import (
|
|
|
17
17
|
HumanMessagePromptTemplate,
|
|
18
18
|
MessagesPlaceholder,
|
|
19
19
|
)
|
|
20
|
-
from langchain_core.pydantic_v1 import root_validator
|
|
21
20
|
from langchain_core.runnables import Runnable, RunnablePassthrough
|
|
22
21
|
from langchain_core.tools import BaseTool
|
|
23
22
|
from langchain_core.utils.function_calling import convert_to_openai_function
|
|
23
|
+
from pydantic import model_validator
|
|
24
|
+
from typing_extensions import Self
|
|
24
25
|
|
|
25
26
|
from langchain.agents import BaseSingleActionAgent
|
|
26
27
|
from langchain.agents.format_scratchpad.openai_functions import (
|
|
@@ -58,8 +59,8 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
|
|
58
59
|
"""Get allowed tools."""
|
|
59
60
|
return [t.name for t in self.tools]
|
|
60
61
|
|
|
61
|
-
@
|
|
62
|
-
def validate_prompt(
|
|
62
|
+
@model_validator(mode="after")
|
|
63
|
+
def validate_prompt(self) -> Self:
|
|
63
64
|
"""Validate prompt.
|
|
64
65
|
|
|
65
66
|
Args:
|
|
@@ -71,13 +72,13 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
|
|
71
72
|
Raises:
|
|
72
73
|
ValueError: If `agent_scratchpad` is not in the prompt.
|
|
73
74
|
"""
|
|
74
|
-
prompt: BasePromptTemplate =
|
|
75
|
+
prompt: BasePromptTemplate = self.prompt
|
|
75
76
|
if "agent_scratchpad" not in prompt.input_variables:
|
|
76
77
|
raise ValueError(
|
|
77
78
|
"`agent_scratchpad` should be one of the variables in the prompt, "
|
|
78
79
|
f"got {prompt.input_variables}"
|
|
79
80
|
)
|
|
80
|
-
return
|
|
81
|
+
return self
|
|
81
82
|
|
|
82
83
|
@property
|
|
83
84
|
def input_keys(self) -> List[str]:
|
|
@@ -21,8 +21,9 @@ from langchain_core.prompts.chat import (
|
|
|
21
21
|
HumanMessagePromptTemplate,
|
|
22
22
|
MessagesPlaceholder,
|
|
23
23
|
)
|
|
24
|
-
from langchain_core.pydantic_v1 import root_validator
|
|
25
24
|
from langchain_core.tools import BaseTool
|
|
25
|
+
from pydantic import model_validator
|
|
26
|
+
from typing_extensions import Self
|
|
26
27
|
|
|
27
28
|
from langchain.agents import BaseMultiActionAgent
|
|
28
29
|
from langchain.agents.format_scratchpad.openai_functions import (
|
|
@@ -115,15 +116,15 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
|
|
115
116
|
"""Get allowed tools."""
|
|
116
117
|
return [t.name for t in self.tools]
|
|
117
118
|
|
|
118
|
-
@
|
|
119
|
-
def validate_prompt(
|
|
120
|
-
prompt: BasePromptTemplate =
|
|
119
|
+
@model_validator(mode="after")
|
|
120
|
+
def validate_prompt(self) -> Self:
|
|
121
|
+
prompt: BasePromptTemplate = self.prompt
|
|
121
122
|
if "agent_scratchpad" not in prompt.input_variables:
|
|
122
123
|
raise ValueError(
|
|
123
124
|
"`agent_scratchpad` should be one of the variables in the prompt, "
|
|
124
125
|
f"got {prompt.input_variables}"
|
|
125
126
|
)
|
|
126
|
-
return
|
|
127
|
+
return self
|
|
127
128
|
|
|
128
129
|
@property
|
|
129
130
|
def input_keys(self) -> List[str]:
|
langchain/agents/react/base.py
CHANGED
|
@@ -8,8 +8,8 @@ from langchain_core._api import deprecated
|
|
|
8
8
|
from langchain_core.documents import Document
|
|
9
9
|
from langchain_core.language_models import BaseLanguageModel
|
|
10
10
|
from langchain_core.prompts import BasePromptTemplate
|
|
11
|
-
from langchain_core.pydantic_v1 import Field
|
|
12
11
|
from langchain_core.tools import BaseTool, Tool
|
|
12
|
+
from pydantic import Field
|
|
13
13
|
|
|
14
14
|
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
|
|
15
15
|
from langchain.agents.agent_types import AgentType
|
|
@@ -7,9 +7,9 @@ from typing import TYPE_CHECKING, Any, Sequence, Union
|
|
|
7
7
|
from langchain_core._api import deprecated
|
|
8
8
|
from langchain_core.language_models import BaseLanguageModel
|
|
9
9
|
from langchain_core.prompts import BasePromptTemplate
|
|
10
|
-
from langchain_core.pydantic_v1 import Field
|
|
11
10
|
from langchain_core.runnables import Runnable, RunnablePassthrough
|
|
12
11
|
from langchain_core.tools import BaseTool, Tool
|
|
12
|
+
from pydantic import Field
|
|
13
13
|
|
|
14
14
|
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
|
|
15
15
|
from langchain.agents.agent_types import AgentType
|
|
@@ -11,10 +11,10 @@ from langchain_core.prompts.chat import (
|
|
|
11
11
|
HumanMessagePromptTemplate,
|
|
12
12
|
SystemMessagePromptTemplate,
|
|
13
13
|
)
|
|
14
|
-
from langchain_core.pydantic_v1 import Field
|
|
15
14
|
from langchain_core.runnables import Runnable, RunnablePassthrough
|
|
16
15
|
from langchain_core.tools import BaseTool
|
|
17
16
|
from langchain_core.tools.render import ToolsRenderer
|
|
17
|
+
from pydantic import Field
|
|
18
18
|
|
|
19
19
|
from langchain.agents.agent import Agent, AgentOutputParser
|
|
20
20
|
from langchain.agents.format_scratchpad import format_log_to_str
|
|
@@ -8,7 +8,7 @@ from typing import Optional, Pattern, Union
|
|
|
8
8
|
from langchain_core.agents import AgentAction, AgentFinish
|
|
9
9
|
from langchain_core.exceptions import OutputParserException
|
|
10
10
|
from langchain_core.language_models import BaseLanguageModel
|
|
11
|
-
from
|
|
11
|
+
from pydantic import Field
|
|
12
12
|
|
|
13
13
|
from langchain.agents.agent import AgentOutputParser
|
|
14
14
|
from langchain.agents.structured_chat.prompt import FORMAT_INSTRUCTIONS
|
langchain/chains/api/base.py
CHANGED
|
@@ -12,7 +12,8 @@ from langchain_core.callbacks import (
|
|
|
12
12
|
)
|
|
13
13
|
from langchain_core.language_models import BaseLanguageModel
|
|
14
14
|
from langchain_core.prompts import BasePromptTemplate
|
|
15
|
-
from
|
|
15
|
+
from pydantic import Field, model_validator
|
|
16
|
+
from typing_extensions import Self
|
|
16
17
|
|
|
17
18
|
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
|
|
18
19
|
from langchain.chains.base import Chain
|
|
@@ -197,7 +198,7 @@ try:
|
|
|
197
198
|
api_docs: str
|
|
198
199
|
question_key: str = "question" #: :meta private:
|
|
199
200
|
output_key: str = "output" #: :meta private:
|
|
200
|
-
limit_to_domains: Optional[Sequence[str]]
|
|
201
|
+
limit_to_domains: Optional[Sequence[str]] = Field(default_factory=list)
|
|
201
202
|
"""Use to limit the domains that can be accessed by the API chain.
|
|
202
203
|
|
|
203
204
|
* For example, to limit to just the domain `https://www.example.com`, set
|
|
@@ -227,19 +228,20 @@ try:
|
|
|
227
228
|
"""
|
|
228
229
|
return [self.output_key]
|
|
229
230
|
|
|
230
|
-
@
|
|
231
|
-
def validate_api_request_prompt(
|
|
231
|
+
@model_validator(mode="after")
|
|
232
|
+
def validate_api_request_prompt(self) -> Self:
|
|
232
233
|
"""Check that api request prompt expects the right variables."""
|
|
233
|
-
input_vars =
|
|
234
|
+
input_vars = self.api_request_chain.prompt.input_variables
|
|
234
235
|
expected_vars = {"question", "api_docs"}
|
|
235
236
|
if set(input_vars) != expected_vars:
|
|
236
237
|
raise ValueError(
|
|
237
238
|
f"Input variables should be {expected_vars}, got {input_vars}"
|
|
238
239
|
)
|
|
239
|
-
return
|
|
240
|
+
return self
|
|
240
241
|
|
|
241
|
-
@
|
|
242
|
-
|
|
242
|
+
@model_validator(mode="before")
|
|
243
|
+
@classmethod
|
|
244
|
+
def validate_limit_to_domains(cls, values: Dict) -> Any:
|
|
243
245
|
"""Check that allowed domains are valid."""
|
|
244
246
|
# This check must be a pre=True check, so that a default of None
|
|
245
247
|
# won't be set to limit_to_domains if it's not provided.
|
|
@@ -258,16 +260,16 @@ try:
|
|
|
258
260
|
)
|
|
259
261
|
return values
|
|
260
262
|
|
|
261
|
-
@
|
|
262
|
-
def validate_api_answer_prompt(
|
|
263
|
+
@model_validator(mode="after")
|
|
264
|
+
def validate_api_answer_prompt(self) -> Self:
|
|
263
265
|
"""Check that api answer prompt expects the right variables."""
|
|
264
|
-
input_vars =
|
|
266
|
+
input_vars = self.api_answer_chain.prompt.input_variables
|
|
265
267
|
expected_vars = {"question", "api_docs", "api_url", "api_response"}
|
|
266
268
|
if set(input_vars) != expected_vars:
|
|
267
269
|
raise ValueError(
|
|
268
270
|
f"Input variables should be {expected_vars}, got {input_vars}"
|
|
269
271
|
)
|
|
270
|
-
return
|
|
272
|
+
return self
|
|
271
273
|
|
|
272
274
|
def _call(
|
|
273
275
|
self,
|
langchain/chains/base.py
CHANGED
|
@@ -18,10 +18,8 @@ from langchain_core.callbacks import (
|
|
|
18
18
|
CallbackManagerForChainRun,
|
|
19
19
|
Callbacks,
|
|
20
20
|
)
|
|
21
|
-
from langchain_core.load.dump import dumpd
|
|
22
21
|
from langchain_core.memory import BaseMemory
|
|
23
22
|
from langchain_core.outputs import RunInfo
|
|
24
|
-
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator, validator
|
|
25
23
|
from langchain_core.runnables import (
|
|
26
24
|
RunnableConfig,
|
|
27
25
|
RunnableSerializable,
|
|
@@ -29,6 +27,13 @@ from langchain_core.runnables import (
|
|
|
29
27
|
run_in_executor,
|
|
30
28
|
)
|
|
31
29
|
from langchain_core.runnables.utils import create_model
|
|
30
|
+
from pydantic import (
|
|
31
|
+
BaseModel,
|
|
32
|
+
ConfigDict,
|
|
33
|
+
Field,
|
|
34
|
+
field_validator,
|
|
35
|
+
model_validator,
|
|
36
|
+
)
|
|
32
37
|
|
|
33
38
|
from langchain.schema import RUN_KEY
|
|
34
39
|
|
|
@@ -96,8 +101,9 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
96
101
|
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
|
97
102
|
"""[DEPRECATED] Use `callbacks` instead."""
|
|
98
103
|
|
|
99
|
-
|
|
100
|
-
arbitrary_types_allowed
|
|
104
|
+
model_config = ConfigDict(
|
|
105
|
+
arbitrary_types_allowed=True,
|
|
106
|
+
)
|
|
101
107
|
|
|
102
108
|
def get_input_schema(
|
|
103
109
|
self, config: Optional[RunnableConfig] = None
|
|
@@ -143,7 +149,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
143
149
|
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
|
144
150
|
|
|
145
151
|
run_manager = callback_manager.on_chain_start(
|
|
146
|
-
|
|
152
|
+
None,
|
|
147
153
|
inputs,
|
|
148
154
|
run_id,
|
|
149
155
|
name=run_name,
|
|
@@ -195,7 +201,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
195
201
|
)
|
|
196
202
|
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
|
197
203
|
run_manager = await callback_manager.on_chain_start(
|
|
198
|
-
|
|
204
|
+
None,
|
|
199
205
|
inputs,
|
|
200
206
|
run_id,
|
|
201
207
|
name=run_name,
|
|
@@ -223,8 +229,9 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
223
229
|
def _chain_type(self) -> str:
|
|
224
230
|
raise NotImplementedError("Saving not supported for this chain type.")
|
|
225
231
|
|
|
226
|
-
@
|
|
227
|
-
|
|
232
|
+
@model_validator(mode="before")
|
|
233
|
+
@classmethod
|
|
234
|
+
def raise_callback_manager_deprecation(cls, values: Dict) -> Any:
|
|
228
235
|
"""Raise deprecation warning if callback_manager is used."""
|
|
229
236
|
if values.get("callback_manager") is not None:
|
|
230
237
|
if values.get("callbacks") is not None:
|
|
@@ -240,7 +247,8 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
240
247
|
values["callbacks"] = values.pop("callback_manager", None)
|
|
241
248
|
return values
|
|
242
249
|
|
|
243
|
-
@
|
|
250
|
+
@field_validator("verbose", mode="before")
|
|
251
|
+
@classmethod
|
|
244
252
|
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
|
245
253
|
"""Set the chain verbosity.
|
|
246
254
|
|
|
@@ -10,10 +10,10 @@ from langchain_core.callbacks import (
|
|
|
10
10
|
)
|
|
11
11
|
from langchain_core.documents import Document
|
|
12
12
|
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
|
13
|
-
from langchain_core.pydantic_v1 import BaseModel, Field
|
|
14
13
|
from langchain_core.runnables.config import RunnableConfig
|
|
15
14
|
from langchain_core.runnables.utils import create_model
|
|
16
15
|
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
|
16
|
+
from pydantic import BaseModel, Field
|
|
17
17
|
|
|
18
18
|
from langchain.chains.base import Chain
|
|
19
19
|
|
|
@@ -6,9 +6,9 @@ from typing import Any, Dict, List, Optional, Tuple, Type
|
|
|
6
6
|
|
|
7
7
|
from langchain_core.callbacks import Callbacks
|
|
8
8
|
from langchain_core.documents import Document
|
|
9
|
-
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
|
10
9
|
from langchain_core.runnables.config import RunnableConfig
|
|
11
10
|
from langchain_core.runnables.utils import create_model
|
|
11
|
+
from pydantic import BaseModel, ConfigDict, model_validator
|
|
12
12
|
|
|
13
13
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
|
14
14
|
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
|
|
@@ -126,12 +126,14 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|
|
126
126
|
_output_keys = _output_keys + ["intermediate_steps"]
|
|
127
127
|
return _output_keys
|
|
128
128
|
|
|
129
|
-
|
|
130
|
-
arbitrary_types_allowed
|
|
131
|
-
extra
|
|
129
|
+
model_config = ConfigDict(
|
|
130
|
+
arbitrary_types_allowed=True,
|
|
131
|
+
extra="forbid",
|
|
132
|
+
)
|
|
132
133
|
|
|
133
|
-
@
|
|
134
|
-
|
|
134
|
+
@model_validator(mode="before")
|
|
135
|
+
@classmethod
|
|
136
|
+
def get_reduce_chain(cls, values: Dict) -> Any:
|
|
135
137
|
"""For backwards compatibility."""
|
|
136
138
|
if "combine_document_chain" in values:
|
|
137
139
|
if "reduce_documents_chain" in values:
|
|
@@ -153,16 +155,18 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|
|
153
155
|
|
|
154
156
|
return values
|
|
155
157
|
|
|
156
|
-
@
|
|
157
|
-
|
|
158
|
+
@model_validator(mode="before")
|
|
159
|
+
@classmethod
|
|
160
|
+
def get_return_intermediate_steps(cls, values: Dict) -> Any:
|
|
158
161
|
"""For backwards compatibility."""
|
|
159
162
|
if "return_map_steps" in values:
|
|
160
163
|
values["return_intermediate_steps"] = values["return_map_steps"]
|
|
161
164
|
del values["return_map_steps"]
|
|
162
165
|
return values
|
|
163
166
|
|
|
164
|
-
@
|
|
165
|
-
|
|
167
|
+
@model_validator(mode="before")
|
|
168
|
+
@classmethod
|
|
169
|
+
def get_default_document_variable_name(cls, values: Dict) -> Any:
|
|
166
170
|
"""Get default document variable name, if not provided."""
|
|
167
171
|
if "llm_chain" not in values:
|
|
168
172
|
raise ValueError("llm_chain must be provided")
|
|
@@ -6,9 +6,10 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast
|
|
|
6
6
|
|
|
7
7
|
from langchain_core.callbacks import Callbacks
|
|
8
8
|
from langchain_core.documents import Document
|
|
9
|
-
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
|
10
9
|
from langchain_core.runnables.config import RunnableConfig
|
|
11
10
|
from langchain_core.runnables.utils import create_model
|
|
11
|
+
from pydantic import BaseModel, ConfigDict, model_validator
|
|
12
|
+
from typing_extensions import Self
|
|
12
13
|
|
|
13
14
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
|
14
15
|
from langchain.chains.llm import LLMChain
|
|
@@ -74,9 +75,10 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
|
|
74
75
|
"""Return intermediate steps.
|
|
75
76
|
Intermediate steps include the results of calling llm_chain on each document."""
|
|
76
77
|
|
|
77
|
-
|
|
78
|
-
arbitrary_types_allowed
|
|
79
|
-
extra
|
|
78
|
+
model_config = ConfigDict(
|
|
79
|
+
arbitrary_types_allowed=True,
|
|
80
|
+
extra="forbid",
|
|
81
|
+
)
|
|
80
82
|
|
|
81
83
|
def get_output_schema(
|
|
82
84
|
self, config: Optional[RunnableConfig] = None
|
|
@@ -104,30 +106,31 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
|
|
104
106
|
_output_keys += self.metadata_keys
|
|
105
107
|
return _output_keys
|
|
106
108
|
|
|
107
|
-
@
|
|
108
|
-
def validate_llm_output(
|
|
109
|
+
@model_validator(mode="after")
|
|
110
|
+
def validate_llm_output(self) -> Self:
|
|
109
111
|
"""Validate that the combine chain outputs a dictionary."""
|
|
110
|
-
output_parser =
|
|
112
|
+
output_parser = self.llm_chain.prompt.output_parser
|
|
111
113
|
if not isinstance(output_parser, RegexParser):
|
|
112
114
|
raise ValueError(
|
|
113
115
|
"Output parser of llm_chain should be a RegexParser,"
|
|
114
116
|
f" got {output_parser}"
|
|
115
117
|
)
|
|
116
118
|
output_keys = output_parser.output_keys
|
|
117
|
-
if
|
|
119
|
+
if self.rank_key not in output_keys:
|
|
118
120
|
raise ValueError(
|
|
119
|
-
f"Got {
|
|
121
|
+
f"Got {self.rank_key} as key to rank on, but did not find "
|
|
120
122
|
f"it in the llm_chain output keys ({output_keys})"
|
|
121
123
|
)
|
|
122
|
-
if
|
|
124
|
+
if self.answer_key not in output_keys:
|
|
123
125
|
raise ValueError(
|
|
124
|
-
f"Got {
|
|
126
|
+
f"Got {self.answer_key} as key to return, but did not find "
|
|
125
127
|
f"it in the llm_chain output keys ({output_keys})"
|
|
126
128
|
)
|
|
127
|
-
return
|
|
129
|
+
return self
|
|
128
130
|
|
|
129
|
-
@
|
|
130
|
-
|
|
131
|
+
@model_validator(mode="before")
|
|
132
|
+
@classmethod
|
|
133
|
+
def get_default_document_variable_name(cls, values: Dict) -> Any:
|
|
131
134
|
"""Get default document variable name, if not provided."""
|
|
132
135
|
if "llm_chain" not in values:
|
|
133
136
|
raise ValueError("llm_chain must be provided")
|