langchain 0.3.27__py3-none-any.whl → 0.4.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. langchain/agents/agent.py +16 -20
  2. langchain/agents/agent_iterator.py +19 -12
  3. langchain/agents/agent_toolkits/vectorstore/base.py +2 -0
  4. langchain/agents/chat/base.py +2 -0
  5. langchain/agents/conversational/base.py +2 -0
  6. langchain/agents/conversational_chat/base.py +2 -0
  7. langchain/agents/initialize.py +1 -1
  8. langchain/agents/json_chat/base.py +1 -0
  9. langchain/agents/mrkl/base.py +2 -0
  10. langchain/agents/openai_assistant/base.py +1 -1
  11. langchain/agents/openai_functions_agent/agent_token_buffer_memory.py +2 -0
  12. langchain/agents/openai_functions_agent/base.py +3 -2
  13. langchain/agents/openai_functions_multi_agent/base.py +1 -1
  14. langchain/agents/openai_tools/base.py +1 -0
  15. langchain/agents/output_parsers/json.py +2 -0
  16. langchain/agents/output_parsers/openai_functions.py +10 -3
  17. langchain/agents/output_parsers/openai_tools.py +8 -1
  18. langchain/agents/output_parsers/react_json_single_input.py +3 -0
  19. langchain/agents/output_parsers/react_single_input.py +3 -0
  20. langchain/agents/output_parsers/self_ask.py +2 -0
  21. langchain/agents/output_parsers/tools.py +16 -2
  22. langchain/agents/output_parsers/xml.py +3 -0
  23. langchain/agents/react/agent.py +1 -0
  24. langchain/agents/react/base.py +4 -0
  25. langchain/agents/react/output_parser.py +2 -0
  26. langchain/agents/schema.py +2 -0
  27. langchain/agents/self_ask_with_search/base.py +4 -0
  28. langchain/agents/structured_chat/base.py +5 -0
  29. langchain/agents/structured_chat/output_parser.py +13 -0
  30. langchain/agents/tool_calling_agent/base.py +1 -0
  31. langchain/agents/tools.py +3 -0
  32. langchain/agents/xml/base.py +7 -1
  33. langchain/callbacks/streaming_aiter.py +13 -2
  34. langchain/callbacks/streaming_aiter_final_only.py +11 -2
  35. langchain/callbacks/streaming_stdout_final_only.py +5 -0
  36. langchain/callbacks/tracers/logging.py +11 -0
  37. langchain/chains/api/base.py +5 -1
  38. langchain/chains/base.py +8 -2
  39. langchain/chains/combine_documents/base.py +7 -1
  40. langchain/chains/combine_documents/map_reduce.py +3 -0
  41. langchain/chains/combine_documents/map_rerank.py +6 -4
  42. langchain/chains/combine_documents/reduce.py +1 -0
  43. langchain/chains/combine_documents/refine.py +1 -0
  44. langchain/chains/combine_documents/stuff.py +5 -1
  45. langchain/chains/constitutional_ai/base.py +7 -0
  46. langchain/chains/conversation/base.py +4 -1
  47. langchain/chains/conversational_retrieval/base.py +67 -59
  48. langchain/chains/elasticsearch_database/base.py +2 -1
  49. langchain/chains/flare/base.py +2 -0
  50. langchain/chains/flare/prompts.py +2 -0
  51. langchain/chains/llm.py +7 -2
  52. langchain/chains/llm_bash/__init__.py +1 -1
  53. langchain/chains/llm_checker/base.py +12 -1
  54. langchain/chains/llm_math/base.py +9 -1
  55. langchain/chains/llm_summarization_checker/base.py +13 -1
  56. langchain/chains/llm_symbolic_math/__init__.py +1 -1
  57. langchain/chains/loading.py +4 -2
  58. langchain/chains/moderation.py +3 -0
  59. langchain/chains/natbot/base.py +3 -1
  60. langchain/chains/natbot/crawler.py +29 -0
  61. langchain/chains/openai_functions/base.py +2 -0
  62. langchain/chains/openai_functions/citation_fuzzy_match.py +9 -0
  63. langchain/chains/openai_functions/openapi.py +4 -0
  64. langchain/chains/openai_functions/qa_with_structure.py +3 -3
  65. langchain/chains/openai_functions/tagging.py +2 -0
  66. langchain/chains/qa_generation/base.py +4 -0
  67. langchain/chains/qa_with_sources/base.py +3 -0
  68. langchain/chains/qa_with_sources/retrieval.py +1 -1
  69. langchain/chains/qa_with_sources/vector_db.py +4 -2
  70. langchain/chains/query_constructor/base.py +4 -2
  71. langchain/chains/query_constructor/parser.py +64 -2
  72. langchain/chains/retrieval_qa/base.py +4 -0
  73. langchain/chains/router/base.py +14 -2
  74. langchain/chains/router/embedding_router.py +3 -0
  75. langchain/chains/router/llm_router.py +6 -4
  76. langchain/chains/router/multi_prompt.py +3 -0
  77. langchain/chains/router/multi_retrieval_qa.py +18 -0
  78. langchain/chains/sql_database/query.py +1 -0
  79. langchain/chains/structured_output/base.py +2 -0
  80. langchain/chains/transform.py +4 -0
  81. langchain/chat_models/base.py +55 -18
  82. langchain/document_loaders/blob_loaders/schema.py +1 -4
  83. langchain/embeddings/base.py +2 -0
  84. langchain/embeddings/cache.py +3 -3
  85. langchain/evaluation/agents/trajectory_eval_chain.py +3 -2
  86. langchain/evaluation/comparison/eval_chain.py +1 -0
  87. langchain/evaluation/criteria/eval_chain.py +3 -0
  88. langchain/evaluation/embedding_distance/base.py +11 -0
  89. langchain/evaluation/exact_match/base.py +14 -1
  90. langchain/evaluation/loading.py +1 -0
  91. langchain/evaluation/parsing/base.py +16 -3
  92. langchain/evaluation/parsing/json_distance.py +19 -8
  93. langchain/evaluation/parsing/json_schema.py +1 -4
  94. langchain/evaluation/qa/eval_chain.py +8 -0
  95. langchain/evaluation/qa/generate_chain.py +2 -0
  96. langchain/evaluation/regex_match/base.py +9 -1
  97. langchain/evaluation/scoring/eval_chain.py +1 -0
  98. langchain/evaluation/string_distance/base.py +6 -0
  99. langchain/memory/buffer.py +5 -0
  100. langchain/memory/buffer_window.py +2 -0
  101. langchain/memory/combined.py +1 -1
  102. langchain/memory/entity.py +47 -0
  103. langchain/memory/simple.py +3 -0
  104. langchain/memory/summary.py +30 -0
  105. langchain/memory/summary_buffer.py +3 -0
  106. langchain/memory/token_buffer.py +2 -0
  107. langchain/output_parsers/combining.py +4 -2
  108. langchain/output_parsers/enum.py +5 -1
  109. langchain/output_parsers/fix.py +8 -1
  110. langchain/output_parsers/pandas_dataframe.py +16 -1
  111. langchain/output_parsers/regex.py +2 -0
  112. langchain/output_parsers/retry.py +21 -1
  113. langchain/output_parsers/structured.py +10 -0
  114. langchain/output_parsers/yaml.py +4 -0
  115. langchain/pydantic_v1/__init__.py +1 -1
  116. langchain/retrievers/document_compressors/chain_extract.py +4 -2
  117. langchain/retrievers/document_compressors/cohere_rerank.py +2 -0
  118. langchain/retrievers/document_compressors/cross_encoder_rerank.py +2 -0
  119. langchain/retrievers/document_compressors/embeddings_filter.py +3 -0
  120. langchain/retrievers/document_compressors/listwise_rerank.py +1 -0
  121. langchain/retrievers/ensemble.py +2 -2
  122. langchain/retrievers/multi_query.py +3 -1
  123. langchain/retrievers/multi_vector.py +4 -1
  124. langchain/retrievers/parent_document_retriever.py +15 -0
  125. langchain/retrievers/self_query/base.py +19 -0
  126. langchain/retrievers/time_weighted_retriever.py +3 -0
  127. langchain/runnables/hub.py +12 -0
  128. langchain/runnables/openai_functions.py +6 -0
  129. langchain/smith/__init__.py +1 -0
  130. langchain/smith/evaluation/config.py +5 -22
  131. langchain/smith/evaluation/progress.py +12 -3
  132. langchain/smith/evaluation/runner_utils.py +240 -123
  133. langchain/smith/evaluation/string_run_evaluator.py +27 -0
  134. langchain/storage/encoder_backed.py +1 -0
  135. langchain/tools/python/__init__.py +1 -1
  136. {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/METADATA +2 -12
  137. {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/RECORD +140 -141
  138. langchain/smith/evaluation/utils.py +0 -0
  139. {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/WHEEL +0 -0
  140. {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/entry_points.txt +0 -0
  141. {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/licenses/LICENSE +0 -0
@@ -3,6 +3,7 @@ from typing import Literal, Optional, Union
3
3
 
4
4
  from langchain_core.agents import AgentAction, AgentFinish
5
5
  from pydantic import Field
6
+ from typing_extensions import override
6
7
 
7
8
  from langchain.agents import AgentOutputParser
8
9
 
@@ -65,6 +66,7 @@ class XMLAgentOutputParser(AgentOutputParser):
65
66
  None - no escaping is applied, which may lead to parsing conflicts.
66
67
  """
67
68
 
69
+ @override
68
70
  def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
69
71
  # Check for tool invocation first
70
72
  tool_matches = re.findall(r"<tool>(.*?)</tool>", text, re.DOTALL)
@@ -115,6 +117,7 @@ class XMLAgentOutputParser(AgentOutputParser):
115
117
  )
116
118
  raise ValueError(msg)
117
119
 
120
+ @override
118
121
  def get_format_instructions(self) -> str:
119
122
  raise NotImplementedError
120
123
 
@@ -116,6 +116,7 @@ def create_react_agent(
116
116
  Thought:{agent_scratchpad}'''
117
117
 
118
118
  prompt = PromptTemplate.from_template(template)
119
+
119
120
  """ # noqa: E501
120
121
  missing_vars = {"tools", "tool_names", "agent_scratchpad"}.difference(
121
122
  prompt.input_variables + list(prompt.partial_variables),
@@ -11,6 +11,7 @@ from langchain_core.language_models import BaseLanguageModel
11
11
  from langchain_core.prompts import BasePromptTemplate
12
12
  from langchain_core.tools import BaseTool, Tool
13
13
  from pydantic import Field
14
+ from typing_extensions import override
14
15
 
15
16
  from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
16
17
  from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
@@ -38,6 +39,7 @@ class ReActDocstoreAgent(Agent):
38
39
  output_parser: AgentOutputParser = Field(default_factory=ReActOutputParser)
39
40
 
40
41
  @classmethod
42
+ @override
41
43
  def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
42
44
  return ReActOutputParser()
43
45
 
@@ -47,6 +49,7 @@ class ReActDocstoreAgent(Agent):
47
49
  return AgentType.REACT_DOCSTORE
48
50
 
49
51
  @classmethod
52
+ @override
50
53
  def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
51
54
  """Return default prompt."""
52
55
  return WIKI_PROMPT
@@ -141,6 +144,7 @@ class ReActTextWorldAgent(ReActDocstoreAgent):
141
144
  """Agent for the ReAct TextWorld chain."""
142
145
 
143
146
  @classmethod
147
+ @override
144
148
  def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
145
149
  """Return default prompt."""
146
150
  return TEXTWORLD_PROMPT
@@ -3,6 +3,7 @@ from typing import Union
3
3
 
4
4
  from langchain_core.agents import AgentAction, AgentFinish
5
5
  from langchain_core.exceptions import OutputParserException
6
+ from typing_extensions import override
6
7
 
7
8
  from langchain.agents.agent import AgentOutputParser
8
9
 
@@ -10,6 +11,7 @@ from langchain.agents.agent import AgentOutputParser
10
11
  class ReActOutputParser(AgentOutputParser):
11
12
  """Output parser for the ReAct agent."""
12
13
 
14
+ @override
13
15
  def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
14
16
  action_prefix = "Action: "
15
17
  if not text.strip().split("\n")[-1].startswith(action_prefix):
@@ -2,12 +2,14 @@ from typing import Any
2
2
 
3
3
  from langchain_core.agents import AgentAction
4
4
  from langchain_core.prompts.chat import ChatPromptTemplate
5
+ from typing_extensions import override
5
6
 
6
7
 
7
8
  class AgentScratchPadChatPromptTemplate(ChatPromptTemplate):
8
9
  """Chat prompt template for the agent scratchpad."""
9
10
 
10
11
  @classmethod
12
+ @override
11
13
  def is_lc_serializable(cls) -> bool:
12
14
  return False
13
15
 
@@ -11,6 +11,7 @@ from langchain_core.prompts import BasePromptTemplate
11
11
  from langchain_core.runnables import Runnable, RunnablePassthrough
12
12
  from langchain_core.tools import BaseTool, Tool
13
13
  from pydantic import Field
14
+ from typing_extensions import override
14
15
 
15
16
  from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
16
17
  from langchain.agents.agent_types import AgentType
@@ -32,6 +33,7 @@ class SelfAskWithSearchAgent(Agent):
32
33
  output_parser: AgentOutputParser = Field(default_factory=SelfAskOutputParser)
33
34
 
34
35
  @classmethod
36
+ @override
35
37
  def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
36
38
  return SelfAskOutputParser()
37
39
 
@@ -41,6 +43,7 @@ class SelfAskWithSearchAgent(Agent):
41
43
  return AgentType.SELF_ASK_WITH_SEARCH
42
44
 
43
45
  @classmethod
46
+ @override
44
47
  def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
45
48
  """Prompt does not depend on tools."""
46
49
  return PROMPT
@@ -182,6 +185,7 @@ def create_self_ask_with_search_agent(
182
185
  Are followup questions needed here:{agent_scratchpad}'''
183
186
 
184
187
  prompt = PromptTemplate.from_template(template)
188
+
185
189
  """ # noqa: E501
186
190
  missing_vars = {"agent_scratchpad"}.difference(
187
191
  prompt.input_variables + list(prompt.partial_variables),
@@ -16,6 +16,7 @@ from langchain_core.runnables import Runnable, RunnablePassthrough
16
16
  from langchain_core.tools import BaseTool
17
17
  from langchain_core.tools.render import ToolsRenderer
18
18
  from pydantic import Field
19
+ from typing_extensions import override
19
20
 
20
21
  from langchain.agents.agent import Agent, AgentOutputParser
21
22
  from langchain.agents.format_scratchpad import format_log_to_str
@@ -70,6 +71,7 @@ class StructuredChatAgent(Agent):
70
71
  pass
71
72
 
72
73
  @classmethod
74
+ @override
73
75
  def _get_default_output_parser(
74
76
  cls,
75
77
  llm: Optional[BaseLanguageModel] = None,
@@ -78,10 +80,12 @@ class StructuredChatAgent(Agent):
78
80
  return StructuredChatOutputParserWithRetries.from_llm(llm=llm)
79
81
 
80
82
  @property
83
+ @override
81
84
  def _stop(self) -> list[str]:
82
85
  return ["Observation:"]
83
86
 
84
87
  @classmethod
88
+ @override
85
89
  def create_prompt(
86
90
  cls,
87
91
  tools: Sequence[BaseTool],
@@ -276,6 +280,7 @@ def create_structured_chat_agent(
276
280
  ("human", human),
277
281
  ]
278
282
  )
283
+
279
284
  """ # noqa: E501
280
285
  missing_vars = {"tools", "tool_names", "agent_scratchpad"}.difference(
281
286
  prompt.input_variables + list(prompt.partial_variables),
@@ -10,6 +10,7 @@ from langchain_core.agents import AgentAction, AgentFinish
10
10
  from langchain_core.exceptions import OutputParserException
11
11
  from langchain_core.language_models import BaseLanguageModel
12
12
  from pydantic import Field
13
+ from typing_extensions import override
13
14
 
14
15
  from langchain.agents.agent import AgentOutputParser
15
16
  from langchain.agents.structured_chat.prompt import FORMAT_INSTRUCTIONS
@@ -27,10 +28,12 @@ class StructuredChatOutputParser(AgentOutputParser):
27
28
  pattern: Pattern = re.compile(r"```(?:json\s+)?(\W.*?)```", re.DOTALL)
28
29
  """Regex pattern to parse the output."""
29
30
 
31
+ @override
30
32
  def get_format_instructions(self) -> str:
31
33
  """Returns formatting instructions for the given output parser."""
32
34
  return self.format_instructions
33
35
 
36
+ @override
34
37
  def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
35
38
  try:
36
39
  action_match = self.pattern.search(text)
@@ -65,9 +68,11 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
65
68
  output_fixing_parser: Optional[OutputFixingParser] = None
66
69
  """The output fixing parser to use."""
67
70
 
71
+ @override
68
72
  def get_format_instructions(self) -> str:
69
73
  return FORMAT_INSTRUCTIONS
70
74
 
75
+ @override
71
76
  def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
72
77
  try:
73
78
  if self.output_fixing_parser is not None:
@@ -83,6 +88,14 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
83
88
  llm: Optional[BaseLanguageModel] = None,
84
89
  base_parser: Optional[StructuredChatOutputParser] = None,
85
90
  ) -> StructuredChatOutputParserWithRetries:
91
+ """Create a StructuredChatOutputParserWithRetries from a language model.
92
+
93
+ Args:
94
+ llm: The language model to use.
95
+ base_parser: An optional StructuredChatOutputParser to use.
96
+ Returns:
97
+ An instance of StructuredChatOutputParserWithRetries.
98
+ """
86
99
  if llm is not None:
87
100
  base_parser = base_parser or StructuredChatOutputParser()
88
101
  output_fixing_parser: OutputFixingParser = OutputFixingParser.from_llm(
@@ -85,6 +85,7 @@ def create_tool_calling_agent(
85
85
  The agent prompt must have an `agent_scratchpad` key that is a
86
86
  ``MessagesPlaceholder``. Intermediate agent actions and tool output
87
87
  messages will be passed in here.
88
+
88
89
  """
89
90
  missing_vars = {"agent_scratchpad"}.difference(
90
91
  prompt.input_variables + list(prompt.partial_variables),
langchain/agents/tools.py CHANGED
@@ -7,6 +7,7 @@ from langchain_core.callbacks import (
7
7
  CallbackManagerForToolRun,
8
8
  )
9
9
  from langchain_core.tools import BaseTool, tool
10
+ from typing_extensions import override
10
11
 
11
12
 
12
13
  class InvalidTool(BaseTool):
@@ -17,6 +18,7 @@ class InvalidTool(BaseTool):
17
18
  description: str = "Called when tool name is invalid. Suggests valid tool names."
18
19
  """Description of the tool."""
19
20
 
21
+ @override
20
22
  def _run(
21
23
  self,
22
24
  requested_tool_name: str,
@@ -30,6 +32,7 @@ class InvalidTool(BaseTool):
30
32
  f"try one of [{available_tool_names_str}]."
31
33
  )
32
34
 
35
+ @override
33
36
  async def _arun(
34
37
  self,
35
38
  requested_tool_name: str,
@@ -10,6 +10,7 @@ from langchain_core.prompts.chat import AIMessagePromptTemplate, ChatPromptTempl
10
10
  from langchain_core.runnables import Runnable, RunnablePassthrough
11
11
  from langchain_core.tools import BaseTool
12
12
  from langchain_core.tools.render import ToolsRenderer, render_text_description
13
+ from typing_extensions import override
13
14
 
14
15
  from langchain.agents.agent import BaseSingleActionAgent
15
16
  from langchain.agents.format_scratchpad import format_xml
@@ -36,7 +37,6 @@ class XMLAgent(BaseSingleActionAgent):
36
37
  tools = ...
37
38
  model =
38
39
 
39
-
40
40
  """
41
41
 
42
42
  tools: list[BaseTool]
@@ -45,11 +45,13 @@ class XMLAgent(BaseSingleActionAgent):
45
45
  """Chain to use to predict action."""
46
46
 
47
47
  @property
48
+ @override
48
49
  def input_keys(self) -> list[str]:
49
50
  return ["input"]
50
51
 
51
52
  @staticmethod
52
53
  def get_default_prompt() -> ChatPromptTemplate:
54
+ """Return the default prompt for the XML agent."""
53
55
  base_prompt = ChatPromptTemplate.from_template(agent_instructions)
54
56
  return base_prompt + AIMessagePromptTemplate.from_template(
55
57
  "{intermediate_steps}",
@@ -57,8 +59,10 @@ class XMLAgent(BaseSingleActionAgent):
57
59
 
58
60
  @staticmethod
59
61
  def get_default_output_parser() -> XMLAgentOutputParser:
62
+ """Return an XMLAgentOutputParser."""
60
63
  return XMLAgentOutputParser()
61
64
 
65
+ @override
62
66
  def plan(
63
67
  self,
64
68
  intermediate_steps: list[tuple[AgentAction, str]],
@@ -83,6 +87,7 @@ class XMLAgent(BaseSingleActionAgent):
83
87
  response = self.llm_chain(inputs, callbacks=callbacks)
84
88
  return response[self.llm_chain.output_key]
85
89
 
90
+ @override
86
91
  async def aplan(
87
92
  self,
88
93
  intermediate_steps: list[tuple[AgentAction, str]],
@@ -203,6 +208,7 @@ def create_xml_agent(
203
208
  Question: {input}
204
209
  {agent_scratchpad}'''
205
210
  prompt = PromptTemplate.from_template(template)
211
+
206
212
  """ # noqa: E501
207
213
  missing_vars = {"tools", "agent_scratchpad"}.difference(
208
214
  prompt.input_variables + list(prompt.partial_variables),
@@ -6,6 +6,8 @@ from typing import Any, Literal, Union, cast
6
6
 
7
7
  from langchain_core.callbacks import AsyncCallbackHandler
8
8
  from langchain_core.outputs import LLMResult
9
+ from langchain_core.v1.messages import AIMessage
10
+ from typing_extensions import override
9
11
 
10
12
  # TODO If used by two LLM runs in parallel this won't work as expected
11
13
 
@@ -19,12 +21,15 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
19
21
 
20
22
  @property
21
23
  def always_verbose(self) -> bool:
24
+ """Always verbose."""
22
25
  return True
23
26
 
24
27
  def __init__(self) -> None:
28
+ """Instantiate AsyncIteratorCallbackHandler."""
25
29
  self.queue = asyncio.Queue()
26
30
  self.done = asyncio.Event()
27
31
 
32
+ @override
28
33
  async def on_llm_start(
29
34
  self,
30
35
  serialized: dict[str, Any],
@@ -34,19 +39,25 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
34
39
  # If two calls are made in a row, this resets the state
35
40
  self.done.clear()
36
41
 
42
+ @override
37
43
  async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
38
44
  if token is not None and token != "":
39
45
  self.queue.put_nowait(token)
40
46
 
41
- async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
47
+ @override
48
+ async def on_llm_end(
49
+ self, response: Union[LLMResult, AIMessage], **kwargs: Any
50
+ ) -> None:
42
51
  self.done.set()
43
52
 
53
+ @override
44
54
  async def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
45
55
  self.done.set()
46
56
 
47
57
  # TODO implement the other methods
48
58
 
49
59
  async def aiter(self) -> AsyncIterator[str]:
60
+ """Asynchronous iterator that yields tokens."""
50
61
  while not self.queue.empty() or not self.done.is_set():
51
62
  # Wait for the next token in the queue,
52
63
  # but stop waiting if the done event is set
@@ -65,7 +76,7 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
65
76
  other.pop().cancel()
66
77
 
67
78
  # Extract the value of the first completed task
68
- token_or_done = cast(Union[str, Literal[True]], done.pop().result())
79
+ token_or_done = cast("Union[str, Literal[True]]", done.pop().result())
69
80
 
70
81
  # If the extracted value is the boolean True, the done event was set
71
82
  if token_or_done is True:
@@ -1,8 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Any, Optional
3
+ from typing import Any, Optional, Union
4
4
 
5
5
  from langchain_core.outputs import LLMResult
6
+ from langchain_core.v1.messages import AIMessage
7
+ from typing_extensions import override
6
8
 
7
9
  from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
8
10
 
@@ -15,6 +17,7 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
15
17
  """
16
18
 
17
19
  def append_to_last_tokens(self, token: str) -> None:
20
+ """Append token to the last tokens."""
18
21
  self.last_tokens.append(token)
19
22
  self.last_tokens_stripped.append(token.strip())
20
23
  if len(self.last_tokens) > len(self.answer_prefix_tokens):
@@ -22,6 +25,7 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
22
25
  self.last_tokens_stripped.pop(0)
23
26
 
24
27
  def check_if_answer_reached(self) -> bool:
28
+ """Check if the answer has been reached."""
25
29
  if self.strip_tokens:
26
30
  return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
27
31
  return self.last_tokens == self.answer_prefix_tokens
@@ -60,6 +64,7 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
60
64
  self.stream_prefix = stream_prefix
61
65
  self.answer_reached = False
62
66
 
67
+ @override
63
68
  async def on_llm_start(
64
69
  self,
65
70
  serialized: dict[str, Any],
@@ -70,10 +75,14 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
70
75
  self.done.clear()
71
76
  self.answer_reached = False
72
77
 
73
- async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
78
+ @override
79
+ async def on_llm_end(
80
+ self, response: Union[LLMResult, AIMessage], **kwargs: Any
81
+ ) -> None:
74
82
  if self.answer_reached:
75
83
  self.done.set()
76
84
 
85
+ @override
77
86
  async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
78
87
  # Remember the last n tokens, where n = len(answer_prefix_tokens)
79
88
  self.append_to_last_tokens(token)
@@ -4,6 +4,7 @@ import sys
4
4
  from typing import Any, Optional
5
5
 
6
6
  from langchain_core.callbacks import StreamingStdOutCallbackHandler
7
+ from typing_extensions import override
7
8
 
8
9
  DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"]
9
10
 
@@ -16,6 +17,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
16
17
  """
17
18
 
18
19
  def append_to_last_tokens(self, token: str) -> None:
20
+ """Append token to the last tokens."""
19
21
  self.last_tokens.append(token)
20
22
  self.last_tokens_stripped.append(token.strip())
21
23
  if len(self.last_tokens) > len(self.answer_prefix_tokens):
@@ -23,6 +25,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
23
25
  self.last_tokens_stripped.pop(0)
24
26
 
25
27
  def check_if_answer_reached(self) -> bool:
28
+ """Check if the answer has been reached."""
26
29
  if self.strip_tokens:
27
30
  return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
28
31
  return self.last_tokens == self.answer_prefix_tokens
@@ -61,6 +64,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
61
64
  self.stream_prefix = stream_prefix
62
65
  self.answer_reached = False
63
66
 
67
+ @override
64
68
  def on_llm_start(
65
69
  self,
66
70
  serialized: dict[str, Any],
@@ -70,6 +74,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
70
74
  """Run when LLM starts running."""
71
75
  self.answer_reached = False
72
76
 
77
+ @override
73
78
  def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
74
79
  """Run on new LLM token. Only available when streaming is enabled."""
75
80
 
@@ -7,6 +7,7 @@ from uuid import UUID
7
7
  from langchain_core.exceptions import TracerException
8
8
  from langchain_core.tracers.stdout import FunctionCallbackHandler
9
9
  from langchain_core.utils.input import get_bolded_text, get_colored_text
10
+ from typing_extensions import override
10
11
 
11
12
 
12
13
  class LoggingCallbackHandler(FunctionCallbackHandler):
@@ -21,6 +22,15 @@ class LoggingCallbackHandler(FunctionCallbackHandler):
21
22
  extra: Optional[dict] = None,
22
23
  **kwargs: Any,
23
24
  ) -> None:
25
+ """
26
+ Initialize the LoggingCallbackHandler.
27
+
28
+ Args:
29
+ logger: the logger to use for logging
30
+ log_level: the logging level (default: logging.INFO)
31
+ extra: the extra context to log (default: None)
32
+ **kwargs:
33
+ """
24
34
  log_method = getattr(logger, logging.getLevelName(level=log_level).lower())
25
35
 
26
36
  def callback(text: str) -> None:
@@ -28,6 +38,7 @@ class LoggingCallbackHandler(FunctionCallbackHandler):
28
38
 
29
39
  super().__init__(function=callback, **kwargs)
30
40
 
41
+ @override
31
42
  def on_text(
32
43
  self,
33
44
  text: str,
@@ -191,6 +191,7 @@ try:
191
191
  )
192
192
  async for event in events:
193
193
  event["messages"][-1].pretty_print()
194
+
194
195
  """ # noqa: E501
195
196
 
196
197
  api_request_chain: LLMChain
@@ -386,7 +387,10 @@ try:
386
387
  except ImportError:
387
388
 
388
389
  class APIChain: # type: ignore[no-redef]
389
- def __init__(self, *args: Any, **kwargs: Any) -> None:
390
+ """Raise an ImportError if APIChain is used without langchain_community."""
391
+
392
+ def __init__(self, *_: Any, **__: Any) -> None:
393
+ """Raise an ImportError if APIChain is used without langchain_community."""
390
394
  msg = (
391
395
  "To use the APIChain, you must install the langchain_community package."
392
396
  "pip install langchain_community"
langchain/chains/base.py CHANGED
@@ -108,6 +108,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
108
108
  arbitrary_types_allowed=True,
109
109
  )
110
110
 
111
+ @override
111
112
  def get_input_schema(
112
113
  self,
113
114
  config: Optional[RunnableConfig] = None,
@@ -115,6 +116,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
115
116
  # This is correct, but pydantic typings/mypy don't think so.
116
117
  return create_model("ChainInput", **dict.fromkeys(self.input_keys, (Any, None)))
117
118
 
119
+ @override
118
120
  def get_output_schema(
119
121
  self,
120
122
  config: Optional[RunnableConfig] = None,
@@ -409,7 +411,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
409
411
 
410
412
  return self.invoke(
411
413
  inputs,
412
- cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}),
414
+ cast("RunnableConfig", {k: v for k, v in config.items() if v is not None}),
413
415
  return_only_outputs=return_only_outputs,
414
416
  include_run_info=include_run_info,
415
417
  )
@@ -459,7 +461,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
459
461
  }
460
462
  return await self.ainvoke(
461
463
  inputs,
462
- cast(RunnableConfig, {k: v for k, v in config.items() if k is not None}),
464
+ cast("RunnableConfig", {k: v for k, v in config.items() if k is not None}),
463
465
  return_only_outputs=return_only_outputs,
464
466
  include_run_info=include_run_info,
465
467
  )
@@ -616,6 +618,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
616
618
  context = "Weather report for Boise, Idaho on 07/03/23..."
617
619
  chain.run(question=question, context=context)
618
620
  # -> "The temperature in Boise is..."
621
+
619
622
  """
620
623
  # Run at start to make sure this is possible/defined
621
624
  _output_key = self._run_output_key
@@ -690,6 +693,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
690
693
  context = "Weather report for Boise, Idaho on 07/03/23..."
691
694
  await chain.arun(question=question, context=context)
692
695
  # -> "The temperature in Boise is..."
696
+
693
697
  """
694
698
  if len(self.output_keys) != 1:
695
699
  msg = (
@@ -744,6 +748,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
744
748
 
745
749
  chain.dict(exclude_unset=True)
746
750
  # -> {"_type": "foo", "verbose": False, ...}
751
+
747
752
  """
748
753
  _dict = super().dict(**kwargs)
749
754
  with contextlib.suppress(NotImplementedError):
@@ -763,6 +768,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
763
768
  .. code-block:: python
764
769
 
765
770
  chain.save(file_path="path/chain.yaml")
771
+
766
772
  """
767
773
  if self.memory is not None:
768
774
  msg = "Saving of memory is not yet supported."
@@ -14,6 +14,7 @@ from langchain_core.runnables.config import RunnableConfig
14
14
  from langchain_core.utils.pydantic import create_model
15
15
  from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
16
16
  from pydantic import BaseModel, Field
17
+ from typing_extensions import override
17
18
 
18
19
  from langchain.chains.base import Chain
19
20
 
@@ -46,6 +47,7 @@ class BaseCombineDocumentsChain(Chain, ABC):
46
47
  input_key: str = "input_documents" #: :meta private:
47
48
  output_key: str = "output_text" #: :meta private:
48
49
 
50
+ @override
49
51
  def get_input_schema(
50
52
  self,
51
53
  config: Optional[RunnableConfig] = None,
@@ -55,6 +57,7 @@ class BaseCombineDocumentsChain(Chain, ABC):
55
57
  **{self.input_key: (list[Document], None)},
56
58
  )
57
59
 
60
+ @override
58
61
  def get_output_schema(
59
62
  self,
60
63
  config: Optional[RunnableConfig] = None,
@@ -80,7 +83,7 @@ class BaseCombineDocumentsChain(Chain, ABC):
80
83
  """
81
84
  return [self.output_key]
82
85
 
83
- def prompt_length(self, docs: list[Document], **kwargs: Any) -> Optional[int]:
86
+ def prompt_length(self, docs: list[Document], **kwargs: Any) -> Optional[int]: # noqa: ARG002
84
87
  """Return the prompt length given the documents passed in.
85
88
 
86
89
  This can be used by a caller to determine whether passing in a list
@@ -231,6 +234,7 @@ class AnalyzeDocumentChain(Chain):
231
234
  input_documents=itemgetter("input_document") | split_text,
232
235
  ) | chain.pick("output_text")
233
236
  )
237
+
234
238
  """
235
239
 
236
240
  input_key: str = "input_document" #: :meta private:
@@ -253,6 +257,7 @@ class AnalyzeDocumentChain(Chain):
253
257
  """
254
258
  return self.combine_docs_chain.output_keys
255
259
 
260
+ @override
256
261
  def get_input_schema(
257
262
  self,
258
263
  config: Optional[RunnableConfig] = None,
@@ -262,6 +267,7 @@ class AnalyzeDocumentChain(Chain):
262
267
  **{self.input_key: (str, None)},
263
268
  )
264
269
 
270
+ @override
265
271
  def get_output_schema(
266
272
  self,
267
273
  config: Optional[RunnableConfig] = None,
@@ -10,6 +10,7 @@ from langchain_core.documents import Document
10
10
  from langchain_core.runnables.config import RunnableConfig
11
11
  from langchain_core.utils.pydantic import create_model
12
12
  from pydantic import BaseModel, ConfigDict, model_validator
13
+ from typing_extensions import override
13
14
 
14
15
  from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
15
16
  from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
@@ -98,6 +99,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
98
99
  llm_chain=llm_chain,
99
100
  reduce_documents_chain=reduce_documents_chain,
100
101
  )
102
+
101
103
  """
102
104
 
103
105
  llm_chain: LLMChain
@@ -111,6 +113,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
111
113
  return_intermediate_steps: bool = False
112
114
  """Return the results of the map steps in the output."""
113
115
 
116
+ @override
114
117
  def get_output_schema(
115
118
  self,
116
119
  config: Optional[RunnableConfig] = None,