langchain-core 0.4.0.dev0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain-core might be problematic. Click here for more details.

Files changed (172) hide show
  1. langchain_core/__init__.py +1 -1
  2. langchain_core/_api/__init__.py +3 -4
  3. langchain_core/_api/beta_decorator.py +45 -70
  4. langchain_core/_api/deprecation.py +80 -80
  5. langchain_core/_api/path.py +22 -8
  6. langchain_core/_import_utils.py +10 -4
  7. langchain_core/agents.py +25 -21
  8. langchain_core/caches.py +53 -63
  9. langchain_core/callbacks/__init__.py +1 -8
  10. langchain_core/callbacks/base.py +341 -348
  11. langchain_core/callbacks/file.py +55 -44
  12. langchain_core/callbacks/manager.py +546 -683
  13. langchain_core/callbacks/stdout.py +29 -30
  14. langchain_core/callbacks/streaming_stdout.py +35 -36
  15. langchain_core/callbacks/usage.py +65 -70
  16. langchain_core/chat_history.py +48 -55
  17. langchain_core/document_loaders/base.py +46 -21
  18. langchain_core/document_loaders/langsmith.py +39 -36
  19. langchain_core/documents/__init__.py +0 -1
  20. langchain_core/documents/base.py +96 -74
  21. langchain_core/documents/compressor.py +12 -9
  22. langchain_core/documents/transformers.py +29 -28
  23. langchain_core/embeddings/fake.py +56 -57
  24. langchain_core/env.py +2 -3
  25. langchain_core/example_selectors/base.py +12 -0
  26. langchain_core/example_selectors/length_based.py +1 -1
  27. langchain_core/example_selectors/semantic_similarity.py +21 -25
  28. langchain_core/exceptions.py +15 -9
  29. langchain_core/globals.py +4 -163
  30. langchain_core/indexing/api.py +132 -125
  31. langchain_core/indexing/base.py +64 -67
  32. langchain_core/indexing/in_memory.py +26 -6
  33. langchain_core/language_models/__init__.py +15 -27
  34. langchain_core/language_models/_utils.py +267 -117
  35. langchain_core/language_models/base.py +92 -177
  36. langchain_core/language_models/chat_models.py +547 -407
  37. langchain_core/language_models/fake.py +11 -11
  38. langchain_core/language_models/fake_chat_models.py +72 -118
  39. langchain_core/language_models/llms.py +168 -242
  40. langchain_core/load/dump.py +8 -11
  41. langchain_core/load/load.py +32 -28
  42. langchain_core/load/mapping.py +2 -4
  43. langchain_core/load/serializable.py +50 -56
  44. langchain_core/messages/__init__.py +36 -51
  45. langchain_core/messages/ai.py +377 -150
  46. langchain_core/messages/base.py +239 -47
  47. langchain_core/messages/block_translators/__init__.py +111 -0
  48. langchain_core/messages/block_translators/anthropic.py +470 -0
  49. langchain_core/messages/block_translators/bedrock.py +94 -0
  50. langchain_core/messages/block_translators/bedrock_converse.py +297 -0
  51. langchain_core/messages/block_translators/google_genai.py +530 -0
  52. langchain_core/messages/block_translators/google_vertexai.py +21 -0
  53. langchain_core/messages/block_translators/groq.py +143 -0
  54. langchain_core/messages/block_translators/langchain_v0.py +301 -0
  55. langchain_core/messages/block_translators/openai.py +1010 -0
  56. langchain_core/messages/chat.py +2 -3
  57. langchain_core/messages/content.py +1423 -0
  58. langchain_core/messages/function.py +7 -7
  59. langchain_core/messages/human.py +44 -38
  60. langchain_core/messages/modifier.py +3 -2
  61. langchain_core/messages/system.py +40 -27
  62. langchain_core/messages/tool.py +160 -58
  63. langchain_core/messages/utils.py +527 -638
  64. langchain_core/output_parsers/__init__.py +1 -14
  65. langchain_core/output_parsers/base.py +68 -104
  66. langchain_core/output_parsers/json.py +13 -17
  67. langchain_core/output_parsers/list.py +11 -33
  68. langchain_core/output_parsers/openai_functions.py +56 -74
  69. langchain_core/output_parsers/openai_tools.py +68 -109
  70. langchain_core/output_parsers/pydantic.py +15 -13
  71. langchain_core/output_parsers/string.py +6 -2
  72. langchain_core/output_parsers/transform.py +17 -60
  73. langchain_core/output_parsers/xml.py +34 -44
  74. langchain_core/outputs/__init__.py +1 -1
  75. langchain_core/outputs/chat_generation.py +26 -11
  76. langchain_core/outputs/chat_result.py +1 -3
  77. langchain_core/outputs/generation.py +17 -6
  78. langchain_core/outputs/llm_result.py +15 -8
  79. langchain_core/prompt_values.py +29 -123
  80. langchain_core/prompts/__init__.py +3 -27
  81. langchain_core/prompts/base.py +48 -63
  82. langchain_core/prompts/chat.py +259 -288
  83. langchain_core/prompts/dict.py +19 -11
  84. langchain_core/prompts/few_shot.py +84 -90
  85. langchain_core/prompts/few_shot_with_templates.py +14 -12
  86. langchain_core/prompts/image.py +19 -14
  87. langchain_core/prompts/loading.py +6 -8
  88. langchain_core/prompts/message.py +7 -8
  89. langchain_core/prompts/prompt.py +42 -43
  90. langchain_core/prompts/string.py +37 -16
  91. langchain_core/prompts/structured.py +43 -46
  92. langchain_core/rate_limiters.py +51 -60
  93. langchain_core/retrievers.py +52 -192
  94. langchain_core/runnables/base.py +1727 -1683
  95. langchain_core/runnables/branch.py +52 -73
  96. langchain_core/runnables/config.py +89 -103
  97. langchain_core/runnables/configurable.py +128 -130
  98. langchain_core/runnables/fallbacks.py +93 -82
  99. langchain_core/runnables/graph.py +127 -127
  100. langchain_core/runnables/graph_ascii.py +63 -41
  101. langchain_core/runnables/graph_mermaid.py +87 -70
  102. langchain_core/runnables/graph_png.py +31 -36
  103. langchain_core/runnables/history.py +145 -161
  104. langchain_core/runnables/passthrough.py +141 -144
  105. langchain_core/runnables/retry.py +84 -68
  106. langchain_core/runnables/router.py +33 -37
  107. langchain_core/runnables/schema.py +79 -72
  108. langchain_core/runnables/utils.py +95 -139
  109. langchain_core/stores.py +85 -131
  110. langchain_core/structured_query.py +11 -15
  111. langchain_core/sys_info.py +31 -32
  112. langchain_core/tools/__init__.py +1 -14
  113. langchain_core/tools/base.py +221 -247
  114. langchain_core/tools/convert.py +144 -161
  115. langchain_core/tools/render.py +10 -10
  116. langchain_core/tools/retriever.py +12 -19
  117. langchain_core/tools/simple.py +52 -29
  118. langchain_core/tools/structured.py +56 -60
  119. langchain_core/tracers/__init__.py +1 -9
  120. langchain_core/tracers/_streaming.py +6 -7
  121. langchain_core/tracers/base.py +103 -112
  122. langchain_core/tracers/context.py +29 -48
  123. langchain_core/tracers/core.py +142 -105
  124. langchain_core/tracers/evaluation.py +30 -34
  125. langchain_core/tracers/event_stream.py +162 -117
  126. langchain_core/tracers/langchain.py +34 -36
  127. langchain_core/tracers/log_stream.py +87 -49
  128. langchain_core/tracers/memory_stream.py +3 -3
  129. langchain_core/tracers/root_listeners.py +18 -34
  130. langchain_core/tracers/run_collector.py +8 -20
  131. langchain_core/tracers/schemas.py +0 -125
  132. langchain_core/tracers/stdout.py +3 -3
  133. langchain_core/utils/__init__.py +1 -4
  134. langchain_core/utils/_merge.py +47 -9
  135. langchain_core/utils/aiter.py +70 -66
  136. langchain_core/utils/env.py +12 -9
  137. langchain_core/utils/function_calling.py +139 -206
  138. langchain_core/utils/html.py +7 -8
  139. langchain_core/utils/input.py +6 -6
  140. langchain_core/utils/interactive_env.py +6 -2
  141. langchain_core/utils/iter.py +48 -45
  142. langchain_core/utils/json.py +14 -4
  143. langchain_core/utils/json_schema.py +159 -43
  144. langchain_core/utils/mustache.py +32 -25
  145. langchain_core/utils/pydantic.py +67 -40
  146. langchain_core/utils/strings.py +5 -5
  147. langchain_core/utils/usage.py +1 -1
  148. langchain_core/utils/utils.py +104 -62
  149. langchain_core/vectorstores/base.py +131 -179
  150. langchain_core/vectorstores/in_memory.py +113 -182
  151. langchain_core/vectorstores/utils.py +23 -17
  152. langchain_core/version.py +1 -1
  153. langchain_core-1.0.0.dist-info/METADATA +68 -0
  154. langchain_core-1.0.0.dist-info/RECORD +172 -0
  155. {langchain_core-0.4.0.dev0.dist-info → langchain_core-1.0.0.dist-info}/WHEEL +1 -1
  156. langchain_core/beta/__init__.py +0 -1
  157. langchain_core/beta/runnables/__init__.py +0 -1
  158. langchain_core/beta/runnables/context.py +0 -448
  159. langchain_core/memory.py +0 -116
  160. langchain_core/messages/content_blocks.py +0 -1435
  161. langchain_core/prompts/pipeline.py +0 -133
  162. langchain_core/pydantic_v1/__init__.py +0 -30
  163. langchain_core/pydantic_v1/dataclasses.py +0 -23
  164. langchain_core/pydantic_v1/main.py +0 -23
  165. langchain_core/tracers/langchain_v1.py +0 -23
  166. langchain_core/utils/loading.py +0 -31
  167. langchain_core/v1/__init__.py +0 -1
  168. langchain_core/v1/chat_models.py +0 -1047
  169. langchain_core/v1/messages.py +0 -755
  170. langchain_core-0.4.0.dev0.dist-info/METADATA +0 -108
  171. langchain_core-0.4.0.dev0.dist-info/RECORD +0 -177
  172. langchain_core-0.4.0.dev0.dist-info/entry_points.txt +0 -4
@@ -3,12 +3,24 @@
3
3
  Adapted from https://github.com/iterative/dvc/blob/main/dvc/dagascii.py.
4
4
  """
5
5
 
6
+ from __future__ import annotations
7
+
6
8
  import math
7
9
  import os
8
10
  from collections.abc import Mapping, Sequence
9
- from typing import Any
11
+ from typing import TYPE_CHECKING, Any
12
+
13
+ try:
14
+ from grandalf.graphs import Edge, Graph, Vertex # type: ignore[import-untyped]
15
+ from grandalf.layouts import SugiyamaLayout # type: ignore[import-untyped]
16
+ from grandalf.routing import route_with_lines # type: ignore[import-untyped]
17
+
18
+ _HAS_GRANDALF = True
19
+ except ImportError:
20
+ _HAS_GRANDALF = False
10
21
 
11
- from langchain_core.runnables.graph import Edge as LangEdge
22
+ if TYPE_CHECKING:
23
+ from langchain_core.runnables.graph import Edge as LangEdge
12
24
 
13
25
 
14
26
  class VertexViewer:
@@ -50,8 +62,11 @@ class AsciiCanvas:
50
62
  """Create an ASCII canvas.
51
63
 
52
64
  Args:
53
- cols (int): number of columns in the canvas. Should be > 1.
54
- lines (int): number of lines in the canvas. Should be > 1.
65
+ cols: number of columns in the canvas. Should be `> 1`.
66
+ lines: number of lines in the canvas. Should be `> 1`.
67
+
68
+ Raises:
69
+ ValueError: if canvas dimensions are invalid.
55
70
  """
56
71
  if cols <= 1 or lines <= 1:
57
72
  msg = "Canvas dimensions should be > 1"
@@ -63,7 +78,11 @@ class AsciiCanvas:
63
78
  self.canvas = [[" "] * cols for line in range(lines)]
64
79
 
65
80
  def draw(self) -> str:
66
- """Draws ASCII canvas on the screen."""
81
+ """Draws ASCII canvas on the screen.
82
+
83
+ Returns:
84
+ The ASCII canvas string.
85
+ """
67
86
  lines = map("".join, self.canvas)
68
87
  return os.linesep.join(lines)
69
88
 
@@ -71,12 +90,16 @@ class AsciiCanvas:
71
90
  """Create a point on ASCII canvas.
72
91
 
73
92
  Args:
74
- x (int): x coordinate. Should be >= 0 and < number of columns in
93
+ x: x coordinate. Should be `>= 0` and `<` number of columns in
75
94
  the canvas.
76
- y (int): y coordinate. Should be >= 0 an < number of lines in the
95
+ y: y coordinate. Should be `>= 0` an `<` number of lines in the
77
96
  canvas.
78
- char (str): character to place in the specified point on the
97
+ char: character to place in the specified point on the
79
98
  canvas.
99
+
100
+ Raises:
101
+ ValueError: if char is not a single character or if
102
+ coordinates are out of bounds.
80
103
  """
81
104
  if len(char) != 1:
82
105
  msg = "char should be a single character"
@@ -94,11 +117,11 @@ class AsciiCanvas:
94
117
  """Create a line on ASCII canvas.
95
118
 
96
119
  Args:
97
- x0 (int): x coordinate where the line should start.
98
- y0 (int): y coordinate where the line should start.
99
- x1 (int): x coordinate where the line should end.
100
- y1 (int): y coordinate where the line should end.
101
- char (str): character to draw the line with.
120
+ x0: x coordinate where the line should start.
121
+ y0: y coordinate where the line should start.
122
+ x1: x coordinate where the line should end.
123
+ y1: y coordinate where the line should end.
124
+ char: character to draw the line with.
102
125
  """
103
126
  if x0 > x1:
104
127
  x1, x0 = x0, x1
@@ -126,9 +149,9 @@ class AsciiCanvas:
126
149
  """Print a text on ASCII canvas.
127
150
 
128
151
  Args:
129
- x (int): x coordinate where the text should start.
130
- y (int): y coordinate where the text should start.
131
- text (str): string that should be printed.
152
+ x: x coordinate where the text should start.
153
+ y: y coordinate where the text should start.
154
+ text: string that should be printed.
132
155
  """
133
156
  for i, char in enumerate(text):
134
157
  self.point(x + i, y, char)
@@ -137,10 +160,10 @@ class AsciiCanvas:
137
160
  """Create a box on ASCII canvas.
138
161
 
139
162
  Args:
140
- x0 (int): x coordinate of the box corner.
141
- y0 (int): y coordinate of the box corner.
142
- width (int): box width.
143
- height (int): box height.
163
+ x0: x coordinate of the box corner.
164
+ y0: y coordinate of the box corner.
165
+ width: box width.
166
+ height: box height.
144
167
  """
145
168
  if width <= 1 or height <= 1:
146
169
  msg = "Box dimensions should be > 1"
@@ -174,13 +197,9 @@ class _EdgeViewer:
174
197
  def _build_sugiyama_layout(
175
198
  vertices: Mapping[str, str], edges: Sequence[LangEdge]
176
199
  ) -> Any:
177
- try:
178
- from grandalf.graphs import Edge, Graph, Vertex # type: ignore[import-untyped]
179
- from grandalf.layouts import SugiyamaLayout # type: ignore[import-untyped]
180
- from grandalf.routing import route_with_lines # type: ignore[import-untyped]
181
- except ImportError as exc:
200
+ if not _HAS_GRANDALF:
182
201
  msg = "Install grandalf to draw graphs: `pip install grandalf`."
183
- raise ImportError(msg) from exc
202
+ raise ImportError(msg)
184
203
 
185
204
  #
186
205
  # Just a reminder about naming conventions:
@@ -225,28 +244,31 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str:
225
244
  """Build a DAG and draw it in ASCII.
226
245
 
227
246
  Args:
228
- vertices (list): list of graph vertices.
229
- edges (list): list of graph edges.
247
+ vertices: list of graph vertices.
248
+ edges: list of graph edges.
249
+
250
+ Raises:
251
+ ValueError: if the canvas dimensions are invalid or if
252
+ edge coordinates are invalid.
230
253
 
231
254
  Returns:
232
- str: ASCII representation
255
+ ASCII representation
233
256
 
234
257
  Example:
258
+ ```python
259
+ from langchain_core.runnables.graph_ascii import draw_ascii
235
260
 
236
- .. code-block:: python
237
-
238
- from langchain_core.runnables.graph_ascii import draw_ascii
261
+ vertices = {1: "1", 2: "2", 3: "3", 4: "4"}
262
+ edges = [
263
+ (source, target, None, None)
264
+ for source, target in [(1, 2), (2, 3), (2, 4), (1, 4)]
265
+ ]
239
266
 
240
- vertices = {1: "1", 2: "2", 3: "3", 4: "4"}
241
- edges = [
242
- (source, target, None, None)
243
- for source, target in [(1, 2), (2, 3), (2, 4), (1, 4)]
244
- ]
245
267
 
268
+ print(draw_ascii(vertices, edges))
269
+ ```
246
270
 
247
- print(draw_ascii(vertices, edges))
248
-
249
- .. code-block:: none
271
+ ```txt
250
272
 
251
273
  +---+
252
274
  | 1 |
@@ -263,7 +285,7 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str:
263
285
  +---+ +---+
264
286
  | 3 | | 4 |
265
287
  +---+ +---+
266
-
288
+ ```
267
289
  """
268
290
  # NOTE: coordinates might me negative, so we need to shift
269
291
  # everything to the positive plane before we actually draw it.
@@ -1,24 +1,43 @@
1
1
  """Mermaid graph drawing utilities."""
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import asyncio
4
6
  import base64
5
7
  import random
6
8
  import re
9
+ import string
7
10
  import time
8
11
  from dataclasses import asdict
9
12
  from pathlib import Path
10
- from typing import Any, Literal, Optional
13
+ from typing import TYPE_CHECKING, Any, Literal
11
14
 
12
15
  import yaml
13
16
 
14
17
  from langchain_core.runnables.graph import (
15
18
  CurveStyle,
16
- Edge,
17
19
  MermaidDrawMethod,
18
- Node,
19
20
  NodeStyles,
20
21
  )
21
22
 
23
+ if TYPE_CHECKING:
24
+ from langchain_core.runnables.graph import Edge, Node
25
+
26
+
27
+ try:
28
+ import requests
29
+
30
+ _HAS_REQUESTS = True
31
+ except ImportError:
32
+ _HAS_REQUESTS = False
33
+
34
+ try:
35
+ from pyppeteer import launch # type: ignore[import-not-found]
36
+
37
+ _HAS_PYPPETEER = True
38
+ except ImportError:
39
+ _HAS_PYPPETEER = False
40
+
22
41
  MARKDOWN_SPECIAL_CHARS = "*_`"
23
42
 
24
43
 
@@ -26,50 +45,44 @@ def draw_mermaid(
26
45
  nodes: dict[str, Node],
27
46
  edges: list[Edge],
28
47
  *,
29
- first_node: Optional[str] = None,
30
- last_node: Optional[str] = None,
48
+ first_node: str | None = None,
49
+ last_node: str | None = None,
31
50
  with_styles: bool = True,
32
51
  curve_style: CurveStyle = CurveStyle.LINEAR,
33
- node_styles: Optional[NodeStyles] = None,
52
+ node_styles: NodeStyles | None = None,
34
53
  wrap_label_n_words: int = 9,
35
- frontmatter_config: Optional[dict[str, Any]] = None,
54
+ frontmatter_config: dict[str, Any] | None = None,
36
55
  ) -> str:
37
56
  """Draws a Mermaid graph using the provided graph data.
38
57
 
39
58
  Args:
40
- nodes (dict[str, str]): List of node ids.
41
- edges (list[Edge]): List of edges, object with a source,
42
- target and data.
43
- first_node (str, optional): Id of the first node. Defaults to None.
44
- last_node (str, optional): Id of the last node. Defaults to None.
45
- with_styles (bool, optional): Whether to include styles in the graph.
46
- Defaults to True.
47
- curve_style (CurveStyle, optional): Curve style for the edges.
48
- Defaults to CurveStyle.LINEAR.
49
- node_styles (NodeStyles, optional): Node colors for different types.
50
- Defaults to NodeStyles().
51
- wrap_label_n_words (int, optional): Words to wrap the edge labels.
52
- Defaults to 9.
53
- frontmatter_config (dict[str, Any], optional): Mermaid frontmatter config.
59
+ nodes: List of node ids.
60
+ edges: List of edges, object with a source, target and data.
61
+ first_node: Id of the first node.
62
+ last_node: Id of the last node.
63
+ with_styles: Whether to include styles in the graph.
64
+ curve_style: Curve style for the edges.
65
+ node_styles: Node colors for different types.
66
+ wrap_label_n_words: Words to wrap the edge labels.
67
+ frontmatter_config: Mermaid frontmatter config.
54
68
  Can be used to customize theme and styles. Will be converted to YAML and
55
- added to the beginning of the mermaid graph. Defaults to None.
69
+ added to the beginning of the mermaid graph.
56
70
 
57
71
  See more here: https://mermaid.js.org/config/configuration.html.
58
72
 
59
73
  Example config:
60
74
 
61
- .. code-block:: python
62
-
75
+ ```python
63
76
  {
64
77
  "config": {
65
78
  "theme": "neutral",
66
79
  "look": "handDrawn",
67
- "themeVariables": { "primaryColor": "#e2e2e2"},
80
+ "themeVariables": {"primaryColor": "#e2e2e2"},
68
81
  }
69
82
  }
70
-
83
+ ```
71
84
  Returns:
72
- str: Mermaid graph syntax.
85
+ Mermaid graph syntax.
73
86
 
74
87
  """
75
88
  # Initialize Mermaid graph configuration
@@ -130,7 +143,7 @@ def draw_mermaid(
130
143
  + "</em></small>"
131
144
  )
132
145
  node_label = format_dict.get(key, format_dict[default_class_label]).format(
133
- _escape_node_label(key), label
146
+ _to_safe_id(key), label
134
147
  )
135
148
  return f"{indent}{node_label}\n"
136
149
 
@@ -145,7 +158,7 @@ def draw_mermaid(
145
158
  src_parts = edge.source.split(":")
146
159
  tgt_parts = edge.target.split(":")
147
160
  common_prefix = ":".join(
148
- src for src, tgt in zip(src_parts, tgt_parts) if src == tgt
161
+ src for src, tgt in zip(src_parts, tgt_parts, strict=False) if src == tgt
149
162
  )
150
163
  edge_groups.setdefault(common_prefix, []).append(edge)
151
164
 
@@ -155,7 +168,7 @@ def draw_mermaid(
155
168
  nonlocal mermaid_graph
156
169
  self_loop = len(edges) == 1 and edges[0].source == edges[0].target
157
170
  if prefix and not self_loop:
158
- subgraph = prefix.split(":")[-1]
171
+ subgraph = prefix.rsplit(":", maxsplit=1)[-1]
159
172
  if subgraph in seen_subgraphs:
160
173
  msg = (
161
174
  f"Found duplicate subgraph '{subgraph}' -- this likely means that "
@@ -193,8 +206,7 @@ def draw_mermaid(
193
206
  edge_label = " -.-> " if edge.conditional else " --> "
194
207
 
195
208
  mermaid_graph += (
196
- f"\t{_escape_node_label(source)}{edge_label}"
197
- f"{_escape_node_label(target)};\n"
209
+ f"\t{_to_safe_id(source)}{edge_label}{_to_safe_id(target)};\n"
198
210
  )
199
211
 
200
212
  # Recursively add nested subgraphs
@@ -214,7 +226,7 @@ def draw_mermaid(
214
226
 
215
227
  # Add remaining subgraphs with edges
216
228
  for prefix, edges_ in edge_groups.items():
217
- if ":" in prefix or prefix == "":
229
+ if not prefix or ":" in prefix:
218
230
  continue
219
231
  add_subgraph(edges_, prefix)
220
232
  seen_subgraphs.add(prefix)
@@ -238,9 +250,18 @@ def draw_mermaid(
238
250
  return mermaid_graph
239
251
 
240
252
 
241
- def _escape_node_label(node_label: str) -> str:
242
- """Escapes the node label for Mermaid syntax."""
243
- return re.sub(r"[^a-zA-Z-_0-9]", "_", node_label)
253
+ def _to_safe_id(label: str) -> str:
254
+ """Convert a string into a Mermaid-compatible node id.
255
+
256
+ Keep [a-zA-Z0-9_-] characters unchanged.
257
+ Map every other character -> backslash + lowercase hex codepoint.
258
+
259
+ Result is guaranteed to be unique and Mermaid-compatible,
260
+ so nodes with special characters always render correctly.
261
+ """
262
+ allowed = string.ascii_letters + string.digits + "_-"
263
+ out = [ch if ch in allowed else "\\" + format(ord(ch), "x") for ch in label]
264
+ return "".join(out)
244
265
 
245
266
 
246
267
  def _generate_mermaid_graph_styles(node_colors: NodeStyles) -> str:
@@ -253,38 +274,33 @@ def _generate_mermaid_graph_styles(node_colors: NodeStyles) -> str:
253
274
 
254
275
  def draw_mermaid_png(
255
276
  mermaid_syntax: str,
256
- output_file_path: Optional[str] = None,
277
+ output_file_path: str | None = None,
257
278
  draw_method: MermaidDrawMethod = MermaidDrawMethod.API,
258
- background_color: Optional[str] = "white",
279
+ background_color: str | None = "white",
259
280
  padding: int = 10,
260
281
  max_retries: int = 1,
261
282
  retry_delay: float = 1.0,
283
+ base_url: str | None = None,
262
284
  ) -> bytes:
263
285
  """Draws a Mermaid graph as PNG using provided syntax.
264
286
 
265
287
  Args:
266
- mermaid_syntax (str): Mermaid graph syntax.
267
- output_file_path (str, optional): Path to save the PNG image.
268
- Defaults to None.
269
- draw_method (MermaidDrawMethod, optional): Method to draw the graph.
270
- Defaults to MermaidDrawMethod.API.
271
- background_color (str, optional): Background color of the image.
272
- Defaults to "white".
273
- padding (int, optional): Padding around the image. Defaults to 10.
274
- max_retries (int, optional): Maximum number of retries (MermaidDrawMethod.API).
275
- Defaults to 1.
276
- retry_delay (float, optional): Delay between retries (MermaidDrawMethod.API).
277
- Defaults to 1.0.
288
+ mermaid_syntax: Mermaid graph syntax.
289
+ output_file_path: Path to save the PNG image.
290
+ draw_method: Method to draw the graph.
291
+ background_color: Background color of the image.
292
+ padding: Padding around the image.
293
+ max_retries: Maximum number of retries (MermaidDrawMethod.API).
294
+ retry_delay: Delay between retries (MermaidDrawMethod.API).
295
+ base_url: Base URL for the Mermaid.ink API.
278
296
 
279
297
  Returns:
280
- bytes: PNG image bytes.
298
+ PNG image bytes.
281
299
 
282
300
  Raises:
283
301
  ValueError: If an invalid draw method is provided.
284
302
  """
285
303
  if draw_method == MermaidDrawMethod.PYPPETEER:
286
- import asyncio
287
-
288
304
  img_bytes = asyncio.run(
289
305
  _render_mermaid_using_pyppeteer(
290
306
  mermaid_syntax, output_file_path, background_color, padding
@@ -297,6 +313,7 @@ def draw_mermaid_png(
297
313
  background_color=background_color,
298
314
  max_retries=max_retries,
299
315
  retry_delay=retry_delay,
316
+ base_url=base_url,
300
317
  )
301
318
  else:
302
319
  supported_methods = ", ".join([m.value for m in MermaidDrawMethod])
@@ -311,17 +328,15 @@ def draw_mermaid_png(
311
328
 
312
329
  async def _render_mermaid_using_pyppeteer(
313
330
  mermaid_syntax: str,
314
- output_file_path: Optional[str] = None,
315
- background_color: Optional[str] = "white",
331
+ output_file_path: str | None = None,
332
+ background_color: str | None = "white",
316
333
  padding: int = 10,
317
334
  device_scale_factor: int = 3,
318
335
  ) -> bytes:
319
336
  """Renders Mermaid graph using Pyppeteer."""
320
- try:
321
- from pyppeteer import launch # type: ignore[import-not-found]
322
- except ImportError as e:
337
+ if not _HAS_PYPPETEER:
323
338
  msg = "Install Pyppeteer to use the Pyppeteer method: `pip install pyppeteer`."
324
- raise ImportError(msg) from e
339
+ raise ImportError(msg)
325
340
 
326
341
  browser = await launch()
327
342
  page = await browser.newPage()
@@ -385,21 +400,23 @@ async def _render_mermaid_using_pyppeteer(
385
400
  def _render_mermaid_using_api(
386
401
  mermaid_syntax: str,
387
402
  *,
388
- output_file_path: Optional[str] = None,
389
- background_color: Optional[str] = "white",
390
- file_type: Optional[Literal["jpeg", "png", "webp"]] = "png",
403
+ output_file_path: str | None = None,
404
+ background_color: str | None = "white",
405
+ file_type: Literal["jpeg", "png", "webp"] | None = "png",
391
406
  max_retries: int = 1,
392
407
  retry_delay: float = 1.0,
408
+ base_url: str | None = None,
393
409
  ) -> bytes:
394
410
  """Renders Mermaid graph using the Mermaid.INK API."""
395
- try:
396
- import requests
397
- except ImportError as e:
411
+ # Defaults to using the public mermaid.ink server.
412
+ base_url = base_url if base_url is not None else "https://mermaid.ink"
413
+
414
+ if not _HAS_REQUESTS:
398
415
  msg = (
399
416
  "Install the `requests` module to use the Mermaid.INK API: "
400
417
  "`pip install requests`."
401
418
  )
402
- raise ImportError(msg) from e
419
+ raise ImportError(msg)
403
420
 
404
421
  # Use Mermaid API to render the image
405
422
  mermaid_syntax_encoded = base64.b64encode(mermaid_syntax.encode("utf8")).decode(
@@ -413,7 +430,7 @@ def _render_mermaid_using_api(
413
430
  background_color = f"!{background_color}"
414
431
 
415
432
  image_url = (
416
- f"https://mermaid.ink/img/{mermaid_syntax_encoded}"
433
+ f"{base_url}/img/{mermaid_syntax_encoded}"
417
434
  f"?type={file_type}&bgColor={background_color}"
418
435
  )
419
436
 
@@ -445,7 +462,7 @@ def _render_mermaid_using_api(
445
462
 
446
463
  # For other status codes, fail immediately
447
464
  msg = (
448
- "Failed to reach https://mermaid.ink/ API while trying to render "
465
+ f"Failed to reach {base_url} API while trying to render "
449
466
  f"your graph. Status code: {response.status_code}.\n\n"
450
467
  ) + error_msg_suffix
451
468
  raise ValueError(msg)
@@ -457,14 +474,14 @@ def _render_mermaid_using_api(
457
474
  time.sleep(sleep_time)
458
475
  else:
459
476
  msg = (
460
- "Failed to reach https://mermaid.ink/ API while trying to render "
477
+ f"Failed to reach {base_url} API while trying to render "
461
478
  f"your graph after {max_retries} retries. "
462
479
  ) + error_msg_suffix
463
480
  raise ValueError(msg) from e
464
481
 
465
482
  # This should not be reached, but just in case
466
483
  msg = (
467
- "Failed to reach https://mermaid.ink/ API while trying to render "
484
+ f"Failed to reach {base_url} API while trying to render "
468
485
  f"your graph after {max_retries} retries. "
469
486
  ) + error_msg_suffix
470
487
  raise ValueError(msg)
@@ -1,36 +1,31 @@
1
1
  """Helper class to draw a state graph into a PNG file."""
2
2
 
3
- from typing import Any, Optional
3
+ from typing import Any
4
4
 
5
5
  from langchain_core.runnables.graph import Graph, LabelsDict
6
6
 
7
+ try:
8
+ import pygraphviz as pgv # type: ignore[import-not-found]
9
+
10
+ _HAS_PYGRAPHVIZ = True
11
+ except ImportError:
12
+ _HAS_PYGRAPHVIZ = False
13
+
7
14
 
8
15
  class PngDrawer:
9
16
  """Helper class to draw a state graph into a PNG file.
10
17
 
11
18
  It requires `graphviz` and `pygraphviz` to be installed.
12
- :param fontname: The font to use for the labels
13
- :param labels: A dictionary of label overrides. The dictionary
14
- should have the following format:
15
- {
16
- "nodes": {
17
- "node1": "CustomLabel1",
18
- "node2": "CustomLabel2",
19
- "__end__": "End Node"
20
- },
21
- "edges": {
22
- "continue": "ContinueLabel",
23
- "end": "EndLabel"
24
- }
25
- }
26
- The keys are the original labels, and the values are the new labels.
27
- Usage:
19
+
20
+ Example:
21
+ ```python
28
22
  drawer = PngDrawer()
29
- drawer.draw(state_graph, 'graph.png')
23
+ drawer.draw(state_graph, "graph.png")
24
+ ```
30
25
  """
31
26
 
32
27
  def __init__(
33
- self, fontname: Optional[str] = None, labels: Optional[LabelsDict] = None
28
+ self, fontname: str | None = None, labels: LabelsDict | None = None
34
29
  ) -> None:
35
30
  """Initializes the PNG drawer.
36
31
 
@@ -50,7 +45,7 @@ class PngDrawer:
50
45
  }
51
46
  }
52
47
  The keys are the original labels, and the values are the new labels.
53
- Defaults to None.
48
+
54
49
  """
55
50
  self.fontname = fontname or "arial"
56
51
  self.labels = labels or LabelsDict(nodes={}, edges={})
@@ -85,9 +80,6 @@ class PngDrawer:
85
80
  Args:
86
81
  viz: The graphviz object.
87
82
  node: The node to add.
88
-
89
- Returns:
90
- None
91
83
  """
92
84
  viz.add_node(
93
85
  node,
@@ -103,7 +95,7 @@ class PngDrawer:
103
95
  viz: Any,
104
96
  source: str,
105
97
  target: str,
106
- label: Optional[str] = None,
98
+ label: str | None = None,
107
99
  conditional: bool = False, # noqa: FBT001,FBT002
108
100
  ) -> None:
109
101
  """Adds an edge to the graph.
@@ -112,11 +104,8 @@ class PngDrawer:
112
104
  viz: The graphviz object.
113
105
  source: The source node.
114
106
  target: The target node.
115
- label: The label for the edge. Defaults to None.
116
- conditional: Whether the edge is conditional. Defaults to False.
117
-
118
- Returns:
119
- None
107
+ label: The label for the edge.
108
+ conditional: Whether the edge is conditional.
120
109
  """
121
110
  viz.add_edge(
122
111
  source,
@@ -127,18 +116,24 @@ class PngDrawer:
127
116
  style="dotted" if conditional else "solid",
128
117
  )
129
118
 
130
- def draw(self, graph: Graph, output_path: Optional[str] = None) -> Optional[bytes]:
119
+ def draw(self, graph: Graph, output_path: str | None = None) -> bytes | None:
131
120
  """Draw the given state graph into a PNG file.
132
121
 
133
122
  Requires `graphviz` and `pygraphviz` to be installed.
134
- :param graph: The graph to draw
135
- :param output_path: The path to save the PNG. If None, PNG bytes are returned.
123
+
124
+ Args:
125
+ graph: The graph to draw
126
+ output_path: The path to save the PNG. If `None`, PNG bytes are returned.
127
+
128
+ Raises:
129
+ ImportError: If `pygraphviz` is not installed.
130
+
131
+ Returns:
132
+ The PNG bytes if `output_path` is None, else None.
136
133
  """
137
- try:
138
- import pygraphviz as pgv # type: ignore[import-not-found]
139
- except ImportError as exc:
134
+ if not _HAS_PYGRAPHVIZ:
140
135
  msg = "Install pygraphviz to draw graphs: `pip install pygraphviz`."
141
- raise ImportError(msg) from exc
136
+ raise ImportError(msg)
142
137
 
143
138
  # Create a directed graph
144
139
  viz = pgv.AGraph(directed=True, nodesep=0.9, ranksep=1.0)