langchain 0.3.27__py3-none-any.whl → 0.4.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. langchain/agents/agent.py +16 -20
  2. langchain/agents/agent_iterator.py +19 -12
  3. langchain/agents/agent_toolkits/vectorstore/base.py +2 -0
  4. langchain/agents/chat/base.py +2 -0
  5. langchain/agents/conversational/base.py +2 -0
  6. langchain/agents/conversational_chat/base.py +2 -0
  7. langchain/agents/initialize.py +1 -1
  8. langchain/agents/json_chat/base.py +1 -0
  9. langchain/agents/mrkl/base.py +2 -0
  10. langchain/agents/openai_assistant/base.py +1 -1
  11. langchain/agents/openai_functions_agent/agent_token_buffer_memory.py +2 -0
  12. langchain/agents/openai_functions_agent/base.py +3 -2
  13. langchain/agents/openai_functions_multi_agent/base.py +1 -1
  14. langchain/agents/openai_tools/base.py +1 -0
  15. langchain/agents/output_parsers/json.py +2 -0
  16. langchain/agents/output_parsers/openai_functions.py +10 -3
  17. langchain/agents/output_parsers/openai_tools.py +8 -1
  18. langchain/agents/output_parsers/react_json_single_input.py +3 -0
  19. langchain/agents/output_parsers/react_single_input.py +3 -0
  20. langchain/agents/output_parsers/self_ask.py +2 -0
  21. langchain/agents/output_parsers/tools.py +16 -2
  22. langchain/agents/output_parsers/xml.py +3 -0
  23. langchain/agents/react/agent.py +1 -0
  24. langchain/agents/react/base.py +4 -0
  25. langchain/agents/react/output_parser.py +2 -0
  26. langchain/agents/schema.py +2 -0
  27. langchain/agents/self_ask_with_search/base.py +4 -0
  28. langchain/agents/structured_chat/base.py +5 -0
  29. langchain/agents/structured_chat/output_parser.py +13 -0
  30. langchain/agents/tool_calling_agent/base.py +1 -0
  31. langchain/agents/tools.py +3 -0
  32. langchain/agents/xml/base.py +7 -1
  33. langchain/callbacks/streaming_aiter.py +13 -2
  34. langchain/callbacks/streaming_aiter_final_only.py +11 -2
  35. langchain/callbacks/streaming_stdout_final_only.py +5 -0
  36. langchain/callbacks/tracers/logging.py +11 -0
  37. langchain/chains/api/base.py +5 -1
  38. langchain/chains/base.py +8 -2
  39. langchain/chains/combine_documents/base.py +7 -1
  40. langchain/chains/combine_documents/map_reduce.py +3 -0
  41. langchain/chains/combine_documents/map_rerank.py +6 -4
  42. langchain/chains/combine_documents/reduce.py +1 -0
  43. langchain/chains/combine_documents/refine.py +1 -0
  44. langchain/chains/combine_documents/stuff.py +5 -1
  45. langchain/chains/constitutional_ai/base.py +7 -0
  46. langchain/chains/conversation/base.py +4 -1
  47. langchain/chains/conversational_retrieval/base.py +67 -59
  48. langchain/chains/elasticsearch_database/base.py +2 -1
  49. langchain/chains/flare/base.py +2 -0
  50. langchain/chains/flare/prompts.py +2 -0
  51. langchain/chains/llm.py +7 -2
  52. langchain/chains/llm_bash/__init__.py +1 -1
  53. langchain/chains/llm_checker/base.py +12 -1
  54. langchain/chains/llm_math/base.py +9 -1
  55. langchain/chains/llm_summarization_checker/base.py +13 -1
  56. langchain/chains/llm_symbolic_math/__init__.py +1 -1
  57. langchain/chains/loading.py +4 -2
  58. langchain/chains/moderation.py +3 -0
  59. langchain/chains/natbot/base.py +3 -1
  60. langchain/chains/natbot/crawler.py +29 -0
  61. langchain/chains/openai_functions/base.py +2 -0
  62. langchain/chains/openai_functions/citation_fuzzy_match.py +9 -0
  63. langchain/chains/openai_functions/openapi.py +4 -0
  64. langchain/chains/openai_functions/qa_with_structure.py +3 -3
  65. langchain/chains/openai_functions/tagging.py +2 -0
  66. langchain/chains/qa_generation/base.py +4 -0
  67. langchain/chains/qa_with_sources/base.py +3 -0
  68. langchain/chains/qa_with_sources/retrieval.py +1 -1
  69. langchain/chains/qa_with_sources/vector_db.py +4 -2
  70. langchain/chains/query_constructor/base.py +4 -2
  71. langchain/chains/query_constructor/parser.py +64 -2
  72. langchain/chains/retrieval_qa/base.py +4 -0
  73. langchain/chains/router/base.py +14 -2
  74. langchain/chains/router/embedding_router.py +3 -0
  75. langchain/chains/router/llm_router.py +6 -4
  76. langchain/chains/router/multi_prompt.py +3 -0
  77. langchain/chains/router/multi_retrieval_qa.py +18 -0
  78. langchain/chains/sql_database/query.py +1 -0
  79. langchain/chains/structured_output/base.py +2 -0
  80. langchain/chains/transform.py +4 -0
  81. langchain/chat_models/base.py +55 -18
  82. langchain/document_loaders/blob_loaders/schema.py +1 -4
  83. langchain/embeddings/base.py +2 -0
  84. langchain/embeddings/cache.py +3 -3
  85. langchain/evaluation/agents/trajectory_eval_chain.py +3 -2
  86. langchain/evaluation/comparison/eval_chain.py +1 -0
  87. langchain/evaluation/criteria/eval_chain.py +3 -0
  88. langchain/evaluation/embedding_distance/base.py +11 -0
  89. langchain/evaluation/exact_match/base.py +14 -1
  90. langchain/evaluation/loading.py +1 -0
  91. langchain/evaluation/parsing/base.py +16 -3
  92. langchain/evaluation/parsing/json_distance.py +19 -8
  93. langchain/evaluation/parsing/json_schema.py +1 -4
  94. langchain/evaluation/qa/eval_chain.py +8 -0
  95. langchain/evaluation/qa/generate_chain.py +2 -0
  96. langchain/evaluation/regex_match/base.py +9 -1
  97. langchain/evaluation/scoring/eval_chain.py +1 -0
  98. langchain/evaluation/string_distance/base.py +6 -0
  99. langchain/memory/buffer.py +5 -0
  100. langchain/memory/buffer_window.py +2 -0
  101. langchain/memory/combined.py +1 -1
  102. langchain/memory/entity.py +47 -0
  103. langchain/memory/simple.py +3 -0
  104. langchain/memory/summary.py +30 -0
  105. langchain/memory/summary_buffer.py +3 -0
  106. langchain/memory/token_buffer.py +2 -0
  107. langchain/output_parsers/combining.py +4 -2
  108. langchain/output_parsers/enum.py +5 -1
  109. langchain/output_parsers/fix.py +8 -1
  110. langchain/output_parsers/pandas_dataframe.py +16 -1
  111. langchain/output_parsers/regex.py +2 -0
  112. langchain/output_parsers/retry.py +21 -1
  113. langchain/output_parsers/structured.py +10 -0
  114. langchain/output_parsers/yaml.py +4 -0
  115. langchain/pydantic_v1/__init__.py +1 -1
  116. langchain/retrievers/document_compressors/chain_extract.py +4 -2
  117. langchain/retrievers/document_compressors/cohere_rerank.py +2 -0
  118. langchain/retrievers/document_compressors/cross_encoder_rerank.py +2 -0
  119. langchain/retrievers/document_compressors/embeddings_filter.py +3 -0
  120. langchain/retrievers/document_compressors/listwise_rerank.py +1 -0
  121. langchain/retrievers/ensemble.py +2 -2
  122. langchain/retrievers/multi_query.py +3 -1
  123. langchain/retrievers/multi_vector.py +4 -1
  124. langchain/retrievers/parent_document_retriever.py +15 -0
  125. langchain/retrievers/self_query/base.py +19 -0
  126. langchain/retrievers/time_weighted_retriever.py +3 -0
  127. langchain/runnables/hub.py +12 -0
  128. langchain/runnables/openai_functions.py +6 -0
  129. langchain/smith/__init__.py +1 -0
  130. langchain/smith/evaluation/config.py +5 -22
  131. langchain/smith/evaluation/progress.py +12 -3
  132. langchain/smith/evaluation/runner_utils.py +240 -123
  133. langchain/smith/evaluation/string_run_evaluator.py +27 -0
  134. langchain/storage/encoder_backed.py +1 -0
  135. langchain/tools/python/__init__.py +1 -1
  136. {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/METADATA +2 -12
  137. {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/RECORD +140 -141
  138. langchain/smith/evaluation/utils.py +0 -0
  139. {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/WHEEL +0 -0
  140. {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/entry_points.txt +0 -0
  141. {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/licenses/LICENSE +0 -0
@@ -187,6 +187,7 @@ class ScoreStringEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
187
187
  )
188
188
 
189
189
  @classmethod
190
+ @override
190
191
  def is_lc_serializable(cls) -> bool:
191
192
  return False
192
193
 
@@ -224,6 +224,7 @@ class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
224
224
  """
225
225
  return f"{self.distance.value}_distance"
226
226
 
227
+ @override
227
228
  def _call(
228
229
  self,
229
230
  inputs: dict[str, Any],
@@ -242,6 +243,7 @@ class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
242
243
  """
243
244
  return {"score": self.compute_metric(inputs["reference"], inputs["prediction"])}
244
245
 
246
+ @override
245
247
  async def _acall(
246
248
  self,
247
249
  inputs: dict[str, Any],
@@ -357,6 +359,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
357
359
  """
358
360
  return f"pairwise_{self.distance.value}_distance"
359
361
 
362
+ @override
360
363
  def _call(
361
364
  self,
362
365
  inputs: dict[str, Any],
@@ -377,6 +380,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
377
380
  "score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]),
378
381
  }
379
382
 
383
+ @override
380
384
  async def _acall(
381
385
  self,
382
386
  inputs: dict[str, Any],
@@ -397,6 +401,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
397
401
  "score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]),
398
402
  }
399
403
 
404
+ @override
400
405
  def _evaluate_string_pairs(
401
406
  self,
402
407
  *,
@@ -431,6 +436,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
431
436
  )
432
437
  return self._prepare_output(result)
433
438
 
439
+ @override
434
440
  async def _aevaluate_string_pairs(
435
441
  self,
436
442
  *,
@@ -4,6 +4,7 @@ from langchain_core._api import deprecated
4
4
  from langchain_core.memory import BaseMemory
5
5
  from langchain_core.messages import BaseMessage, get_buffer_string
6
6
  from langchain_core.utils import pre_init
7
+ from typing_extensions import override
7
8
 
8
9
  from langchain.memory.chat_memory import BaseChatMemory
9
10
  from langchain.memory.utils import get_prompt_input_key
@@ -78,10 +79,12 @@ class ConversationBufferMemory(BaseChatMemory):
78
79
  """
79
80
  return [self.memory_key]
80
81
 
82
+ @override
81
83
  def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
82
84
  """Return history buffer."""
83
85
  return {self.memory_key: self.buffer}
84
86
 
87
+ @override
85
88
  async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
86
89
  """Return key-value pairs given the text input to the chain."""
87
90
  buffer = await self.abuffer()
@@ -132,6 +135,7 @@ class ConversationStringBufferMemory(BaseMemory):
132
135
  """
133
136
  return [self.memory_key]
134
137
 
138
+ @override
135
139
  def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, str]:
136
140
  """Return history buffer."""
137
141
  return {self.memory_key: self.buffer}
@@ -169,5 +173,6 @@ class ConversationStringBufferMemory(BaseMemory):
169
173
  """Clear memory contents."""
170
174
  self.buffer = ""
171
175
 
176
+ @override
172
177
  async def aclear(self) -> None:
173
178
  self.clear()
@@ -2,6 +2,7 @@ from typing import Any, Union
2
2
 
3
3
  from langchain_core._api import deprecated
4
4
  from langchain_core.messages import BaseMessage, get_buffer_string
5
+ from typing_extensions import override
5
6
 
6
7
  from langchain.memory.chat_memory import BaseChatMemory
7
8
 
@@ -55,6 +56,7 @@ class ConversationBufferWindowMemory(BaseChatMemory):
55
56
  """
56
57
  return [self.memory_key]
57
58
 
59
+ @override
58
60
  def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
59
61
  """Return history buffer."""
60
62
  return {self.memory_key: self.buffer}
@@ -15,7 +15,7 @@ class CombinedMemory(BaseMemory):
15
15
 
16
16
  @field_validator("memories")
17
17
  @classmethod
18
- def check_repeated_memory_variable(
18
+ def _check_repeated_memory_variable(
19
19
  cls,
20
20
  value: list[BaseMemory],
21
21
  ) -> list[BaseMemory]:
@@ -11,6 +11,7 @@ from langchain_core.language_models import BaseLanguageModel
11
11
  from langchain_core.messages import BaseMessage, get_buffer_string
12
12
  from langchain_core.prompts import BasePromptTemplate
13
13
  from pydantic import BaseModel, ConfigDict, Field
14
+ from typing_extensions import override
14
15
 
15
16
  from langchain.chains.llm import LLMChain
16
17
  from langchain.memory.chat_memory import BaseChatMemory
@@ -71,18 +72,23 @@ class InMemoryEntityStore(BaseEntityStore):
71
72
 
72
73
  store: dict[str, Optional[str]] = {}
73
74
 
75
+ @override
74
76
  def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
75
77
  return self.store.get(key, default)
76
78
 
79
+ @override
77
80
  def set(self, key: str, value: Optional[str]) -> None:
78
81
  self.store[key] = value
79
82
 
83
+ @override
80
84
  def delete(self, key: str) -> None:
81
85
  del self.store[key]
82
86
 
87
+ @override
83
88
  def exists(self, key: str) -> bool:
84
89
  return key in self.store
85
90
 
91
+ @override
86
92
  def clear(self) -> None:
87
93
  return self.store.clear()
88
94
 
@@ -113,6 +119,16 @@ class UpstashRedisEntityStore(BaseEntityStore):
113
119
  *args: Any,
114
120
  **kwargs: Any,
115
121
  ):
122
+ """Initializes the RedisEntityStore.
123
+
124
+ Args:
125
+ session_id: Unique identifier for the session.
126
+ url: URL of the Redis server.
127
+ token: Authentication token for the Redis server.
128
+ key_prefix: Prefix for keys in the Redis store.
129
+ ttl: Time-to-live for keys in seconds (default 1 day).
130
+ recall_ttl: Time-to-live extension for keys when recalled (default 3 days).
131
+ """
116
132
  try:
117
133
  from upstash_redis import Redis
118
134
  except ImportError as e:
@@ -138,8 +154,10 @@ class UpstashRedisEntityStore(BaseEntityStore):
138
154
 
139
155
  @property
140
156
  def full_key_prefix(self) -> str:
157
+ """Returns the full key prefix with session ID."""
141
158
  return f"{self.key_prefix}:{self.session_id}"
142
159
 
160
+ @override
143
161
  def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
144
162
  res = (
145
163
  self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
@@ -151,6 +169,7 @@ class UpstashRedisEntityStore(BaseEntityStore):
151
169
  )
152
170
  return res
153
171
 
172
+ @override
154
173
  def set(self, key: str, value: Optional[str]) -> None:
155
174
  if not value:
156
175
  return self.delete(key)
@@ -164,12 +183,15 @@ class UpstashRedisEntityStore(BaseEntityStore):
164
183
  )
165
184
  return None
166
185
 
186
+ @override
167
187
  def delete(self, key: str) -> None:
168
188
  self.redis_client.delete(f"{self.full_key_prefix}:{key}")
169
189
 
190
+ @override
170
191
  def exists(self, key: str) -> bool:
171
192
  return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
172
193
 
194
+ @override
173
195
  def clear(self) -> None:
174
196
  def scan_and_delete(cursor: int) -> int:
175
197
  cursor, keys_to_delete = self.redis_client.scan(
@@ -215,6 +237,16 @@ class RedisEntityStore(BaseEntityStore):
215
237
  *args: Any,
216
238
  **kwargs: Any,
217
239
  ):
240
+ """Initializes the RedisEntityStore.
241
+
242
+ Args:
243
+ session_id: Unique identifier for the session.
244
+ url: URL of the Redis server.
245
+ key_prefix: Prefix for keys in the Redis store.
246
+ ttl: Time-to-live for keys in seconds (default 1 day).
247
+ recall_ttl: Time-to-live extension for keys when recalled (default 3 days).
248
+ """
249
+
218
250
  try:
219
251
  import redis
220
252
  except ImportError as e:
@@ -247,8 +279,10 @@ class RedisEntityStore(BaseEntityStore):
247
279
 
248
280
  @property
249
281
  def full_key_prefix(self) -> str:
282
+ """Returns the full key prefix with session ID."""
250
283
  return f"{self.key_prefix}:{self.session_id}"
251
284
 
285
+ @override
252
286
  def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
253
287
  res = (
254
288
  self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
@@ -258,6 +292,7 @@ class RedisEntityStore(BaseEntityStore):
258
292
  logger.debug("REDIS MEM get '%s:%s': '%s'", self.full_key_prefix, key, res)
259
293
  return res
260
294
 
295
+ @override
261
296
  def set(self, key: str, value: Optional[str]) -> None:
262
297
  if not value:
263
298
  return self.delete(key)
@@ -271,12 +306,15 @@ class RedisEntityStore(BaseEntityStore):
271
306
  )
272
307
  return None
273
308
 
309
+ @override
274
310
  def delete(self, key: str) -> None:
275
311
  self.redis_client.delete(f"{self.full_key_prefix}:{key}")
276
312
 
313
+ @override
277
314
  def exists(self, key: str) -> bool:
278
315
  return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
279
316
 
317
+ @override
280
318
  def clear(self) -> None:
281
319
  # iterate a list in batches of size batch_size
282
320
  def batched(iterable: Iterable[Any], batch_size: int) -> Iterable[Any]:
@@ -318,6 +356,13 @@ class SQLiteEntityStore(BaseEntityStore):
318
356
  *args: Any,
319
357
  **kwargs: Any,
320
358
  ):
359
+ """Initializes the SQLiteEntityStore.
360
+
361
+ Args:
362
+ session_id: Unique identifier for the session.
363
+ db_file: Path to the SQLite database file.
364
+ table_name: Name of the table to store entities.
365
+ """
321
366
  super().__init__(*args, **kwargs)
322
367
  try:
323
368
  import sqlite3
@@ -341,6 +386,7 @@ class SQLiteEntityStore(BaseEntityStore):
341
386
 
342
387
  @property
343
388
  def full_table_name(self) -> str:
389
+ """Returns the full table name with session ID."""
344
390
  return f"{self.table_name}_{self.session_id}"
345
391
 
346
392
  def _execute_query(self, query: str, params: tuple = ()) -> "sqlite3.Cursor":
@@ -393,6 +439,7 @@ class SQLiteEntityStore(BaseEntityStore):
393
439
  cursor = self._execute_query(query, (key,))
394
440
  return cursor.fetchone() is not None
395
441
 
442
+ @override
396
443
  def clear(self) -> None:
397
444
  # Ignore S608 since we validate for malicious table/session names in `__init__`
398
445
  query = f"""
@@ -1,6 +1,7 @@
1
1
  from typing import Any
2
2
 
3
3
  from langchain_core.memory import BaseMemory
4
+ from typing_extensions import override
4
5
 
5
6
 
6
7
  class SimpleMemory(BaseMemory):
@@ -11,9 +12,11 @@ class SimpleMemory(BaseMemory):
11
12
  memories: dict[str, Any] = {}
12
13
 
13
14
  @property
15
+ @override
14
16
  def memory_variables(self) -> list[str]:
15
17
  return list(self.memories.keys())
16
18
 
19
+ @override
17
20
  def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, str]:
18
21
  return self.memories
19
22
 
@@ -9,6 +9,7 @@ from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_strin
9
9
  from langchain_core.prompts import BasePromptTemplate
10
10
  from langchain_core.utils import pre_init
11
11
  from pydantic import BaseModel
12
+ from typing_extensions import override
12
13
 
13
14
  from langchain.chains.llm import LLMChain
14
15
  from langchain.memory.chat_memory import BaseChatMemory
@@ -37,6 +38,15 @@ class SummarizerMixin(BaseModel):
37
38
  messages: list[BaseMessage],
38
39
  existing_summary: str,
39
40
  ) -> str:
41
+ """Predict a new summary based on the messages and existing summary.
42
+
43
+ Args:
44
+ messages: List of messages to summarize.
45
+ existing_summary: Existing summary to build upon.
46
+
47
+ Returns:
48
+ A new summary string.
49
+ """
40
50
  new_lines = get_buffer_string(
41
51
  messages,
42
52
  human_prefix=self.human_prefix,
@@ -51,6 +61,15 @@ class SummarizerMixin(BaseModel):
51
61
  messages: list[BaseMessage],
52
62
  existing_summary: str,
53
63
  ) -> str:
64
+ """Predict a new summary based on the messages and existing summary.
65
+
66
+ Args:
67
+ messages: List of messages to summarize.
68
+ existing_summary: Existing summary to build upon.
69
+
70
+ Returns:
71
+ A new summary string.
72
+ """
54
73
  new_lines = get_buffer_string(
55
74
  messages,
56
75
  human_prefix=self.human_prefix,
@@ -89,6 +108,16 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
89
108
  summarize_step: int = 2,
90
109
  **kwargs: Any,
91
110
  ) -> ConversationSummaryMemory:
111
+ """Create a ConversationSummaryMemory from a list of messages.
112
+
113
+ Args:
114
+ llm: The language model to use for summarization.
115
+ chat_memory: The chat history to summarize.
116
+ summarize_step: Number of messages to summarize at a time.
117
+ **kwargs: Additional keyword arguments to pass to the class.
118
+ Returns:
119
+ An instance of ConversationSummaryMemory with the summarized history.
120
+ """
92
121
  obj = cls(llm=llm, chat_memory=chat_memory, **kwargs)
93
122
  for i in range(0, len(obj.chat_memory.messages), summarize_step):
94
123
  obj.buffer = obj.predict_new_summary(
@@ -105,6 +134,7 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
105
134
  """
106
135
  return [self.memory_key]
107
136
 
137
+ @override
108
138
  def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
109
139
  """Return history buffer."""
110
140
  if self.return_messages:
@@ -3,6 +3,7 @@ from typing import Any, Union
3
3
  from langchain_core._api import deprecated
4
4
  from langchain_core.messages import BaseMessage, get_buffer_string
5
5
  from langchain_core.utils import pre_init
6
+ from typing_extensions import override
6
7
 
7
8
  from langchain.memory.chat_memory import BaseChatMemory
8
9
  from langchain.memory.summary import SummarizerMixin
@@ -46,6 +47,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
46
47
  """
47
48
  return [self.memory_key]
48
49
 
50
+ @override
49
51
  def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
50
52
  """Return history buffer."""
51
53
  buffer = self.chat_memory.messages
@@ -64,6 +66,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
64
66
  )
65
67
  return {self.memory_key: final_buffer}
66
68
 
69
+ @override
67
70
  async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
68
71
  """Asynchronously return key-value pairs given the text input to the chain."""
69
72
  buffer = await self.chat_memory.aget_messages()
@@ -3,6 +3,7 @@ from typing import Any
3
3
  from langchain_core._api import deprecated
4
4
  from langchain_core.language_models import BaseLanguageModel
5
5
  from langchain_core.messages import BaseMessage, get_buffer_string
6
+ from typing_extensions import override
6
7
 
7
8
  from langchain.memory.chat_memory import BaseChatMemory
8
9
 
@@ -55,6 +56,7 @@ class ConversationTokenBufferMemory(BaseChatMemory):
55
56
  """
56
57
  return [self.memory_key]
57
58
 
59
+ @override
58
60
  def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
59
61
  """Return history buffer."""
60
62
  return {self.memory_key: self.buffer}
@@ -4,6 +4,7 @@ from typing import Any
4
4
 
5
5
  from langchain_core.output_parsers import BaseOutputParser
6
6
  from langchain_core.utils import pre_init
7
+ from typing_extensions import override
7
8
 
8
9
  _MIN_PARSERS = 2
9
10
 
@@ -14,6 +15,7 @@ class CombiningOutputParser(BaseOutputParser[dict[str, Any]]):
14
15
  parsers: list[BaseOutputParser]
15
16
 
16
17
  @classmethod
18
+ @override
17
19
  def is_lc_serializable(cls) -> bool:
18
20
  return True
19
21
 
@@ -25,10 +27,10 @@ class CombiningOutputParser(BaseOutputParser[dict[str, Any]]):
25
27
  msg = "Must have at least two parsers"
26
28
  raise ValueError(msg)
27
29
  for parser in parsers:
28
- if parser._type == "combining":
30
+ if parser._type == "combining": # noqa: SLF001
29
31
  msg = "Cannot nest combining parsers"
30
32
  raise ValueError(msg)
31
- if parser._type == "list":
33
+ if parser._type == "list": # noqa: SLF001
32
34
  msg = "Cannot combine list parsers"
33
35
  raise ValueError(msg)
34
36
  return values
@@ -3,6 +3,7 @@ from enum import Enum
3
3
  from langchain_core.exceptions import OutputParserException
4
4
  from langchain_core.output_parsers import BaseOutputParser
5
5
  from langchain_core.utils import pre_init
6
+ from typing_extensions import override
6
7
 
7
8
 
8
9
  class EnumOutputParser(BaseOutputParser[Enum]):
@@ -12,7 +13,7 @@ class EnumOutputParser(BaseOutputParser[Enum]):
12
13
  """The enum to parse. Its values must be strings."""
13
14
 
14
15
  @pre_init
15
- def raise_deprecation(cls, values: dict) -> dict:
16
+ def _raise_deprecation(cls, values: dict) -> dict:
16
17
  enum = values["enum"]
17
18
  if not all(isinstance(e.value, str) for e in enum):
18
19
  msg = "Enum values must be strings"
@@ -23,6 +24,7 @@ class EnumOutputParser(BaseOutputParser[Enum]):
23
24
  def _valid_values(self) -> list[str]:
24
25
  return [e.value for e in self.enum]
25
26
 
27
+ @override
26
28
  def parse(self, response: str) -> Enum:
27
29
  try:
28
30
  return self.enum(response.strip())
@@ -33,9 +35,11 @@ class EnumOutputParser(BaseOutputParser[Enum]):
33
35
  )
34
36
  raise OutputParserException(msg) from e
35
37
 
38
+ @override
36
39
  def get_format_instructions(self) -> str:
37
40
  return f"Select one of the following options: {', '.join(self._valid_values)}"
38
41
 
39
42
  @property
43
+ @override
40
44
  def OutputType(self) -> type[Enum]:
41
45
  return self.enum
@@ -7,7 +7,7 @@ from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
7
7
  from langchain_core.prompts import BasePromptTemplate
8
8
  from langchain_core.runnables import Runnable, RunnableSerializable
9
9
  from pydantic import SkipValidation
10
- from typing_extensions import TypedDict
10
+ from typing_extensions import TypedDict, override
11
11
 
12
12
  from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
13
13
 
@@ -15,6 +15,8 @@ T = TypeVar("T")
15
15
 
16
16
 
17
17
  class OutputFixingParserRetryChainInput(TypedDict, total=False):
18
+ """Input for the retry chain of the OutputFixingParser."""
19
+
18
20
  instructions: str
19
21
  completion: str
20
22
  error: str
@@ -24,6 +26,7 @@ class OutputFixingParser(BaseOutputParser[T]):
24
26
  """Wrap a parser and try to fix parsing errors."""
25
27
 
26
28
  @classmethod
29
+ @override
27
30
  def is_lc_serializable(cls) -> bool:
28
31
  return True
29
32
 
@@ -62,6 +65,7 @@ class OutputFixingParser(BaseOutputParser[T]):
62
65
  chain = prompt | llm | StrOutputParser()
63
66
  return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
64
67
 
68
+ @override
65
69
  def parse(self, completion: str) -> T:
66
70
  retries = 0
67
71
 
@@ -99,6 +103,7 @@ class OutputFixingParser(BaseOutputParser[T]):
99
103
  msg = "Failed to parse"
100
104
  raise OutputParserException(msg)
101
105
 
106
+ @override
102
107
  async def aparse(self, completion: str) -> T:
103
108
  retries = 0
104
109
 
@@ -136,6 +141,7 @@ class OutputFixingParser(BaseOutputParser[T]):
136
141
  msg = "Failed to parse"
137
142
  raise OutputParserException(msg)
138
143
 
144
+ @override
139
145
  def get_format_instructions(self) -> str:
140
146
  return self.parser.get_format_instructions()
141
147
 
@@ -144,5 +150,6 @@ class OutputFixingParser(BaseOutputParser[T]):
144
150
  return "output_fixing"
145
151
 
146
152
  @property
153
+ @override
147
154
  def OutputType(self) -> type[T]:
148
155
  return self.parser.OutputType
@@ -4,6 +4,7 @@ from typing import Any, Union
4
4
  from langchain_core.exceptions import OutputParserException
5
5
  from langchain_core.output_parsers.base import BaseOutputParser
6
6
  from pydantic import field_validator
7
+ from typing_extensions import override
7
8
 
8
9
  from langchain.output_parsers.format_instructions import (
9
10
  PANDAS_DATAFRAME_FORMAT_INSTRUCTIONS,
@@ -18,7 +19,7 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
18
19
 
19
20
  @field_validator("dataframe")
20
21
  @classmethod
21
- def validate_dataframe(cls, val: Any) -> Any:
22
+ def _validate_dataframe(cls, val: Any) -> Any:
22
23
  import pandas as pd
23
24
 
24
25
  if issubclass(type(val), pd.DataFrame):
@@ -36,6 +37,18 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
36
37
  array: str,
37
38
  original_request_params: str,
38
39
  ) -> tuple[list[Union[int, str]], str]:
40
+ """Parse the array from the request parameters.
41
+
42
+ Args:
43
+ array: The array string to parse.
44
+ original_request_params: The original request parameters string.
45
+
46
+ Returns:
47
+ A tuple containing the parsed array and the stripped request parameters.
48
+
49
+ Raises:
50
+ OutputParserException: If the array format is invalid or cannot be parsed.
51
+ """
39
52
  parsed_array: list[Union[int, str]] = []
40
53
 
41
54
  # Check if the format is [1,3,5]
@@ -76,6 +89,7 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
76
89
 
77
90
  return parsed_array, original_request_params.split("[")[0]
78
91
 
92
+ @override
79
93
  def parse(self, request: str) -> dict[str, Any]:
80
94
  stripped_request_params = None
81
95
  splitted_request = request.strip().split(":")
@@ -150,6 +164,7 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
150
164
 
151
165
  return result
152
166
 
167
+ @override
153
168
  def get_format_instructions(self) -> str:
154
169
  return PANDAS_DATAFRAME_FORMAT_INSTRUCTIONS.format(
155
170
  columns=", ".join(self.dataframe.columns),
@@ -4,12 +4,14 @@ import re
4
4
  from typing import Optional
5
5
 
6
6
  from langchain_core.output_parsers import BaseOutputParser
7
+ from typing_extensions import override
7
8
 
8
9
 
9
10
  class RegexParser(BaseOutputParser[dict[str, str]]):
10
11
  """Parse the output of an LLM call using a regex."""
11
12
 
12
13
  @classmethod
14
+ @override
13
15
  def is_lc_serializable(cls) -> bool:
14
16
  return True
15
17
 
@@ -9,7 +9,7 @@ from langchain_core.prompt_values import PromptValue
9
9
  from langchain_core.prompts import BasePromptTemplate, PromptTemplate
10
10
  from langchain_core.runnables import RunnableSerializable
11
11
  from pydantic import SkipValidation
12
- from typing_extensions import TypedDict
12
+ from typing_extensions import TypedDict, override
13
13
 
14
14
  NAIVE_COMPLETION_RETRY = """Prompt:
15
15
  {prompt}
@@ -37,11 +37,15 @@ T = TypeVar("T")
37
37
 
38
38
 
39
39
  class RetryOutputParserRetryChainInput(TypedDict):
40
+ """Retry chain input for RetryOutputParser."""
41
+
40
42
  prompt: str
41
43
  completion: str
42
44
 
43
45
 
44
46
  class RetryWithErrorOutputParserRetryChainInput(TypedDict):
47
+ """Retry chain input for RetryWithErrorOutputParser."""
48
+
45
49
  prompt: str
46
50
  completion: str
47
51
  error: str
@@ -160,10 +164,12 @@ class RetryOutputParser(BaseOutputParser[T]):
160
164
  msg = "Failed to parse"
161
165
  raise OutputParserException(msg)
162
166
 
167
+ @override
163
168
  def parse(self, completion: str) -> T:
164
169
  msg = "This OutputParser can only be called by the `parse_with_prompt` method."
165
170
  raise NotImplementedError(msg)
166
171
 
172
+ @override
167
173
  def get_format_instructions(self) -> str:
168
174
  return self.parser.get_format_instructions()
169
175
 
@@ -172,6 +178,7 @@ class RetryOutputParser(BaseOutputParser[T]):
172
178
  return "retry"
173
179
 
174
180
  @property
181
+ @override
175
182
  def OutputType(self) -> type[T]:
176
183
  return self.parser.OutputType
177
184
 
@@ -224,6 +231,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
224
231
  chain = prompt | llm | StrOutputParser()
225
232
  return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
226
233
 
234
+ @override
227
235
  def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
228
236
  retries = 0
229
237
 
@@ -253,6 +261,15 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
253
261
  raise OutputParserException(msg)
254
262
 
255
263
  async def aparse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
264
+ """Parse the output of an LLM call using a wrapped parser.
265
+
266
+ Args:
267
+ completion: The chain completion to parse.
268
+ prompt_value: The prompt to use to parse the completion.
269
+
270
+ Returns:
271
+ The parsed completion.
272
+ """
256
273
  retries = 0
257
274
 
258
275
  while retries <= self.max_retries:
@@ -280,10 +297,12 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
280
297
  msg = "Failed to parse"
281
298
  raise OutputParserException(msg)
282
299
 
300
+ @override
283
301
  def parse(self, completion: str) -> T:
284
302
  msg = "This OutputParser can only be called by the `parse_with_prompt` method."
285
303
  raise NotImplementedError(msg)
286
304
 
305
+ @override
287
306
  def get_format_instructions(self) -> str:
288
307
  return self.parser.get_format_instructions()
289
308
 
@@ -292,5 +311,6 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
292
311
  return "retry_with_error"
293
312
 
294
313
  @property
314
+ @override
295
315
  def OutputType(self) -> type[T]:
296
316
  return self.parser.OutputType