langchain 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/agents/agent.py +38 -21
- langchain/agents/agent_iterator.py +2 -2
- langchain/agents/agent_types.py +1 -1
- langchain/agents/openai_assistant/base.py +11 -15
- langchain/agents/openai_tools/base.py +8 -3
- langchain/chains/combine_documents/base.py +4 -4
- langchain/chains/combine_documents/stuff.py +1 -1
- langchain/chains/constitutional_ai/base.py +144 -1
- langchain/chains/conversation/base.py +1 -1
- langchain/chains/flare/base.py +46 -70
- langchain/chains/hyde/base.py +18 -8
- langchain/chains/llm_math/base.py +118 -1
- langchain/chains/moderation.py +8 -7
- langchain/chains/natbot/base.py +24 -10
- langchain/chains/query_constructor/parser.py +23 -0
- langchain/retrievers/document_compressors/chain_extract.py +19 -10
- langchain/retrievers/document_compressors/chain_filter.py +27 -10
- langchain/retrievers/re_phraser.py +7 -7
- langchain/retrievers/self_query/base.py +11 -2
- {langchain-0.2.14.dist-info → langchain-0.2.16.dist-info}/METADATA +2 -2
- {langchain-0.2.14.dist-info → langchain-0.2.16.dist-info}/RECORD +24 -24
- {langchain-0.2.14.dist-info → langchain-0.2.16.dist-info}/LICENSE +0 -0
- {langchain-0.2.14.dist-info → langchain-0.2.16.dist-info}/WHEEL +0 -0
- {langchain-0.2.14.dist-info → langchain-0.2.16.dist-info}/entry_points.txt +0 -0
langchain/chains/flare/base.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import re
|
|
4
|
-
from abc import abstractmethod
|
|
5
4
|
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
|
6
5
|
|
|
7
6
|
import numpy as np
|
|
@@ -9,10 +8,12 @@ from langchain_core.callbacks import (
|
|
|
9
8
|
CallbackManagerForChainRun,
|
|
10
9
|
)
|
|
11
10
|
from langchain_core.language_models import BaseLanguageModel
|
|
12
|
-
from langchain_core.
|
|
11
|
+
from langchain_core.messages import AIMessage
|
|
12
|
+
from langchain_core.output_parsers import StrOutputParser
|
|
13
13
|
from langchain_core.prompts import BasePromptTemplate
|
|
14
14
|
from langchain_core.pydantic_v1 import Field
|
|
15
15
|
from langchain_core.retrievers import BaseRetriever
|
|
16
|
+
from langchain_core.runnables import Runnable
|
|
16
17
|
|
|
17
18
|
from langchain.chains.base import Chain
|
|
18
19
|
from langchain.chains.flare.prompts import (
|
|
@@ -23,51 +24,14 @@ from langchain.chains.flare.prompts import (
|
|
|
23
24
|
from langchain.chains.llm import LLMChain
|
|
24
25
|
|
|
25
26
|
|
|
26
|
-
|
|
27
|
-
"""
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
@property
|
|
36
|
-
def input_keys(self) -> List[str]:
|
|
37
|
-
return self.prompt.input_variables
|
|
38
|
-
|
|
39
|
-
def generate_tokens_and_log_probs(
|
|
40
|
-
self,
|
|
41
|
-
_input: Dict[str, Any],
|
|
42
|
-
*,
|
|
43
|
-
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
44
|
-
) -> Tuple[Sequence[str], Sequence[float]]:
|
|
45
|
-
llm_result = self.generate([_input], run_manager=run_manager)
|
|
46
|
-
return self._extract_tokens_and_log_probs(llm_result.generations[0])
|
|
47
|
-
|
|
48
|
-
@abstractmethod
|
|
49
|
-
def _extract_tokens_and_log_probs(
|
|
50
|
-
self, generations: List[Generation]
|
|
51
|
-
) -> Tuple[Sequence[str], Sequence[float]]:
|
|
52
|
-
"""Extract tokens and log probs from response."""
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
class _OpenAIResponseChain(_ResponseChain):
|
|
56
|
-
"""Chain that generates responses from user input and context."""
|
|
57
|
-
|
|
58
|
-
llm: BaseLanguageModel
|
|
59
|
-
|
|
60
|
-
def _extract_tokens_and_log_probs(
|
|
61
|
-
self, generations: List[Generation]
|
|
62
|
-
) -> Tuple[Sequence[str], Sequence[float]]:
|
|
63
|
-
tokens = []
|
|
64
|
-
log_probs = []
|
|
65
|
-
for gen in generations:
|
|
66
|
-
if gen.generation_info is None:
|
|
67
|
-
raise ValueError
|
|
68
|
-
tokens.extend(gen.generation_info["logprobs"]["tokens"])
|
|
69
|
-
log_probs.extend(gen.generation_info["logprobs"]["token_logprobs"])
|
|
70
|
-
return tokens, log_probs
|
|
27
|
+
def _extract_tokens_and_log_probs(response: AIMessage) -> Tuple[List[str], List[float]]:
|
|
28
|
+
"""Extract tokens and log probabilities from chat model response."""
|
|
29
|
+
tokens = []
|
|
30
|
+
log_probs = []
|
|
31
|
+
for token in response.response_metadata["logprobs"]["content"]:
|
|
32
|
+
tokens.append(token["token"])
|
|
33
|
+
log_probs.append(token["logprob"])
|
|
34
|
+
return tokens, log_probs
|
|
71
35
|
|
|
72
36
|
|
|
73
37
|
class QuestionGeneratorChain(LLMChain):
|
|
@@ -109,11 +73,14 @@ def _low_confidence_spans(
|
|
|
109
73
|
|
|
110
74
|
class FlareChain(Chain):
|
|
111
75
|
"""Chain that combines a retriever, a question generator,
|
|
112
|
-
and a response generator.
|
|
76
|
+
and a response generator.
|
|
113
77
|
|
|
114
|
-
|
|
78
|
+
See [Active Retrieval Augmented Generation](https://arxiv.org/abs/2305.06983) paper.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
question_generator_chain: Runnable
|
|
115
82
|
"""Chain that generates questions from uncertain spans."""
|
|
116
|
-
response_chain:
|
|
83
|
+
response_chain: Runnable
|
|
117
84
|
"""Chain that generates responses from user input and context."""
|
|
118
85
|
output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
|
|
119
86
|
"""Parser that determines whether the chain is finished."""
|
|
@@ -152,12 +119,16 @@ class FlareChain(Chain):
|
|
|
152
119
|
for question in questions:
|
|
153
120
|
docs.extend(self.retriever.invoke(question))
|
|
154
121
|
context = "\n\n".join(d.page_content for d in docs)
|
|
155
|
-
result = self.response_chain.
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
122
|
+
result = self.response_chain.invoke(
|
|
123
|
+
{
|
|
124
|
+
"user_input": user_input,
|
|
125
|
+
"context": context,
|
|
126
|
+
"response": response,
|
|
127
|
+
},
|
|
128
|
+
{"callbacks": callbacks},
|
|
160
129
|
)
|
|
130
|
+
if isinstance(result, AIMessage):
|
|
131
|
+
result = result.content
|
|
161
132
|
marginal, finished = self.output_parser.parse(result)
|
|
162
133
|
return marginal, finished
|
|
163
134
|
|
|
@@ -178,13 +149,18 @@ class FlareChain(Chain):
|
|
|
178
149
|
for span in low_confidence_spans
|
|
179
150
|
]
|
|
180
151
|
callbacks = _run_manager.get_child()
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
152
|
+
if isinstance(self.question_generator_chain, LLMChain):
|
|
153
|
+
question_gen_outputs = self.question_generator_chain.apply(
|
|
154
|
+
question_gen_inputs, callbacks=callbacks
|
|
155
|
+
)
|
|
156
|
+
questions = [
|
|
157
|
+
output[self.question_generator_chain.output_keys[0]]
|
|
158
|
+
for output in question_gen_outputs
|
|
159
|
+
]
|
|
160
|
+
else:
|
|
161
|
+
questions = self.question_generator_chain.batch(
|
|
162
|
+
question_gen_inputs, config={"callbacks": callbacks}
|
|
163
|
+
)
|
|
188
164
|
_run_manager.on_text(
|
|
189
165
|
f"Generated Questions: {questions}", color="yellow", end="\n"
|
|
190
166
|
)
|
|
@@ -206,8 +182,10 @@ class FlareChain(Chain):
|
|
|
206
182
|
f"Current Response: {response}", color="blue", end="\n"
|
|
207
183
|
)
|
|
208
184
|
_input = {"user_input": user_input, "context": "", "response": response}
|
|
209
|
-
tokens, log_probs =
|
|
210
|
-
|
|
185
|
+
tokens, log_probs = _extract_tokens_and_log_probs(
|
|
186
|
+
self.response_chain.invoke(
|
|
187
|
+
_input, {"callbacks": _run_manager.get_child()}
|
|
188
|
+
)
|
|
211
189
|
)
|
|
212
190
|
low_confidence_spans = _low_confidence_spans(
|
|
213
191
|
tokens,
|
|
@@ -251,18 +229,16 @@ class FlareChain(Chain):
|
|
|
251
229
|
FlareChain class with the given language model.
|
|
252
230
|
"""
|
|
253
231
|
try:
|
|
254
|
-
from langchain_openai import
|
|
232
|
+
from langchain_openai import ChatOpenAI
|
|
255
233
|
except ImportError:
|
|
256
234
|
raise ImportError(
|
|
257
235
|
"OpenAI is required for FlareChain. "
|
|
258
236
|
"Please install langchain-openai."
|
|
259
237
|
"pip install langchain-openai"
|
|
260
238
|
)
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
)
|
|
265
|
-
response_chain = _OpenAIResponseChain(llm=response_llm)
|
|
239
|
+
llm = ChatOpenAI(max_tokens=max_generation_len, logprobs=True, temperature=0)
|
|
240
|
+
response_chain = PROMPT | llm
|
|
241
|
+
question_gen_chain = QUESTION_GENERATOR_PROMPT | llm | StrOutputParser()
|
|
266
242
|
return cls(
|
|
267
243
|
question_generator_chain=question_gen_chain,
|
|
268
244
|
response_chain=response_chain,
|
langchain/chains/hyde/base.py
CHANGED
|
@@ -11,7 +11,9 @@ import numpy as np
|
|
|
11
11
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
|
12
12
|
from langchain_core.embeddings import Embeddings
|
|
13
13
|
from langchain_core.language_models import BaseLanguageModel
|
|
14
|
+
from langchain_core.output_parsers import StrOutputParser
|
|
14
15
|
from langchain_core.prompts import BasePromptTemplate
|
|
16
|
+
from langchain_core.runnables import Runnable
|
|
15
17
|
|
|
16
18
|
from langchain.chains.base import Chain
|
|
17
19
|
from langchain.chains.hyde.prompts import PROMPT_MAP
|
|
@@ -25,7 +27,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|
|
25
27
|
"""
|
|
26
28
|
|
|
27
29
|
base_embeddings: Embeddings
|
|
28
|
-
llm_chain:
|
|
30
|
+
llm_chain: Runnable
|
|
29
31
|
|
|
30
32
|
class Config:
|
|
31
33
|
arbitrary_types_allowed = True
|
|
@@ -34,12 +36,15 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|
|
34
36
|
@property
|
|
35
37
|
def input_keys(self) -> List[str]:
|
|
36
38
|
"""Input keys for Hyde's LLM chain."""
|
|
37
|
-
return self.llm_chain.
|
|
39
|
+
return self.llm_chain.input_schema.schema()["required"]
|
|
38
40
|
|
|
39
41
|
@property
|
|
40
42
|
def output_keys(self) -> List[str]:
|
|
41
43
|
"""Output keys for Hyde's LLM chain."""
|
|
42
|
-
|
|
44
|
+
if isinstance(self.llm_chain, LLMChain):
|
|
45
|
+
return self.llm_chain.output_keys
|
|
46
|
+
else:
|
|
47
|
+
return ["text"]
|
|
43
48
|
|
|
44
49
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
45
50
|
"""Call the base embeddings."""
|
|
@@ -51,9 +56,12 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|
|
51
56
|
|
|
52
57
|
def embed_query(self, text: str) -> List[float]:
|
|
53
58
|
"""Generate a hypothetical document and embedded it."""
|
|
54
|
-
var_name = self.
|
|
55
|
-
result = self.llm_chain.
|
|
56
|
-
|
|
59
|
+
var_name = self.input_keys[0]
|
|
60
|
+
result = self.llm_chain.invoke({var_name: text})
|
|
61
|
+
if isinstance(self.llm_chain, LLMChain):
|
|
62
|
+
documents = [result[self.output_keys[0]]]
|
|
63
|
+
else:
|
|
64
|
+
documents = [result]
|
|
57
65
|
embeddings = self.embed_documents(documents)
|
|
58
66
|
return self.combine_embeddings(embeddings)
|
|
59
67
|
|
|
@@ -64,7 +72,9 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|
|
64
72
|
) -> Dict[str, str]:
|
|
65
73
|
"""Call the internal llm chain."""
|
|
66
74
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
67
|
-
return self.llm_chain
|
|
75
|
+
return self.llm_chain.invoke(
|
|
76
|
+
inputs, config={"callbacks": _run_manager.get_child()}
|
|
77
|
+
)
|
|
68
78
|
|
|
69
79
|
@classmethod
|
|
70
80
|
def from_llm(
|
|
@@ -86,7 +96,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|
|
86
96
|
f"of {list(PROMPT_MAP.keys())}."
|
|
87
97
|
)
|
|
88
98
|
|
|
89
|
-
llm_chain =
|
|
99
|
+
llm_chain = prompt | llm | StrOutputParser()
|
|
90
100
|
return cls(base_embeddings=base_embeddings, llm_chain=llm_chain, **kwargs)
|
|
91
101
|
|
|
92
102
|
@property
|
|
@@ -7,6 +7,7 @@ import re
|
|
|
7
7
|
import warnings
|
|
8
8
|
from typing import Any, Dict, List, Optional
|
|
9
9
|
|
|
10
|
+
from langchain_core._api import deprecated
|
|
10
11
|
from langchain_core.callbacks import (
|
|
11
12
|
AsyncCallbackManagerForChainRun,
|
|
12
13
|
CallbackManagerForChainRun,
|
|
@@ -20,16 +21,132 @@ from langchain.chains.llm import LLMChain
|
|
|
20
21
|
from langchain.chains.llm_math.prompt import PROMPT
|
|
21
22
|
|
|
22
23
|
|
|
24
|
+
@deprecated(
|
|
25
|
+
since="0.2.13",
|
|
26
|
+
message=(
|
|
27
|
+
"This class is deprecated and will be removed in langchain 1.0. "
|
|
28
|
+
"See API reference for replacement: "
|
|
29
|
+
"https://api.python.langchain.com/en/latest/chains/langchain.chains.llm_math.base.LLMMathChain.html" # noqa: E501
|
|
30
|
+
),
|
|
31
|
+
removal="1.0",
|
|
32
|
+
)
|
|
23
33
|
class LLMMathChain(Chain):
|
|
24
34
|
"""Chain that interprets a prompt and executes python code to do math.
|
|
25
35
|
|
|
36
|
+
Note: this class is deprecated. See below for a replacement implementation
|
|
37
|
+
using LangGraph. The benefits of this implementation are:
|
|
38
|
+
|
|
39
|
+
- Uses LLM tool calling features;
|
|
40
|
+
- Support for both token-by-token and step-by-step streaming;
|
|
41
|
+
- Support for checkpointing and memory of chat history;
|
|
42
|
+
- Easier to modify or extend (e.g., with additional tools, structured responses, etc.)
|
|
43
|
+
|
|
44
|
+
Install LangGraph with:
|
|
45
|
+
|
|
46
|
+
.. code-block:: bash
|
|
47
|
+
|
|
48
|
+
pip install -U langgraph
|
|
49
|
+
|
|
50
|
+
.. code-block:: python
|
|
51
|
+
|
|
52
|
+
import math
|
|
53
|
+
from typing import Annotated, Sequence
|
|
54
|
+
|
|
55
|
+
from langchain_core.messages import BaseMessage
|
|
56
|
+
from langchain_core.runnables import RunnableConfig
|
|
57
|
+
from langchain_core.tools import tool
|
|
58
|
+
from langchain_openai import ChatOpenAI
|
|
59
|
+
from langgraph.graph import END, StateGraph
|
|
60
|
+
from langgraph.graph.message import add_messages
|
|
61
|
+
from langgraph.prebuilt.tool_node import ToolNode
|
|
62
|
+
import numexpr
|
|
63
|
+
from typing_extensions import TypedDict
|
|
64
|
+
|
|
65
|
+
@tool
|
|
66
|
+
def calculator(expression: str) -> str:
|
|
67
|
+
\"\"\"Calculate expression using Python's numexpr library.
|
|
68
|
+
|
|
69
|
+
Expression should be a single line mathematical expression
|
|
70
|
+
that solves the problem.
|
|
71
|
+
|
|
72
|
+
Examples:
|
|
73
|
+
"37593 * 67" for "37593 times 67"
|
|
74
|
+
"37593**(1/5)" for "37593^(1/5)"
|
|
75
|
+
\"\"\"
|
|
76
|
+
local_dict = {"pi": math.pi, "e": math.e}
|
|
77
|
+
return str(
|
|
78
|
+
numexpr.evaluate(
|
|
79
|
+
expression.strip(),
|
|
80
|
+
global_dict={}, # restrict access to globals
|
|
81
|
+
local_dict=local_dict, # add common mathematical functions
|
|
82
|
+
)
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
|
86
|
+
tools = [calculator]
|
|
87
|
+
llm_with_tools = llm.bind_tools(tools, tool_choice="any")
|
|
88
|
+
|
|
89
|
+
class ChainState(TypedDict):
|
|
90
|
+
\"\"\"LangGraph state.\"\"\"
|
|
91
|
+
|
|
92
|
+
messages: Annotated[Sequence[BaseMessage], add_messages]
|
|
93
|
+
|
|
94
|
+
async def acall_chain(state: ChainState, config: RunnableConfig):
|
|
95
|
+
last_message = state["messages"][-1]
|
|
96
|
+
response = await llm_with_tools.ainvoke(state["messages"], config)
|
|
97
|
+
return {"messages": [response]}
|
|
98
|
+
|
|
99
|
+
async def acall_model(state: ChainState, config: RunnableConfig):
|
|
100
|
+
response = await llm.ainvoke(state["messages"], config)
|
|
101
|
+
return {"messages": [response]}
|
|
102
|
+
|
|
103
|
+
graph_builder = StateGraph(ChainState)
|
|
104
|
+
graph_builder.add_node("call_tool", acall_chain)
|
|
105
|
+
graph_builder.add_node("execute_tool", ToolNode(tools))
|
|
106
|
+
graph_builder.add_node("call_model", acall_model)
|
|
107
|
+
graph_builder.set_entry_point("call_tool")
|
|
108
|
+
graph_builder.add_edge("call_tool", "execute_tool")
|
|
109
|
+
graph_builder.add_edge("execute_tool", "call_model")
|
|
110
|
+
graph_builder.add_edge("call_model", END)
|
|
111
|
+
chain = graph_builder.compile()
|
|
112
|
+
|
|
113
|
+
.. code-block:: python
|
|
114
|
+
|
|
115
|
+
example_query = "What is 551368 divided by 82"
|
|
116
|
+
|
|
117
|
+
events = chain.astream(
|
|
118
|
+
{"messages": [("user", example_query)]},
|
|
119
|
+
stream_mode="values",
|
|
120
|
+
)
|
|
121
|
+
async for event in events:
|
|
122
|
+
event["messages"][-1].pretty_print()
|
|
123
|
+
|
|
124
|
+
.. code-block:: none
|
|
125
|
+
|
|
126
|
+
================================ Human Message =================================
|
|
127
|
+
|
|
128
|
+
What is 551368 divided by 82
|
|
129
|
+
================================== Ai Message ==================================
|
|
130
|
+
Tool Calls:
|
|
131
|
+
calculator (call_MEiGXuJjJ7wGU4aOT86QuGJS)
|
|
132
|
+
Call ID: call_MEiGXuJjJ7wGU4aOT86QuGJS
|
|
133
|
+
Args:
|
|
134
|
+
expression: 551368 / 82
|
|
135
|
+
================================= Tool Message =================================
|
|
136
|
+
Name: calculator
|
|
137
|
+
|
|
138
|
+
6724.0
|
|
139
|
+
================================== Ai Message ==================================
|
|
140
|
+
|
|
141
|
+
551368 divided by 82 equals 6724.
|
|
142
|
+
|
|
26
143
|
Example:
|
|
27
144
|
.. code-block:: python
|
|
28
145
|
|
|
29
146
|
from langchain.chains import LLMMathChain
|
|
30
147
|
from langchain_community.llms import OpenAI
|
|
31
148
|
llm_math = LLMMathChain.from_llm(OpenAI())
|
|
32
|
-
"""
|
|
149
|
+
""" # noqa: E501
|
|
33
150
|
|
|
34
151
|
llm_chain: LLMChain
|
|
35
152
|
llm: Optional[BaseLanguageModel] = None
|
langchain/chains/moderation.py
CHANGED
|
@@ -38,7 +38,7 @@ class OpenAIModerationChain(Chain):
|
|
|
38
38
|
output_key: str = "output" #: :meta private:
|
|
39
39
|
openai_api_key: Optional[str] = None
|
|
40
40
|
openai_organization: Optional[str] = None
|
|
41
|
-
|
|
41
|
+
openai_pre_1_0: bool = Field(default=None)
|
|
42
42
|
|
|
43
43
|
@root_validator(pre=True)
|
|
44
44
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
@@ -58,16 +58,17 @@ class OpenAIModerationChain(Chain):
|
|
|
58
58
|
openai.api_key = openai_api_key
|
|
59
59
|
if openai_organization:
|
|
60
60
|
openai.organization = openai_organization
|
|
61
|
-
values["
|
|
61
|
+
values["openai_pre_1_0"] = False
|
|
62
62
|
try:
|
|
63
63
|
check_package_version("openai", gte_version="1.0")
|
|
64
64
|
except ValueError:
|
|
65
|
-
values["
|
|
66
|
-
if values["
|
|
65
|
+
values["openai_pre_1_0"] = True
|
|
66
|
+
if values["openai_pre_1_0"]:
|
|
67
67
|
values["client"] = openai.Moderation
|
|
68
68
|
else:
|
|
69
69
|
values["client"] = openai.OpenAI()
|
|
70
70
|
values["async_client"] = openai.AsyncOpenAI()
|
|
71
|
+
|
|
71
72
|
except ImportError:
|
|
72
73
|
raise ImportError(
|
|
73
74
|
"Could not import openai python package. "
|
|
@@ -92,7 +93,7 @@ class OpenAIModerationChain(Chain):
|
|
|
92
93
|
return [self.output_key]
|
|
93
94
|
|
|
94
95
|
def _moderate(self, text: str, results: Any) -> str:
|
|
95
|
-
if self.
|
|
96
|
+
if self.openai_pre_1_0:
|
|
96
97
|
condition = results["flagged"]
|
|
97
98
|
else:
|
|
98
99
|
condition = results.flagged
|
|
@@ -110,7 +111,7 @@ class OpenAIModerationChain(Chain):
|
|
|
110
111
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
111
112
|
) -> Dict[str, Any]:
|
|
112
113
|
text = inputs[self.input_key]
|
|
113
|
-
if self.
|
|
114
|
+
if self.openai_pre_1_0:
|
|
114
115
|
results = self.client.create(text)
|
|
115
116
|
output = self._moderate(text, results["results"][0])
|
|
116
117
|
else:
|
|
@@ -123,7 +124,7 @@ class OpenAIModerationChain(Chain):
|
|
|
123
124
|
inputs: Dict[str, Any],
|
|
124
125
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
|
125
126
|
) -> Dict[str, Any]:
|
|
126
|
-
if self.
|
|
127
|
+
if self.openai_pre_1_0:
|
|
127
128
|
return await super()._acall(inputs, run_manager=run_manager)
|
|
128
129
|
text = inputs[self.input_key]
|
|
129
130
|
results = await self.async_client.moderations.create(input=text)
|
langchain/chains/natbot/base.py
CHANGED
|
@@ -5,15 +5,27 @@ from __future__ import annotations
|
|
|
5
5
|
import warnings
|
|
6
6
|
from typing import Any, Dict, List, Optional
|
|
7
7
|
|
|
8
|
+
from langchain_core._api import deprecated
|
|
8
9
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
|
9
10
|
from langchain_core.language_models import BaseLanguageModel
|
|
11
|
+
from langchain_core.output_parsers import StrOutputParser
|
|
10
12
|
from langchain_core.pydantic_v1 import root_validator
|
|
13
|
+
from langchain_core.runnables import Runnable
|
|
11
14
|
|
|
12
15
|
from langchain.chains.base import Chain
|
|
13
|
-
from langchain.chains.llm import LLMChain
|
|
14
16
|
from langchain.chains.natbot.prompt import PROMPT
|
|
15
17
|
|
|
16
18
|
|
|
19
|
+
@deprecated(
|
|
20
|
+
since="0.2.13",
|
|
21
|
+
message=(
|
|
22
|
+
"Importing NatBotChain from langchain is deprecated and will be removed in "
|
|
23
|
+
"langchain 1.0. Please import from langchain_community instead: "
|
|
24
|
+
"from langchain_community.chains.natbot import NatBotChain. "
|
|
25
|
+
"You may need to pip install -U langchain-community."
|
|
26
|
+
),
|
|
27
|
+
removal="1.0",
|
|
28
|
+
)
|
|
17
29
|
class NatBotChain(Chain):
|
|
18
30
|
"""Implement an LLM driven browser.
|
|
19
31
|
|
|
@@ -37,7 +49,7 @@ class NatBotChain(Chain):
|
|
|
37
49
|
natbot = NatBotChain.from_default("Buy me a new hat.")
|
|
38
50
|
"""
|
|
39
51
|
|
|
40
|
-
llm_chain:
|
|
52
|
+
llm_chain: Runnable
|
|
41
53
|
objective: str
|
|
42
54
|
"""Objective that NatBot is tasked with completing."""
|
|
43
55
|
llm: Optional[BaseLanguageModel] = None
|
|
@@ -60,7 +72,7 @@ class NatBotChain(Chain):
|
|
|
60
72
|
"class method."
|
|
61
73
|
)
|
|
62
74
|
if "llm_chain" not in values and values["llm"] is not None:
|
|
63
|
-
values["llm_chain"] =
|
|
75
|
+
values["llm_chain"] = PROMPT | values["llm"] | StrOutputParser()
|
|
64
76
|
return values
|
|
65
77
|
|
|
66
78
|
@classmethod
|
|
@@ -77,7 +89,7 @@ class NatBotChain(Chain):
|
|
|
77
89
|
cls, llm: BaseLanguageModel, objective: str, **kwargs: Any
|
|
78
90
|
) -> NatBotChain:
|
|
79
91
|
"""Load from LLM."""
|
|
80
|
-
llm_chain =
|
|
92
|
+
llm_chain = PROMPT | llm | StrOutputParser()
|
|
81
93
|
return cls(llm_chain=llm_chain, objective=objective, **kwargs)
|
|
82
94
|
|
|
83
95
|
@property
|
|
@@ -104,12 +116,14 @@ class NatBotChain(Chain):
|
|
|
104
116
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
105
117
|
url = inputs[self.input_url_key]
|
|
106
118
|
browser_content = inputs[self.input_browser_content_key]
|
|
107
|
-
llm_cmd = self.llm_chain.
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
119
|
+
llm_cmd = self.llm_chain.invoke(
|
|
120
|
+
{
|
|
121
|
+
"objective": self.objective,
|
|
122
|
+
"url": url[:100],
|
|
123
|
+
"previous_command": self.previous_command,
|
|
124
|
+
"browser_content": browser_content[:4500],
|
|
125
|
+
},
|
|
126
|
+
config={"callbacks": _run_manager.get_child()},
|
|
113
127
|
)
|
|
114
128
|
llm_cmd = llm_cmd.strip()
|
|
115
129
|
self.previous_command = llm_cmd
|
|
@@ -35,6 +35,7 @@ GRAMMAR = r"""
|
|
|
35
35
|
?value: SIGNED_INT -> int
|
|
36
36
|
| SIGNED_FLOAT -> float
|
|
37
37
|
| DATE -> date
|
|
38
|
+
| DATETIME -> datetime
|
|
38
39
|
| list
|
|
39
40
|
| string
|
|
40
41
|
| ("false" | "False" | "FALSE") -> false
|
|
@@ -42,6 +43,7 @@ GRAMMAR = r"""
|
|
|
42
43
|
|
|
43
44
|
args: expr ("," expr)*
|
|
44
45
|
DATE.2: /["']?(\d{4}-[01]\d-[0-3]\d)["']?/
|
|
46
|
+
DATETIME.2: /["']?\d{4}-[01]\d-[0-3]\dT[0-2]\d:[0-5]\d:[0-5]\d[Zz]?["']?/
|
|
45
47
|
string: /'[^']*'/ | ESCAPED_STRING
|
|
46
48
|
list: "[" [args] "]"
|
|
47
49
|
|
|
@@ -61,6 +63,13 @@ class ISO8601Date(TypedDict):
|
|
|
61
63
|
type: Literal["date"]
|
|
62
64
|
|
|
63
65
|
|
|
66
|
+
class ISO8601DateTime(TypedDict):
|
|
67
|
+
"""A datetime in ISO 8601 format (YYYY-MM-DDTHH:MM:SS)."""
|
|
68
|
+
|
|
69
|
+
datetime: str
|
|
70
|
+
type: Literal["datetime"]
|
|
71
|
+
|
|
72
|
+
|
|
64
73
|
@v_args(inline=True)
|
|
65
74
|
class QueryTransformer(Transformer):
|
|
66
75
|
"""Transform a query string into an intermediate representation."""
|
|
@@ -149,6 +158,20 @@ class QueryTransformer(Transformer):
|
|
|
149
158
|
)
|
|
150
159
|
return {"date": item, "type": "date"}
|
|
151
160
|
|
|
161
|
+
def datetime(self, item: Any) -> ISO8601DateTime:
|
|
162
|
+
item = str(item).strip("\"'")
|
|
163
|
+
try:
|
|
164
|
+
# Parse full ISO 8601 datetime format
|
|
165
|
+
datetime.datetime.strptime(item, "%Y-%m-%dT%H:%M:%S%z")
|
|
166
|
+
except ValueError:
|
|
167
|
+
try:
|
|
168
|
+
datetime.datetime.strptime(item, "%Y-%m-%dT%H:%M:%S")
|
|
169
|
+
except ValueError:
|
|
170
|
+
raise ValueError(
|
|
171
|
+
"Datetime values are expected to be in ISO 8601 format."
|
|
172
|
+
)
|
|
173
|
+
return {"datetime": item, "type": "datetime"}
|
|
174
|
+
|
|
152
175
|
def string(self, item: Any) -> str:
|
|
153
176
|
# Remove escaped quotes
|
|
154
177
|
return str(item).strip("\"'")
|
|
@@ -8,8 +8,9 @@ from typing import Any, Callable, Dict, Optional, Sequence, cast
|
|
|
8
8
|
from langchain_core.callbacks.manager import Callbacks
|
|
9
9
|
from langchain_core.documents import Document
|
|
10
10
|
from langchain_core.language_models import BaseLanguageModel
|
|
11
|
-
from langchain_core.output_parsers import BaseOutputParser
|
|
11
|
+
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
|
|
12
12
|
from langchain_core.prompts import PromptTemplate
|
|
13
|
+
from langchain_core.runnables import Runnable
|
|
13
14
|
|
|
14
15
|
from langchain.chains.llm import LLMChain
|
|
15
16
|
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
|
@@ -49,12 +50,15 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
|
|
49
50
|
"""Document compressor that uses an LLM chain to extract
|
|
50
51
|
the relevant parts of documents."""
|
|
51
52
|
|
|
52
|
-
llm_chain:
|
|
53
|
+
llm_chain: Runnable
|
|
53
54
|
"""LLM wrapper to use for compressing documents."""
|
|
54
55
|
|
|
55
56
|
get_input: Callable[[str, Document], dict] = default_get_input
|
|
56
57
|
"""Callable for constructing the chain input from the query and a Document."""
|
|
57
58
|
|
|
59
|
+
class Config:
|
|
60
|
+
arbitrary_types_allowed = True
|
|
61
|
+
|
|
58
62
|
def compress_documents(
|
|
59
63
|
self,
|
|
60
64
|
documents: Sequence[Document],
|
|
@@ -65,10 +69,13 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
|
|
65
69
|
compressed_docs = []
|
|
66
70
|
for doc in documents:
|
|
67
71
|
_input = self.get_input(query, doc)
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
+
output_ = self.llm_chain.invoke(_input, config={"callbacks": callbacks})
|
|
73
|
+
if isinstance(self.llm_chain, LLMChain):
|
|
74
|
+
output = output_[self.llm_chain.output_key]
|
|
75
|
+
if self.llm_chain.prompt.output_parser is not None:
|
|
76
|
+
output = self.llm_chain.prompt.output_parser.parse(output)
|
|
77
|
+
else:
|
|
78
|
+
output = output_
|
|
72
79
|
if len(output) == 0:
|
|
73
80
|
continue
|
|
74
81
|
compressed_docs.append(
|
|
@@ -85,9 +92,7 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
|
|
85
92
|
"""Compress page content of raw documents asynchronously."""
|
|
86
93
|
outputs = await asyncio.gather(
|
|
87
94
|
*[
|
|
88
|
-
self.llm_chain.
|
|
89
|
-
**self.get_input(query, doc), callbacks=callbacks
|
|
90
|
-
)
|
|
95
|
+
self.llm_chain.ainvoke(self.get_input(query, doc), callbacks=callbacks)
|
|
91
96
|
for doc in documents
|
|
92
97
|
]
|
|
93
98
|
)
|
|
@@ -111,5 +116,9 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
|
|
111
116
|
"""Initialize from LLM."""
|
|
112
117
|
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
|
|
113
118
|
_get_input = get_input if get_input is not None else default_get_input
|
|
114
|
-
|
|
119
|
+
if _prompt.output_parser is not None:
|
|
120
|
+
parser = _prompt.output_parser
|
|
121
|
+
else:
|
|
122
|
+
parser = StrOutputParser()
|
|
123
|
+
llm_chain = _prompt | llm | parser
|
|
115
124
|
return cls(llm_chain=llm_chain, get_input=_get_input) # type: ignore[arg-type]
|