pydantic-graph 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_graph-0.0.1/.gitignore +15 -0
- pydantic_graph-0.0.1/PKG-INFO +47 -0
- pydantic_graph-0.0.1/README.md +16 -0
- pydantic_graph-0.0.1/pydantic_graph/__init__.py +14 -0
- pydantic_graph-0.0.1/pydantic_graph/_utils.py +70 -0
- pydantic_graph-0.0.1/pydantic_graph/graph.py +135 -0
- pydantic_graph-0.0.1/pydantic_graph/mermaid.py +210 -0
- pydantic_graph-0.0.1/pydantic_graph/nodes.py +93 -0
- pydantic_graph-0.0.1/pydantic_graph/py.typed +0 -0
- pydantic_graph-0.0.1/pydantic_graph/state.py +82 -0
- pydantic_graph-0.0.1/pyproject.toml +43 -0
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: pydantic-graph
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Graph and state machine library
|
|
5
|
+
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Classifier: Development Status :: 4 - Beta
|
|
8
|
+
Classifier: Environment :: Console
|
|
9
|
+
Classifier: Environment :: MacOS X
|
|
10
|
+
Classifier: Intended Audience :: Developers
|
|
11
|
+
Classifier: Intended Audience :: Information Technology
|
|
12
|
+
Classifier: Intended Audience :: System Administrators
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Operating System :: POSIX :: Linux
|
|
15
|
+
Classifier: Operating System :: Unix
|
|
16
|
+
Classifier: Programming Language :: Python
|
|
17
|
+
Classifier: Programming Language :: Python :: 3
|
|
18
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
21
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
22
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
23
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
24
|
+
Classifier: Topic :: Internet
|
|
25
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
26
|
+
Requires-Python: >=3.9
|
|
27
|
+
Requires-Dist: httpx>=0.27.2
|
|
28
|
+
Requires-Dist: logfire-api>=1.2.0
|
|
29
|
+
Requires-Dist: pydantic>=2.10
|
|
30
|
+
Description-Content-Type: text/markdown
|
|
31
|
+
|
|
32
|
+
# Pydantic Graph
|
|
33
|
+
|
|
34
|
+
[](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml?query=branch%3Amain)
|
|
35
|
+
[](https://coverage-badge.samuelcolvin.workers.dev/redirect/pydantic/pydantic-ai)
|
|
36
|
+
[](https://pypi.python.org/pypi/pydantic-graph)
|
|
37
|
+
[](https://github.com/pydantic/pydantic-ai)
|
|
38
|
+
[](https://github.com/pydantic/pydantic-ai/blob/main/LICENSE)
|
|
39
|
+
|
|
40
|
+
Graph and finite state machine library.
|
|
41
|
+
|
|
42
|
+
This library is developed as part of the [PydanticAI](https://ai.pydantic.dev), however it has no dependency
|
|
43
|
+
on `pydantic-ai` or related packages and does and can be considered as a pure graph library.
|
|
44
|
+
|
|
45
|
+
As with PydanticAI, this library prioritizes type safety and use of common Python syntax over esoteric, domain-specific use of Python syntax.
|
|
46
|
+
|
|
47
|
+
`pydantic-graph` allows you to define graphs using simple Python syntax. In particular, edges are defined using the return type hint of nodes.
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Pydantic Graph
|
|
2
|
+
|
|
3
|
+
[](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml?query=branch%3Amain)
|
|
4
|
+
[](https://coverage-badge.samuelcolvin.workers.dev/redirect/pydantic/pydantic-ai)
|
|
5
|
+
[](https://pypi.python.org/pypi/pydantic-graph)
|
|
6
|
+
[](https://github.com/pydantic/pydantic-ai)
|
|
7
|
+
[](https://github.com/pydantic/pydantic-ai/blob/main/LICENSE)
|
|
8
|
+
|
|
9
|
+
Graph and finite state machine library.
|
|
10
|
+
|
|
11
|
+
This library is developed as part of the [PydanticAI](https://ai.pydantic.dev), however it has no dependency
|
|
12
|
+
on `pydantic-ai` or related packages and does and can be considered as a pure graph library.
|
|
13
|
+
|
|
14
|
+
As with PydanticAI, this library prioritizes type safety and use of common Python syntax over esoteric, domain-specific use of Python syntax.
|
|
15
|
+
|
|
16
|
+
`pydantic-graph` allows you to define graphs using simple Python syntax. In particular, edges are defined using the return type hint of nodes.
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from .graph import Graph
|
|
2
|
+
from .nodes import BaseNode, End, GraphContext
|
|
3
|
+
from .state import AbstractState, EndEvent, HistoryStep, NodeEvent
|
|
4
|
+
|
|
5
|
+
__all__ = (
|
|
6
|
+
'Graph',
|
|
7
|
+
'BaseNode',
|
|
8
|
+
'End',
|
|
9
|
+
'GraphContext',
|
|
10
|
+
'AbstractState',
|
|
11
|
+
'EndEvent',
|
|
12
|
+
'HistoryStep',
|
|
13
|
+
'NodeEvent',
|
|
14
|
+
)
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
import types
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from typing import Any, Union, get_args, get_origin
|
|
7
|
+
|
|
8
|
+
from typing_extensions import TypeAliasType
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_union_args(tp: Any) -> tuple[Any, ...]:
|
|
12
|
+
"""Extract the arguments of a Union type if `response_type` is a union, otherwise return the original type."""
|
|
13
|
+
# similar to `pydantic_ai_slim/pydantic_ai/_result.py:get_union_args`
|
|
14
|
+
if isinstance(tp, TypeAliasType):
|
|
15
|
+
tp = tp.__value__
|
|
16
|
+
|
|
17
|
+
origin = get_origin(tp)
|
|
18
|
+
if origin_is_union(origin):
|
|
19
|
+
return get_args(tp)
|
|
20
|
+
else:
|
|
21
|
+
return (tp,)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# same as `pydantic_ai_slim/pydantic_ai/_result.py:origin_is_union`
|
|
25
|
+
if sys.version_info < (3, 10):
|
|
26
|
+
|
|
27
|
+
def origin_is_union(tp: type[Any] | None) -> bool:
|
|
28
|
+
return tp is Union
|
|
29
|
+
|
|
30
|
+
else:
|
|
31
|
+
|
|
32
|
+
def origin_is_union(tp: type[Any] | None) -> bool:
|
|
33
|
+
return tp is Union or tp is types.UnionType
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def comma_and(items: list[str]) -> str:
|
|
37
|
+
"""Join with a comma and 'and' for the last item."""
|
|
38
|
+
if len(items) == 1:
|
|
39
|
+
return items[0]
|
|
40
|
+
else:
|
|
41
|
+
# oxford comma ¯\_(ツ)_/¯
|
|
42
|
+
return ', '.join(items[:-1]) + ', and ' + items[-1]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
_NoneType = type(None)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def type_arg_name(arg: Any) -> str:
|
|
49
|
+
if arg is _NoneType:
|
|
50
|
+
return 'None'
|
|
51
|
+
else:
|
|
52
|
+
return arg.__name__
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def get_parent_namespace(frame: types.FrameType | None) -> dict[str, Any] | None:
|
|
56
|
+
"""Attempt to get the namespace where the graph was defined.
|
|
57
|
+
|
|
58
|
+
If the graph is defined with generics `Graph[a, b]` then another frame is inserted, and we have to skip that
|
|
59
|
+
to get the correct namespace.
|
|
60
|
+
"""
|
|
61
|
+
if frame is not None:
|
|
62
|
+
if back := frame.f_back:
|
|
63
|
+
if back.f_code.co_filename.endswith('/typing.py'):
|
|
64
|
+
return get_parent_namespace(back)
|
|
65
|
+
else:
|
|
66
|
+
return back.f_locals
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def now_utc() -> datetime:
|
|
70
|
+
return datetime.now(tz=timezone.utc)
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from time import perf_counter
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Generic
|
|
9
|
+
|
|
10
|
+
import logfire_api
|
|
11
|
+
from typing_extensions import Never, ParamSpec, TypeVar, Unpack, assert_never
|
|
12
|
+
|
|
13
|
+
from . import _utils, mermaid
|
|
14
|
+
from ._utils import get_parent_namespace
|
|
15
|
+
from .nodes import BaseNode, End, GraphContext, NodeDef
|
|
16
|
+
from .state import EndEvent, HistoryStep, NodeEvent, StateT
|
|
17
|
+
|
|
18
|
+
__all__ = ('Graph',)
|
|
19
|
+
|
|
20
|
+
_logfire = logfire_api.Logfire(otel_scope='pydantic-graph')
|
|
21
|
+
|
|
22
|
+
RunSignatureT = ParamSpec('RunSignatureT')
|
|
23
|
+
RunEndT = TypeVar('RunEndT', default=None)
|
|
24
|
+
NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(init=False)
|
|
28
|
+
class Graph(Generic[StateT, RunEndT]):
|
|
29
|
+
"""Definition of a graph."""
|
|
30
|
+
|
|
31
|
+
name: str | None
|
|
32
|
+
nodes: tuple[type[BaseNode[StateT, RunEndT]], ...]
|
|
33
|
+
node_defs: dict[str, NodeDef[StateT, RunEndT]]
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
*,
|
|
38
|
+
nodes: Sequence[type[BaseNode[StateT, RunEndT]]],
|
|
39
|
+
state_type: type[StateT] | None = None,
|
|
40
|
+
name: str | None = None,
|
|
41
|
+
):
|
|
42
|
+
self.name = name
|
|
43
|
+
|
|
44
|
+
_nodes_by_id: dict[str, type[BaseNode[StateT, RunEndT]]] = {}
|
|
45
|
+
for node in nodes:
|
|
46
|
+
node_id = node.get_id()
|
|
47
|
+
if (existing_node := _nodes_by_id.get(node_id)) and existing_node is not node:
|
|
48
|
+
raise ValueError(f'Node ID "{node_id}" is not unique — found in {existing_node} and {node}')
|
|
49
|
+
else:
|
|
50
|
+
_nodes_by_id[node_id] = node
|
|
51
|
+
self.nodes = tuple(_nodes_by_id.values())
|
|
52
|
+
|
|
53
|
+
parent_namespace = get_parent_namespace(inspect.currentframe())
|
|
54
|
+
self.node_defs: dict[str, NodeDef[StateT, RunEndT]] = {}
|
|
55
|
+
for node in self.nodes:
|
|
56
|
+
self.node_defs[node.get_id()] = node.get_node_def(parent_namespace)
|
|
57
|
+
|
|
58
|
+
self._validate_edges()
|
|
59
|
+
|
|
60
|
+
def _validate_edges(self):
|
|
61
|
+
known_node_ids = set(self.node_defs.keys())
|
|
62
|
+
bad_edges: dict[str, list[str]] = {}
|
|
63
|
+
|
|
64
|
+
for node_id, node_def in self.node_defs.items():
|
|
65
|
+
node_bad_edges = node_def.next_node_ids - known_node_ids
|
|
66
|
+
for bad_edge in node_bad_edges:
|
|
67
|
+
bad_edges.setdefault(bad_edge, []).append(f'"{node_id}"')
|
|
68
|
+
|
|
69
|
+
if bad_edges:
|
|
70
|
+
bad_edges_list = [f'"{k}" is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()]
|
|
71
|
+
if len(bad_edges_list) == 1:
|
|
72
|
+
raise ValueError(f'{bad_edges_list[0]} but not included in the graph.')
|
|
73
|
+
else:
|
|
74
|
+
b = '\n'.join(f' {be}' for be in bad_edges_list)
|
|
75
|
+
raise ValueError(f'Nodes are referenced in the graph but not included in the graph:\n{b}')
|
|
76
|
+
|
|
77
|
+
async def next(
|
|
78
|
+
self, state: StateT, node: BaseNode[StateT, RunEndT], history: list[HistoryStep[StateT, RunEndT]]
|
|
79
|
+
) -> BaseNode[StateT, Any] | End[RunEndT]:
|
|
80
|
+
node_id = node.get_id()
|
|
81
|
+
if node_id not in self.node_defs:
|
|
82
|
+
raise TypeError(f'Node "{node}" is not in the graph.')
|
|
83
|
+
|
|
84
|
+
history_step: NodeEvent[StateT, RunEndT] | None = NodeEvent(state, node)
|
|
85
|
+
history.append(history_step)
|
|
86
|
+
|
|
87
|
+
ctx = GraphContext(state)
|
|
88
|
+
with _logfire.span('run node {node_id}', node_id=node_id, node=node):
|
|
89
|
+
start = perf_counter()
|
|
90
|
+
next_node = await node.run(ctx)
|
|
91
|
+
history_step.duration = perf_counter() - start
|
|
92
|
+
return next_node
|
|
93
|
+
|
|
94
|
+
async def run(
|
|
95
|
+
self,
|
|
96
|
+
state: StateT,
|
|
97
|
+
node: BaseNode[StateT, RunEndT],
|
|
98
|
+
) -> tuple[End[RunEndT], list[HistoryStep[StateT, RunEndT]]]:
|
|
99
|
+
history: list[HistoryStep[StateT, RunEndT]] = []
|
|
100
|
+
|
|
101
|
+
with _logfire.span(
|
|
102
|
+
'{graph_name} run {start=}',
|
|
103
|
+
graph_name=self.name or 'graph',
|
|
104
|
+
start=node,
|
|
105
|
+
) as run_span:
|
|
106
|
+
while True:
|
|
107
|
+
next_node = await self.next(state, node, history=history)
|
|
108
|
+
if isinstance(next_node, End):
|
|
109
|
+
history.append(EndEvent(state, next_node))
|
|
110
|
+
run_span.set_attribute('history', history)
|
|
111
|
+
return next_node, history
|
|
112
|
+
elif isinstance(next_node, BaseNode):
|
|
113
|
+
node = next_node
|
|
114
|
+
else:
|
|
115
|
+
if TYPE_CHECKING:
|
|
116
|
+
assert_never(next_node)
|
|
117
|
+
else:
|
|
118
|
+
raise TypeError(f'Invalid node type: {type(next_node)}. Expected `BaseNode` or `End`.')
|
|
119
|
+
|
|
120
|
+
def mermaid_code(
|
|
121
|
+
self,
|
|
122
|
+
*,
|
|
123
|
+
start_node: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None,
|
|
124
|
+
highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None,
|
|
125
|
+
highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS,
|
|
126
|
+
) -> str:
|
|
127
|
+
return mermaid.generate_code(
|
|
128
|
+
self, start_node=start_node, highlighted_nodes=highlighted_nodes, highlight_css=highlight_css
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def mermaid_image(self, **kwargs: Unpack[mermaid.MermaidConfig]) -> bytes:
|
|
132
|
+
return mermaid.request_image(self, **kwargs)
|
|
133
|
+
|
|
134
|
+
def mermaid_save(self, path: Path | str, /, **kwargs: Unpack[mermaid.MermaidConfig]) -> None:
|
|
135
|
+
mermaid.save_image(path, self, **kwargs)
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
from collections.abc import Iterable, Sequence
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
|
7
|
+
|
|
8
|
+
from annotated_types import Ge, Le
|
|
9
|
+
from typing_extensions import TypeAlias, TypedDict, Unpack
|
|
10
|
+
|
|
11
|
+
from .nodes import BaseNode
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from .graph import Graph
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
NodeIdent: TypeAlias = 'type[BaseNode[Any, Any]] | BaseNode[Any, Any] | str'
|
|
18
|
+
DEFAULT_HIGHLIGHT_CSS = 'fill:#fdff32'
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def generate_code(
|
|
22
|
+
graph: Graph[Any, Any],
|
|
23
|
+
/,
|
|
24
|
+
*,
|
|
25
|
+
start_node: Sequence[NodeIdent] | NodeIdent | None = None,
|
|
26
|
+
highlighted_nodes: Sequence[NodeIdent] | NodeIdent | None = None,
|
|
27
|
+
highlight_css: str = DEFAULT_HIGHLIGHT_CSS,
|
|
28
|
+
) -> str:
|
|
29
|
+
"""Generate Mermaid code for a graph.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
graph: The graph to generate the image for.
|
|
33
|
+
start_node: Identifiers of nodes that start the graph.
|
|
34
|
+
highlighted_nodes: Identifiers of nodes to highlight.
|
|
35
|
+
highlight_css: CSS to use for highlighting nodes.
|
|
36
|
+
|
|
37
|
+
Returns: The Mermaid code for the graph.
|
|
38
|
+
"""
|
|
39
|
+
start_node_ids = set(node_ids(start_node or ()))
|
|
40
|
+
for node_id in start_node_ids:
|
|
41
|
+
if node_id not in graph.node_defs:
|
|
42
|
+
raise LookupError(f'Start node "{node_id}" is not in the graph.')
|
|
43
|
+
|
|
44
|
+
node_order = {node_id: index for index, node_id in enumerate(graph.node_defs)}
|
|
45
|
+
|
|
46
|
+
lines = ['graph TD']
|
|
47
|
+
for node in graph.nodes:
|
|
48
|
+
node_id = node.get_id()
|
|
49
|
+
node_def = graph.node_defs[node_id]
|
|
50
|
+
|
|
51
|
+
# we use round brackets (rounded box) for nodes other than the start and end
|
|
52
|
+
mermaid_name = f'({node_id})'
|
|
53
|
+
if node_id in start_node_ids:
|
|
54
|
+
lines.append(f' START --> {node_id}{mermaid_name}')
|
|
55
|
+
if node_def.returns_base_node:
|
|
56
|
+
for next_node_id in graph.nodes:
|
|
57
|
+
lines.append(f' {node_id}{mermaid_name} --> {next_node_id}')
|
|
58
|
+
else:
|
|
59
|
+
for _, next_node_id in sorted((node_order[node_id], node_id) for node_id in node_def.next_node_ids):
|
|
60
|
+
lines.append(f' {node_id}{mermaid_name} --> {next_node_id}')
|
|
61
|
+
if node_def.returns_end:
|
|
62
|
+
lines.append(f' {node_id}{mermaid_name} --> END')
|
|
63
|
+
|
|
64
|
+
if highlighted_nodes:
|
|
65
|
+
lines.append('')
|
|
66
|
+
lines.append(f'classDef highlighted {highlight_css}')
|
|
67
|
+
for node_id in node_ids(highlighted_nodes):
|
|
68
|
+
if node_id not in graph.node_defs:
|
|
69
|
+
raise LookupError(f'Highlighted node "{node_id}" is not in the graph.')
|
|
70
|
+
lines.append(f'class {node_id} highlighted')
|
|
71
|
+
|
|
72
|
+
return '\n'.join(lines)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def node_ids(node_idents: Sequence[NodeIdent] | NodeIdent) -> Iterable[str]:
|
|
76
|
+
"""Get the node IDs from a sequence of node identifiers."""
|
|
77
|
+
if isinstance(node_idents, str):
|
|
78
|
+
node_iter = (node_idents,)
|
|
79
|
+
elif isinstance(node_idents, Sequence):
|
|
80
|
+
node_iter = node_idents
|
|
81
|
+
else:
|
|
82
|
+
node_iter = (node_idents,)
|
|
83
|
+
|
|
84
|
+
for node in node_iter:
|
|
85
|
+
if isinstance(node, str):
|
|
86
|
+
yield node
|
|
87
|
+
else:
|
|
88
|
+
yield node.get_id()
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class MermaidConfig(TypedDict, total=False):
|
|
92
|
+
"""Parameters to configure mermaid chart generation."""
|
|
93
|
+
|
|
94
|
+
start_node: Sequence[NodeIdent] | NodeIdent
|
|
95
|
+
"""Identifiers of nodes that start the graph."""
|
|
96
|
+
highlighted_nodes: Sequence[NodeIdent] | NodeIdent
|
|
97
|
+
"""Identifiers of nodes to highlight."""
|
|
98
|
+
highlight_css: str
|
|
99
|
+
"""CSS to use for highlighting nodes."""
|
|
100
|
+
image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf']
|
|
101
|
+
"""The image type to generate. If unspecified, the default behavior is `'jpeg'`."""
|
|
102
|
+
pdf_fit: bool
|
|
103
|
+
"""When using image_type='pdf', whether to fit the diagram to the PDF page."""
|
|
104
|
+
pdf_landscape: bool
|
|
105
|
+
"""When using image_type='pdf', whether to use landscape orientation for the PDF.
|
|
106
|
+
|
|
107
|
+
This has no effect if using `pdf_fit`.
|
|
108
|
+
"""
|
|
109
|
+
pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6']
|
|
110
|
+
"""When using image_type='pdf', the paper size of the PDF."""
|
|
111
|
+
background_color: str
|
|
112
|
+
"""The background color of the diagram.
|
|
113
|
+
|
|
114
|
+
If None, the default transparent background is used. The color value is interpreted as a hexadecimal color
|
|
115
|
+
code by default (and should not have a leading '#'), but you can also use named colors by prefixing the
|
|
116
|
+
value with `'!'`. For example, valid choices include `background_color='!white'` or `background_color='FF0000'`.
|
|
117
|
+
"""
|
|
118
|
+
theme: Literal['default', 'neutral', 'dark', 'forest']
|
|
119
|
+
"""The theme of the diagram. Defaults to 'default'."""
|
|
120
|
+
width: int
|
|
121
|
+
"""The width of the diagram."""
|
|
122
|
+
height: int
|
|
123
|
+
"""The height of the diagram."""
|
|
124
|
+
scale: Annotated[float, Ge(1), Le(3)]
|
|
125
|
+
"""The scale of the diagram.
|
|
126
|
+
|
|
127
|
+
The scale must be a number between 1 and 3, and you can only set a scale if one or both of width and height are set.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def request_image(
|
|
132
|
+
graph: Graph[Any, Any],
|
|
133
|
+
/,
|
|
134
|
+
**kwargs: Unpack[MermaidConfig],
|
|
135
|
+
) -> bytes:
|
|
136
|
+
"""Generate an image of a Mermaid diagram using [mermaid.ink](https://mermaid.ink).
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
graph: The graph to generate the image for.
|
|
140
|
+
**kwargs: Additional parameters to configure mermaid chart generation.
|
|
141
|
+
|
|
142
|
+
Returns: The image data.
|
|
143
|
+
"""
|
|
144
|
+
import httpx
|
|
145
|
+
|
|
146
|
+
code = generate_code(
|
|
147
|
+
graph,
|
|
148
|
+
start_node=kwargs.get('start_node'),
|
|
149
|
+
highlighted_nodes=kwargs.get('highlighted_nodes'),
|
|
150
|
+
highlight_css=kwargs.get('highlight_css', DEFAULT_HIGHLIGHT_CSS),
|
|
151
|
+
)
|
|
152
|
+
code_base64 = base64.b64encode(code.encode()).decode()
|
|
153
|
+
|
|
154
|
+
params: dict[str, str | bool] = {}
|
|
155
|
+
if kwargs.get('image_type') == 'pdf':
|
|
156
|
+
url = f'https://mermaid.ink/pdf/{code_base64}'
|
|
157
|
+
if kwargs.get('pdf_fit'):
|
|
158
|
+
params['fit'] = True
|
|
159
|
+
if kwargs.get('pdf_landscape'):
|
|
160
|
+
params['landscape'] = True
|
|
161
|
+
if pdf_paper := kwargs.get('pdf_paper'):
|
|
162
|
+
params['paper'] = pdf_paper
|
|
163
|
+
elif kwargs.get('image_type') == 'svg':
|
|
164
|
+
url = f'https://mermaid.ink/svg/{code_base64}'
|
|
165
|
+
else:
|
|
166
|
+
url = f'https://mermaid.ink/img/{code_base64}'
|
|
167
|
+
|
|
168
|
+
if image_type := kwargs.get('image_type'):
|
|
169
|
+
params['type'] = image_type
|
|
170
|
+
|
|
171
|
+
if background_color := kwargs.get('background_color'):
|
|
172
|
+
params['bgColor'] = background_color
|
|
173
|
+
if theme := kwargs.get('theme'):
|
|
174
|
+
params['theme'] = theme
|
|
175
|
+
if width := kwargs.get('width'):
|
|
176
|
+
params['width'] = str(width)
|
|
177
|
+
if height := kwargs.get('height'):
|
|
178
|
+
params['height'] = str(height)
|
|
179
|
+
if scale := kwargs.get('scale'):
|
|
180
|
+
params['scale'] = str(scale)
|
|
181
|
+
|
|
182
|
+
response = httpx.get(url, params=params)
|
|
183
|
+
response.raise_for_status()
|
|
184
|
+
return response.content
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def save_image(
|
|
188
|
+
path: Path | str,
|
|
189
|
+
graph: Graph[Any, Any],
|
|
190
|
+
/,
|
|
191
|
+
**kwargs: Unpack[MermaidConfig],
|
|
192
|
+
) -> None:
|
|
193
|
+
"""Generate an image of a Mermaid diagram using [mermaid.ink](https://mermaid.ink) and save it to a local file.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
path: The path to save the image to.
|
|
197
|
+
graph: The graph to generate the image for.
|
|
198
|
+
**kwargs: Additional parameters to configure mermaid chart generation.
|
|
199
|
+
"""
|
|
200
|
+
if isinstance(path, str):
|
|
201
|
+
path = Path(path)
|
|
202
|
+
|
|
203
|
+
if 'image_type' not in kwargs:
|
|
204
|
+
ext = path.suffix.lower()[1:]
|
|
205
|
+
# no need to check for .jpeg/.jpg, as it is the default
|
|
206
|
+
if ext in ('png', 'webp', 'svg', 'pdf'):
|
|
207
|
+
kwargs['image_type'] = ext
|
|
208
|
+
|
|
209
|
+
image_data = request_image(graph, **kwargs)
|
|
210
|
+
path.write_bytes(image_data)
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from abc import abstractmethod
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from functools import cache
|
|
6
|
+
from typing import Any, Generic, get_origin, get_type_hints
|
|
7
|
+
|
|
8
|
+
from typing_extensions import Never, TypeVar
|
|
9
|
+
|
|
10
|
+
from . import _utils
|
|
11
|
+
from .state import StateT
|
|
12
|
+
|
|
13
|
+
__all__ = 'GraphContext', 'BaseNode', 'End', 'NodeDef'
|
|
14
|
+
|
|
15
|
+
RunEndT = TypeVar('RunEndT', default=None)
|
|
16
|
+
NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class GraphContext(Generic[StateT]):
|
|
21
|
+
"""Context for a graph."""
|
|
22
|
+
|
|
23
|
+
state: StateT
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class BaseNode(Generic[StateT, NodeRunEndT]):
|
|
27
|
+
"""Base class for a node."""
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any] | End[NodeRunEndT]: ...
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
@cache
|
|
34
|
+
def get_id(cls) -> str:
|
|
35
|
+
return cls.__name__
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRunEndT]:
|
|
39
|
+
type_hints = get_type_hints(cls.run, localns=local_ns)
|
|
40
|
+
try:
|
|
41
|
+
return_hint = type_hints['return']
|
|
42
|
+
except KeyError:
|
|
43
|
+
raise TypeError(f'Node {cls} is missing a return type hint on its `run` method')
|
|
44
|
+
|
|
45
|
+
next_node_ids: set[str] = set()
|
|
46
|
+
returns_end: bool = False
|
|
47
|
+
returns_base_node: bool = False
|
|
48
|
+
for return_type in _utils.get_union_args(return_hint):
|
|
49
|
+
return_type_origin = get_origin(return_type) or return_type
|
|
50
|
+
if return_type_origin is End:
|
|
51
|
+
returns_end = True
|
|
52
|
+
elif return_type_origin is BaseNode:
|
|
53
|
+
# TODO: Should we disallow this?
|
|
54
|
+
returns_base_node = True
|
|
55
|
+
elif issubclass(return_type_origin, BaseNode):
|
|
56
|
+
next_node_ids.add(return_type.get_id())
|
|
57
|
+
else:
|
|
58
|
+
raise TypeError(f'Invalid return type: {return_type}')
|
|
59
|
+
|
|
60
|
+
return NodeDef(
|
|
61
|
+
cls,
|
|
62
|
+
cls.get_id(),
|
|
63
|
+
next_node_ids,
|
|
64
|
+
returns_end,
|
|
65
|
+
returns_base_node,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@dataclass
|
|
70
|
+
class End(Generic[RunEndT]):
|
|
71
|
+
"""Type to return from a node to signal the end of the graph."""
|
|
72
|
+
|
|
73
|
+
data: RunEndT
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@dataclass
|
|
77
|
+
class NodeDef(Generic[StateT, NodeRunEndT]):
|
|
78
|
+
"""Definition of a node.
|
|
79
|
+
|
|
80
|
+
Used by [`Graph`][pydantic_graph.graph.Graph] to store information about a node, and when generating
|
|
81
|
+
mermaid graphs.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
node: type[BaseNode[StateT, NodeRunEndT]]
|
|
85
|
+
"""The node definition itself."""
|
|
86
|
+
node_id: str
|
|
87
|
+
"""ID of the node."""
|
|
88
|
+
next_node_ids: set[str]
|
|
89
|
+
"""IDs of the nodes that can be called next."""
|
|
90
|
+
returns_end: bool
|
|
91
|
+
"""The node definition returns an `End`, hence the node can end the run."""
|
|
92
|
+
returns_base_node: bool
|
|
93
|
+
"""The node definition returns a `BaseNode`, hence any node in the next can be called next."""
|
|
File without changes
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from typing import TYPE_CHECKING, Generic, Literal, Self, Union
|
|
8
|
+
|
|
9
|
+
from typing_extensions import Never, TypeVar
|
|
10
|
+
|
|
11
|
+
from . import _utils
|
|
12
|
+
|
|
13
|
+
__all__ = 'AbstractState', 'StateT', 'NodeEvent', 'EndEvent', 'HistoryStep'
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from pydantic_graph import BaseNode
|
|
17
|
+
from pydantic_graph.nodes import End
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AbstractState(ABC):
|
|
21
|
+
"""Abstract class for a state object."""
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def serialize(self) -> bytes | None:
|
|
25
|
+
"""Serialize the state object."""
|
|
26
|
+
raise NotImplementedError
|
|
27
|
+
|
|
28
|
+
def deep_copy(self) -> Self:
|
|
29
|
+
"""Create a deep copy of the state object."""
|
|
30
|
+
return copy.deepcopy(self)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
RunEndT = TypeVar('RunEndT', default=None)
|
|
34
|
+
NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never)
|
|
35
|
+
StateT = TypeVar('StateT', bound=Union[None, AbstractState], default=None)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class NodeEvent(Generic[StateT, RunEndT]):
|
|
40
|
+
"""History step describing the execution of a node in a graph."""
|
|
41
|
+
|
|
42
|
+
state: StateT
|
|
43
|
+
node: BaseNode[StateT, RunEndT]
|
|
44
|
+
start_ts: datetime = field(default_factory=_utils.now_utc)
|
|
45
|
+
duration: float | None = None
|
|
46
|
+
|
|
47
|
+
kind: Literal['step'] = 'step'
|
|
48
|
+
|
|
49
|
+
def __post_init__(self):
|
|
50
|
+
# Copy the state to prevent it from being modified by other code
|
|
51
|
+
self.state = _deep_copy_state(self.state)
|
|
52
|
+
|
|
53
|
+
def summary(self) -> str:
|
|
54
|
+
return str(self.node)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class EndEvent(Generic[StateT, RunEndT]):
|
|
59
|
+
"""History step describing the end of a graph run."""
|
|
60
|
+
|
|
61
|
+
state: StateT
|
|
62
|
+
result: End[RunEndT]
|
|
63
|
+
ts: datetime = field(default_factory=_utils.now_utc)
|
|
64
|
+
|
|
65
|
+
kind: Literal['end'] = 'end'
|
|
66
|
+
|
|
67
|
+
def __post_init__(self):
|
|
68
|
+
# Copy the state to prevent it from being modified by other code
|
|
69
|
+
self.state = _deep_copy_state(self.state)
|
|
70
|
+
|
|
71
|
+
def summary(self) -> str:
|
|
72
|
+
return str(self.result)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _deep_copy_state(state: StateT) -> StateT:
|
|
76
|
+
if state is None:
|
|
77
|
+
return state
|
|
78
|
+
else:
|
|
79
|
+
return state.deep_copy()
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
HistoryStep = Union[NodeEvent[StateT, RunEndT], EndEvent[StateT, RunEndT]]
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "pydantic-graph"
|
|
7
|
+
version = "0.0.1"
|
|
8
|
+
description = "Graph and state machine library"
|
|
9
|
+
authors = [
|
|
10
|
+
{ name = "Samuel Colvin", email = "samuel@pydantic.dev" },
|
|
11
|
+
]
|
|
12
|
+
license = "MIT"
|
|
13
|
+
readme = "README.md"
|
|
14
|
+
classifiers = [
|
|
15
|
+
"Development Status :: 4 - Beta",
|
|
16
|
+
"Programming Language :: Python",
|
|
17
|
+
"Programming Language :: Python :: 3",
|
|
18
|
+
"Programming Language :: Python :: 3 :: Only",
|
|
19
|
+
"Programming Language :: Python :: 3.9",
|
|
20
|
+
"Programming Language :: Python :: 3.10",
|
|
21
|
+
"Programming Language :: Python :: 3.11",
|
|
22
|
+
"Programming Language :: Python :: 3.12",
|
|
23
|
+
"Programming Language :: Python :: 3.13",
|
|
24
|
+
"Intended Audience :: Developers",
|
|
25
|
+
"Intended Audience :: Information Technology",
|
|
26
|
+
"Intended Audience :: System Administrators",
|
|
27
|
+
"License :: OSI Approved :: MIT License",
|
|
28
|
+
"Operating System :: Unix",
|
|
29
|
+
"Operating System :: POSIX :: Linux",
|
|
30
|
+
"Environment :: Console",
|
|
31
|
+
"Environment :: MacOS X",
|
|
32
|
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
|
33
|
+
"Topic :: Internet",
|
|
34
|
+
]
|
|
35
|
+
requires-python = ">=3.9"
|
|
36
|
+
dependencies = [
|
|
37
|
+
"httpx>=0.27.2",
|
|
38
|
+
"logfire-api>=1.2.0",
|
|
39
|
+
"pydantic>=2.10",
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
[tool.hatch.build.targets.wheel]
|
|
43
|
+
packages = ["pydantic_graph"]
|