langchain 0.2.14__py3-none-any.whl → 0.2.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/agents/agent.py +38 -21
- langchain/agents/agent_iterator.py +2 -2
- langchain/agents/agent_types.py +1 -1
- langchain/agents/openai_assistant/base.py +11 -15
- langchain/chains/combine_documents/base.py +4 -4
- langchain/chains/combine_documents/stuff.py +1 -1
- langchain/chains/constitutional_ai/base.py +144 -1
- langchain/chains/conversation/base.py +1 -1
- langchain/chains/flare/base.py +42 -69
- langchain/chains/hyde/base.py +18 -8
- langchain/chains/llm_math/base.py +118 -1
- langchain/chains/natbot/base.py +24 -10
- langchain/chains/query_constructor/parser.py +23 -0
- langchain/retrievers/document_compressors/chain_extract.py +19 -10
- langchain/retrievers/document_compressors/chain_filter.py +27 -10
- langchain/retrievers/re_phraser.py +7 -7
- {langchain-0.2.14.dist-info → langchain-0.2.15.dist-info}/METADATA +2 -2
- {langchain-0.2.14.dist-info → langchain-0.2.15.dist-info}/RECORD +21 -21
- {langchain-0.2.14.dist-info → langchain-0.2.15.dist-info}/LICENSE +0 -0
- {langchain-0.2.14.dist-info → langchain-0.2.15.dist-info}/WHEEL +0 -0
- {langchain-0.2.14.dist-info → langchain-0.2.15.dist-info}/entry_points.txt +0 -0
langchain/chains/flare/base.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import re
|
|
4
|
-
from abc import abstractmethod
|
|
5
4
|
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
|
6
5
|
|
|
7
6
|
import numpy as np
|
|
@@ -9,10 +8,12 @@ from langchain_core.callbacks import (
|
|
|
9
8
|
CallbackManagerForChainRun,
|
|
10
9
|
)
|
|
11
10
|
from langchain_core.language_models import BaseLanguageModel
|
|
12
|
-
from langchain_core.
|
|
11
|
+
from langchain_core.messages import AIMessage
|
|
12
|
+
from langchain_core.output_parsers import StrOutputParser
|
|
13
13
|
from langchain_core.prompts import BasePromptTemplate
|
|
14
14
|
from langchain_core.pydantic_v1 import Field
|
|
15
15
|
from langchain_core.retrievers import BaseRetriever
|
|
16
|
+
from langchain_core.runnables import Runnable
|
|
16
17
|
|
|
17
18
|
from langchain.chains.base import Chain
|
|
18
19
|
from langchain.chains.flare.prompts import (
|
|
@@ -23,51 +24,14 @@ from langchain.chains.flare.prompts import (
|
|
|
23
24
|
from langchain.chains.llm import LLMChain
|
|
24
25
|
|
|
25
26
|
|
|
26
|
-
|
|
27
|
-
"""
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
@property
|
|
36
|
-
def input_keys(self) -> List[str]:
|
|
37
|
-
return self.prompt.input_variables
|
|
38
|
-
|
|
39
|
-
def generate_tokens_and_log_probs(
|
|
40
|
-
self,
|
|
41
|
-
_input: Dict[str, Any],
|
|
42
|
-
*,
|
|
43
|
-
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
44
|
-
) -> Tuple[Sequence[str], Sequence[float]]:
|
|
45
|
-
llm_result = self.generate([_input], run_manager=run_manager)
|
|
46
|
-
return self._extract_tokens_and_log_probs(llm_result.generations[0])
|
|
47
|
-
|
|
48
|
-
@abstractmethod
|
|
49
|
-
def _extract_tokens_and_log_probs(
|
|
50
|
-
self, generations: List[Generation]
|
|
51
|
-
) -> Tuple[Sequence[str], Sequence[float]]:
|
|
52
|
-
"""Extract tokens and log probs from response."""
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
class _OpenAIResponseChain(_ResponseChain):
|
|
56
|
-
"""Chain that generates responses from user input and context."""
|
|
57
|
-
|
|
58
|
-
llm: BaseLanguageModel
|
|
59
|
-
|
|
60
|
-
def _extract_tokens_and_log_probs(
|
|
61
|
-
self, generations: List[Generation]
|
|
62
|
-
) -> Tuple[Sequence[str], Sequence[float]]:
|
|
63
|
-
tokens = []
|
|
64
|
-
log_probs = []
|
|
65
|
-
for gen in generations:
|
|
66
|
-
if gen.generation_info is None:
|
|
67
|
-
raise ValueError
|
|
68
|
-
tokens.extend(gen.generation_info["logprobs"]["tokens"])
|
|
69
|
-
log_probs.extend(gen.generation_info["logprobs"]["token_logprobs"])
|
|
70
|
-
return tokens, log_probs
|
|
27
|
+
def _extract_tokens_and_log_probs(response: AIMessage) -> Tuple[List[str], List[float]]:
|
|
28
|
+
"""Extract tokens and log probabilities from chat model response."""
|
|
29
|
+
tokens = []
|
|
30
|
+
log_probs = []
|
|
31
|
+
for token in response.response_metadata["logprobs"]["content"]:
|
|
32
|
+
tokens.append(token["token"])
|
|
33
|
+
log_probs.append(token["logprob"])
|
|
34
|
+
return tokens, log_probs
|
|
71
35
|
|
|
72
36
|
|
|
73
37
|
class QuestionGeneratorChain(LLMChain):
|
|
@@ -111,9 +75,9 @@ class FlareChain(Chain):
|
|
|
111
75
|
"""Chain that combines a retriever, a question generator,
|
|
112
76
|
and a response generator."""
|
|
113
77
|
|
|
114
|
-
question_generator_chain:
|
|
78
|
+
question_generator_chain: Runnable
|
|
115
79
|
"""Chain that generates questions from uncertain spans."""
|
|
116
|
-
response_chain:
|
|
80
|
+
response_chain: Runnable
|
|
117
81
|
"""Chain that generates responses from user input and context."""
|
|
118
82
|
output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
|
|
119
83
|
"""Parser that determines whether the chain is finished."""
|
|
@@ -152,12 +116,16 @@ class FlareChain(Chain):
|
|
|
152
116
|
for question in questions:
|
|
153
117
|
docs.extend(self.retriever.invoke(question))
|
|
154
118
|
context = "\n\n".join(d.page_content for d in docs)
|
|
155
|
-
result = self.response_chain.
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
119
|
+
result = self.response_chain.invoke(
|
|
120
|
+
{
|
|
121
|
+
"user_input": user_input,
|
|
122
|
+
"context": context,
|
|
123
|
+
"response": response,
|
|
124
|
+
},
|
|
125
|
+
{"callbacks": callbacks},
|
|
160
126
|
)
|
|
127
|
+
if isinstance(result, AIMessage):
|
|
128
|
+
result = result.content
|
|
161
129
|
marginal, finished = self.output_parser.parse(result)
|
|
162
130
|
return marginal, finished
|
|
163
131
|
|
|
@@ -178,13 +146,18 @@ class FlareChain(Chain):
|
|
|
178
146
|
for span in low_confidence_spans
|
|
179
147
|
]
|
|
180
148
|
callbacks = _run_manager.get_child()
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
149
|
+
if isinstance(self.question_generator_chain, LLMChain):
|
|
150
|
+
question_gen_outputs = self.question_generator_chain.apply(
|
|
151
|
+
question_gen_inputs, callbacks=callbacks
|
|
152
|
+
)
|
|
153
|
+
questions = [
|
|
154
|
+
output[self.question_generator_chain.output_keys[0]]
|
|
155
|
+
for output in question_gen_outputs
|
|
156
|
+
]
|
|
157
|
+
else:
|
|
158
|
+
questions = self.question_generator_chain.batch(
|
|
159
|
+
question_gen_inputs, config={"callbacks": callbacks}
|
|
160
|
+
)
|
|
188
161
|
_run_manager.on_text(
|
|
189
162
|
f"Generated Questions: {questions}", color="yellow", end="\n"
|
|
190
163
|
)
|
|
@@ -206,8 +179,10 @@ class FlareChain(Chain):
|
|
|
206
179
|
f"Current Response: {response}", color="blue", end="\n"
|
|
207
180
|
)
|
|
208
181
|
_input = {"user_input": user_input, "context": "", "response": response}
|
|
209
|
-
tokens, log_probs =
|
|
210
|
-
|
|
182
|
+
tokens, log_probs = _extract_tokens_and_log_probs(
|
|
183
|
+
self.response_chain.invoke(
|
|
184
|
+
_input, {"callbacks": _run_manager.get_child()}
|
|
185
|
+
)
|
|
211
186
|
)
|
|
212
187
|
low_confidence_spans = _low_confidence_spans(
|
|
213
188
|
tokens,
|
|
@@ -251,18 +226,16 @@ class FlareChain(Chain):
|
|
|
251
226
|
FlareChain class with the given language model.
|
|
252
227
|
"""
|
|
253
228
|
try:
|
|
254
|
-
from langchain_openai import
|
|
229
|
+
from langchain_openai import ChatOpenAI
|
|
255
230
|
except ImportError:
|
|
256
231
|
raise ImportError(
|
|
257
232
|
"OpenAI is required for FlareChain. "
|
|
258
233
|
"Please install langchain-openai."
|
|
259
234
|
"pip install langchain-openai"
|
|
260
235
|
)
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
)
|
|
265
|
-
response_chain = _OpenAIResponseChain(llm=response_llm)
|
|
236
|
+
llm = ChatOpenAI(max_tokens=max_generation_len, logprobs=True, temperature=0)
|
|
237
|
+
response_chain = PROMPT | llm
|
|
238
|
+
question_gen_chain = QUESTION_GENERATOR_PROMPT | llm | StrOutputParser()
|
|
266
239
|
return cls(
|
|
267
240
|
question_generator_chain=question_gen_chain,
|
|
268
241
|
response_chain=response_chain,
|
langchain/chains/hyde/base.py
CHANGED
|
@@ -11,7 +11,9 @@ import numpy as np
|
|
|
11
11
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
|
12
12
|
from langchain_core.embeddings import Embeddings
|
|
13
13
|
from langchain_core.language_models import BaseLanguageModel
|
|
14
|
+
from langchain_core.output_parsers import StrOutputParser
|
|
14
15
|
from langchain_core.prompts import BasePromptTemplate
|
|
16
|
+
from langchain_core.runnables import Runnable
|
|
15
17
|
|
|
16
18
|
from langchain.chains.base import Chain
|
|
17
19
|
from langchain.chains.hyde.prompts import PROMPT_MAP
|
|
@@ -25,7 +27,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|
|
25
27
|
"""
|
|
26
28
|
|
|
27
29
|
base_embeddings: Embeddings
|
|
28
|
-
llm_chain:
|
|
30
|
+
llm_chain: Runnable
|
|
29
31
|
|
|
30
32
|
class Config:
|
|
31
33
|
arbitrary_types_allowed = True
|
|
@@ -34,12 +36,15 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|
|
34
36
|
@property
|
|
35
37
|
def input_keys(self) -> List[str]:
|
|
36
38
|
"""Input keys for Hyde's LLM chain."""
|
|
37
|
-
return self.llm_chain.
|
|
39
|
+
return self.llm_chain.input_schema.schema()["required"]
|
|
38
40
|
|
|
39
41
|
@property
|
|
40
42
|
def output_keys(self) -> List[str]:
|
|
41
43
|
"""Output keys for Hyde's LLM chain."""
|
|
42
|
-
|
|
44
|
+
if isinstance(self.llm_chain, LLMChain):
|
|
45
|
+
return self.llm_chain.output_keys
|
|
46
|
+
else:
|
|
47
|
+
return ["text"]
|
|
43
48
|
|
|
44
49
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
45
50
|
"""Call the base embeddings."""
|
|
@@ -51,9 +56,12 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|
|
51
56
|
|
|
52
57
|
def embed_query(self, text: str) -> List[float]:
|
|
53
58
|
"""Generate a hypothetical document and embedded it."""
|
|
54
|
-
var_name = self.
|
|
55
|
-
result = self.llm_chain.
|
|
56
|
-
|
|
59
|
+
var_name = self.input_keys[0]
|
|
60
|
+
result = self.llm_chain.invoke({var_name: text})
|
|
61
|
+
if isinstance(self.llm_chain, LLMChain):
|
|
62
|
+
documents = [result[self.output_keys[0]]]
|
|
63
|
+
else:
|
|
64
|
+
documents = [result]
|
|
57
65
|
embeddings = self.embed_documents(documents)
|
|
58
66
|
return self.combine_embeddings(embeddings)
|
|
59
67
|
|
|
@@ -64,7 +72,9 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|
|
64
72
|
) -> Dict[str, str]:
|
|
65
73
|
"""Call the internal llm chain."""
|
|
66
74
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
67
|
-
return self.llm_chain
|
|
75
|
+
return self.llm_chain.invoke(
|
|
76
|
+
inputs, config={"callbacks": _run_manager.get_child()}
|
|
77
|
+
)
|
|
68
78
|
|
|
69
79
|
@classmethod
|
|
70
80
|
def from_llm(
|
|
@@ -86,7 +96,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|
|
86
96
|
f"of {list(PROMPT_MAP.keys())}."
|
|
87
97
|
)
|
|
88
98
|
|
|
89
|
-
llm_chain =
|
|
99
|
+
llm_chain = prompt | llm | StrOutputParser()
|
|
90
100
|
return cls(base_embeddings=base_embeddings, llm_chain=llm_chain, **kwargs)
|
|
91
101
|
|
|
92
102
|
@property
|
|
@@ -7,6 +7,7 @@ import re
|
|
|
7
7
|
import warnings
|
|
8
8
|
from typing import Any, Dict, List, Optional
|
|
9
9
|
|
|
10
|
+
from langchain_core._api import deprecated
|
|
10
11
|
from langchain_core.callbacks import (
|
|
11
12
|
AsyncCallbackManagerForChainRun,
|
|
12
13
|
CallbackManagerForChainRun,
|
|
@@ -20,16 +21,132 @@ from langchain.chains.llm import LLMChain
|
|
|
20
21
|
from langchain.chains.llm_math.prompt import PROMPT
|
|
21
22
|
|
|
22
23
|
|
|
24
|
+
@deprecated(
|
|
25
|
+
since="0.2.13",
|
|
26
|
+
message=(
|
|
27
|
+
"This class is deprecated and will be removed in langchain 1.0. "
|
|
28
|
+
"See API reference for replacement: "
|
|
29
|
+
"https://api.python.langchain.com/en/latest/chains/langchain.chains.llm_math.base.LLMMathChain.html" # noqa: E501
|
|
30
|
+
),
|
|
31
|
+
removal="1.0",
|
|
32
|
+
)
|
|
23
33
|
class LLMMathChain(Chain):
|
|
24
34
|
"""Chain that interprets a prompt and executes python code to do math.
|
|
25
35
|
|
|
36
|
+
Note: this class is deprecated. See below for a replacement implementation
|
|
37
|
+
using LangGraph. The benefits of this implementation are:
|
|
38
|
+
|
|
39
|
+
- Uses LLM tool calling features;
|
|
40
|
+
- Support for both token-by-token and step-by-step streaming;
|
|
41
|
+
- Support for checkpointing and memory of chat history;
|
|
42
|
+
- Easier to modify or extend (e.g., with additional tools, structured responses, etc.)
|
|
43
|
+
|
|
44
|
+
Install LangGraph with:
|
|
45
|
+
|
|
46
|
+
.. code-block:: bash
|
|
47
|
+
|
|
48
|
+
pip install -U langgraph
|
|
49
|
+
|
|
50
|
+
.. code-block:: python
|
|
51
|
+
|
|
52
|
+
import math
|
|
53
|
+
from typing import Annotated, Sequence
|
|
54
|
+
|
|
55
|
+
from langchain_core.messages import BaseMessage
|
|
56
|
+
from langchain_core.runnables import RunnableConfig
|
|
57
|
+
from langchain_core.tools import tool
|
|
58
|
+
from langchain_openai import ChatOpenAI
|
|
59
|
+
from langgraph.graph import END, StateGraph
|
|
60
|
+
from langgraph.graph.message import add_messages
|
|
61
|
+
from langgraph.prebuilt.tool_node import ToolNode
|
|
62
|
+
import numexpr
|
|
63
|
+
from typing_extensions import TypedDict
|
|
64
|
+
|
|
65
|
+
@tool
|
|
66
|
+
def calculator(expression: str) -> str:
|
|
67
|
+
\"\"\"Calculate expression using Python's numexpr library.
|
|
68
|
+
|
|
69
|
+
Expression should be a single line mathematical expression
|
|
70
|
+
that solves the problem.
|
|
71
|
+
|
|
72
|
+
Examples:
|
|
73
|
+
"37593 * 67" for "37593 times 67"
|
|
74
|
+
"37593**(1/5)" for "37593^(1/5)"
|
|
75
|
+
\"\"\"
|
|
76
|
+
local_dict = {"pi": math.pi, "e": math.e}
|
|
77
|
+
return str(
|
|
78
|
+
numexpr.evaluate(
|
|
79
|
+
expression.strip(),
|
|
80
|
+
global_dict={}, # restrict access to globals
|
|
81
|
+
local_dict=local_dict, # add common mathematical functions
|
|
82
|
+
)
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
|
86
|
+
tools = [calculator]
|
|
87
|
+
llm_with_tools = llm.bind_tools(tools, tool_choice="any")
|
|
88
|
+
|
|
89
|
+
class ChainState(TypedDict):
|
|
90
|
+
\"\"\"LangGraph state.\"\"\"
|
|
91
|
+
|
|
92
|
+
messages: Annotated[Sequence[BaseMessage], add_messages]
|
|
93
|
+
|
|
94
|
+
async def acall_chain(state: ChainState, config: RunnableConfig):
|
|
95
|
+
last_message = state["messages"][-1]
|
|
96
|
+
response = await llm_with_tools.ainvoke(state["messages"], config)
|
|
97
|
+
return {"messages": [response]}
|
|
98
|
+
|
|
99
|
+
async def acall_model(state: ChainState, config: RunnableConfig):
|
|
100
|
+
response = await llm.ainvoke(state["messages"], config)
|
|
101
|
+
return {"messages": [response]}
|
|
102
|
+
|
|
103
|
+
graph_builder = StateGraph(ChainState)
|
|
104
|
+
graph_builder.add_node("call_tool", acall_chain)
|
|
105
|
+
graph_builder.add_node("execute_tool", ToolNode(tools))
|
|
106
|
+
graph_builder.add_node("call_model", acall_model)
|
|
107
|
+
graph_builder.set_entry_point("call_tool")
|
|
108
|
+
graph_builder.add_edge("call_tool", "execute_tool")
|
|
109
|
+
graph_builder.add_edge("execute_tool", "call_model")
|
|
110
|
+
graph_builder.add_edge("call_model", END)
|
|
111
|
+
chain = graph_builder.compile()
|
|
112
|
+
|
|
113
|
+
.. code-block:: python
|
|
114
|
+
|
|
115
|
+
example_query = "What is 551368 divided by 82"
|
|
116
|
+
|
|
117
|
+
events = chain.astream(
|
|
118
|
+
{"messages": [("user", example_query)]},
|
|
119
|
+
stream_mode="values",
|
|
120
|
+
)
|
|
121
|
+
async for event in events:
|
|
122
|
+
event["messages"][-1].pretty_print()
|
|
123
|
+
|
|
124
|
+
.. code-block:: none
|
|
125
|
+
|
|
126
|
+
================================ Human Message =================================
|
|
127
|
+
|
|
128
|
+
What is 551368 divided by 82
|
|
129
|
+
================================== Ai Message ==================================
|
|
130
|
+
Tool Calls:
|
|
131
|
+
calculator (call_MEiGXuJjJ7wGU4aOT86QuGJS)
|
|
132
|
+
Call ID: call_MEiGXuJjJ7wGU4aOT86QuGJS
|
|
133
|
+
Args:
|
|
134
|
+
expression: 551368 / 82
|
|
135
|
+
================================= Tool Message =================================
|
|
136
|
+
Name: calculator
|
|
137
|
+
|
|
138
|
+
6724.0
|
|
139
|
+
================================== Ai Message ==================================
|
|
140
|
+
|
|
141
|
+
551368 divided by 82 equals 6724.
|
|
142
|
+
|
|
26
143
|
Example:
|
|
27
144
|
.. code-block:: python
|
|
28
145
|
|
|
29
146
|
from langchain.chains import LLMMathChain
|
|
30
147
|
from langchain_community.llms import OpenAI
|
|
31
148
|
llm_math = LLMMathChain.from_llm(OpenAI())
|
|
32
|
-
"""
|
|
149
|
+
""" # noqa: E501
|
|
33
150
|
|
|
34
151
|
llm_chain: LLMChain
|
|
35
152
|
llm: Optional[BaseLanguageModel] = None
|
langchain/chains/natbot/base.py
CHANGED
|
@@ -5,15 +5,27 @@ from __future__ import annotations
|
|
|
5
5
|
import warnings
|
|
6
6
|
from typing import Any, Dict, List, Optional
|
|
7
7
|
|
|
8
|
+
from langchain_core._api import deprecated
|
|
8
9
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
|
9
10
|
from langchain_core.language_models import BaseLanguageModel
|
|
11
|
+
from langchain_core.output_parsers import StrOutputParser
|
|
10
12
|
from langchain_core.pydantic_v1 import root_validator
|
|
13
|
+
from langchain_core.runnables import Runnable
|
|
11
14
|
|
|
12
15
|
from langchain.chains.base import Chain
|
|
13
|
-
from langchain.chains.llm import LLMChain
|
|
14
16
|
from langchain.chains.natbot.prompt import PROMPT
|
|
15
17
|
|
|
16
18
|
|
|
19
|
+
@deprecated(
|
|
20
|
+
since="0.2.13",
|
|
21
|
+
message=(
|
|
22
|
+
"Importing NatBotChain from langchain is deprecated and will be removed in "
|
|
23
|
+
"langchain 1.0. Please import from langchain_community instead: "
|
|
24
|
+
"from langchain_community.chains.natbot import NatBotChain. "
|
|
25
|
+
"You may need to pip install -U langchain-community."
|
|
26
|
+
),
|
|
27
|
+
removal="1.0",
|
|
28
|
+
)
|
|
17
29
|
class NatBotChain(Chain):
|
|
18
30
|
"""Implement an LLM driven browser.
|
|
19
31
|
|
|
@@ -37,7 +49,7 @@ class NatBotChain(Chain):
|
|
|
37
49
|
natbot = NatBotChain.from_default("Buy me a new hat.")
|
|
38
50
|
"""
|
|
39
51
|
|
|
40
|
-
llm_chain:
|
|
52
|
+
llm_chain: Runnable
|
|
41
53
|
objective: str
|
|
42
54
|
"""Objective that NatBot is tasked with completing."""
|
|
43
55
|
llm: Optional[BaseLanguageModel] = None
|
|
@@ -60,7 +72,7 @@ class NatBotChain(Chain):
|
|
|
60
72
|
"class method."
|
|
61
73
|
)
|
|
62
74
|
if "llm_chain" not in values and values["llm"] is not None:
|
|
63
|
-
values["llm_chain"] =
|
|
75
|
+
values["llm_chain"] = PROMPT | values["llm"] | StrOutputParser()
|
|
64
76
|
return values
|
|
65
77
|
|
|
66
78
|
@classmethod
|
|
@@ -77,7 +89,7 @@ class NatBotChain(Chain):
|
|
|
77
89
|
cls, llm: BaseLanguageModel, objective: str, **kwargs: Any
|
|
78
90
|
) -> NatBotChain:
|
|
79
91
|
"""Load from LLM."""
|
|
80
|
-
llm_chain =
|
|
92
|
+
llm_chain = PROMPT | llm | StrOutputParser()
|
|
81
93
|
return cls(llm_chain=llm_chain, objective=objective, **kwargs)
|
|
82
94
|
|
|
83
95
|
@property
|
|
@@ -104,12 +116,14 @@ class NatBotChain(Chain):
|
|
|
104
116
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
105
117
|
url = inputs[self.input_url_key]
|
|
106
118
|
browser_content = inputs[self.input_browser_content_key]
|
|
107
|
-
llm_cmd = self.llm_chain.
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
119
|
+
llm_cmd = self.llm_chain.invoke(
|
|
120
|
+
{
|
|
121
|
+
"objective": self.objective,
|
|
122
|
+
"url": url[:100],
|
|
123
|
+
"previous_command": self.previous_command,
|
|
124
|
+
"browser_content": browser_content[:4500],
|
|
125
|
+
},
|
|
126
|
+
config={"callbacks": _run_manager.get_child()},
|
|
113
127
|
)
|
|
114
128
|
llm_cmd = llm_cmd.strip()
|
|
115
129
|
self.previous_command = llm_cmd
|
|
@@ -35,6 +35,7 @@ GRAMMAR = r"""
|
|
|
35
35
|
?value: SIGNED_INT -> int
|
|
36
36
|
| SIGNED_FLOAT -> float
|
|
37
37
|
| DATE -> date
|
|
38
|
+
| DATETIME -> datetime
|
|
38
39
|
| list
|
|
39
40
|
| string
|
|
40
41
|
| ("false" | "False" | "FALSE") -> false
|
|
@@ -42,6 +43,7 @@ GRAMMAR = r"""
|
|
|
42
43
|
|
|
43
44
|
args: expr ("," expr)*
|
|
44
45
|
DATE.2: /["']?(\d{4}-[01]\d-[0-3]\d)["']?/
|
|
46
|
+
DATETIME.2: /["']?\d{4}-[01]\d-[0-3]\dT[0-2]\d:[0-5]\d:[0-5]\d[Zz]?["']?/
|
|
45
47
|
string: /'[^']*'/ | ESCAPED_STRING
|
|
46
48
|
list: "[" [args] "]"
|
|
47
49
|
|
|
@@ -61,6 +63,13 @@ class ISO8601Date(TypedDict):
|
|
|
61
63
|
type: Literal["date"]
|
|
62
64
|
|
|
63
65
|
|
|
66
|
+
class ISO8601DateTime(TypedDict):
|
|
67
|
+
"""A datetime in ISO 8601 format (YYYY-MM-DDTHH:MM:SS)."""
|
|
68
|
+
|
|
69
|
+
datetime: str
|
|
70
|
+
type: Literal["datetime"]
|
|
71
|
+
|
|
72
|
+
|
|
64
73
|
@v_args(inline=True)
|
|
65
74
|
class QueryTransformer(Transformer):
|
|
66
75
|
"""Transform a query string into an intermediate representation."""
|
|
@@ -149,6 +158,20 @@ class QueryTransformer(Transformer):
|
|
|
149
158
|
)
|
|
150
159
|
return {"date": item, "type": "date"}
|
|
151
160
|
|
|
161
|
+
def datetime(self, item: Any) -> ISO8601DateTime:
|
|
162
|
+
item = str(item).strip("\"'")
|
|
163
|
+
try:
|
|
164
|
+
# Parse full ISO 8601 datetime format
|
|
165
|
+
datetime.datetime.strptime(item, "%Y-%m-%dT%H:%M:%S%z")
|
|
166
|
+
except ValueError:
|
|
167
|
+
try:
|
|
168
|
+
datetime.datetime.strptime(item, "%Y-%m-%dT%H:%M:%S")
|
|
169
|
+
except ValueError:
|
|
170
|
+
raise ValueError(
|
|
171
|
+
"Datetime values are expected to be in ISO 8601 format."
|
|
172
|
+
)
|
|
173
|
+
return {"datetime": item, "type": "datetime"}
|
|
174
|
+
|
|
152
175
|
def string(self, item: Any) -> str:
|
|
153
176
|
# Remove escaped quotes
|
|
154
177
|
return str(item).strip("\"'")
|
|
@@ -8,8 +8,9 @@ from typing import Any, Callable, Dict, Optional, Sequence, cast
|
|
|
8
8
|
from langchain_core.callbacks.manager import Callbacks
|
|
9
9
|
from langchain_core.documents import Document
|
|
10
10
|
from langchain_core.language_models import BaseLanguageModel
|
|
11
|
-
from langchain_core.output_parsers import BaseOutputParser
|
|
11
|
+
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
|
|
12
12
|
from langchain_core.prompts import PromptTemplate
|
|
13
|
+
from langchain_core.runnables import Runnable
|
|
13
14
|
|
|
14
15
|
from langchain.chains.llm import LLMChain
|
|
15
16
|
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
|
@@ -49,12 +50,15 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
|
|
49
50
|
"""Document compressor that uses an LLM chain to extract
|
|
50
51
|
the relevant parts of documents."""
|
|
51
52
|
|
|
52
|
-
llm_chain:
|
|
53
|
+
llm_chain: Runnable
|
|
53
54
|
"""LLM wrapper to use for compressing documents."""
|
|
54
55
|
|
|
55
56
|
get_input: Callable[[str, Document], dict] = default_get_input
|
|
56
57
|
"""Callable for constructing the chain input from the query and a Document."""
|
|
57
58
|
|
|
59
|
+
class Config:
|
|
60
|
+
arbitrary_types_allowed = True
|
|
61
|
+
|
|
58
62
|
def compress_documents(
|
|
59
63
|
self,
|
|
60
64
|
documents: Sequence[Document],
|
|
@@ -65,10 +69,13 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
|
|
65
69
|
compressed_docs = []
|
|
66
70
|
for doc in documents:
|
|
67
71
|
_input = self.get_input(query, doc)
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
+
output_ = self.llm_chain.invoke(_input, config={"callbacks": callbacks})
|
|
73
|
+
if isinstance(self.llm_chain, LLMChain):
|
|
74
|
+
output = output_[self.llm_chain.output_key]
|
|
75
|
+
if self.llm_chain.prompt.output_parser is not None:
|
|
76
|
+
output = self.llm_chain.prompt.output_parser.parse(output)
|
|
77
|
+
else:
|
|
78
|
+
output = output_
|
|
72
79
|
if len(output) == 0:
|
|
73
80
|
continue
|
|
74
81
|
compressed_docs.append(
|
|
@@ -85,9 +92,7 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
|
|
85
92
|
"""Compress page content of raw documents asynchronously."""
|
|
86
93
|
outputs = await asyncio.gather(
|
|
87
94
|
*[
|
|
88
|
-
self.llm_chain.
|
|
89
|
-
**self.get_input(query, doc), callbacks=callbacks
|
|
90
|
-
)
|
|
95
|
+
self.llm_chain.ainvoke(self.get_input(query, doc), callbacks=callbacks)
|
|
91
96
|
for doc in documents
|
|
92
97
|
]
|
|
93
98
|
)
|
|
@@ -111,5 +116,9 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
|
|
111
116
|
"""Initialize from LLM."""
|
|
112
117
|
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
|
|
113
118
|
_get_input = get_input if get_input is not None else default_get_input
|
|
114
|
-
|
|
119
|
+
if _prompt.output_parser is not None:
|
|
120
|
+
parser = _prompt.output_parser
|
|
121
|
+
else:
|
|
122
|
+
parser = StrOutputParser()
|
|
123
|
+
llm_chain = _prompt | llm | parser
|
|
115
124
|
return cls(llm_chain=llm_chain, get_input=_get_input) # type: ignore[arg-type]
|
|
@@ -5,7 +5,9 @@ from typing import Any, Callable, Dict, Optional, Sequence
|
|
|
5
5
|
from langchain_core.callbacks.manager import Callbacks
|
|
6
6
|
from langchain_core.documents import Document
|
|
7
7
|
from langchain_core.language_models import BaseLanguageModel
|
|
8
|
+
from langchain_core.output_parsers import StrOutputParser
|
|
8
9
|
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
|
10
|
+
from langchain_core.runnables import Runnable
|
|
9
11
|
from langchain_core.runnables.config import RunnableConfig
|
|
10
12
|
|
|
11
13
|
from langchain.chains import LLMChain
|
|
@@ -32,13 +34,16 @@ def default_get_input(query: str, doc: Document) -> Dict[str, Any]:
|
|
|
32
34
|
class LLMChainFilter(BaseDocumentCompressor):
|
|
33
35
|
"""Filter that drops documents that aren't relevant to the query."""
|
|
34
36
|
|
|
35
|
-
llm_chain:
|
|
37
|
+
llm_chain: Runnable
|
|
36
38
|
"""LLM wrapper to use for filtering documents.
|
|
37
39
|
The chain prompt is expected to have a BooleanOutputParser."""
|
|
38
40
|
|
|
39
41
|
get_input: Callable[[str, Document], dict] = default_get_input
|
|
40
42
|
"""Callable for constructing the chain input from the query and a Document."""
|
|
41
43
|
|
|
44
|
+
class Config:
|
|
45
|
+
arbitrary_types_allowed = True
|
|
46
|
+
|
|
42
47
|
def compress_documents(
|
|
43
48
|
self,
|
|
44
49
|
documents: Sequence[Document],
|
|
@@ -56,11 +61,15 @@ class LLMChainFilter(BaseDocumentCompressor):
|
|
|
56
61
|
documents,
|
|
57
62
|
)
|
|
58
63
|
|
|
59
|
-
for
|
|
64
|
+
for output_, doc in outputs:
|
|
60
65
|
include_doc = None
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
66
|
+
if isinstance(self.llm_chain, LLMChain):
|
|
67
|
+
output = output_[self.llm_chain.output_key]
|
|
68
|
+
if self.llm_chain.prompt.output_parser is not None:
|
|
69
|
+
include_doc = self.llm_chain.prompt.output_parser.parse(output)
|
|
70
|
+
else:
|
|
71
|
+
if isinstance(output_, bool):
|
|
72
|
+
include_doc = output_
|
|
64
73
|
if include_doc:
|
|
65
74
|
filtered_docs.append(doc)
|
|
66
75
|
|
|
@@ -82,11 +91,15 @@ class LLMChainFilter(BaseDocumentCompressor):
|
|
|
82
91
|
),
|
|
83
92
|
documents,
|
|
84
93
|
)
|
|
85
|
-
for
|
|
94
|
+
for output_, doc in outputs:
|
|
86
95
|
include_doc = None
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
96
|
+
if isinstance(self.llm_chain, LLMChain):
|
|
97
|
+
output = output_[self.llm_chain.output_key]
|
|
98
|
+
if self.llm_chain.prompt.output_parser is not None:
|
|
99
|
+
include_doc = self.llm_chain.prompt.output_parser.parse(output)
|
|
100
|
+
else:
|
|
101
|
+
if isinstance(output_, bool):
|
|
102
|
+
include_doc = output_
|
|
90
103
|
if include_doc:
|
|
91
104
|
filtered_docs.append(doc)
|
|
92
105
|
|
|
@@ -110,5 +123,9 @@ class LLMChainFilter(BaseDocumentCompressor):
|
|
|
110
123
|
A LLMChainFilter that uses the given language model.
|
|
111
124
|
"""
|
|
112
125
|
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
|
|
113
|
-
|
|
126
|
+
if _prompt.output_parser is not None:
|
|
127
|
+
parser = _prompt.output_parser
|
|
128
|
+
else:
|
|
129
|
+
parser = StrOutputParser()
|
|
130
|
+
llm_chain = _prompt | llm | parser
|
|
114
131
|
return cls(llm_chain=llm_chain, **kwargs)
|