langchain 0.3.27__py3-none-any.whl → 0.4.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/agents/agent.py +16 -20
- langchain/agents/agent_iterator.py +19 -12
- langchain/agents/agent_toolkits/vectorstore/base.py +2 -0
- langchain/agents/chat/base.py +2 -0
- langchain/agents/conversational/base.py +2 -0
- langchain/agents/conversational_chat/base.py +2 -0
- langchain/agents/initialize.py +1 -1
- langchain/agents/json_chat/base.py +1 -0
- langchain/agents/mrkl/base.py +2 -0
- langchain/agents/openai_assistant/base.py +1 -1
- langchain/agents/openai_functions_agent/agent_token_buffer_memory.py +2 -0
- langchain/agents/openai_functions_agent/base.py +3 -2
- langchain/agents/openai_functions_multi_agent/base.py +1 -1
- langchain/agents/openai_tools/base.py +1 -0
- langchain/agents/output_parsers/json.py +2 -0
- langchain/agents/output_parsers/openai_functions.py +10 -3
- langchain/agents/output_parsers/openai_tools.py +8 -1
- langchain/agents/output_parsers/react_json_single_input.py +3 -0
- langchain/agents/output_parsers/react_single_input.py +3 -0
- langchain/agents/output_parsers/self_ask.py +2 -0
- langchain/agents/output_parsers/tools.py +16 -2
- langchain/agents/output_parsers/xml.py +3 -0
- langchain/agents/react/agent.py +1 -0
- langchain/agents/react/base.py +4 -0
- langchain/agents/react/output_parser.py +2 -0
- langchain/agents/schema.py +2 -0
- langchain/agents/self_ask_with_search/base.py +4 -0
- langchain/agents/structured_chat/base.py +5 -0
- langchain/agents/structured_chat/output_parser.py +13 -0
- langchain/agents/tool_calling_agent/base.py +1 -0
- langchain/agents/tools.py +3 -0
- langchain/agents/xml/base.py +7 -1
- langchain/callbacks/streaming_aiter.py +13 -2
- langchain/callbacks/streaming_aiter_final_only.py +11 -2
- langchain/callbacks/streaming_stdout_final_only.py +5 -0
- langchain/callbacks/tracers/logging.py +11 -0
- langchain/chains/api/base.py +5 -1
- langchain/chains/base.py +8 -2
- langchain/chains/combine_documents/base.py +7 -1
- langchain/chains/combine_documents/map_reduce.py +3 -0
- langchain/chains/combine_documents/map_rerank.py +6 -4
- langchain/chains/combine_documents/reduce.py +1 -0
- langchain/chains/combine_documents/refine.py +1 -0
- langchain/chains/combine_documents/stuff.py +5 -1
- langchain/chains/constitutional_ai/base.py +7 -0
- langchain/chains/conversation/base.py +4 -1
- langchain/chains/conversational_retrieval/base.py +67 -59
- langchain/chains/elasticsearch_database/base.py +2 -1
- langchain/chains/flare/base.py +2 -0
- langchain/chains/flare/prompts.py +2 -0
- langchain/chains/llm.py +7 -2
- langchain/chains/llm_bash/__init__.py +1 -1
- langchain/chains/llm_checker/base.py +12 -1
- langchain/chains/llm_math/base.py +9 -1
- langchain/chains/llm_summarization_checker/base.py +13 -1
- langchain/chains/llm_symbolic_math/__init__.py +1 -1
- langchain/chains/loading.py +4 -2
- langchain/chains/moderation.py +3 -0
- langchain/chains/natbot/base.py +3 -1
- langchain/chains/natbot/crawler.py +29 -0
- langchain/chains/openai_functions/base.py +2 -0
- langchain/chains/openai_functions/citation_fuzzy_match.py +9 -0
- langchain/chains/openai_functions/openapi.py +4 -0
- langchain/chains/openai_functions/qa_with_structure.py +3 -3
- langchain/chains/openai_functions/tagging.py +2 -0
- langchain/chains/qa_generation/base.py +4 -0
- langchain/chains/qa_with_sources/base.py +3 -0
- langchain/chains/qa_with_sources/retrieval.py +1 -1
- langchain/chains/qa_with_sources/vector_db.py +4 -2
- langchain/chains/query_constructor/base.py +4 -2
- langchain/chains/query_constructor/parser.py +64 -2
- langchain/chains/retrieval_qa/base.py +4 -0
- langchain/chains/router/base.py +14 -2
- langchain/chains/router/embedding_router.py +3 -0
- langchain/chains/router/llm_router.py +6 -4
- langchain/chains/router/multi_prompt.py +3 -0
- langchain/chains/router/multi_retrieval_qa.py +18 -0
- langchain/chains/sql_database/query.py +1 -0
- langchain/chains/structured_output/base.py +2 -0
- langchain/chains/transform.py +4 -0
- langchain/chat_models/base.py +55 -18
- langchain/document_loaders/blob_loaders/schema.py +1 -4
- langchain/embeddings/base.py +2 -0
- langchain/embeddings/cache.py +3 -3
- langchain/evaluation/agents/trajectory_eval_chain.py +3 -2
- langchain/evaluation/comparison/eval_chain.py +1 -0
- langchain/evaluation/criteria/eval_chain.py +3 -0
- langchain/evaluation/embedding_distance/base.py +11 -0
- langchain/evaluation/exact_match/base.py +14 -1
- langchain/evaluation/loading.py +1 -0
- langchain/evaluation/parsing/base.py +16 -3
- langchain/evaluation/parsing/json_distance.py +19 -8
- langchain/evaluation/parsing/json_schema.py +1 -4
- langchain/evaluation/qa/eval_chain.py +8 -0
- langchain/evaluation/qa/generate_chain.py +2 -0
- langchain/evaluation/regex_match/base.py +9 -1
- langchain/evaluation/scoring/eval_chain.py +1 -0
- langchain/evaluation/string_distance/base.py +6 -0
- langchain/memory/buffer.py +5 -0
- langchain/memory/buffer_window.py +2 -0
- langchain/memory/combined.py +1 -1
- langchain/memory/entity.py +47 -0
- langchain/memory/simple.py +3 -0
- langchain/memory/summary.py +30 -0
- langchain/memory/summary_buffer.py +3 -0
- langchain/memory/token_buffer.py +2 -0
- langchain/output_parsers/combining.py +4 -2
- langchain/output_parsers/enum.py +5 -1
- langchain/output_parsers/fix.py +8 -1
- langchain/output_parsers/pandas_dataframe.py +16 -1
- langchain/output_parsers/regex.py +2 -0
- langchain/output_parsers/retry.py +21 -1
- langchain/output_parsers/structured.py +10 -0
- langchain/output_parsers/yaml.py +4 -0
- langchain/pydantic_v1/__init__.py +1 -1
- langchain/retrievers/document_compressors/chain_extract.py +4 -2
- langchain/retrievers/document_compressors/cohere_rerank.py +2 -0
- langchain/retrievers/document_compressors/cross_encoder_rerank.py +2 -0
- langchain/retrievers/document_compressors/embeddings_filter.py +3 -0
- langchain/retrievers/document_compressors/listwise_rerank.py +1 -0
- langchain/retrievers/ensemble.py +2 -2
- langchain/retrievers/multi_query.py +3 -1
- langchain/retrievers/multi_vector.py +4 -1
- langchain/retrievers/parent_document_retriever.py +15 -0
- langchain/retrievers/self_query/base.py +19 -0
- langchain/retrievers/time_weighted_retriever.py +3 -0
- langchain/runnables/hub.py +12 -0
- langchain/runnables/openai_functions.py +6 -0
- langchain/smith/__init__.py +1 -0
- langchain/smith/evaluation/config.py +5 -22
- langchain/smith/evaluation/progress.py +12 -3
- langchain/smith/evaluation/runner_utils.py +240 -123
- langchain/smith/evaluation/string_run_evaluator.py +27 -0
- langchain/storage/encoder_backed.py +1 -0
- langchain/tools/python/__init__.py +1 -1
- {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/METADATA +2 -12
- {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/RECORD +140 -141
- langchain/smith/evaluation/utils.py +0 -0
- {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/WHEEL +0 -0
- {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/entry_points.txt +0 -0
- {langchain-0.3.27.dist-info → langchain-0.4.0.dev0.dist-info}/licenses/LICENSE +0 -0
|
@@ -224,6 +224,7 @@ class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
|
|
|
224
224
|
"""
|
|
225
225
|
return f"{self.distance.value}_distance"
|
|
226
226
|
|
|
227
|
+
@override
|
|
227
228
|
def _call(
|
|
228
229
|
self,
|
|
229
230
|
inputs: dict[str, Any],
|
|
@@ -242,6 +243,7 @@ class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
|
|
|
242
243
|
"""
|
|
243
244
|
return {"score": self.compute_metric(inputs["reference"], inputs["prediction"])}
|
|
244
245
|
|
|
246
|
+
@override
|
|
245
247
|
async def _acall(
|
|
246
248
|
self,
|
|
247
249
|
inputs: dict[str, Any],
|
|
@@ -357,6 +359,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
|
|
|
357
359
|
"""
|
|
358
360
|
return f"pairwise_{self.distance.value}_distance"
|
|
359
361
|
|
|
362
|
+
@override
|
|
360
363
|
def _call(
|
|
361
364
|
self,
|
|
362
365
|
inputs: dict[str, Any],
|
|
@@ -377,6 +380,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
|
|
|
377
380
|
"score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]),
|
|
378
381
|
}
|
|
379
382
|
|
|
383
|
+
@override
|
|
380
384
|
async def _acall(
|
|
381
385
|
self,
|
|
382
386
|
inputs: dict[str, Any],
|
|
@@ -397,6 +401,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
|
|
|
397
401
|
"score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]),
|
|
398
402
|
}
|
|
399
403
|
|
|
404
|
+
@override
|
|
400
405
|
def _evaluate_string_pairs(
|
|
401
406
|
self,
|
|
402
407
|
*,
|
|
@@ -431,6 +436,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
|
|
|
431
436
|
)
|
|
432
437
|
return self._prepare_output(result)
|
|
433
438
|
|
|
439
|
+
@override
|
|
434
440
|
async def _aevaluate_string_pairs(
|
|
435
441
|
self,
|
|
436
442
|
*,
|
langchain/memory/buffer.py
CHANGED
|
@@ -4,6 +4,7 @@ from langchain_core._api import deprecated
|
|
|
4
4
|
from langchain_core.memory import BaseMemory
|
|
5
5
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
|
6
6
|
from langchain_core.utils import pre_init
|
|
7
|
+
from typing_extensions import override
|
|
7
8
|
|
|
8
9
|
from langchain.memory.chat_memory import BaseChatMemory
|
|
9
10
|
from langchain.memory.utils import get_prompt_input_key
|
|
@@ -78,10 +79,12 @@ class ConversationBufferMemory(BaseChatMemory):
|
|
|
78
79
|
"""
|
|
79
80
|
return [self.memory_key]
|
|
80
81
|
|
|
82
|
+
@override
|
|
81
83
|
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
|
82
84
|
"""Return history buffer."""
|
|
83
85
|
return {self.memory_key: self.buffer}
|
|
84
86
|
|
|
87
|
+
@override
|
|
85
88
|
async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
|
86
89
|
"""Return key-value pairs given the text input to the chain."""
|
|
87
90
|
buffer = await self.abuffer()
|
|
@@ -132,6 +135,7 @@ class ConversationStringBufferMemory(BaseMemory):
|
|
|
132
135
|
"""
|
|
133
136
|
return [self.memory_key]
|
|
134
137
|
|
|
138
|
+
@override
|
|
135
139
|
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, str]:
|
|
136
140
|
"""Return history buffer."""
|
|
137
141
|
return {self.memory_key: self.buffer}
|
|
@@ -169,5 +173,6 @@ class ConversationStringBufferMemory(BaseMemory):
|
|
|
169
173
|
"""Clear memory contents."""
|
|
170
174
|
self.buffer = ""
|
|
171
175
|
|
|
176
|
+
@override
|
|
172
177
|
async def aclear(self) -> None:
|
|
173
178
|
self.clear()
|
|
@@ -2,6 +2,7 @@ from typing import Any, Union
|
|
|
2
2
|
|
|
3
3
|
from langchain_core._api import deprecated
|
|
4
4
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
|
5
|
+
from typing_extensions import override
|
|
5
6
|
|
|
6
7
|
from langchain.memory.chat_memory import BaseChatMemory
|
|
7
8
|
|
|
@@ -55,6 +56,7 @@ class ConversationBufferWindowMemory(BaseChatMemory):
|
|
|
55
56
|
"""
|
|
56
57
|
return [self.memory_key]
|
|
57
58
|
|
|
59
|
+
@override
|
|
58
60
|
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
|
59
61
|
"""Return history buffer."""
|
|
60
62
|
return {self.memory_key: self.buffer}
|
langchain/memory/combined.py
CHANGED
langchain/memory/entity.py
CHANGED
|
@@ -11,6 +11,7 @@ from langchain_core.language_models import BaseLanguageModel
|
|
|
11
11
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
|
12
12
|
from langchain_core.prompts import BasePromptTemplate
|
|
13
13
|
from pydantic import BaseModel, ConfigDict, Field
|
|
14
|
+
from typing_extensions import override
|
|
14
15
|
|
|
15
16
|
from langchain.chains.llm import LLMChain
|
|
16
17
|
from langchain.memory.chat_memory import BaseChatMemory
|
|
@@ -71,18 +72,23 @@ class InMemoryEntityStore(BaseEntityStore):
|
|
|
71
72
|
|
|
72
73
|
store: dict[str, Optional[str]] = {}
|
|
73
74
|
|
|
75
|
+
@override
|
|
74
76
|
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
|
75
77
|
return self.store.get(key, default)
|
|
76
78
|
|
|
79
|
+
@override
|
|
77
80
|
def set(self, key: str, value: Optional[str]) -> None:
|
|
78
81
|
self.store[key] = value
|
|
79
82
|
|
|
83
|
+
@override
|
|
80
84
|
def delete(self, key: str) -> None:
|
|
81
85
|
del self.store[key]
|
|
82
86
|
|
|
87
|
+
@override
|
|
83
88
|
def exists(self, key: str) -> bool:
|
|
84
89
|
return key in self.store
|
|
85
90
|
|
|
91
|
+
@override
|
|
86
92
|
def clear(self) -> None:
|
|
87
93
|
return self.store.clear()
|
|
88
94
|
|
|
@@ -113,6 +119,16 @@ class UpstashRedisEntityStore(BaseEntityStore):
|
|
|
113
119
|
*args: Any,
|
|
114
120
|
**kwargs: Any,
|
|
115
121
|
):
|
|
122
|
+
"""Initializes the RedisEntityStore.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
session_id: Unique identifier for the session.
|
|
126
|
+
url: URL of the Redis server.
|
|
127
|
+
token: Authentication token for the Redis server.
|
|
128
|
+
key_prefix: Prefix for keys in the Redis store.
|
|
129
|
+
ttl: Time-to-live for keys in seconds (default 1 day).
|
|
130
|
+
recall_ttl: Time-to-live extension for keys when recalled (default 3 days).
|
|
131
|
+
"""
|
|
116
132
|
try:
|
|
117
133
|
from upstash_redis import Redis
|
|
118
134
|
except ImportError as e:
|
|
@@ -138,8 +154,10 @@ class UpstashRedisEntityStore(BaseEntityStore):
|
|
|
138
154
|
|
|
139
155
|
@property
|
|
140
156
|
def full_key_prefix(self) -> str:
|
|
157
|
+
"""Returns the full key prefix with session ID."""
|
|
141
158
|
return f"{self.key_prefix}:{self.session_id}"
|
|
142
159
|
|
|
160
|
+
@override
|
|
143
161
|
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
|
144
162
|
res = (
|
|
145
163
|
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
|
|
@@ -151,6 +169,7 @@ class UpstashRedisEntityStore(BaseEntityStore):
|
|
|
151
169
|
)
|
|
152
170
|
return res
|
|
153
171
|
|
|
172
|
+
@override
|
|
154
173
|
def set(self, key: str, value: Optional[str]) -> None:
|
|
155
174
|
if not value:
|
|
156
175
|
return self.delete(key)
|
|
@@ -164,12 +183,15 @@ class UpstashRedisEntityStore(BaseEntityStore):
|
|
|
164
183
|
)
|
|
165
184
|
return None
|
|
166
185
|
|
|
186
|
+
@override
|
|
167
187
|
def delete(self, key: str) -> None:
|
|
168
188
|
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
|
169
189
|
|
|
190
|
+
@override
|
|
170
191
|
def exists(self, key: str) -> bool:
|
|
171
192
|
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
|
|
172
193
|
|
|
194
|
+
@override
|
|
173
195
|
def clear(self) -> None:
|
|
174
196
|
def scan_and_delete(cursor: int) -> int:
|
|
175
197
|
cursor, keys_to_delete = self.redis_client.scan(
|
|
@@ -215,6 +237,16 @@ class RedisEntityStore(BaseEntityStore):
|
|
|
215
237
|
*args: Any,
|
|
216
238
|
**kwargs: Any,
|
|
217
239
|
):
|
|
240
|
+
"""Initializes the RedisEntityStore.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
session_id: Unique identifier for the session.
|
|
244
|
+
url: URL of the Redis server.
|
|
245
|
+
key_prefix: Prefix for keys in the Redis store.
|
|
246
|
+
ttl: Time-to-live for keys in seconds (default 1 day).
|
|
247
|
+
recall_ttl: Time-to-live extension for keys when recalled (default 3 days).
|
|
248
|
+
"""
|
|
249
|
+
|
|
218
250
|
try:
|
|
219
251
|
import redis
|
|
220
252
|
except ImportError as e:
|
|
@@ -247,8 +279,10 @@ class RedisEntityStore(BaseEntityStore):
|
|
|
247
279
|
|
|
248
280
|
@property
|
|
249
281
|
def full_key_prefix(self) -> str:
|
|
282
|
+
"""Returns the full key prefix with session ID."""
|
|
250
283
|
return f"{self.key_prefix}:{self.session_id}"
|
|
251
284
|
|
|
285
|
+
@override
|
|
252
286
|
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
|
253
287
|
res = (
|
|
254
288
|
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
|
|
@@ -258,6 +292,7 @@ class RedisEntityStore(BaseEntityStore):
|
|
|
258
292
|
logger.debug("REDIS MEM get '%s:%s': '%s'", self.full_key_prefix, key, res)
|
|
259
293
|
return res
|
|
260
294
|
|
|
295
|
+
@override
|
|
261
296
|
def set(self, key: str, value: Optional[str]) -> None:
|
|
262
297
|
if not value:
|
|
263
298
|
return self.delete(key)
|
|
@@ -271,12 +306,15 @@ class RedisEntityStore(BaseEntityStore):
|
|
|
271
306
|
)
|
|
272
307
|
return None
|
|
273
308
|
|
|
309
|
+
@override
|
|
274
310
|
def delete(self, key: str) -> None:
|
|
275
311
|
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
|
276
312
|
|
|
313
|
+
@override
|
|
277
314
|
def exists(self, key: str) -> bool:
|
|
278
315
|
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
|
|
279
316
|
|
|
317
|
+
@override
|
|
280
318
|
def clear(self) -> None:
|
|
281
319
|
# iterate a list in batches of size batch_size
|
|
282
320
|
def batched(iterable: Iterable[Any], batch_size: int) -> Iterable[Any]:
|
|
@@ -318,6 +356,13 @@ class SQLiteEntityStore(BaseEntityStore):
|
|
|
318
356
|
*args: Any,
|
|
319
357
|
**kwargs: Any,
|
|
320
358
|
):
|
|
359
|
+
"""Initializes the SQLiteEntityStore.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
session_id: Unique identifier for the session.
|
|
363
|
+
db_file: Path to the SQLite database file.
|
|
364
|
+
table_name: Name of the table to store entities.
|
|
365
|
+
"""
|
|
321
366
|
super().__init__(*args, **kwargs)
|
|
322
367
|
try:
|
|
323
368
|
import sqlite3
|
|
@@ -341,6 +386,7 @@ class SQLiteEntityStore(BaseEntityStore):
|
|
|
341
386
|
|
|
342
387
|
@property
|
|
343
388
|
def full_table_name(self) -> str:
|
|
389
|
+
"""Returns the full table name with session ID."""
|
|
344
390
|
return f"{self.table_name}_{self.session_id}"
|
|
345
391
|
|
|
346
392
|
def _execute_query(self, query: str, params: tuple = ()) -> "sqlite3.Cursor":
|
|
@@ -393,6 +439,7 @@ class SQLiteEntityStore(BaseEntityStore):
|
|
|
393
439
|
cursor = self._execute_query(query, (key,))
|
|
394
440
|
return cursor.fetchone() is not None
|
|
395
441
|
|
|
442
|
+
@override
|
|
396
443
|
def clear(self) -> None:
|
|
397
444
|
# Ignore S608 since we validate for malicious table/session names in `__init__`
|
|
398
445
|
query = f"""
|
langchain/memory/simple.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from typing import Any
|
|
2
2
|
|
|
3
3
|
from langchain_core.memory import BaseMemory
|
|
4
|
+
from typing_extensions import override
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
class SimpleMemory(BaseMemory):
|
|
@@ -11,9 +12,11 @@ class SimpleMemory(BaseMemory):
|
|
|
11
12
|
memories: dict[str, Any] = {}
|
|
12
13
|
|
|
13
14
|
@property
|
|
15
|
+
@override
|
|
14
16
|
def memory_variables(self) -> list[str]:
|
|
15
17
|
return list(self.memories.keys())
|
|
16
18
|
|
|
19
|
+
@override
|
|
17
20
|
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, str]:
|
|
18
21
|
return self.memories
|
|
19
22
|
|
langchain/memory/summary.py
CHANGED
|
@@ -9,6 +9,7 @@ from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_strin
|
|
|
9
9
|
from langchain_core.prompts import BasePromptTemplate
|
|
10
10
|
from langchain_core.utils import pre_init
|
|
11
11
|
from pydantic import BaseModel
|
|
12
|
+
from typing_extensions import override
|
|
12
13
|
|
|
13
14
|
from langchain.chains.llm import LLMChain
|
|
14
15
|
from langchain.memory.chat_memory import BaseChatMemory
|
|
@@ -37,6 +38,15 @@ class SummarizerMixin(BaseModel):
|
|
|
37
38
|
messages: list[BaseMessage],
|
|
38
39
|
existing_summary: str,
|
|
39
40
|
) -> str:
|
|
41
|
+
"""Predict a new summary based on the messages and existing summary.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
messages: List of messages to summarize.
|
|
45
|
+
existing_summary: Existing summary to build upon.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
A new summary string.
|
|
49
|
+
"""
|
|
40
50
|
new_lines = get_buffer_string(
|
|
41
51
|
messages,
|
|
42
52
|
human_prefix=self.human_prefix,
|
|
@@ -51,6 +61,15 @@ class SummarizerMixin(BaseModel):
|
|
|
51
61
|
messages: list[BaseMessage],
|
|
52
62
|
existing_summary: str,
|
|
53
63
|
) -> str:
|
|
64
|
+
"""Predict a new summary based on the messages and existing summary.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
messages: List of messages to summarize.
|
|
68
|
+
existing_summary: Existing summary to build upon.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
A new summary string.
|
|
72
|
+
"""
|
|
54
73
|
new_lines = get_buffer_string(
|
|
55
74
|
messages,
|
|
56
75
|
human_prefix=self.human_prefix,
|
|
@@ -89,6 +108,16 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
|
|
|
89
108
|
summarize_step: int = 2,
|
|
90
109
|
**kwargs: Any,
|
|
91
110
|
) -> ConversationSummaryMemory:
|
|
111
|
+
"""Create a ConversationSummaryMemory from a list of messages.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
llm: The language model to use for summarization.
|
|
115
|
+
chat_memory: The chat history to summarize.
|
|
116
|
+
summarize_step: Number of messages to summarize at a time.
|
|
117
|
+
**kwargs: Additional keyword arguments to pass to the class.
|
|
118
|
+
Returns:
|
|
119
|
+
An instance of ConversationSummaryMemory with the summarized history.
|
|
120
|
+
"""
|
|
92
121
|
obj = cls(llm=llm, chat_memory=chat_memory, **kwargs)
|
|
93
122
|
for i in range(0, len(obj.chat_memory.messages), summarize_step):
|
|
94
123
|
obj.buffer = obj.predict_new_summary(
|
|
@@ -105,6 +134,7 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
|
|
|
105
134
|
"""
|
|
106
135
|
return [self.memory_key]
|
|
107
136
|
|
|
137
|
+
@override
|
|
108
138
|
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
|
109
139
|
"""Return history buffer."""
|
|
110
140
|
if self.return_messages:
|
|
@@ -3,6 +3,7 @@ from typing import Any, Union
|
|
|
3
3
|
from langchain_core._api import deprecated
|
|
4
4
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
|
5
5
|
from langchain_core.utils import pre_init
|
|
6
|
+
from typing_extensions import override
|
|
6
7
|
|
|
7
8
|
from langchain.memory.chat_memory import BaseChatMemory
|
|
8
9
|
from langchain.memory.summary import SummarizerMixin
|
|
@@ -46,6 +47,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
|
|
|
46
47
|
"""
|
|
47
48
|
return [self.memory_key]
|
|
48
49
|
|
|
50
|
+
@override
|
|
49
51
|
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
|
50
52
|
"""Return history buffer."""
|
|
51
53
|
buffer = self.chat_memory.messages
|
|
@@ -64,6 +66,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
|
|
|
64
66
|
)
|
|
65
67
|
return {self.memory_key: final_buffer}
|
|
66
68
|
|
|
69
|
+
@override
|
|
67
70
|
async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
|
68
71
|
"""Asynchronously return key-value pairs given the text input to the chain."""
|
|
69
72
|
buffer = await self.chat_memory.aget_messages()
|
langchain/memory/token_buffer.py
CHANGED
|
@@ -3,6 +3,7 @@ from typing import Any
|
|
|
3
3
|
from langchain_core._api import deprecated
|
|
4
4
|
from langchain_core.language_models import BaseLanguageModel
|
|
5
5
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
|
6
|
+
from typing_extensions import override
|
|
6
7
|
|
|
7
8
|
from langchain.memory.chat_memory import BaseChatMemory
|
|
8
9
|
|
|
@@ -55,6 +56,7 @@ class ConversationTokenBufferMemory(BaseChatMemory):
|
|
|
55
56
|
"""
|
|
56
57
|
return [self.memory_key]
|
|
57
58
|
|
|
59
|
+
@override
|
|
58
60
|
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
|
59
61
|
"""Return history buffer."""
|
|
60
62
|
return {self.memory_key: self.buffer}
|
|
@@ -4,6 +4,7 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
from langchain_core.output_parsers import BaseOutputParser
|
|
6
6
|
from langchain_core.utils import pre_init
|
|
7
|
+
from typing_extensions import override
|
|
7
8
|
|
|
8
9
|
_MIN_PARSERS = 2
|
|
9
10
|
|
|
@@ -14,6 +15,7 @@ class CombiningOutputParser(BaseOutputParser[dict[str, Any]]):
|
|
|
14
15
|
parsers: list[BaseOutputParser]
|
|
15
16
|
|
|
16
17
|
@classmethod
|
|
18
|
+
@override
|
|
17
19
|
def is_lc_serializable(cls) -> bool:
|
|
18
20
|
return True
|
|
19
21
|
|
|
@@ -25,10 +27,10 @@ class CombiningOutputParser(BaseOutputParser[dict[str, Any]]):
|
|
|
25
27
|
msg = "Must have at least two parsers"
|
|
26
28
|
raise ValueError(msg)
|
|
27
29
|
for parser in parsers:
|
|
28
|
-
if parser._type == "combining":
|
|
30
|
+
if parser._type == "combining": # noqa: SLF001
|
|
29
31
|
msg = "Cannot nest combining parsers"
|
|
30
32
|
raise ValueError(msg)
|
|
31
|
-
if parser._type == "list":
|
|
33
|
+
if parser._type == "list": # noqa: SLF001
|
|
32
34
|
msg = "Cannot combine list parsers"
|
|
33
35
|
raise ValueError(msg)
|
|
34
36
|
return values
|
langchain/output_parsers/enum.py
CHANGED
|
@@ -3,6 +3,7 @@ from enum import Enum
|
|
|
3
3
|
from langchain_core.exceptions import OutputParserException
|
|
4
4
|
from langchain_core.output_parsers import BaseOutputParser
|
|
5
5
|
from langchain_core.utils import pre_init
|
|
6
|
+
from typing_extensions import override
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class EnumOutputParser(BaseOutputParser[Enum]):
|
|
@@ -12,7 +13,7 @@ class EnumOutputParser(BaseOutputParser[Enum]):
|
|
|
12
13
|
"""The enum to parse. Its values must be strings."""
|
|
13
14
|
|
|
14
15
|
@pre_init
|
|
15
|
-
def
|
|
16
|
+
def _raise_deprecation(cls, values: dict) -> dict:
|
|
16
17
|
enum = values["enum"]
|
|
17
18
|
if not all(isinstance(e.value, str) for e in enum):
|
|
18
19
|
msg = "Enum values must be strings"
|
|
@@ -23,6 +24,7 @@ class EnumOutputParser(BaseOutputParser[Enum]):
|
|
|
23
24
|
def _valid_values(self) -> list[str]:
|
|
24
25
|
return [e.value for e in self.enum]
|
|
25
26
|
|
|
27
|
+
@override
|
|
26
28
|
def parse(self, response: str) -> Enum:
|
|
27
29
|
try:
|
|
28
30
|
return self.enum(response.strip())
|
|
@@ -33,9 +35,11 @@ class EnumOutputParser(BaseOutputParser[Enum]):
|
|
|
33
35
|
)
|
|
34
36
|
raise OutputParserException(msg) from e
|
|
35
37
|
|
|
38
|
+
@override
|
|
36
39
|
def get_format_instructions(self) -> str:
|
|
37
40
|
return f"Select one of the following options: {', '.join(self._valid_values)}"
|
|
38
41
|
|
|
39
42
|
@property
|
|
43
|
+
@override
|
|
40
44
|
def OutputType(self) -> type[Enum]:
|
|
41
45
|
return self.enum
|
langchain/output_parsers/fix.py
CHANGED
|
@@ -7,7 +7,7 @@ from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
|
|
|
7
7
|
from langchain_core.prompts import BasePromptTemplate
|
|
8
8
|
from langchain_core.runnables import Runnable, RunnableSerializable
|
|
9
9
|
from pydantic import SkipValidation
|
|
10
|
-
from typing_extensions import TypedDict
|
|
10
|
+
from typing_extensions import TypedDict, override
|
|
11
11
|
|
|
12
12
|
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
|
|
13
13
|
|
|
@@ -15,6 +15,8 @@ T = TypeVar("T")
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
class OutputFixingParserRetryChainInput(TypedDict, total=False):
|
|
18
|
+
"""Input for the retry chain of the OutputFixingParser."""
|
|
19
|
+
|
|
18
20
|
instructions: str
|
|
19
21
|
completion: str
|
|
20
22
|
error: str
|
|
@@ -24,6 +26,7 @@ class OutputFixingParser(BaseOutputParser[T]):
|
|
|
24
26
|
"""Wrap a parser and try to fix parsing errors."""
|
|
25
27
|
|
|
26
28
|
@classmethod
|
|
29
|
+
@override
|
|
27
30
|
def is_lc_serializable(cls) -> bool:
|
|
28
31
|
return True
|
|
29
32
|
|
|
@@ -62,6 +65,7 @@ class OutputFixingParser(BaseOutputParser[T]):
|
|
|
62
65
|
chain = prompt | llm | StrOutputParser()
|
|
63
66
|
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
|
|
64
67
|
|
|
68
|
+
@override
|
|
65
69
|
def parse(self, completion: str) -> T:
|
|
66
70
|
retries = 0
|
|
67
71
|
|
|
@@ -99,6 +103,7 @@ class OutputFixingParser(BaseOutputParser[T]):
|
|
|
99
103
|
msg = "Failed to parse"
|
|
100
104
|
raise OutputParserException(msg)
|
|
101
105
|
|
|
106
|
+
@override
|
|
102
107
|
async def aparse(self, completion: str) -> T:
|
|
103
108
|
retries = 0
|
|
104
109
|
|
|
@@ -136,6 +141,7 @@ class OutputFixingParser(BaseOutputParser[T]):
|
|
|
136
141
|
msg = "Failed to parse"
|
|
137
142
|
raise OutputParserException(msg)
|
|
138
143
|
|
|
144
|
+
@override
|
|
139
145
|
def get_format_instructions(self) -> str:
|
|
140
146
|
return self.parser.get_format_instructions()
|
|
141
147
|
|
|
@@ -144,5 +150,6 @@ class OutputFixingParser(BaseOutputParser[T]):
|
|
|
144
150
|
return "output_fixing"
|
|
145
151
|
|
|
146
152
|
@property
|
|
153
|
+
@override
|
|
147
154
|
def OutputType(self) -> type[T]:
|
|
148
155
|
return self.parser.OutputType
|
|
@@ -4,6 +4,7 @@ from typing import Any, Union
|
|
|
4
4
|
from langchain_core.exceptions import OutputParserException
|
|
5
5
|
from langchain_core.output_parsers.base import BaseOutputParser
|
|
6
6
|
from pydantic import field_validator
|
|
7
|
+
from typing_extensions import override
|
|
7
8
|
|
|
8
9
|
from langchain.output_parsers.format_instructions import (
|
|
9
10
|
PANDAS_DATAFRAME_FORMAT_INSTRUCTIONS,
|
|
@@ -18,7 +19,7 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
|
|
|
18
19
|
|
|
19
20
|
@field_validator("dataframe")
|
|
20
21
|
@classmethod
|
|
21
|
-
def
|
|
22
|
+
def _validate_dataframe(cls, val: Any) -> Any:
|
|
22
23
|
import pandas as pd
|
|
23
24
|
|
|
24
25
|
if issubclass(type(val), pd.DataFrame):
|
|
@@ -36,6 +37,18 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
|
|
|
36
37
|
array: str,
|
|
37
38
|
original_request_params: str,
|
|
38
39
|
) -> tuple[list[Union[int, str]], str]:
|
|
40
|
+
"""Parse the array from the request parameters.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
array: The array string to parse.
|
|
44
|
+
original_request_params: The original request parameters string.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
A tuple containing the parsed array and the stripped request parameters.
|
|
48
|
+
|
|
49
|
+
Raises:
|
|
50
|
+
OutputParserException: If the array format is invalid or cannot be parsed.
|
|
51
|
+
"""
|
|
39
52
|
parsed_array: list[Union[int, str]] = []
|
|
40
53
|
|
|
41
54
|
# Check if the format is [1,3,5]
|
|
@@ -76,6 +89,7 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
|
|
|
76
89
|
|
|
77
90
|
return parsed_array, original_request_params.split("[")[0]
|
|
78
91
|
|
|
92
|
+
@override
|
|
79
93
|
def parse(self, request: str) -> dict[str, Any]:
|
|
80
94
|
stripped_request_params = None
|
|
81
95
|
splitted_request = request.strip().split(":")
|
|
@@ -150,6 +164,7 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
|
|
|
150
164
|
|
|
151
165
|
return result
|
|
152
166
|
|
|
167
|
+
@override
|
|
153
168
|
def get_format_instructions(self) -> str:
|
|
154
169
|
return PANDAS_DATAFRAME_FORMAT_INSTRUCTIONS.format(
|
|
155
170
|
columns=", ".join(self.dataframe.columns),
|
|
@@ -4,12 +4,14 @@ import re
|
|
|
4
4
|
from typing import Optional
|
|
5
5
|
|
|
6
6
|
from langchain_core.output_parsers import BaseOutputParser
|
|
7
|
+
from typing_extensions import override
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
class RegexParser(BaseOutputParser[dict[str, str]]):
|
|
10
11
|
"""Parse the output of an LLM call using a regex."""
|
|
11
12
|
|
|
12
13
|
@classmethod
|
|
14
|
+
@override
|
|
13
15
|
def is_lc_serializable(cls) -> bool:
|
|
14
16
|
return True
|
|
15
17
|
|
|
@@ -9,7 +9,7 @@ from langchain_core.prompt_values import PromptValue
|
|
|
9
9
|
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
|
10
10
|
from langchain_core.runnables import RunnableSerializable
|
|
11
11
|
from pydantic import SkipValidation
|
|
12
|
-
from typing_extensions import TypedDict
|
|
12
|
+
from typing_extensions import TypedDict, override
|
|
13
13
|
|
|
14
14
|
NAIVE_COMPLETION_RETRY = """Prompt:
|
|
15
15
|
{prompt}
|
|
@@ -37,11 +37,15 @@ T = TypeVar("T")
|
|
|
37
37
|
|
|
38
38
|
|
|
39
39
|
class RetryOutputParserRetryChainInput(TypedDict):
|
|
40
|
+
"""Retry chain input for RetryOutputParser."""
|
|
41
|
+
|
|
40
42
|
prompt: str
|
|
41
43
|
completion: str
|
|
42
44
|
|
|
43
45
|
|
|
44
46
|
class RetryWithErrorOutputParserRetryChainInput(TypedDict):
|
|
47
|
+
"""Retry chain input for RetryWithErrorOutputParser."""
|
|
48
|
+
|
|
45
49
|
prompt: str
|
|
46
50
|
completion: str
|
|
47
51
|
error: str
|
|
@@ -160,10 +164,12 @@ class RetryOutputParser(BaseOutputParser[T]):
|
|
|
160
164
|
msg = "Failed to parse"
|
|
161
165
|
raise OutputParserException(msg)
|
|
162
166
|
|
|
167
|
+
@override
|
|
163
168
|
def parse(self, completion: str) -> T:
|
|
164
169
|
msg = "This OutputParser can only be called by the `parse_with_prompt` method."
|
|
165
170
|
raise NotImplementedError(msg)
|
|
166
171
|
|
|
172
|
+
@override
|
|
167
173
|
def get_format_instructions(self) -> str:
|
|
168
174
|
return self.parser.get_format_instructions()
|
|
169
175
|
|
|
@@ -172,6 +178,7 @@ class RetryOutputParser(BaseOutputParser[T]):
|
|
|
172
178
|
return "retry"
|
|
173
179
|
|
|
174
180
|
@property
|
|
181
|
+
@override
|
|
175
182
|
def OutputType(self) -> type[T]:
|
|
176
183
|
return self.parser.OutputType
|
|
177
184
|
|
|
@@ -224,6 +231,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
|
|
224
231
|
chain = prompt | llm | StrOutputParser()
|
|
225
232
|
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
|
|
226
233
|
|
|
234
|
+
@override
|
|
227
235
|
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
|
|
228
236
|
retries = 0
|
|
229
237
|
|
|
@@ -253,6 +261,15 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
|
|
253
261
|
raise OutputParserException(msg)
|
|
254
262
|
|
|
255
263
|
async def aparse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
|
|
264
|
+
"""Parse the output of an LLM call using a wrapped parser.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
completion: The chain completion to parse.
|
|
268
|
+
prompt_value: The prompt to use to parse the completion.
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
The parsed completion.
|
|
272
|
+
"""
|
|
256
273
|
retries = 0
|
|
257
274
|
|
|
258
275
|
while retries <= self.max_retries:
|
|
@@ -280,10 +297,12 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
|
|
280
297
|
msg = "Failed to parse"
|
|
281
298
|
raise OutputParserException(msg)
|
|
282
299
|
|
|
300
|
+
@override
|
|
283
301
|
def parse(self, completion: str) -> T:
|
|
284
302
|
msg = "This OutputParser can only be called by the `parse_with_prompt` method."
|
|
285
303
|
raise NotImplementedError(msg)
|
|
286
304
|
|
|
305
|
+
@override
|
|
287
306
|
def get_format_instructions(self) -> str:
|
|
288
307
|
return self.parser.get_format_instructions()
|
|
289
308
|
|
|
@@ -292,5 +311,6 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
|
|
292
311
|
return "retry_with_error"
|
|
293
312
|
|
|
294
313
|
@property
|
|
314
|
+
@override
|
|
295
315
|
def OutputType(self) -> type[T]:
|
|
296
316
|
return self.parser.OutputType
|