pydantic-graph 0.2.2__py3-none-any.whl → 1.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_graph/_utils.py +73 -6
- pydantic_graph/beta/__init__.py +25 -0
- pydantic_graph/beta/decision.py +276 -0
- pydantic_graph/beta/graph.py +978 -0
- pydantic_graph/beta/graph_builder.py +1053 -0
- pydantic_graph/beta/id_types.py +76 -0
- pydantic_graph/beta/join.py +249 -0
- pydantic_graph/beta/mermaid.py +208 -0
- pydantic_graph/beta/node.py +95 -0
- pydantic_graph/beta/node_types.py +90 -0
- pydantic_graph/beta/parent_forks.py +232 -0
- pydantic_graph/beta/paths.py +421 -0
- pydantic_graph/beta/step.py +253 -0
- pydantic_graph/beta/util.py +90 -0
- pydantic_graph/exceptions.py +22 -0
- pydantic_graph/graph.py +27 -71
- pydantic_graph/mermaid.py +2 -2
- pydantic_graph/nodes.py +13 -16
- pydantic_graph/persistence/__init__.py +4 -4
- pydantic_graph/persistence/_utils.py +1 -1
- pydantic_graph/persistence/file.py +13 -14
- pydantic_graph/persistence/in_mem.py +4 -4
- {pydantic_graph-0.2.2.dist-info → pydantic_graph-1.24.0.dist-info}/METADATA +10 -10
- pydantic_graph-1.24.0.dist-info/RECORD +28 -0
- pydantic_graph-1.24.0.dist-info/licenses/LICENSE +21 -0
- pydantic_graph-0.2.2.dist-info/RECORD +0 -14
- {pydantic_graph-0.2.2.dist-info → pydantic_graph-1.24.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""Utility types and functions for type manipulation and introspection.
|
|
2
|
+
|
|
3
|
+
This module provides helper classes and functions for working with Python's type system,
|
|
4
|
+
including workarounds for type checker limitations and utilities for runtime type inspection.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Any, Generic, cast, get_args, get_origin
|
|
9
|
+
|
|
10
|
+
from typing_extensions import TypeAliasType, TypeVar
|
|
11
|
+
|
|
12
|
+
T = TypeVar('T', infer_variance=True)
|
|
13
|
+
"""Generic type variable with inferred variance."""
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TypeExpression(Generic[T]):
|
|
17
|
+
"""A workaround for type checker limitations when using complex type expressions.
|
|
18
|
+
|
|
19
|
+
This class serves as a wrapper for types that cannot normally be used in positions
|
|
20
|
+
requiring `type[T]`, such as `Any`, `Union[...]`, or `Literal[...]`. It provides a
|
|
21
|
+
way to pass these complex type expressions to functions expecting concrete types.
|
|
22
|
+
|
|
23
|
+
Example:
|
|
24
|
+
Instead of `output_type=Union[str, int]` (which may cause type errors),
|
|
25
|
+
use `output_type=TypeExpression[Union[str, int]]`.
|
|
26
|
+
|
|
27
|
+
Note:
|
|
28
|
+
This is a workaround for the lack of TypeForm in the Python type system.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
TypeOrTypeExpression = TypeAliasType('TypeOrTypeExpression', type[TypeExpression[T]] | type[T], type_params=(T,))
|
|
35
|
+
"""Type alias allowing both direct types and TypeExpression wrappers.
|
|
36
|
+
|
|
37
|
+
This alias enables functions to accept either regular types (when compatible with type checkers)
|
|
38
|
+
or TypeExpression wrappers for complex type expressions. The correct type should be inferred
|
|
39
|
+
automatically in either case.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def unpack_type_expression(type_: TypeOrTypeExpression[T]) -> type[T]:
|
|
44
|
+
"""Extract the actual type from a TypeExpression wrapper or return the type directly.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
type_: Either a direct type or a TypeExpression wrapper.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
The unwrapped type, ready for use in runtime type operations.
|
|
51
|
+
"""
|
|
52
|
+
if get_origin(type_) is TypeExpression:
|
|
53
|
+
return get_args(type_)[0]
|
|
54
|
+
return cast(type[T], type_)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class Some(Generic[T]):
|
|
59
|
+
"""Container for explicitly present values in Maybe type pattern.
|
|
60
|
+
|
|
61
|
+
This class represents a value that is definitely present, as opposed to None.
|
|
62
|
+
It's part of the Maybe pattern, similar to Option/Maybe in functional programming,
|
|
63
|
+
allowing distinction between "no value" (None) and "value is None" (Some(None)).
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
value: T
|
|
67
|
+
"""The wrapped value."""
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
Maybe = TypeAliasType('Maybe', Some[T] | None, type_params=(T,))
|
|
71
|
+
"""Optional-like type that distinguishes between absence and None values.
|
|
72
|
+
|
|
73
|
+
Unlike Optional[T], Maybe[T] can differentiate between:
|
|
74
|
+
- No value present: represented as None
|
|
75
|
+
- Value is None: represented as Some(None)
|
|
76
|
+
|
|
77
|
+
This is particularly useful when None is a valid value in your domain.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def get_callable_name(callable_: Any) -> str:
|
|
82
|
+
"""Extract a human-readable name from a callable object.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
callable_: Any callable object (function, method, class, etc.).
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
The callable's __name__ attribute if available, otherwise its string representation.
|
|
89
|
+
"""
|
|
90
|
+
return getattr(callable_, '__name__', str(callable_))
|
pydantic_graph/exceptions.py
CHANGED
|
@@ -15,6 +15,28 @@ class GraphSetupError(TypeError):
|
|
|
15
15
|
super().__init__(message)
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
class GraphBuildingError(ValueError):
|
|
19
|
+
"""An error raised during graph-building."""
|
|
20
|
+
|
|
21
|
+
message: str
|
|
22
|
+
"""The error message."""
|
|
23
|
+
|
|
24
|
+
def __init__(self, message: str):
|
|
25
|
+
self.message = message
|
|
26
|
+
super().__init__(message)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class GraphValidationError(ValueError):
|
|
30
|
+
"""An error raised during graph validation."""
|
|
31
|
+
|
|
32
|
+
message: str
|
|
33
|
+
"""The error message."""
|
|
34
|
+
|
|
35
|
+
def __init__(self, message: str):
|
|
36
|
+
self.message = message
|
|
37
|
+
super().__init__(message)
|
|
38
|
+
|
|
39
|
+
|
|
18
40
|
class GraphRuntimeError(RuntimeError):
|
|
19
41
|
"""Error caused by an issue during graph execution."""
|
|
20
42
|
|
pydantic_graph/graph.py
CHANGED
|
@@ -6,35 +6,20 @@ from collections.abc import AsyncIterator, Sequence
|
|
|
6
6
|
from contextlib import AbstractContextManager, ExitStack, asynccontextmanager
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from functools import cached_property
|
|
9
|
+
from pathlib import Path
|
|
9
10
|
from typing import Any, Generic, cast, overload
|
|
10
11
|
|
|
11
|
-
import logfire_api
|
|
12
12
|
import typing_extensions
|
|
13
|
-
from opentelemetry.trace import Span
|
|
14
|
-
from typing_extensions import deprecated
|
|
15
13
|
from typing_inspection import typing_objects
|
|
16
14
|
|
|
17
15
|
from . import _utils, exceptions, mermaid
|
|
18
|
-
from ._utils import AbstractSpan, get_traceparent
|
|
16
|
+
from ._utils import AbstractSpan, get_traceparent, logfire_span
|
|
19
17
|
from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT, StateT
|
|
20
18
|
from .persistence import BaseStatePersistence
|
|
21
19
|
from .persistence.in_mem import SimpleStatePersistence
|
|
22
20
|
|
|
23
|
-
# while waiting for https://github.com/pydantic/logfire/issues/745
|
|
24
|
-
try:
|
|
25
|
-
import logfire._internal.stack_info
|
|
26
|
-
except ImportError:
|
|
27
|
-
pass
|
|
28
|
-
else:
|
|
29
|
-
from pathlib import Path
|
|
30
|
-
|
|
31
|
-
logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),) # pyright: ignore[reportPrivateImportUsage]
|
|
32
|
-
|
|
33
|
-
|
|
34
21
|
__all__ = 'Graph', 'GraphRun', 'GraphRunResult'
|
|
35
22
|
|
|
36
|
-
_logfire = logfire_api.Logfire(otel_scope='pydantic-graph')
|
|
37
|
-
|
|
38
23
|
|
|
39
24
|
@dataclass(init=False)
|
|
40
25
|
class Graph(Generic[StateT, DepsT, RunEndT]):
|
|
@@ -46,7 +31,7 @@ class Graph(Generic[StateT, DepsT, RunEndT]):
|
|
|
46
31
|
Here's a very simple example of a graph which increments a number by 1, but makes sure the number is never
|
|
47
32
|
42 at the end.
|
|
48
33
|
|
|
49
|
-
```py {title="never_42.py" noqa="I001"
|
|
34
|
+
```py {title="never_42.py" noqa="I001"}
|
|
50
35
|
from __future__ import annotations
|
|
51
36
|
|
|
52
37
|
from dataclasses import dataclass
|
|
@@ -143,7 +128,7 @@ class Graph(Generic[StateT, DepsT, RunEndT]):
|
|
|
143
128
|
|
|
144
129
|
Here's an example of running the graph from [above][pydantic_graph.graph.Graph]:
|
|
145
130
|
|
|
146
|
-
```py {title="run_never_42.py" noqa="I001"
|
|
131
|
+
```py {title="run_never_42.py" noqa="I001" requires="never_42.py"}
|
|
147
132
|
from never_42 import Increment, MyState, never_42_graph
|
|
148
133
|
|
|
149
134
|
async def main():
|
|
@@ -197,7 +182,7 @@ class Graph(Generic[StateT, DepsT, RunEndT]):
|
|
|
197
182
|
Returns:
|
|
198
183
|
The result type from ending the run and the history of the run.
|
|
199
184
|
"""
|
|
200
|
-
if infer_name and self.name is None:
|
|
185
|
+
if infer_name and self.name is None: # pragma: no branch
|
|
201
186
|
self._infer_name(inspect.currentframe())
|
|
202
187
|
|
|
203
188
|
return _utils.get_event_loop().run_until_complete(
|
|
@@ -212,7 +197,7 @@ class Graph(Generic[StateT, DepsT, RunEndT]):
|
|
|
212
197
|
state: StateT = None,
|
|
213
198
|
deps: DepsT = None,
|
|
214
199
|
persistence: BaseStatePersistence[StateT, RunEndT] | None = None,
|
|
215
|
-
span: AbstractContextManager[
|
|
200
|
+
span: AbstractContextManager[AbstractSpan] | None = None,
|
|
216
201
|
infer_name: bool = True,
|
|
217
202
|
) -> AsyncIterator[GraphRun[StateT, DepsT, RunEndT]]:
|
|
218
203
|
"""A contextmanager which can be used to iterate over the graph's nodes as they are executed.
|
|
@@ -253,8 +238,12 @@ class Graph(Generic[StateT, DepsT, RunEndT]):
|
|
|
253
238
|
with ExitStack() as stack:
|
|
254
239
|
entered_span: AbstractSpan | None = None
|
|
255
240
|
if span is None:
|
|
256
|
-
if self.auto_instrument:
|
|
257
|
-
|
|
241
|
+
if self.auto_instrument: # pragma: no branch
|
|
242
|
+
# Separate variable because we actually don't want logfire's f-string magic here,
|
|
243
|
+
# we want the span_name to be preformatted for other backends
|
|
244
|
+
# as requested in https://github.com/pydantic/pydantic-ai/issues/3173.
|
|
245
|
+
span_name = f'run graph {self.name}'
|
|
246
|
+
entered_span = stack.enter_context(logfire_span(span_name, graph=self))
|
|
258
247
|
else:
|
|
259
248
|
entered_span = stack.enter_context(span)
|
|
260
249
|
traceparent = None if entered_span is None else get_traceparent(entered_span)
|
|
@@ -302,8 +291,8 @@ class Graph(Generic[StateT, DepsT, RunEndT]):
|
|
|
302
291
|
|
|
303
292
|
snapshot.node.set_snapshot_id(snapshot.id)
|
|
304
293
|
|
|
305
|
-
if self.auto_instrument and span is None:
|
|
306
|
-
span =
|
|
294
|
+
if self.auto_instrument and span is None: # pragma: no branch
|
|
295
|
+
span = logfire_span('run graph {graph.name}', graph=self)
|
|
307
296
|
|
|
308
297
|
with ExitStack() as stack:
|
|
309
298
|
entered_span = None if span is None else stack.enter_context(span)
|
|
@@ -343,43 +332,6 @@ class Graph(Generic[StateT, DepsT, RunEndT]):
|
|
|
343
332
|
persistence.set_graph_types(self)
|
|
344
333
|
await persistence.snapshot_node(state, node)
|
|
345
334
|
|
|
346
|
-
@deprecated('`next` is deprecated, use `async with graph.iter(...) as run: run.next()` instead')
|
|
347
|
-
async def next(
|
|
348
|
-
self,
|
|
349
|
-
node: BaseNode[StateT, DepsT, RunEndT],
|
|
350
|
-
persistence: BaseStatePersistence[StateT, RunEndT],
|
|
351
|
-
*,
|
|
352
|
-
state: StateT = None,
|
|
353
|
-
deps: DepsT = None,
|
|
354
|
-
infer_name: bool = True,
|
|
355
|
-
) -> BaseNode[StateT, DepsT, Any] | End[RunEndT]:
|
|
356
|
-
"""Run a node in the graph and return the next node to run.
|
|
357
|
-
|
|
358
|
-
Args:
|
|
359
|
-
node: The node to run.
|
|
360
|
-
persistence: State persistence interface, defaults to
|
|
361
|
-
[`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`.
|
|
362
|
-
state: The current state of the graph.
|
|
363
|
-
deps: The dependencies of the graph.
|
|
364
|
-
infer_name: Whether to infer the graph name from the calling frame.
|
|
365
|
-
|
|
366
|
-
Returns:
|
|
367
|
-
The next node to run or [`End`][pydantic_graph.nodes.End] if the graph has finished.
|
|
368
|
-
"""
|
|
369
|
-
if infer_name and self.name is None:
|
|
370
|
-
self._infer_name(inspect.currentframe())
|
|
371
|
-
|
|
372
|
-
persistence.set_graph_types(self)
|
|
373
|
-
run = GraphRun[StateT, DepsT, RunEndT](
|
|
374
|
-
graph=self,
|
|
375
|
-
start_node=node,
|
|
376
|
-
persistence=persistence,
|
|
377
|
-
state=state,
|
|
378
|
-
deps=deps,
|
|
379
|
-
traceparent=None,
|
|
380
|
-
)
|
|
381
|
-
return await run.next(node)
|
|
382
|
-
|
|
383
335
|
def mermaid_code(
|
|
384
336
|
self,
|
|
385
337
|
*,
|
|
@@ -411,7 +363,7 @@ class Graph(Generic[StateT, DepsT, RunEndT]):
|
|
|
411
363
|
|
|
412
364
|
Here's an example of generating a diagram for the graph from [above][pydantic_graph.graph.Graph]:
|
|
413
365
|
|
|
414
|
-
```py {title="mermaid_never_42.py"
|
|
366
|
+
```py {title="mermaid_never_42.py" requires="never_42.py"}
|
|
415
367
|
from never_42 import Increment, never_42_graph
|
|
416
368
|
|
|
417
369
|
print(never_42_graph.mermaid_code(start_node=Increment))
|
|
@@ -532,7 +484,7 @@ class Graph(Generic[StateT, DepsT, RunEndT]):
|
|
|
532
484
|
# break the inner (bases) loop
|
|
533
485
|
break
|
|
534
486
|
|
|
535
|
-
if not _utils.is_set(state_type):
|
|
487
|
+
if not _utils.is_set(state_type): # pragma: no branch
|
|
536
488
|
# state defaults to None, so use that if we can't infer it
|
|
537
489
|
state_type = None
|
|
538
490
|
if not _utils.is_set(run_end_type):
|
|
@@ -585,9 +537,9 @@ class Graph(Generic[StateT, DepsT, RunEndT]):
|
|
|
585
537
|
if item is self:
|
|
586
538
|
self.name = name
|
|
587
539
|
return
|
|
588
|
-
if parent_frame.f_locals != parent_frame.f_globals:
|
|
540
|
+
if parent_frame.f_locals != parent_frame.f_globals: # pragma: no branch
|
|
589
541
|
# if we couldn't find the agent in locals and globals are a different dict, try globals
|
|
590
|
-
for name, item in parent_frame.f_globals.items():
|
|
542
|
+
for name, item in parent_frame.f_globals.items(): # pragma: no branch
|
|
591
543
|
if item is self:
|
|
592
544
|
self.name = name
|
|
593
545
|
return
|
|
@@ -601,7 +553,7 @@ class GraphRun(Generic[StateT, DepsT, RunEndT]):
|
|
|
601
553
|
through nodes as they run, either by `async for` iteration or by repeatedly calling `.next(...)`.
|
|
602
554
|
|
|
603
555
|
Here's an example of iterating over the graph from [above][pydantic_graph.graph.Graph]:
|
|
604
|
-
```py {title="iter_never_42.py" noqa="I001"
|
|
556
|
+
```py {title="iter_never_42.py" noqa="I001" requires="never_42.py"}
|
|
605
557
|
from copy import deepcopy
|
|
606
558
|
from never_42 import Increment, MyState, never_42_graph
|
|
607
559
|
|
|
@@ -717,7 +669,7 @@ class GraphRun(Generic[StateT, DepsT, RunEndT]):
|
|
|
717
669
|
under dynamic conditions. The graph run should stop when you return an [`End`][pydantic_graph.nodes.End] node.
|
|
718
670
|
|
|
719
671
|
Here's an example of using `next` to drive the graph from [above][pydantic_graph.graph.Graph]:
|
|
720
|
-
```py {title="next_never_42.py" noqa="I001"
|
|
672
|
+
```py {title="next_never_42.py" noqa="I001" requires="never_42.py"}
|
|
721
673
|
from copy import deepcopy
|
|
722
674
|
from pydantic_graph import End
|
|
723
675
|
from never_42 import Increment, MyState, never_42_graph
|
|
@@ -775,11 +727,15 @@ class GraphRun(Generic[StateT, DepsT, RunEndT]):
|
|
|
775
727
|
raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.')
|
|
776
728
|
|
|
777
729
|
with ExitStack() as stack:
|
|
778
|
-
if self.graph.auto_instrument:
|
|
779
|
-
|
|
730
|
+
if self.graph.auto_instrument: # pragma: no branch
|
|
731
|
+
# Separate variable because we actually don't want logfire's f-string magic here,
|
|
732
|
+
# we want the span_name to be preformatted for other backends
|
|
733
|
+
# as requested in https://github.com/pydantic/pydantic-ai/issues/3173.
|
|
734
|
+
span_name = f'run node {node_id}'
|
|
735
|
+
stack.enter_context(logfire_span(span_name, node_id=node_id, node=node))
|
|
780
736
|
|
|
781
737
|
async with self.persistence.record_run(node_snapshot_id):
|
|
782
|
-
ctx = GraphRunContext(self.state, self.deps)
|
|
738
|
+
ctx = GraphRunContext(state=self.state, deps=self.deps)
|
|
783
739
|
self._next_node = await node.run(ctx)
|
|
784
740
|
|
|
785
741
|
if isinstance(self._next_node, End):
|
pydantic_graph/mermaid.py
CHANGED
|
@@ -5,11 +5,11 @@ import re
|
|
|
5
5
|
from collections.abc import Iterable, Sequence
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
from textwrap import indent
|
|
8
|
-
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
|
8
|
+
from typing import TYPE_CHECKING, Annotated, Any, Literal, TypeAlias
|
|
9
9
|
|
|
10
10
|
import httpx
|
|
11
11
|
from annotated_types import Ge, Le
|
|
12
|
-
from typing_extensions import
|
|
12
|
+
from typing_extensions import TypedDict, Unpack
|
|
13
13
|
|
|
14
14
|
from .nodes import BaseNode
|
|
15
15
|
|
pydantic_graph/nodes.py
CHANGED
|
@@ -4,10 +4,10 @@ import copy
|
|
|
4
4
|
from abc import ABC, abstractmethod
|
|
5
5
|
from dataclasses import dataclass, is_dataclass
|
|
6
6
|
from functools import cache
|
|
7
|
-
from typing import Any, ClassVar, Generic, get_type_hints
|
|
7
|
+
from typing import Any, ClassVar, Generic, get_origin, get_type_hints
|
|
8
8
|
from uuid import uuid4
|
|
9
9
|
|
|
10
|
-
from typing_extensions import Never, Self, TypeVar
|
|
10
|
+
from typing_extensions import Never, Self, TypeVar
|
|
11
11
|
|
|
12
12
|
from . import _utils, exceptions
|
|
13
13
|
|
|
@@ -24,12 +24,10 @@ DepsT = TypeVar('DepsT', default=None, contravariant=True)
|
|
|
24
24
|
"""Type variable for the dependencies of a graph and node."""
|
|
25
25
|
|
|
26
26
|
|
|
27
|
-
@dataclass
|
|
27
|
+
@dataclass(kw_only=True)
|
|
28
28
|
class GraphRunContext(Generic[StateT, DepsT]):
|
|
29
29
|
"""Context for a graph."""
|
|
30
30
|
|
|
31
|
-
# TODO: Can we get rid of this struct and just pass both these things around..?
|
|
32
|
-
|
|
33
31
|
state: StateT
|
|
34
32
|
"""The state of the graph."""
|
|
35
33
|
deps: DepsT
|
|
@@ -94,8 +92,8 @@ class BaseNode(ABC, Generic[StateT, DepsT, NodeRunEndT]):
|
|
|
94
92
|
docstring = cls.__doc__
|
|
95
93
|
# dataclasses get an automatic docstring which is just their signature, we don't want that
|
|
96
94
|
if docstring and is_dataclass(cls) and docstring.startswith(f'{cls.__name__}('):
|
|
97
|
-
docstring = None
|
|
98
|
-
if docstring:
|
|
95
|
+
docstring = None # pragma: no cover
|
|
96
|
+
if docstring: # pragma: no branch
|
|
99
97
|
# remove indentation from docstring
|
|
100
98
|
import inspect
|
|
101
99
|
|
|
@@ -121,7 +119,6 @@ class BaseNode(ABC, Generic[StateT, DepsT, NodeRunEndT]):
|
|
|
121
119
|
if return_type_origin is End:
|
|
122
120
|
end_edge = edge
|
|
123
121
|
elif return_type_origin is BaseNode:
|
|
124
|
-
# TODO: Should we disallow this?
|
|
125
122
|
returns_base_node = True
|
|
126
123
|
elif issubclass(return_type_origin, BaseNode):
|
|
127
124
|
next_node_edges[return_type.get_node_id()] = edge
|
|
@@ -129,12 +126,12 @@ class BaseNode(ABC, Generic[StateT, DepsT, NodeRunEndT]):
|
|
|
129
126
|
raise exceptions.GraphSetupError(f'Invalid return type: {return_type}')
|
|
130
127
|
|
|
131
128
|
return NodeDef(
|
|
132
|
-
cls,
|
|
133
|
-
cls.get_node_id(),
|
|
134
|
-
cls.get_note(),
|
|
135
|
-
next_node_edges,
|
|
136
|
-
end_edge,
|
|
137
|
-
returns_base_node,
|
|
129
|
+
node=cls,
|
|
130
|
+
node_id=cls.get_node_id(),
|
|
131
|
+
note=cls.get_note(),
|
|
132
|
+
next_node_edges=next_node_edges,
|
|
133
|
+
end_edge=end_edge,
|
|
134
|
+
returns_base_node=returns_base_node,
|
|
138
135
|
)
|
|
139
136
|
|
|
140
137
|
def deep_copy(self) -> Self:
|
|
@@ -174,7 +171,7 @@ def generate_snapshot_id(node_id: str) -> str:
|
|
|
174
171
|
return f'{node_id}:{uuid4().hex}'
|
|
175
172
|
|
|
176
173
|
|
|
177
|
-
@dataclass
|
|
174
|
+
@dataclass(frozen=True)
|
|
178
175
|
class Edge:
|
|
179
176
|
"""Annotation to apply a label to an edge in a graph."""
|
|
180
177
|
|
|
@@ -182,7 +179,7 @@ class Edge:
|
|
|
182
179
|
"""Label for the edge."""
|
|
183
180
|
|
|
184
181
|
|
|
185
|
-
@dataclass
|
|
182
|
+
@dataclass(kw_only=True)
|
|
186
183
|
class NodeDef(Generic[StateT, DepsT, NodeRunEndT]):
|
|
187
184
|
"""Definition of a node.
|
|
188
185
|
|
|
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
|
|
4
4
|
from contextlib import AbstractAsyncContextManager
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime
|
|
7
|
-
from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal
|
|
7
|
+
from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal
|
|
8
8
|
|
|
9
9
|
import pydantic
|
|
10
10
|
from typing_extensions import TypeVar
|
|
@@ -41,7 +41,7 @@ SnapshotStatus = Literal['created', 'pending', 'running', 'success', 'error']
|
|
|
41
41
|
"""
|
|
42
42
|
|
|
43
43
|
|
|
44
|
-
@dataclass
|
|
44
|
+
@dataclass(kw_only=True)
|
|
45
45
|
class NodeSnapshot(Generic[StateT, RunEndT]):
|
|
46
46
|
"""History step describing the execution of a node in a graph."""
|
|
47
47
|
|
|
@@ -66,7 +66,7 @@ class NodeSnapshot(Generic[StateT, RunEndT]):
|
|
|
66
66
|
self.id = self.node.get_snapshot_id()
|
|
67
67
|
|
|
68
68
|
|
|
69
|
-
@dataclass
|
|
69
|
+
@dataclass(kw_only=True)
|
|
70
70
|
class EndSnapshot(Generic[StateT, RunEndT]):
|
|
71
71
|
"""History step describing the end of a graph run."""
|
|
72
72
|
|
|
@@ -95,7 +95,7 @@ class EndSnapshot(Generic[StateT, RunEndT]):
|
|
|
95
95
|
return self.result
|
|
96
96
|
|
|
97
97
|
|
|
98
|
-
Snapshot =
|
|
98
|
+
Snapshot = NodeSnapshot[StateT, RunEndT] | EndSnapshot[StateT, RunEndT]
|
|
99
99
|
"""A step in the history of a graph run.
|
|
100
100
|
|
|
101
101
|
[`Graph.run`][pydantic_graph.graph.Graph.run] returns a list of these steps describing the execution of the graph,
|
|
@@ -31,7 +31,7 @@ class CustomNodeSchema:
|
|
|
31
31
|
nodes_type = nodes[0]
|
|
32
32
|
else:
|
|
33
33
|
nodes_annotated = [Annotated[node, pydantic.Tag(node.get_node_id())] for node in nodes]
|
|
34
|
-
nodes_type = Annotated[Union[tuple(nodes_annotated)], pydantic.Discriminator(self._node_discriminator)]
|
|
34
|
+
nodes_type = Annotated[Union[tuple(nodes_annotated)], pydantic.Discriminator(self._node_discriminator)] # noqa: UP007
|
|
35
35
|
|
|
36
36
|
schema = handler(nodes_type)
|
|
37
37
|
schema['serialization'] = core_schema.wrap_serializer_function_ser_schema(
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
import asyncio
|
|
4
3
|
import secrets
|
|
5
4
|
from collections.abc import AsyncIterator
|
|
6
5
|
from contextlib import AsyncExitStack, asynccontextmanager
|
|
@@ -9,6 +8,7 @@ from pathlib import Path
|
|
|
9
8
|
from time import perf_counter
|
|
10
9
|
from typing import Any
|
|
11
10
|
|
|
11
|
+
import anyio
|
|
12
12
|
import pydantic
|
|
13
13
|
|
|
14
14
|
from .. import _utils as _graph_utils, exceptions
|
|
@@ -59,7 +59,7 @@ class FileStatePersistence(BaseStatePersistence[StateT, RunEndT]):
|
|
|
59
59
|
) -> None:
|
|
60
60
|
async with self._lock():
|
|
61
61
|
snapshots = await self.load_all()
|
|
62
|
-
if not any(s.id == snapshot_id for s in snapshots):
|
|
62
|
+
if not any(s.id == snapshot_id for s in snapshots): # pragma: no branch
|
|
63
63
|
await self._append_save(NodeSnapshot(state=state, node=next_node), lock=False)
|
|
64
64
|
|
|
65
65
|
async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None:
|
|
@@ -155,24 +155,23 @@ class FileStatePersistence(BaseStatePersistence[StateT, RunEndT]):
|
|
|
155
155
|
"""
|
|
156
156
|
lock_file = self.json_file.parent / f'{self.json_file.name}.pydantic-graph-persistence-lock'
|
|
157
157
|
lock_id = secrets.token_urlsafe().encode()
|
|
158
|
-
|
|
158
|
+
|
|
159
|
+
with anyio.fail_after(timeout):
|
|
160
|
+
while not await _file_append_check(lock_file, lock_id):
|
|
161
|
+
await anyio.sleep(0.01)
|
|
162
|
+
|
|
159
163
|
try:
|
|
160
164
|
yield
|
|
161
165
|
finally:
|
|
162
166
|
await _graph_utils.run_in_executor(lock_file.unlink, missing_ok=True)
|
|
163
167
|
|
|
164
168
|
|
|
165
|
-
async def
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
await asyncio.sleep(0.01)
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
def _file_append_check(file: Path, content: bytes) -> bool:
|
|
172
|
-
if file.exists():
|
|
169
|
+
async def _file_append_check(file: Path, content: bytes) -> bool:
|
|
170
|
+
path = anyio.Path(file)
|
|
171
|
+
if await path.exists():
|
|
173
172
|
return False
|
|
174
173
|
|
|
175
|
-
with
|
|
176
|
-
f.write(content + b'\n')
|
|
174
|
+
async with await anyio.open_file(path, mode='ab') as f:
|
|
175
|
+
await f.write(content + b'\n')
|
|
177
176
|
|
|
178
|
-
return
|
|
177
|
+
return (await path.read_bytes()).startswith(content)
|
|
@@ -45,7 +45,7 @@ class SimpleStatePersistence(BaseStatePersistence[StateT, RunEndT]):
|
|
|
45
45
|
self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]
|
|
46
46
|
) -> None:
|
|
47
47
|
if self.last_snapshot and self.last_snapshot.id == snapshot_id:
|
|
48
|
-
return
|
|
48
|
+
return # pragma: no cover
|
|
49
49
|
else:
|
|
50
50
|
await self.snapshot_node(state, next_node)
|
|
51
51
|
|
|
@@ -65,7 +65,7 @@ class SimpleStatePersistence(BaseStatePersistence[StateT, RunEndT]):
|
|
|
65
65
|
start = perf_counter()
|
|
66
66
|
try:
|
|
67
67
|
yield
|
|
68
|
-
except Exception:
|
|
68
|
+
except Exception: # pragma: no cover
|
|
69
69
|
self.last_snapshot.duration = perf_counter() - start
|
|
70
70
|
self.last_snapshot.status = 'error'
|
|
71
71
|
raise
|
|
@@ -76,7 +76,7 @@ class SimpleStatePersistence(BaseStatePersistence[StateT, RunEndT]):
|
|
|
76
76
|
async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None:
|
|
77
77
|
if isinstance(self.last_snapshot, NodeSnapshot) and self.last_snapshot.status == 'created':
|
|
78
78
|
self.last_snapshot.status = 'pending'
|
|
79
|
-
return self.last_snapshot
|
|
79
|
+
return copy.deepcopy(self.last_snapshot)
|
|
80
80
|
|
|
81
81
|
async def load_all(self) -> list[Snapshot[StateT, RunEndT]]:
|
|
82
82
|
raise NotImplementedError('load is not supported for SimpleStatePersistence')
|
|
@@ -143,7 +143,7 @@ class FullStatePersistence(BaseStatePersistence[StateT, RunEndT]):
|
|
|
143
143
|
async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None:
|
|
144
144
|
if snapshot := next((s for s in self.history if isinstance(s, NodeSnapshot) and s.status == 'created'), None):
|
|
145
145
|
snapshot.status = 'pending'
|
|
146
|
-
return snapshot
|
|
146
|
+
return copy.deepcopy(snapshot)
|
|
147
147
|
|
|
148
148
|
async def load_all(self) -> list[Snapshot[StateT, RunEndT]]:
|
|
149
149
|
return self.history
|
|
@@ -1,14 +1,15 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-graph
|
|
3
|
-
Version:
|
|
3
|
+
Version: 1.24.0
|
|
4
4
|
Summary: Graph and state machine library
|
|
5
5
|
Project-URL: Homepage, https://ai.pydantic.dev/graph/tree/main/pydantic_graph
|
|
6
6
|
Project-URL: Source, https://github.com/pydantic/pydantic-ai
|
|
7
7
|
Project-URL: Documentation, https://ai.pydantic.dev/graph
|
|
8
8
|
Project-URL: Changelog, https://github.com/pydantic/pydantic-ai/releases
|
|
9
|
-
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
9
|
+
Author-email: Samuel Colvin <samuel@pydantic.dev>, Marcelo Trylesinski <marcelotryle@gmail.com>, David Montague <david@pydantic.dev>, Alex Hall <alex@pydantic.dev>, Douwe Maan <douwe@pydantic.dev>
|
|
10
10
|
License-Expression: MIT
|
|
11
|
-
|
|
11
|
+
License-File: LICENSE
|
|
12
|
+
Classifier: Development Status :: 5 - Production/Stable
|
|
12
13
|
Classifier: Environment :: Console
|
|
13
14
|
Classifier: Environment :: MacOS X
|
|
14
15
|
Classifier: Intended Audience :: Developers
|
|
@@ -20,16 +21,15 @@ Classifier: Operating System :: Unix
|
|
|
20
21
|
Classifier: Programming Language :: Python
|
|
21
22
|
Classifier: Programming Language :: Python :: 3
|
|
22
23
|
Classifier: Programming Language :: Python :: 3 :: Only
|
|
23
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
24
24
|
Classifier: Programming Language :: Python :: 3.10
|
|
25
25
|
Classifier: Programming Language :: Python :: 3.11
|
|
26
26
|
Classifier: Programming Language :: Python :: 3.12
|
|
27
27
|
Classifier: Programming Language :: Python :: 3.13
|
|
28
28
|
Classifier: Topic :: Internet
|
|
29
29
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
30
|
-
Requires-Python: >=3.
|
|
30
|
+
Requires-Python: >=3.10
|
|
31
31
|
Requires-Dist: httpx>=0.27
|
|
32
|
-
Requires-Dist: logfire-api>=
|
|
32
|
+
Requires-Dist: logfire-api>=3.14.1
|
|
33
33
|
Requires-Dist: pydantic>=2.10
|
|
34
34
|
Requires-Dist: typing-inspection>=0.4.0
|
|
35
35
|
Description-Content-Type: text/markdown
|
|
@@ -44,10 +44,10 @@ Description-Content-Type: text/markdown
|
|
|
44
44
|
|
|
45
45
|
Graph and finite state machine library.
|
|
46
46
|
|
|
47
|
-
This library is developed as part of [
|
|
48
|
-
on `pydantic-ai` or related packages and can be considered as a pure graph-based state machine library. You may find it useful whether or not you're using
|
|
47
|
+
This library is developed as part of [Pydantic AI](https://ai.pydantic.dev), however it has no dependency
|
|
48
|
+
on `pydantic-ai` or related packages and can be considered as a pure graph-based state machine library. You may find it useful whether or not you're using Pydantic AI or even building with GenAI.
|
|
49
49
|
|
|
50
|
-
As with
|
|
50
|
+
As with Pydantic AI, this library prioritizes type safety and use of common Python syntax over esoteric, domain-specific use of Python syntax.
|
|
51
51
|
|
|
52
52
|
`pydantic-graph` allows you to define graphs using standard Python syntax. In particular, edges are defined using the return type hint of nodes.
|
|
53
53
|
|
|
@@ -55,7 +55,7 @@ Full documentation is available at [ai.pydantic.dev/graph](https://ai.pydantic.d
|
|
|
55
55
|
|
|
56
56
|
Here's a basic example:
|
|
57
57
|
|
|
58
|
-
```python {noqa="I001"
|
|
58
|
+
```python {noqa="I001"}
|
|
59
59
|
from __future__ import annotations
|
|
60
60
|
|
|
61
61
|
from dataclasses import dataclass
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
pydantic_graph/__init__.py,sha256=qkrSmWLpnNhD7mLBWV70iS46vy2vFiU2zUModG31wPQ,593
|
|
2
|
+
pydantic_graph/_utils.py,sha256=f0B1VIRxfAslj5UIfBVRzXhIJkDXyF8P3B9V22Q7o7U,6782
|
|
3
|
+
pydantic_graph/exceptions.py,sha256=aeaBf2H18dV7YCNZKZmiXiI6Fyys2qQdunZwd7TSCPk,1648
|
|
4
|
+
pydantic_graph/graph.py,sha256=rEm_5PzRs-5k6Y0mmaF5SGhF0wPA2JSclNAEZtBUZpA,33942
|
|
5
|
+
pydantic_graph/mermaid.py,sha256=u8xM8eEAOWV0TkqEAPJJ9jL2XEZnJ_H7yNGhulg7SL4,10045
|
|
6
|
+
pydantic_graph/nodes.py,sha256=CkY3lrC6jqZtzwhSRjFzmM69TdFFFrr58XSDU4THKHA,7450
|
|
7
|
+
pydantic_graph/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
|
+
pydantic_graph/beta/__init__.py,sha256=VVmbEFaCSXYHwXqS4pANg4B3cn_c86tT62tW_EXcuyw,751
|
|
9
|
+
pydantic_graph/beta/decision.py,sha256=x-Ta549b-j5hyBPUWFdwRQDRaJqnBHF1pfBP9L8I3vI,11239
|
|
10
|
+
pydantic_graph/beta/graph.py,sha256=-T-HbVyBC3qgg_-dXURnCbI6K_mqj25jDVh_RMlVsS8,42811
|
|
11
|
+
pydantic_graph/beta/graph_builder.py,sha256=2sD7TR8oGg4Gatrms0jE17NXzQN7drzUvaJKs5BvILU,43329
|
|
12
|
+
pydantic_graph/beta/id_types.py,sha256=FZ3rYSubF6g_Ocv0faL3yJsy1lNN9AGZl9f_izvORUg,2814
|
|
13
|
+
pydantic_graph/beta/join.py,sha256=rzCumDX_YgaU_a5bisfbjbbOuI3IwSZsCZs9TC0T9E4,8002
|
|
14
|
+
pydantic_graph/beta/mermaid.py,sha256=vpB9laZeTaS-P6BJplyN7DLFz0ppRVafGjBfqRiTh-s,7128
|
|
15
|
+
pydantic_graph/beta/node.py,sha256=cTEGKiT3Lutg-PWxBbZDihpnBTVoPMSyCbfB50fjKeY,3071
|
|
16
|
+
pydantic_graph/beta/node_types.py,sha256=Ha1QPbAHqmJ1ARb359b8LOJK-jZDMO_ZyHkYv9Zbglw,3399
|
|
17
|
+
pydantic_graph/beta/parent_forks.py,sha256=lMCT3slwDuZtiLZImqXuW-i0ZftONCWGr7RTpCAe9dY,9691
|
|
18
|
+
pydantic_graph/beta/paths.py,sha256=LkFvgnyNa1lHdFkN83F7Dgsdg9Q2y0zYrLyqprQiQiY,16068
|
|
19
|
+
pydantic_graph/beta/step.py,sha256=n0JstmxM6Z2rCc2EPUrSAp4MS4IjM2mZsE0ymeekzxU,8683
|
|
20
|
+
pydantic_graph/beta/util.py,sha256=F7IkSC0U-tU1yOxncslyOrZ5HlrZIdafBJARsPetIHQ,3153
|
|
21
|
+
pydantic_graph/persistence/__init__.py,sha256=NLBGvUWhem23EdMHHxtX0XgTS2vyixmuWtWmZKj_U58,8968
|
|
22
|
+
pydantic_graph/persistence/_utils.py,sha256=6ySxCc1lFz7bbLUwDLkoZWNqi8VNLBVU4xxJbKI23fQ,2264
|
|
23
|
+
pydantic_graph/persistence/file.py,sha256=XZy295cGc86HfUl_KuB-e7cECZW3bubiEdyJMVQ1OD0,6906
|
|
24
|
+
pydantic_graph/persistence/in_mem.py,sha256=MmahaVpdzmDB30Dm3ZfSCZBqgmx6vH4HXdBaWwVF0K0,6799
|
|
25
|
+
pydantic_graph-1.24.0.dist-info/METADATA,sha256=MuevkYyRJr96QAs8gcaDetAButYgPY35d89YQ7Ruluw,3895
|
|
26
|
+
pydantic_graph-1.24.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
27
|
+
pydantic_graph-1.24.0.dist-info/licenses/LICENSE,sha256=vA6Jc482lEyBBuGUfD1pYx-cM7jxvLYOxPidZ30t_PQ,1100
|
|
28
|
+
pydantic_graph-1.24.0.dist-info/RECORD,,
|