langchain 0.3.23__py3-none-any.whl → 0.3.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/_api/module_import.py +3 -3
- langchain/agents/agent.py +104 -109
- langchain/agents/agent_iterator.py +11 -15
- langchain/agents/agent_toolkits/conversational_retrieval/openai_functions.py +2 -2
- langchain/agents/agent_toolkits/vectorstore/base.py +3 -3
- langchain/agents/agent_toolkits/vectorstore/toolkit.py +4 -6
- langchain/agents/chat/base.py +7 -6
- langchain/agents/chat/output_parser.py +2 -1
- langchain/agents/conversational/base.py +5 -4
- langchain/agents/conversational_chat/base.py +9 -8
- langchain/agents/format_scratchpad/log.py +1 -3
- langchain/agents/format_scratchpad/log_to_messages.py +3 -5
- langchain/agents/format_scratchpad/openai_functions.py +4 -4
- langchain/agents/format_scratchpad/tools.py +3 -3
- langchain/agents/format_scratchpad/xml.py +1 -3
- langchain/agents/initialize.py +2 -1
- langchain/agents/json_chat/base.py +3 -2
- langchain/agents/loading.py +5 -5
- langchain/agents/mrkl/base.py +6 -5
- langchain/agents/openai_assistant/base.py +13 -17
- langchain/agents/openai_functions_agent/agent_token_buffer_memory.py +6 -6
- langchain/agents/openai_functions_agent/base.py +13 -12
- langchain/agents/openai_functions_multi_agent/base.py +15 -14
- langchain/agents/openai_tools/base.py +2 -1
- langchain/agents/output_parsers/openai_functions.py +2 -2
- langchain/agents/output_parsers/openai_tools.py +6 -6
- langchain/agents/output_parsers/react_json_single_input.py +2 -1
- langchain/agents/output_parsers/self_ask.py +2 -1
- langchain/agents/output_parsers/tools.py +7 -7
- langchain/agents/react/agent.py +3 -2
- langchain/agents/react/base.py +4 -3
- langchain/agents/schema.py +3 -3
- langchain/agents/self_ask_with_search/base.py +2 -1
- langchain/agents/structured_chat/base.py +9 -8
- langchain/agents/structured_chat/output_parser.py +2 -1
- langchain/agents/tool_calling_agent/base.py +3 -2
- langchain/agents/tools.py +4 -4
- langchain/agents/types.py +3 -3
- langchain/agents/utils.py +1 -1
- langchain/agents/xml/base.py +7 -6
- langchain/callbacks/streaming_aiter.py +3 -2
- langchain/callbacks/streaming_aiter_final_only.py +3 -3
- langchain/callbacks/streaming_stdout_final_only.py +3 -3
- langchain/chains/api/base.py +11 -12
- langchain/chains/base.py +47 -50
- langchain/chains/combine_documents/base.py +23 -23
- langchain/chains/combine_documents/map_reduce.py +12 -12
- langchain/chains/combine_documents/map_rerank.py +16 -15
- langchain/chains/combine_documents/reduce.py +17 -17
- langchain/chains/combine_documents/refine.py +12 -12
- langchain/chains/combine_documents/stuff.py +10 -10
- langchain/chains/constitutional_ai/base.py +9 -9
- langchain/chains/conversation/base.py +2 -4
- langchain/chains/conversational_retrieval/base.py +30 -30
- langchain/chains/elasticsearch_database/base.py +13 -13
- langchain/chains/example_generator.py +1 -3
- langchain/chains/flare/base.py +13 -12
- langchain/chains/flare/prompts.py +2 -4
- langchain/chains/hyde/base.py +8 -8
- langchain/chains/llm.py +31 -30
- langchain/chains/llm_checker/base.py +6 -6
- langchain/chains/llm_math/base.py +10 -10
- langchain/chains/llm_summarization_checker/base.py +6 -6
- langchain/chains/loading.py +12 -14
- langchain/chains/mapreduce.py +7 -6
- langchain/chains/moderation.py +8 -8
- langchain/chains/natbot/base.py +6 -6
- langchain/chains/openai_functions/base.py +8 -10
- langchain/chains/openai_functions/citation_fuzzy_match.py +4 -4
- langchain/chains/openai_functions/extraction.py +3 -3
- langchain/chains/openai_functions/openapi.py +12 -12
- langchain/chains/openai_functions/qa_with_structure.py +4 -4
- langchain/chains/openai_functions/utils.py +2 -2
- langchain/chains/openai_tools/extraction.py +2 -2
- langchain/chains/prompt_selector.py +3 -3
- langchain/chains/qa_generation/base.py +5 -5
- langchain/chains/qa_with_sources/base.py +21 -21
- langchain/chains/qa_with_sources/loading.py +2 -1
- langchain/chains/qa_with_sources/retrieval.py +6 -6
- langchain/chains/qa_with_sources/vector_db.py +8 -8
- langchain/chains/query_constructor/base.py +4 -3
- langchain/chains/query_constructor/parser.py +5 -4
- langchain/chains/question_answering/chain.py +3 -2
- langchain/chains/retrieval.py +2 -2
- langchain/chains/retrieval_qa/base.py +16 -16
- langchain/chains/router/base.py +12 -11
- langchain/chains/router/embedding_router.py +12 -11
- langchain/chains/router/llm_router.py +12 -12
- langchain/chains/router/multi_prompt.py +3 -3
- langchain/chains/router/multi_retrieval_qa.py +5 -4
- langchain/chains/sequential.py +18 -18
- langchain/chains/sql_database/query.py +4 -4
- langchain/chains/structured_output/base.py +14 -13
- langchain/chains/summarize/chain.py +4 -3
- langchain/chains/transform.py +12 -11
- langchain/chat_models/base.py +27 -31
- langchain/embeddings/__init__.py +1 -1
- langchain/embeddings/base.py +4 -4
- langchain/embeddings/cache.py +19 -18
- langchain/evaluation/agents/trajectory_eval_chain.py +16 -19
- langchain/evaluation/comparison/eval_chain.py +10 -10
- langchain/evaluation/criteria/eval_chain.py +11 -10
- langchain/evaluation/embedding_distance/base.py +21 -21
- langchain/evaluation/exact_match/base.py +3 -3
- langchain/evaluation/loading.py +7 -8
- langchain/evaluation/qa/eval_chain.py +7 -6
- langchain/evaluation/regex_match/base.py +3 -3
- langchain/evaluation/schema.py +6 -5
- langchain/evaluation/scoring/eval_chain.py +9 -9
- langchain/evaluation/string_distance/base.py +23 -23
- langchain/hub.py +2 -1
- langchain/indexes/_sql_record_manager.py +8 -7
- langchain/indexes/vectorstore.py +11 -11
- langchain/llms/__init__.py +3 -3
- langchain/memory/buffer.py +13 -13
- langchain/memory/buffer_window.py +5 -5
- langchain/memory/chat_memory.py +5 -5
- langchain/memory/combined.py +10 -10
- langchain/memory/entity.py +8 -7
- langchain/memory/readonly.py +4 -4
- langchain/memory/simple.py +5 -5
- langchain/memory/summary.py +8 -8
- langchain/memory/summary_buffer.py +11 -11
- langchain/memory/token_buffer.py +5 -5
- langchain/memory/utils.py +2 -2
- langchain/memory/vectorstore.py +15 -14
- langchain/memory/vectorstore_token_buffer_memory.py +7 -7
- langchain/model_laboratory.py +4 -3
- langchain/output_parsers/combining.py +5 -5
- langchain/output_parsers/datetime.py +1 -2
- langchain/output_parsers/enum.py +4 -5
- langchain/output_parsers/pandas_dataframe.py +5 -5
- langchain/output_parsers/regex.py +4 -4
- langchain/output_parsers/regex_dict.py +4 -4
- langchain/output_parsers/retry.py +2 -2
- langchain/output_parsers/structured.py +5 -5
- langchain/output_parsers/yaml.py +3 -3
- langchain/pydantic_v1/__init__.py +1 -6
- langchain/pydantic_v1/dataclasses.py +1 -5
- langchain/pydantic_v1/main.py +1 -5
- langchain/retrievers/contextual_compression.py +3 -3
- langchain/retrievers/document_compressors/base.py +3 -2
- langchain/retrievers/document_compressors/chain_extract.py +4 -3
- langchain/retrievers/document_compressors/chain_filter.py +3 -2
- langchain/retrievers/document_compressors/cohere_rerank.py +4 -3
- langchain/retrievers/document_compressors/cross_encoder.py +1 -2
- langchain/retrievers/document_compressors/cross_encoder_rerank.py +2 -1
- langchain/retrievers/document_compressors/embeddings_filter.py +3 -2
- langchain/retrievers/document_compressors/listwise_rerank.py +6 -5
- langchain/retrievers/ensemble.py +15 -19
- langchain/retrievers/merger_retriever.py +7 -12
- langchain/retrievers/multi_query.py +14 -13
- langchain/retrievers/multi_vector.py +4 -4
- langchain/retrievers/parent_document_retriever.py +9 -8
- langchain/retrievers/re_phraser.py +2 -3
- langchain/retrievers/self_query/base.py +13 -12
- langchain/retrievers/time_weighted_retriever.py +14 -14
- langchain/runnables/openai_functions.py +4 -3
- langchain/smith/evaluation/config.py +7 -6
- langchain/smith/evaluation/progress.py +3 -2
- langchain/smith/evaluation/runner_utils.py +58 -61
- langchain/smith/evaluation/string_run_evaluator.py +29 -29
- langchain/storage/encoder_backed.py +7 -11
- langchain/storage/file_system.py +5 -4
- {langchain-0.3.23.dist-info → langchain-0.3.24.dist-info}/METADATA +2 -2
- {langchain-0.3.23.dist-info → langchain-0.3.24.dist-info}/RECORD +169 -169
- {langchain-0.3.23.dist-info → langchain-0.3.24.dist-info}/WHEEL +1 -1
- langchain-0.3.24.dist-info/entry_points.txt +4 -0
- langchain-0.3.23.dist-info/entry_points.txt +0 -5
- {langchain-0.3.23.dist-info → langchain-0.3.24.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
-
from
|
|
4
|
+
from collections.abc import AsyncIterator
|
|
5
|
+
from typing import Any, Literal, Union, cast
|
|
5
6
|
|
|
6
7
|
from langchain_core.callbacks import AsyncCallbackHandler
|
|
7
8
|
from langchain_core.outputs import LLMResult
|
|
@@ -25,7 +26,7 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
|
|
25
26
|
self.done = asyncio.Event()
|
|
26
27
|
|
|
27
28
|
async def on_llm_start(
|
|
28
|
-
self, serialized:
|
|
29
|
+
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
|
29
30
|
) -> None:
|
|
30
31
|
# If two calls are made in a row, this resets the state
|
|
31
32
|
self.done.clear()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import Any,
|
|
3
|
+
from typing import Any, Optional
|
|
4
4
|
|
|
5
5
|
from langchain_core.outputs import LLMResult
|
|
6
6
|
|
|
@@ -30,7 +30,7 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|
|
30
30
|
def __init__(
|
|
31
31
|
self,
|
|
32
32
|
*,
|
|
33
|
-
answer_prefix_tokens: Optional[
|
|
33
|
+
answer_prefix_tokens: Optional[list[str]] = None,
|
|
34
34
|
strip_tokens: bool = True,
|
|
35
35
|
stream_prefix: bool = False,
|
|
36
36
|
) -> None:
|
|
@@ -62,7 +62,7 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|
|
62
62
|
self.answer_reached = False
|
|
63
63
|
|
|
64
64
|
async def on_llm_start(
|
|
65
|
-
self, serialized:
|
|
65
|
+
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
|
66
66
|
) -> None:
|
|
67
67
|
# If two calls are made in a row, this resets the state
|
|
68
68
|
self.done.clear()
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Callback Handler streams to stdout on new llm token."""
|
|
2
2
|
|
|
3
3
|
import sys
|
|
4
|
-
from typing import Any,
|
|
4
|
+
from typing import Any, Optional
|
|
5
5
|
|
|
6
6
|
from langchain_core.callbacks import StreamingStdOutCallbackHandler
|
|
7
7
|
|
|
@@ -31,7 +31,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
|
|
31
31
|
def __init__(
|
|
32
32
|
self,
|
|
33
33
|
*,
|
|
34
|
-
answer_prefix_tokens: Optional[
|
|
34
|
+
answer_prefix_tokens: Optional[list[str]] = None,
|
|
35
35
|
strip_tokens: bool = True,
|
|
36
36
|
stream_prefix: bool = False,
|
|
37
37
|
) -> None:
|
|
@@ -63,7 +63,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
|
|
63
63
|
self.answer_reached = False
|
|
64
64
|
|
|
65
65
|
def on_llm_start(
|
|
66
|
-
self, serialized:
|
|
66
|
+
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
|
67
67
|
) -> None:
|
|
68
68
|
"""Run when LLM starts running."""
|
|
69
69
|
self.answer_reached = False
|
langchain/chains/api/base.py
CHANGED
|
@@ -2,7 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from collections.abc import Sequence
|
|
6
|
+
from typing import Any, Optional
|
|
6
7
|
from urllib.parse import urlparse
|
|
7
8
|
|
|
8
9
|
from langchain_core._api import deprecated
|
|
@@ -20,7 +21,7 @@ from langchain.chains.base import Chain
|
|
|
20
21
|
from langchain.chains.llm import LLMChain
|
|
21
22
|
|
|
22
23
|
|
|
23
|
-
def _extract_scheme_and_domain(url: str) ->
|
|
24
|
+
def _extract_scheme_and_domain(url: str) -> tuple[str, str]:
|
|
24
25
|
"""Extract the scheme + domain from a given URL.
|
|
25
26
|
|
|
26
27
|
Args:
|
|
@@ -198,9 +199,7 @@ try:
|
|
|
198
199
|
api_docs: str
|
|
199
200
|
question_key: str = "question" #: :meta private:
|
|
200
201
|
output_key: str = "output" #: :meta private:
|
|
201
|
-
limit_to_domains: Optional[Sequence[str]] = Field(
|
|
202
|
-
default_factory=list # type: ignore
|
|
203
|
-
)
|
|
202
|
+
limit_to_domains: Optional[Sequence[str]] = Field(default_factory=list) # type: ignore[arg-type]
|
|
204
203
|
"""Use to limit the domains that can be accessed by the API chain.
|
|
205
204
|
|
|
206
205
|
* For example, to limit to just the domain `https://www.example.com`, set
|
|
@@ -215,7 +214,7 @@ try:
|
|
|
215
214
|
"""
|
|
216
215
|
|
|
217
216
|
@property
|
|
218
|
-
def input_keys(self) ->
|
|
217
|
+
def input_keys(self) -> list[str]:
|
|
219
218
|
"""Expect input key.
|
|
220
219
|
|
|
221
220
|
:meta private:
|
|
@@ -223,7 +222,7 @@ try:
|
|
|
223
222
|
return [self.question_key]
|
|
224
223
|
|
|
225
224
|
@property
|
|
226
|
-
def output_keys(self) ->
|
|
225
|
+
def output_keys(self) -> list[str]:
|
|
227
226
|
"""Expect output key.
|
|
228
227
|
|
|
229
228
|
:meta private:
|
|
@@ -243,7 +242,7 @@ try:
|
|
|
243
242
|
|
|
244
243
|
@model_validator(mode="before")
|
|
245
244
|
@classmethod
|
|
246
|
-
def validate_limit_to_domains(cls, values:
|
|
245
|
+
def validate_limit_to_domains(cls, values: dict) -> Any:
|
|
247
246
|
"""Check that allowed domains are valid."""
|
|
248
247
|
# This check must be a pre=True check, so that a default of None
|
|
249
248
|
# won't be set to limit_to_domains if it's not provided.
|
|
@@ -275,9 +274,9 @@ try:
|
|
|
275
274
|
|
|
276
275
|
def _call(
|
|
277
276
|
self,
|
|
278
|
-
inputs:
|
|
277
|
+
inputs: dict[str, Any],
|
|
279
278
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
280
|
-
) ->
|
|
279
|
+
) -> dict[str, str]:
|
|
281
280
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
282
281
|
question = inputs[self.question_key]
|
|
283
282
|
api_url = self.api_request_chain.predict(
|
|
@@ -308,9 +307,9 @@ try:
|
|
|
308
307
|
|
|
309
308
|
async def _acall(
|
|
310
309
|
self,
|
|
311
|
-
inputs:
|
|
310
|
+
inputs: dict[str, Any],
|
|
312
311
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
|
313
|
-
) ->
|
|
312
|
+
) -> dict[str, str]:
|
|
314
313
|
_run_manager = (
|
|
315
314
|
run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
|
316
315
|
)
|
langchain/chains/base.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
"""Base interface that all chains should implement."""
|
|
2
2
|
|
|
3
|
+
import builtins
|
|
3
4
|
import inspect
|
|
4
5
|
import json
|
|
5
6
|
import logging
|
|
6
7
|
import warnings
|
|
7
8
|
from abc import ABC, abstractmethod
|
|
8
9
|
from pathlib import Path
|
|
9
|
-
from typing import Any,
|
|
10
|
+
from typing import Any, Optional, Union, cast
|
|
10
11
|
|
|
11
12
|
import yaml
|
|
12
13
|
from langchain_core._api import deprecated
|
|
@@ -46,7 +47,7 @@ def _get_verbosity() -> bool:
|
|
|
46
47
|
return get_verbose()
|
|
47
48
|
|
|
48
49
|
|
|
49
|
-
class Chain(RunnableSerializable[
|
|
50
|
+
class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|
50
51
|
"""Abstract base class for creating structured sequences of calls to components.
|
|
51
52
|
|
|
52
53
|
Chains should be used to encode a sequence of calls to components like
|
|
@@ -86,13 +87,13 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
86
87
|
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs
|
|
87
88
|
will be printed to the console. Defaults to the global `verbose` value,
|
|
88
89
|
accessible via `langchain.globals.get_verbose()`."""
|
|
89
|
-
tags: Optional[
|
|
90
|
+
tags: Optional[list[str]] = None
|
|
90
91
|
"""Optional list of tags associated with the chain. Defaults to None.
|
|
91
92
|
These tags will be associated with each call to this chain,
|
|
92
93
|
and passed as arguments to the handlers defined in `callbacks`.
|
|
93
94
|
You can use these to eg identify a specific instance of a chain with its use case.
|
|
94
95
|
"""
|
|
95
|
-
metadata: Optional[
|
|
96
|
+
metadata: Optional[dict[str, Any]] = None
|
|
96
97
|
"""Optional metadata associated with the chain. Defaults to None.
|
|
97
98
|
This metadata will be associated with each call to this chain,
|
|
98
99
|
and passed as arguments to the handlers defined in `callbacks`.
|
|
@@ -107,26 +108,22 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
107
108
|
|
|
108
109
|
def get_input_schema(
|
|
109
110
|
self, config: Optional[RunnableConfig] = None
|
|
110
|
-
) ->
|
|
111
|
+
) -> type[BaseModel]:
|
|
111
112
|
# This is correct, but pydantic typings/mypy don't think so.
|
|
112
|
-
return create_model(
|
|
113
|
-
"ChainInput", **{k: (Any, None) for k in self.input_keys}
|
|
114
|
-
)
|
|
113
|
+
return create_model("ChainInput", **{k: (Any, None) for k in self.input_keys})
|
|
115
114
|
|
|
116
115
|
def get_output_schema(
|
|
117
116
|
self, config: Optional[RunnableConfig] = None
|
|
118
|
-
) ->
|
|
117
|
+
) -> type[BaseModel]:
|
|
119
118
|
# This is correct, but pydantic typings/mypy don't think so.
|
|
120
|
-
return create_model(
|
|
121
|
-
"ChainOutput", **{k: (Any, None) for k in self.output_keys}
|
|
122
|
-
)
|
|
119
|
+
return create_model("ChainOutput", **{k: (Any, None) for k in self.output_keys})
|
|
123
120
|
|
|
124
121
|
def invoke(
|
|
125
122
|
self,
|
|
126
|
-
input:
|
|
123
|
+
input: dict[str, Any],
|
|
127
124
|
config: Optional[RunnableConfig] = None,
|
|
128
125
|
**kwargs: Any,
|
|
129
|
-
) ->
|
|
126
|
+
) -> dict[str, Any]:
|
|
130
127
|
config = ensure_config(config)
|
|
131
128
|
callbacks = config.get("callbacks")
|
|
132
129
|
tags = config.get("tags")
|
|
@@ -162,7 +159,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
162
159
|
else self._call(inputs)
|
|
163
160
|
)
|
|
164
161
|
|
|
165
|
-
final_outputs:
|
|
162
|
+
final_outputs: dict[str, Any] = self.prep_outputs(
|
|
166
163
|
inputs, outputs, return_only_outputs
|
|
167
164
|
)
|
|
168
165
|
except BaseException as e:
|
|
@@ -176,10 +173,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
176
173
|
|
|
177
174
|
async def ainvoke(
|
|
178
175
|
self,
|
|
179
|
-
input:
|
|
176
|
+
input: dict[str, Any],
|
|
180
177
|
config: Optional[RunnableConfig] = None,
|
|
181
178
|
**kwargs: Any,
|
|
182
|
-
) ->
|
|
179
|
+
) -> dict[str, Any]:
|
|
183
180
|
config = ensure_config(config)
|
|
184
181
|
callbacks = config.get("callbacks")
|
|
185
182
|
tags = config.get("tags")
|
|
@@ -213,7 +210,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
213
210
|
if new_arg_supported
|
|
214
211
|
else await self._acall(inputs)
|
|
215
212
|
)
|
|
216
|
-
final_outputs:
|
|
213
|
+
final_outputs: dict[str, Any] = await self.aprep_outputs(
|
|
217
214
|
inputs, outputs, return_only_outputs
|
|
218
215
|
)
|
|
219
216
|
except BaseException as e:
|
|
@@ -231,7 +228,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
231
228
|
|
|
232
229
|
@model_validator(mode="before")
|
|
233
230
|
@classmethod
|
|
234
|
-
def raise_callback_manager_deprecation(cls, values:
|
|
231
|
+
def raise_callback_manager_deprecation(cls, values: dict) -> Any:
|
|
235
232
|
"""Raise deprecation warning if callback_manager is used."""
|
|
236
233
|
if values.get("callback_manager") is not None:
|
|
237
234
|
if values.get("callbacks") is not None:
|
|
@@ -261,15 +258,15 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
261
258
|
|
|
262
259
|
@property
|
|
263
260
|
@abstractmethod
|
|
264
|
-
def input_keys(self) ->
|
|
261
|
+
def input_keys(self) -> list[str]:
|
|
265
262
|
"""Keys expected to be in the chain input."""
|
|
266
263
|
|
|
267
264
|
@property
|
|
268
265
|
@abstractmethod
|
|
269
|
-
def output_keys(self) ->
|
|
266
|
+
def output_keys(self) -> list[str]:
|
|
270
267
|
"""Keys expected to be in the chain output."""
|
|
271
268
|
|
|
272
|
-
def _validate_inputs(self, inputs:
|
|
269
|
+
def _validate_inputs(self, inputs: dict[str, Any]) -> None:
|
|
273
270
|
"""Check that all inputs are present."""
|
|
274
271
|
if not isinstance(inputs, dict):
|
|
275
272
|
_input_keys = set(self.input_keys)
|
|
@@ -289,7 +286,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
289
286
|
if missing_keys:
|
|
290
287
|
raise ValueError(f"Missing some input keys: {missing_keys}")
|
|
291
288
|
|
|
292
|
-
def _validate_outputs(self, outputs:
|
|
289
|
+
def _validate_outputs(self, outputs: dict[str, Any]) -> None:
|
|
293
290
|
missing_keys = set(self.output_keys).difference(outputs)
|
|
294
291
|
if missing_keys:
|
|
295
292
|
raise ValueError(f"Missing some output keys: {missing_keys}")
|
|
@@ -297,9 +294,9 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
297
294
|
@abstractmethod
|
|
298
295
|
def _call(
|
|
299
296
|
self,
|
|
300
|
-
inputs:
|
|
297
|
+
inputs: dict[str, Any],
|
|
301
298
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
302
|
-
) ->
|
|
299
|
+
) -> dict[str, Any]:
|
|
303
300
|
"""Execute the chain.
|
|
304
301
|
|
|
305
302
|
This is a private method that is not user-facing. It is only called within
|
|
@@ -319,9 +316,9 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
319
316
|
|
|
320
317
|
async def _acall(
|
|
321
318
|
self,
|
|
322
|
-
inputs:
|
|
319
|
+
inputs: dict[str, Any],
|
|
323
320
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
|
324
|
-
) ->
|
|
321
|
+
) -> dict[str, Any]:
|
|
325
322
|
"""Asynchronously execute the chain.
|
|
326
323
|
|
|
327
324
|
This is a private method that is not user-facing. It is only called within
|
|
@@ -345,15 +342,15 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
345
342
|
@deprecated("0.1.0", alternative="invoke", removal="1.0")
|
|
346
343
|
def __call__(
|
|
347
344
|
self,
|
|
348
|
-
inputs: Union[
|
|
345
|
+
inputs: Union[dict[str, Any], Any],
|
|
349
346
|
return_only_outputs: bool = False,
|
|
350
347
|
callbacks: Callbacks = None,
|
|
351
348
|
*,
|
|
352
|
-
tags: Optional[
|
|
353
|
-
metadata: Optional[
|
|
349
|
+
tags: Optional[list[str]] = None,
|
|
350
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
354
351
|
run_name: Optional[str] = None,
|
|
355
352
|
include_run_info: bool = False,
|
|
356
|
-
) ->
|
|
353
|
+
) -> dict[str, Any]:
|
|
357
354
|
"""Execute the chain.
|
|
358
355
|
|
|
359
356
|
Args:
|
|
@@ -396,15 +393,15 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
396
393
|
@deprecated("0.1.0", alternative="ainvoke", removal="1.0")
|
|
397
394
|
async def acall(
|
|
398
395
|
self,
|
|
399
|
-
inputs: Union[
|
|
396
|
+
inputs: Union[dict[str, Any], Any],
|
|
400
397
|
return_only_outputs: bool = False,
|
|
401
398
|
callbacks: Callbacks = None,
|
|
402
399
|
*,
|
|
403
|
-
tags: Optional[
|
|
404
|
-
metadata: Optional[
|
|
400
|
+
tags: Optional[list[str]] = None,
|
|
401
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
405
402
|
run_name: Optional[str] = None,
|
|
406
403
|
include_run_info: bool = False,
|
|
407
|
-
) ->
|
|
404
|
+
) -> dict[str, Any]:
|
|
408
405
|
"""Asynchronously execute the chain.
|
|
409
406
|
|
|
410
407
|
Args:
|
|
@@ -445,10 +442,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
445
442
|
|
|
446
443
|
def prep_outputs(
|
|
447
444
|
self,
|
|
448
|
-
inputs:
|
|
449
|
-
outputs:
|
|
445
|
+
inputs: dict[str, str],
|
|
446
|
+
outputs: dict[str, str],
|
|
450
447
|
return_only_outputs: bool = False,
|
|
451
|
-
) ->
|
|
448
|
+
) -> dict[str, str]:
|
|
452
449
|
"""Validate and prepare chain outputs, and save info about this run to memory.
|
|
453
450
|
|
|
454
451
|
Args:
|
|
@@ -471,10 +468,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
471
468
|
|
|
472
469
|
async def aprep_outputs(
|
|
473
470
|
self,
|
|
474
|
-
inputs:
|
|
475
|
-
outputs:
|
|
471
|
+
inputs: dict[str, str],
|
|
472
|
+
outputs: dict[str, str],
|
|
476
473
|
return_only_outputs: bool = False,
|
|
477
|
-
) ->
|
|
474
|
+
) -> dict[str, str]:
|
|
478
475
|
"""Validate and prepare chain outputs, and save info about this run to memory.
|
|
479
476
|
|
|
480
477
|
Args:
|
|
@@ -495,7 +492,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
495
492
|
else:
|
|
496
493
|
return {**inputs, **outputs}
|
|
497
494
|
|
|
498
|
-
def prep_inputs(self, inputs: Union[
|
|
495
|
+
def prep_inputs(self, inputs: Union[dict[str, Any], Any]) -> dict[str, str]:
|
|
499
496
|
"""Prepare chain inputs, including adding inputs from memory.
|
|
500
497
|
|
|
501
498
|
Args:
|
|
@@ -519,7 +516,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
519
516
|
inputs = dict(inputs, **external_context)
|
|
520
517
|
return inputs
|
|
521
518
|
|
|
522
|
-
async def aprep_inputs(self, inputs: Union[
|
|
519
|
+
async def aprep_inputs(self, inputs: Union[dict[str, Any], Any]) -> dict[str, str]:
|
|
523
520
|
"""Prepare chain inputs, including adding inputs from memory.
|
|
524
521
|
|
|
525
522
|
Args:
|
|
@@ -557,8 +554,8 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
557
554
|
self,
|
|
558
555
|
*args: Any,
|
|
559
556
|
callbacks: Callbacks = None,
|
|
560
|
-
tags: Optional[
|
|
561
|
-
metadata: Optional[
|
|
557
|
+
tags: Optional[list[str]] = None,
|
|
558
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
562
559
|
**kwargs: Any,
|
|
563
560
|
) -> Any:
|
|
564
561
|
"""Convenience method for executing chain.
|
|
@@ -628,8 +625,8 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
628
625
|
self,
|
|
629
626
|
*args: Any,
|
|
630
627
|
callbacks: Callbacks = None,
|
|
631
|
-
tags: Optional[
|
|
632
|
-
metadata: Optional[
|
|
628
|
+
tags: Optional[list[str]] = None,
|
|
629
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
633
630
|
**kwargs: Any,
|
|
634
631
|
) -> Any:
|
|
635
632
|
"""Convenience method for executing chain.
|
|
@@ -695,7 +692,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
695
692
|
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
|
696
693
|
)
|
|
697
694
|
|
|
698
|
-
def dict(self, **kwargs: Any) ->
|
|
695
|
+
def dict(self, **kwargs: Any) -> dict:
|
|
699
696
|
"""Dictionary representation of chain.
|
|
700
697
|
|
|
701
698
|
Expects `Chain._chain_type` property to be implemented and for memory to be
|
|
@@ -763,7 +760,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|
|
763
760
|
|
|
764
761
|
@deprecated("0.1.0", alternative="batch", removal="1.0")
|
|
765
762
|
def apply(
|
|
766
|
-
self, input_list:
|
|
767
|
-
) ->
|
|
763
|
+
self, input_list: list[builtins.dict[str, Any]], callbacks: Callbacks = None
|
|
764
|
+
) -> list[builtins.dict[str, str]]:
|
|
768
765
|
"""Call the chain on all inputs in the list."""
|
|
769
766
|
return [self(inputs, callbacks=callbacks) for inputs in input_list]
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Base interface for chains combining documents."""
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from typing import Any,
|
|
4
|
+
from typing import Any, Optional
|
|
5
5
|
|
|
6
6
|
from langchain_core._api import deprecated
|
|
7
7
|
from langchain_core.callbacks import (
|
|
@@ -47,22 +47,22 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
|
|
47
47
|
|
|
48
48
|
def get_input_schema(
|
|
49
49
|
self, config: Optional[RunnableConfig] = None
|
|
50
|
-
) ->
|
|
50
|
+
) -> type[BaseModel]:
|
|
51
51
|
return create_model(
|
|
52
52
|
"CombineDocumentsInput",
|
|
53
|
-
**{self.input_key: (
|
|
53
|
+
**{self.input_key: (list[Document], None)},
|
|
54
54
|
)
|
|
55
55
|
|
|
56
56
|
def get_output_schema(
|
|
57
57
|
self, config: Optional[RunnableConfig] = None
|
|
58
|
-
) ->
|
|
58
|
+
) -> type[BaseModel]:
|
|
59
59
|
return create_model(
|
|
60
60
|
"CombineDocumentsOutput",
|
|
61
|
-
**{self.output_key: (str, None)},
|
|
61
|
+
**{self.output_key: (str, None)},
|
|
62
62
|
)
|
|
63
63
|
|
|
64
64
|
@property
|
|
65
|
-
def input_keys(self) ->
|
|
65
|
+
def input_keys(self) -> list[str]:
|
|
66
66
|
"""Expect input key.
|
|
67
67
|
|
|
68
68
|
:meta private:
|
|
@@ -70,14 +70,14 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
|
|
70
70
|
return [self.input_key]
|
|
71
71
|
|
|
72
72
|
@property
|
|
73
|
-
def output_keys(self) ->
|
|
73
|
+
def output_keys(self) -> list[str]:
|
|
74
74
|
"""Return output key.
|
|
75
75
|
|
|
76
76
|
:meta private:
|
|
77
77
|
"""
|
|
78
78
|
return [self.output_key]
|
|
79
79
|
|
|
80
|
-
def prompt_length(self, docs:
|
|
80
|
+
def prompt_length(self, docs: list[Document], **kwargs: Any) -> Optional[int]:
|
|
81
81
|
"""Return the prompt length given the documents passed in.
|
|
82
82
|
|
|
83
83
|
This can be used by a caller to determine whether passing in a list
|
|
@@ -96,7 +96,7 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
|
|
96
96
|
return None
|
|
97
97
|
|
|
98
98
|
@abstractmethod
|
|
99
|
-
def combine_docs(self, docs:
|
|
99
|
+
def combine_docs(self, docs: list[Document], **kwargs: Any) -> tuple[str, dict]:
|
|
100
100
|
"""Combine documents into a single string.
|
|
101
101
|
|
|
102
102
|
Args:
|
|
@@ -111,8 +111,8 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
|
|
111
111
|
|
|
112
112
|
@abstractmethod
|
|
113
113
|
async def acombine_docs(
|
|
114
|
-
self, docs:
|
|
115
|
-
) ->
|
|
114
|
+
self, docs: list[Document], **kwargs: Any
|
|
115
|
+
) -> tuple[str, dict]:
|
|
116
116
|
"""Combine documents into a single string.
|
|
117
117
|
|
|
118
118
|
Args:
|
|
@@ -127,9 +127,9 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
|
|
127
127
|
|
|
128
128
|
def _call(
|
|
129
129
|
self,
|
|
130
|
-
inputs:
|
|
130
|
+
inputs: dict[str, list[Document]],
|
|
131
131
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
132
|
-
) ->
|
|
132
|
+
) -> dict[str, str]:
|
|
133
133
|
"""Prepare inputs, call combine docs, prepare outputs."""
|
|
134
134
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
135
135
|
docs = inputs[self.input_key]
|
|
@@ -143,9 +143,9 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
|
|
143
143
|
|
|
144
144
|
async def _acall(
|
|
145
145
|
self,
|
|
146
|
-
inputs:
|
|
146
|
+
inputs: dict[str, list[Document]],
|
|
147
147
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
|
148
|
-
) ->
|
|
148
|
+
) -> dict[str, str]:
|
|
149
149
|
"""Prepare inputs, call combine docs, prepare outputs."""
|
|
150
150
|
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
|
151
151
|
docs = inputs[self.input_key]
|
|
@@ -229,7 +229,7 @@ class AnalyzeDocumentChain(Chain):
|
|
|
229
229
|
combine_docs_chain: BaseCombineDocumentsChain
|
|
230
230
|
|
|
231
231
|
@property
|
|
232
|
-
def input_keys(self) ->
|
|
232
|
+
def input_keys(self) -> list[str]:
|
|
233
233
|
"""Expect input key.
|
|
234
234
|
|
|
235
235
|
:meta private:
|
|
@@ -237,7 +237,7 @@ class AnalyzeDocumentChain(Chain):
|
|
|
237
237
|
return [self.input_key]
|
|
238
238
|
|
|
239
239
|
@property
|
|
240
|
-
def output_keys(self) ->
|
|
240
|
+
def output_keys(self) -> list[str]:
|
|
241
241
|
"""Return output key.
|
|
242
242
|
|
|
243
243
|
:meta private:
|
|
@@ -246,28 +246,28 @@ class AnalyzeDocumentChain(Chain):
|
|
|
246
246
|
|
|
247
247
|
def get_input_schema(
|
|
248
248
|
self, config: Optional[RunnableConfig] = None
|
|
249
|
-
) ->
|
|
249
|
+
) -> type[BaseModel]:
|
|
250
250
|
return create_model(
|
|
251
251
|
"AnalyzeDocumentChain",
|
|
252
|
-
**{self.input_key: (str, None)},
|
|
252
|
+
**{self.input_key: (str, None)},
|
|
253
253
|
)
|
|
254
254
|
|
|
255
255
|
def get_output_schema(
|
|
256
256
|
self, config: Optional[RunnableConfig] = None
|
|
257
|
-
) ->
|
|
257
|
+
) -> type[BaseModel]:
|
|
258
258
|
return self.combine_docs_chain.get_output_schema(config)
|
|
259
259
|
|
|
260
260
|
def _call(
|
|
261
261
|
self,
|
|
262
|
-
inputs:
|
|
262
|
+
inputs: dict[str, str],
|
|
263
263
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
264
|
-
) ->
|
|
264
|
+
) -> dict[str, str]:
|
|
265
265
|
"""Split document into chunks and pass to CombineDocumentsChain."""
|
|
266
266
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
267
267
|
document = inputs[self.input_key]
|
|
268
268
|
docs = self.text_splitter.create_documents([document])
|
|
269
269
|
# Other keys are assumed to be needed for LLM prediction
|
|
270
|
-
other_keys:
|
|
270
|
+
other_keys: dict = {k: v for k, v in inputs.items() if k != self.input_key}
|
|
271
271
|
other_keys[self.combine_docs_chain.input_key] = docs
|
|
272
272
|
return self.combine_docs_chain(
|
|
273
273
|
other_keys, return_only_outputs=True, callbacks=_run_manager.get_child()
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Optional
|
|
6
6
|
|
|
7
7
|
from langchain_core._api import deprecated
|
|
8
8
|
from langchain_core.callbacks import Callbacks
|
|
@@ -113,20 +113,20 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|
|
113
113
|
|
|
114
114
|
def get_output_schema(
|
|
115
115
|
self, config: Optional[RunnableConfig] = None
|
|
116
|
-
) ->
|
|
116
|
+
) -> type[BaseModel]:
|
|
117
117
|
if self.return_intermediate_steps:
|
|
118
118
|
return create_model(
|
|
119
119
|
"MapReduceDocumentsOutput",
|
|
120
120
|
**{
|
|
121
121
|
self.output_key: (str, None),
|
|
122
|
-
"intermediate_steps": (
|
|
123
|
-
},
|
|
122
|
+
"intermediate_steps": (list[str], None),
|
|
123
|
+
},
|
|
124
124
|
)
|
|
125
125
|
|
|
126
126
|
return super().get_output_schema(config)
|
|
127
127
|
|
|
128
128
|
@property
|
|
129
|
-
def output_keys(self) ->
|
|
129
|
+
def output_keys(self) -> list[str]:
|
|
130
130
|
"""Expect input key.
|
|
131
131
|
|
|
132
132
|
:meta private:
|
|
@@ -143,7 +143,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|
|
143
143
|
|
|
144
144
|
@model_validator(mode="before")
|
|
145
145
|
@classmethod
|
|
146
|
-
def get_reduce_chain(cls, values:
|
|
146
|
+
def get_reduce_chain(cls, values: dict) -> Any:
|
|
147
147
|
"""For backwards compatibility."""
|
|
148
148
|
if "combine_document_chain" in values:
|
|
149
149
|
if "reduce_documents_chain" in values:
|
|
@@ -167,7 +167,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|
|
167
167
|
|
|
168
168
|
@model_validator(mode="before")
|
|
169
169
|
@classmethod
|
|
170
|
-
def get_return_intermediate_steps(cls, values:
|
|
170
|
+
def get_return_intermediate_steps(cls, values: dict) -> Any:
|
|
171
171
|
"""For backwards compatibility."""
|
|
172
172
|
if "return_map_steps" in values:
|
|
173
173
|
values["return_intermediate_steps"] = values["return_map_steps"]
|
|
@@ -176,7 +176,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|
|
176
176
|
|
|
177
177
|
@model_validator(mode="before")
|
|
178
178
|
@classmethod
|
|
179
|
-
def get_default_document_variable_name(cls, values:
|
|
179
|
+
def get_default_document_variable_name(cls, values: dict) -> Any:
|
|
180
180
|
"""Get default document variable name, if not provided."""
|
|
181
181
|
if "llm_chain" not in values:
|
|
182
182
|
raise ValueError("llm_chain must be provided")
|
|
@@ -227,11 +227,11 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|
|
227
227
|
|
|
228
228
|
def combine_docs(
|
|
229
229
|
self,
|
|
230
|
-
docs:
|
|
230
|
+
docs: list[Document],
|
|
231
231
|
token_max: Optional[int] = None,
|
|
232
232
|
callbacks: Callbacks = None,
|
|
233
233
|
**kwargs: Any,
|
|
234
|
-
) ->
|
|
234
|
+
) -> tuple[str, dict]:
|
|
235
235
|
"""Combine documents in a map reduce manner.
|
|
236
236
|
|
|
237
237
|
Combine by mapping first chain over all documents, then reducing the results.
|
|
@@ -258,11 +258,11 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|
|
258
258
|
|
|
259
259
|
async def acombine_docs(
|
|
260
260
|
self,
|
|
261
|
-
docs:
|
|
261
|
+
docs: list[Document],
|
|
262
262
|
token_max: Optional[int] = None,
|
|
263
263
|
callbacks: Callbacks = None,
|
|
264
264
|
**kwargs: Any,
|
|
265
|
-
) ->
|
|
265
|
+
) -> tuple[str, dict]:
|
|
266
266
|
"""Combine documents in a map reduce manner.
|
|
267
267
|
|
|
268
268
|
Combine by mapping first chain over all documents, then reducing the results.
|