langchain 0.2.13__py3-none-any.whl → 0.2.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. langchain/_api/module_import.py +2 -2
  2. langchain/agents/__init__.py +4 -3
  3. langchain/agents/agent.py +40 -23
  4. langchain/agents/agent_iterator.py +2 -2
  5. langchain/agents/agent_toolkits/__init__.py +1 -1
  6. langchain/agents/agent_toolkits/vectorstore/toolkit.py +2 -1
  7. langchain/agents/agent_types.py +2 -2
  8. langchain/agents/chat/base.py +1 -1
  9. langchain/agents/conversational/base.py +1 -1
  10. langchain/agents/conversational_chat/base.py +1 -1
  11. langchain/agents/initialize.py +1 -1
  12. langchain/agents/json_chat/base.py +1 -1
  13. langchain/agents/loading.py +2 -2
  14. langchain/agents/mrkl/base.py +3 -3
  15. langchain/agents/openai_assistant/base.py +11 -15
  16. langchain/agents/openai_functions_agent/base.py +1 -1
  17. langchain/agents/openai_functions_multi_agent/base.py +1 -1
  18. langchain/agents/react/agent.py +1 -1
  19. langchain/agents/react/base.py +4 -4
  20. langchain/agents/self_ask_with_search/base.py +2 -2
  21. langchain/agents/structured_chat/base.py +3 -2
  22. langchain/agents/tools.py +2 -2
  23. langchain/agents/xml/base.py +2 -2
  24. langchain/chains/base.py +5 -5
  25. langchain/chains/combine_documents/base.py +4 -4
  26. langchain/chains/combine_documents/stuff.py +1 -1
  27. langchain/chains/constitutional_ai/base.py +144 -1
  28. langchain/chains/conversation/base.py +1 -1
  29. langchain/chains/conversational_retrieval/base.py +1 -1
  30. langchain/chains/flare/base.py +42 -69
  31. langchain/chains/hyde/base.py +18 -8
  32. langchain/chains/llm_math/base.py +118 -1
  33. langchain/chains/natbot/base.py +24 -10
  34. langchain/chains/openai_functions/base.py +2 -2
  35. langchain/chains/openai_functions/extraction.py +2 -2
  36. langchain/chains/openai_tools/extraction.py +1 -1
  37. langchain/chains/query_constructor/parser.py +23 -0
  38. langchain/chains/structured_output/base.py +2 -2
  39. langchain/retrievers/document_compressors/chain_extract.py +19 -10
  40. langchain/retrievers/document_compressors/chain_filter.py +27 -10
  41. langchain/retrievers/document_compressors/cohere_rerank.py +1 -1
  42. langchain/retrievers/re_phraser.py +7 -7
  43. langchain/tools/__init__.py +14 -5
  44. langchain/tools/render.py +0 -2
  45. langchain/tools/retriever.py +0 -4
  46. {langchain-0.2.13.dist-info → langchain-0.2.15.dist-info}/METADATA +2 -2
  47. {langchain-0.2.13.dist-info → langchain-0.2.15.dist-info}/RECORD +50 -50
  48. {langchain-0.2.13.dist-info → langchain-0.2.15.dist-info}/LICENSE +0 -0
  49. {langchain-0.2.13.dist-info → langchain-0.2.15.dist-info}/WHEEL +0 -0
  50. {langchain-0.2.13.dist-info → langchain-0.2.15.dist-info}/entry_points.txt +0 -0
langchain/agents/tools.py CHANGED
@@ -6,7 +6,7 @@ from langchain_core.callbacks import (
6
6
  AsyncCallbackManagerForToolRun,
7
7
  CallbackManagerForToolRun,
8
8
  )
9
- from langchain_core.tools import BaseTool, Tool, tool
9
+ from langchain_core.tools import BaseTool, tool
10
10
 
11
11
 
12
12
  class InvalidTool(BaseTool):
@@ -44,4 +44,4 @@ class InvalidTool(BaseTool):
44
44
  )
45
45
 
46
46
 
47
- __all__ = ["InvalidTool", "BaseTool", "tool", "Tool"]
47
+ __all__ = ["InvalidTool", "tool"]
@@ -8,16 +8,16 @@ from langchain_core.prompts.base import BasePromptTemplate
8
8
  from langchain_core.prompts.chat import AIMessagePromptTemplate, ChatPromptTemplate
9
9
  from langchain_core.runnables import Runnable, RunnablePassthrough
10
10
  from langchain_core.tools import BaseTool
11
+ from langchain_core.tools.render import ToolsRenderer, render_text_description
11
12
 
12
13
  from langchain.agents.agent import BaseSingleActionAgent
13
14
  from langchain.agents.format_scratchpad import format_xml
14
15
  from langchain.agents.output_parsers import XMLAgentOutputParser
15
16
  from langchain.agents.xml.prompt import agent_instructions
16
17
  from langchain.chains.llm import LLMChain
17
- from langchain.tools.render import ToolsRenderer, render_text_description
18
18
 
19
19
 
20
- @deprecated("0.1.0", alternative="create_xml_agent", removal="0.3.0")
20
+ @deprecated("0.1.0", alternative="create_xml_agent", removal="1.0")
21
21
  class XMLAgent(BaseSingleActionAgent):
22
22
  """Agent that uses XML tags.
23
23
 
langchain/chains/base.py CHANGED
@@ -334,7 +334,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
334
334
  None, self._call, inputs, run_manager.get_sync() if run_manager else None
335
335
  )
336
336
 
337
- @deprecated("0.1.0", alternative="invoke", removal="0.3.0")
337
+ @deprecated("0.1.0", alternative="invoke", removal="1.0")
338
338
  def __call__(
339
339
  self,
340
340
  inputs: Union[Dict[str, Any], Any],
@@ -385,7 +385,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
385
385
  include_run_info=include_run_info,
386
386
  )
387
387
 
388
- @deprecated("0.1.0", alternative="ainvoke", removal="0.3.0")
388
+ @deprecated("0.1.0", alternative="ainvoke", removal="1.0")
389
389
  async def acall(
390
390
  self,
391
391
  inputs: Union[Dict[str, Any], Any],
@@ -544,7 +544,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
544
544
  )
545
545
  return self.output_keys[0]
546
546
 
547
- @deprecated("0.1.0", alternative="invoke", removal="0.3.0")
547
+ @deprecated("0.1.0", alternative="invoke", removal="1.0")
548
548
  def run(
549
549
  self,
550
550
  *args: Any,
@@ -615,7 +615,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
615
615
  f" but not both. Got args: {args} and kwargs: {kwargs}."
616
616
  )
617
617
 
618
- @deprecated("0.1.0", alternative="ainvoke", removal="0.3.0")
618
+ @deprecated("0.1.0", alternative="ainvoke", removal="1.0")
619
619
  async def arun(
620
620
  self,
621
621
  *args: Any,
@@ -753,7 +753,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
753
753
  else:
754
754
  raise ValueError(f"{save_path} must be json or yaml")
755
755
 
756
- @deprecated("0.1.0", alternative="batch", removal="0.3.0")
756
+ @deprecated("0.1.0", alternative="batch", removal="1.0")
757
757
  def apply(
758
758
  self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
759
759
  ) -> List[Dict[str, str]]:
@@ -22,11 +22,11 @@ DOCUMENTS_KEY = "context"
22
22
  DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template("{page_content}")
23
23
 
24
24
 
25
- def _validate_prompt(prompt: BasePromptTemplate) -> None:
26
- if DOCUMENTS_KEY not in prompt.input_variables:
25
+ def _validate_prompt(prompt: BasePromptTemplate, document_variable_name: str) -> None:
26
+ if document_variable_name not in prompt.input_variables:
27
27
  raise ValueError(
28
- f"Prompt must accept {DOCUMENTS_KEY} as an input variable. Received prompt "
29
- f"with input variables: {prompt.input_variables}"
28
+ f"Prompt must accept {document_variable_name} as an input variable. "
29
+ f"Received prompt with input variables: {prompt.input_variables}"
30
30
  )
31
31
 
32
32
 
@@ -76,7 +76,7 @@ def create_stuff_documents_chain(
76
76
  chain.invoke({"context": docs})
77
77
  """ # noqa: E501
78
78
 
79
- _validate_prompt(prompt)
79
+ _validate_prompt(prompt, document_variable_name)
80
80
  _document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT
81
81
  _output_parser = output_parser or StrOutputParser()
82
82
 
@@ -2,6 +2,7 @@
2
2
 
3
3
  from typing import Any, Dict, List, Optional
4
4
 
5
+ from langchain_core._api import deprecated
5
6
  from langchain_core.callbacks import CallbackManagerForChainRun
6
7
  from langchain_core.language_models import BaseLanguageModel
7
8
  from langchain_core.prompts import BasePromptTemplate
@@ -13,9 +14,151 @@ from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION
13
14
  from langchain.chains.llm import LLMChain
14
15
 
15
16
 
17
+ @deprecated(
18
+ since="0.2.13",
19
+ message=(
20
+ "This class is deprecated and will be removed in langchain 1.0. "
21
+ "See API reference for replacement: "
22
+ "https://api.python.langchain.com/en/latest/chains/langchain.chains.constitutional_ai.base.ConstitutionalChain.html" # noqa: E501
23
+ ),
24
+ removal="1.0",
25
+ )
16
26
  class ConstitutionalChain(Chain):
17
27
  """Chain for applying constitutional principles.
18
28
 
29
+ Note: this class is deprecated. See below for a replacement implementation
30
+ using LangGraph. The benefits of this implementation are:
31
+
32
+ - Uses LLM tool calling features instead of parsing string responses;
33
+ - Support for both token-by-token and step-by-step streaming;
34
+ - Support for checkpointing and memory of chat history;
35
+ - Easier to modify or extend (e.g., with additional tools, structured responses, etc.)
36
+
37
+ Install LangGraph with:
38
+
39
+ .. code-block:: bash
40
+
41
+ pip install -U langgraph
42
+
43
+ .. code-block:: python
44
+
45
+ from typing import List, Optional, Tuple
46
+
47
+ from langchain.chains.constitutional_ai.prompts import (
48
+ CRITIQUE_PROMPT,
49
+ REVISION_PROMPT,
50
+ )
51
+ from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
52
+ from langchain_core.output_parsers import StrOutputParser
53
+ from langchain_core.prompts import ChatPromptTemplate
54
+ from langchain_openai import ChatOpenAI
55
+ from langgraph.graph import END, START, StateGraph
56
+ from typing_extensions import Annotated, TypedDict
57
+
58
+ llm = ChatOpenAI(model="gpt-4o-mini")
59
+
60
+ class Critique(TypedDict):
61
+ \"\"\"Generate a critique, if needed.\"\"\"
62
+ critique_needed: Annotated[bool, ..., "Whether or not a critique is needed."]
63
+ critique: Annotated[str, ..., "If needed, the critique."]
64
+
65
+ critique_prompt = ChatPromptTemplate.from_template(
66
+ "Critique this response according to the critique request. "
67
+ "If no critique is needed, specify that.\\n\\n"
68
+ "Query: {query}\\n\\n"
69
+ "Response: {response}\\n\\n"
70
+ "Critique request: {critique_request}"
71
+ )
72
+
73
+ revision_prompt = ChatPromptTemplate.from_template(
74
+ "Revise this response according to the critique and reivsion request.\\n\\n"
75
+ "Query: {query}\\n\\n"
76
+ "Response: {response}\\n\\n"
77
+ "Critique request: {critique_request}\\n\\n"
78
+ "Critique: {critique}\\n\\n"
79
+ "If the critique does not identify anything worth changing, ignore the "
80
+ "revision request and return 'No revisions needed'. If the critique "
81
+ "does identify something worth changing, revise the response based on "
82
+ "the revision request.\\n\\n"
83
+ "Revision Request: {revision_request}"
84
+ )
85
+
86
+ chain = llm | StrOutputParser()
87
+ critique_chain = critique_prompt | llm.with_structured_output(Critique)
88
+ revision_chain = revision_prompt | llm | StrOutputParser()
89
+
90
+
91
+ class State(TypedDict):
92
+ query: str
93
+ constitutional_principles: List[ConstitutionalPrinciple]
94
+ initial_response: str
95
+ critiques_and_revisions: List[Tuple[str, str]]
96
+ response: str
97
+
98
+
99
+ async def generate_response(state: State):
100
+ \"\"\"Generate initial response.\"\"\"
101
+ response = await chain.ainvoke(state["query"])
102
+ return {"response": response, "initial_response": response}
103
+
104
+ async def critique_and_revise(state: State):
105
+ \"\"\"Critique and revise response according to principles.\"\"\"
106
+ critiques_and_revisions = []
107
+ response = state["initial_response"]
108
+ for principle in state["constitutional_principles"]:
109
+ critique = await critique_chain.ainvoke(
110
+ {
111
+ "query": state["query"],
112
+ "response": response,
113
+ "critique_request": principle.critique_request,
114
+ }
115
+ )
116
+ if critique["critique_needed"]:
117
+ revision = await revision_chain.ainvoke(
118
+ {
119
+ "query": state["query"],
120
+ "response": response,
121
+ "critique_request": principle.critique_request,
122
+ "critique": critique["critique"],
123
+ "revision_request": principle.revision_request,
124
+ }
125
+ )
126
+ response = revision
127
+ critiques_and_revisions.append((critique["critique"], revision))
128
+ else:
129
+ critiques_and_revisions.append((critique["critique"], ""))
130
+ return {
131
+ "critiques_and_revisions": critiques_and_revisions,
132
+ "response": response,
133
+ }
134
+
135
+ graph = StateGraph(State)
136
+ graph.add_node("generate_response", generate_response)
137
+ graph.add_node("critique_and_revise", critique_and_revise)
138
+
139
+ graph.add_edge(START, "generate_response")
140
+ graph.add_edge("generate_response", "critique_and_revise")
141
+ graph.add_edge("critique_and_revise", END)
142
+ app = graph.compile()
143
+
144
+ .. code-block:: python
145
+
146
+ constitutional_principles=[
147
+ ConstitutionalPrinciple(
148
+ critique_request="Tell if this answer is good.",
149
+ revision_request="Give a better answer.",
150
+ )
151
+ ]
152
+
153
+ query = "What is the meaning of life? Answer in 10 words or fewer."
154
+
155
+ async for step in app.astream(
156
+ {"query": query, "constitutional_principles": constitutional_principles},
157
+ stream_mode="values",
158
+ ):
159
+ subset = ["initial_response", "critiques_and_revisions", "response"]
160
+ print({k: v for k, v in step.items() if k in subset})
161
+
19
162
  Example:
20
163
  .. code-block:: python
21
164
 
@@ -44,7 +187,7 @@ class ConstitutionalChain(Chain):
44
187
  )
45
188
 
46
189
  constitutional_chain.run(question="What is the meaning of life?")
47
- """
190
+ """ # noqa: E501
48
191
 
49
192
  chain: LLMChain
50
193
  constitutional_principles: List[ConstitutionalPrinciple]
@@ -16,7 +16,7 @@ from langchain.memory.buffer import ConversationBufferMemory
16
16
  since="0.2.7",
17
17
  alternative=(
18
18
  "RunnableWithMessageHistory: "
19
- "https://api.python.langchain.com/en/latest/runnables/langchain_core.runnables.history.RunnableWithMessageHistory.html" # noqa: E501
19
+ "https://python.langchain.com/v0.2/api_reference/core/runnables/langchain_core.runnables.history.RunnableWithMessageHistory.html" # noqa: E501
20
20
  ),
21
21
  removal="1.0",
22
22
  )
@@ -242,7 +242,7 @@ class BaseConversationalRetrievalChain(Chain):
242
242
  "create_history_aware_retriever together with create_retrieval_chain "
243
243
  "(see example in docstring)"
244
244
  ),
245
- removal="0.3.0",
245
+ removal="1.0",
246
246
  )
247
247
  class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
248
248
  """Chain for having a conversation based on retrieved documents.
@@ -1,7 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import re
4
- from abc import abstractmethod
5
4
  from typing import Any, Dict, List, Optional, Sequence, Tuple
6
5
 
7
6
  import numpy as np
@@ -9,10 +8,12 @@ from langchain_core.callbacks import (
9
8
  CallbackManagerForChainRun,
10
9
  )
11
10
  from langchain_core.language_models import BaseLanguageModel
12
- from langchain_core.outputs import Generation
11
+ from langchain_core.messages import AIMessage
12
+ from langchain_core.output_parsers import StrOutputParser
13
13
  from langchain_core.prompts import BasePromptTemplate
14
14
  from langchain_core.pydantic_v1 import Field
15
15
  from langchain_core.retrievers import BaseRetriever
16
+ from langchain_core.runnables import Runnable
16
17
 
17
18
  from langchain.chains.base import Chain
18
19
  from langchain.chains.flare.prompts import (
@@ -23,51 +24,14 @@ from langchain.chains.flare.prompts import (
23
24
  from langchain.chains.llm import LLMChain
24
25
 
25
26
 
26
- class _ResponseChain(LLMChain):
27
- """Base class for chains that generate responses."""
28
-
29
- prompt: BasePromptTemplate = PROMPT
30
-
31
- @classmethod
32
- def is_lc_serializable(cls) -> bool:
33
- return False
34
-
35
- @property
36
- def input_keys(self) -> List[str]:
37
- return self.prompt.input_variables
38
-
39
- def generate_tokens_and_log_probs(
40
- self,
41
- _input: Dict[str, Any],
42
- *,
43
- run_manager: Optional[CallbackManagerForChainRun] = None,
44
- ) -> Tuple[Sequence[str], Sequence[float]]:
45
- llm_result = self.generate([_input], run_manager=run_manager)
46
- return self._extract_tokens_and_log_probs(llm_result.generations[0])
47
-
48
- @abstractmethod
49
- def _extract_tokens_and_log_probs(
50
- self, generations: List[Generation]
51
- ) -> Tuple[Sequence[str], Sequence[float]]:
52
- """Extract tokens and log probs from response."""
53
-
54
-
55
- class _OpenAIResponseChain(_ResponseChain):
56
- """Chain that generates responses from user input and context."""
57
-
58
- llm: BaseLanguageModel
59
-
60
- def _extract_tokens_and_log_probs(
61
- self, generations: List[Generation]
62
- ) -> Tuple[Sequence[str], Sequence[float]]:
63
- tokens = []
64
- log_probs = []
65
- for gen in generations:
66
- if gen.generation_info is None:
67
- raise ValueError
68
- tokens.extend(gen.generation_info["logprobs"]["tokens"])
69
- log_probs.extend(gen.generation_info["logprobs"]["token_logprobs"])
70
- return tokens, log_probs
27
+ def _extract_tokens_and_log_probs(response: AIMessage) -> Tuple[List[str], List[float]]:
28
+ """Extract tokens and log probabilities from chat model response."""
29
+ tokens = []
30
+ log_probs = []
31
+ for token in response.response_metadata["logprobs"]["content"]:
32
+ tokens.append(token["token"])
33
+ log_probs.append(token["logprob"])
34
+ return tokens, log_probs
71
35
 
72
36
 
73
37
  class QuestionGeneratorChain(LLMChain):
@@ -111,9 +75,9 @@ class FlareChain(Chain):
111
75
  """Chain that combines a retriever, a question generator,
112
76
  and a response generator."""
113
77
 
114
- question_generator_chain: QuestionGeneratorChain
78
+ question_generator_chain: Runnable
115
79
  """Chain that generates questions from uncertain spans."""
116
- response_chain: _ResponseChain
80
+ response_chain: Runnable
117
81
  """Chain that generates responses from user input and context."""
118
82
  output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
119
83
  """Parser that determines whether the chain is finished."""
@@ -152,12 +116,16 @@ class FlareChain(Chain):
152
116
  for question in questions:
153
117
  docs.extend(self.retriever.invoke(question))
154
118
  context = "\n\n".join(d.page_content for d in docs)
155
- result = self.response_chain.predict(
156
- user_input=user_input,
157
- context=context,
158
- response=response,
159
- callbacks=callbacks,
119
+ result = self.response_chain.invoke(
120
+ {
121
+ "user_input": user_input,
122
+ "context": context,
123
+ "response": response,
124
+ },
125
+ {"callbacks": callbacks},
160
126
  )
127
+ if isinstance(result, AIMessage):
128
+ result = result.content
161
129
  marginal, finished = self.output_parser.parse(result)
162
130
  return marginal, finished
163
131
 
@@ -178,13 +146,18 @@ class FlareChain(Chain):
178
146
  for span in low_confidence_spans
179
147
  ]
180
148
  callbacks = _run_manager.get_child()
181
- question_gen_outputs = self.question_generator_chain.apply(
182
- question_gen_inputs, callbacks=callbacks
183
- )
184
- questions = [
185
- output[self.question_generator_chain.output_keys[0]]
186
- for output in question_gen_outputs
187
- ]
149
+ if isinstance(self.question_generator_chain, LLMChain):
150
+ question_gen_outputs = self.question_generator_chain.apply(
151
+ question_gen_inputs, callbacks=callbacks
152
+ )
153
+ questions = [
154
+ output[self.question_generator_chain.output_keys[0]]
155
+ for output in question_gen_outputs
156
+ ]
157
+ else:
158
+ questions = self.question_generator_chain.batch(
159
+ question_gen_inputs, config={"callbacks": callbacks}
160
+ )
188
161
  _run_manager.on_text(
189
162
  f"Generated Questions: {questions}", color="yellow", end="\n"
190
163
  )
@@ -206,8 +179,10 @@ class FlareChain(Chain):
206
179
  f"Current Response: {response}", color="blue", end="\n"
207
180
  )
208
181
  _input = {"user_input": user_input, "context": "", "response": response}
209
- tokens, log_probs = self.response_chain.generate_tokens_and_log_probs(
210
- _input, run_manager=_run_manager
182
+ tokens, log_probs = _extract_tokens_and_log_probs(
183
+ self.response_chain.invoke(
184
+ _input, {"callbacks": _run_manager.get_child()}
185
+ )
211
186
  )
212
187
  low_confidence_spans = _low_confidence_spans(
213
188
  tokens,
@@ -251,18 +226,16 @@ class FlareChain(Chain):
251
226
  FlareChain class with the given language model.
252
227
  """
253
228
  try:
254
- from langchain_openai import OpenAI
229
+ from langchain_openai import ChatOpenAI
255
230
  except ImportError:
256
231
  raise ImportError(
257
232
  "OpenAI is required for FlareChain. "
258
233
  "Please install langchain-openai."
259
234
  "pip install langchain-openai"
260
235
  )
261
- question_gen_chain = QuestionGeneratorChain(llm=llm)
262
- response_llm = OpenAI(
263
- max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0
264
- )
265
- response_chain = _OpenAIResponseChain(llm=response_llm)
236
+ llm = ChatOpenAI(max_tokens=max_generation_len, logprobs=True, temperature=0)
237
+ response_chain = PROMPT | llm
238
+ question_gen_chain = QUESTION_GENERATOR_PROMPT | llm | StrOutputParser()
266
239
  return cls(
267
240
  question_generator_chain=question_gen_chain,
268
241
  response_chain=response_chain,
@@ -11,7 +11,9 @@ import numpy as np
11
11
  from langchain_core.callbacks import CallbackManagerForChainRun
12
12
  from langchain_core.embeddings import Embeddings
13
13
  from langchain_core.language_models import BaseLanguageModel
14
+ from langchain_core.output_parsers import StrOutputParser
14
15
  from langchain_core.prompts import BasePromptTemplate
16
+ from langchain_core.runnables import Runnable
15
17
 
16
18
  from langchain.chains.base import Chain
17
19
  from langchain.chains.hyde.prompts import PROMPT_MAP
@@ -25,7 +27,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
25
27
  """
26
28
 
27
29
  base_embeddings: Embeddings
28
- llm_chain: LLMChain
30
+ llm_chain: Runnable
29
31
 
30
32
  class Config:
31
33
  arbitrary_types_allowed = True
@@ -34,12 +36,15 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
34
36
  @property
35
37
  def input_keys(self) -> List[str]:
36
38
  """Input keys for Hyde's LLM chain."""
37
- return self.llm_chain.input_keys
39
+ return self.llm_chain.input_schema.schema()["required"]
38
40
 
39
41
  @property
40
42
  def output_keys(self) -> List[str]:
41
43
  """Output keys for Hyde's LLM chain."""
42
- return self.llm_chain.output_keys
44
+ if isinstance(self.llm_chain, LLMChain):
45
+ return self.llm_chain.output_keys
46
+ else:
47
+ return ["text"]
43
48
 
44
49
  def embed_documents(self, texts: List[str]) -> List[List[float]]:
45
50
  """Call the base embeddings."""
@@ -51,9 +56,12 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
51
56
 
52
57
  def embed_query(self, text: str) -> List[float]:
53
58
  """Generate a hypothetical document and embedded it."""
54
- var_name = self.llm_chain.input_keys[0]
55
- result = self.llm_chain.generate([{var_name: text}])
56
- documents = [generation.text for generation in result.generations[0]]
59
+ var_name = self.input_keys[0]
60
+ result = self.llm_chain.invoke({var_name: text})
61
+ if isinstance(self.llm_chain, LLMChain):
62
+ documents = [result[self.output_keys[0]]]
63
+ else:
64
+ documents = [result]
57
65
  embeddings = self.embed_documents(documents)
58
66
  return self.combine_embeddings(embeddings)
59
67
 
@@ -64,7 +72,9 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
64
72
  ) -> Dict[str, str]:
65
73
  """Call the internal llm chain."""
66
74
  _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
67
- return self.llm_chain(inputs, callbacks=_run_manager.get_child())
75
+ return self.llm_chain.invoke(
76
+ inputs, config={"callbacks": _run_manager.get_child()}
77
+ )
68
78
 
69
79
  @classmethod
70
80
  def from_llm(
@@ -86,7 +96,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
86
96
  f"of {list(PROMPT_MAP.keys())}."
87
97
  )
88
98
 
89
- llm_chain = LLMChain(llm=llm, prompt=prompt)
99
+ llm_chain = prompt | llm | StrOutputParser()
90
100
  return cls(base_embeddings=base_embeddings, llm_chain=llm_chain, **kwargs)
91
101
 
92
102
  @property