langchain-core 0.4.0.dev0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain-core might be problematic. Click here for more details.
- langchain_core/__init__.py +1 -1
- langchain_core/_api/__init__.py +3 -4
- langchain_core/_api/beta_decorator.py +45 -70
- langchain_core/_api/deprecation.py +80 -80
- langchain_core/_api/path.py +22 -8
- langchain_core/_import_utils.py +10 -4
- langchain_core/agents.py +25 -21
- langchain_core/caches.py +53 -63
- langchain_core/callbacks/__init__.py +1 -8
- langchain_core/callbacks/base.py +341 -348
- langchain_core/callbacks/file.py +55 -44
- langchain_core/callbacks/manager.py +546 -683
- langchain_core/callbacks/stdout.py +29 -30
- langchain_core/callbacks/streaming_stdout.py +35 -36
- langchain_core/callbacks/usage.py +65 -70
- langchain_core/chat_history.py +48 -55
- langchain_core/document_loaders/base.py +46 -21
- langchain_core/document_loaders/langsmith.py +39 -36
- langchain_core/documents/__init__.py +0 -1
- langchain_core/documents/base.py +96 -74
- langchain_core/documents/compressor.py +12 -9
- langchain_core/documents/transformers.py +29 -28
- langchain_core/embeddings/fake.py +56 -57
- langchain_core/env.py +2 -3
- langchain_core/example_selectors/base.py +12 -0
- langchain_core/example_selectors/length_based.py +1 -1
- langchain_core/example_selectors/semantic_similarity.py +21 -25
- langchain_core/exceptions.py +15 -9
- langchain_core/globals.py +4 -163
- langchain_core/indexing/api.py +132 -125
- langchain_core/indexing/base.py +64 -67
- langchain_core/indexing/in_memory.py +26 -6
- langchain_core/language_models/__init__.py +15 -27
- langchain_core/language_models/_utils.py +267 -117
- langchain_core/language_models/base.py +92 -177
- langchain_core/language_models/chat_models.py +547 -407
- langchain_core/language_models/fake.py +11 -11
- langchain_core/language_models/fake_chat_models.py +72 -118
- langchain_core/language_models/llms.py +168 -242
- langchain_core/load/dump.py +8 -11
- langchain_core/load/load.py +32 -28
- langchain_core/load/mapping.py +2 -4
- langchain_core/load/serializable.py +50 -56
- langchain_core/messages/__init__.py +36 -51
- langchain_core/messages/ai.py +377 -150
- langchain_core/messages/base.py +239 -47
- langchain_core/messages/block_translators/__init__.py +111 -0
- langchain_core/messages/block_translators/anthropic.py +470 -0
- langchain_core/messages/block_translators/bedrock.py +94 -0
- langchain_core/messages/block_translators/bedrock_converse.py +297 -0
- langchain_core/messages/block_translators/google_genai.py +530 -0
- langchain_core/messages/block_translators/google_vertexai.py +21 -0
- langchain_core/messages/block_translators/groq.py +143 -0
- langchain_core/messages/block_translators/langchain_v0.py +301 -0
- langchain_core/messages/block_translators/openai.py +1010 -0
- langchain_core/messages/chat.py +2 -3
- langchain_core/messages/content.py +1423 -0
- langchain_core/messages/function.py +7 -7
- langchain_core/messages/human.py +44 -38
- langchain_core/messages/modifier.py +3 -2
- langchain_core/messages/system.py +40 -27
- langchain_core/messages/tool.py +160 -58
- langchain_core/messages/utils.py +527 -638
- langchain_core/output_parsers/__init__.py +1 -14
- langchain_core/output_parsers/base.py +68 -104
- langchain_core/output_parsers/json.py +13 -17
- langchain_core/output_parsers/list.py +11 -33
- langchain_core/output_parsers/openai_functions.py +56 -74
- langchain_core/output_parsers/openai_tools.py +68 -109
- langchain_core/output_parsers/pydantic.py +15 -13
- langchain_core/output_parsers/string.py +6 -2
- langchain_core/output_parsers/transform.py +17 -60
- langchain_core/output_parsers/xml.py +34 -44
- langchain_core/outputs/__init__.py +1 -1
- langchain_core/outputs/chat_generation.py +26 -11
- langchain_core/outputs/chat_result.py +1 -3
- langchain_core/outputs/generation.py +17 -6
- langchain_core/outputs/llm_result.py +15 -8
- langchain_core/prompt_values.py +29 -123
- langchain_core/prompts/__init__.py +3 -27
- langchain_core/prompts/base.py +48 -63
- langchain_core/prompts/chat.py +259 -288
- langchain_core/prompts/dict.py +19 -11
- langchain_core/prompts/few_shot.py +84 -90
- langchain_core/prompts/few_shot_with_templates.py +14 -12
- langchain_core/prompts/image.py +19 -14
- langchain_core/prompts/loading.py +6 -8
- langchain_core/prompts/message.py +7 -8
- langchain_core/prompts/prompt.py +42 -43
- langchain_core/prompts/string.py +37 -16
- langchain_core/prompts/structured.py +43 -46
- langchain_core/rate_limiters.py +51 -60
- langchain_core/retrievers.py +52 -192
- langchain_core/runnables/base.py +1727 -1683
- langchain_core/runnables/branch.py +52 -73
- langchain_core/runnables/config.py +89 -103
- langchain_core/runnables/configurable.py +128 -130
- langchain_core/runnables/fallbacks.py +93 -82
- langchain_core/runnables/graph.py +127 -127
- langchain_core/runnables/graph_ascii.py +63 -41
- langchain_core/runnables/graph_mermaid.py +87 -70
- langchain_core/runnables/graph_png.py +31 -36
- langchain_core/runnables/history.py +145 -161
- langchain_core/runnables/passthrough.py +141 -144
- langchain_core/runnables/retry.py +84 -68
- langchain_core/runnables/router.py +33 -37
- langchain_core/runnables/schema.py +79 -72
- langchain_core/runnables/utils.py +95 -139
- langchain_core/stores.py +85 -131
- langchain_core/structured_query.py +11 -15
- langchain_core/sys_info.py +31 -32
- langchain_core/tools/__init__.py +1 -14
- langchain_core/tools/base.py +221 -247
- langchain_core/tools/convert.py +144 -161
- langchain_core/tools/render.py +10 -10
- langchain_core/tools/retriever.py +12 -19
- langchain_core/tools/simple.py +52 -29
- langchain_core/tools/structured.py +56 -60
- langchain_core/tracers/__init__.py +1 -9
- langchain_core/tracers/_streaming.py +6 -7
- langchain_core/tracers/base.py +103 -112
- langchain_core/tracers/context.py +29 -48
- langchain_core/tracers/core.py +142 -105
- langchain_core/tracers/evaluation.py +30 -34
- langchain_core/tracers/event_stream.py +162 -117
- langchain_core/tracers/langchain.py +34 -36
- langchain_core/tracers/log_stream.py +87 -49
- langchain_core/tracers/memory_stream.py +3 -3
- langchain_core/tracers/root_listeners.py +18 -34
- langchain_core/tracers/run_collector.py +8 -20
- langchain_core/tracers/schemas.py +0 -125
- langchain_core/tracers/stdout.py +3 -3
- langchain_core/utils/__init__.py +1 -4
- langchain_core/utils/_merge.py +47 -9
- langchain_core/utils/aiter.py +70 -66
- langchain_core/utils/env.py +12 -9
- langchain_core/utils/function_calling.py +139 -206
- langchain_core/utils/html.py +7 -8
- langchain_core/utils/input.py +6 -6
- langchain_core/utils/interactive_env.py +6 -2
- langchain_core/utils/iter.py +48 -45
- langchain_core/utils/json.py +14 -4
- langchain_core/utils/json_schema.py +159 -43
- langchain_core/utils/mustache.py +32 -25
- langchain_core/utils/pydantic.py +67 -40
- langchain_core/utils/strings.py +5 -5
- langchain_core/utils/usage.py +1 -1
- langchain_core/utils/utils.py +104 -62
- langchain_core/vectorstores/base.py +131 -179
- langchain_core/vectorstores/in_memory.py +113 -182
- langchain_core/vectorstores/utils.py +23 -17
- langchain_core/version.py +1 -1
- langchain_core-1.0.0.dist-info/METADATA +68 -0
- langchain_core-1.0.0.dist-info/RECORD +172 -0
- {langchain_core-0.4.0.dev0.dist-info → langchain_core-1.0.0.dist-info}/WHEEL +1 -1
- langchain_core/beta/__init__.py +0 -1
- langchain_core/beta/runnables/__init__.py +0 -1
- langchain_core/beta/runnables/context.py +0 -448
- langchain_core/memory.py +0 -116
- langchain_core/messages/content_blocks.py +0 -1435
- langchain_core/prompts/pipeline.py +0 -133
- langchain_core/pydantic_v1/__init__.py +0 -30
- langchain_core/pydantic_v1/dataclasses.py +0 -23
- langchain_core/pydantic_v1/main.py +0 -23
- langchain_core/tracers/langchain_v1.py +0 -23
- langchain_core/utils/loading.py +0 -31
- langchain_core/v1/__init__.py +0 -1
- langchain_core/v1/chat_models.py +0 -1047
- langchain_core/v1/messages.py +0 -755
- langchain_core-0.4.0.dev0.dist-info/METADATA +0 -108
- langchain_core-0.4.0.dev0.dist-info/RECORD +0 -177
- langchain_core-0.4.0.dev0.dist-info/entry_points.txt +0 -4
|
@@ -3,12 +3,24 @@
|
|
|
3
3
|
Adapted from https://github.com/iterative/dvc/blob/main/dvc/dagascii.py.
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
6
8
|
import math
|
|
7
9
|
import os
|
|
8
10
|
from collections.abc import Mapping, Sequence
|
|
9
|
-
from typing import Any
|
|
11
|
+
from typing import TYPE_CHECKING, Any
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from grandalf.graphs import Edge, Graph, Vertex # type: ignore[import-untyped]
|
|
15
|
+
from grandalf.layouts import SugiyamaLayout # type: ignore[import-untyped]
|
|
16
|
+
from grandalf.routing import route_with_lines # type: ignore[import-untyped]
|
|
17
|
+
|
|
18
|
+
_HAS_GRANDALF = True
|
|
19
|
+
except ImportError:
|
|
20
|
+
_HAS_GRANDALF = False
|
|
10
21
|
|
|
11
|
-
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from langchain_core.runnables.graph import Edge as LangEdge
|
|
12
24
|
|
|
13
25
|
|
|
14
26
|
class VertexViewer:
|
|
@@ -50,8 +62,11 @@ class AsciiCanvas:
|
|
|
50
62
|
"""Create an ASCII canvas.
|
|
51
63
|
|
|
52
64
|
Args:
|
|
53
|
-
cols
|
|
54
|
-
lines
|
|
65
|
+
cols: number of columns in the canvas. Should be `> 1`.
|
|
66
|
+
lines: number of lines in the canvas. Should be `> 1`.
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
ValueError: if canvas dimensions are invalid.
|
|
55
70
|
"""
|
|
56
71
|
if cols <= 1 or lines <= 1:
|
|
57
72
|
msg = "Canvas dimensions should be > 1"
|
|
@@ -63,7 +78,11 @@ class AsciiCanvas:
|
|
|
63
78
|
self.canvas = [[" "] * cols for line in range(lines)]
|
|
64
79
|
|
|
65
80
|
def draw(self) -> str:
|
|
66
|
-
"""Draws ASCII canvas on the screen.
|
|
81
|
+
"""Draws ASCII canvas on the screen.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
The ASCII canvas string.
|
|
85
|
+
"""
|
|
67
86
|
lines = map("".join, self.canvas)
|
|
68
87
|
return os.linesep.join(lines)
|
|
69
88
|
|
|
@@ -71,12 +90,16 @@ class AsciiCanvas:
|
|
|
71
90
|
"""Create a point on ASCII canvas.
|
|
72
91
|
|
|
73
92
|
Args:
|
|
74
|
-
x
|
|
93
|
+
x: x coordinate. Should be `>= 0` and `<` number of columns in
|
|
75
94
|
the canvas.
|
|
76
|
-
y
|
|
95
|
+
y: y coordinate. Should be `>= 0` an `<` number of lines in the
|
|
77
96
|
canvas.
|
|
78
|
-
char
|
|
97
|
+
char: character to place in the specified point on the
|
|
79
98
|
canvas.
|
|
99
|
+
|
|
100
|
+
Raises:
|
|
101
|
+
ValueError: if char is not a single character or if
|
|
102
|
+
coordinates are out of bounds.
|
|
80
103
|
"""
|
|
81
104
|
if len(char) != 1:
|
|
82
105
|
msg = "char should be a single character"
|
|
@@ -94,11 +117,11 @@ class AsciiCanvas:
|
|
|
94
117
|
"""Create a line on ASCII canvas.
|
|
95
118
|
|
|
96
119
|
Args:
|
|
97
|
-
x0
|
|
98
|
-
y0
|
|
99
|
-
x1
|
|
100
|
-
y1
|
|
101
|
-
char
|
|
120
|
+
x0: x coordinate where the line should start.
|
|
121
|
+
y0: y coordinate where the line should start.
|
|
122
|
+
x1: x coordinate where the line should end.
|
|
123
|
+
y1: y coordinate where the line should end.
|
|
124
|
+
char: character to draw the line with.
|
|
102
125
|
"""
|
|
103
126
|
if x0 > x1:
|
|
104
127
|
x1, x0 = x0, x1
|
|
@@ -126,9 +149,9 @@ class AsciiCanvas:
|
|
|
126
149
|
"""Print a text on ASCII canvas.
|
|
127
150
|
|
|
128
151
|
Args:
|
|
129
|
-
x
|
|
130
|
-
y
|
|
131
|
-
text
|
|
152
|
+
x: x coordinate where the text should start.
|
|
153
|
+
y: y coordinate where the text should start.
|
|
154
|
+
text: string that should be printed.
|
|
132
155
|
"""
|
|
133
156
|
for i, char in enumerate(text):
|
|
134
157
|
self.point(x + i, y, char)
|
|
@@ -137,10 +160,10 @@ class AsciiCanvas:
|
|
|
137
160
|
"""Create a box on ASCII canvas.
|
|
138
161
|
|
|
139
162
|
Args:
|
|
140
|
-
x0
|
|
141
|
-
y0
|
|
142
|
-
width
|
|
143
|
-
height
|
|
163
|
+
x0: x coordinate of the box corner.
|
|
164
|
+
y0: y coordinate of the box corner.
|
|
165
|
+
width: box width.
|
|
166
|
+
height: box height.
|
|
144
167
|
"""
|
|
145
168
|
if width <= 1 or height <= 1:
|
|
146
169
|
msg = "Box dimensions should be > 1"
|
|
@@ -174,13 +197,9 @@ class _EdgeViewer:
|
|
|
174
197
|
def _build_sugiyama_layout(
|
|
175
198
|
vertices: Mapping[str, str], edges: Sequence[LangEdge]
|
|
176
199
|
) -> Any:
|
|
177
|
-
|
|
178
|
-
from grandalf.graphs import Edge, Graph, Vertex # type: ignore[import-untyped]
|
|
179
|
-
from grandalf.layouts import SugiyamaLayout # type: ignore[import-untyped]
|
|
180
|
-
from grandalf.routing import route_with_lines # type: ignore[import-untyped]
|
|
181
|
-
except ImportError as exc:
|
|
200
|
+
if not _HAS_GRANDALF:
|
|
182
201
|
msg = "Install grandalf to draw graphs: `pip install grandalf`."
|
|
183
|
-
raise ImportError(msg)
|
|
202
|
+
raise ImportError(msg)
|
|
184
203
|
|
|
185
204
|
#
|
|
186
205
|
# Just a reminder about naming conventions:
|
|
@@ -225,28 +244,31 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str:
|
|
|
225
244
|
"""Build a DAG and draw it in ASCII.
|
|
226
245
|
|
|
227
246
|
Args:
|
|
228
|
-
vertices
|
|
229
|
-
edges
|
|
247
|
+
vertices: list of graph vertices.
|
|
248
|
+
edges: list of graph edges.
|
|
249
|
+
|
|
250
|
+
Raises:
|
|
251
|
+
ValueError: if the canvas dimensions are invalid or if
|
|
252
|
+
edge coordinates are invalid.
|
|
230
253
|
|
|
231
254
|
Returns:
|
|
232
|
-
|
|
255
|
+
ASCII representation
|
|
233
256
|
|
|
234
257
|
Example:
|
|
258
|
+
```python
|
|
259
|
+
from langchain_core.runnables.graph_ascii import draw_ascii
|
|
235
260
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
261
|
+
vertices = {1: "1", 2: "2", 3: "3", 4: "4"}
|
|
262
|
+
edges = [
|
|
263
|
+
(source, target, None, None)
|
|
264
|
+
for source, target in [(1, 2), (2, 3), (2, 4), (1, 4)]
|
|
265
|
+
]
|
|
239
266
|
|
|
240
|
-
vertices = {1: "1", 2: "2", 3: "3", 4: "4"}
|
|
241
|
-
edges = [
|
|
242
|
-
(source, target, None, None)
|
|
243
|
-
for source, target in [(1, 2), (2, 3), (2, 4), (1, 4)]
|
|
244
|
-
]
|
|
245
267
|
|
|
268
|
+
print(draw_ascii(vertices, edges))
|
|
269
|
+
```
|
|
246
270
|
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
.. code-block:: none
|
|
271
|
+
```txt
|
|
250
272
|
|
|
251
273
|
+---+
|
|
252
274
|
| 1 |
|
|
@@ -263,7 +285,7 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str:
|
|
|
263
285
|
+---+ +---+
|
|
264
286
|
| 3 | | 4 |
|
|
265
287
|
+---+ +---+
|
|
266
|
-
|
|
288
|
+
```
|
|
267
289
|
"""
|
|
268
290
|
# NOTE: coordinates might me negative, so we need to shift
|
|
269
291
|
# everything to the positive plane before we actually draw it.
|
|
@@ -1,24 +1,43 @@
|
|
|
1
1
|
"""Mermaid graph drawing utilities."""
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
import asyncio
|
|
4
6
|
import base64
|
|
5
7
|
import random
|
|
6
8
|
import re
|
|
9
|
+
import string
|
|
7
10
|
import time
|
|
8
11
|
from dataclasses import asdict
|
|
9
12
|
from pathlib import Path
|
|
10
|
-
from typing import Any, Literal
|
|
13
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
11
14
|
|
|
12
15
|
import yaml
|
|
13
16
|
|
|
14
17
|
from langchain_core.runnables.graph import (
|
|
15
18
|
CurveStyle,
|
|
16
|
-
Edge,
|
|
17
19
|
MermaidDrawMethod,
|
|
18
|
-
Node,
|
|
19
20
|
NodeStyles,
|
|
20
21
|
)
|
|
21
22
|
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from langchain_core.runnables.graph import Edge, Node
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
import requests
|
|
29
|
+
|
|
30
|
+
_HAS_REQUESTS = True
|
|
31
|
+
except ImportError:
|
|
32
|
+
_HAS_REQUESTS = False
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
from pyppeteer import launch # type: ignore[import-not-found]
|
|
36
|
+
|
|
37
|
+
_HAS_PYPPETEER = True
|
|
38
|
+
except ImportError:
|
|
39
|
+
_HAS_PYPPETEER = False
|
|
40
|
+
|
|
22
41
|
MARKDOWN_SPECIAL_CHARS = "*_`"
|
|
23
42
|
|
|
24
43
|
|
|
@@ -26,50 +45,44 @@ def draw_mermaid(
|
|
|
26
45
|
nodes: dict[str, Node],
|
|
27
46
|
edges: list[Edge],
|
|
28
47
|
*,
|
|
29
|
-
first_node:
|
|
30
|
-
last_node:
|
|
48
|
+
first_node: str | None = None,
|
|
49
|
+
last_node: str | None = None,
|
|
31
50
|
with_styles: bool = True,
|
|
32
51
|
curve_style: CurveStyle = CurveStyle.LINEAR,
|
|
33
|
-
node_styles:
|
|
52
|
+
node_styles: NodeStyles | None = None,
|
|
34
53
|
wrap_label_n_words: int = 9,
|
|
35
|
-
frontmatter_config:
|
|
54
|
+
frontmatter_config: dict[str, Any] | None = None,
|
|
36
55
|
) -> str:
|
|
37
56
|
"""Draws a Mermaid graph using the provided graph data.
|
|
38
57
|
|
|
39
58
|
Args:
|
|
40
|
-
nodes
|
|
41
|
-
edges
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
node_styles (NodeStyles, optional): Node colors for different types.
|
|
50
|
-
Defaults to NodeStyles().
|
|
51
|
-
wrap_label_n_words (int, optional): Words to wrap the edge labels.
|
|
52
|
-
Defaults to 9.
|
|
53
|
-
frontmatter_config (dict[str, Any], optional): Mermaid frontmatter config.
|
|
59
|
+
nodes: List of node ids.
|
|
60
|
+
edges: List of edges, object with a source, target and data.
|
|
61
|
+
first_node: Id of the first node.
|
|
62
|
+
last_node: Id of the last node.
|
|
63
|
+
with_styles: Whether to include styles in the graph.
|
|
64
|
+
curve_style: Curve style for the edges.
|
|
65
|
+
node_styles: Node colors for different types.
|
|
66
|
+
wrap_label_n_words: Words to wrap the edge labels.
|
|
67
|
+
frontmatter_config: Mermaid frontmatter config.
|
|
54
68
|
Can be used to customize theme and styles. Will be converted to YAML and
|
|
55
|
-
added to the beginning of the mermaid graph.
|
|
69
|
+
added to the beginning of the mermaid graph.
|
|
56
70
|
|
|
57
71
|
See more here: https://mermaid.js.org/config/configuration.html.
|
|
58
72
|
|
|
59
73
|
Example config:
|
|
60
74
|
|
|
61
|
-
|
|
62
|
-
|
|
75
|
+
```python
|
|
63
76
|
{
|
|
64
77
|
"config": {
|
|
65
78
|
"theme": "neutral",
|
|
66
79
|
"look": "handDrawn",
|
|
67
|
-
"themeVariables": {
|
|
80
|
+
"themeVariables": {"primaryColor": "#e2e2e2"},
|
|
68
81
|
}
|
|
69
82
|
}
|
|
70
|
-
|
|
83
|
+
```
|
|
71
84
|
Returns:
|
|
72
|
-
|
|
85
|
+
Mermaid graph syntax.
|
|
73
86
|
|
|
74
87
|
"""
|
|
75
88
|
# Initialize Mermaid graph configuration
|
|
@@ -130,7 +143,7 @@ def draw_mermaid(
|
|
|
130
143
|
+ "</em></small>"
|
|
131
144
|
)
|
|
132
145
|
node_label = format_dict.get(key, format_dict[default_class_label]).format(
|
|
133
|
-
|
|
146
|
+
_to_safe_id(key), label
|
|
134
147
|
)
|
|
135
148
|
return f"{indent}{node_label}\n"
|
|
136
149
|
|
|
@@ -145,7 +158,7 @@ def draw_mermaid(
|
|
|
145
158
|
src_parts = edge.source.split(":")
|
|
146
159
|
tgt_parts = edge.target.split(":")
|
|
147
160
|
common_prefix = ":".join(
|
|
148
|
-
src for src, tgt in zip(src_parts, tgt_parts) if src == tgt
|
|
161
|
+
src for src, tgt in zip(src_parts, tgt_parts, strict=False) if src == tgt
|
|
149
162
|
)
|
|
150
163
|
edge_groups.setdefault(common_prefix, []).append(edge)
|
|
151
164
|
|
|
@@ -155,7 +168,7 @@ def draw_mermaid(
|
|
|
155
168
|
nonlocal mermaid_graph
|
|
156
169
|
self_loop = len(edges) == 1 and edges[0].source == edges[0].target
|
|
157
170
|
if prefix and not self_loop:
|
|
158
|
-
subgraph = prefix.
|
|
171
|
+
subgraph = prefix.rsplit(":", maxsplit=1)[-1]
|
|
159
172
|
if subgraph in seen_subgraphs:
|
|
160
173
|
msg = (
|
|
161
174
|
f"Found duplicate subgraph '{subgraph}' -- this likely means that "
|
|
@@ -193,8 +206,7 @@ def draw_mermaid(
|
|
|
193
206
|
edge_label = " -.-> " if edge.conditional else " --> "
|
|
194
207
|
|
|
195
208
|
mermaid_graph += (
|
|
196
|
-
f"\t{
|
|
197
|
-
f"{_escape_node_label(target)};\n"
|
|
209
|
+
f"\t{_to_safe_id(source)}{edge_label}{_to_safe_id(target)};\n"
|
|
198
210
|
)
|
|
199
211
|
|
|
200
212
|
# Recursively add nested subgraphs
|
|
@@ -214,7 +226,7 @@ def draw_mermaid(
|
|
|
214
226
|
|
|
215
227
|
# Add remaining subgraphs with edges
|
|
216
228
|
for prefix, edges_ in edge_groups.items():
|
|
217
|
-
if ":" in prefix
|
|
229
|
+
if not prefix or ":" in prefix:
|
|
218
230
|
continue
|
|
219
231
|
add_subgraph(edges_, prefix)
|
|
220
232
|
seen_subgraphs.add(prefix)
|
|
@@ -238,9 +250,18 @@ def draw_mermaid(
|
|
|
238
250
|
return mermaid_graph
|
|
239
251
|
|
|
240
252
|
|
|
241
|
-
def
|
|
242
|
-
"""
|
|
243
|
-
|
|
253
|
+
def _to_safe_id(label: str) -> str:
|
|
254
|
+
"""Convert a string into a Mermaid-compatible node id.
|
|
255
|
+
|
|
256
|
+
Keep [a-zA-Z0-9_-] characters unchanged.
|
|
257
|
+
Map every other character -> backslash + lowercase hex codepoint.
|
|
258
|
+
|
|
259
|
+
Result is guaranteed to be unique and Mermaid-compatible,
|
|
260
|
+
so nodes with special characters always render correctly.
|
|
261
|
+
"""
|
|
262
|
+
allowed = string.ascii_letters + string.digits + "_-"
|
|
263
|
+
out = [ch if ch in allowed else "\\" + format(ord(ch), "x") for ch in label]
|
|
264
|
+
return "".join(out)
|
|
244
265
|
|
|
245
266
|
|
|
246
267
|
def _generate_mermaid_graph_styles(node_colors: NodeStyles) -> str:
|
|
@@ -253,38 +274,33 @@ def _generate_mermaid_graph_styles(node_colors: NodeStyles) -> str:
|
|
|
253
274
|
|
|
254
275
|
def draw_mermaid_png(
|
|
255
276
|
mermaid_syntax: str,
|
|
256
|
-
output_file_path:
|
|
277
|
+
output_file_path: str | None = None,
|
|
257
278
|
draw_method: MermaidDrawMethod = MermaidDrawMethod.API,
|
|
258
|
-
background_color:
|
|
279
|
+
background_color: str | None = "white",
|
|
259
280
|
padding: int = 10,
|
|
260
281
|
max_retries: int = 1,
|
|
261
282
|
retry_delay: float = 1.0,
|
|
283
|
+
base_url: str | None = None,
|
|
262
284
|
) -> bytes:
|
|
263
285
|
"""Draws a Mermaid graph as PNG using provided syntax.
|
|
264
286
|
|
|
265
287
|
Args:
|
|
266
|
-
mermaid_syntax
|
|
267
|
-
output_file_path
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
max_retries (int, optional): Maximum number of retries (MermaidDrawMethod.API).
|
|
275
|
-
Defaults to 1.
|
|
276
|
-
retry_delay (float, optional): Delay between retries (MermaidDrawMethod.API).
|
|
277
|
-
Defaults to 1.0.
|
|
288
|
+
mermaid_syntax: Mermaid graph syntax.
|
|
289
|
+
output_file_path: Path to save the PNG image.
|
|
290
|
+
draw_method: Method to draw the graph.
|
|
291
|
+
background_color: Background color of the image.
|
|
292
|
+
padding: Padding around the image.
|
|
293
|
+
max_retries: Maximum number of retries (MermaidDrawMethod.API).
|
|
294
|
+
retry_delay: Delay between retries (MermaidDrawMethod.API).
|
|
295
|
+
base_url: Base URL for the Mermaid.ink API.
|
|
278
296
|
|
|
279
297
|
Returns:
|
|
280
|
-
|
|
298
|
+
PNG image bytes.
|
|
281
299
|
|
|
282
300
|
Raises:
|
|
283
301
|
ValueError: If an invalid draw method is provided.
|
|
284
302
|
"""
|
|
285
303
|
if draw_method == MermaidDrawMethod.PYPPETEER:
|
|
286
|
-
import asyncio
|
|
287
|
-
|
|
288
304
|
img_bytes = asyncio.run(
|
|
289
305
|
_render_mermaid_using_pyppeteer(
|
|
290
306
|
mermaid_syntax, output_file_path, background_color, padding
|
|
@@ -297,6 +313,7 @@ def draw_mermaid_png(
|
|
|
297
313
|
background_color=background_color,
|
|
298
314
|
max_retries=max_retries,
|
|
299
315
|
retry_delay=retry_delay,
|
|
316
|
+
base_url=base_url,
|
|
300
317
|
)
|
|
301
318
|
else:
|
|
302
319
|
supported_methods = ", ".join([m.value for m in MermaidDrawMethod])
|
|
@@ -311,17 +328,15 @@ def draw_mermaid_png(
|
|
|
311
328
|
|
|
312
329
|
async def _render_mermaid_using_pyppeteer(
|
|
313
330
|
mermaid_syntax: str,
|
|
314
|
-
output_file_path:
|
|
315
|
-
background_color:
|
|
331
|
+
output_file_path: str | None = None,
|
|
332
|
+
background_color: str | None = "white",
|
|
316
333
|
padding: int = 10,
|
|
317
334
|
device_scale_factor: int = 3,
|
|
318
335
|
) -> bytes:
|
|
319
336
|
"""Renders Mermaid graph using Pyppeteer."""
|
|
320
|
-
|
|
321
|
-
from pyppeteer import launch # type: ignore[import-not-found]
|
|
322
|
-
except ImportError as e:
|
|
337
|
+
if not _HAS_PYPPETEER:
|
|
323
338
|
msg = "Install Pyppeteer to use the Pyppeteer method: `pip install pyppeteer`."
|
|
324
|
-
raise ImportError(msg)
|
|
339
|
+
raise ImportError(msg)
|
|
325
340
|
|
|
326
341
|
browser = await launch()
|
|
327
342
|
page = await browser.newPage()
|
|
@@ -385,21 +400,23 @@ async def _render_mermaid_using_pyppeteer(
|
|
|
385
400
|
def _render_mermaid_using_api(
|
|
386
401
|
mermaid_syntax: str,
|
|
387
402
|
*,
|
|
388
|
-
output_file_path:
|
|
389
|
-
background_color:
|
|
390
|
-
file_type:
|
|
403
|
+
output_file_path: str | None = None,
|
|
404
|
+
background_color: str | None = "white",
|
|
405
|
+
file_type: Literal["jpeg", "png", "webp"] | None = "png",
|
|
391
406
|
max_retries: int = 1,
|
|
392
407
|
retry_delay: float = 1.0,
|
|
408
|
+
base_url: str | None = None,
|
|
393
409
|
) -> bytes:
|
|
394
410
|
"""Renders Mermaid graph using the Mermaid.INK API."""
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
411
|
+
# Defaults to using the public mermaid.ink server.
|
|
412
|
+
base_url = base_url if base_url is not None else "https://mermaid.ink"
|
|
413
|
+
|
|
414
|
+
if not _HAS_REQUESTS:
|
|
398
415
|
msg = (
|
|
399
416
|
"Install the `requests` module to use the Mermaid.INK API: "
|
|
400
417
|
"`pip install requests`."
|
|
401
418
|
)
|
|
402
|
-
raise ImportError(msg)
|
|
419
|
+
raise ImportError(msg)
|
|
403
420
|
|
|
404
421
|
# Use Mermaid API to render the image
|
|
405
422
|
mermaid_syntax_encoded = base64.b64encode(mermaid_syntax.encode("utf8")).decode(
|
|
@@ -413,7 +430,7 @@ def _render_mermaid_using_api(
|
|
|
413
430
|
background_color = f"!{background_color}"
|
|
414
431
|
|
|
415
432
|
image_url = (
|
|
416
|
-
f"
|
|
433
|
+
f"{base_url}/img/{mermaid_syntax_encoded}"
|
|
417
434
|
f"?type={file_type}&bgColor={background_color}"
|
|
418
435
|
)
|
|
419
436
|
|
|
@@ -445,7 +462,7 @@ def _render_mermaid_using_api(
|
|
|
445
462
|
|
|
446
463
|
# For other status codes, fail immediately
|
|
447
464
|
msg = (
|
|
448
|
-
"Failed to reach
|
|
465
|
+
f"Failed to reach {base_url} API while trying to render "
|
|
449
466
|
f"your graph. Status code: {response.status_code}.\n\n"
|
|
450
467
|
) + error_msg_suffix
|
|
451
468
|
raise ValueError(msg)
|
|
@@ -457,14 +474,14 @@ def _render_mermaid_using_api(
|
|
|
457
474
|
time.sleep(sleep_time)
|
|
458
475
|
else:
|
|
459
476
|
msg = (
|
|
460
|
-
"Failed to reach
|
|
477
|
+
f"Failed to reach {base_url} API while trying to render "
|
|
461
478
|
f"your graph after {max_retries} retries. "
|
|
462
479
|
) + error_msg_suffix
|
|
463
480
|
raise ValueError(msg) from e
|
|
464
481
|
|
|
465
482
|
# This should not be reached, but just in case
|
|
466
483
|
msg = (
|
|
467
|
-
"Failed to reach
|
|
484
|
+
f"Failed to reach {base_url} API while trying to render "
|
|
468
485
|
f"your graph after {max_retries} retries. "
|
|
469
486
|
) + error_msg_suffix
|
|
470
487
|
raise ValueError(msg)
|
|
@@ -1,36 +1,31 @@
|
|
|
1
1
|
"""Helper class to draw a state graph into a PNG file."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any
|
|
4
4
|
|
|
5
5
|
from langchain_core.runnables.graph import Graph, LabelsDict
|
|
6
6
|
|
|
7
|
+
try:
|
|
8
|
+
import pygraphviz as pgv # type: ignore[import-not-found]
|
|
9
|
+
|
|
10
|
+
_HAS_PYGRAPHVIZ = True
|
|
11
|
+
except ImportError:
|
|
12
|
+
_HAS_PYGRAPHVIZ = False
|
|
13
|
+
|
|
7
14
|
|
|
8
15
|
class PngDrawer:
|
|
9
16
|
"""Helper class to draw a state graph into a PNG file.
|
|
10
17
|
|
|
11
18
|
It requires `graphviz` and `pygraphviz` to be installed.
|
|
12
|
-
|
|
13
|
-
:
|
|
14
|
-
|
|
15
|
-
{
|
|
16
|
-
"nodes": {
|
|
17
|
-
"node1": "CustomLabel1",
|
|
18
|
-
"node2": "CustomLabel2",
|
|
19
|
-
"__end__": "End Node"
|
|
20
|
-
},
|
|
21
|
-
"edges": {
|
|
22
|
-
"continue": "ContinueLabel",
|
|
23
|
-
"end": "EndLabel"
|
|
24
|
-
}
|
|
25
|
-
}
|
|
26
|
-
The keys are the original labels, and the values are the new labels.
|
|
27
|
-
Usage:
|
|
19
|
+
|
|
20
|
+
Example:
|
|
21
|
+
```python
|
|
28
22
|
drawer = PngDrawer()
|
|
29
|
-
drawer.draw(state_graph,
|
|
23
|
+
drawer.draw(state_graph, "graph.png")
|
|
24
|
+
```
|
|
30
25
|
"""
|
|
31
26
|
|
|
32
27
|
def __init__(
|
|
33
|
-
self, fontname:
|
|
28
|
+
self, fontname: str | None = None, labels: LabelsDict | None = None
|
|
34
29
|
) -> None:
|
|
35
30
|
"""Initializes the PNG drawer.
|
|
36
31
|
|
|
@@ -50,7 +45,7 @@ class PngDrawer:
|
|
|
50
45
|
}
|
|
51
46
|
}
|
|
52
47
|
The keys are the original labels, and the values are the new labels.
|
|
53
|
-
|
|
48
|
+
|
|
54
49
|
"""
|
|
55
50
|
self.fontname = fontname or "arial"
|
|
56
51
|
self.labels = labels or LabelsDict(nodes={}, edges={})
|
|
@@ -85,9 +80,6 @@ class PngDrawer:
|
|
|
85
80
|
Args:
|
|
86
81
|
viz: The graphviz object.
|
|
87
82
|
node: The node to add.
|
|
88
|
-
|
|
89
|
-
Returns:
|
|
90
|
-
None
|
|
91
83
|
"""
|
|
92
84
|
viz.add_node(
|
|
93
85
|
node,
|
|
@@ -103,7 +95,7 @@ class PngDrawer:
|
|
|
103
95
|
viz: Any,
|
|
104
96
|
source: str,
|
|
105
97
|
target: str,
|
|
106
|
-
label:
|
|
98
|
+
label: str | None = None,
|
|
107
99
|
conditional: bool = False, # noqa: FBT001,FBT002
|
|
108
100
|
) -> None:
|
|
109
101
|
"""Adds an edge to the graph.
|
|
@@ -112,11 +104,8 @@ class PngDrawer:
|
|
|
112
104
|
viz: The graphviz object.
|
|
113
105
|
source: The source node.
|
|
114
106
|
target: The target node.
|
|
115
|
-
label: The label for the edge.
|
|
116
|
-
conditional: Whether the edge is conditional.
|
|
117
|
-
|
|
118
|
-
Returns:
|
|
119
|
-
None
|
|
107
|
+
label: The label for the edge.
|
|
108
|
+
conditional: Whether the edge is conditional.
|
|
120
109
|
"""
|
|
121
110
|
viz.add_edge(
|
|
122
111
|
source,
|
|
@@ -127,18 +116,24 @@ class PngDrawer:
|
|
|
127
116
|
style="dotted" if conditional else "solid",
|
|
128
117
|
)
|
|
129
118
|
|
|
130
|
-
def draw(self, graph: Graph, output_path:
|
|
119
|
+
def draw(self, graph: Graph, output_path: str | None = None) -> bytes | None:
|
|
131
120
|
"""Draw the given state graph into a PNG file.
|
|
132
121
|
|
|
133
122
|
Requires `graphviz` and `pygraphviz` to be installed.
|
|
134
|
-
|
|
135
|
-
:
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
graph: The graph to draw
|
|
126
|
+
output_path: The path to save the PNG. If `None`, PNG bytes are returned.
|
|
127
|
+
|
|
128
|
+
Raises:
|
|
129
|
+
ImportError: If `pygraphviz` is not installed.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
The PNG bytes if `output_path` is None, else None.
|
|
136
133
|
"""
|
|
137
|
-
|
|
138
|
-
import pygraphviz as pgv # type: ignore[import-not-found]
|
|
139
|
-
except ImportError as exc:
|
|
134
|
+
if not _HAS_PYGRAPHVIZ:
|
|
140
135
|
msg = "Install pygraphviz to draw graphs: `pip install pygraphviz`."
|
|
141
|
-
raise ImportError(msg)
|
|
136
|
+
raise ImportError(msg)
|
|
142
137
|
|
|
143
138
|
# Create a directed graph
|
|
144
139
|
viz = pgv.AGraph(directed=True, nodesep=0.9, ranksep=1.0)
|