langchain 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import re
4
- from abc import abstractmethod
5
4
  from typing import Any, Dict, List, Optional, Sequence, Tuple
6
5
 
7
6
  import numpy as np
@@ -9,10 +8,12 @@ from langchain_core.callbacks import (
9
8
  CallbackManagerForChainRun,
10
9
  )
11
10
  from langchain_core.language_models import BaseLanguageModel
12
- from langchain_core.outputs import Generation
11
+ from langchain_core.messages import AIMessage
12
+ from langchain_core.output_parsers import StrOutputParser
13
13
  from langchain_core.prompts import BasePromptTemplate
14
14
  from langchain_core.pydantic_v1 import Field
15
15
  from langchain_core.retrievers import BaseRetriever
16
+ from langchain_core.runnables import Runnable
16
17
 
17
18
  from langchain.chains.base import Chain
18
19
  from langchain.chains.flare.prompts import (
@@ -23,51 +24,14 @@ from langchain.chains.flare.prompts import (
23
24
  from langchain.chains.llm import LLMChain
24
25
 
25
26
 
26
- class _ResponseChain(LLMChain):
27
- """Base class for chains that generate responses."""
28
-
29
- prompt: BasePromptTemplate = PROMPT
30
-
31
- @classmethod
32
- def is_lc_serializable(cls) -> bool:
33
- return False
34
-
35
- @property
36
- def input_keys(self) -> List[str]:
37
- return self.prompt.input_variables
38
-
39
- def generate_tokens_and_log_probs(
40
- self,
41
- _input: Dict[str, Any],
42
- *,
43
- run_manager: Optional[CallbackManagerForChainRun] = None,
44
- ) -> Tuple[Sequence[str], Sequence[float]]:
45
- llm_result = self.generate([_input], run_manager=run_manager)
46
- return self._extract_tokens_and_log_probs(llm_result.generations[0])
47
-
48
- @abstractmethod
49
- def _extract_tokens_and_log_probs(
50
- self, generations: List[Generation]
51
- ) -> Tuple[Sequence[str], Sequence[float]]:
52
- """Extract tokens and log probs from response."""
53
-
54
-
55
- class _OpenAIResponseChain(_ResponseChain):
56
- """Chain that generates responses from user input and context."""
57
-
58
- llm: BaseLanguageModel
59
-
60
- def _extract_tokens_and_log_probs(
61
- self, generations: List[Generation]
62
- ) -> Tuple[Sequence[str], Sequence[float]]:
63
- tokens = []
64
- log_probs = []
65
- for gen in generations:
66
- if gen.generation_info is None:
67
- raise ValueError
68
- tokens.extend(gen.generation_info["logprobs"]["tokens"])
69
- log_probs.extend(gen.generation_info["logprobs"]["token_logprobs"])
70
- return tokens, log_probs
27
+ def _extract_tokens_and_log_probs(response: AIMessage) -> Tuple[List[str], List[float]]:
28
+ """Extract tokens and log probabilities from chat model response."""
29
+ tokens = []
30
+ log_probs = []
31
+ for token in response.response_metadata["logprobs"]["content"]:
32
+ tokens.append(token["token"])
33
+ log_probs.append(token["logprob"])
34
+ return tokens, log_probs
71
35
 
72
36
 
73
37
  class QuestionGeneratorChain(LLMChain):
@@ -109,11 +73,14 @@ def _low_confidence_spans(
109
73
 
110
74
  class FlareChain(Chain):
111
75
  """Chain that combines a retriever, a question generator,
112
- and a response generator."""
76
+ and a response generator.
113
77
 
114
- question_generator_chain: QuestionGeneratorChain
78
+ See [Active Retrieval Augmented Generation](https://arxiv.org/abs/2305.06983) paper.
79
+ """
80
+
81
+ question_generator_chain: Runnable
115
82
  """Chain that generates questions from uncertain spans."""
116
- response_chain: _ResponseChain
83
+ response_chain: Runnable
117
84
  """Chain that generates responses from user input and context."""
118
85
  output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
119
86
  """Parser that determines whether the chain is finished."""
@@ -152,12 +119,16 @@ class FlareChain(Chain):
152
119
  for question in questions:
153
120
  docs.extend(self.retriever.invoke(question))
154
121
  context = "\n\n".join(d.page_content for d in docs)
155
- result = self.response_chain.predict(
156
- user_input=user_input,
157
- context=context,
158
- response=response,
159
- callbacks=callbacks,
122
+ result = self.response_chain.invoke(
123
+ {
124
+ "user_input": user_input,
125
+ "context": context,
126
+ "response": response,
127
+ },
128
+ {"callbacks": callbacks},
160
129
  )
130
+ if isinstance(result, AIMessage):
131
+ result = result.content
161
132
  marginal, finished = self.output_parser.parse(result)
162
133
  return marginal, finished
163
134
 
@@ -178,13 +149,18 @@ class FlareChain(Chain):
178
149
  for span in low_confidence_spans
179
150
  ]
180
151
  callbacks = _run_manager.get_child()
181
- question_gen_outputs = self.question_generator_chain.apply(
182
- question_gen_inputs, callbacks=callbacks
183
- )
184
- questions = [
185
- output[self.question_generator_chain.output_keys[0]]
186
- for output in question_gen_outputs
187
- ]
152
+ if isinstance(self.question_generator_chain, LLMChain):
153
+ question_gen_outputs = self.question_generator_chain.apply(
154
+ question_gen_inputs, callbacks=callbacks
155
+ )
156
+ questions = [
157
+ output[self.question_generator_chain.output_keys[0]]
158
+ for output in question_gen_outputs
159
+ ]
160
+ else:
161
+ questions = self.question_generator_chain.batch(
162
+ question_gen_inputs, config={"callbacks": callbacks}
163
+ )
188
164
  _run_manager.on_text(
189
165
  f"Generated Questions: {questions}", color="yellow", end="\n"
190
166
  )
@@ -206,8 +182,10 @@ class FlareChain(Chain):
206
182
  f"Current Response: {response}", color="blue", end="\n"
207
183
  )
208
184
  _input = {"user_input": user_input, "context": "", "response": response}
209
- tokens, log_probs = self.response_chain.generate_tokens_and_log_probs(
210
- _input, run_manager=_run_manager
185
+ tokens, log_probs = _extract_tokens_and_log_probs(
186
+ self.response_chain.invoke(
187
+ _input, {"callbacks": _run_manager.get_child()}
188
+ )
211
189
  )
212
190
  low_confidence_spans = _low_confidence_spans(
213
191
  tokens,
@@ -251,18 +229,16 @@ class FlareChain(Chain):
251
229
  FlareChain class with the given language model.
252
230
  """
253
231
  try:
254
- from langchain_openai import OpenAI
232
+ from langchain_openai import ChatOpenAI
255
233
  except ImportError:
256
234
  raise ImportError(
257
235
  "OpenAI is required for FlareChain. "
258
236
  "Please install langchain-openai."
259
237
  "pip install langchain-openai"
260
238
  )
261
- question_gen_chain = QuestionGeneratorChain(llm=llm)
262
- response_llm = OpenAI(
263
- max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0
264
- )
265
- response_chain = _OpenAIResponseChain(llm=response_llm)
239
+ llm = ChatOpenAI(max_tokens=max_generation_len, logprobs=True, temperature=0)
240
+ response_chain = PROMPT | llm
241
+ question_gen_chain = QUESTION_GENERATOR_PROMPT | llm | StrOutputParser()
266
242
  return cls(
267
243
  question_generator_chain=question_gen_chain,
268
244
  response_chain=response_chain,
@@ -11,7 +11,9 @@ import numpy as np
11
11
  from langchain_core.callbacks import CallbackManagerForChainRun
12
12
  from langchain_core.embeddings import Embeddings
13
13
  from langchain_core.language_models import BaseLanguageModel
14
+ from langchain_core.output_parsers import StrOutputParser
14
15
  from langchain_core.prompts import BasePromptTemplate
16
+ from langchain_core.runnables import Runnable
15
17
 
16
18
  from langchain.chains.base import Chain
17
19
  from langchain.chains.hyde.prompts import PROMPT_MAP
@@ -25,7 +27,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
25
27
  """
26
28
 
27
29
  base_embeddings: Embeddings
28
- llm_chain: LLMChain
30
+ llm_chain: Runnable
29
31
 
30
32
  class Config:
31
33
  arbitrary_types_allowed = True
@@ -34,12 +36,15 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
34
36
  @property
35
37
  def input_keys(self) -> List[str]:
36
38
  """Input keys for Hyde's LLM chain."""
37
- return self.llm_chain.input_keys
39
+ return self.llm_chain.input_schema.schema()["required"]
38
40
 
39
41
  @property
40
42
  def output_keys(self) -> List[str]:
41
43
  """Output keys for Hyde's LLM chain."""
42
- return self.llm_chain.output_keys
44
+ if isinstance(self.llm_chain, LLMChain):
45
+ return self.llm_chain.output_keys
46
+ else:
47
+ return ["text"]
43
48
 
44
49
  def embed_documents(self, texts: List[str]) -> List[List[float]]:
45
50
  """Call the base embeddings."""
@@ -51,9 +56,12 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
51
56
 
52
57
  def embed_query(self, text: str) -> List[float]:
53
58
  """Generate a hypothetical document and embedded it."""
54
- var_name = self.llm_chain.input_keys[0]
55
- result = self.llm_chain.generate([{var_name: text}])
56
- documents = [generation.text for generation in result.generations[0]]
59
+ var_name = self.input_keys[0]
60
+ result = self.llm_chain.invoke({var_name: text})
61
+ if isinstance(self.llm_chain, LLMChain):
62
+ documents = [result[self.output_keys[0]]]
63
+ else:
64
+ documents = [result]
57
65
  embeddings = self.embed_documents(documents)
58
66
  return self.combine_embeddings(embeddings)
59
67
 
@@ -64,7 +72,9 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
64
72
  ) -> Dict[str, str]:
65
73
  """Call the internal llm chain."""
66
74
  _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
67
- return self.llm_chain(inputs, callbacks=_run_manager.get_child())
75
+ return self.llm_chain.invoke(
76
+ inputs, config={"callbacks": _run_manager.get_child()}
77
+ )
68
78
 
69
79
  @classmethod
70
80
  def from_llm(
@@ -86,7 +96,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
86
96
  f"of {list(PROMPT_MAP.keys())}."
87
97
  )
88
98
 
89
- llm_chain = LLMChain(llm=llm, prompt=prompt)
99
+ llm_chain = prompt | llm | StrOutputParser()
90
100
  return cls(base_embeddings=base_embeddings, llm_chain=llm_chain, **kwargs)
91
101
 
92
102
  @property
@@ -7,6 +7,7 @@ import re
7
7
  import warnings
8
8
  from typing import Any, Dict, List, Optional
9
9
 
10
+ from langchain_core._api import deprecated
10
11
  from langchain_core.callbacks import (
11
12
  AsyncCallbackManagerForChainRun,
12
13
  CallbackManagerForChainRun,
@@ -20,16 +21,132 @@ from langchain.chains.llm import LLMChain
20
21
  from langchain.chains.llm_math.prompt import PROMPT
21
22
 
22
23
 
24
+ @deprecated(
25
+ since="0.2.13",
26
+ message=(
27
+ "This class is deprecated and will be removed in langchain 1.0. "
28
+ "See API reference for replacement: "
29
+ "https://api.python.langchain.com/en/latest/chains/langchain.chains.llm_math.base.LLMMathChain.html" # noqa: E501
30
+ ),
31
+ removal="1.0",
32
+ )
23
33
  class LLMMathChain(Chain):
24
34
  """Chain that interprets a prompt and executes python code to do math.
25
35
 
36
+ Note: this class is deprecated. See below for a replacement implementation
37
+ using LangGraph. The benefits of this implementation are:
38
+
39
+ - Uses LLM tool calling features;
40
+ - Support for both token-by-token and step-by-step streaming;
41
+ - Support for checkpointing and memory of chat history;
42
+ - Easier to modify or extend (e.g., with additional tools, structured responses, etc.)
43
+
44
+ Install LangGraph with:
45
+
46
+ .. code-block:: bash
47
+
48
+ pip install -U langgraph
49
+
50
+ .. code-block:: python
51
+
52
+ import math
53
+ from typing import Annotated, Sequence
54
+
55
+ from langchain_core.messages import BaseMessage
56
+ from langchain_core.runnables import RunnableConfig
57
+ from langchain_core.tools import tool
58
+ from langchain_openai import ChatOpenAI
59
+ from langgraph.graph import END, StateGraph
60
+ from langgraph.graph.message import add_messages
61
+ from langgraph.prebuilt.tool_node import ToolNode
62
+ import numexpr
63
+ from typing_extensions import TypedDict
64
+
65
+ @tool
66
+ def calculator(expression: str) -> str:
67
+ \"\"\"Calculate expression using Python's numexpr library.
68
+
69
+ Expression should be a single line mathematical expression
70
+ that solves the problem.
71
+
72
+ Examples:
73
+ "37593 * 67" for "37593 times 67"
74
+ "37593**(1/5)" for "37593^(1/5)"
75
+ \"\"\"
76
+ local_dict = {"pi": math.pi, "e": math.e}
77
+ return str(
78
+ numexpr.evaluate(
79
+ expression.strip(),
80
+ global_dict={}, # restrict access to globals
81
+ local_dict=local_dict, # add common mathematical functions
82
+ )
83
+ )
84
+
85
+ llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
86
+ tools = [calculator]
87
+ llm_with_tools = llm.bind_tools(tools, tool_choice="any")
88
+
89
+ class ChainState(TypedDict):
90
+ \"\"\"LangGraph state.\"\"\"
91
+
92
+ messages: Annotated[Sequence[BaseMessage], add_messages]
93
+
94
+ async def acall_chain(state: ChainState, config: RunnableConfig):
95
+ last_message = state["messages"][-1]
96
+ response = await llm_with_tools.ainvoke(state["messages"], config)
97
+ return {"messages": [response]}
98
+
99
+ async def acall_model(state: ChainState, config: RunnableConfig):
100
+ response = await llm.ainvoke(state["messages"], config)
101
+ return {"messages": [response]}
102
+
103
+ graph_builder = StateGraph(ChainState)
104
+ graph_builder.add_node("call_tool", acall_chain)
105
+ graph_builder.add_node("execute_tool", ToolNode(tools))
106
+ graph_builder.add_node("call_model", acall_model)
107
+ graph_builder.set_entry_point("call_tool")
108
+ graph_builder.add_edge("call_tool", "execute_tool")
109
+ graph_builder.add_edge("execute_tool", "call_model")
110
+ graph_builder.add_edge("call_model", END)
111
+ chain = graph_builder.compile()
112
+
113
+ .. code-block:: python
114
+
115
+ example_query = "What is 551368 divided by 82"
116
+
117
+ events = chain.astream(
118
+ {"messages": [("user", example_query)]},
119
+ stream_mode="values",
120
+ )
121
+ async for event in events:
122
+ event["messages"][-1].pretty_print()
123
+
124
+ .. code-block:: none
125
+
126
+ ================================ Human Message =================================
127
+
128
+ What is 551368 divided by 82
129
+ ================================== Ai Message ==================================
130
+ Tool Calls:
131
+ calculator (call_MEiGXuJjJ7wGU4aOT86QuGJS)
132
+ Call ID: call_MEiGXuJjJ7wGU4aOT86QuGJS
133
+ Args:
134
+ expression: 551368 / 82
135
+ ================================= Tool Message =================================
136
+ Name: calculator
137
+
138
+ 6724.0
139
+ ================================== Ai Message ==================================
140
+
141
+ 551368 divided by 82 equals 6724.
142
+
26
143
  Example:
27
144
  .. code-block:: python
28
145
 
29
146
  from langchain.chains import LLMMathChain
30
147
  from langchain_community.llms import OpenAI
31
148
  llm_math = LLMMathChain.from_llm(OpenAI())
32
- """
149
+ """ # noqa: E501
33
150
 
34
151
  llm_chain: LLMChain
35
152
  llm: Optional[BaseLanguageModel] = None
@@ -38,7 +38,7 @@ class OpenAIModerationChain(Chain):
38
38
  output_key: str = "output" #: :meta private:
39
39
  openai_api_key: Optional[str] = None
40
40
  openai_organization: Optional[str] = None
41
- _openai_pre_1_0: bool = Field(default=None)
41
+ openai_pre_1_0: bool = Field(default=None)
42
42
 
43
43
  @root_validator(pre=True)
44
44
  def validate_environment(cls, values: Dict) -> Dict:
@@ -58,16 +58,17 @@ class OpenAIModerationChain(Chain):
58
58
  openai.api_key = openai_api_key
59
59
  if openai_organization:
60
60
  openai.organization = openai_organization
61
- values["_openai_pre_1_0"] = False
61
+ values["openai_pre_1_0"] = False
62
62
  try:
63
63
  check_package_version("openai", gte_version="1.0")
64
64
  except ValueError:
65
- values["_openai_pre_1_0"] = True
66
- if values["_openai_pre_1_0"]:
65
+ values["openai_pre_1_0"] = True
66
+ if values["openai_pre_1_0"]:
67
67
  values["client"] = openai.Moderation
68
68
  else:
69
69
  values["client"] = openai.OpenAI()
70
70
  values["async_client"] = openai.AsyncOpenAI()
71
+
71
72
  except ImportError:
72
73
  raise ImportError(
73
74
  "Could not import openai python package. "
@@ -92,7 +93,7 @@ class OpenAIModerationChain(Chain):
92
93
  return [self.output_key]
93
94
 
94
95
  def _moderate(self, text: str, results: Any) -> str:
95
- if self._openai_pre_1_0:
96
+ if self.openai_pre_1_0:
96
97
  condition = results["flagged"]
97
98
  else:
98
99
  condition = results.flagged
@@ -110,7 +111,7 @@ class OpenAIModerationChain(Chain):
110
111
  run_manager: Optional[CallbackManagerForChainRun] = None,
111
112
  ) -> Dict[str, Any]:
112
113
  text = inputs[self.input_key]
113
- if self._openai_pre_1_0:
114
+ if self.openai_pre_1_0:
114
115
  results = self.client.create(text)
115
116
  output = self._moderate(text, results["results"][0])
116
117
  else:
@@ -123,7 +124,7 @@ class OpenAIModerationChain(Chain):
123
124
  inputs: Dict[str, Any],
124
125
  run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
125
126
  ) -> Dict[str, Any]:
126
- if self._openai_pre_1_0:
127
+ if self.openai_pre_1_0:
127
128
  return await super()._acall(inputs, run_manager=run_manager)
128
129
  text = inputs[self.input_key]
129
130
  results = await self.async_client.moderations.create(input=text)
@@ -5,15 +5,27 @@ from __future__ import annotations
5
5
  import warnings
6
6
  from typing import Any, Dict, List, Optional
7
7
 
8
+ from langchain_core._api import deprecated
8
9
  from langchain_core.callbacks import CallbackManagerForChainRun
9
10
  from langchain_core.language_models import BaseLanguageModel
11
+ from langchain_core.output_parsers import StrOutputParser
10
12
  from langchain_core.pydantic_v1 import root_validator
13
+ from langchain_core.runnables import Runnable
11
14
 
12
15
  from langchain.chains.base import Chain
13
- from langchain.chains.llm import LLMChain
14
16
  from langchain.chains.natbot.prompt import PROMPT
15
17
 
16
18
 
19
+ @deprecated(
20
+ since="0.2.13",
21
+ message=(
22
+ "Importing NatBotChain from langchain is deprecated and will be removed in "
23
+ "langchain 1.0. Please import from langchain_community instead: "
24
+ "from langchain_community.chains.natbot import NatBotChain. "
25
+ "You may need to pip install -U langchain-community."
26
+ ),
27
+ removal="1.0",
28
+ )
17
29
  class NatBotChain(Chain):
18
30
  """Implement an LLM driven browser.
19
31
 
@@ -37,7 +49,7 @@ class NatBotChain(Chain):
37
49
  natbot = NatBotChain.from_default("Buy me a new hat.")
38
50
  """
39
51
 
40
- llm_chain: LLMChain
52
+ llm_chain: Runnable
41
53
  objective: str
42
54
  """Objective that NatBot is tasked with completing."""
43
55
  llm: Optional[BaseLanguageModel] = None
@@ -60,7 +72,7 @@ class NatBotChain(Chain):
60
72
  "class method."
61
73
  )
62
74
  if "llm_chain" not in values and values["llm"] is not None:
63
- values["llm_chain"] = LLMChain(llm=values["llm"], prompt=PROMPT)
75
+ values["llm_chain"] = PROMPT | values["llm"] | StrOutputParser()
64
76
  return values
65
77
 
66
78
  @classmethod
@@ -77,7 +89,7 @@ class NatBotChain(Chain):
77
89
  cls, llm: BaseLanguageModel, objective: str, **kwargs: Any
78
90
  ) -> NatBotChain:
79
91
  """Load from LLM."""
80
- llm_chain = LLMChain(llm=llm, prompt=PROMPT)
92
+ llm_chain = PROMPT | llm | StrOutputParser()
81
93
  return cls(llm_chain=llm_chain, objective=objective, **kwargs)
82
94
 
83
95
  @property
@@ -104,12 +116,14 @@ class NatBotChain(Chain):
104
116
  _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
105
117
  url = inputs[self.input_url_key]
106
118
  browser_content = inputs[self.input_browser_content_key]
107
- llm_cmd = self.llm_chain.predict(
108
- objective=self.objective,
109
- url=url[:100],
110
- previous_command=self.previous_command,
111
- browser_content=browser_content[:4500],
112
- callbacks=_run_manager.get_child(),
119
+ llm_cmd = self.llm_chain.invoke(
120
+ {
121
+ "objective": self.objective,
122
+ "url": url[:100],
123
+ "previous_command": self.previous_command,
124
+ "browser_content": browser_content[:4500],
125
+ },
126
+ config={"callbacks": _run_manager.get_child()},
113
127
  )
114
128
  llm_cmd = llm_cmd.strip()
115
129
  self.previous_command = llm_cmd
@@ -35,6 +35,7 @@ GRAMMAR = r"""
35
35
  ?value: SIGNED_INT -> int
36
36
  | SIGNED_FLOAT -> float
37
37
  | DATE -> date
38
+ | DATETIME -> datetime
38
39
  | list
39
40
  | string
40
41
  | ("false" | "False" | "FALSE") -> false
@@ -42,6 +43,7 @@ GRAMMAR = r"""
42
43
 
43
44
  args: expr ("," expr)*
44
45
  DATE.2: /["']?(\d{4}-[01]\d-[0-3]\d)["']?/
46
+ DATETIME.2: /["']?\d{4}-[01]\d-[0-3]\dT[0-2]\d:[0-5]\d:[0-5]\d[Zz]?["']?/
45
47
  string: /'[^']*'/ | ESCAPED_STRING
46
48
  list: "[" [args] "]"
47
49
 
@@ -61,6 +63,13 @@ class ISO8601Date(TypedDict):
61
63
  type: Literal["date"]
62
64
 
63
65
 
66
+ class ISO8601DateTime(TypedDict):
67
+ """A datetime in ISO 8601 format (YYYY-MM-DDTHH:MM:SS)."""
68
+
69
+ datetime: str
70
+ type: Literal["datetime"]
71
+
72
+
64
73
  @v_args(inline=True)
65
74
  class QueryTransformer(Transformer):
66
75
  """Transform a query string into an intermediate representation."""
@@ -149,6 +158,20 @@ class QueryTransformer(Transformer):
149
158
  )
150
159
  return {"date": item, "type": "date"}
151
160
 
161
+ def datetime(self, item: Any) -> ISO8601DateTime:
162
+ item = str(item).strip("\"'")
163
+ try:
164
+ # Parse full ISO 8601 datetime format
165
+ datetime.datetime.strptime(item, "%Y-%m-%dT%H:%M:%S%z")
166
+ except ValueError:
167
+ try:
168
+ datetime.datetime.strptime(item, "%Y-%m-%dT%H:%M:%S")
169
+ except ValueError:
170
+ raise ValueError(
171
+ "Datetime values are expected to be in ISO 8601 format."
172
+ )
173
+ return {"datetime": item, "type": "datetime"}
174
+
152
175
  def string(self, item: Any) -> str:
153
176
  # Remove escaped quotes
154
177
  return str(item).strip("\"'")
@@ -8,8 +8,9 @@ from typing import Any, Callable, Dict, Optional, Sequence, cast
8
8
  from langchain_core.callbacks.manager import Callbacks
9
9
  from langchain_core.documents import Document
10
10
  from langchain_core.language_models import BaseLanguageModel
11
- from langchain_core.output_parsers import BaseOutputParser
11
+ from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
12
12
  from langchain_core.prompts import PromptTemplate
13
+ from langchain_core.runnables import Runnable
13
14
 
14
15
  from langchain.chains.llm import LLMChain
15
16
  from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
@@ -49,12 +50,15 @@ class LLMChainExtractor(BaseDocumentCompressor):
49
50
  """Document compressor that uses an LLM chain to extract
50
51
  the relevant parts of documents."""
51
52
 
52
- llm_chain: LLMChain
53
+ llm_chain: Runnable
53
54
  """LLM wrapper to use for compressing documents."""
54
55
 
55
56
  get_input: Callable[[str, Document], dict] = default_get_input
56
57
  """Callable for constructing the chain input from the query and a Document."""
57
58
 
59
+ class Config:
60
+ arbitrary_types_allowed = True
61
+
58
62
  def compress_documents(
59
63
  self,
60
64
  documents: Sequence[Document],
@@ -65,10 +69,13 @@ class LLMChainExtractor(BaseDocumentCompressor):
65
69
  compressed_docs = []
66
70
  for doc in documents:
67
71
  _input = self.get_input(query, doc)
68
- output_dict = self.llm_chain.invoke(_input, config={"callbacks": callbacks})
69
- output = output_dict[self.llm_chain.output_key]
70
- if self.llm_chain.prompt.output_parser is not None:
71
- output = self.llm_chain.prompt.output_parser.parse(output)
72
+ output_ = self.llm_chain.invoke(_input, config={"callbacks": callbacks})
73
+ if isinstance(self.llm_chain, LLMChain):
74
+ output = output_[self.llm_chain.output_key]
75
+ if self.llm_chain.prompt.output_parser is not None:
76
+ output = self.llm_chain.prompt.output_parser.parse(output)
77
+ else:
78
+ output = output_
72
79
  if len(output) == 0:
73
80
  continue
74
81
  compressed_docs.append(
@@ -85,9 +92,7 @@ class LLMChainExtractor(BaseDocumentCompressor):
85
92
  """Compress page content of raw documents asynchronously."""
86
93
  outputs = await asyncio.gather(
87
94
  *[
88
- self.llm_chain.apredict_and_parse(
89
- **self.get_input(query, doc), callbacks=callbacks
90
- )
95
+ self.llm_chain.ainvoke(self.get_input(query, doc), callbacks=callbacks)
91
96
  for doc in documents
92
97
  ]
93
98
  )
@@ -111,5 +116,9 @@ class LLMChainExtractor(BaseDocumentCompressor):
111
116
  """Initialize from LLM."""
112
117
  _prompt = prompt if prompt is not None else _get_default_chain_prompt()
113
118
  _get_input = get_input if get_input is not None else default_get_input
114
- llm_chain = LLMChain(llm=llm, prompt=_prompt, **(llm_chain_kwargs or {}))
119
+ if _prompt.output_parser is not None:
120
+ parser = _prompt.output_parser
121
+ else:
122
+ parser = StrOutputParser()
123
+ llm_chain = _prompt | llm | parser
115
124
  return cls(llm_chain=llm_chain, get_input=_get_input) # type: ignore[arg-type]