langchain 0.2.13__py3-none-any.whl → 0.2.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/_api/module_import.py +2 -2
- langchain/agents/__init__.py +4 -3
- langchain/agents/agent.py +40 -23
- langchain/agents/agent_iterator.py +2 -2
- langchain/agents/agent_toolkits/__init__.py +1 -1
- langchain/agents/agent_toolkits/vectorstore/toolkit.py +2 -1
- langchain/agents/agent_types.py +2 -2
- langchain/agents/chat/base.py +1 -1
- langchain/agents/conversational/base.py +1 -1
- langchain/agents/conversational_chat/base.py +1 -1
- langchain/agents/initialize.py +1 -1
- langchain/agents/json_chat/base.py +1 -1
- langchain/agents/loading.py +2 -2
- langchain/agents/mrkl/base.py +3 -3
- langchain/agents/openai_assistant/base.py +11 -15
- langchain/agents/openai_functions_agent/base.py +1 -1
- langchain/agents/openai_functions_multi_agent/base.py +1 -1
- langchain/agents/react/agent.py +1 -1
- langchain/agents/react/base.py +4 -4
- langchain/agents/self_ask_with_search/base.py +2 -2
- langchain/agents/structured_chat/base.py +3 -2
- langchain/agents/tools.py +2 -2
- langchain/agents/xml/base.py +2 -2
- langchain/chains/base.py +5 -5
- langchain/chains/combine_documents/base.py +4 -4
- langchain/chains/combine_documents/stuff.py +1 -1
- langchain/chains/constitutional_ai/base.py +144 -1
- langchain/chains/conversation/base.py +1 -1
- langchain/chains/conversational_retrieval/base.py +1 -1
- langchain/chains/flare/base.py +42 -69
- langchain/chains/hyde/base.py +18 -8
- langchain/chains/llm_math/base.py +118 -1
- langchain/chains/natbot/base.py +24 -10
- langchain/chains/openai_functions/base.py +2 -2
- langchain/chains/openai_functions/extraction.py +2 -2
- langchain/chains/openai_tools/extraction.py +1 -1
- langchain/chains/query_constructor/parser.py +23 -0
- langchain/chains/structured_output/base.py +2 -2
- langchain/retrievers/document_compressors/chain_extract.py +19 -10
- langchain/retrievers/document_compressors/chain_filter.py +27 -10
- langchain/retrievers/document_compressors/cohere_rerank.py +1 -1
- langchain/retrievers/re_phraser.py +7 -7
- langchain/tools/__init__.py +14 -5
- langchain/tools/render.py +0 -2
- langchain/tools/retriever.py +0 -4
- {langchain-0.2.13.dist-info → langchain-0.2.15.dist-info}/METADATA +2 -2
- {langchain-0.2.13.dist-info → langchain-0.2.15.dist-info}/RECORD +50 -50
- {langchain-0.2.13.dist-info → langchain-0.2.15.dist-info}/LICENSE +0 -0
- {langchain-0.2.13.dist-info → langchain-0.2.15.dist-info}/WHEEL +0 -0
- {langchain-0.2.13.dist-info → langchain-0.2.15.dist-info}/entry_points.txt +0 -0
langchain/agents/tools.py
CHANGED
|
@@ -6,7 +6,7 @@ from langchain_core.callbacks import (
|
|
|
6
6
|
AsyncCallbackManagerForToolRun,
|
|
7
7
|
CallbackManagerForToolRun,
|
|
8
8
|
)
|
|
9
|
-
from langchain_core.tools import BaseTool,
|
|
9
|
+
from langchain_core.tools import BaseTool, tool
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class InvalidTool(BaseTool):
|
|
@@ -44,4 +44,4 @@ class InvalidTool(BaseTool):
|
|
|
44
44
|
)
|
|
45
45
|
|
|
46
46
|
|
|
47
|
-
__all__ = ["InvalidTool", "
|
|
47
|
+
__all__ = ["InvalidTool", "tool"]
|
langchain/agents/xml/base.py
CHANGED
|
@@ -8,16 +8,16 @@ from langchain_core.prompts.base import BasePromptTemplate
|
|
|
8
8
|
from langchain_core.prompts.chat import AIMessagePromptTemplate, ChatPromptTemplate
|
|
9
9
|
from langchain_core.runnables import Runnable, RunnablePassthrough
|
|
10
10
|
from langchain_core.tools import BaseTool
|
|
11
|
+
from langchain_core.tools.render import ToolsRenderer, render_text_description
|
|
11
12
|
|
|
12
13
|
from langchain.agents.agent import BaseSingleActionAgent
|
|
13
14
|
from langchain.agents.format_scratchpad import format_xml
|
|
14
15
|
from langchain.agents.output_parsers import XMLAgentOutputParser
|
|
15
16
|
from langchain.agents.xml.prompt import agent_instructions
|
|
16
17
|
from langchain.chains.llm import LLMChain
|
|
17
|
-
from langchain.tools.render import ToolsRenderer, render_text_description
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
@deprecated("0.1.0", alternative="create_xml_agent", removal="
|
|
20
|
+
@deprecated("0.1.0", alternative="create_xml_agent", removal="1.0")
|
|
21
21
|
class XMLAgent(BaseSingleActionAgent):
|
|
22
22
|
"""Agent that uses XML tags.
|
|
23
23
|
|
langchain/chains/base.py
CHANGED
|
@@ -334,7 +334,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
334
334
|
None, self._call, inputs, run_manager.get_sync() if run_manager else None
|
|
335
335
|
)
|
|
336
336
|
|
|
337
|
-
@deprecated("0.1.0", alternative="invoke", removal="
|
|
337
|
+
@deprecated("0.1.0", alternative="invoke", removal="1.0")
|
|
338
338
|
def __call__(
|
|
339
339
|
self,
|
|
340
340
|
inputs: Union[Dict[str, Any], Any],
|
|
@@ -385,7 +385,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
385
385
|
include_run_info=include_run_info,
|
|
386
386
|
)
|
|
387
387
|
|
|
388
|
-
@deprecated("0.1.0", alternative="ainvoke", removal="
|
|
388
|
+
@deprecated("0.1.0", alternative="ainvoke", removal="1.0")
|
|
389
389
|
async def acall(
|
|
390
390
|
self,
|
|
391
391
|
inputs: Union[Dict[str, Any], Any],
|
|
@@ -544,7 +544,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
544
544
|
)
|
|
545
545
|
return self.output_keys[0]
|
|
546
546
|
|
|
547
|
-
@deprecated("0.1.0", alternative="invoke", removal="
|
|
547
|
+
@deprecated("0.1.0", alternative="invoke", removal="1.0")
|
|
548
548
|
def run(
|
|
549
549
|
self,
|
|
550
550
|
*args: Any,
|
|
@@ -615,7 +615,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
615
615
|
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
|
616
616
|
)
|
|
617
617
|
|
|
618
|
-
@deprecated("0.1.0", alternative="ainvoke", removal="
|
|
618
|
+
@deprecated("0.1.0", alternative="ainvoke", removal="1.0")
|
|
619
619
|
async def arun(
|
|
620
620
|
self,
|
|
621
621
|
*args: Any,
|
|
@@ -753,7 +753,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
753
753
|
else:
|
|
754
754
|
raise ValueError(f"{save_path} must be json or yaml")
|
|
755
755
|
|
|
756
|
-
@deprecated("0.1.0", alternative="batch", removal="
|
|
756
|
+
@deprecated("0.1.0", alternative="batch", removal="1.0")
|
|
757
757
|
def apply(
|
|
758
758
|
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
|
759
759
|
) -> List[Dict[str, str]]:
|
|
@@ -22,11 +22,11 @@ DOCUMENTS_KEY = "context"
|
|
|
22
22
|
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template("{page_content}")
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
def _validate_prompt(prompt: BasePromptTemplate) -> None:
|
|
26
|
-
if
|
|
25
|
+
def _validate_prompt(prompt: BasePromptTemplate, document_variable_name: str) -> None:
|
|
26
|
+
if document_variable_name not in prompt.input_variables:
|
|
27
27
|
raise ValueError(
|
|
28
|
-
f"Prompt must accept {
|
|
29
|
-
f"with input variables: {prompt.input_variables}"
|
|
28
|
+
f"Prompt must accept {document_variable_name} as an input variable. "
|
|
29
|
+
f"Received prompt with input variables: {prompt.input_variables}"
|
|
30
30
|
)
|
|
31
31
|
|
|
32
32
|
|
|
@@ -76,7 +76,7 @@ def create_stuff_documents_chain(
|
|
|
76
76
|
chain.invoke({"context": docs})
|
|
77
77
|
""" # noqa: E501
|
|
78
78
|
|
|
79
|
-
_validate_prompt(prompt)
|
|
79
|
+
_validate_prompt(prompt, document_variable_name)
|
|
80
80
|
_document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT
|
|
81
81
|
_output_parser = output_parser or StrOutputParser()
|
|
82
82
|
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Any, Dict, List, Optional
|
|
4
4
|
|
|
5
|
+
from langchain_core._api import deprecated
|
|
5
6
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
|
6
7
|
from langchain_core.language_models import BaseLanguageModel
|
|
7
8
|
from langchain_core.prompts import BasePromptTemplate
|
|
@@ -13,9 +14,151 @@ from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION
|
|
|
13
14
|
from langchain.chains.llm import LLMChain
|
|
14
15
|
|
|
15
16
|
|
|
17
|
+
@deprecated(
|
|
18
|
+
since="0.2.13",
|
|
19
|
+
message=(
|
|
20
|
+
"This class is deprecated and will be removed in langchain 1.0. "
|
|
21
|
+
"See API reference for replacement: "
|
|
22
|
+
"https://api.python.langchain.com/en/latest/chains/langchain.chains.constitutional_ai.base.ConstitutionalChain.html" # noqa: E501
|
|
23
|
+
),
|
|
24
|
+
removal="1.0",
|
|
25
|
+
)
|
|
16
26
|
class ConstitutionalChain(Chain):
|
|
17
27
|
"""Chain for applying constitutional principles.
|
|
18
28
|
|
|
29
|
+
Note: this class is deprecated. See below for a replacement implementation
|
|
30
|
+
using LangGraph. The benefits of this implementation are:
|
|
31
|
+
|
|
32
|
+
- Uses LLM tool calling features instead of parsing string responses;
|
|
33
|
+
- Support for both token-by-token and step-by-step streaming;
|
|
34
|
+
- Support for checkpointing and memory of chat history;
|
|
35
|
+
- Easier to modify or extend (e.g., with additional tools, structured responses, etc.)
|
|
36
|
+
|
|
37
|
+
Install LangGraph with:
|
|
38
|
+
|
|
39
|
+
.. code-block:: bash
|
|
40
|
+
|
|
41
|
+
pip install -U langgraph
|
|
42
|
+
|
|
43
|
+
.. code-block:: python
|
|
44
|
+
|
|
45
|
+
from typing import List, Optional, Tuple
|
|
46
|
+
|
|
47
|
+
from langchain.chains.constitutional_ai.prompts import (
|
|
48
|
+
CRITIQUE_PROMPT,
|
|
49
|
+
REVISION_PROMPT,
|
|
50
|
+
)
|
|
51
|
+
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
|
52
|
+
from langchain_core.output_parsers import StrOutputParser
|
|
53
|
+
from langchain_core.prompts import ChatPromptTemplate
|
|
54
|
+
from langchain_openai import ChatOpenAI
|
|
55
|
+
from langgraph.graph import END, START, StateGraph
|
|
56
|
+
from typing_extensions import Annotated, TypedDict
|
|
57
|
+
|
|
58
|
+
llm = ChatOpenAI(model="gpt-4o-mini")
|
|
59
|
+
|
|
60
|
+
class Critique(TypedDict):
|
|
61
|
+
\"\"\"Generate a critique, if needed.\"\"\"
|
|
62
|
+
critique_needed: Annotated[bool, ..., "Whether or not a critique is needed."]
|
|
63
|
+
critique: Annotated[str, ..., "If needed, the critique."]
|
|
64
|
+
|
|
65
|
+
critique_prompt = ChatPromptTemplate.from_template(
|
|
66
|
+
"Critique this response according to the critique request. "
|
|
67
|
+
"If no critique is needed, specify that.\\n\\n"
|
|
68
|
+
"Query: {query}\\n\\n"
|
|
69
|
+
"Response: {response}\\n\\n"
|
|
70
|
+
"Critique request: {critique_request}"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
revision_prompt = ChatPromptTemplate.from_template(
|
|
74
|
+
"Revise this response according to the critique and reivsion request.\\n\\n"
|
|
75
|
+
"Query: {query}\\n\\n"
|
|
76
|
+
"Response: {response}\\n\\n"
|
|
77
|
+
"Critique request: {critique_request}\\n\\n"
|
|
78
|
+
"Critique: {critique}\\n\\n"
|
|
79
|
+
"If the critique does not identify anything worth changing, ignore the "
|
|
80
|
+
"revision request and return 'No revisions needed'. If the critique "
|
|
81
|
+
"does identify something worth changing, revise the response based on "
|
|
82
|
+
"the revision request.\\n\\n"
|
|
83
|
+
"Revision Request: {revision_request}"
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
chain = llm | StrOutputParser()
|
|
87
|
+
critique_chain = critique_prompt | llm.with_structured_output(Critique)
|
|
88
|
+
revision_chain = revision_prompt | llm | StrOutputParser()
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class State(TypedDict):
|
|
92
|
+
query: str
|
|
93
|
+
constitutional_principles: List[ConstitutionalPrinciple]
|
|
94
|
+
initial_response: str
|
|
95
|
+
critiques_and_revisions: List[Tuple[str, str]]
|
|
96
|
+
response: str
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
async def generate_response(state: State):
|
|
100
|
+
\"\"\"Generate initial response.\"\"\"
|
|
101
|
+
response = await chain.ainvoke(state["query"])
|
|
102
|
+
return {"response": response, "initial_response": response}
|
|
103
|
+
|
|
104
|
+
async def critique_and_revise(state: State):
|
|
105
|
+
\"\"\"Critique and revise response according to principles.\"\"\"
|
|
106
|
+
critiques_and_revisions = []
|
|
107
|
+
response = state["initial_response"]
|
|
108
|
+
for principle in state["constitutional_principles"]:
|
|
109
|
+
critique = await critique_chain.ainvoke(
|
|
110
|
+
{
|
|
111
|
+
"query": state["query"],
|
|
112
|
+
"response": response,
|
|
113
|
+
"critique_request": principle.critique_request,
|
|
114
|
+
}
|
|
115
|
+
)
|
|
116
|
+
if critique["critique_needed"]:
|
|
117
|
+
revision = await revision_chain.ainvoke(
|
|
118
|
+
{
|
|
119
|
+
"query": state["query"],
|
|
120
|
+
"response": response,
|
|
121
|
+
"critique_request": principle.critique_request,
|
|
122
|
+
"critique": critique["critique"],
|
|
123
|
+
"revision_request": principle.revision_request,
|
|
124
|
+
}
|
|
125
|
+
)
|
|
126
|
+
response = revision
|
|
127
|
+
critiques_and_revisions.append((critique["critique"], revision))
|
|
128
|
+
else:
|
|
129
|
+
critiques_and_revisions.append((critique["critique"], ""))
|
|
130
|
+
return {
|
|
131
|
+
"critiques_and_revisions": critiques_and_revisions,
|
|
132
|
+
"response": response,
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
graph = StateGraph(State)
|
|
136
|
+
graph.add_node("generate_response", generate_response)
|
|
137
|
+
graph.add_node("critique_and_revise", critique_and_revise)
|
|
138
|
+
|
|
139
|
+
graph.add_edge(START, "generate_response")
|
|
140
|
+
graph.add_edge("generate_response", "critique_and_revise")
|
|
141
|
+
graph.add_edge("critique_and_revise", END)
|
|
142
|
+
app = graph.compile()
|
|
143
|
+
|
|
144
|
+
.. code-block:: python
|
|
145
|
+
|
|
146
|
+
constitutional_principles=[
|
|
147
|
+
ConstitutionalPrinciple(
|
|
148
|
+
critique_request="Tell if this answer is good.",
|
|
149
|
+
revision_request="Give a better answer.",
|
|
150
|
+
)
|
|
151
|
+
]
|
|
152
|
+
|
|
153
|
+
query = "What is the meaning of life? Answer in 10 words or fewer."
|
|
154
|
+
|
|
155
|
+
async for step in app.astream(
|
|
156
|
+
{"query": query, "constitutional_principles": constitutional_principles},
|
|
157
|
+
stream_mode="values",
|
|
158
|
+
):
|
|
159
|
+
subset = ["initial_response", "critiques_and_revisions", "response"]
|
|
160
|
+
print({k: v for k, v in step.items() if k in subset})
|
|
161
|
+
|
|
19
162
|
Example:
|
|
20
163
|
.. code-block:: python
|
|
21
164
|
|
|
@@ -44,7 +187,7 @@ class ConstitutionalChain(Chain):
|
|
|
44
187
|
)
|
|
45
188
|
|
|
46
189
|
constitutional_chain.run(question="What is the meaning of life?")
|
|
47
|
-
"""
|
|
190
|
+
""" # noqa: E501
|
|
48
191
|
|
|
49
192
|
chain: LLMChain
|
|
50
193
|
constitutional_principles: List[ConstitutionalPrinciple]
|
|
@@ -16,7 +16,7 @@ from langchain.memory.buffer import ConversationBufferMemory
|
|
|
16
16
|
since="0.2.7",
|
|
17
17
|
alternative=(
|
|
18
18
|
"RunnableWithMessageHistory: "
|
|
19
|
-
"https://
|
|
19
|
+
"https://python.langchain.com/v0.2/api_reference/core/runnables/langchain_core.runnables.history.RunnableWithMessageHistory.html" # noqa: E501
|
|
20
20
|
),
|
|
21
21
|
removal="1.0",
|
|
22
22
|
)
|
|
@@ -242,7 +242,7 @@ class BaseConversationalRetrievalChain(Chain):
|
|
|
242
242
|
"create_history_aware_retriever together with create_retrieval_chain "
|
|
243
243
|
"(see example in docstring)"
|
|
244
244
|
),
|
|
245
|
-
removal="
|
|
245
|
+
removal="1.0",
|
|
246
246
|
)
|
|
247
247
|
class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
|
248
248
|
"""Chain for having a conversation based on retrieved documents.
|
langchain/chains/flare/base.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import re
|
|
4
|
-
from abc import abstractmethod
|
|
5
4
|
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
|
6
5
|
|
|
7
6
|
import numpy as np
|
|
@@ -9,10 +8,12 @@ from langchain_core.callbacks import (
|
|
|
9
8
|
CallbackManagerForChainRun,
|
|
10
9
|
)
|
|
11
10
|
from langchain_core.language_models import BaseLanguageModel
|
|
12
|
-
from langchain_core.
|
|
11
|
+
from langchain_core.messages import AIMessage
|
|
12
|
+
from langchain_core.output_parsers import StrOutputParser
|
|
13
13
|
from langchain_core.prompts import BasePromptTemplate
|
|
14
14
|
from langchain_core.pydantic_v1 import Field
|
|
15
15
|
from langchain_core.retrievers import BaseRetriever
|
|
16
|
+
from langchain_core.runnables import Runnable
|
|
16
17
|
|
|
17
18
|
from langchain.chains.base import Chain
|
|
18
19
|
from langchain.chains.flare.prompts import (
|
|
@@ -23,51 +24,14 @@ from langchain.chains.flare.prompts import (
|
|
|
23
24
|
from langchain.chains.llm import LLMChain
|
|
24
25
|
|
|
25
26
|
|
|
26
|
-
|
|
27
|
-
"""
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
@property
|
|
36
|
-
def input_keys(self) -> List[str]:
|
|
37
|
-
return self.prompt.input_variables
|
|
38
|
-
|
|
39
|
-
def generate_tokens_and_log_probs(
|
|
40
|
-
self,
|
|
41
|
-
_input: Dict[str, Any],
|
|
42
|
-
*,
|
|
43
|
-
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
44
|
-
) -> Tuple[Sequence[str], Sequence[float]]:
|
|
45
|
-
llm_result = self.generate([_input], run_manager=run_manager)
|
|
46
|
-
return self._extract_tokens_and_log_probs(llm_result.generations[0])
|
|
47
|
-
|
|
48
|
-
@abstractmethod
|
|
49
|
-
def _extract_tokens_and_log_probs(
|
|
50
|
-
self, generations: List[Generation]
|
|
51
|
-
) -> Tuple[Sequence[str], Sequence[float]]:
|
|
52
|
-
"""Extract tokens and log probs from response."""
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
class _OpenAIResponseChain(_ResponseChain):
|
|
56
|
-
"""Chain that generates responses from user input and context."""
|
|
57
|
-
|
|
58
|
-
llm: BaseLanguageModel
|
|
59
|
-
|
|
60
|
-
def _extract_tokens_and_log_probs(
|
|
61
|
-
self, generations: List[Generation]
|
|
62
|
-
) -> Tuple[Sequence[str], Sequence[float]]:
|
|
63
|
-
tokens = []
|
|
64
|
-
log_probs = []
|
|
65
|
-
for gen in generations:
|
|
66
|
-
if gen.generation_info is None:
|
|
67
|
-
raise ValueError
|
|
68
|
-
tokens.extend(gen.generation_info["logprobs"]["tokens"])
|
|
69
|
-
log_probs.extend(gen.generation_info["logprobs"]["token_logprobs"])
|
|
70
|
-
return tokens, log_probs
|
|
27
|
+
def _extract_tokens_and_log_probs(response: AIMessage) -> Tuple[List[str], List[float]]:
|
|
28
|
+
"""Extract tokens and log probabilities from chat model response."""
|
|
29
|
+
tokens = []
|
|
30
|
+
log_probs = []
|
|
31
|
+
for token in response.response_metadata["logprobs"]["content"]:
|
|
32
|
+
tokens.append(token["token"])
|
|
33
|
+
log_probs.append(token["logprob"])
|
|
34
|
+
return tokens, log_probs
|
|
71
35
|
|
|
72
36
|
|
|
73
37
|
class QuestionGeneratorChain(LLMChain):
|
|
@@ -111,9 +75,9 @@ class FlareChain(Chain):
|
|
|
111
75
|
"""Chain that combines a retriever, a question generator,
|
|
112
76
|
and a response generator."""
|
|
113
77
|
|
|
114
|
-
question_generator_chain:
|
|
78
|
+
question_generator_chain: Runnable
|
|
115
79
|
"""Chain that generates questions from uncertain spans."""
|
|
116
|
-
response_chain:
|
|
80
|
+
response_chain: Runnable
|
|
117
81
|
"""Chain that generates responses from user input and context."""
|
|
118
82
|
output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
|
|
119
83
|
"""Parser that determines whether the chain is finished."""
|
|
@@ -152,12 +116,16 @@ class FlareChain(Chain):
|
|
|
152
116
|
for question in questions:
|
|
153
117
|
docs.extend(self.retriever.invoke(question))
|
|
154
118
|
context = "\n\n".join(d.page_content for d in docs)
|
|
155
|
-
result = self.response_chain.
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
119
|
+
result = self.response_chain.invoke(
|
|
120
|
+
{
|
|
121
|
+
"user_input": user_input,
|
|
122
|
+
"context": context,
|
|
123
|
+
"response": response,
|
|
124
|
+
},
|
|
125
|
+
{"callbacks": callbacks},
|
|
160
126
|
)
|
|
127
|
+
if isinstance(result, AIMessage):
|
|
128
|
+
result = result.content
|
|
161
129
|
marginal, finished = self.output_parser.parse(result)
|
|
162
130
|
return marginal, finished
|
|
163
131
|
|
|
@@ -178,13 +146,18 @@ class FlareChain(Chain):
|
|
|
178
146
|
for span in low_confidence_spans
|
|
179
147
|
]
|
|
180
148
|
callbacks = _run_manager.get_child()
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
149
|
+
if isinstance(self.question_generator_chain, LLMChain):
|
|
150
|
+
question_gen_outputs = self.question_generator_chain.apply(
|
|
151
|
+
question_gen_inputs, callbacks=callbacks
|
|
152
|
+
)
|
|
153
|
+
questions = [
|
|
154
|
+
output[self.question_generator_chain.output_keys[0]]
|
|
155
|
+
for output in question_gen_outputs
|
|
156
|
+
]
|
|
157
|
+
else:
|
|
158
|
+
questions = self.question_generator_chain.batch(
|
|
159
|
+
question_gen_inputs, config={"callbacks": callbacks}
|
|
160
|
+
)
|
|
188
161
|
_run_manager.on_text(
|
|
189
162
|
f"Generated Questions: {questions}", color="yellow", end="\n"
|
|
190
163
|
)
|
|
@@ -206,8 +179,10 @@ class FlareChain(Chain):
|
|
|
206
179
|
f"Current Response: {response}", color="blue", end="\n"
|
|
207
180
|
)
|
|
208
181
|
_input = {"user_input": user_input, "context": "", "response": response}
|
|
209
|
-
tokens, log_probs =
|
|
210
|
-
|
|
182
|
+
tokens, log_probs = _extract_tokens_and_log_probs(
|
|
183
|
+
self.response_chain.invoke(
|
|
184
|
+
_input, {"callbacks": _run_manager.get_child()}
|
|
185
|
+
)
|
|
211
186
|
)
|
|
212
187
|
low_confidence_spans = _low_confidence_spans(
|
|
213
188
|
tokens,
|
|
@@ -251,18 +226,16 @@ class FlareChain(Chain):
|
|
|
251
226
|
FlareChain class with the given language model.
|
|
252
227
|
"""
|
|
253
228
|
try:
|
|
254
|
-
from langchain_openai import
|
|
229
|
+
from langchain_openai import ChatOpenAI
|
|
255
230
|
except ImportError:
|
|
256
231
|
raise ImportError(
|
|
257
232
|
"OpenAI is required for FlareChain. "
|
|
258
233
|
"Please install langchain-openai."
|
|
259
234
|
"pip install langchain-openai"
|
|
260
235
|
)
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
)
|
|
265
|
-
response_chain = _OpenAIResponseChain(llm=response_llm)
|
|
236
|
+
llm = ChatOpenAI(max_tokens=max_generation_len, logprobs=True, temperature=0)
|
|
237
|
+
response_chain = PROMPT | llm
|
|
238
|
+
question_gen_chain = QUESTION_GENERATOR_PROMPT | llm | StrOutputParser()
|
|
266
239
|
return cls(
|
|
267
240
|
question_generator_chain=question_gen_chain,
|
|
268
241
|
response_chain=response_chain,
|
langchain/chains/hyde/base.py
CHANGED
|
@@ -11,7 +11,9 @@ import numpy as np
|
|
|
11
11
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
|
12
12
|
from langchain_core.embeddings import Embeddings
|
|
13
13
|
from langchain_core.language_models import BaseLanguageModel
|
|
14
|
+
from langchain_core.output_parsers import StrOutputParser
|
|
14
15
|
from langchain_core.prompts import BasePromptTemplate
|
|
16
|
+
from langchain_core.runnables import Runnable
|
|
15
17
|
|
|
16
18
|
from langchain.chains.base import Chain
|
|
17
19
|
from langchain.chains.hyde.prompts import PROMPT_MAP
|
|
@@ -25,7 +27,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|
|
25
27
|
"""
|
|
26
28
|
|
|
27
29
|
base_embeddings: Embeddings
|
|
28
|
-
llm_chain:
|
|
30
|
+
llm_chain: Runnable
|
|
29
31
|
|
|
30
32
|
class Config:
|
|
31
33
|
arbitrary_types_allowed = True
|
|
@@ -34,12 +36,15 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|
|
34
36
|
@property
|
|
35
37
|
def input_keys(self) -> List[str]:
|
|
36
38
|
"""Input keys for Hyde's LLM chain."""
|
|
37
|
-
return self.llm_chain.
|
|
39
|
+
return self.llm_chain.input_schema.schema()["required"]
|
|
38
40
|
|
|
39
41
|
@property
|
|
40
42
|
def output_keys(self) -> List[str]:
|
|
41
43
|
"""Output keys for Hyde's LLM chain."""
|
|
42
|
-
|
|
44
|
+
if isinstance(self.llm_chain, LLMChain):
|
|
45
|
+
return self.llm_chain.output_keys
|
|
46
|
+
else:
|
|
47
|
+
return ["text"]
|
|
43
48
|
|
|
44
49
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
45
50
|
"""Call the base embeddings."""
|
|
@@ -51,9 +56,12 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|
|
51
56
|
|
|
52
57
|
def embed_query(self, text: str) -> List[float]:
|
|
53
58
|
"""Generate a hypothetical document and embedded it."""
|
|
54
|
-
var_name = self.
|
|
55
|
-
result = self.llm_chain.
|
|
56
|
-
|
|
59
|
+
var_name = self.input_keys[0]
|
|
60
|
+
result = self.llm_chain.invoke({var_name: text})
|
|
61
|
+
if isinstance(self.llm_chain, LLMChain):
|
|
62
|
+
documents = [result[self.output_keys[0]]]
|
|
63
|
+
else:
|
|
64
|
+
documents = [result]
|
|
57
65
|
embeddings = self.embed_documents(documents)
|
|
58
66
|
return self.combine_embeddings(embeddings)
|
|
59
67
|
|
|
@@ -64,7 +72,9 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|
|
64
72
|
) -> Dict[str, str]:
|
|
65
73
|
"""Call the internal llm chain."""
|
|
66
74
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
67
|
-
return self.llm_chain
|
|
75
|
+
return self.llm_chain.invoke(
|
|
76
|
+
inputs, config={"callbacks": _run_manager.get_child()}
|
|
77
|
+
)
|
|
68
78
|
|
|
69
79
|
@classmethod
|
|
70
80
|
def from_llm(
|
|
@@ -86,7 +96,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|
|
86
96
|
f"of {list(PROMPT_MAP.keys())}."
|
|
87
97
|
)
|
|
88
98
|
|
|
89
|
-
llm_chain =
|
|
99
|
+
llm_chain = prompt | llm | StrOutputParser()
|
|
90
100
|
return cls(base_embeddings=base_embeddings, llm_chain=llm_chain, **kwargs)
|
|
91
101
|
|
|
92
102
|
@property
|