langchain-core 1.0.0a1__py3-none-any.whl → 1.0.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain-core might be problematic. Click here for more details.

Files changed (131) hide show
  1. langchain_core/_api/beta_decorator.py +17 -40
  2. langchain_core/_api/deprecation.py +20 -7
  3. langchain_core/_api/path.py +19 -2
  4. langchain_core/_import_utils.py +7 -0
  5. langchain_core/agents.py +10 -6
  6. langchain_core/callbacks/base.py +28 -15
  7. langchain_core/callbacks/manager.py +81 -69
  8. langchain_core/callbacks/usage.py +4 -2
  9. langchain_core/chat_history.py +29 -21
  10. langchain_core/document_loaders/base.py +34 -9
  11. langchain_core/document_loaders/langsmith.py +3 -0
  12. langchain_core/documents/base.py +35 -10
  13. langchain_core/documents/transformers.py +4 -2
  14. langchain_core/embeddings/fake.py +8 -5
  15. langchain_core/env.py +2 -3
  16. langchain_core/example_selectors/base.py +12 -0
  17. langchain_core/exceptions.py +7 -0
  18. langchain_core/globals.py +17 -28
  19. langchain_core/indexing/api.py +57 -45
  20. langchain_core/indexing/base.py +5 -8
  21. langchain_core/indexing/in_memory.py +23 -3
  22. langchain_core/language_models/__init__.py +6 -2
  23. langchain_core/language_models/_utils.py +28 -4
  24. langchain_core/language_models/base.py +33 -21
  25. langchain_core/language_models/chat_models.py +103 -29
  26. langchain_core/language_models/fake_chat_models.py +5 -7
  27. langchain_core/language_models/llms.py +54 -20
  28. langchain_core/load/dump.py +2 -3
  29. langchain_core/load/load.py +15 -1
  30. langchain_core/load/serializable.py +38 -43
  31. langchain_core/memory.py +7 -3
  32. langchain_core/messages/__init__.py +7 -17
  33. langchain_core/messages/ai.py +41 -34
  34. langchain_core/messages/base.py +16 -7
  35. langchain_core/messages/block_translators/__init__.py +10 -8
  36. langchain_core/messages/block_translators/anthropic.py +3 -1
  37. langchain_core/messages/block_translators/bedrock.py +3 -1
  38. langchain_core/messages/block_translators/bedrock_converse.py +3 -1
  39. langchain_core/messages/block_translators/google_genai.py +3 -1
  40. langchain_core/messages/block_translators/google_vertexai.py +3 -1
  41. langchain_core/messages/block_translators/groq.py +3 -1
  42. langchain_core/messages/block_translators/langchain_v0.py +3 -136
  43. langchain_core/messages/block_translators/ollama.py +3 -1
  44. langchain_core/messages/block_translators/openai.py +252 -10
  45. langchain_core/messages/content.py +26 -124
  46. langchain_core/messages/human.py +2 -13
  47. langchain_core/messages/system.py +2 -6
  48. langchain_core/messages/tool.py +34 -14
  49. langchain_core/messages/utils.py +189 -74
  50. langchain_core/output_parsers/base.py +5 -2
  51. langchain_core/output_parsers/json.py +4 -4
  52. langchain_core/output_parsers/list.py +7 -22
  53. langchain_core/output_parsers/openai_functions.py +3 -0
  54. langchain_core/output_parsers/openai_tools.py +6 -1
  55. langchain_core/output_parsers/pydantic.py +4 -0
  56. langchain_core/output_parsers/string.py +5 -1
  57. langchain_core/output_parsers/xml.py +19 -19
  58. langchain_core/outputs/chat_generation.py +18 -7
  59. langchain_core/outputs/generation.py +14 -3
  60. langchain_core/outputs/llm_result.py +8 -1
  61. langchain_core/prompt_values.py +10 -4
  62. langchain_core/prompts/base.py +6 -11
  63. langchain_core/prompts/chat.py +88 -60
  64. langchain_core/prompts/dict.py +16 -8
  65. langchain_core/prompts/few_shot.py +9 -11
  66. langchain_core/prompts/few_shot_with_templates.py +5 -1
  67. langchain_core/prompts/image.py +12 -5
  68. langchain_core/prompts/loading.py +2 -2
  69. langchain_core/prompts/message.py +5 -6
  70. langchain_core/prompts/pipeline.py +13 -8
  71. langchain_core/prompts/prompt.py +22 -8
  72. langchain_core/prompts/string.py +18 -10
  73. langchain_core/prompts/structured.py +7 -2
  74. langchain_core/rate_limiters.py +2 -2
  75. langchain_core/retrievers.py +7 -6
  76. langchain_core/runnables/base.py +387 -246
  77. langchain_core/runnables/branch.py +11 -28
  78. langchain_core/runnables/config.py +20 -17
  79. langchain_core/runnables/configurable.py +34 -19
  80. langchain_core/runnables/fallbacks.py +20 -13
  81. langchain_core/runnables/graph.py +48 -38
  82. langchain_core/runnables/graph_ascii.py +40 -17
  83. langchain_core/runnables/graph_mermaid.py +54 -25
  84. langchain_core/runnables/graph_png.py +27 -31
  85. langchain_core/runnables/history.py +55 -58
  86. langchain_core/runnables/passthrough.py +44 -21
  87. langchain_core/runnables/retry.py +44 -23
  88. langchain_core/runnables/router.py +9 -8
  89. langchain_core/runnables/schema.py +9 -0
  90. langchain_core/runnables/utils.py +53 -90
  91. langchain_core/stores.py +19 -31
  92. langchain_core/sys_info.py +9 -8
  93. langchain_core/tools/base.py +36 -27
  94. langchain_core/tools/convert.py +25 -14
  95. langchain_core/tools/simple.py +36 -8
  96. langchain_core/tools/structured.py +25 -12
  97. langchain_core/tracers/base.py +2 -2
  98. langchain_core/tracers/context.py +5 -1
  99. langchain_core/tracers/core.py +110 -46
  100. langchain_core/tracers/evaluation.py +22 -26
  101. langchain_core/tracers/event_stream.py +97 -42
  102. langchain_core/tracers/langchain.py +12 -3
  103. langchain_core/tracers/langchain_v1.py +10 -2
  104. langchain_core/tracers/log_stream.py +56 -17
  105. langchain_core/tracers/root_listeners.py +4 -20
  106. langchain_core/tracers/run_collector.py +6 -16
  107. langchain_core/tracers/schemas.py +5 -1
  108. langchain_core/utils/aiter.py +14 -6
  109. langchain_core/utils/env.py +3 -0
  110. langchain_core/utils/function_calling.py +46 -20
  111. langchain_core/utils/interactive_env.py +6 -2
  112. langchain_core/utils/iter.py +12 -5
  113. langchain_core/utils/json.py +12 -3
  114. langchain_core/utils/json_schema.py +156 -40
  115. langchain_core/utils/loading.py +5 -1
  116. langchain_core/utils/mustache.py +25 -16
  117. langchain_core/utils/pydantic.py +38 -9
  118. langchain_core/utils/utils.py +25 -9
  119. langchain_core/vectorstores/base.py +7 -20
  120. langchain_core/vectorstores/in_memory.py +20 -14
  121. langchain_core/vectorstores/utils.py +18 -12
  122. langchain_core/version.py +1 -1
  123. langchain_core-1.0.0a3.dist-info/METADATA +77 -0
  124. langchain_core-1.0.0a3.dist-info/RECORD +181 -0
  125. langchain_core/beta/__init__.py +0 -1
  126. langchain_core/beta/runnables/__init__.py +0 -1
  127. langchain_core/beta/runnables/context.py +0 -448
  128. langchain_core-1.0.0a1.dist-info/METADATA +0 -106
  129. langchain_core-1.0.0a1.dist-info/RECORD +0 -184
  130. {langchain_core-1.0.0a1.dist-info → langchain_core-1.0.0a3.dist-info}/WHEEL +0 -0
  131. {langchain_core-1.0.0a1.dist-info → langchain_core-1.0.0a3.dist-info}/entry_points.txt +0 -0
@@ -3,12 +3,24 @@
3
3
  Adapted from https://github.com/iterative/dvc/blob/main/dvc/dagascii.py.
4
4
  """
5
5
 
6
+ from __future__ import annotations
7
+
6
8
  import math
7
9
  import os
8
10
  from collections.abc import Mapping, Sequence
9
- from typing import Any
11
+ from typing import TYPE_CHECKING, Any
12
+
13
+ try:
14
+ from grandalf.graphs import Edge, Graph, Vertex # type: ignore[import-untyped]
15
+ from grandalf.layouts import SugiyamaLayout # type: ignore[import-untyped]
16
+ from grandalf.routing import route_with_lines # type: ignore[import-untyped]
10
17
 
11
- from langchain_core.runnables.graph import Edge as LangEdge
18
+ _HAS_GRANDALF = True
19
+ except ImportError:
20
+ _HAS_GRANDALF = False
21
+
22
+ if TYPE_CHECKING:
23
+ from langchain_core.runnables.graph import Edge as LangEdge
12
24
 
13
25
 
14
26
  class VertexViewer:
@@ -50,8 +62,11 @@ class AsciiCanvas:
50
62
  """Create an ASCII canvas.
51
63
 
52
64
  Args:
53
- cols (int): number of columns in the canvas. Should be > 1.
54
- lines (int): number of lines in the canvas. Should be > 1.
65
+ cols: number of columns in the canvas. Should be ``> 1``.
66
+ lines: number of lines in the canvas. Should be ``> 1``.
67
+
68
+ Raises:
69
+ ValueError: if canvas dimensions are invalid.
55
70
  """
56
71
  if cols <= 1 or lines <= 1:
57
72
  msg = "Canvas dimensions should be > 1"
@@ -63,7 +78,11 @@ class AsciiCanvas:
63
78
  self.canvas = [[" "] * cols for line in range(lines)]
64
79
 
65
80
  def draw(self) -> str:
66
- """Draws ASCII canvas on the screen."""
81
+ """Draws ASCII canvas on the screen.
82
+
83
+ Returns:
84
+ The ASCII canvas string.
85
+ """
67
86
  lines = map("".join, self.canvas)
68
87
  return os.linesep.join(lines)
69
88
 
@@ -71,12 +90,16 @@ class AsciiCanvas:
71
90
  """Create a point on ASCII canvas.
72
91
 
73
92
  Args:
74
- x (int): x coordinate. Should be >= 0 and < number of columns in
93
+ x: x coordinate. Should be ``>= 0`` and ``<`` number of columns in
75
94
  the canvas.
76
- y (int): y coordinate. Should be >= 0 an < number of lines in the
95
+ y: y coordinate. Should be ``>= 0`` an ``<`` number of lines in the
77
96
  canvas.
78
- char (str): character to place in the specified point on the
97
+ char: character to place in the specified point on the
79
98
  canvas.
99
+
100
+ Raises:
101
+ ValueError: if char is not a single character or if
102
+ coordinates are out of bounds.
80
103
  """
81
104
  if len(char) != 1:
82
105
  msg = "char should be a single character"
@@ -174,13 +197,9 @@ class _EdgeViewer:
174
197
  def _build_sugiyama_layout(
175
198
  vertices: Mapping[str, str], edges: Sequence[LangEdge]
176
199
  ) -> Any:
177
- try:
178
- from grandalf.graphs import Edge, Graph, Vertex # type: ignore[import-untyped]
179
- from grandalf.layouts import SugiyamaLayout # type: ignore[import-untyped]
180
- from grandalf.routing import route_with_lines # type: ignore[import-untyped]
181
- except ImportError as exc:
200
+ if not _HAS_GRANDALF:
182
201
  msg = "Install grandalf to draw graphs: `pip install grandalf`."
183
- raise ImportError(msg) from exc
202
+ raise ImportError(msg)
184
203
 
185
204
  #
186
205
  # Just a reminder about naming conventions:
@@ -225,11 +244,15 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str:
225
244
  """Build a DAG and draw it in ASCII.
226
245
 
227
246
  Args:
228
- vertices (list): list of graph vertices.
229
- edges (list): list of graph edges.
247
+ vertices: list of graph vertices.
248
+ edges: list of graph edges.
249
+
250
+ Raises:
251
+ ValueError: if the canvas dimensions are invalid or if
252
+ edge coordinates are invalid.
230
253
 
231
254
  Returns:
232
- str: ASCII representation
255
+ ASCII representation
233
256
 
234
257
  Example:
235
258
 
@@ -1,24 +1,43 @@
1
1
  """Mermaid graph drawing utilities."""
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import asyncio
4
6
  import base64
5
7
  import random
6
8
  import re
9
+ import string
7
10
  import time
8
11
  from dataclasses import asdict
9
12
  from pathlib import Path
10
- from typing import Any, Literal, Optional
13
+ from typing import TYPE_CHECKING, Any, Literal, Optional
11
14
 
12
15
  import yaml
13
16
 
14
17
  from langchain_core.runnables.graph import (
15
18
  CurveStyle,
16
- Edge,
17
19
  MermaidDrawMethod,
18
- Node,
19
20
  NodeStyles,
20
21
  )
21
22
 
23
+ if TYPE_CHECKING:
24
+ from langchain_core.runnables.graph import Edge, Node
25
+
26
+
27
+ try:
28
+ import requests
29
+
30
+ _HAS_REQUESTS = True
31
+ except ImportError:
32
+ _HAS_REQUESTS = False
33
+
34
+ try:
35
+ from pyppeteer import launch # type: ignore[import-not-found]
36
+
37
+ _HAS_PYPPETEER = True
38
+ except ImportError:
39
+ _HAS_PYPPETEER = False
40
+
22
41
  MARKDOWN_SPECIAL_CHARS = "*_`"
23
42
 
24
43
 
@@ -130,7 +149,7 @@ def draw_mermaid(
130
149
  + "</em></small>"
131
150
  )
132
151
  node_label = format_dict.get(key, format_dict[default_class_label]).format(
133
- _escape_node_label(key), label
152
+ _to_safe_id(key), label
134
153
  )
135
154
  return f"{indent}{node_label}\n"
136
155
 
@@ -155,7 +174,7 @@ def draw_mermaid(
155
174
  nonlocal mermaid_graph
156
175
  self_loop = len(edges) == 1 and edges[0].source == edges[0].target
157
176
  if prefix and not self_loop:
158
- subgraph = prefix.split(":")[-1]
177
+ subgraph = prefix.rsplit(":", maxsplit=1)[-1]
159
178
  if subgraph in seen_subgraphs:
160
179
  msg = (
161
180
  f"Found duplicate subgraph '{subgraph}' -- this likely means that "
@@ -193,8 +212,7 @@ def draw_mermaid(
193
212
  edge_label = " -.-> " if edge.conditional else " --> "
194
213
 
195
214
  mermaid_graph += (
196
- f"\t{_escape_node_label(source)}{edge_label}"
197
- f"{_escape_node_label(target)};\n"
215
+ f"\t{_to_safe_id(source)}{edge_label}{_to_safe_id(target)};\n"
198
216
  )
199
217
 
200
218
  # Recursively add nested subgraphs
@@ -214,7 +232,7 @@ def draw_mermaid(
214
232
 
215
233
  # Add remaining subgraphs with edges
216
234
  for prefix, edges_ in edge_groups.items():
217
- if ":" in prefix or prefix == "":
235
+ if not prefix or ":" in prefix:
218
236
  continue
219
237
  add_subgraph(edges_, prefix)
220
238
  seen_subgraphs.add(prefix)
@@ -238,9 +256,18 @@ def draw_mermaid(
238
256
  return mermaid_graph
239
257
 
240
258
 
241
- def _escape_node_label(node_label: str) -> str:
242
- """Escapes the node label for Mermaid syntax."""
243
- return re.sub(r"[^a-zA-Z-_0-9]", "_", node_label)
259
+ def _to_safe_id(label: str) -> str:
260
+ """Convert a string into a Mermaid-compatible node id.
261
+
262
+ Keep [a-zA-Z0-9_-] characters unchanged.
263
+ Map every other character -> backslash + lowercase hex codepoint.
264
+
265
+ Result is guaranteed to be unique and Mermaid-compatible,
266
+ so nodes with special characters always render correctly.
267
+ """
268
+ allowed = string.ascii_letters + string.digits + "_-"
269
+ out = [ch if ch in allowed else "\\" + format(ord(ch), "x") for ch in label]
270
+ return "".join(out)
244
271
 
245
272
 
246
273
  def _generate_mermaid_graph_styles(node_colors: NodeStyles) -> str:
@@ -259,6 +286,7 @@ def draw_mermaid_png(
259
286
  padding: int = 10,
260
287
  max_retries: int = 1,
261
288
  retry_delay: float = 1.0,
289
+ base_url: Optional[str] = None,
262
290
  ) -> bytes:
263
291
  """Draws a Mermaid graph as PNG using provided syntax.
264
292
 
@@ -275,6 +303,8 @@ def draw_mermaid_png(
275
303
  Defaults to 1.
276
304
  retry_delay (float, optional): Delay between retries (MermaidDrawMethod.API).
277
305
  Defaults to 1.0.
306
+ base_url (str, optional): Base URL for the Mermaid.ink API.
307
+ Defaults to None.
278
308
 
279
309
  Returns:
280
310
  bytes: PNG image bytes.
@@ -283,8 +313,6 @@ def draw_mermaid_png(
283
313
  ValueError: If an invalid draw method is provided.
284
314
  """
285
315
  if draw_method == MermaidDrawMethod.PYPPETEER:
286
- import asyncio
287
-
288
316
  img_bytes = asyncio.run(
289
317
  _render_mermaid_using_pyppeteer(
290
318
  mermaid_syntax, output_file_path, background_color, padding
@@ -297,6 +325,7 @@ def draw_mermaid_png(
297
325
  background_color=background_color,
298
326
  max_retries=max_retries,
299
327
  retry_delay=retry_delay,
328
+ base_url=base_url,
300
329
  )
301
330
  else:
302
331
  supported_methods = ", ".join([m.value for m in MermaidDrawMethod])
@@ -317,11 +346,9 @@ async def _render_mermaid_using_pyppeteer(
317
346
  device_scale_factor: int = 3,
318
347
  ) -> bytes:
319
348
  """Renders Mermaid graph using Pyppeteer."""
320
- try:
321
- from pyppeteer import launch # type: ignore[import-not-found]
322
- except ImportError as e:
349
+ if not _HAS_PYPPETEER:
323
350
  msg = "Install Pyppeteer to use the Pyppeteer method: `pip install pyppeteer`."
324
- raise ImportError(msg) from e
351
+ raise ImportError(msg)
325
352
 
326
353
  browser = await launch()
327
354
  page = await browser.newPage()
@@ -390,16 +417,18 @@ def _render_mermaid_using_api(
390
417
  file_type: Optional[Literal["jpeg", "png", "webp"]] = "png",
391
418
  max_retries: int = 1,
392
419
  retry_delay: float = 1.0,
420
+ base_url: Optional[str] = None,
393
421
  ) -> bytes:
394
422
  """Renders Mermaid graph using the Mermaid.INK API."""
395
- try:
396
- import requests
397
- except ImportError as e:
423
+ # Defaults to using the public mermaid.ink server.
424
+ base_url = base_url if base_url is not None else "https://mermaid.ink"
425
+
426
+ if not _HAS_REQUESTS:
398
427
  msg = (
399
428
  "Install the `requests` module to use the Mermaid.INK API: "
400
429
  "`pip install requests`."
401
430
  )
402
- raise ImportError(msg) from e
431
+ raise ImportError(msg)
403
432
 
404
433
  # Use Mermaid API to render the image
405
434
  mermaid_syntax_encoded = base64.b64encode(mermaid_syntax.encode("utf8")).decode(
@@ -413,7 +442,7 @@ def _render_mermaid_using_api(
413
442
  background_color = f"!{background_color}"
414
443
 
415
444
  image_url = (
416
- f"https://mermaid.ink/img/{mermaid_syntax_encoded}"
445
+ f"{base_url}/img/{mermaid_syntax_encoded}"
417
446
  f"?type={file_type}&bgColor={background_color}"
418
447
  )
419
448
 
@@ -445,7 +474,7 @@ def _render_mermaid_using_api(
445
474
 
446
475
  # For other status codes, fail immediately
447
476
  msg = (
448
- "Failed to reach https://mermaid.ink/ API while trying to render "
477
+ f"Failed to reach {base_url} API while trying to render "
449
478
  f"your graph. Status code: {response.status_code}.\n\n"
450
479
  ) + error_msg_suffix
451
480
  raise ValueError(msg)
@@ -457,14 +486,14 @@ def _render_mermaid_using_api(
457
486
  time.sleep(sleep_time)
458
487
  else:
459
488
  msg = (
460
- "Failed to reach https://mermaid.ink/ API while trying to render "
489
+ f"Failed to reach {base_url} API while trying to render "
461
490
  f"your graph after {max_retries} retries. "
462
491
  ) + error_msg_suffix
463
492
  raise ValueError(msg) from e
464
493
 
465
494
  # This should not be reached, but just in case
466
495
  msg = (
467
- "Failed to reach https://mermaid.ink/ API while trying to render "
496
+ f"Failed to reach {base_url} API while trying to render "
468
497
  f"your graph after {max_retries} retries. "
469
498
  ) + error_msg_suffix
470
499
  raise ValueError(msg)
@@ -4,29 +4,25 @@ from typing import Any, Optional
4
4
 
5
5
  from langchain_core.runnables.graph import Graph, LabelsDict
6
6
 
7
+ try:
8
+ import pygraphviz as pgv # type: ignore[import-not-found]
9
+
10
+ _HAS_PYGRAPHVIZ = True
11
+ except ImportError:
12
+ _HAS_PYGRAPHVIZ = False
13
+
7
14
 
8
15
  class PngDrawer:
9
16
  """Helper class to draw a state graph into a PNG file.
10
17
 
11
- It requires `graphviz` and `pygraphviz` to be installed.
12
- :param fontname: The font to use for the labels
13
- :param labels: A dictionary of label overrides. The dictionary
14
- should have the following format:
15
- {
16
- "nodes": {
17
- "node1": "CustomLabel1",
18
- "node2": "CustomLabel2",
19
- "__end__": "End Node"
20
- },
21
- "edges": {
22
- "continue": "ContinueLabel",
23
- "end": "EndLabel"
24
- }
25
- }
26
- The keys are the original labels, and the values are the new labels.
27
- Usage:
28
- drawer = PngDrawer()
29
- drawer.draw(state_graph, 'graph.png')
18
+ It requires ``graphviz`` and ``pygraphviz`` to be installed.
19
+
20
+ Example:
21
+
22
+ .. code-block:: python
23
+
24
+ drawer = PngDrawer()
25
+ drawer.draw(state_graph, "graph.png")
30
26
  """
31
27
 
32
28
  def __init__(
@@ -85,9 +81,6 @@ class PngDrawer:
85
81
  Args:
86
82
  viz: The graphviz object.
87
83
  node: The node to add.
88
-
89
- Returns:
90
- None
91
84
  """
92
85
  viz.add_node(
93
86
  node,
@@ -114,9 +107,6 @@ class PngDrawer:
114
107
  target: The target node.
115
108
  label: The label for the edge. Defaults to None.
116
109
  conditional: Whether the edge is conditional. Defaults to False.
117
-
118
- Returns:
119
- None
120
110
  """
121
111
  viz.add_edge(
122
112
  source,
@@ -131,14 +121,20 @@ class PngDrawer:
131
121
  """Draw the given state graph into a PNG file.
132
122
 
133
123
  Requires `graphviz` and `pygraphviz` to be installed.
134
- :param graph: The graph to draw
135
- :param output_path: The path to save the PNG. If None, PNG bytes are returned.
124
+
125
+ Args:
126
+ graph: The graph to draw
127
+ output_path: The path to save the PNG. If None, PNG bytes are returned.
128
+
129
+ Raises:
130
+ ImportError: If ``pygraphviz`` is not installed.
131
+
132
+ Returns:
133
+ The PNG bytes if ``output_path`` is None, else None.
136
134
  """
137
- try:
138
- import pygraphviz as pgv # type: ignore[import-not-found]
139
- except ImportError as exc:
135
+ if not _HAS_PYGRAPHVIZ:
140
136
  msg = "Install pygraphviz to draw graphs: `pip install pygraphviz`."
141
- raise ImportError(msg) from exc
137
+ raise ImportError(msg)
142
138
 
143
139
  # Create a directed graph
144
140
  viz = pgv.AGraph(directed=True, nodesep=0.9, ranksep=1.0)
@@ -18,6 +18,7 @@ from typing_extensions import override
18
18
 
19
19
  from langchain_core.chat_history import BaseChatMessageHistory
20
20
  from langchain_core.load.load import load
21
+ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
21
22
  from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
22
23
  from langchain_core.runnables.passthrough import RunnablePassthrough
23
24
  from langchain_core.runnables.utils import (
@@ -29,7 +30,6 @@ from langchain_core.utils.pydantic import create_model_v2
29
30
 
30
31
  if TYPE_CHECKING:
31
32
  from langchain_core.language_models.base import LanguageModelLike
32
- from langchain_core.messages.base import BaseMessage
33
33
  from langchain_core.runnables.config import RunnableConfig
34
34
  from langchain_core.tracers.schemas import Run
35
35
 
@@ -72,20 +72,6 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
72
72
  For production use cases, you will want to use a persistent implementation
73
73
  of chat message history, such as ``RedisChatMessageHistory``.
74
74
 
75
- Parameters:
76
- get_session_history: Function that returns a new BaseChatMessageHistory.
77
- This function should either take a single positional argument
78
- `session_id` of type string and return a corresponding
79
- chat message history instance.
80
- input_messages_key: Must be specified if the base runnable accepts a dict
81
- as input. The key in the input dict that contains the messages.
82
- output_messages_key: Must be specified if the base Runnable returns a dict
83
- as output. The key in the output dict that contains the messages.
84
- history_messages_key: Must be specified if the base runnable accepts a dict
85
- as input and expects a separate key for historical messages.
86
- history_factory_config: Configure fields that should be passed to the
87
- chat history factory. See ``ConfigurableFieldSpec`` for more details.
88
-
89
75
  Example: Chat message history with an in-memory implementation for testing.
90
76
 
91
77
  .. code-block:: python
@@ -145,11 +131,13 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
145
131
  from langchain_core.runnables.history import RunnableWithMessageHistory
146
132
 
147
133
 
148
- prompt = ChatPromptTemplate.from_messages([
149
- ("system", "You're an assistant who's good at {ability}"),
150
- MessagesPlaceholder(variable_name="history"),
151
- ("human", "{question}"),
152
- ])
134
+ prompt = ChatPromptTemplate.from_messages(
135
+ [
136
+ ("system", "You're an assistant who's good at {ability}"),
137
+ MessagesPlaceholder(variable_name="history"),
138
+ ("human", "{question}"),
139
+ ]
140
+ )
153
141
 
154
142
  chain = prompt | ChatAnthropic(model="claude-2")
155
143
 
@@ -162,18 +150,22 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
162
150
  history_messages_key="history",
163
151
  )
164
152
 
165
- print(chain_with_history.invoke( # noqa: T201
166
- {"ability": "math", "question": "What does cosine mean?"},
167
- config={"configurable": {"session_id": "foo"}}
168
- ))
153
+ print(
154
+ chain_with_history.invoke( # noqa: T201
155
+ {"ability": "math", "question": "What does cosine mean?"},
156
+ config={"configurable": {"session_id": "foo"}},
157
+ )
158
+ )
169
159
 
170
160
  # Uses the store defined in the example above.
171
161
  print(store) # noqa: T201
172
162
 
173
- print(chain_with_history.invoke( # noqa: T201
174
- {"ability": "math", "question": "What's its inverse"},
175
- config={"configurable": {"session_id": "foo"}}
176
- ))
163
+ print(
164
+ chain_with_history.invoke( # noqa: T201
165
+ {"ability": "math", "question": "What's its inverse"},
166
+ config={"configurable": {"session_id": "foo"}},
167
+ )
168
+ )
177
169
 
178
170
  print(store) # noqa: T201
179
171
 
@@ -184,6 +176,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
184
176
 
185
177
  store = {}
186
178
 
179
+
187
180
  def get_session_history(
188
181
  user_id: str, conversation_id: str
189
182
  ) -> BaseChatMessageHistory:
@@ -191,11 +184,14 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
191
184
  store[(user_id, conversation_id)] = InMemoryHistory()
192
185
  return store[(user_id, conversation_id)]
193
186
 
194
- prompt = ChatPromptTemplate.from_messages([
195
- ("system", "You're an assistant who's good at {ability}"),
196
- MessagesPlaceholder(variable_name="history"),
197
- ("human", "{question}"),
198
- ])
187
+
188
+ prompt = ChatPromptTemplate.from_messages(
189
+ [
190
+ ("system", "You're an assistant who's good at {ability}"),
191
+ MessagesPlaceholder(variable_name="history"),
192
+ ("human", "{question}"),
193
+ ]
194
+ )
199
195
 
200
196
  chain = prompt | ChatAnthropic(model="claude-2")
201
197
 
@@ -226,16 +222,27 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
226
222
 
227
223
  with_message_history.invoke(
228
224
  {"ability": "math", "question": "What does cosine mean?"},
229
- config={"configurable": {"user_id": "123", "conversation_id": "1"}}
225
+ config={"configurable": {"user_id": "123", "conversation_id": "1"}},
230
226
  )
231
227
 
232
228
  """
233
229
 
234
230
  get_session_history: GetSessionHistoryCallable
231
+ """Function that returns a new BaseChatMessageHistory.
232
+ This function should either take a single positional argument ``session_id`` of type
233
+ string and return a corresponding chat message history instance"""
235
234
  input_messages_key: Optional[str] = None
235
+ """Must be specified if the base runnable accepts a dict as input.
236
+ The key in the input dict that contains the messages."""
236
237
  output_messages_key: Optional[str] = None
238
+ """Must be specified if the base Runnable returns a dict as output.
239
+ The key in the output dict that contains the messages."""
237
240
  history_messages_key: Optional[str] = None
241
+ """Must be specified if the base runnable accepts a dict as input and expects a
242
+ separate key for historical messages."""
238
243
  history_factory_config: Sequence[ConfigurableFieldSpec]
244
+ """Configure fields that should be passed to the chat history factory.
245
+ See ``ConfigurableFieldSpec`` for more details."""
239
246
 
240
247
  def __init__(
241
248
  self,
@@ -261,17 +268,21 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
261
268
  """Initialize RunnableWithMessageHistory.
262
269
 
263
270
  Args:
264
- runnable: The base Runnable to be wrapped. Must take as input one of:
265
- 1. A list of BaseMessages
271
+ runnable: The base Runnable to be wrapped.
272
+ Must take as input one of:
273
+
274
+ 1. A list of ``BaseMessage``
266
275
  2. A dict with one key for all messages
267
276
  3. A dict with one key for the current input string/message(s) and
268
- a separate key for historical messages. If the input key points
269
- to a string, it will be treated as a HumanMessage in history.
277
+ a separate key for historical messages. If the input key points
278
+ to a string, it will be treated as a ``HumanMessage`` in history.
270
279
 
271
280
  Must return as output one of:
272
- 1. A string which can be treated as an AIMessage
273
- 2. A BaseMessage or sequence of BaseMessages
274
- 3. A dict with a key for a BaseMessage or sequence of BaseMessages
281
+
282
+ 1. A string which can be treated as an ``AIMessage``
283
+ 2. A ``BaseMessage`` or sequence of ``BaseMessage``
284
+ 3. A dict with a key for a ``BaseMessage`` or sequence of
285
+ ``BaseMessage``
275
286
 
276
287
  get_session_history: Function that returns a new BaseChatMessageHistory.
277
288
  This function should either take a single positional argument
@@ -280,11 +291,8 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
280
291
  .. code-block:: python
281
292
 
282
293
  def get_session_history(
283
- session_id: str,
284
- *,
285
- user_id: Optional[str]=None
286
- ) -> BaseChatMessageHistory:
287
- ...
294
+ session_id: str, *, user_id: Optional[str] = None
295
+ ) -> BaseChatMessageHistory: ...
288
296
 
289
297
  Or it should take keyword arguments that match the keys of
290
298
  `session_history_config_specs` and return a corresponding
@@ -296,8 +304,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
296
304
  *,
297
305
  user_id: str,
298
306
  thread_id: str,
299
- ) -> BaseChatMessageHistory:
300
- ...
307
+ ) -> BaseChatMessageHistory: ...
301
308
 
302
309
  input_messages_key: Must be specified if the base runnable accepts a dict
303
310
  as input. Default is None.
@@ -377,8 +384,6 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
377
384
  def get_input_schema(
378
385
  self, config: Optional[RunnableConfig] = None
379
386
  ) -> type[BaseModel]:
380
- from langchain_core.messages import BaseMessage
381
-
382
387
  fields: dict = {}
383
388
  if self.input_messages_key and self.history_messages_key:
384
389
  fields[self.input_messages_key] = (
@@ -440,8 +445,6 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
440
445
  def _get_input_messages(
441
446
  self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
442
447
  ) -> list[BaseMessage]:
443
- from langchain_core.messages import BaseMessage
444
-
445
448
  # If dictionary, try to pluck the single key representing messages
446
449
  if isinstance(input_val, dict):
447
450
  if self.input_messages_key:
@@ -454,8 +457,6 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
454
457
 
455
458
  # If value is a string, convert to a human message
456
459
  if isinstance(input_val, str):
457
- from langchain_core.messages import HumanMessage
458
-
459
460
  return [HumanMessage(content=input_val)]
460
461
  # If value is a single message, convert to a list
461
462
  if isinstance(input_val, BaseMessage):
@@ -482,8 +483,6 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
482
483
  def _get_output_messages(
483
484
  self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
484
485
  ) -> list[BaseMessage]:
485
- from langchain_core.messages import BaseMessage
486
-
487
486
  # If dictionary, try to pluck the single key representing messages
488
487
  if isinstance(output_val, dict):
489
488
  if self.output_messages_key:
@@ -500,8 +499,6 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
500
499
  output_val = output_val[key]
501
500
 
502
501
  if isinstance(output_val, str):
503
- from langchain_core.messages import AIMessage
504
-
505
502
  return [AIMessage(content=output_val)]
506
503
  # If value is a single message, convert to a list
507
504
  if isinstance(output_val, BaseMessage):