langchain 0.2.15__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain might be problematic. Click here for more details.
- langchain/agents/agent.py +23 -19
- langchain/agents/agent_toolkits/vectorstore/toolkit.py +10 -7
- langchain/agents/chat/base.py +1 -1
- langchain/agents/conversational/base.py +1 -1
- langchain/agents/conversational_chat/base.py +1 -1
- langchain/agents/mrkl/base.py +1 -1
- langchain/agents/openai_assistant/base.py +8 -7
- langchain/agents/openai_functions_agent/base.py +6 -5
- langchain/agents/openai_functions_multi_agent/base.py +6 -5
- langchain/agents/openai_tools/base.py +8 -3
- langchain/agents/react/base.py +1 -1
- langchain/agents/self_ask_with_search/base.py +1 -1
- langchain/agents/structured_chat/base.py +1 -1
- langchain/agents/structured_chat/output_parser.py +1 -1
- langchain/chains/api/base.py +14 -12
- langchain/chains/base.py +17 -9
- langchain/chains/combine_documents/base.py +1 -1
- langchain/chains/combine_documents/map_reduce.py +14 -10
- langchain/chains/combine_documents/map_rerank.py +17 -14
- langchain/chains/combine_documents/reduce.py +5 -3
- langchain/chains/combine_documents/refine.py +11 -8
- langchain/chains/combine_documents/stuff.py +8 -6
- langchain/chains/constitutional_ai/models.py +1 -1
- langchain/chains/conversation/base.py +13 -11
- langchain/chains/conversational_retrieval/base.py +10 -8
- langchain/chains/elasticsearch_database/base.py +11 -9
- langchain/chains/flare/base.py +5 -2
- langchain/chains/hyde/base.py +6 -4
- langchain/chains/llm.py +7 -7
- langchain/chains/llm_checker/base.py +8 -6
- langchain/chains/llm_math/base.py +8 -6
- langchain/chains/llm_summarization_checker/base.py +8 -6
- langchain/chains/mapreduce.py +5 -3
- langchain/chains/moderation.py +14 -12
- langchain/chains/natbot/base.py +8 -6
- langchain/chains/openai_functions/base.py +3 -3
- langchain/chains/openai_functions/citation_fuzzy_match.py +1 -1
- langchain/chains/openai_functions/extraction.py +8 -4
- langchain/chains/openai_functions/qa_with_structure.py +5 -2
- langchain/chains/openai_functions/tagging.py +5 -2
- langchain/chains/openai_tools/extraction.py +2 -2
- langchain/chains/prompt_selector.py +1 -1
- langchain/chains/qa_generation/base.py +1 -1
- langchain/chains/qa_with_sources/base.py +8 -6
- langchain/chains/qa_with_sources/retrieval.py +1 -1
- langchain/chains/qa_with_sources/vector_db.py +4 -3
- langchain/chains/query_constructor/schema.py +5 -4
- langchain/chains/retrieval_qa/base.py +12 -9
- langchain/chains/router/base.py +5 -3
- langchain/chains/router/embedding_router.py +5 -3
- langchain/chains/router/llm_router.py +6 -5
- langchain/chains/sequential.py +17 -13
- langchain/chains/structured_output/base.py +8 -8
- langchain/chains/transform.py +1 -1
- langchain/chat_models/base.py +2 -2
- langchain/evaluation/agents/trajectory_eval_chain.py +4 -3
- langchain/evaluation/comparison/eval_chain.py +4 -3
- langchain/evaluation/criteria/eval_chain.py +4 -3
- langchain/evaluation/embedding_distance/base.py +4 -3
- langchain/evaluation/qa/eval_chain.py +7 -4
- langchain/evaluation/qa/generate_chain.py +1 -1
- langchain/evaluation/scoring/eval_chain.py +4 -3
- langchain/evaluation/string_distance/base.py +1 -1
- langchain/indexes/vectorstore.py +9 -7
- langchain/memory/chat_memory.py +1 -1
- langchain/memory/combined.py +5 -3
- langchain/memory/entity.py +4 -3
- langchain/memory/summary.py +1 -1
- langchain/memory/vectorstore.py +1 -1
- langchain/memory/vectorstore_token_buffer_memory.py +1 -1
- langchain/output_parsers/fix.py +3 -2
- langchain/output_parsers/pandas_dataframe.py +3 -2
- langchain/output_parsers/retry.py +4 -3
- langchain/output_parsers/structured.py +1 -1
- langchain/output_parsers/yaml.py +5 -2
- langchain/pydantic_v1/__init__.py +20 -0
- langchain/pydantic_v1/dataclasses.py +20 -0
- langchain/pydantic_v1/main.py +20 -0
- langchain/retrievers/contextual_compression.py +4 -2
- langchain/retrievers/document_compressors/base.py +4 -2
- langchain/retrievers/document_compressors/chain_extract.py +4 -2
- langchain/retrievers/document_compressors/chain_filter.py +4 -2
- langchain/retrievers/document_compressors/cohere_rerank.py +8 -6
- langchain/retrievers/document_compressors/cross_encoder_rerank.py +5 -3
- langchain/retrievers/document_compressors/embeddings_filter.py +5 -4
- langchain/retrievers/document_compressors/listwise_rerank.py +4 -3
- langchain/retrievers/ensemble.py +18 -14
- langchain/retrievers/multi_vector.py +5 -4
- langchain/retrievers/self_query/base.py +19 -8
- langchain/retrievers/time_weighted_retriever.py +4 -3
- langchain/smith/evaluation/config.py +7 -5
- {langchain-0.2.15.dist-info → langchain-0.3.0.dist-info}/METADATA +5 -5
- {langchain-0.2.15.dist-info → langchain-0.3.0.dist-info}/RECORD +96 -96
- {langchain-0.2.15.dist-info → langchain-0.3.0.dist-info}/LICENSE +0 -0
- {langchain-0.2.15.dist-info → langchain-0.3.0.dist-info}/WHEEL +0 -0
- {langchain-0.2.15.dist-info → langchain-0.3.0.dist-info}/entry_points.txt +0 -0
|
@@ -6,9 +6,10 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast
|
|
|
6
6
|
|
|
7
7
|
from langchain_core.callbacks import Callbacks
|
|
8
8
|
from langchain_core.documents import Document
|
|
9
|
-
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
|
10
9
|
from langchain_core.runnables.config import RunnableConfig
|
|
11
10
|
from langchain_core.runnables.utils import create_model
|
|
11
|
+
from pydantic import BaseModel, ConfigDict, model_validator
|
|
12
|
+
from typing_extensions import Self
|
|
12
13
|
|
|
13
14
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
|
14
15
|
from langchain.chains.llm import LLMChain
|
|
@@ -74,9 +75,10 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
|
|
74
75
|
"""Return intermediate steps.
|
|
75
76
|
Intermediate steps include the results of calling llm_chain on each document."""
|
|
76
77
|
|
|
77
|
-
|
|
78
|
-
arbitrary_types_allowed
|
|
79
|
-
extra
|
|
78
|
+
model_config = ConfigDict(
|
|
79
|
+
arbitrary_types_allowed=True,
|
|
80
|
+
extra="forbid",
|
|
81
|
+
)
|
|
80
82
|
|
|
81
83
|
def get_output_schema(
|
|
82
84
|
self, config: Optional[RunnableConfig] = None
|
|
@@ -104,30 +106,31 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
|
|
104
106
|
_output_keys += self.metadata_keys
|
|
105
107
|
return _output_keys
|
|
106
108
|
|
|
107
|
-
@
|
|
108
|
-
def validate_llm_output(
|
|
109
|
+
@model_validator(mode="after")
|
|
110
|
+
def validate_llm_output(self) -> Self:
|
|
109
111
|
"""Validate that the combine chain outputs a dictionary."""
|
|
110
|
-
output_parser =
|
|
112
|
+
output_parser = self.llm_chain.prompt.output_parser
|
|
111
113
|
if not isinstance(output_parser, RegexParser):
|
|
112
114
|
raise ValueError(
|
|
113
115
|
"Output parser of llm_chain should be a RegexParser,"
|
|
114
116
|
f" got {output_parser}"
|
|
115
117
|
)
|
|
116
118
|
output_keys = output_parser.output_keys
|
|
117
|
-
if
|
|
119
|
+
if self.rank_key not in output_keys:
|
|
118
120
|
raise ValueError(
|
|
119
|
-
f"Got {
|
|
121
|
+
f"Got {self.rank_key} as key to rank on, but did not find "
|
|
120
122
|
f"it in the llm_chain output keys ({output_keys})"
|
|
121
123
|
)
|
|
122
|
-
if
|
|
124
|
+
if self.answer_key not in output_keys:
|
|
123
125
|
raise ValueError(
|
|
124
|
-
f"Got {
|
|
126
|
+
f"Got {self.answer_key} as key to return, but did not find "
|
|
125
127
|
f"it in the llm_chain output keys ({output_keys})"
|
|
126
128
|
)
|
|
127
|
-
return
|
|
129
|
+
return self
|
|
128
130
|
|
|
129
|
-
@
|
|
130
|
-
|
|
131
|
+
@model_validator(mode="before")
|
|
132
|
+
@classmethod
|
|
133
|
+
def get_default_document_variable_name(cls, values: Dict) -> Any:
|
|
131
134
|
"""Get default document variable name, if not provided."""
|
|
132
135
|
if "llm_chain" not in values:
|
|
133
136
|
raise ValueError("llm_chain must be provided")
|
|
@@ -6,6 +6,7 @@ from typing import Any, Callable, List, Optional, Protocol, Tuple
|
|
|
6
6
|
|
|
7
7
|
from langchain_core.callbacks import Callbacks
|
|
8
8
|
from langchain_core.documents import Document
|
|
9
|
+
from pydantic import ConfigDict
|
|
9
10
|
|
|
10
11
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
|
11
12
|
|
|
@@ -204,9 +205,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
|
|
204
205
|
If None, it will keep trying to collapse documents to fit token_max.
|
|
205
206
|
Otherwise, after it reaches the max number, it will throw an error"""
|
|
206
207
|
|
|
207
|
-
|
|
208
|
-
arbitrary_types_allowed
|
|
209
|
-
extra
|
|
208
|
+
model_config = ConfigDict(
|
|
209
|
+
arbitrary_types_allowed=True,
|
|
210
|
+
extra="forbid",
|
|
211
|
+
)
|
|
210
212
|
|
|
211
213
|
@property
|
|
212
214
|
def _collapse_chain(self) -> BaseCombineDocumentsChain:
|
|
@@ -8,7 +8,7 @@ from langchain_core.callbacks import Callbacks
|
|
|
8
8
|
from langchain_core.documents import Document
|
|
9
9
|
from langchain_core.prompts import BasePromptTemplate, format_document
|
|
10
10
|
from langchain_core.prompts.prompt import PromptTemplate
|
|
11
|
-
from
|
|
11
|
+
from pydantic import ConfigDict, Field, model_validator
|
|
12
12
|
|
|
13
13
|
from langchain.chains.combine_documents.base import (
|
|
14
14
|
BaseCombineDocumentsChain,
|
|
@@ -98,20 +98,23 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
|
|
98
98
|
_output_keys = _output_keys + ["intermediate_steps"]
|
|
99
99
|
return _output_keys
|
|
100
100
|
|
|
101
|
-
|
|
102
|
-
arbitrary_types_allowed
|
|
103
|
-
extra
|
|
101
|
+
model_config = ConfigDict(
|
|
102
|
+
arbitrary_types_allowed=True,
|
|
103
|
+
extra="forbid",
|
|
104
|
+
)
|
|
104
105
|
|
|
105
|
-
@
|
|
106
|
-
|
|
106
|
+
@model_validator(mode="before")
|
|
107
|
+
@classmethod
|
|
108
|
+
def get_return_intermediate_steps(cls, values: Dict) -> Any:
|
|
107
109
|
"""For backwards compatibility."""
|
|
108
110
|
if "return_refine_steps" in values:
|
|
109
111
|
values["return_intermediate_steps"] = values["return_refine_steps"]
|
|
110
112
|
del values["return_refine_steps"]
|
|
111
113
|
return values
|
|
112
114
|
|
|
113
|
-
@
|
|
114
|
-
|
|
115
|
+
@model_validator(mode="before")
|
|
116
|
+
@classmethod
|
|
117
|
+
def get_default_document_variable_name(cls, values: Dict) -> Any:
|
|
115
118
|
"""Get default document variable name, if not provided."""
|
|
116
119
|
if "initial_llm_chain" not in values:
|
|
117
120
|
raise ValueError("initial_llm_chain must be provided")
|
|
@@ -8,8 +8,8 @@ from langchain_core.documents import Document
|
|
|
8
8
|
from langchain_core.language_models import LanguageModelLike
|
|
9
9
|
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
|
|
10
10
|
from langchain_core.prompts import BasePromptTemplate, format_document
|
|
11
|
-
from langchain_core.pydantic_v1 import Field, root_validator
|
|
12
11
|
from langchain_core.runnables import Runnable, RunnablePassthrough
|
|
12
|
+
from pydantic import ConfigDict, Field, model_validator
|
|
13
13
|
|
|
14
14
|
from langchain.chains.combine_documents.base import (
|
|
15
15
|
DEFAULT_DOCUMENT_PROMPT,
|
|
@@ -156,12 +156,14 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
|
|
156
156
|
document_separator: str = "\n\n"
|
|
157
157
|
"""The string with which to join the formatted documents"""
|
|
158
158
|
|
|
159
|
-
|
|
160
|
-
arbitrary_types_allowed
|
|
161
|
-
extra
|
|
159
|
+
model_config = ConfigDict(
|
|
160
|
+
arbitrary_types_allowed=True,
|
|
161
|
+
extra="forbid",
|
|
162
|
+
)
|
|
162
163
|
|
|
163
|
-
@
|
|
164
|
-
|
|
164
|
+
@model_validator(mode="before")
|
|
165
|
+
@classmethod
|
|
166
|
+
def get_default_document_variable_name(cls, values: Dict) -> Any:
|
|
165
167
|
"""Get default document variable name, if not provided.
|
|
166
168
|
|
|
167
169
|
If only one variable is present in the llm_chain.prompt,
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
"""Chain that carries on a conversation and calls an LLM."""
|
|
2
2
|
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import List
|
|
4
4
|
|
|
5
5
|
from langchain_core._api import deprecated
|
|
6
6
|
from langchain_core.memory import BaseMemory
|
|
7
7
|
from langchain_core.prompts import BasePromptTemplate
|
|
8
|
-
from
|
|
8
|
+
from pydantic import ConfigDict, Field, model_validator
|
|
9
|
+
from typing_extensions import Self
|
|
9
10
|
|
|
10
11
|
from langchain.chains.conversation.prompt import PROMPT
|
|
11
12
|
from langchain.chains.llm import LLMChain
|
|
@@ -110,9 +111,10 @@ class ConversationChain(LLMChain):
|
|
|
110
111
|
input_key: str = "input" #: :meta private:
|
|
111
112
|
output_key: str = "response" #: :meta private:
|
|
112
113
|
|
|
113
|
-
|
|
114
|
-
arbitrary_types_allowed
|
|
115
|
-
extra
|
|
114
|
+
model_config = ConfigDict(
|
|
115
|
+
arbitrary_types_allowed=True,
|
|
116
|
+
extra="forbid",
|
|
117
|
+
)
|
|
116
118
|
|
|
117
119
|
@classmethod
|
|
118
120
|
def is_lc_serializable(cls) -> bool:
|
|
@@ -123,17 +125,17 @@ class ConversationChain(LLMChain):
|
|
|
123
125
|
"""Use this since so some prompt vars come from history."""
|
|
124
126
|
return [self.input_key]
|
|
125
127
|
|
|
126
|
-
@
|
|
127
|
-
def validate_prompt_input_variables(
|
|
128
|
+
@model_validator(mode="after")
|
|
129
|
+
def validate_prompt_input_variables(self) -> Self:
|
|
128
130
|
"""Validate that prompt input variables are consistent."""
|
|
129
|
-
memory_keys =
|
|
130
|
-
input_key =
|
|
131
|
+
memory_keys = self.memory.memory_variables
|
|
132
|
+
input_key = self.input_key
|
|
131
133
|
if input_key in memory_keys:
|
|
132
134
|
raise ValueError(
|
|
133
135
|
f"The input key {input_key} was also found in the memory keys "
|
|
134
136
|
f"({memory_keys}) - please provide keys that don't overlap."
|
|
135
137
|
)
|
|
136
|
-
prompt_variables =
|
|
138
|
+
prompt_variables = self.prompt.input_variables
|
|
137
139
|
expected_keys = memory_keys + [input_key]
|
|
138
140
|
if set(expected_keys) != set(prompt_variables):
|
|
139
141
|
raise ValueError(
|
|
@@ -141,4 +143,4 @@ class ConversationChain(LLMChain):
|
|
|
141
143
|
f"{prompt_variables}, but got {memory_keys} as inputs from "
|
|
142
144
|
f"memory, and {input_key} as the normal input key."
|
|
143
145
|
)
|
|
144
|
-
return
|
|
146
|
+
return self
|
|
@@ -18,10 +18,10 @@ from langchain_core.documents import Document
|
|
|
18
18
|
from langchain_core.language_models import BaseLanguageModel
|
|
19
19
|
from langchain_core.messages import BaseMessage
|
|
20
20
|
from langchain_core.prompts import BasePromptTemplate
|
|
21
|
-
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
|
22
21
|
from langchain_core.retrievers import BaseRetriever
|
|
23
22
|
from langchain_core.runnables import RunnableConfig
|
|
24
23
|
from langchain_core.vectorstores import VectorStore
|
|
24
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
25
25
|
|
|
26
26
|
from langchain.chains.base import Chain
|
|
27
27
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
|
@@ -92,14 +92,15 @@ class BaseConversationalRetrievalChain(Chain):
|
|
|
92
92
|
get_chat_history: Optional[Callable[[List[CHAT_TURN_TYPE]], str]] = None
|
|
93
93
|
"""An optional function to get a string of the chat history.
|
|
94
94
|
If None is provided, will use a default."""
|
|
95
|
-
response_if_no_docs_found: Optional[str]
|
|
95
|
+
response_if_no_docs_found: Optional[str] = None
|
|
96
96
|
"""If specified, the chain will return a fixed response if no docs
|
|
97
97
|
are found for the question. """
|
|
98
98
|
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
arbitrary_types_allowed
|
|
102
|
-
extra
|
|
99
|
+
model_config = ConfigDict(
|
|
100
|
+
populate_by_name=True,
|
|
101
|
+
arbitrary_types_allowed=True,
|
|
102
|
+
extra="forbid",
|
|
103
|
+
)
|
|
103
104
|
|
|
104
105
|
@property
|
|
105
106
|
def input_keys(self) -> List[str]:
|
|
@@ -482,8 +483,9 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
|
|
|
482
483
|
def _chain_type(self) -> str:
|
|
483
484
|
return "chat-vector-db"
|
|
484
485
|
|
|
485
|
-
@
|
|
486
|
-
|
|
486
|
+
@model_validator(mode="before")
|
|
487
|
+
@classmethod
|
|
488
|
+
def raise_deprecation(cls, values: Dict) -> Any:
|
|
487
489
|
warnings.warn(
|
|
488
490
|
"`ChatVectorDBChain` is deprecated - "
|
|
489
491
|
"please use `from langchain.chains import ConversationalRetrievalChain`"
|
|
@@ -9,8 +9,9 @@ from langchain_core.language_models import BaseLanguageModel
|
|
|
9
9
|
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
|
|
10
10
|
from langchain_core.output_parsers.json import SimpleJsonOutputParser
|
|
11
11
|
from langchain_core.prompts import BasePromptTemplate
|
|
12
|
-
from langchain_core.pydantic_v1 import root_validator
|
|
13
12
|
from langchain_core.runnables import Runnable
|
|
13
|
+
from pydantic import ConfigDict, model_validator
|
|
14
|
+
from typing_extensions import Self
|
|
14
15
|
|
|
15
16
|
from langchain.chains.base import Chain
|
|
16
17
|
from langchain.chains.elasticsearch_database.prompts import ANSWER_PROMPT, DSL_PROMPT
|
|
@@ -39,7 +40,7 @@ class ElasticsearchDatabaseChain(Chain):
|
|
|
39
40
|
"""Chain for creating the ES query."""
|
|
40
41
|
answer_chain: Runnable
|
|
41
42
|
"""Chain for answering the user question."""
|
|
42
|
-
database: Any
|
|
43
|
+
database: Any = None
|
|
43
44
|
"""Elasticsearch database to connect to of type elasticsearch.Elasticsearch."""
|
|
44
45
|
top_k: int = 10
|
|
45
46
|
"""Number of results to return from the query"""
|
|
@@ -51,17 +52,18 @@ class ElasticsearchDatabaseChain(Chain):
|
|
|
51
52
|
return_intermediate_steps: bool = False
|
|
52
53
|
"""Whether or not to return the intermediate steps along with the final answer."""
|
|
53
54
|
|
|
54
|
-
|
|
55
|
-
arbitrary_types_allowed
|
|
56
|
-
extra
|
|
55
|
+
model_config = ConfigDict(
|
|
56
|
+
arbitrary_types_allowed=True,
|
|
57
|
+
extra="forbid",
|
|
58
|
+
)
|
|
57
59
|
|
|
58
|
-
@
|
|
59
|
-
def validate_indices(
|
|
60
|
-
if
|
|
60
|
+
@model_validator(mode="after")
|
|
61
|
+
def validate_indices(self) -> Self:
|
|
62
|
+
if self.include_indices and self.ignore_indices:
|
|
61
63
|
raise ValueError(
|
|
62
64
|
"Cannot specify both 'include_indices' and 'ignore_indices'."
|
|
63
65
|
)
|
|
64
|
-
return
|
|
66
|
+
return self
|
|
65
67
|
|
|
66
68
|
@property
|
|
67
69
|
def input_keys(self) -> List[str]:
|
langchain/chains/flare/base.py
CHANGED
|
@@ -11,9 +11,9 @@ from langchain_core.language_models import BaseLanguageModel
|
|
|
11
11
|
from langchain_core.messages import AIMessage
|
|
12
12
|
from langchain_core.output_parsers import StrOutputParser
|
|
13
13
|
from langchain_core.prompts import BasePromptTemplate
|
|
14
|
-
from langchain_core.pydantic_v1 import Field
|
|
15
14
|
from langchain_core.retrievers import BaseRetriever
|
|
16
15
|
from langchain_core.runnables import Runnable
|
|
16
|
+
from pydantic import Field
|
|
17
17
|
|
|
18
18
|
from langchain.chains.base import Chain
|
|
19
19
|
from langchain.chains.flare.prompts import (
|
|
@@ -73,7 +73,10 @@ def _low_confidence_spans(
|
|
|
73
73
|
|
|
74
74
|
class FlareChain(Chain):
|
|
75
75
|
"""Chain that combines a retriever, a question generator,
|
|
76
|
-
and a response generator.
|
|
76
|
+
and a response generator.
|
|
77
|
+
|
|
78
|
+
See [Active Retrieval Augmented Generation](https://arxiv.org/abs/2305.06983) paper.
|
|
79
|
+
"""
|
|
77
80
|
|
|
78
81
|
question_generator_chain: Runnable
|
|
79
82
|
"""Chain that generates questions from uncertain spans."""
|
langchain/chains/hyde/base.py
CHANGED
|
@@ -14,6 +14,7 @@ from langchain_core.language_models import BaseLanguageModel
|
|
|
14
14
|
from langchain_core.output_parsers import StrOutputParser
|
|
15
15
|
from langchain_core.prompts import BasePromptTemplate
|
|
16
16
|
from langchain_core.runnables import Runnable
|
|
17
|
+
from pydantic import ConfigDict
|
|
17
18
|
|
|
18
19
|
from langchain.chains.base import Chain
|
|
19
20
|
from langchain.chains.hyde.prompts import PROMPT_MAP
|
|
@@ -29,14 +30,15 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|
|
29
30
|
base_embeddings: Embeddings
|
|
30
31
|
llm_chain: Runnable
|
|
31
32
|
|
|
32
|
-
|
|
33
|
-
arbitrary_types_allowed
|
|
34
|
-
extra
|
|
33
|
+
model_config = ConfigDict(
|
|
34
|
+
arbitrary_types_allowed=True,
|
|
35
|
+
extra="forbid",
|
|
36
|
+
)
|
|
35
37
|
|
|
36
38
|
@property
|
|
37
39
|
def input_keys(self) -> List[str]:
|
|
38
40
|
"""Input keys for Hyde's LLM chain."""
|
|
39
|
-
return self.llm_chain.input_schema.
|
|
41
|
+
return self.llm_chain.input_schema.model_json_schema()["required"]
|
|
40
42
|
|
|
41
43
|
@property
|
|
42
44
|
def output_keys(self) -> List[str]:
|
langchain/chains/llm.py
CHANGED
|
@@ -17,13 +17,11 @@ from langchain_core.language_models import (
|
|
|
17
17
|
BaseLanguageModel,
|
|
18
18
|
LanguageModelInput,
|
|
19
19
|
)
|
|
20
|
-
from langchain_core.load.dump import dumpd
|
|
21
20
|
from langchain_core.messages import BaseMessage
|
|
22
21
|
from langchain_core.output_parsers import BaseLLMOutputParser, StrOutputParser
|
|
23
22
|
from langchain_core.outputs import ChatGeneration, Generation, LLMResult
|
|
24
23
|
from langchain_core.prompt_values import PromptValue
|
|
25
24
|
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
|
26
|
-
from langchain_core.pydantic_v1 import Field
|
|
27
25
|
from langchain_core.runnables import (
|
|
28
26
|
Runnable,
|
|
29
27
|
RunnableBinding,
|
|
@@ -32,6 +30,7 @@ from langchain_core.runnables import (
|
|
|
32
30
|
)
|
|
33
31
|
from langchain_core.runnables.configurable import DynamicRunnable
|
|
34
32
|
from langchain_core.utils.input import get_colored_text
|
|
33
|
+
from pydantic import ConfigDict, Field
|
|
35
34
|
|
|
36
35
|
from langchain.chains.base import Chain
|
|
37
36
|
|
|
@@ -95,9 +94,10 @@ class LLMChain(Chain):
|
|
|
95
94
|
If false, will return a bunch of extra information about the generation."""
|
|
96
95
|
llm_kwargs: dict = Field(default_factory=dict)
|
|
97
96
|
|
|
98
|
-
|
|
99
|
-
arbitrary_types_allowed
|
|
100
|
-
extra
|
|
97
|
+
model_config = ConfigDict(
|
|
98
|
+
arbitrary_types_allowed=True,
|
|
99
|
+
extra="forbid",
|
|
100
|
+
)
|
|
101
101
|
|
|
102
102
|
@property
|
|
103
103
|
def input_keys(self) -> List[str]:
|
|
@@ -240,7 +240,7 @@ class LLMChain(Chain):
|
|
|
240
240
|
callbacks, self.callbacks, self.verbose
|
|
241
241
|
)
|
|
242
242
|
run_manager = callback_manager.on_chain_start(
|
|
243
|
-
|
|
243
|
+
None,
|
|
244
244
|
{"input_list": input_list},
|
|
245
245
|
)
|
|
246
246
|
try:
|
|
@@ -260,7 +260,7 @@ class LLMChain(Chain):
|
|
|
260
260
|
callbacks, self.callbacks, self.verbose
|
|
261
261
|
)
|
|
262
262
|
run_manager = await callback_manager.on_chain_start(
|
|
263
|
-
|
|
263
|
+
None,
|
|
264
264
|
{"input_list": input_list},
|
|
265
265
|
)
|
|
266
266
|
try:
|
|
@@ -9,7 +9,7 @@ from langchain_core._api import deprecated
|
|
|
9
9
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
|
10
10
|
from langchain_core.language_models import BaseLanguageModel
|
|
11
11
|
from langchain_core.prompts import PromptTemplate
|
|
12
|
-
from
|
|
12
|
+
from pydantic import ConfigDict, model_validator
|
|
13
13
|
|
|
14
14
|
from langchain.chains.base import Chain
|
|
15
15
|
from langchain.chains.llm import LLMChain
|
|
@@ -100,12 +100,14 @@ class LLMCheckerChain(Chain):
|
|
|
100
100
|
input_key: str = "query" #: :meta private:
|
|
101
101
|
output_key: str = "result" #: :meta private:
|
|
102
102
|
|
|
103
|
-
|
|
104
|
-
arbitrary_types_allowed
|
|
105
|
-
extra
|
|
103
|
+
model_config = ConfigDict(
|
|
104
|
+
arbitrary_types_allowed=True,
|
|
105
|
+
extra="forbid",
|
|
106
|
+
)
|
|
106
107
|
|
|
107
|
-
@
|
|
108
|
-
|
|
108
|
+
@model_validator(mode="before")
|
|
109
|
+
@classmethod
|
|
110
|
+
def raise_deprecation(cls, values: Dict) -> Any:
|
|
109
111
|
if "llm" in values:
|
|
110
112
|
warnings.warn(
|
|
111
113
|
"Directly instantiating an LLMCheckerChain with an llm is deprecated. "
|
|
@@ -14,7 +14,7 @@ from langchain_core.callbacks import (
|
|
|
14
14
|
)
|
|
15
15
|
from langchain_core.language_models import BaseLanguageModel
|
|
16
16
|
from langchain_core.prompts import BasePromptTemplate
|
|
17
|
-
from
|
|
17
|
+
from pydantic import ConfigDict, model_validator
|
|
18
18
|
|
|
19
19
|
from langchain.chains.base import Chain
|
|
20
20
|
from langchain.chains.llm import LLMChain
|
|
@@ -156,12 +156,14 @@ class LLMMathChain(Chain):
|
|
|
156
156
|
input_key: str = "question" #: :meta private:
|
|
157
157
|
output_key: str = "answer" #: :meta private:
|
|
158
158
|
|
|
159
|
-
|
|
160
|
-
arbitrary_types_allowed
|
|
161
|
-
extra
|
|
159
|
+
model_config = ConfigDict(
|
|
160
|
+
arbitrary_types_allowed=True,
|
|
161
|
+
extra="forbid",
|
|
162
|
+
)
|
|
162
163
|
|
|
163
|
-
@
|
|
164
|
-
|
|
164
|
+
@model_validator(mode="before")
|
|
165
|
+
@classmethod
|
|
166
|
+
def raise_deprecation(cls, values: Dict) -> Any:
|
|
165
167
|
try:
|
|
166
168
|
import numexpr # noqa: F401
|
|
167
169
|
except ImportError:
|
|
@@ -10,7 +10,7 @@ from langchain_core._api import deprecated
|
|
|
10
10
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
|
11
11
|
from langchain_core.language_models import BaseLanguageModel
|
|
12
12
|
from langchain_core.prompts.prompt import PromptTemplate
|
|
13
|
-
from
|
|
13
|
+
from pydantic import ConfigDict, model_validator
|
|
14
14
|
|
|
15
15
|
from langchain.chains.base import Chain
|
|
16
16
|
from langchain.chains.llm import LLMChain
|
|
@@ -105,12 +105,14 @@ class LLMSummarizationCheckerChain(Chain):
|
|
|
105
105
|
max_checks: int = 2
|
|
106
106
|
"""Maximum number of times to check the assertions. Default to double-checking."""
|
|
107
107
|
|
|
108
|
-
|
|
109
|
-
arbitrary_types_allowed
|
|
110
|
-
extra
|
|
108
|
+
model_config = ConfigDict(
|
|
109
|
+
arbitrary_types_allowed=True,
|
|
110
|
+
extra="forbid",
|
|
111
|
+
)
|
|
111
112
|
|
|
112
|
-
@
|
|
113
|
-
|
|
113
|
+
@model_validator(mode="before")
|
|
114
|
+
@classmethod
|
|
115
|
+
def raise_deprecation(cls, values: Dict) -> Any:
|
|
114
116
|
if "llm" in values:
|
|
115
117
|
warnings.warn(
|
|
116
118
|
"Directly instantiating an LLMSummarizationCheckerChain with an llm is "
|
langchain/chains/mapreduce.py
CHANGED
|
@@ -14,6 +14,7 @@ from langchain_core.documents import Document
|
|
|
14
14
|
from langchain_core.language_models import BaseLanguageModel
|
|
15
15
|
from langchain_core.prompts import BasePromptTemplate
|
|
16
16
|
from langchain_text_splitters import TextSplitter
|
|
17
|
+
from pydantic import ConfigDict
|
|
17
18
|
|
|
18
19
|
from langchain.chains import ReduceDocumentsChain
|
|
19
20
|
from langchain.chains.base import Chain
|
|
@@ -77,9 +78,10 @@ class MapReduceChain(Chain):
|
|
|
77
78
|
**kwargs,
|
|
78
79
|
)
|
|
79
80
|
|
|
80
|
-
|
|
81
|
-
arbitrary_types_allowed
|
|
82
|
-
extra
|
|
81
|
+
model_config = ConfigDict(
|
|
82
|
+
arbitrary_types_allowed=True,
|
|
83
|
+
extra="forbid",
|
|
84
|
+
)
|
|
83
85
|
|
|
84
86
|
@property
|
|
85
87
|
def input_keys(self) -> List[str]:
|
langchain/chains/moderation.py
CHANGED
|
@@ -6,8 +6,8 @@ from langchain_core.callbacks import (
|
|
|
6
6
|
AsyncCallbackManagerForChainRun,
|
|
7
7
|
CallbackManagerForChainRun,
|
|
8
8
|
)
|
|
9
|
-
from langchain_core.pydantic_v1 import Field, root_validator
|
|
10
9
|
from langchain_core.utils import check_package_version, get_from_dict_or_env
|
|
10
|
+
from pydantic import Field, model_validator
|
|
11
11
|
|
|
12
12
|
from langchain.chains.base import Chain
|
|
13
13
|
|
|
@@ -28,8 +28,8 @@ class OpenAIModerationChain(Chain):
|
|
|
28
28
|
moderation = OpenAIModerationChain()
|
|
29
29
|
"""
|
|
30
30
|
|
|
31
|
-
client: Any #: :meta private:
|
|
32
|
-
async_client: Any #: :meta private:
|
|
31
|
+
client: Any = None #: :meta private:
|
|
32
|
+
async_client: Any = None #: :meta private:
|
|
33
33
|
model_name: Optional[str] = None
|
|
34
34
|
"""Moderation model name to use."""
|
|
35
35
|
error: bool = False
|
|
@@ -38,10 +38,11 @@ class OpenAIModerationChain(Chain):
|
|
|
38
38
|
output_key: str = "output" #: :meta private:
|
|
39
39
|
openai_api_key: Optional[str] = None
|
|
40
40
|
openai_organization: Optional[str] = None
|
|
41
|
-
|
|
41
|
+
openai_pre_1_0: bool = Field(default=None)
|
|
42
42
|
|
|
43
|
-
@
|
|
44
|
-
|
|
43
|
+
@model_validator(mode="before")
|
|
44
|
+
@classmethod
|
|
45
|
+
def validate_environment(cls, values: Dict) -> Any:
|
|
45
46
|
"""Validate that api key and python package exists in environment."""
|
|
46
47
|
openai_api_key = get_from_dict_or_env(
|
|
47
48
|
values, "openai_api_key", "OPENAI_API_KEY"
|
|
@@ -58,16 +59,17 @@ class OpenAIModerationChain(Chain):
|
|
|
58
59
|
openai.api_key = openai_api_key
|
|
59
60
|
if openai_organization:
|
|
60
61
|
openai.organization = openai_organization
|
|
61
|
-
values["
|
|
62
|
+
values["openai_pre_1_0"] = False
|
|
62
63
|
try:
|
|
63
64
|
check_package_version("openai", gte_version="1.0")
|
|
64
65
|
except ValueError:
|
|
65
|
-
values["
|
|
66
|
-
if values["
|
|
66
|
+
values["openai_pre_1_0"] = True
|
|
67
|
+
if values["openai_pre_1_0"]:
|
|
67
68
|
values["client"] = openai.Moderation
|
|
68
69
|
else:
|
|
69
70
|
values["client"] = openai.OpenAI()
|
|
70
71
|
values["async_client"] = openai.AsyncOpenAI()
|
|
72
|
+
|
|
71
73
|
except ImportError:
|
|
72
74
|
raise ImportError(
|
|
73
75
|
"Could not import openai python package. "
|
|
@@ -92,7 +94,7 @@ class OpenAIModerationChain(Chain):
|
|
|
92
94
|
return [self.output_key]
|
|
93
95
|
|
|
94
96
|
def _moderate(self, text: str, results: Any) -> str:
|
|
95
|
-
if self.
|
|
97
|
+
if self.openai_pre_1_0:
|
|
96
98
|
condition = results["flagged"]
|
|
97
99
|
else:
|
|
98
100
|
condition = results.flagged
|
|
@@ -110,7 +112,7 @@ class OpenAIModerationChain(Chain):
|
|
|
110
112
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
111
113
|
) -> Dict[str, Any]:
|
|
112
114
|
text = inputs[self.input_key]
|
|
113
|
-
if self.
|
|
115
|
+
if self.openai_pre_1_0:
|
|
114
116
|
results = self.client.create(text)
|
|
115
117
|
output = self._moderate(text, results["results"][0])
|
|
116
118
|
else:
|
|
@@ -123,7 +125,7 @@ class OpenAIModerationChain(Chain):
|
|
|
123
125
|
inputs: Dict[str, Any],
|
|
124
126
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
|
125
127
|
) -> Dict[str, Any]:
|
|
126
|
-
if self.
|
|
128
|
+
if self.openai_pre_1_0:
|
|
127
129
|
return await super()._acall(inputs, run_manager=run_manager)
|
|
128
130
|
text = inputs[self.input_key]
|
|
129
131
|
results = await self.async_client.moderations.create(input=text)
|
langchain/chains/natbot/base.py
CHANGED
|
@@ -9,8 +9,8 @@ from langchain_core._api import deprecated
|
|
|
9
9
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
|
10
10
|
from langchain_core.language_models import BaseLanguageModel
|
|
11
11
|
from langchain_core.output_parsers import StrOutputParser
|
|
12
|
-
from langchain_core.pydantic_v1 import root_validator
|
|
13
12
|
from langchain_core.runnables import Runnable
|
|
13
|
+
from pydantic import ConfigDict, model_validator
|
|
14
14
|
|
|
15
15
|
from langchain.chains.base import Chain
|
|
16
16
|
from langchain.chains.natbot.prompt import PROMPT
|
|
@@ -59,12 +59,14 @@ class NatBotChain(Chain):
|
|
|
59
59
|
previous_command: str = "" #: :meta private:
|
|
60
60
|
output_key: str = "command" #: :meta private:
|
|
61
61
|
|
|
62
|
-
|
|
63
|
-
arbitrary_types_allowed
|
|
64
|
-
extra
|
|
62
|
+
model_config = ConfigDict(
|
|
63
|
+
arbitrary_types_allowed=True,
|
|
64
|
+
extra="forbid",
|
|
65
|
+
)
|
|
65
66
|
|
|
66
|
-
@
|
|
67
|
-
|
|
67
|
+
@model_validator(mode="before")
|
|
68
|
+
@classmethod
|
|
69
|
+
def raise_deprecation(cls, values: Dict) -> Any:
|
|
68
70
|
if "llm" in values:
|
|
69
71
|
warnings.warn(
|
|
70
72
|
"Directly instantiating an NatBotChain with an llm is deprecated. "
|