pydantic-graph 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,15 @@
1
+ site
2
+ .python-version
3
+ .venv
4
+ dist
5
+ __pycache__
6
+ *.env
7
+ /scratch/
8
+ /.coverage
9
+ env*/
10
+ /TODO.md
11
+ /postgres-data/
12
+ .DS_Store
13
+ examples/pydantic_ai_examples/.chat_app_messages.sqlite
14
+ .cache/
15
+ .vscode/
@@ -0,0 +1,47 @@
1
+ Metadata-Version: 2.4
2
+ Name: pydantic-graph
3
+ Version: 0.0.1
4
+ Summary: Graph and state machine library
5
+ Author-email: Samuel Colvin <samuel@pydantic.dev>
6
+ License-Expression: MIT
7
+ Classifier: Development Status :: 4 - Beta
8
+ Classifier: Environment :: Console
9
+ Classifier: Environment :: MacOS X
10
+ Classifier: Intended Audience :: Developers
11
+ Classifier: Intended Audience :: Information Technology
12
+ Classifier: Intended Audience :: System Administrators
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Operating System :: POSIX :: Linux
15
+ Classifier: Operating System :: Unix
16
+ Classifier: Programming Language :: Python
17
+ Classifier: Programming Language :: Python :: 3
18
+ Classifier: Programming Language :: Python :: 3 :: Only
19
+ Classifier: Programming Language :: Python :: 3.9
20
+ Classifier: Programming Language :: Python :: 3.10
21
+ Classifier: Programming Language :: Python :: 3.11
22
+ Classifier: Programming Language :: Python :: 3.12
23
+ Classifier: Programming Language :: Python :: 3.13
24
+ Classifier: Topic :: Internet
25
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
26
+ Requires-Python: >=3.9
27
+ Requires-Dist: httpx>=0.27.2
28
+ Requires-Dist: logfire-api>=1.2.0
29
+ Requires-Dist: pydantic>=2.10
30
+ Description-Content-Type: text/markdown
31
+
32
+ # Pydantic Graph
33
+
34
+ [![CI](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml/badge.svg?event=push)](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml?query=branch%3Amain)
35
+ [![Coverage](https://coverage-badge.samuelcolvin.workers.dev/pydantic/pydantic-ai.svg)](https://coverage-badge.samuelcolvin.workers.dev/redirect/pydantic/pydantic-ai)
36
+ [![PyPI](https://img.shields.io/pypi/v/pydantic-graph.svg)](https://pypi.python.org/pypi/pydantic-graph)
37
+ [![versions](https://img.shields.io/pypi/pyversions/pydantic-graph.svg)](https://github.com/pydantic/pydantic-ai)
38
+ [![license](https://img.shields.io/github/license/pydantic/pydantic-ai.svg?v)](https://github.com/pydantic/pydantic-ai/blob/main/LICENSE)
39
+
40
+ Graph and finite state machine library.
41
+
42
+ This library is developed as part of the [PydanticAI](https://ai.pydantic.dev), however it has no dependency
43
+ on `pydantic-ai` or related packages and does and can be considered as a pure graph library.
44
+
45
+ As with PydanticAI, this library prioritizes type safety and use of common Python syntax over esoteric, domain-specific use of Python syntax.
46
+
47
+ `pydantic-graph` allows you to define graphs using simple Python syntax. In particular, edges are defined using the return type hint of nodes.
@@ -0,0 +1,16 @@
1
+ # Pydantic Graph
2
+
3
+ [![CI](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml/badge.svg?event=push)](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml?query=branch%3Amain)
4
+ [![Coverage](https://coverage-badge.samuelcolvin.workers.dev/pydantic/pydantic-ai.svg)](https://coverage-badge.samuelcolvin.workers.dev/redirect/pydantic/pydantic-ai)
5
+ [![PyPI](https://img.shields.io/pypi/v/pydantic-graph.svg)](https://pypi.python.org/pypi/pydantic-graph)
6
+ [![versions](https://img.shields.io/pypi/pyversions/pydantic-graph.svg)](https://github.com/pydantic/pydantic-ai)
7
+ [![license](https://img.shields.io/github/license/pydantic/pydantic-ai.svg?v)](https://github.com/pydantic/pydantic-ai/blob/main/LICENSE)
8
+
9
+ Graph and finite state machine library.
10
+
11
+ This library is developed as part of the [PydanticAI](https://ai.pydantic.dev), however it has no dependency
12
+ on `pydantic-ai` or related packages and does and can be considered as a pure graph library.
13
+
14
+ As with PydanticAI, this library prioritizes type safety and use of common Python syntax over esoteric, domain-specific use of Python syntax.
15
+
16
+ `pydantic-graph` allows you to define graphs using simple Python syntax. In particular, edges are defined using the return type hint of nodes.
@@ -0,0 +1,14 @@
1
+ from .graph import Graph
2
+ from .nodes import BaseNode, End, GraphContext
3
+ from .state import AbstractState, EndEvent, HistoryStep, NodeEvent
4
+
5
+ __all__ = (
6
+ 'Graph',
7
+ 'BaseNode',
8
+ 'End',
9
+ 'GraphContext',
10
+ 'AbstractState',
11
+ 'EndEvent',
12
+ 'HistoryStep',
13
+ 'NodeEvent',
14
+ )
@@ -0,0 +1,70 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import sys
4
+ import types
5
+ from datetime import datetime, timezone
6
+ from typing import Any, Union, get_args, get_origin
7
+
8
+ from typing_extensions import TypeAliasType
9
+
10
+
11
+ def get_union_args(tp: Any) -> tuple[Any, ...]:
12
+ """Extract the arguments of a Union type if `response_type` is a union, otherwise return the original type."""
13
+ # similar to `pydantic_ai_slim/pydantic_ai/_result.py:get_union_args`
14
+ if isinstance(tp, TypeAliasType):
15
+ tp = tp.__value__
16
+
17
+ origin = get_origin(tp)
18
+ if origin_is_union(origin):
19
+ return get_args(tp)
20
+ else:
21
+ return (tp,)
22
+
23
+
24
+ # same as `pydantic_ai_slim/pydantic_ai/_result.py:origin_is_union`
25
+ if sys.version_info < (3, 10):
26
+
27
+ def origin_is_union(tp: type[Any] | None) -> bool:
28
+ return tp is Union
29
+
30
+ else:
31
+
32
+ def origin_is_union(tp: type[Any] | None) -> bool:
33
+ return tp is Union or tp is types.UnionType
34
+
35
+
36
+ def comma_and(items: list[str]) -> str:
37
+ """Join with a comma and 'and' for the last item."""
38
+ if len(items) == 1:
39
+ return items[0]
40
+ else:
41
+ # oxford comma ¯\_(ツ)_/¯
42
+ return ', '.join(items[:-1]) + ', and ' + items[-1]
43
+
44
+
45
+ _NoneType = type(None)
46
+
47
+
48
+ def type_arg_name(arg: Any) -> str:
49
+ if arg is _NoneType:
50
+ return 'None'
51
+ else:
52
+ return arg.__name__
53
+
54
+
55
+ def get_parent_namespace(frame: types.FrameType | None) -> dict[str, Any] | None:
56
+ """Attempt to get the namespace where the graph was defined.
57
+
58
+ If the graph is defined with generics `Graph[a, b]` then another frame is inserted, and we have to skip that
59
+ to get the correct namespace.
60
+ """
61
+ if frame is not None:
62
+ if back := frame.f_back:
63
+ if back.f_code.co_filename.endswith('/typing.py'):
64
+ return get_parent_namespace(back)
65
+ else:
66
+ return back.f_locals
67
+
68
+
69
+ def now_utc() -> datetime:
70
+ return datetime.now(tz=timezone.utc)
@@ -0,0 +1,135 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import inspect
4
+ from collections.abc import Sequence
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from time import perf_counter
8
+ from typing import TYPE_CHECKING, Any, Generic
9
+
10
+ import logfire_api
11
+ from typing_extensions import Never, ParamSpec, TypeVar, Unpack, assert_never
12
+
13
+ from . import _utils, mermaid
14
+ from ._utils import get_parent_namespace
15
+ from .nodes import BaseNode, End, GraphContext, NodeDef
16
+ from .state import EndEvent, HistoryStep, NodeEvent, StateT
17
+
18
+ __all__ = ('Graph',)
19
+
20
+ _logfire = logfire_api.Logfire(otel_scope='pydantic-graph')
21
+
22
+ RunSignatureT = ParamSpec('RunSignatureT')
23
+ RunEndT = TypeVar('RunEndT', default=None)
24
+ NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never)
25
+
26
+
27
+ @dataclass(init=False)
28
+ class Graph(Generic[StateT, RunEndT]):
29
+ """Definition of a graph."""
30
+
31
+ name: str | None
32
+ nodes: tuple[type[BaseNode[StateT, RunEndT]], ...]
33
+ node_defs: dict[str, NodeDef[StateT, RunEndT]]
34
+
35
+ def __init__(
36
+ self,
37
+ *,
38
+ nodes: Sequence[type[BaseNode[StateT, RunEndT]]],
39
+ state_type: type[StateT] | None = None,
40
+ name: str | None = None,
41
+ ):
42
+ self.name = name
43
+
44
+ _nodes_by_id: dict[str, type[BaseNode[StateT, RunEndT]]] = {}
45
+ for node in nodes:
46
+ node_id = node.get_id()
47
+ if (existing_node := _nodes_by_id.get(node_id)) and existing_node is not node:
48
+ raise ValueError(f'Node ID "{node_id}" is not unique — found in {existing_node} and {node}')
49
+ else:
50
+ _nodes_by_id[node_id] = node
51
+ self.nodes = tuple(_nodes_by_id.values())
52
+
53
+ parent_namespace = get_parent_namespace(inspect.currentframe())
54
+ self.node_defs: dict[str, NodeDef[StateT, RunEndT]] = {}
55
+ for node in self.nodes:
56
+ self.node_defs[node.get_id()] = node.get_node_def(parent_namespace)
57
+
58
+ self._validate_edges()
59
+
60
+ def _validate_edges(self):
61
+ known_node_ids = set(self.node_defs.keys())
62
+ bad_edges: dict[str, list[str]] = {}
63
+
64
+ for node_id, node_def in self.node_defs.items():
65
+ node_bad_edges = node_def.next_node_ids - known_node_ids
66
+ for bad_edge in node_bad_edges:
67
+ bad_edges.setdefault(bad_edge, []).append(f'"{node_id}"')
68
+
69
+ if bad_edges:
70
+ bad_edges_list = [f'"{k}" is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()]
71
+ if len(bad_edges_list) == 1:
72
+ raise ValueError(f'{bad_edges_list[0]} but not included in the graph.')
73
+ else:
74
+ b = '\n'.join(f' {be}' for be in bad_edges_list)
75
+ raise ValueError(f'Nodes are referenced in the graph but not included in the graph:\n{b}')
76
+
77
+ async def next(
78
+ self, state: StateT, node: BaseNode[StateT, RunEndT], history: list[HistoryStep[StateT, RunEndT]]
79
+ ) -> BaseNode[StateT, Any] | End[RunEndT]:
80
+ node_id = node.get_id()
81
+ if node_id not in self.node_defs:
82
+ raise TypeError(f'Node "{node}" is not in the graph.')
83
+
84
+ history_step: NodeEvent[StateT, RunEndT] | None = NodeEvent(state, node)
85
+ history.append(history_step)
86
+
87
+ ctx = GraphContext(state)
88
+ with _logfire.span('run node {node_id}', node_id=node_id, node=node):
89
+ start = perf_counter()
90
+ next_node = await node.run(ctx)
91
+ history_step.duration = perf_counter() - start
92
+ return next_node
93
+
94
+ async def run(
95
+ self,
96
+ state: StateT,
97
+ node: BaseNode[StateT, RunEndT],
98
+ ) -> tuple[End[RunEndT], list[HistoryStep[StateT, RunEndT]]]:
99
+ history: list[HistoryStep[StateT, RunEndT]] = []
100
+
101
+ with _logfire.span(
102
+ '{graph_name} run {start=}',
103
+ graph_name=self.name or 'graph',
104
+ start=node,
105
+ ) as run_span:
106
+ while True:
107
+ next_node = await self.next(state, node, history=history)
108
+ if isinstance(next_node, End):
109
+ history.append(EndEvent(state, next_node))
110
+ run_span.set_attribute('history', history)
111
+ return next_node, history
112
+ elif isinstance(next_node, BaseNode):
113
+ node = next_node
114
+ else:
115
+ if TYPE_CHECKING:
116
+ assert_never(next_node)
117
+ else:
118
+ raise TypeError(f'Invalid node type: {type(next_node)}. Expected `BaseNode` or `End`.')
119
+
120
+ def mermaid_code(
121
+ self,
122
+ *,
123
+ start_node: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None,
124
+ highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None,
125
+ highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS,
126
+ ) -> str:
127
+ return mermaid.generate_code(
128
+ self, start_node=start_node, highlighted_nodes=highlighted_nodes, highlight_css=highlight_css
129
+ )
130
+
131
+ def mermaid_image(self, **kwargs: Unpack[mermaid.MermaidConfig]) -> bytes:
132
+ return mermaid.request_image(self, **kwargs)
133
+
134
+ def mermaid_save(self, path: Path | str, /, **kwargs: Unpack[mermaid.MermaidConfig]) -> None:
135
+ mermaid.save_image(path, self, **kwargs)
@@ -0,0 +1,210 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import base64
4
+ from collections.abc import Iterable, Sequence
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Annotated, Any, Literal
7
+
8
+ from annotated_types import Ge, Le
9
+ from typing_extensions import TypeAlias, TypedDict, Unpack
10
+
11
+ from .nodes import BaseNode
12
+
13
+ if TYPE_CHECKING:
14
+ from .graph import Graph
15
+
16
+
17
+ NodeIdent: TypeAlias = 'type[BaseNode[Any, Any]] | BaseNode[Any, Any] | str'
18
+ DEFAULT_HIGHLIGHT_CSS = 'fill:#fdff32'
19
+
20
+
21
+ def generate_code(
22
+ graph: Graph[Any, Any],
23
+ /,
24
+ *,
25
+ start_node: Sequence[NodeIdent] | NodeIdent | None = None,
26
+ highlighted_nodes: Sequence[NodeIdent] | NodeIdent | None = None,
27
+ highlight_css: str = DEFAULT_HIGHLIGHT_CSS,
28
+ ) -> str:
29
+ """Generate Mermaid code for a graph.
30
+
31
+ Args:
32
+ graph: The graph to generate the image for.
33
+ start_node: Identifiers of nodes that start the graph.
34
+ highlighted_nodes: Identifiers of nodes to highlight.
35
+ highlight_css: CSS to use for highlighting nodes.
36
+
37
+ Returns: The Mermaid code for the graph.
38
+ """
39
+ start_node_ids = set(node_ids(start_node or ()))
40
+ for node_id in start_node_ids:
41
+ if node_id not in graph.node_defs:
42
+ raise LookupError(f'Start node "{node_id}" is not in the graph.')
43
+
44
+ node_order = {node_id: index for index, node_id in enumerate(graph.node_defs)}
45
+
46
+ lines = ['graph TD']
47
+ for node in graph.nodes:
48
+ node_id = node.get_id()
49
+ node_def = graph.node_defs[node_id]
50
+
51
+ # we use round brackets (rounded box) for nodes other than the start and end
52
+ mermaid_name = f'({node_id})'
53
+ if node_id in start_node_ids:
54
+ lines.append(f' START --> {node_id}{mermaid_name}')
55
+ if node_def.returns_base_node:
56
+ for next_node_id in graph.nodes:
57
+ lines.append(f' {node_id}{mermaid_name} --> {next_node_id}')
58
+ else:
59
+ for _, next_node_id in sorted((node_order[node_id], node_id) for node_id in node_def.next_node_ids):
60
+ lines.append(f' {node_id}{mermaid_name} --> {next_node_id}')
61
+ if node_def.returns_end:
62
+ lines.append(f' {node_id}{mermaid_name} --> END')
63
+
64
+ if highlighted_nodes:
65
+ lines.append('')
66
+ lines.append(f'classDef highlighted {highlight_css}')
67
+ for node_id in node_ids(highlighted_nodes):
68
+ if node_id not in graph.node_defs:
69
+ raise LookupError(f'Highlighted node "{node_id}" is not in the graph.')
70
+ lines.append(f'class {node_id} highlighted')
71
+
72
+ return '\n'.join(lines)
73
+
74
+
75
+ def node_ids(node_idents: Sequence[NodeIdent] | NodeIdent) -> Iterable[str]:
76
+ """Get the node IDs from a sequence of node identifiers."""
77
+ if isinstance(node_idents, str):
78
+ node_iter = (node_idents,)
79
+ elif isinstance(node_idents, Sequence):
80
+ node_iter = node_idents
81
+ else:
82
+ node_iter = (node_idents,)
83
+
84
+ for node in node_iter:
85
+ if isinstance(node, str):
86
+ yield node
87
+ else:
88
+ yield node.get_id()
89
+
90
+
91
+ class MermaidConfig(TypedDict, total=False):
92
+ """Parameters to configure mermaid chart generation."""
93
+
94
+ start_node: Sequence[NodeIdent] | NodeIdent
95
+ """Identifiers of nodes that start the graph."""
96
+ highlighted_nodes: Sequence[NodeIdent] | NodeIdent
97
+ """Identifiers of nodes to highlight."""
98
+ highlight_css: str
99
+ """CSS to use for highlighting nodes."""
100
+ image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf']
101
+ """The image type to generate. If unspecified, the default behavior is `'jpeg'`."""
102
+ pdf_fit: bool
103
+ """When using image_type='pdf', whether to fit the diagram to the PDF page."""
104
+ pdf_landscape: bool
105
+ """When using image_type='pdf', whether to use landscape orientation for the PDF.
106
+
107
+ This has no effect if using `pdf_fit`.
108
+ """
109
+ pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6']
110
+ """When using image_type='pdf', the paper size of the PDF."""
111
+ background_color: str
112
+ """The background color of the diagram.
113
+
114
+ If None, the default transparent background is used. The color value is interpreted as a hexadecimal color
115
+ code by default (and should not have a leading '#'), but you can also use named colors by prefixing the
116
+ value with `'!'`. For example, valid choices include `background_color='!white'` or `background_color='FF0000'`.
117
+ """
118
+ theme: Literal['default', 'neutral', 'dark', 'forest']
119
+ """The theme of the diagram. Defaults to 'default'."""
120
+ width: int
121
+ """The width of the diagram."""
122
+ height: int
123
+ """The height of the diagram."""
124
+ scale: Annotated[float, Ge(1), Le(3)]
125
+ """The scale of the diagram.
126
+
127
+ The scale must be a number between 1 and 3, and you can only set a scale if one or both of width and height are set.
128
+ """
129
+
130
+
131
+ def request_image(
132
+ graph: Graph[Any, Any],
133
+ /,
134
+ **kwargs: Unpack[MermaidConfig],
135
+ ) -> bytes:
136
+ """Generate an image of a Mermaid diagram using [mermaid.ink](https://mermaid.ink).
137
+
138
+ Args:
139
+ graph: The graph to generate the image for.
140
+ **kwargs: Additional parameters to configure mermaid chart generation.
141
+
142
+ Returns: The image data.
143
+ """
144
+ import httpx
145
+
146
+ code = generate_code(
147
+ graph,
148
+ start_node=kwargs.get('start_node'),
149
+ highlighted_nodes=kwargs.get('highlighted_nodes'),
150
+ highlight_css=kwargs.get('highlight_css', DEFAULT_HIGHLIGHT_CSS),
151
+ )
152
+ code_base64 = base64.b64encode(code.encode()).decode()
153
+
154
+ params: dict[str, str | bool] = {}
155
+ if kwargs.get('image_type') == 'pdf':
156
+ url = f'https://mermaid.ink/pdf/{code_base64}'
157
+ if kwargs.get('pdf_fit'):
158
+ params['fit'] = True
159
+ if kwargs.get('pdf_landscape'):
160
+ params['landscape'] = True
161
+ if pdf_paper := kwargs.get('pdf_paper'):
162
+ params['paper'] = pdf_paper
163
+ elif kwargs.get('image_type') == 'svg':
164
+ url = f'https://mermaid.ink/svg/{code_base64}'
165
+ else:
166
+ url = f'https://mermaid.ink/img/{code_base64}'
167
+
168
+ if image_type := kwargs.get('image_type'):
169
+ params['type'] = image_type
170
+
171
+ if background_color := kwargs.get('background_color'):
172
+ params['bgColor'] = background_color
173
+ if theme := kwargs.get('theme'):
174
+ params['theme'] = theme
175
+ if width := kwargs.get('width'):
176
+ params['width'] = str(width)
177
+ if height := kwargs.get('height'):
178
+ params['height'] = str(height)
179
+ if scale := kwargs.get('scale'):
180
+ params['scale'] = str(scale)
181
+
182
+ response = httpx.get(url, params=params)
183
+ response.raise_for_status()
184
+ return response.content
185
+
186
+
187
+ def save_image(
188
+ path: Path | str,
189
+ graph: Graph[Any, Any],
190
+ /,
191
+ **kwargs: Unpack[MermaidConfig],
192
+ ) -> None:
193
+ """Generate an image of a Mermaid diagram using [mermaid.ink](https://mermaid.ink) and save it to a local file.
194
+
195
+ Args:
196
+ path: The path to save the image to.
197
+ graph: The graph to generate the image for.
198
+ **kwargs: Additional parameters to configure mermaid chart generation.
199
+ """
200
+ if isinstance(path, str):
201
+ path = Path(path)
202
+
203
+ if 'image_type' not in kwargs:
204
+ ext = path.suffix.lower()[1:]
205
+ # no need to check for .jpeg/.jpg, as it is the default
206
+ if ext in ('png', 'webp', 'svg', 'pdf'):
207
+ kwargs['image_type'] = ext
208
+
209
+ image_data = request_image(graph, **kwargs)
210
+ path.write_bytes(image_data)
@@ -0,0 +1,93 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from abc import abstractmethod
4
+ from dataclasses import dataclass
5
+ from functools import cache
6
+ from typing import Any, Generic, get_origin, get_type_hints
7
+
8
+ from typing_extensions import Never, TypeVar
9
+
10
+ from . import _utils
11
+ from .state import StateT
12
+
13
+ __all__ = 'GraphContext', 'BaseNode', 'End', 'NodeDef'
14
+
15
+ RunEndT = TypeVar('RunEndT', default=None)
16
+ NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never)
17
+
18
+
19
+ @dataclass
20
+ class GraphContext(Generic[StateT]):
21
+ """Context for a graph."""
22
+
23
+ state: StateT
24
+
25
+
26
+ class BaseNode(Generic[StateT, NodeRunEndT]):
27
+ """Base class for a node."""
28
+
29
+ @abstractmethod
30
+ async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any] | End[NodeRunEndT]: ...
31
+
32
+ @classmethod
33
+ @cache
34
+ def get_id(cls) -> str:
35
+ return cls.__name__
36
+
37
+ @classmethod
38
+ def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRunEndT]:
39
+ type_hints = get_type_hints(cls.run, localns=local_ns)
40
+ try:
41
+ return_hint = type_hints['return']
42
+ except KeyError:
43
+ raise TypeError(f'Node {cls} is missing a return type hint on its `run` method')
44
+
45
+ next_node_ids: set[str] = set()
46
+ returns_end: bool = False
47
+ returns_base_node: bool = False
48
+ for return_type in _utils.get_union_args(return_hint):
49
+ return_type_origin = get_origin(return_type) or return_type
50
+ if return_type_origin is End:
51
+ returns_end = True
52
+ elif return_type_origin is BaseNode:
53
+ # TODO: Should we disallow this?
54
+ returns_base_node = True
55
+ elif issubclass(return_type_origin, BaseNode):
56
+ next_node_ids.add(return_type.get_id())
57
+ else:
58
+ raise TypeError(f'Invalid return type: {return_type}')
59
+
60
+ return NodeDef(
61
+ cls,
62
+ cls.get_id(),
63
+ next_node_ids,
64
+ returns_end,
65
+ returns_base_node,
66
+ )
67
+
68
+
69
+ @dataclass
70
+ class End(Generic[RunEndT]):
71
+ """Type to return from a node to signal the end of the graph."""
72
+
73
+ data: RunEndT
74
+
75
+
76
+ @dataclass
77
+ class NodeDef(Generic[StateT, NodeRunEndT]):
78
+ """Definition of a node.
79
+
80
+ Used by [`Graph`][pydantic_graph.graph.Graph] to store information about a node, and when generating
81
+ mermaid graphs.
82
+ """
83
+
84
+ node: type[BaseNode[StateT, NodeRunEndT]]
85
+ """The node definition itself."""
86
+ node_id: str
87
+ """ID of the node."""
88
+ next_node_ids: set[str]
89
+ """IDs of the nodes that can be called next."""
90
+ returns_end: bool
91
+ """The node definition returns an `End`, hence the node can end the run."""
92
+ returns_base_node: bool
93
+ """The node definition returns a `BaseNode`, hence any node in the next can be called next."""
File without changes
@@ -0,0 +1,82 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import copy
4
+ from abc import ABC, abstractmethod
5
+ from dataclasses import dataclass, field
6
+ from datetime import datetime
7
+ from typing import TYPE_CHECKING, Generic, Literal, Self, Union
8
+
9
+ from typing_extensions import Never, TypeVar
10
+
11
+ from . import _utils
12
+
13
+ __all__ = 'AbstractState', 'StateT', 'NodeEvent', 'EndEvent', 'HistoryStep'
14
+
15
+ if TYPE_CHECKING:
16
+ from pydantic_graph import BaseNode
17
+ from pydantic_graph.nodes import End
18
+
19
+
20
+ class AbstractState(ABC):
21
+ """Abstract class for a state object."""
22
+
23
+ @abstractmethod
24
+ def serialize(self) -> bytes | None:
25
+ """Serialize the state object."""
26
+ raise NotImplementedError
27
+
28
+ def deep_copy(self) -> Self:
29
+ """Create a deep copy of the state object."""
30
+ return copy.deepcopy(self)
31
+
32
+
33
+ RunEndT = TypeVar('RunEndT', default=None)
34
+ NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never)
35
+ StateT = TypeVar('StateT', bound=Union[None, AbstractState], default=None)
36
+
37
+
38
+ @dataclass
39
+ class NodeEvent(Generic[StateT, RunEndT]):
40
+ """History step describing the execution of a node in a graph."""
41
+
42
+ state: StateT
43
+ node: BaseNode[StateT, RunEndT]
44
+ start_ts: datetime = field(default_factory=_utils.now_utc)
45
+ duration: float | None = None
46
+
47
+ kind: Literal['step'] = 'step'
48
+
49
+ def __post_init__(self):
50
+ # Copy the state to prevent it from being modified by other code
51
+ self.state = _deep_copy_state(self.state)
52
+
53
+ def summary(self) -> str:
54
+ return str(self.node)
55
+
56
+
57
+ @dataclass
58
+ class EndEvent(Generic[StateT, RunEndT]):
59
+ """History step describing the end of a graph run."""
60
+
61
+ state: StateT
62
+ result: End[RunEndT]
63
+ ts: datetime = field(default_factory=_utils.now_utc)
64
+
65
+ kind: Literal['end'] = 'end'
66
+
67
+ def __post_init__(self):
68
+ # Copy the state to prevent it from being modified by other code
69
+ self.state = _deep_copy_state(self.state)
70
+
71
+ def summary(self) -> str:
72
+ return str(self.result)
73
+
74
+
75
+ def _deep_copy_state(state: StateT) -> StateT:
76
+ if state is None:
77
+ return state
78
+ else:
79
+ return state.deep_copy()
80
+
81
+
82
+ HistoryStep = Union[NodeEvent[StateT, RunEndT], EndEvent[StateT, RunEndT]]
@@ -0,0 +1,43 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "pydantic-graph"
7
+ version = "0.0.1"
8
+ description = "Graph and state machine library"
9
+ authors = [
10
+ { name = "Samuel Colvin", email = "samuel@pydantic.dev" },
11
+ ]
12
+ license = "MIT"
13
+ readme = "README.md"
14
+ classifiers = [
15
+ "Development Status :: 4 - Beta",
16
+ "Programming Language :: Python",
17
+ "Programming Language :: Python :: 3",
18
+ "Programming Language :: Python :: 3 :: Only",
19
+ "Programming Language :: Python :: 3.9",
20
+ "Programming Language :: Python :: 3.10",
21
+ "Programming Language :: Python :: 3.11",
22
+ "Programming Language :: Python :: 3.12",
23
+ "Programming Language :: Python :: 3.13",
24
+ "Intended Audience :: Developers",
25
+ "Intended Audience :: Information Technology",
26
+ "Intended Audience :: System Administrators",
27
+ "License :: OSI Approved :: MIT License",
28
+ "Operating System :: Unix",
29
+ "Operating System :: POSIX :: Linux",
30
+ "Environment :: Console",
31
+ "Environment :: MacOS X",
32
+ "Topic :: Software Development :: Libraries :: Python Modules",
33
+ "Topic :: Internet",
34
+ ]
35
+ requires-python = ">=3.9"
36
+ dependencies = [
37
+ "httpx>=0.27.2",
38
+ "logfire-api>=1.2.0",
39
+ "pydantic>=2.10",
40
+ ]
41
+
42
+ [tool.hatch.build.targets.wheel]
43
+ packages = ["pydantic_graph"]