langchain 0.3.27__py3-none-any.whl → 0.4.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. langchain/agents/agent.py +16 -20
  2. langchain/agents/agent_iterator.py +19 -12
  3. langchain/agents/agent_toolkits/vectorstore/base.py +2 -0
  4. langchain/agents/chat/base.py +2 -0
  5. langchain/agents/conversational/base.py +2 -0
  6. langchain/agents/conversational_chat/base.py +2 -0
  7. langchain/agents/initialize.py +1 -1
  8. langchain/agents/json_chat/base.py +1 -0
  9. langchain/agents/mrkl/base.py +2 -0
  10. langchain/agents/openai_assistant/base.py +1 -1
  11. langchain/agents/openai_functions_agent/agent_token_buffer_memory.py +2 -0
  12. langchain/agents/openai_functions_agent/base.py +3 -2
  13. langchain/agents/openai_functions_multi_agent/base.py +1 -1
  14. langchain/agents/openai_tools/base.py +1 -0
  15. langchain/agents/output_parsers/json.py +2 -0
  16. langchain/agents/output_parsers/openai_functions.py +10 -3
  17. langchain/agents/output_parsers/openai_tools.py +8 -1
  18. langchain/agents/output_parsers/react_json_single_input.py +3 -0
  19. langchain/agents/output_parsers/react_single_input.py +3 -0
  20. langchain/agents/output_parsers/self_ask.py +2 -0
  21. langchain/agents/output_parsers/tools.py +16 -2
  22. langchain/agents/output_parsers/xml.py +3 -0
  23. langchain/agents/react/agent.py +1 -0
  24. langchain/agents/react/base.py +4 -0
  25. langchain/agents/react/output_parser.py +2 -0
  26. langchain/agents/schema.py +2 -0
  27. langchain/agents/self_ask_with_search/base.py +4 -0
  28. langchain/agents/structured_chat/base.py +5 -0
  29. langchain/agents/structured_chat/output_parser.py +13 -0
  30. langchain/agents/tool_calling_agent/base.py +1 -0
  31. langchain/agents/tools.py +3 -0
  32. langchain/agents/xml/base.py +7 -1
  33. langchain/callbacks/streaming_aiter.py +13 -2
  34. langchain/callbacks/streaming_aiter_final_only.py +11 -2
  35. langchain/callbacks/streaming_stdout_final_only.py +5 -0
  36. langchain/callbacks/tracers/logging.py +11 -0
  37. langchain/chains/api/base.py +5 -1
  38. langchain/chains/base.py +8 -2
  39. langchain/chains/combine_documents/base.py +7 -1
  40. langchain/chains/combine_documents/map_reduce.py +3 -0
  41. langchain/chains/combine_documents/map_rerank.py +6 -4
  42. langchain/chains/combine_documents/reduce.py +1 -0
  43. langchain/chains/combine_documents/refine.py +1 -0
  44. langchain/chains/combine_documents/stuff.py +5 -1
  45. langchain/chains/constitutional_ai/base.py +7 -0
  46. langchain/chains/conversation/base.py +4 -1
  47. langchain/chains/conversational_retrieval/base.py +67 -59
  48. langchain/chains/elasticsearch_database/base.py +2 -1
  49. langchain/chains/flare/base.py +2 -0
  50. langchain/chains/flare/prompts.py +2 -0
  51. langchain/chains/llm.py +7 -2
  52. langchain/chains/llm_bash/__init__.py +1 -1
  53. langchain/chains/llm_checker/base.py +12 -1
  54. langchain/chains/llm_math/base.py +9 -1
  55. langchain/chains/llm_summarization_checker/base.py +13 -1
  56. langchain/chains/llm_symbolic_math/__init__.py +1 -1
  57. langchain/chains/loading.py +4 -2
  58. langchain/chains/moderation.py +3 -0
  59. langchain/chains/natbot/base.py +3 -1
  60. langchain/chains/natbot/crawler.py +29 -0
  61. langchain/chains/openai_functions/base.py +2 -0
  62. langchain/chains/openai_functions/citation_fuzzy_match.py +9 -0
  63. langchain/chains/openai_functions/openapi.py +4 -0
  64. langchain/chains/openai_functions/qa_with_structure.py +3 -3
  65. langchain/chains/openai_functions/tagging.py +2 -0
  66. langchain/chains/qa_generation/base.py +4 -0
  67. langchain/chains/qa_with_sources/base.py +3 -0
  68. langchain/chains/qa_with_sources/retrieval.py +1 -1
  69. langchain/chains/qa_with_sources/vector_db.py +4 -2
  70. langchain/chains/query_constructor/base.py +4 -2
  71. langchain/chains/query_constructor/parser.py +64 -2
  72. langchain/chains/retrieval_qa/base.py +4 -0
  73. langchain/chains/router/base.py +14 -2
  74. langchain/chains/router/embedding_router.py +3 -0
  75. langchain/chains/router/llm_router.py +6 -4
  76. langchain/chains/router/multi_prompt.py +3 -0
  77. langchain/chains/router/multi_retrieval_qa.py +18 -0
  78. langchain/chains/sql_database/query.py +1 -0
  79. langchain/chains/structured_output/base.py +2 -0
  80. langchain/chains/transform.py +4 -0
  81. langchain/chat_models/base.py +55 -18
  82. langchain/document_loaders/blob_loaders/schema.py +1 -4
  83. langchain/embeddings/base.py +2 -0
  84. langchain/embeddings/cache.py +3 -3
  85. langchain/evaluation/agents/trajectory_eval_chain.py +3 -2
  86. langchain/evaluation/comparison/eval_chain.py +1 -0
  87. langchain/evaluation/criteria/eval_chain.py +3 -0
  88. langchain/evaluation/embedding_distance/base.py +11 -0
  89. langchain/evaluation/exact_match/base.py +14 -1
  90. langchain/evaluation/loading.py +1 -0
  91. langchain/evaluation/parsing/base.py +16 -3
  92. langchain/evaluation/parsing/json_distance.py +19 -8
  93. langchain/evaluation/parsing/json_schema.py +1 -4
  94. langchain/evaluation/qa/eval_chain.py +8 -0
  95. langchain/evaluation/qa/generate_chain.py +2 -0
  96. langchain/evaluation/regex_match/base.py +9 -1
  97. langchain/evaluation/scoring/eval_chain.py +1 -0
  98. langchain/evaluation/string_distance/base.py +6 -0
  99. langchain/memory/buffer.py +5 -0
  100. langchain/memory/buffer_window.py +2 -0
  101. langchain/memory/combined.py +1 -1
  102. langchain/memory/entity.py +47 -0
  103. langchain/memory/simple.py +3 -0
  104. langchain/memory/summary.py +30 -0
  105. langchain/memory/summary_buffer.py +3 -0
  106. langchain/memory/token_buffer.py +2 -0
  107. langchain/output_parsers/combining.py +4 -2
  108. langchain/output_parsers/enum.py +5 -1
  109. langchain/output_parsers/fix.py +8 -1
  110. langchain/output_parsers/pandas_dataframe.py +16 -1
  111. langchain/output_parsers/regex.py +2 -0
  112. langchain/output_parsers/retry.py +21 -1
  113. langchain/output_parsers/structured.py +10 -0
  114. langchain/output_parsers/yaml.py +4 -0
  115. langchain/pydantic_v1/__init__.py +1 -1
  116. langchain/retrievers/document_compressors/chain_extract.py +4 -2
  117. langchain/retrievers/document_compressors/cohere_rerank.py +2 -0
  118. langchain/retrievers/document_compressors/cross_encoder_rerank.py +2 -0
  119. langchain/retrievers/document_compressors/embeddings_filter.py +3 -0
  120. langchain/retrievers/document_compressors/listwise_rerank.py +1 -0
  121. langchain/retrievers/ensemble.py +2 -2
  122. langchain/retrievers/multi_query.py +3 -1
  123. langchain/retrievers/multi_vector.py +4 -1
  124. langchain/retrievers/parent_document_retriever.py +15 -0
  125. langchain/retrievers/self_query/base.py +19 -0
  126. langchain/retrievers/time_weighted_retriever.py +3 -0
  127. langchain/runnables/hub.py +12 -0
  128. langchain/runnables/openai_functions.py +6 -0
  129. langchain/smith/__init__.py +1 -0
  130. langchain/smith/evaluation/config.py +5 -22
  131. langchain/smith/evaluation/progress.py +12 -3
  132. langchain/smith/evaluation/runner_utils.py +240 -123
  133. langchain/smith/evaluation/string_run_evaluator.py +27 -0
  134. langchain/storage/encoder_backed.py +1 -0
  135. langchain/tools/python/__init__.py +1 -1
  136. {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/METADATA +2 -12
  137. {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/RECORD +140 -141
  138. langchain/smith/evaluation/utils.py +0 -0
  139. {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/WHEEL +0 -0
  140. {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/entry_points.txt +0 -0
  141. {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/licenses/LICENSE +0 -0
@@ -61,6 +61,7 @@ class Crawler:
61
61
  """
62
62
 
63
63
  def __init__(self) -> None:
64
+ """Initialize the crawler."""
64
65
  try:
65
66
  from playwright.sync_api import sync_playwright
66
67
  except ImportError as e:
@@ -78,11 +79,22 @@ class Crawler:
78
79
  self.client: CDPSession
79
80
 
80
81
  def go_to_page(self, url: str) -> None:
82
+ """Navigate to the given URL.
83
+
84
+ Args:
85
+ url: The URL to navigate to. If it does not contain a scheme, it will be
86
+ prefixed with "http://".
87
+ """
81
88
  self.page.goto(url=url if "://" in url else "http://" + url)
82
89
  self.client = self.page.context.new_cdp_session(self.page)
83
90
  self.page_element_buffer = {}
84
91
 
85
92
  def scroll(self, direction: str) -> None:
93
+ """Scroll the page in the given direction.
94
+
95
+ Args:
96
+ direction: The direction to scroll in, either "up" or "down".
97
+ """
86
98
  if direction == "up":
87
99
  self.page.evaluate(
88
100
  "(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop - window.innerHeight;" # noqa: E501
@@ -93,6 +105,11 @@ class Crawler:
93
105
  )
94
106
 
95
107
  def click(self, id_: Union[str, int]) -> None:
108
+ """Click on an element with the given id.
109
+
110
+ Args:
111
+ id_: The id of the element to click on.
112
+ """
96
113
  # Inject javascript into the page which removes the target= attribute from links
97
114
  js = """
98
115
  links = document.getElementsByTagName("a");
@@ -112,13 +129,25 @@ class Crawler:
112
129
  print("Could not find element") # noqa: T201
113
130
 
114
131
  def type(self, id_: Union[str, int], text: str) -> None:
132
+ """Type text into an element with the given id.
133
+
134
+ Args:
135
+ id_: The id of the element to type into.
136
+ text: The text to type into the element.
137
+ """
115
138
  self.click(id_)
116
139
  self.page.keyboard.type(text)
117
140
 
118
141
  def enter(self) -> None:
142
+ """Press the Enter key."""
119
143
  self.page.keyboard.press("Enter")
120
144
 
121
145
  def crawl(self) -> list[str]:
146
+ """Crawl the current page.
147
+
148
+ Returns:
149
+ A list of the elements in the viewport.
150
+ """
122
151
  page = self.page
123
152
  page_element_buffer = self.page_element_buffer
124
153
  start = time.time()
@@ -121,6 +121,7 @@ def create_openai_fn_chain(
121
121
  chain = create_openai_fn_chain([RecordPerson, RecordDog], llm, prompt)
122
122
  chain.run("Harry was a chubby brown beagle who loved chicken")
123
123
  # -> RecordDog(name="Harry", color="brown", fav_food="chicken")
124
+
124
125
  """ # noqa: E501
125
126
  if not functions:
126
127
  msg = "Need to pass in at least one function. Received zero."
@@ -203,6 +204,7 @@ def create_structured_output_chain(
203
204
  chain = create_structured_output_chain(Dog, llm, prompt)
204
205
  chain.run("Harry was a chubby brown beagle who loved chicken")
205
206
  # -> Dog(name="Harry", color="brown", fav_food="chicken")
207
+
206
208
  """ # noqa: E501
207
209
  if isinstance(output_schema, dict):
208
210
  function: Any = {
@@ -45,6 +45,14 @@ class FactWithEvidence(BaseModel):
45
45
  yield from s.spans()
46
46
 
47
47
  def get_spans(self, context: str) -> Iterator[str]:
48
+ """Get spans of the substring quote in the context.
49
+
50
+ Args:
51
+ context: The context in which to find the spans of the substring quote.
52
+
53
+ Returns:
54
+ An iterator over the spans of the substring quote in the context.
55
+ """
48
56
  for quote in self.substring_quote:
49
57
  yield from self._get_span(quote, context)
50
58
 
@@ -86,6 +94,7 @@ def create_citation_fuzzy_match_runnable(llm: BaseChatModel) -> Runnable:
86
94
 
87
95
  Returns:
88
96
  Runnable that can be used to answer questions with citations.
97
+
89
98
  """
90
99
  if llm.bind_tools is BaseChatModel.bind_tools:
91
100
  msg = "Language model must implement bind_tools to use this function."
@@ -13,6 +13,7 @@ from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsPa
13
13
  from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
14
14
  from langchain_core.utils.input import get_colored_text
15
15
  from requests import Response
16
+ from typing_extensions import override
16
17
 
17
18
  from langchain.chains.base import Chain
18
19
  from langchain.chains.llm import LLMChain
@@ -202,10 +203,12 @@ class SimpleRequestChain(Chain):
202
203
  """Key to use for the input of the request."""
203
204
 
204
205
  @property
206
+ @override
205
207
  def input_keys(self) -> list[str]:
206
208
  return [self.input_key]
207
209
 
208
210
  @property
211
+ @override
209
212
  def output_keys(self) -> list[str]:
210
213
  return [self.output_key]
211
214
 
@@ -342,6 +345,7 @@ def get_openapi_chain(
342
345
  `ChatOpenAI(model="gpt-3.5-turbo-0613")`.
343
346
  prompt: Main prompt template to use.
344
347
  request_chain: Chain for taking the functions output and executing the request.
348
+
345
349
  """ # noqa: E501
346
350
  try:
347
351
  from langchain_community.utilities.openapi import OpenAPISpec
@@ -76,11 +76,11 @@ def create_qa_with_structure_chain(
76
76
  raise ValueError(msg)
77
77
  if isinstance(schema, type) and is_basemodel_subclass(schema):
78
78
  if hasattr(schema, "model_json_schema"):
79
- schema_dict = cast(dict, schema.model_json_schema())
79
+ schema_dict = cast("dict", schema.model_json_schema())
80
80
  else:
81
- schema_dict = cast(dict, schema.schema())
81
+ schema_dict = cast("dict", schema.schema())
82
82
  else:
83
- schema_dict = cast(dict, schema)
83
+ schema_dict = cast("dict", schema)
84
84
  function = {
85
85
  "name": schema_dict["title"],
86
86
  "description": schema_dict["description"],
@@ -86,6 +86,7 @@ def create_tagging_chain(
86
86
 
87
87
  Returns:
88
88
  Chain (LLMChain) that can be used to extract information from a passage.
89
+
89
90
  """
90
91
  function = _get_tagging_function(schema)
91
92
  prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
@@ -154,6 +155,7 @@ def create_tagging_chain_pydantic(
154
155
 
155
156
  Returns:
156
157
  Chain (LLMChain) that can be used to extract information from a passage.
158
+
157
159
  """
158
160
  if hasattr(pydantic_schema, "model_json_schema"):
159
161
  openai_schema = pydantic_schema.model_json_schema()
@@ -9,6 +9,7 @@ from langchain_core.language_models import BaseLanguageModel
9
9
  from langchain_core.prompts import BasePromptTemplate
10
10
  from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
11
11
  from pydantic import Field
12
+ from typing_extensions import override
12
13
 
13
14
  from langchain.chains.base import Chain
14
15
  from langchain.chains.llm import LLMChain
@@ -61,6 +62,7 @@ class QAGenerationChain(Chain):
61
62
  split_text | RunnableEach(bound=prompt | llm | JsonOutputParser())
62
63
  )
63
64
  )
65
+
64
66
  """
65
67
 
66
68
  llm_chain: LLMChain
@@ -103,10 +105,12 @@ class QAGenerationChain(Chain):
103
105
  raise NotImplementedError
104
106
 
105
107
  @property
108
+ @override
106
109
  def input_keys(self) -> list[str]:
107
110
  return [self.input_key]
108
111
 
109
112
  @property
113
+ @override
110
114
  def output_keys(self) -> list[str]:
111
115
  return [self.output_key]
112
116
 
@@ -16,6 +16,7 @@ from langchain_core.documents import Document
16
16
  from langchain_core.language_models import BaseLanguageModel
17
17
  from langchain_core.prompts import BasePromptTemplate
18
18
  from pydantic import ConfigDict, model_validator
19
+ from typing_extensions import override
19
20
 
20
21
  from langchain.chains import ReduceDocumentsChain
21
22
  from langchain.chains.base import Chain
@@ -240,6 +241,7 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
240
241
  """
241
242
  return [self.input_docs_key, self.question_key]
242
243
 
244
+ @override
243
245
  def _get_docs(
244
246
  self,
245
247
  inputs: dict[str, Any],
@@ -249,6 +251,7 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
249
251
  """Get docs to run questioning over."""
250
252
  return inputs.pop(self.input_docs_key)
251
253
 
254
+ @override
252
255
  async def _aget_docs(
253
256
  self,
254
257
  inputs: dict[str, Any],
@@ -33,7 +33,7 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
33
33
  StuffDocumentsChain,
34
34
  ):
35
35
  tokens = [
36
- self.combine_documents_chain.llm_chain._get_num_tokens(doc.page_content)
36
+ self.combine_documents_chain.llm_chain._get_num_tokens(doc.page_content) # noqa: SLF001
37
37
  for doc in docs
38
38
  ]
39
39
  token_count = sum(tokens[:num_docs])
@@ -10,6 +10,7 @@ from langchain_core.callbacks import (
10
10
  from langchain_core.documents import Document
11
11
  from langchain_core.vectorstores import VectorStore
12
12
  from pydantic import Field, model_validator
13
+ from typing_extensions import override
13
14
 
14
15
  from langchain.chains.combine_documents.stuff import StuffDocumentsChain
15
16
  from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
@@ -38,7 +39,7 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
38
39
  StuffDocumentsChain,
39
40
  ):
40
41
  tokens = [
41
- self.combine_documents_chain.llm_chain._get_num_tokens(doc.page_content)
42
+ self.combine_documents_chain.llm_chain._get_num_tokens(doc.page_content) # noqa: SLF001
42
43
  for doc in docs
43
44
  ]
44
45
  token_count = sum(tokens[:num_docs])
@@ -48,6 +49,7 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
48
49
 
49
50
  return docs[:num_docs]
50
51
 
52
+ @override
51
53
  def _get_docs(
52
54
  self,
53
55
  inputs: dict[str, Any],
@@ -73,7 +75,7 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
73
75
 
74
76
  @model_validator(mode="before")
75
77
  @classmethod
76
- def raise_deprecation(cls, values: dict) -> Any:
78
+ def _raise_deprecation(cls, values: dict) -> Any:
77
79
  warnings.warn(
78
80
  "`VectorDBQAWithSourcesChain` is deprecated - "
79
81
  "please use `from langchain.chains import RetrievalQAWithSourcesChain`",
@@ -22,6 +22,7 @@ from langchain_core.structured_query import (
22
22
  Operator,
23
23
  StructuredQuery,
24
24
  )
25
+ from typing_extensions import override
25
26
 
26
27
  from langchain.chains.llm import LLMChain
27
28
  from langchain.chains.query_constructor.parser import get_parser
@@ -46,6 +47,7 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
46
47
  ast_parse: Callable
47
48
  """Callable that parses dict into internal representation of query language."""
48
49
 
50
+ @override
49
51
  def parse(self, text: str) -> StructuredQuery:
50
52
  try:
51
53
  expected_keys = ["query", "filter"]
@@ -89,7 +91,7 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
89
91
 
90
92
  def ast_parse(raw_filter: str) -> Optional[FilterDirective]:
91
93
  filter_directive = cast(
92
- Optional[FilterDirective],
94
+ "Optional[FilterDirective]",
93
95
  get_parser().parse(raw_filter),
94
96
  )
95
97
  return fix_filter_directive(
@@ -142,7 +144,7 @@ def fix_filter_directive(
142
144
  return None
143
145
  args = [
144
146
  cast(
145
- FilterDirective,
147
+ "FilterDirective",
146
148
  fix_filter_directive(
147
149
  arg,
148
150
  allowed_comparators=allowed_comparators,
@@ -11,7 +11,7 @@ try:
11
11
  from lark import Lark, Transformer, v_args
12
12
  except ImportError:
13
13
 
14
- def v_args(*args: Any, **kwargs: Any) -> Any: # type: ignore[misc]
14
+ def v_args(*_: Any, **__: Any) -> Any: # type: ignore[misc]
15
15
  """Dummy decorator for when lark is not installed."""
16
16
  return lambda _: None
17
17
 
@@ -83,15 +83,35 @@ class QueryTransformer(Transformer):
83
83
  allowed_attributes: Optional[Sequence[str]] = None,
84
84
  **kwargs: Any,
85
85
  ):
86
+ """Initialize the QueryTransformer.
87
+
88
+ Args:
89
+ allowed_comparators: Optional sequence of allowed comparators.
90
+ allowed_operators: Optional sequence of allowed operators.
91
+ allowed_attributes: Optional sequence of allowed attributes for comparators.
92
+ **kwargs: Additional keyword arguments.
93
+ """
86
94
  super().__init__(*args, **kwargs)
87
95
  self.allowed_comparators = allowed_comparators
88
96
  self.allowed_operators = allowed_operators
89
97
  self.allowed_attributes = allowed_attributes
90
98
 
91
99
  def program(self, *items: Any) -> tuple:
100
+ """Transform the items into a tuple."""
92
101
  return items
93
102
 
94
103
  def func_call(self, func_name: Any, args: list) -> FilterDirective:
104
+ """Transform a function name and args into a FilterDirective.
105
+
106
+ Args:
107
+ func_name: The name of the function.
108
+ args: The arguments passed to the function.
109
+ Returns:
110
+ FilterDirective: The filter directive.
111
+ Raises:
112
+ ValueError: If the function is a comparator and the first arg is not in the
113
+ allowed attributes.
114
+ """
95
115
  func = self._match_func_name(str(func_name))
96
116
  if isinstance(func, Comparator):
97
117
  if self.allowed_attributes and args[0] not in self.allowed_attributes:
@@ -135,26 +155,55 @@ class QueryTransformer(Transformer):
135
155
  raise ValueError(msg)
136
156
 
137
157
  def args(self, *items: Any) -> tuple:
158
+ """Transforms items into a tuple.
159
+
160
+ Args:
161
+ items: The items to transform.
162
+ """
138
163
  return items
139
164
 
140
165
  def false(self) -> bool:
166
+ """Returns false."""
141
167
  return False
142
168
 
143
169
  def true(self) -> bool:
170
+ """Returns true."""
144
171
  return True
145
172
 
146
173
  def list(self, item: Any) -> list:
174
+ """Transforms an item into a list.
175
+
176
+ Args:
177
+ item: The item to transform.
178
+ """
147
179
  if item is None:
148
180
  return []
149
181
  return list(item)
150
182
 
151
183
  def int(self, item: Any) -> int:
184
+ """Transforms an item into an int.
185
+
186
+ Args:
187
+ item: The item to transform.
188
+ """
152
189
  return int(item)
153
190
 
154
191
  def float(self, item: Any) -> float:
192
+ """Transforms an item into a float.
193
+
194
+ Args:
195
+ item: The item to transform.
196
+ """
155
197
  return float(item)
156
198
 
157
199
  def date(self, item: Any) -> ISO8601Date:
200
+ """Transforms an item into a ISO8601Date object.
201
+
202
+ Args:
203
+ item: The item to transform.
204
+ Raises:
205
+ ValueError: If the item is not in ISO 8601 date format.
206
+ """
158
207
  item = str(item).strip("\"'")
159
208
  try:
160
209
  datetime.datetime.strptime(item, "%Y-%m-%d") # noqa: DTZ007
@@ -167,6 +216,13 @@ class QueryTransformer(Transformer):
167
216
  return {"date": item, "type": "date"}
168
217
 
169
218
  def datetime(self, item: Any) -> ISO8601DateTime:
219
+ """Transforms an item into a ISO8601DateTime object.
220
+
221
+ Args:
222
+ item: The item to transform.
223
+ Raises:
224
+ ValueError: If the item is not in ISO 8601 datetime format.
225
+ """
170
226
  item = str(item).strip("\"'")
171
227
  try:
172
228
  # Parse full ISO 8601 datetime format
@@ -180,7 +236,13 @@ class QueryTransformer(Transformer):
180
236
  return {"datetime": item, "type": "datetime"}
181
237
 
182
238
  def string(self, item: Any) -> str:
183
- # Remove escaped quotes
239
+ """Transforms an item into a string.
240
+
241
+ Removes escaped quotes.
242
+
243
+ Args:
244
+ item: The item to transform.
245
+ """
184
246
  return str(item).strip("\"'")
185
247
 
186
248
 
@@ -18,6 +18,7 @@ from langchain_core.prompts import PromptTemplate
18
18
  from langchain_core.retrievers import BaseRetriever
19
19
  from langchain_core.vectorstores import VectorStore
20
20
  from pydantic import ConfigDict, Field, model_validator
21
+ from typing_extensions import override
21
22
 
22
23
  from langchain.chains.base import Chain
23
24
  from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
@@ -146,6 +147,7 @@ class BaseRetrievalQA(Chain):
146
147
 
147
148
  res = indexqa({'query': 'This is my query'})
148
149
  answer, docs = res['result'], res['source_documents']
150
+
149
151
  """
150
152
  _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
151
153
  question = inputs[self.input_key]
@@ -190,6 +192,7 @@ class BaseRetrievalQA(Chain):
190
192
 
191
193
  res = indexqa({'query': 'This is my query'})
192
194
  answer, docs = res['result'], res['source_documents']
195
+
193
196
  """
194
197
  _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
195
198
  question = inputs[self.input_key]
@@ -330,6 +333,7 @@ class VectorDBQA(BaseRetrievalQA):
330
333
  raise ValueError(msg)
331
334
  return values
332
335
 
336
+ @override
333
337
  def _get_docs(
334
338
  self,
335
339
  question: str,
@@ -12,11 +12,14 @@ from langchain_core.callbacks import (
12
12
  Callbacks,
13
13
  )
14
14
  from pydantic import ConfigDict
15
+ from typing_extensions import override
15
16
 
16
17
  from langchain.chains.base import Chain
17
18
 
18
19
 
19
20
  class Route(NamedTuple):
21
+ """A route to a destination chain."""
22
+
20
23
  destination: Optional[str]
21
24
  next_inputs: dict[str, Any]
22
25
 
@@ -25,12 +28,12 @@ class RouterChain(Chain, ABC):
25
28
  """Chain that outputs the name of a destination chain and the inputs to it."""
26
29
 
27
30
  @property
31
+ @override
28
32
  def output_keys(self) -> list[str]:
29
33
  return ["destination", "next_inputs"]
30
34
 
31
35
  def route(self, inputs: dict[str, Any], callbacks: Callbacks = None) -> Route:
32
- """
33
- Route inputs to a destination chain.
36
+ """Route inputs to a destination chain.
34
37
 
35
38
  Args:
36
39
  inputs: inputs to the chain
@@ -47,6 +50,15 @@ class RouterChain(Chain, ABC):
47
50
  inputs: dict[str, Any],
48
51
  callbacks: Callbacks = None,
49
52
  ) -> Route:
53
+ """Route inputs to a destination chain.
54
+
55
+ Args:
56
+ inputs: inputs to the chain
57
+ callbacks: callbacks to use for the chain
58
+
59
+ Returns:
60
+ a Route object
61
+ """
50
62
  result = await self.acall(inputs, callbacks=callbacks)
51
63
  return Route(result["destination"], result["next_inputs"])
52
64
 
@@ -11,6 +11,7 @@ from langchain_core.documents import Document
11
11
  from langchain_core.embeddings import Embeddings
12
12
  from langchain_core.vectorstores import VectorStore
13
13
  from pydantic import ConfigDict
14
+ from typing_extensions import override
14
15
 
15
16
  from langchain.chains.router.base import RouterChain
16
17
 
@@ -34,6 +35,7 @@ class EmbeddingRouterChain(RouterChain):
34
35
  """
35
36
  return self.routing_keys
36
37
 
38
+ @override
37
39
  def _call(
38
40
  self,
39
41
  inputs: dict[str, Any],
@@ -43,6 +45,7 @@ class EmbeddingRouterChain(RouterChain):
43
45
  results = self.vectorstore.similarity_search(_input, k=1)
44
46
  return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
45
47
 
48
+ @override
46
49
  async def _acall(
47
50
  self,
48
51
  inputs: dict[str, Any],
@@ -15,7 +15,7 @@ from langchain_core.output_parsers import BaseOutputParser
15
15
  from langchain_core.prompts import BasePromptTemplate
16
16
  from langchain_core.utils.json import parse_and_check_json_markdown
17
17
  from pydantic import model_validator
18
- from typing_extensions import Self
18
+ from typing_extensions import Self, override
19
19
 
20
20
  from langchain.chains import LLMChain
21
21
  from langchain.chains.router.base import RouterChain
@@ -96,13 +96,14 @@ class LLMRouterChain(RouterChain):
96
96
  )
97
97
 
98
98
  chain.invoke({"query": "what color are carrots"})
99
+
99
100
  """ # noqa: E501
100
101
 
101
102
  llm_chain: LLMChain
102
103
  """LLM chain used to perform routing"""
103
104
 
104
105
  @model_validator(mode="after")
105
- def validate_prompt(self) -> Self:
106
+ def _validate_prompt(self) -> Self:
106
107
  prompt = self.llm_chain.prompt
107
108
  if prompt.output_parser is None:
108
109
  msg = (
@@ -137,7 +138,7 @@ class LLMRouterChain(RouterChain):
137
138
 
138
139
  prediction = self.llm_chain.predict(callbacks=callbacks, **inputs)
139
140
  return cast(
140
- dict[str, Any],
141
+ "dict[str, Any]",
141
142
  self.llm_chain.prompt.output_parser.parse(prediction),
142
143
  )
143
144
 
@@ -149,7 +150,7 @@ class LLMRouterChain(RouterChain):
149
150
  _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
150
151
  callbacks = _run_manager.get_child()
151
152
  return cast(
152
- dict[str, Any],
153
+ "dict[str, Any]",
153
154
  await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
154
155
  )
155
156
 
@@ -172,6 +173,7 @@ class RouterOutputParser(BaseOutputParser[dict[str, str]]):
172
173
  next_inputs_type: type = str
173
174
  next_inputs_inner_key: str = "input"
174
175
 
176
+ @override
175
177
  def parse(self, text: str) -> dict[str, Any]:
176
178
  try:
177
179
  expected_keys = ["destination", "next_inputs"]
@@ -7,6 +7,7 @@ from typing import Any, Optional
7
7
  from langchain_core._api import deprecated
8
8
  from langchain_core.language_models import BaseLanguageModel
9
9
  from langchain_core.prompts import PromptTemplate
10
+ from typing_extensions import override
10
11
 
11
12
  from langchain.chains import ConversationChain
12
13
  from langchain.chains.base import Chain
@@ -139,9 +140,11 @@ class MultiPromptChain(MultiRouteChain):
139
140
  result = await app.ainvoke({"query": "what color are carrots"})
140
141
  print(result["destination"])
141
142
  print(result["answer"])
143
+
142
144
  """ # noqa: E501
143
145
 
144
146
  @property
147
+ @override
145
148
  def output_keys(self) -> list[str]:
146
149
  return ["text"]
147
150
 
@@ -8,6 +8,7 @@ from typing import Any, Optional
8
8
  from langchain_core.language_models import BaseLanguageModel
9
9
  from langchain_core.prompts import PromptTemplate
10
10
  from langchain_core.retrievers import BaseRetriever
11
+ from typing_extensions import override
11
12
 
12
13
  from langchain.chains import ConversationChain
13
14
  from langchain.chains.base import Chain
@@ -32,6 +33,7 @@ class MultiRetrievalQAChain(MultiRouteChain):
32
33
  """Default chain to use when router doesn't map input to one of the destinations."""
33
34
 
34
35
  @property
36
+ @override
35
37
  def output_keys(self) -> list[str]:
36
38
  return ["result"]
37
39
 
@@ -47,6 +49,22 @@ class MultiRetrievalQAChain(MultiRouteChain):
47
49
  default_chain_llm: Optional[BaseLanguageModel] = None,
48
50
  **kwargs: Any,
49
51
  ) -> MultiRetrievalQAChain:
52
+ """Create a multi retrieval qa chain from an LLM and a default chain.
53
+
54
+ Args:
55
+ llm: The language model to use.
56
+ retriever_infos: Dictionaries containing retriever information.
57
+ default_retriever: Optional default retriever to use if no default chain
58
+ is provided.
59
+ default_prompt: Optional prompt template to use for the default retriever.
60
+ default_chain: Optional default chain to use when router doesn't map input
61
+ to one of the destinations.
62
+ default_chain_llm: Optional language model to use if no default chain and
63
+ no default retriever are provided.
64
+ **kwargs: Additional keyword arguments to pass to the chain.
65
+ Returns:
66
+ An instance of the multi retrieval qa chain.
67
+ """
50
68
  if default_prompt and not default_retriever:
51
69
  msg = (
52
70
  "`default_retriever` must be specified if `default_prompt` is "
@@ -113,6 +113,7 @@ def create_sql_query_chain(
113
113
 
114
114
  Question: {input}'''
115
115
  prompt = PromptTemplate.from_template(template)
116
+
116
117
  """ # noqa: E501
117
118
  if prompt is not None:
118
119
  prompt_to_use = prompt
@@ -132,6 +132,7 @@ def create_openai_fn_runnable(
132
132
  structured_llm = create_openai_fn_runnable([RecordPerson, RecordDog], llm)
133
133
  structured_llm.invoke("Harry was a chubby brown beagle who loved chicken)
134
134
  # -> RecordDog(name="Harry", color="brown", fav_food="chicken")
135
+
135
136
  """ # noqa: E501
136
137
  if not functions:
137
138
  msg = "Need to pass in at least one function. Received zero."
@@ -390,6 +391,7 @@ def create_structured_output_runnable(
390
391
  )
391
392
  chain = prompt | structured_llm
392
393
  chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
394
+
393
395
  """ # noqa: E501
394
396
  # for backwards compatibility
395
397
  force_function_usage = kwargs.get(