langchain 0.3.27__py3-none-any.whl → 0.4.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/agents/agent.py +16 -20
- langchain/agents/agent_iterator.py +19 -12
- langchain/agents/agent_toolkits/vectorstore/base.py +2 -0
- langchain/agents/chat/base.py +2 -0
- langchain/agents/conversational/base.py +2 -0
- langchain/agents/conversational_chat/base.py +2 -0
- langchain/agents/initialize.py +1 -1
- langchain/agents/json_chat/base.py +1 -0
- langchain/agents/mrkl/base.py +2 -0
- langchain/agents/openai_assistant/base.py +1 -1
- langchain/agents/openai_functions_agent/agent_token_buffer_memory.py +2 -0
- langchain/agents/openai_functions_agent/base.py +3 -2
- langchain/agents/openai_functions_multi_agent/base.py +1 -1
- langchain/agents/openai_tools/base.py +1 -0
- langchain/agents/output_parsers/json.py +2 -0
- langchain/agents/output_parsers/openai_functions.py +10 -3
- langchain/agents/output_parsers/openai_tools.py +8 -1
- langchain/agents/output_parsers/react_json_single_input.py +3 -0
- langchain/agents/output_parsers/react_single_input.py +3 -0
- langchain/agents/output_parsers/self_ask.py +2 -0
- langchain/agents/output_parsers/tools.py +16 -2
- langchain/agents/output_parsers/xml.py +3 -0
- langchain/agents/react/agent.py +1 -0
- langchain/agents/react/base.py +4 -0
- langchain/agents/react/output_parser.py +2 -0
- langchain/agents/schema.py +2 -0
- langchain/agents/self_ask_with_search/base.py +4 -0
- langchain/agents/structured_chat/base.py +5 -0
- langchain/agents/structured_chat/output_parser.py +13 -0
- langchain/agents/tool_calling_agent/base.py +1 -0
- langchain/agents/tools.py +3 -0
- langchain/agents/xml/base.py +7 -1
- langchain/callbacks/streaming_aiter.py +13 -2
- langchain/callbacks/streaming_aiter_final_only.py +11 -2
- langchain/callbacks/streaming_stdout_final_only.py +5 -0
- langchain/callbacks/tracers/logging.py +11 -0
- langchain/chains/api/base.py +5 -1
- langchain/chains/base.py +8 -2
- langchain/chains/combine_documents/base.py +7 -1
- langchain/chains/combine_documents/map_reduce.py +3 -0
- langchain/chains/combine_documents/map_rerank.py +6 -4
- langchain/chains/combine_documents/reduce.py +1 -0
- langchain/chains/combine_documents/refine.py +1 -0
- langchain/chains/combine_documents/stuff.py +5 -1
- langchain/chains/constitutional_ai/base.py +7 -0
- langchain/chains/conversation/base.py +4 -1
- langchain/chains/conversational_retrieval/base.py +67 -59
- langchain/chains/elasticsearch_database/base.py +2 -1
- langchain/chains/flare/base.py +2 -0
- langchain/chains/flare/prompts.py +2 -0
- langchain/chains/llm.py +7 -2
- langchain/chains/llm_bash/__init__.py +1 -1
- langchain/chains/llm_checker/base.py +12 -1
- langchain/chains/llm_math/base.py +9 -1
- langchain/chains/llm_summarization_checker/base.py +13 -1
- langchain/chains/llm_symbolic_math/__init__.py +1 -1
- langchain/chains/loading.py +4 -2
- langchain/chains/moderation.py +3 -0
- langchain/chains/natbot/base.py +3 -1
- langchain/chains/natbot/crawler.py +29 -0
- langchain/chains/openai_functions/base.py +2 -0
- langchain/chains/openai_functions/citation_fuzzy_match.py +9 -0
- langchain/chains/openai_functions/openapi.py +4 -0
- langchain/chains/openai_functions/qa_with_structure.py +3 -3
- langchain/chains/openai_functions/tagging.py +2 -0
- langchain/chains/qa_generation/base.py +4 -0
- langchain/chains/qa_with_sources/base.py +3 -0
- langchain/chains/qa_with_sources/retrieval.py +1 -1
- langchain/chains/qa_with_sources/vector_db.py +4 -2
- langchain/chains/query_constructor/base.py +4 -2
- langchain/chains/query_constructor/parser.py +64 -2
- langchain/chains/retrieval_qa/base.py +4 -0
- langchain/chains/router/base.py +14 -2
- langchain/chains/router/embedding_router.py +3 -0
- langchain/chains/router/llm_router.py +6 -4
- langchain/chains/router/multi_prompt.py +3 -0
- langchain/chains/router/multi_retrieval_qa.py +18 -0
- langchain/chains/sql_database/query.py +1 -0
- langchain/chains/structured_output/base.py +2 -0
- langchain/chains/transform.py +4 -0
- langchain/chat_models/base.py +55 -18
- langchain/document_loaders/blob_loaders/schema.py +1 -4
- langchain/embeddings/base.py +2 -0
- langchain/embeddings/cache.py +3 -3
- langchain/evaluation/agents/trajectory_eval_chain.py +3 -2
- langchain/evaluation/comparison/eval_chain.py +1 -0
- langchain/evaluation/criteria/eval_chain.py +3 -0
- langchain/evaluation/embedding_distance/base.py +11 -0
- langchain/evaluation/exact_match/base.py +14 -1
- langchain/evaluation/loading.py +1 -0
- langchain/evaluation/parsing/base.py +16 -3
- langchain/evaluation/parsing/json_distance.py +19 -8
- langchain/evaluation/parsing/json_schema.py +1 -4
- langchain/evaluation/qa/eval_chain.py +8 -0
- langchain/evaluation/qa/generate_chain.py +2 -0
- langchain/evaluation/regex_match/base.py +9 -1
- langchain/evaluation/scoring/eval_chain.py +1 -0
- langchain/evaluation/string_distance/base.py +6 -0
- langchain/memory/buffer.py +5 -0
- langchain/memory/buffer_window.py +2 -0
- langchain/memory/combined.py +1 -1
- langchain/memory/entity.py +47 -0
- langchain/memory/simple.py +3 -0
- langchain/memory/summary.py +30 -0
- langchain/memory/summary_buffer.py +3 -0
- langchain/memory/token_buffer.py +2 -0
- langchain/output_parsers/combining.py +4 -2
- langchain/output_parsers/enum.py +5 -1
- langchain/output_parsers/fix.py +8 -1
- langchain/output_parsers/pandas_dataframe.py +16 -1
- langchain/output_parsers/regex.py +2 -0
- langchain/output_parsers/retry.py +21 -1
- langchain/output_parsers/structured.py +10 -0
- langchain/output_parsers/yaml.py +4 -0
- langchain/pydantic_v1/__init__.py +1 -1
- langchain/retrievers/document_compressors/chain_extract.py +4 -2
- langchain/retrievers/document_compressors/cohere_rerank.py +2 -0
- langchain/retrievers/document_compressors/cross_encoder_rerank.py +2 -0
- langchain/retrievers/document_compressors/embeddings_filter.py +3 -0
- langchain/retrievers/document_compressors/listwise_rerank.py +1 -0
- langchain/retrievers/ensemble.py +2 -2
- langchain/retrievers/multi_query.py +3 -1
- langchain/retrievers/multi_vector.py +4 -1
- langchain/retrievers/parent_document_retriever.py +15 -0
- langchain/retrievers/self_query/base.py +19 -0
- langchain/retrievers/time_weighted_retriever.py +3 -0
- langchain/runnables/hub.py +12 -0
- langchain/runnables/openai_functions.py +6 -0
- langchain/smith/__init__.py +1 -0
- langchain/smith/evaluation/config.py +5 -22
- langchain/smith/evaluation/progress.py +12 -3
- langchain/smith/evaluation/runner_utils.py +240 -123
- langchain/smith/evaluation/string_run_evaluator.py +27 -0
- langchain/storage/encoder_backed.py +1 -0
- langchain/tools/python/__init__.py +1 -1
- {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/METADATA +2 -12
- {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/RECORD +140 -141
- langchain/smith/evaluation/utils.py +0 -0
- {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/WHEEL +0 -0
- {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/entry_points.txt +0 -0
- {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/licenses/LICENSE +0 -0
|
@@ -61,6 +61,7 @@ class Crawler:
|
|
|
61
61
|
"""
|
|
62
62
|
|
|
63
63
|
def __init__(self) -> None:
|
|
64
|
+
"""Initialize the crawler."""
|
|
64
65
|
try:
|
|
65
66
|
from playwright.sync_api import sync_playwright
|
|
66
67
|
except ImportError as e:
|
|
@@ -78,11 +79,22 @@ class Crawler:
|
|
|
78
79
|
self.client: CDPSession
|
|
79
80
|
|
|
80
81
|
def go_to_page(self, url: str) -> None:
|
|
82
|
+
"""Navigate to the given URL.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
url: The URL to navigate to. If it does not contain a scheme, it will be
|
|
86
|
+
prefixed with "http://".
|
|
87
|
+
"""
|
|
81
88
|
self.page.goto(url=url if "://" in url else "http://" + url)
|
|
82
89
|
self.client = self.page.context.new_cdp_session(self.page)
|
|
83
90
|
self.page_element_buffer = {}
|
|
84
91
|
|
|
85
92
|
def scroll(self, direction: str) -> None:
|
|
93
|
+
"""Scroll the page in the given direction.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
direction: The direction to scroll in, either "up" or "down".
|
|
97
|
+
"""
|
|
86
98
|
if direction == "up":
|
|
87
99
|
self.page.evaluate(
|
|
88
100
|
"(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop - window.innerHeight;" # noqa: E501
|
|
@@ -93,6 +105,11 @@ class Crawler:
|
|
|
93
105
|
)
|
|
94
106
|
|
|
95
107
|
def click(self, id_: Union[str, int]) -> None:
|
|
108
|
+
"""Click on an element with the given id.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
id_: The id of the element to click on.
|
|
112
|
+
"""
|
|
96
113
|
# Inject javascript into the page which removes the target= attribute from links
|
|
97
114
|
js = """
|
|
98
115
|
links = document.getElementsByTagName("a");
|
|
@@ -112,13 +129,25 @@ class Crawler:
|
|
|
112
129
|
print("Could not find element") # noqa: T201
|
|
113
130
|
|
|
114
131
|
def type(self, id_: Union[str, int], text: str) -> None:
|
|
132
|
+
"""Type text into an element with the given id.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
id_: The id of the element to type into.
|
|
136
|
+
text: The text to type into the element.
|
|
137
|
+
"""
|
|
115
138
|
self.click(id_)
|
|
116
139
|
self.page.keyboard.type(text)
|
|
117
140
|
|
|
118
141
|
def enter(self) -> None:
|
|
142
|
+
"""Press the Enter key."""
|
|
119
143
|
self.page.keyboard.press("Enter")
|
|
120
144
|
|
|
121
145
|
def crawl(self) -> list[str]:
|
|
146
|
+
"""Crawl the current page.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
A list of the elements in the viewport.
|
|
150
|
+
"""
|
|
122
151
|
page = self.page
|
|
123
152
|
page_element_buffer = self.page_element_buffer
|
|
124
153
|
start = time.time()
|
|
@@ -121,6 +121,7 @@ def create_openai_fn_chain(
|
|
|
121
121
|
chain = create_openai_fn_chain([RecordPerson, RecordDog], llm, prompt)
|
|
122
122
|
chain.run("Harry was a chubby brown beagle who loved chicken")
|
|
123
123
|
# -> RecordDog(name="Harry", color="brown", fav_food="chicken")
|
|
124
|
+
|
|
124
125
|
""" # noqa: E501
|
|
125
126
|
if not functions:
|
|
126
127
|
msg = "Need to pass in at least one function. Received zero."
|
|
@@ -203,6 +204,7 @@ def create_structured_output_chain(
|
|
|
203
204
|
chain = create_structured_output_chain(Dog, llm, prompt)
|
|
204
205
|
chain.run("Harry was a chubby brown beagle who loved chicken")
|
|
205
206
|
# -> Dog(name="Harry", color="brown", fav_food="chicken")
|
|
207
|
+
|
|
206
208
|
""" # noqa: E501
|
|
207
209
|
if isinstance(output_schema, dict):
|
|
208
210
|
function: Any = {
|
|
@@ -45,6 +45,14 @@ class FactWithEvidence(BaseModel):
|
|
|
45
45
|
yield from s.spans()
|
|
46
46
|
|
|
47
47
|
def get_spans(self, context: str) -> Iterator[str]:
|
|
48
|
+
"""Get spans of the substring quote in the context.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
context: The context in which to find the spans of the substring quote.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
An iterator over the spans of the substring quote in the context.
|
|
55
|
+
"""
|
|
48
56
|
for quote in self.substring_quote:
|
|
49
57
|
yield from self._get_span(quote, context)
|
|
50
58
|
|
|
@@ -86,6 +94,7 @@ def create_citation_fuzzy_match_runnable(llm: BaseChatModel) -> Runnable:
|
|
|
86
94
|
|
|
87
95
|
Returns:
|
|
88
96
|
Runnable that can be used to answer questions with citations.
|
|
97
|
+
|
|
89
98
|
"""
|
|
90
99
|
if llm.bind_tools is BaseChatModel.bind_tools:
|
|
91
100
|
msg = "Language model must implement bind_tools to use this function."
|
|
@@ -13,6 +13,7 @@ from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsPa
|
|
|
13
13
|
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
|
|
14
14
|
from langchain_core.utils.input import get_colored_text
|
|
15
15
|
from requests import Response
|
|
16
|
+
from typing_extensions import override
|
|
16
17
|
|
|
17
18
|
from langchain.chains.base import Chain
|
|
18
19
|
from langchain.chains.llm import LLMChain
|
|
@@ -202,10 +203,12 @@ class SimpleRequestChain(Chain):
|
|
|
202
203
|
"""Key to use for the input of the request."""
|
|
203
204
|
|
|
204
205
|
@property
|
|
206
|
+
@override
|
|
205
207
|
def input_keys(self) -> list[str]:
|
|
206
208
|
return [self.input_key]
|
|
207
209
|
|
|
208
210
|
@property
|
|
211
|
+
@override
|
|
209
212
|
def output_keys(self) -> list[str]:
|
|
210
213
|
return [self.output_key]
|
|
211
214
|
|
|
@@ -342,6 +345,7 @@ def get_openapi_chain(
|
|
|
342
345
|
`ChatOpenAI(model="gpt-3.5-turbo-0613")`.
|
|
343
346
|
prompt: Main prompt template to use.
|
|
344
347
|
request_chain: Chain for taking the functions output and executing the request.
|
|
348
|
+
|
|
345
349
|
""" # noqa: E501
|
|
346
350
|
try:
|
|
347
351
|
from langchain_community.utilities.openapi import OpenAPISpec
|
|
@@ -76,11 +76,11 @@ def create_qa_with_structure_chain(
|
|
|
76
76
|
raise ValueError(msg)
|
|
77
77
|
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
|
78
78
|
if hasattr(schema, "model_json_schema"):
|
|
79
|
-
schema_dict = cast(dict, schema.model_json_schema())
|
|
79
|
+
schema_dict = cast("dict", schema.model_json_schema())
|
|
80
80
|
else:
|
|
81
|
-
schema_dict = cast(dict, schema.schema())
|
|
81
|
+
schema_dict = cast("dict", schema.schema())
|
|
82
82
|
else:
|
|
83
|
-
schema_dict = cast(dict, schema)
|
|
83
|
+
schema_dict = cast("dict", schema)
|
|
84
84
|
function = {
|
|
85
85
|
"name": schema_dict["title"],
|
|
86
86
|
"description": schema_dict["description"],
|
|
@@ -86,6 +86,7 @@ def create_tagging_chain(
|
|
|
86
86
|
|
|
87
87
|
Returns:
|
|
88
88
|
Chain (LLMChain) that can be used to extract information from a passage.
|
|
89
|
+
|
|
89
90
|
"""
|
|
90
91
|
function = _get_tagging_function(schema)
|
|
91
92
|
prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
|
@@ -154,6 +155,7 @@ def create_tagging_chain_pydantic(
|
|
|
154
155
|
|
|
155
156
|
Returns:
|
|
156
157
|
Chain (LLMChain) that can be used to extract information from a passage.
|
|
158
|
+
|
|
157
159
|
"""
|
|
158
160
|
if hasattr(pydantic_schema, "model_json_schema"):
|
|
159
161
|
openai_schema = pydantic_schema.model_json_schema()
|
|
@@ -9,6 +9,7 @@ from langchain_core.language_models import BaseLanguageModel
|
|
|
9
9
|
from langchain_core.prompts import BasePromptTemplate
|
|
10
10
|
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
|
11
11
|
from pydantic import Field
|
|
12
|
+
from typing_extensions import override
|
|
12
13
|
|
|
13
14
|
from langchain.chains.base import Chain
|
|
14
15
|
from langchain.chains.llm import LLMChain
|
|
@@ -61,6 +62,7 @@ class QAGenerationChain(Chain):
|
|
|
61
62
|
split_text | RunnableEach(bound=prompt | llm | JsonOutputParser())
|
|
62
63
|
)
|
|
63
64
|
)
|
|
65
|
+
|
|
64
66
|
"""
|
|
65
67
|
|
|
66
68
|
llm_chain: LLMChain
|
|
@@ -103,10 +105,12 @@ class QAGenerationChain(Chain):
|
|
|
103
105
|
raise NotImplementedError
|
|
104
106
|
|
|
105
107
|
@property
|
|
108
|
+
@override
|
|
106
109
|
def input_keys(self) -> list[str]:
|
|
107
110
|
return [self.input_key]
|
|
108
111
|
|
|
109
112
|
@property
|
|
113
|
+
@override
|
|
110
114
|
def output_keys(self) -> list[str]:
|
|
111
115
|
return [self.output_key]
|
|
112
116
|
|
|
@@ -16,6 +16,7 @@ from langchain_core.documents import Document
|
|
|
16
16
|
from langchain_core.language_models import BaseLanguageModel
|
|
17
17
|
from langchain_core.prompts import BasePromptTemplate
|
|
18
18
|
from pydantic import ConfigDict, model_validator
|
|
19
|
+
from typing_extensions import override
|
|
19
20
|
|
|
20
21
|
from langchain.chains import ReduceDocumentsChain
|
|
21
22
|
from langchain.chains.base import Chain
|
|
@@ -240,6 +241,7 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
|
|
|
240
241
|
"""
|
|
241
242
|
return [self.input_docs_key, self.question_key]
|
|
242
243
|
|
|
244
|
+
@override
|
|
243
245
|
def _get_docs(
|
|
244
246
|
self,
|
|
245
247
|
inputs: dict[str, Any],
|
|
@@ -249,6 +251,7 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
|
|
|
249
251
|
"""Get docs to run questioning over."""
|
|
250
252
|
return inputs.pop(self.input_docs_key)
|
|
251
253
|
|
|
254
|
+
@override
|
|
252
255
|
async def _aget_docs(
|
|
253
256
|
self,
|
|
254
257
|
inputs: dict[str, Any],
|
|
@@ -33,7 +33,7 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
|
|
|
33
33
|
StuffDocumentsChain,
|
|
34
34
|
):
|
|
35
35
|
tokens = [
|
|
36
|
-
self.combine_documents_chain.llm_chain._get_num_tokens(doc.page_content)
|
|
36
|
+
self.combine_documents_chain.llm_chain._get_num_tokens(doc.page_content) # noqa: SLF001
|
|
37
37
|
for doc in docs
|
|
38
38
|
]
|
|
39
39
|
token_count = sum(tokens[:num_docs])
|
|
@@ -10,6 +10,7 @@ from langchain_core.callbacks import (
|
|
|
10
10
|
from langchain_core.documents import Document
|
|
11
11
|
from langchain_core.vectorstores import VectorStore
|
|
12
12
|
from pydantic import Field, model_validator
|
|
13
|
+
from typing_extensions import override
|
|
13
14
|
|
|
14
15
|
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
|
15
16
|
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
|
|
@@ -38,7 +39,7 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
|
|
|
38
39
|
StuffDocumentsChain,
|
|
39
40
|
):
|
|
40
41
|
tokens = [
|
|
41
|
-
self.combine_documents_chain.llm_chain._get_num_tokens(doc.page_content)
|
|
42
|
+
self.combine_documents_chain.llm_chain._get_num_tokens(doc.page_content) # noqa: SLF001
|
|
42
43
|
for doc in docs
|
|
43
44
|
]
|
|
44
45
|
token_count = sum(tokens[:num_docs])
|
|
@@ -48,6 +49,7 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
|
|
|
48
49
|
|
|
49
50
|
return docs[:num_docs]
|
|
50
51
|
|
|
52
|
+
@override
|
|
51
53
|
def _get_docs(
|
|
52
54
|
self,
|
|
53
55
|
inputs: dict[str, Any],
|
|
@@ -73,7 +75,7 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
|
|
|
73
75
|
|
|
74
76
|
@model_validator(mode="before")
|
|
75
77
|
@classmethod
|
|
76
|
-
def
|
|
78
|
+
def _raise_deprecation(cls, values: dict) -> Any:
|
|
77
79
|
warnings.warn(
|
|
78
80
|
"`VectorDBQAWithSourcesChain` is deprecated - "
|
|
79
81
|
"please use `from langchain.chains import RetrievalQAWithSourcesChain`",
|
|
@@ -22,6 +22,7 @@ from langchain_core.structured_query import (
|
|
|
22
22
|
Operator,
|
|
23
23
|
StructuredQuery,
|
|
24
24
|
)
|
|
25
|
+
from typing_extensions import override
|
|
25
26
|
|
|
26
27
|
from langchain.chains.llm import LLMChain
|
|
27
28
|
from langchain.chains.query_constructor.parser import get_parser
|
|
@@ -46,6 +47,7 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
|
|
46
47
|
ast_parse: Callable
|
|
47
48
|
"""Callable that parses dict into internal representation of query language."""
|
|
48
49
|
|
|
50
|
+
@override
|
|
49
51
|
def parse(self, text: str) -> StructuredQuery:
|
|
50
52
|
try:
|
|
51
53
|
expected_keys = ["query", "filter"]
|
|
@@ -89,7 +91,7 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
|
|
89
91
|
|
|
90
92
|
def ast_parse(raw_filter: str) -> Optional[FilterDirective]:
|
|
91
93
|
filter_directive = cast(
|
|
92
|
-
Optional[FilterDirective],
|
|
94
|
+
"Optional[FilterDirective]",
|
|
93
95
|
get_parser().parse(raw_filter),
|
|
94
96
|
)
|
|
95
97
|
return fix_filter_directive(
|
|
@@ -142,7 +144,7 @@ def fix_filter_directive(
|
|
|
142
144
|
return None
|
|
143
145
|
args = [
|
|
144
146
|
cast(
|
|
145
|
-
FilterDirective,
|
|
147
|
+
"FilterDirective",
|
|
146
148
|
fix_filter_directive(
|
|
147
149
|
arg,
|
|
148
150
|
allowed_comparators=allowed_comparators,
|
|
@@ -11,7 +11,7 @@ try:
|
|
|
11
11
|
from lark import Lark, Transformer, v_args
|
|
12
12
|
except ImportError:
|
|
13
13
|
|
|
14
|
-
def v_args(*
|
|
14
|
+
def v_args(*_: Any, **__: Any) -> Any: # type: ignore[misc]
|
|
15
15
|
"""Dummy decorator for when lark is not installed."""
|
|
16
16
|
return lambda _: None
|
|
17
17
|
|
|
@@ -83,15 +83,35 @@ class QueryTransformer(Transformer):
|
|
|
83
83
|
allowed_attributes: Optional[Sequence[str]] = None,
|
|
84
84
|
**kwargs: Any,
|
|
85
85
|
):
|
|
86
|
+
"""Initialize the QueryTransformer.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
allowed_comparators: Optional sequence of allowed comparators.
|
|
90
|
+
allowed_operators: Optional sequence of allowed operators.
|
|
91
|
+
allowed_attributes: Optional sequence of allowed attributes for comparators.
|
|
92
|
+
**kwargs: Additional keyword arguments.
|
|
93
|
+
"""
|
|
86
94
|
super().__init__(*args, **kwargs)
|
|
87
95
|
self.allowed_comparators = allowed_comparators
|
|
88
96
|
self.allowed_operators = allowed_operators
|
|
89
97
|
self.allowed_attributes = allowed_attributes
|
|
90
98
|
|
|
91
99
|
def program(self, *items: Any) -> tuple:
|
|
100
|
+
"""Transform the items into a tuple."""
|
|
92
101
|
return items
|
|
93
102
|
|
|
94
103
|
def func_call(self, func_name: Any, args: list) -> FilterDirective:
|
|
104
|
+
"""Transform a function name and args into a FilterDirective.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
func_name: The name of the function.
|
|
108
|
+
args: The arguments passed to the function.
|
|
109
|
+
Returns:
|
|
110
|
+
FilterDirective: The filter directive.
|
|
111
|
+
Raises:
|
|
112
|
+
ValueError: If the function is a comparator and the first arg is not in the
|
|
113
|
+
allowed attributes.
|
|
114
|
+
"""
|
|
95
115
|
func = self._match_func_name(str(func_name))
|
|
96
116
|
if isinstance(func, Comparator):
|
|
97
117
|
if self.allowed_attributes and args[0] not in self.allowed_attributes:
|
|
@@ -135,26 +155,55 @@ class QueryTransformer(Transformer):
|
|
|
135
155
|
raise ValueError(msg)
|
|
136
156
|
|
|
137
157
|
def args(self, *items: Any) -> tuple:
|
|
158
|
+
"""Transforms items into a tuple.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
items: The items to transform.
|
|
162
|
+
"""
|
|
138
163
|
return items
|
|
139
164
|
|
|
140
165
|
def false(self) -> bool:
|
|
166
|
+
"""Returns false."""
|
|
141
167
|
return False
|
|
142
168
|
|
|
143
169
|
def true(self) -> bool:
|
|
170
|
+
"""Returns true."""
|
|
144
171
|
return True
|
|
145
172
|
|
|
146
173
|
def list(self, item: Any) -> list:
|
|
174
|
+
"""Transforms an item into a list.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
item: The item to transform.
|
|
178
|
+
"""
|
|
147
179
|
if item is None:
|
|
148
180
|
return []
|
|
149
181
|
return list(item)
|
|
150
182
|
|
|
151
183
|
def int(self, item: Any) -> int:
|
|
184
|
+
"""Transforms an item into an int.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
item: The item to transform.
|
|
188
|
+
"""
|
|
152
189
|
return int(item)
|
|
153
190
|
|
|
154
191
|
def float(self, item: Any) -> float:
|
|
192
|
+
"""Transforms an item into a float.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
item: The item to transform.
|
|
196
|
+
"""
|
|
155
197
|
return float(item)
|
|
156
198
|
|
|
157
199
|
def date(self, item: Any) -> ISO8601Date:
|
|
200
|
+
"""Transforms an item into a ISO8601Date object.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
item: The item to transform.
|
|
204
|
+
Raises:
|
|
205
|
+
ValueError: If the item is not in ISO 8601 date format.
|
|
206
|
+
"""
|
|
158
207
|
item = str(item).strip("\"'")
|
|
159
208
|
try:
|
|
160
209
|
datetime.datetime.strptime(item, "%Y-%m-%d") # noqa: DTZ007
|
|
@@ -167,6 +216,13 @@ class QueryTransformer(Transformer):
|
|
|
167
216
|
return {"date": item, "type": "date"}
|
|
168
217
|
|
|
169
218
|
def datetime(self, item: Any) -> ISO8601DateTime:
|
|
219
|
+
"""Transforms an item into a ISO8601DateTime object.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
item: The item to transform.
|
|
223
|
+
Raises:
|
|
224
|
+
ValueError: If the item is not in ISO 8601 datetime format.
|
|
225
|
+
"""
|
|
170
226
|
item = str(item).strip("\"'")
|
|
171
227
|
try:
|
|
172
228
|
# Parse full ISO 8601 datetime format
|
|
@@ -180,7 +236,13 @@ class QueryTransformer(Transformer):
|
|
|
180
236
|
return {"datetime": item, "type": "datetime"}
|
|
181
237
|
|
|
182
238
|
def string(self, item: Any) -> str:
|
|
183
|
-
|
|
239
|
+
"""Transforms an item into a string.
|
|
240
|
+
|
|
241
|
+
Removes escaped quotes.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
item: The item to transform.
|
|
245
|
+
"""
|
|
184
246
|
return str(item).strip("\"'")
|
|
185
247
|
|
|
186
248
|
|
|
@@ -18,6 +18,7 @@ from langchain_core.prompts import PromptTemplate
|
|
|
18
18
|
from langchain_core.retrievers import BaseRetriever
|
|
19
19
|
from langchain_core.vectorstores import VectorStore
|
|
20
20
|
from pydantic import ConfigDict, Field, model_validator
|
|
21
|
+
from typing_extensions import override
|
|
21
22
|
|
|
22
23
|
from langchain.chains.base import Chain
|
|
23
24
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
|
@@ -146,6 +147,7 @@ class BaseRetrievalQA(Chain):
|
|
|
146
147
|
|
|
147
148
|
res = indexqa({'query': 'This is my query'})
|
|
148
149
|
answer, docs = res['result'], res['source_documents']
|
|
150
|
+
|
|
149
151
|
"""
|
|
150
152
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
151
153
|
question = inputs[self.input_key]
|
|
@@ -190,6 +192,7 @@ class BaseRetrievalQA(Chain):
|
|
|
190
192
|
|
|
191
193
|
res = indexqa({'query': 'This is my query'})
|
|
192
194
|
answer, docs = res['result'], res['source_documents']
|
|
195
|
+
|
|
193
196
|
"""
|
|
194
197
|
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
|
195
198
|
question = inputs[self.input_key]
|
|
@@ -330,6 +333,7 @@ class VectorDBQA(BaseRetrievalQA):
|
|
|
330
333
|
raise ValueError(msg)
|
|
331
334
|
return values
|
|
332
335
|
|
|
336
|
+
@override
|
|
333
337
|
def _get_docs(
|
|
334
338
|
self,
|
|
335
339
|
question: str,
|
langchain/chains/router/base.py
CHANGED
|
@@ -12,11 +12,14 @@ from langchain_core.callbacks import (
|
|
|
12
12
|
Callbacks,
|
|
13
13
|
)
|
|
14
14
|
from pydantic import ConfigDict
|
|
15
|
+
from typing_extensions import override
|
|
15
16
|
|
|
16
17
|
from langchain.chains.base import Chain
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class Route(NamedTuple):
|
|
21
|
+
"""A route to a destination chain."""
|
|
22
|
+
|
|
20
23
|
destination: Optional[str]
|
|
21
24
|
next_inputs: dict[str, Any]
|
|
22
25
|
|
|
@@ -25,12 +28,12 @@ class RouterChain(Chain, ABC):
|
|
|
25
28
|
"""Chain that outputs the name of a destination chain and the inputs to it."""
|
|
26
29
|
|
|
27
30
|
@property
|
|
31
|
+
@override
|
|
28
32
|
def output_keys(self) -> list[str]:
|
|
29
33
|
return ["destination", "next_inputs"]
|
|
30
34
|
|
|
31
35
|
def route(self, inputs: dict[str, Any], callbacks: Callbacks = None) -> Route:
|
|
32
|
-
"""
|
|
33
|
-
Route inputs to a destination chain.
|
|
36
|
+
"""Route inputs to a destination chain.
|
|
34
37
|
|
|
35
38
|
Args:
|
|
36
39
|
inputs: inputs to the chain
|
|
@@ -47,6 +50,15 @@ class RouterChain(Chain, ABC):
|
|
|
47
50
|
inputs: dict[str, Any],
|
|
48
51
|
callbacks: Callbacks = None,
|
|
49
52
|
) -> Route:
|
|
53
|
+
"""Route inputs to a destination chain.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
inputs: inputs to the chain
|
|
57
|
+
callbacks: callbacks to use for the chain
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
a Route object
|
|
61
|
+
"""
|
|
50
62
|
result = await self.acall(inputs, callbacks=callbacks)
|
|
51
63
|
return Route(result["destination"], result["next_inputs"])
|
|
52
64
|
|
|
@@ -11,6 +11,7 @@ from langchain_core.documents import Document
|
|
|
11
11
|
from langchain_core.embeddings import Embeddings
|
|
12
12
|
from langchain_core.vectorstores import VectorStore
|
|
13
13
|
from pydantic import ConfigDict
|
|
14
|
+
from typing_extensions import override
|
|
14
15
|
|
|
15
16
|
from langchain.chains.router.base import RouterChain
|
|
16
17
|
|
|
@@ -34,6 +35,7 @@ class EmbeddingRouterChain(RouterChain):
|
|
|
34
35
|
"""
|
|
35
36
|
return self.routing_keys
|
|
36
37
|
|
|
38
|
+
@override
|
|
37
39
|
def _call(
|
|
38
40
|
self,
|
|
39
41
|
inputs: dict[str, Any],
|
|
@@ -43,6 +45,7 @@ class EmbeddingRouterChain(RouterChain):
|
|
|
43
45
|
results = self.vectorstore.similarity_search(_input, k=1)
|
|
44
46
|
return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
|
|
45
47
|
|
|
48
|
+
@override
|
|
46
49
|
async def _acall(
|
|
47
50
|
self,
|
|
48
51
|
inputs: dict[str, Any],
|
|
@@ -15,7 +15,7 @@ from langchain_core.output_parsers import BaseOutputParser
|
|
|
15
15
|
from langchain_core.prompts import BasePromptTemplate
|
|
16
16
|
from langchain_core.utils.json import parse_and_check_json_markdown
|
|
17
17
|
from pydantic import model_validator
|
|
18
|
-
from typing_extensions import Self
|
|
18
|
+
from typing_extensions import Self, override
|
|
19
19
|
|
|
20
20
|
from langchain.chains import LLMChain
|
|
21
21
|
from langchain.chains.router.base import RouterChain
|
|
@@ -96,13 +96,14 @@ class LLMRouterChain(RouterChain):
|
|
|
96
96
|
)
|
|
97
97
|
|
|
98
98
|
chain.invoke({"query": "what color are carrots"})
|
|
99
|
+
|
|
99
100
|
""" # noqa: E501
|
|
100
101
|
|
|
101
102
|
llm_chain: LLMChain
|
|
102
103
|
"""LLM chain used to perform routing"""
|
|
103
104
|
|
|
104
105
|
@model_validator(mode="after")
|
|
105
|
-
def
|
|
106
|
+
def _validate_prompt(self) -> Self:
|
|
106
107
|
prompt = self.llm_chain.prompt
|
|
107
108
|
if prompt.output_parser is None:
|
|
108
109
|
msg = (
|
|
@@ -137,7 +138,7 @@ class LLMRouterChain(RouterChain):
|
|
|
137
138
|
|
|
138
139
|
prediction = self.llm_chain.predict(callbacks=callbacks, **inputs)
|
|
139
140
|
return cast(
|
|
140
|
-
dict[str, Any],
|
|
141
|
+
"dict[str, Any]",
|
|
141
142
|
self.llm_chain.prompt.output_parser.parse(prediction),
|
|
142
143
|
)
|
|
143
144
|
|
|
@@ -149,7 +150,7 @@ class LLMRouterChain(RouterChain):
|
|
|
149
150
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
150
151
|
callbacks = _run_manager.get_child()
|
|
151
152
|
return cast(
|
|
152
|
-
dict[str, Any],
|
|
153
|
+
"dict[str, Any]",
|
|
153
154
|
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
|
|
154
155
|
)
|
|
155
156
|
|
|
@@ -172,6 +173,7 @@ class RouterOutputParser(BaseOutputParser[dict[str, str]]):
|
|
|
172
173
|
next_inputs_type: type = str
|
|
173
174
|
next_inputs_inner_key: str = "input"
|
|
174
175
|
|
|
176
|
+
@override
|
|
175
177
|
def parse(self, text: str) -> dict[str, Any]:
|
|
176
178
|
try:
|
|
177
179
|
expected_keys = ["destination", "next_inputs"]
|
|
@@ -7,6 +7,7 @@ from typing import Any, Optional
|
|
|
7
7
|
from langchain_core._api import deprecated
|
|
8
8
|
from langchain_core.language_models import BaseLanguageModel
|
|
9
9
|
from langchain_core.prompts import PromptTemplate
|
|
10
|
+
from typing_extensions import override
|
|
10
11
|
|
|
11
12
|
from langchain.chains import ConversationChain
|
|
12
13
|
from langchain.chains.base import Chain
|
|
@@ -139,9 +140,11 @@ class MultiPromptChain(MultiRouteChain):
|
|
|
139
140
|
result = await app.ainvoke({"query": "what color are carrots"})
|
|
140
141
|
print(result["destination"])
|
|
141
142
|
print(result["answer"])
|
|
143
|
+
|
|
142
144
|
""" # noqa: E501
|
|
143
145
|
|
|
144
146
|
@property
|
|
147
|
+
@override
|
|
145
148
|
def output_keys(self) -> list[str]:
|
|
146
149
|
return ["text"]
|
|
147
150
|
|
|
@@ -8,6 +8,7 @@ from typing import Any, Optional
|
|
|
8
8
|
from langchain_core.language_models import BaseLanguageModel
|
|
9
9
|
from langchain_core.prompts import PromptTemplate
|
|
10
10
|
from langchain_core.retrievers import BaseRetriever
|
|
11
|
+
from typing_extensions import override
|
|
11
12
|
|
|
12
13
|
from langchain.chains import ConversationChain
|
|
13
14
|
from langchain.chains.base import Chain
|
|
@@ -32,6 +33,7 @@ class MultiRetrievalQAChain(MultiRouteChain):
|
|
|
32
33
|
"""Default chain to use when router doesn't map input to one of the destinations."""
|
|
33
34
|
|
|
34
35
|
@property
|
|
36
|
+
@override
|
|
35
37
|
def output_keys(self) -> list[str]:
|
|
36
38
|
return ["result"]
|
|
37
39
|
|
|
@@ -47,6 +49,22 @@ class MultiRetrievalQAChain(MultiRouteChain):
|
|
|
47
49
|
default_chain_llm: Optional[BaseLanguageModel] = None,
|
|
48
50
|
**kwargs: Any,
|
|
49
51
|
) -> MultiRetrievalQAChain:
|
|
52
|
+
"""Create a multi retrieval qa chain from an LLM and a default chain.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
llm: The language model to use.
|
|
56
|
+
retriever_infos: Dictionaries containing retriever information.
|
|
57
|
+
default_retriever: Optional default retriever to use if no default chain
|
|
58
|
+
is provided.
|
|
59
|
+
default_prompt: Optional prompt template to use for the default retriever.
|
|
60
|
+
default_chain: Optional default chain to use when router doesn't map input
|
|
61
|
+
to one of the destinations.
|
|
62
|
+
default_chain_llm: Optional language model to use if no default chain and
|
|
63
|
+
no default retriever are provided.
|
|
64
|
+
**kwargs: Additional keyword arguments to pass to the chain.
|
|
65
|
+
Returns:
|
|
66
|
+
An instance of the multi retrieval qa chain.
|
|
67
|
+
"""
|
|
50
68
|
if default_prompt and not default_retriever:
|
|
51
69
|
msg = (
|
|
52
70
|
"`default_retriever` must be specified if `default_prompt` is "
|
|
@@ -132,6 +132,7 @@ def create_openai_fn_runnable(
|
|
|
132
132
|
structured_llm = create_openai_fn_runnable([RecordPerson, RecordDog], llm)
|
|
133
133
|
structured_llm.invoke("Harry was a chubby brown beagle who loved chicken)
|
|
134
134
|
# -> RecordDog(name="Harry", color="brown", fav_food="chicken")
|
|
135
|
+
|
|
135
136
|
""" # noqa: E501
|
|
136
137
|
if not functions:
|
|
137
138
|
msg = "Need to pass in at least one function. Received zero."
|
|
@@ -390,6 +391,7 @@ def create_structured_output_runnable(
|
|
|
390
391
|
)
|
|
391
392
|
chain = prompt | structured_llm
|
|
392
393
|
chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
|
|
394
|
+
|
|
393
395
|
""" # noqa: E501
|
|
394
396
|
# for backwards compatibility
|
|
395
397
|
force_function_usage = kwargs.get(
|