langchain 0.3.27__py3-none-any.whl → 0.4.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/agents/agent.py +16 -20
- langchain/agents/agent_iterator.py +19 -12
- langchain/agents/agent_toolkits/vectorstore/base.py +2 -0
- langchain/agents/chat/base.py +2 -0
- langchain/agents/conversational/base.py +2 -0
- langchain/agents/conversational_chat/base.py +2 -0
- langchain/agents/initialize.py +1 -1
- langchain/agents/json_chat/base.py +1 -0
- langchain/agents/mrkl/base.py +2 -0
- langchain/agents/openai_assistant/base.py +1 -1
- langchain/agents/openai_functions_agent/agent_token_buffer_memory.py +2 -0
- langchain/agents/openai_functions_agent/base.py +3 -2
- langchain/agents/openai_functions_multi_agent/base.py +1 -1
- langchain/agents/openai_tools/base.py +1 -0
- langchain/agents/output_parsers/json.py +2 -0
- langchain/agents/output_parsers/openai_functions.py +10 -3
- langchain/agents/output_parsers/openai_tools.py +8 -1
- langchain/agents/output_parsers/react_json_single_input.py +3 -0
- langchain/agents/output_parsers/react_single_input.py +3 -0
- langchain/agents/output_parsers/self_ask.py +2 -0
- langchain/agents/output_parsers/tools.py +16 -2
- langchain/agents/output_parsers/xml.py +3 -0
- langchain/agents/react/agent.py +1 -0
- langchain/agents/react/base.py +4 -0
- langchain/agents/react/output_parser.py +2 -0
- langchain/agents/schema.py +2 -0
- langchain/agents/self_ask_with_search/base.py +4 -0
- langchain/agents/structured_chat/base.py +5 -0
- langchain/agents/structured_chat/output_parser.py +13 -0
- langchain/agents/tool_calling_agent/base.py +1 -0
- langchain/agents/tools.py +3 -0
- langchain/agents/xml/base.py +7 -1
- langchain/callbacks/streaming_aiter.py +13 -2
- langchain/callbacks/streaming_aiter_final_only.py +11 -2
- langchain/callbacks/streaming_stdout_final_only.py +5 -0
- langchain/callbacks/tracers/logging.py +11 -0
- langchain/chains/api/base.py +5 -1
- langchain/chains/base.py +8 -2
- langchain/chains/combine_documents/base.py +7 -1
- langchain/chains/combine_documents/map_reduce.py +3 -0
- langchain/chains/combine_documents/map_rerank.py +6 -4
- langchain/chains/combine_documents/reduce.py +1 -0
- langchain/chains/combine_documents/refine.py +1 -0
- langchain/chains/combine_documents/stuff.py +5 -1
- langchain/chains/constitutional_ai/base.py +7 -0
- langchain/chains/conversation/base.py +4 -1
- langchain/chains/conversational_retrieval/base.py +67 -59
- langchain/chains/elasticsearch_database/base.py +2 -1
- langchain/chains/flare/base.py +2 -0
- langchain/chains/flare/prompts.py +2 -0
- langchain/chains/llm.py +7 -2
- langchain/chains/llm_bash/__init__.py +1 -1
- langchain/chains/llm_checker/base.py +12 -1
- langchain/chains/llm_math/base.py +9 -1
- langchain/chains/llm_summarization_checker/base.py +13 -1
- langchain/chains/llm_symbolic_math/__init__.py +1 -1
- langchain/chains/loading.py +4 -2
- langchain/chains/moderation.py +3 -0
- langchain/chains/natbot/base.py +3 -1
- langchain/chains/natbot/crawler.py +29 -0
- langchain/chains/openai_functions/base.py +2 -0
- langchain/chains/openai_functions/citation_fuzzy_match.py +9 -0
- langchain/chains/openai_functions/openapi.py +4 -0
- langchain/chains/openai_functions/qa_with_structure.py +3 -3
- langchain/chains/openai_functions/tagging.py +2 -0
- langchain/chains/qa_generation/base.py +4 -0
- langchain/chains/qa_with_sources/base.py +3 -0
- langchain/chains/qa_with_sources/retrieval.py +1 -1
- langchain/chains/qa_with_sources/vector_db.py +4 -2
- langchain/chains/query_constructor/base.py +4 -2
- langchain/chains/query_constructor/parser.py +64 -2
- langchain/chains/retrieval_qa/base.py +4 -0
- langchain/chains/router/base.py +14 -2
- langchain/chains/router/embedding_router.py +3 -0
- langchain/chains/router/llm_router.py +6 -4
- langchain/chains/router/multi_prompt.py +3 -0
- langchain/chains/router/multi_retrieval_qa.py +18 -0
- langchain/chains/sql_database/query.py +1 -0
- langchain/chains/structured_output/base.py +2 -0
- langchain/chains/transform.py +4 -0
- langchain/chat_models/base.py +55 -18
- langchain/document_loaders/blob_loaders/schema.py +1 -4
- langchain/embeddings/base.py +2 -0
- langchain/embeddings/cache.py +3 -3
- langchain/evaluation/agents/trajectory_eval_chain.py +3 -2
- langchain/evaluation/comparison/eval_chain.py +1 -0
- langchain/evaluation/criteria/eval_chain.py +3 -0
- langchain/evaluation/embedding_distance/base.py +11 -0
- langchain/evaluation/exact_match/base.py +14 -1
- langchain/evaluation/loading.py +1 -0
- langchain/evaluation/parsing/base.py +16 -3
- langchain/evaluation/parsing/json_distance.py +19 -8
- langchain/evaluation/parsing/json_schema.py +1 -4
- langchain/evaluation/qa/eval_chain.py +8 -0
- langchain/evaluation/qa/generate_chain.py +2 -0
- langchain/evaluation/regex_match/base.py +9 -1
- langchain/evaluation/scoring/eval_chain.py +1 -0
- langchain/evaluation/string_distance/base.py +6 -0
- langchain/memory/buffer.py +5 -0
- langchain/memory/buffer_window.py +2 -0
- langchain/memory/combined.py +1 -1
- langchain/memory/entity.py +47 -0
- langchain/memory/simple.py +3 -0
- langchain/memory/summary.py +30 -0
- langchain/memory/summary_buffer.py +3 -0
- langchain/memory/token_buffer.py +2 -0
- langchain/output_parsers/combining.py +4 -2
- langchain/output_parsers/enum.py +5 -1
- langchain/output_parsers/fix.py +8 -1
- langchain/output_parsers/pandas_dataframe.py +16 -1
- langchain/output_parsers/regex.py +2 -0
- langchain/output_parsers/retry.py +21 -1
- langchain/output_parsers/structured.py +10 -0
- langchain/output_parsers/yaml.py +4 -0
- langchain/pydantic_v1/__init__.py +1 -1
- langchain/retrievers/document_compressors/chain_extract.py +4 -2
- langchain/retrievers/document_compressors/cohere_rerank.py +2 -0
- langchain/retrievers/document_compressors/cross_encoder_rerank.py +2 -0
- langchain/retrievers/document_compressors/embeddings_filter.py +3 -0
- langchain/retrievers/document_compressors/listwise_rerank.py +1 -0
- langchain/retrievers/ensemble.py +2 -2
- langchain/retrievers/multi_query.py +3 -1
- langchain/retrievers/multi_vector.py +4 -1
- langchain/retrievers/parent_document_retriever.py +15 -0
- langchain/retrievers/self_query/base.py +19 -0
- langchain/retrievers/time_weighted_retriever.py +3 -0
- langchain/runnables/hub.py +12 -0
- langchain/runnables/openai_functions.py +6 -0
- langchain/smith/__init__.py +1 -0
- langchain/smith/evaluation/config.py +5 -22
- langchain/smith/evaluation/progress.py +12 -3
- langchain/smith/evaluation/runner_utils.py +240 -123
- langchain/smith/evaluation/string_run_evaluator.py +27 -0
- langchain/storage/encoder_backed.py +1 -0
- langchain/tools/python/__init__.py +1 -1
- {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/METADATA +2 -12
- {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/RECORD +140 -141
- langchain/smith/evaluation/utils.py +0 -0
- {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/WHEEL +0 -0
- {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/entry_points.txt +0 -0
- {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/licenses/LICENSE +0 -0
|
@@ -3,6 +3,7 @@ from typing import Literal, Optional, Union
|
|
|
3
3
|
|
|
4
4
|
from langchain_core.agents import AgentAction, AgentFinish
|
|
5
5
|
from pydantic import Field
|
|
6
|
+
from typing_extensions import override
|
|
6
7
|
|
|
7
8
|
from langchain.agents import AgentOutputParser
|
|
8
9
|
|
|
@@ -65,6 +66,7 @@ class XMLAgentOutputParser(AgentOutputParser):
|
|
|
65
66
|
None - no escaping is applied, which may lead to parsing conflicts.
|
|
66
67
|
"""
|
|
67
68
|
|
|
69
|
+
@override
|
|
68
70
|
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
|
69
71
|
# Check for tool invocation first
|
|
70
72
|
tool_matches = re.findall(r"<tool>(.*?)</tool>", text, re.DOTALL)
|
|
@@ -115,6 +117,7 @@ class XMLAgentOutputParser(AgentOutputParser):
|
|
|
115
117
|
)
|
|
116
118
|
raise ValueError(msg)
|
|
117
119
|
|
|
120
|
+
@override
|
|
118
121
|
def get_format_instructions(self) -> str:
|
|
119
122
|
raise NotImplementedError
|
|
120
123
|
|
langchain/agents/react/agent.py
CHANGED
|
@@ -116,6 +116,7 @@ def create_react_agent(
|
|
|
116
116
|
Thought:{agent_scratchpad}'''
|
|
117
117
|
|
|
118
118
|
prompt = PromptTemplate.from_template(template)
|
|
119
|
+
|
|
119
120
|
""" # noqa: E501
|
|
120
121
|
missing_vars = {"tools", "tool_names", "agent_scratchpad"}.difference(
|
|
121
122
|
prompt.input_variables + list(prompt.partial_variables),
|
langchain/agents/react/base.py
CHANGED
|
@@ -11,6 +11,7 @@ from langchain_core.language_models import BaseLanguageModel
|
|
|
11
11
|
from langchain_core.prompts import BasePromptTemplate
|
|
12
12
|
from langchain_core.tools import BaseTool, Tool
|
|
13
13
|
from pydantic import Field
|
|
14
|
+
from typing_extensions import override
|
|
14
15
|
|
|
15
16
|
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
|
|
16
17
|
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
|
|
@@ -38,6 +39,7 @@ class ReActDocstoreAgent(Agent):
|
|
|
38
39
|
output_parser: AgentOutputParser = Field(default_factory=ReActOutputParser)
|
|
39
40
|
|
|
40
41
|
@classmethod
|
|
42
|
+
@override
|
|
41
43
|
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
|
|
42
44
|
return ReActOutputParser()
|
|
43
45
|
|
|
@@ -47,6 +49,7 @@ class ReActDocstoreAgent(Agent):
|
|
|
47
49
|
return AgentType.REACT_DOCSTORE
|
|
48
50
|
|
|
49
51
|
@classmethod
|
|
52
|
+
@override
|
|
50
53
|
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
|
|
51
54
|
"""Return default prompt."""
|
|
52
55
|
return WIKI_PROMPT
|
|
@@ -141,6 +144,7 @@ class ReActTextWorldAgent(ReActDocstoreAgent):
|
|
|
141
144
|
"""Agent for the ReAct TextWorld chain."""
|
|
142
145
|
|
|
143
146
|
@classmethod
|
|
147
|
+
@override
|
|
144
148
|
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
|
|
145
149
|
"""Return default prompt."""
|
|
146
150
|
return TEXTWORLD_PROMPT
|
|
@@ -3,6 +3,7 @@ from typing import Union
|
|
|
3
3
|
|
|
4
4
|
from langchain_core.agents import AgentAction, AgentFinish
|
|
5
5
|
from langchain_core.exceptions import OutputParserException
|
|
6
|
+
from typing_extensions import override
|
|
6
7
|
|
|
7
8
|
from langchain.agents.agent import AgentOutputParser
|
|
8
9
|
|
|
@@ -10,6 +11,7 @@ from langchain.agents.agent import AgentOutputParser
|
|
|
10
11
|
class ReActOutputParser(AgentOutputParser):
|
|
11
12
|
"""Output parser for the ReAct agent."""
|
|
12
13
|
|
|
14
|
+
@override
|
|
13
15
|
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
|
14
16
|
action_prefix = "Action: "
|
|
15
17
|
if not text.strip().split("\n")[-1].startswith(action_prefix):
|
langchain/agents/schema.py
CHANGED
|
@@ -2,12 +2,14 @@ from typing import Any
|
|
|
2
2
|
|
|
3
3
|
from langchain_core.agents import AgentAction
|
|
4
4
|
from langchain_core.prompts.chat import ChatPromptTemplate
|
|
5
|
+
from typing_extensions import override
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
class AgentScratchPadChatPromptTemplate(ChatPromptTemplate):
|
|
8
9
|
"""Chat prompt template for the agent scratchpad."""
|
|
9
10
|
|
|
10
11
|
@classmethod
|
|
12
|
+
@override
|
|
11
13
|
def is_lc_serializable(cls) -> bool:
|
|
12
14
|
return False
|
|
13
15
|
|
|
@@ -11,6 +11,7 @@ from langchain_core.prompts import BasePromptTemplate
|
|
|
11
11
|
from langchain_core.runnables import Runnable, RunnablePassthrough
|
|
12
12
|
from langchain_core.tools import BaseTool, Tool
|
|
13
13
|
from pydantic import Field
|
|
14
|
+
from typing_extensions import override
|
|
14
15
|
|
|
15
16
|
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
|
|
16
17
|
from langchain.agents.agent_types import AgentType
|
|
@@ -32,6 +33,7 @@ class SelfAskWithSearchAgent(Agent):
|
|
|
32
33
|
output_parser: AgentOutputParser = Field(default_factory=SelfAskOutputParser)
|
|
33
34
|
|
|
34
35
|
@classmethod
|
|
36
|
+
@override
|
|
35
37
|
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
|
|
36
38
|
return SelfAskOutputParser()
|
|
37
39
|
|
|
@@ -41,6 +43,7 @@ class SelfAskWithSearchAgent(Agent):
|
|
|
41
43
|
return AgentType.SELF_ASK_WITH_SEARCH
|
|
42
44
|
|
|
43
45
|
@classmethod
|
|
46
|
+
@override
|
|
44
47
|
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
|
|
45
48
|
"""Prompt does not depend on tools."""
|
|
46
49
|
return PROMPT
|
|
@@ -182,6 +185,7 @@ def create_self_ask_with_search_agent(
|
|
|
182
185
|
Are followup questions needed here:{agent_scratchpad}'''
|
|
183
186
|
|
|
184
187
|
prompt = PromptTemplate.from_template(template)
|
|
188
|
+
|
|
185
189
|
""" # noqa: E501
|
|
186
190
|
missing_vars = {"agent_scratchpad"}.difference(
|
|
187
191
|
prompt.input_variables + list(prompt.partial_variables),
|
|
@@ -16,6 +16,7 @@ from langchain_core.runnables import Runnable, RunnablePassthrough
|
|
|
16
16
|
from langchain_core.tools import BaseTool
|
|
17
17
|
from langchain_core.tools.render import ToolsRenderer
|
|
18
18
|
from pydantic import Field
|
|
19
|
+
from typing_extensions import override
|
|
19
20
|
|
|
20
21
|
from langchain.agents.agent import Agent, AgentOutputParser
|
|
21
22
|
from langchain.agents.format_scratchpad import format_log_to_str
|
|
@@ -70,6 +71,7 @@ class StructuredChatAgent(Agent):
|
|
|
70
71
|
pass
|
|
71
72
|
|
|
72
73
|
@classmethod
|
|
74
|
+
@override
|
|
73
75
|
def _get_default_output_parser(
|
|
74
76
|
cls,
|
|
75
77
|
llm: Optional[BaseLanguageModel] = None,
|
|
@@ -78,10 +80,12 @@ class StructuredChatAgent(Agent):
|
|
|
78
80
|
return StructuredChatOutputParserWithRetries.from_llm(llm=llm)
|
|
79
81
|
|
|
80
82
|
@property
|
|
83
|
+
@override
|
|
81
84
|
def _stop(self) -> list[str]:
|
|
82
85
|
return ["Observation:"]
|
|
83
86
|
|
|
84
87
|
@classmethod
|
|
88
|
+
@override
|
|
85
89
|
def create_prompt(
|
|
86
90
|
cls,
|
|
87
91
|
tools: Sequence[BaseTool],
|
|
@@ -276,6 +280,7 @@ def create_structured_chat_agent(
|
|
|
276
280
|
("human", human),
|
|
277
281
|
]
|
|
278
282
|
)
|
|
283
|
+
|
|
279
284
|
""" # noqa: E501
|
|
280
285
|
missing_vars = {"tools", "tool_names", "agent_scratchpad"}.difference(
|
|
281
286
|
prompt.input_variables + list(prompt.partial_variables),
|
|
@@ -10,6 +10,7 @@ from langchain_core.agents import AgentAction, AgentFinish
|
|
|
10
10
|
from langchain_core.exceptions import OutputParserException
|
|
11
11
|
from langchain_core.language_models import BaseLanguageModel
|
|
12
12
|
from pydantic import Field
|
|
13
|
+
from typing_extensions import override
|
|
13
14
|
|
|
14
15
|
from langchain.agents.agent import AgentOutputParser
|
|
15
16
|
from langchain.agents.structured_chat.prompt import FORMAT_INSTRUCTIONS
|
|
@@ -27,10 +28,12 @@ class StructuredChatOutputParser(AgentOutputParser):
|
|
|
27
28
|
pattern: Pattern = re.compile(r"```(?:json\s+)?(\W.*?)```", re.DOTALL)
|
|
28
29
|
"""Regex pattern to parse the output."""
|
|
29
30
|
|
|
31
|
+
@override
|
|
30
32
|
def get_format_instructions(self) -> str:
|
|
31
33
|
"""Returns formatting instructions for the given output parser."""
|
|
32
34
|
return self.format_instructions
|
|
33
35
|
|
|
36
|
+
@override
|
|
34
37
|
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
|
35
38
|
try:
|
|
36
39
|
action_match = self.pattern.search(text)
|
|
@@ -65,9 +68,11 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
|
|
|
65
68
|
output_fixing_parser: Optional[OutputFixingParser] = None
|
|
66
69
|
"""The output fixing parser to use."""
|
|
67
70
|
|
|
71
|
+
@override
|
|
68
72
|
def get_format_instructions(self) -> str:
|
|
69
73
|
return FORMAT_INSTRUCTIONS
|
|
70
74
|
|
|
75
|
+
@override
|
|
71
76
|
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
|
72
77
|
try:
|
|
73
78
|
if self.output_fixing_parser is not None:
|
|
@@ -83,6 +88,14 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
|
|
|
83
88
|
llm: Optional[BaseLanguageModel] = None,
|
|
84
89
|
base_parser: Optional[StructuredChatOutputParser] = None,
|
|
85
90
|
) -> StructuredChatOutputParserWithRetries:
|
|
91
|
+
"""Create a StructuredChatOutputParserWithRetries from a language model.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
llm: The language model to use.
|
|
95
|
+
base_parser: An optional StructuredChatOutputParser to use.
|
|
96
|
+
Returns:
|
|
97
|
+
An instance of StructuredChatOutputParserWithRetries.
|
|
98
|
+
"""
|
|
86
99
|
if llm is not None:
|
|
87
100
|
base_parser = base_parser or StructuredChatOutputParser()
|
|
88
101
|
output_fixing_parser: OutputFixingParser = OutputFixingParser.from_llm(
|
|
@@ -85,6 +85,7 @@ def create_tool_calling_agent(
|
|
|
85
85
|
The agent prompt must have an `agent_scratchpad` key that is a
|
|
86
86
|
``MessagesPlaceholder``. Intermediate agent actions and tool output
|
|
87
87
|
messages will be passed in here.
|
|
88
|
+
|
|
88
89
|
"""
|
|
89
90
|
missing_vars = {"agent_scratchpad"}.difference(
|
|
90
91
|
prompt.input_variables + list(prompt.partial_variables),
|
langchain/agents/tools.py
CHANGED
|
@@ -7,6 +7,7 @@ from langchain_core.callbacks import (
|
|
|
7
7
|
CallbackManagerForToolRun,
|
|
8
8
|
)
|
|
9
9
|
from langchain_core.tools import BaseTool, tool
|
|
10
|
+
from typing_extensions import override
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
class InvalidTool(BaseTool):
|
|
@@ -17,6 +18,7 @@ class InvalidTool(BaseTool):
|
|
|
17
18
|
description: str = "Called when tool name is invalid. Suggests valid tool names."
|
|
18
19
|
"""Description of the tool."""
|
|
19
20
|
|
|
21
|
+
@override
|
|
20
22
|
def _run(
|
|
21
23
|
self,
|
|
22
24
|
requested_tool_name: str,
|
|
@@ -30,6 +32,7 @@ class InvalidTool(BaseTool):
|
|
|
30
32
|
f"try one of [{available_tool_names_str}]."
|
|
31
33
|
)
|
|
32
34
|
|
|
35
|
+
@override
|
|
33
36
|
async def _arun(
|
|
34
37
|
self,
|
|
35
38
|
requested_tool_name: str,
|
langchain/agents/xml/base.py
CHANGED
|
@@ -10,6 +10,7 @@ from langchain_core.prompts.chat import AIMessagePromptTemplate, ChatPromptTempl
|
|
|
10
10
|
from langchain_core.runnables import Runnable, RunnablePassthrough
|
|
11
11
|
from langchain_core.tools import BaseTool
|
|
12
12
|
from langchain_core.tools.render import ToolsRenderer, render_text_description
|
|
13
|
+
from typing_extensions import override
|
|
13
14
|
|
|
14
15
|
from langchain.agents.agent import BaseSingleActionAgent
|
|
15
16
|
from langchain.agents.format_scratchpad import format_xml
|
|
@@ -36,7 +37,6 @@ class XMLAgent(BaseSingleActionAgent):
|
|
|
36
37
|
tools = ...
|
|
37
38
|
model =
|
|
38
39
|
|
|
39
|
-
|
|
40
40
|
"""
|
|
41
41
|
|
|
42
42
|
tools: list[BaseTool]
|
|
@@ -45,11 +45,13 @@ class XMLAgent(BaseSingleActionAgent):
|
|
|
45
45
|
"""Chain to use to predict action."""
|
|
46
46
|
|
|
47
47
|
@property
|
|
48
|
+
@override
|
|
48
49
|
def input_keys(self) -> list[str]:
|
|
49
50
|
return ["input"]
|
|
50
51
|
|
|
51
52
|
@staticmethod
|
|
52
53
|
def get_default_prompt() -> ChatPromptTemplate:
|
|
54
|
+
"""Return the default prompt for the XML agent."""
|
|
53
55
|
base_prompt = ChatPromptTemplate.from_template(agent_instructions)
|
|
54
56
|
return base_prompt + AIMessagePromptTemplate.from_template(
|
|
55
57
|
"{intermediate_steps}",
|
|
@@ -57,8 +59,10 @@ class XMLAgent(BaseSingleActionAgent):
|
|
|
57
59
|
|
|
58
60
|
@staticmethod
|
|
59
61
|
def get_default_output_parser() -> XMLAgentOutputParser:
|
|
62
|
+
"""Return an XMLAgentOutputParser."""
|
|
60
63
|
return XMLAgentOutputParser()
|
|
61
64
|
|
|
65
|
+
@override
|
|
62
66
|
def plan(
|
|
63
67
|
self,
|
|
64
68
|
intermediate_steps: list[tuple[AgentAction, str]],
|
|
@@ -83,6 +87,7 @@ class XMLAgent(BaseSingleActionAgent):
|
|
|
83
87
|
response = self.llm_chain(inputs, callbacks=callbacks)
|
|
84
88
|
return response[self.llm_chain.output_key]
|
|
85
89
|
|
|
90
|
+
@override
|
|
86
91
|
async def aplan(
|
|
87
92
|
self,
|
|
88
93
|
intermediate_steps: list[tuple[AgentAction, str]],
|
|
@@ -203,6 +208,7 @@ def create_xml_agent(
|
|
|
203
208
|
Question: {input}
|
|
204
209
|
{agent_scratchpad}'''
|
|
205
210
|
prompt = PromptTemplate.from_template(template)
|
|
211
|
+
|
|
206
212
|
""" # noqa: E501
|
|
207
213
|
missing_vars = {"tools", "agent_scratchpad"}.difference(
|
|
208
214
|
prompt.input_variables + list(prompt.partial_variables),
|
|
@@ -6,6 +6,8 @@ from typing import Any, Literal, Union, cast
|
|
|
6
6
|
|
|
7
7
|
from langchain_core.callbacks import AsyncCallbackHandler
|
|
8
8
|
from langchain_core.outputs import LLMResult
|
|
9
|
+
from langchain_core.v1.messages import AIMessage
|
|
10
|
+
from typing_extensions import override
|
|
9
11
|
|
|
10
12
|
# TODO If used by two LLM runs in parallel this won't work as expected
|
|
11
13
|
|
|
@@ -19,12 +21,15 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
|
|
19
21
|
|
|
20
22
|
@property
|
|
21
23
|
def always_verbose(self) -> bool:
|
|
24
|
+
"""Always verbose."""
|
|
22
25
|
return True
|
|
23
26
|
|
|
24
27
|
def __init__(self) -> None:
|
|
28
|
+
"""Instantiate AsyncIteratorCallbackHandler."""
|
|
25
29
|
self.queue = asyncio.Queue()
|
|
26
30
|
self.done = asyncio.Event()
|
|
27
31
|
|
|
32
|
+
@override
|
|
28
33
|
async def on_llm_start(
|
|
29
34
|
self,
|
|
30
35
|
serialized: dict[str, Any],
|
|
@@ -34,19 +39,25 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
|
|
34
39
|
# If two calls are made in a row, this resets the state
|
|
35
40
|
self.done.clear()
|
|
36
41
|
|
|
42
|
+
@override
|
|
37
43
|
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
38
44
|
if token is not None and token != "":
|
|
39
45
|
self.queue.put_nowait(token)
|
|
40
46
|
|
|
41
|
-
|
|
47
|
+
@override
|
|
48
|
+
async def on_llm_end(
|
|
49
|
+
self, response: Union[LLMResult, AIMessage], **kwargs: Any
|
|
50
|
+
) -> None:
|
|
42
51
|
self.done.set()
|
|
43
52
|
|
|
53
|
+
@override
|
|
44
54
|
async def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
45
55
|
self.done.set()
|
|
46
56
|
|
|
47
57
|
# TODO implement the other methods
|
|
48
58
|
|
|
49
59
|
async def aiter(self) -> AsyncIterator[str]:
|
|
60
|
+
"""Asynchronous iterator that yields tokens."""
|
|
50
61
|
while not self.queue.empty() or not self.done.is_set():
|
|
51
62
|
# Wait for the next token in the queue,
|
|
52
63
|
# but stop waiting if the done event is set
|
|
@@ -65,7 +76,7 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
|
|
65
76
|
other.pop().cancel()
|
|
66
77
|
|
|
67
78
|
# Extract the value of the first completed task
|
|
68
|
-
token_or_done = cast(Union[str, Literal[True]], done.pop().result())
|
|
79
|
+
token_or_done = cast("Union[str, Literal[True]]", done.pop().result())
|
|
69
80
|
|
|
70
81
|
# If the extracted value is the boolean True, the done event was set
|
|
71
82
|
if token_or_done is True:
|
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import Any, Optional
|
|
3
|
+
from typing import Any, Optional, Union
|
|
4
4
|
|
|
5
5
|
from langchain_core.outputs import LLMResult
|
|
6
|
+
from langchain_core.v1.messages import AIMessage
|
|
7
|
+
from typing_extensions import override
|
|
6
8
|
|
|
7
9
|
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
|
8
10
|
|
|
@@ -15,6 +17,7 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|
|
15
17
|
"""
|
|
16
18
|
|
|
17
19
|
def append_to_last_tokens(self, token: str) -> None:
|
|
20
|
+
"""Append token to the last tokens."""
|
|
18
21
|
self.last_tokens.append(token)
|
|
19
22
|
self.last_tokens_stripped.append(token.strip())
|
|
20
23
|
if len(self.last_tokens) > len(self.answer_prefix_tokens):
|
|
@@ -22,6 +25,7 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|
|
22
25
|
self.last_tokens_stripped.pop(0)
|
|
23
26
|
|
|
24
27
|
def check_if_answer_reached(self) -> bool:
|
|
28
|
+
"""Check if the answer has been reached."""
|
|
25
29
|
if self.strip_tokens:
|
|
26
30
|
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
|
|
27
31
|
return self.last_tokens == self.answer_prefix_tokens
|
|
@@ -60,6 +64,7 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|
|
60
64
|
self.stream_prefix = stream_prefix
|
|
61
65
|
self.answer_reached = False
|
|
62
66
|
|
|
67
|
+
@override
|
|
63
68
|
async def on_llm_start(
|
|
64
69
|
self,
|
|
65
70
|
serialized: dict[str, Any],
|
|
@@ -70,10 +75,14 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|
|
70
75
|
self.done.clear()
|
|
71
76
|
self.answer_reached = False
|
|
72
77
|
|
|
73
|
-
|
|
78
|
+
@override
|
|
79
|
+
async def on_llm_end(
|
|
80
|
+
self, response: Union[LLMResult, AIMessage], **kwargs: Any
|
|
81
|
+
) -> None:
|
|
74
82
|
if self.answer_reached:
|
|
75
83
|
self.done.set()
|
|
76
84
|
|
|
85
|
+
@override
|
|
77
86
|
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
78
87
|
# Remember the last n tokens, where n = len(answer_prefix_tokens)
|
|
79
88
|
self.append_to_last_tokens(token)
|
|
@@ -4,6 +4,7 @@ import sys
|
|
|
4
4
|
from typing import Any, Optional
|
|
5
5
|
|
|
6
6
|
from langchain_core.callbacks import StreamingStdOutCallbackHandler
|
|
7
|
+
from typing_extensions import override
|
|
7
8
|
|
|
8
9
|
DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"]
|
|
9
10
|
|
|
@@ -16,6 +17,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
|
|
16
17
|
"""
|
|
17
18
|
|
|
18
19
|
def append_to_last_tokens(self, token: str) -> None:
|
|
20
|
+
"""Append token to the last tokens."""
|
|
19
21
|
self.last_tokens.append(token)
|
|
20
22
|
self.last_tokens_stripped.append(token.strip())
|
|
21
23
|
if len(self.last_tokens) > len(self.answer_prefix_tokens):
|
|
@@ -23,6 +25,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
|
|
23
25
|
self.last_tokens_stripped.pop(0)
|
|
24
26
|
|
|
25
27
|
def check_if_answer_reached(self) -> bool:
|
|
28
|
+
"""Check if the answer has been reached."""
|
|
26
29
|
if self.strip_tokens:
|
|
27
30
|
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
|
|
28
31
|
return self.last_tokens == self.answer_prefix_tokens
|
|
@@ -61,6 +64,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
|
|
61
64
|
self.stream_prefix = stream_prefix
|
|
62
65
|
self.answer_reached = False
|
|
63
66
|
|
|
67
|
+
@override
|
|
64
68
|
def on_llm_start(
|
|
65
69
|
self,
|
|
66
70
|
serialized: dict[str, Any],
|
|
@@ -70,6 +74,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
|
|
70
74
|
"""Run when LLM starts running."""
|
|
71
75
|
self.answer_reached = False
|
|
72
76
|
|
|
77
|
+
@override
|
|
73
78
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
74
79
|
"""Run on new LLM token. Only available when streaming is enabled."""
|
|
75
80
|
|
|
@@ -7,6 +7,7 @@ from uuid import UUID
|
|
|
7
7
|
from langchain_core.exceptions import TracerException
|
|
8
8
|
from langchain_core.tracers.stdout import FunctionCallbackHandler
|
|
9
9
|
from langchain_core.utils.input import get_bolded_text, get_colored_text
|
|
10
|
+
from typing_extensions import override
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
class LoggingCallbackHandler(FunctionCallbackHandler):
|
|
@@ -21,6 +22,15 @@ class LoggingCallbackHandler(FunctionCallbackHandler):
|
|
|
21
22
|
extra: Optional[dict] = None,
|
|
22
23
|
**kwargs: Any,
|
|
23
24
|
) -> None:
|
|
25
|
+
"""
|
|
26
|
+
Initialize the LoggingCallbackHandler.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
logger: the logger to use for logging
|
|
30
|
+
log_level: the logging level (default: logging.INFO)
|
|
31
|
+
extra: the extra context to log (default: None)
|
|
32
|
+
**kwargs:
|
|
33
|
+
"""
|
|
24
34
|
log_method = getattr(logger, logging.getLevelName(level=log_level).lower())
|
|
25
35
|
|
|
26
36
|
def callback(text: str) -> None:
|
|
@@ -28,6 +38,7 @@ class LoggingCallbackHandler(FunctionCallbackHandler):
|
|
|
28
38
|
|
|
29
39
|
super().__init__(function=callback, **kwargs)
|
|
30
40
|
|
|
41
|
+
@override
|
|
31
42
|
def on_text(
|
|
32
43
|
self,
|
|
33
44
|
text: str,
|
langchain/chains/api/base.py
CHANGED
|
@@ -191,6 +191,7 @@ try:
|
|
|
191
191
|
)
|
|
192
192
|
async for event in events:
|
|
193
193
|
event["messages"][-1].pretty_print()
|
|
194
|
+
|
|
194
195
|
""" # noqa: E501
|
|
195
196
|
|
|
196
197
|
api_request_chain: LLMChain
|
|
@@ -386,7 +387,10 @@ try:
|
|
|
386
387
|
except ImportError:
|
|
387
388
|
|
|
388
389
|
class APIChain: # type: ignore[no-redef]
|
|
389
|
-
|
|
390
|
+
"""Raise an ImportError if APIChain is used without langchain_community."""
|
|
391
|
+
|
|
392
|
+
def __init__(self, *_: Any, **__: Any) -> None:
|
|
393
|
+
"""Raise an ImportError if APIChain is used without langchain_community."""
|
|
390
394
|
msg = (
|
|
391
395
|
"To use the APIChain, you must install the langchain_community package."
|
|
392
396
|
"pip install langchain_community"
|
langchain/chains/base.py
CHANGED
|
@@ -108,6 +108,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|
|
108
108
|
arbitrary_types_allowed=True,
|
|
109
109
|
)
|
|
110
110
|
|
|
111
|
+
@override
|
|
111
112
|
def get_input_schema(
|
|
112
113
|
self,
|
|
113
114
|
config: Optional[RunnableConfig] = None,
|
|
@@ -115,6 +116,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|
|
115
116
|
# This is correct, but pydantic typings/mypy don't think so.
|
|
116
117
|
return create_model("ChainInput", **dict.fromkeys(self.input_keys, (Any, None)))
|
|
117
118
|
|
|
119
|
+
@override
|
|
118
120
|
def get_output_schema(
|
|
119
121
|
self,
|
|
120
122
|
config: Optional[RunnableConfig] = None,
|
|
@@ -409,7 +411,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|
|
409
411
|
|
|
410
412
|
return self.invoke(
|
|
411
413
|
inputs,
|
|
412
|
-
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}),
|
|
414
|
+
cast("RunnableConfig", {k: v for k, v in config.items() if v is not None}),
|
|
413
415
|
return_only_outputs=return_only_outputs,
|
|
414
416
|
include_run_info=include_run_info,
|
|
415
417
|
)
|
|
@@ -459,7 +461,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|
|
459
461
|
}
|
|
460
462
|
return await self.ainvoke(
|
|
461
463
|
inputs,
|
|
462
|
-
cast(RunnableConfig, {k: v for k, v in config.items() if k is not None}),
|
|
464
|
+
cast("RunnableConfig", {k: v for k, v in config.items() if k is not None}),
|
|
463
465
|
return_only_outputs=return_only_outputs,
|
|
464
466
|
include_run_info=include_run_info,
|
|
465
467
|
)
|
|
@@ -616,6 +618,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|
|
616
618
|
context = "Weather report for Boise, Idaho on 07/03/23..."
|
|
617
619
|
chain.run(question=question, context=context)
|
|
618
620
|
# -> "The temperature in Boise is..."
|
|
621
|
+
|
|
619
622
|
"""
|
|
620
623
|
# Run at start to make sure this is possible/defined
|
|
621
624
|
_output_key = self._run_output_key
|
|
@@ -690,6 +693,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|
|
690
693
|
context = "Weather report for Boise, Idaho on 07/03/23..."
|
|
691
694
|
await chain.arun(question=question, context=context)
|
|
692
695
|
# -> "The temperature in Boise is..."
|
|
696
|
+
|
|
693
697
|
"""
|
|
694
698
|
if len(self.output_keys) != 1:
|
|
695
699
|
msg = (
|
|
@@ -744,6 +748,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|
|
744
748
|
|
|
745
749
|
chain.dict(exclude_unset=True)
|
|
746
750
|
# -> {"_type": "foo", "verbose": False, ...}
|
|
751
|
+
|
|
747
752
|
"""
|
|
748
753
|
_dict = super().dict(**kwargs)
|
|
749
754
|
with contextlib.suppress(NotImplementedError):
|
|
@@ -763,6 +768,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|
|
763
768
|
.. code-block:: python
|
|
764
769
|
|
|
765
770
|
chain.save(file_path="path/chain.yaml")
|
|
771
|
+
|
|
766
772
|
"""
|
|
767
773
|
if self.memory is not None:
|
|
768
774
|
msg = "Saving of memory is not yet supported."
|
|
@@ -14,6 +14,7 @@ from langchain_core.runnables.config import RunnableConfig
|
|
|
14
14
|
from langchain_core.utils.pydantic import create_model
|
|
15
15
|
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
|
16
16
|
from pydantic import BaseModel, Field
|
|
17
|
+
from typing_extensions import override
|
|
17
18
|
|
|
18
19
|
from langchain.chains.base import Chain
|
|
19
20
|
|
|
@@ -46,6 +47,7 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
|
|
46
47
|
input_key: str = "input_documents" #: :meta private:
|
|
47
48
|
output_key: str = "output_text" #: :meta private:
|
|
48
49
|
|
|
50
|
+
@override
|
|
49
51
|
def get_input_schema(
|
|
50
52
|
self,
|
|
51
53
|
config: Optional[RunnableConfig] = None,
|
|
@@ -55,6 +57,7 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
|
|
55
57
|
**{self.input_key: (list[Document], None)},
|
|
56
58
|
)
|
|
57
59
|
|
|
60
|
+
@override
|
|
58
61
|
def get_output_schema(
|
|
59
62
|
self,
|
|
60
63
|
config: Optional[RunnableConfig] = None,
|
|
@@ -80,7 +83,7 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
|
|
80
83
|
"""
|
|
81
84
|
return [self.output_key]
|
|
82
85
|
|
|
83
|
-
def prompt_length(self, docs: list[Document], **kwargs: Any) -> Optional[int]:
|
|
86
|
+
def prompt_length(self, docs: list[Document], **kwargs: Any) -> Optional[int]: # noqa: ARG002
|
|
84
87
|
"""Return the prompt length given the documents passed in.
|
|
85
88
|
|
|
86
89
|
This can be used by a caller to determine whether passing in a list
|
|
@@ -231,6 +234,7 @@ class AnalyzeDocumentChain(Chain):
|
|
|
231
234
|
input_documents=itemgetter("input_document") | split_text,
|
|
232
235
|
) | chain.pick("output_text")
|
|
233
236
|
)
|
|
237
|
+
|
|
234
238
|
"""
|
|
235
239
|
|
|
236
240
|
input_key: str = "input_document" #: :meta private:
|
|
@@ -253,6 +257,7 @@ class AnalyzeDocumentChain(Chain):
|
|
|
253
257
|
"""
|
|
254
258
|
return self.combine_docs_chain.output_keys
|
|
255
259
|
|
|
260
|
+
@override
|
|
256
261
|
def get_input_schema(
|
|
257
262
|
self,
|
|
258
263
|
config: Optional[RunnableConfig] = None,
|
|
@@ -262,6 +267,7 @@ class AnalyzeDocumentChain(Chain):
|
|
|
262
267
|
**{self.input_key: (str, None)},
|
|
263
268
|
)
|
|
264
269
|
|
|
270
|
+
@override
|
|
265
271
|
def get_output_schema(
|
|
266
272
|
self,
|
|
267
273
|
config: Optional[RunnableConfig] = None,
|
|
@@ -10,6 +10,7 @@ from langchain_core.documents import Document
|
|
|
10
10
|
from langchain_core.runnables.config import RunnableConfig
|
|
11
11
|
from langchain_core.utils.pydantic import create_model
|
|
12
12
|
from pydantic import BaseModel, ConfigDict, model_validator
|
|
13
|
+
from typing_extensions import override
|
|
13
14
|
|
|
14
15
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
|
15
16
|
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
|
|
@@ -98,6 +99,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|
|
98
99
|
llm_chain=llm_chain,
|
|
99
100
|
reduce_documents_chain=reduce_documents_chain,
|
|
100
101
|
)
|
|
102
|
+
|
|
101
103
|
"""
|
|
102
104
|
|
|
103
105
|
llm_chain: LLMChain
|
|
@@ -111,6 +113,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|
|
111
113
|
return_intermediate_steps: bool = False
|
|
112
114
|
"""Return the results of the map steps in the output."""
|
|
113
115
|
|
|
116
|
+
@override
|
|
114
117
|
def get_output_schema(
|
|
115
118
|
self,
|
|
116
119
|
config: Optional[RunnableConfig] = None,
|