pydantic-graph 1.3.0__py3-none-any.whl → 1.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_graph/_utils.py +39 -0
- pydantic_graph/beta/__init__.py +25 -0
- pydantic_graph/beta/decision.py +276 -0
- pydantic_graph/beta/graph.py +939 -0
- pydantic_graph/beta/graph_builder.py +1053 -0
- pydantic_graph/beta/id_types.py +79 -0
- pydantic_graph/beta/join.py +249 -0
- pydantic_graph/beta/mermaid.py +208 -0
- pydantic_graph/beta/node.py +95 -0
- pydantic_graph/beta/node_types.py +90 -0
- pydantic_graph/beta/parent_forks.py +232 -0
- pydantic_graph/beta/paths.py +421 -0
- pydantic_graph/beta/step.py +253 -0
- pydantic_graph/beta/util.py +90 -0
- pydantic_graph/exceptions.py +22 -0
- pydantic_graph/graph.py +12 -4
- pydantic_graph/nodes.py +0 -2
- pydantic_graph/persistence/in_mem.py +1 -1
- {pydantic_graph-1.3.0.dist-info → pydantic_graph-1.12.0.dist-info}/METADATA +1 -1
- pydantic_graph-1.12.0.dist-info/RECORD +28 -0
- pydantic_graph-1.3.0.dist-info/RECORD +0 -15
- {pydantic_graph-1.3.0.dist-info → pydantic_graph-1.12.0.dist-info}/WHEEL +0 -0
- {pydantic_graph-1.3.0.dist-info → pydantic_graph-1.12.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,421 @@
|
|
|
1
|
+
"""Path and edge definition for graph navigation.
|
|
2
|
+
|
|
3
|
+
This module provides the building blocks for defining paths through a graph,
|
|
4
|
+
including transformations, maps, broadcasts, and routing to destinations.
|
|
5
|
+
Paths enable complex data flow patterns in graph execution.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import inspect
|
|
11
|
+
from collections.abc import AsyncIterable, Callable, Iterable, Sequence
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from typing import TYPE_CHECKING, Any, Generic, get_origin
|
|
14
|
+
|
|
15
|
+
from typing_extensions import Protocol, Self, TypeAliasType, TypeVar
|
|
16
|
+
|
|
17
|
+
from pydantic_graph import BaseNode
|
|
18
|
+
from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID, generate_placeholder_node_id
|
|
19
|
+
from pydantic_graph.beta.step import NodeStep, StepContext
|
|
20
|
+
from pydantic_graph.exceptions import GraphBuildingError
|
|
21
|
+
|
|
22
|
+
StateT = TypeVar('StateT', infer_variance=True)
|
|
23
|
+
DepsT = TypeVar('DepsT', infer_variance=True)
|
|
24
|
+
OutputT = TypeVar('OutputT', infer_variance=True)
|
|
25
|
+
InputT = TypeVar('InputT', infer_variance=True)
|
|
26
|
+
T = TypeVar('T')
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from pydantic_graph.beta.node_types import AnyDestinationNode, DestinationNode, SourceNode
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TransformFunction(Protocol[StateT, DepsT, InputT, OutputT]):
|
|
33
|
+
"""Protocol for step functions that can be executed in the graph.
|
|
34
|
+
|
|
35
|
+
Transform functions are sync callables that receive a step context and return
|
|
36
|
+
a result. This protocol enables serialization and deserialization of step
|
|
37
|
+
calls similar to how evaluators work.
|
|
38
|
+
|
|
39
|
+
This is very similar to a StepFunction, but must be sync instead of async.
|
|
40
|
+
|
|
41
|
+
Type Parameters:
|
|
42
|
+
StateT: The type of the graph state
|
|
43
|
+
DepsT: The type of the dependencies
|
|
44
|
+
InputT: The type of the input data
|
|
45
|
+
OutputT: The type of the output data
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __call__(self, ctx: StepContext[StateT, DepsT, InputT]) -> OutputT:
|
|
49
|
+
"""Execute the step function with the given context.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
ctx: The step context containing state, dependencies, and inputs
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
An awaitable that resolves to the step's output
|
|
56
|
+
"""
|
|
57
|
+
raise NotImplementedError
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class TransformMarker:
|
|
62
|
+
"""A marker indicating a data transformation step in a path.
|
|
63
|
+
|
|
64
|
+
Transform markers wrap step functions that modify data as it flows
|
|
65
|
+
through the graph path.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
transform: TransformFunction[Any, Any, Any, Any]
|
|
69
|
+
"""The step function that performs the transformation."""
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dataclass
|
|
73
|
+
class MapMarker:
|
|
74
|
+
"""A marker indicating that iterable data should be map across parallel paths.
|
|
75
|
+
|
|
76
|
+
Spread markers take iterable input and create parallel execution paths
|
|
77
|
+
for each item in the iterable.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
fork_id: ForkID
|
|
81
|
+
"""Unique identifier for the fork created by this map operation."""
|
|
82
|
+
downstream_join_id: JoinID | None
|
|
83
|
+
"""Optional identifier of a downstream join node that should be jumped to if mapping an empty iterable."""
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@dataclass
|
|
87
|
+
class BroadcastMarker:
|
|
88
|
+
"""A marker indicating that data should be broadcast to multiple parallel paths.
|
|
89
|
+
|
|
90
|
+
Broadcast markers create multiple parallel execution paths, sending the
|
|
91
|
+
same input data to each path.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
paths: Sequence[Path]
|
|
95
|
+
"""The parallel paths that will receive the broadcast data."""
|
|
96
|
+
|
|
97
|
+
fork_id: ForkID
|
|
98
|
+
"""Unique identifier for the fork created by this broadcast operation."""
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@dataclass
|
|
102
|
+
class LabelMarker:
|
|
103
|
+
"""A marker providing a human-readable label for a path segment.
|
|
104
|
+
|
|
105
|
+
Label markers are used for debugging, visualization, and documentation
|
|
106
|
+
purposes to provide meaningful names for path segments.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
label: str
|
|
110
|
+
"""The human-readable label for this path segment."""
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@dataclass
|
|
114
|
+
class DestinationMarker:
|
|
115
|
+
"""A marker indicating the target destination node for a path.
|
|
116
|
+
|
|
117
|
+
Destination markers specify where data should be routed at the end
|
|
118
|
+
of a path execution.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
destination_id: NodeID
|
|
122
|
+
"""The unique identifier of the destination node."""
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
PathItem = TypeAliasType('PathItem', TransformMarker | MapMarker | BroadcastMarker | LabelMarker | DestinationMarker)
|
|
126
|
+
"""Type alias for any item that can appear in a path sequence."""
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@dataclass
|
|
130
|
+
class Path:
|
|
131
|
+
"""A sequence of path items defining data flow through the graph.
|
|
132
|
+
|
|
133
|
+
Paths represent the route that data takes through the graph, including
|
|
134
|
+
transformations, forks, and routing decisions.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
items: list[PathItem]
|
|
138
|
+
"""The sequence of path items that define this path."""
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def last_fork(self) -> BroadcastMarker | MapMarker | None:
|
|
142
|
+
"""Get the most recent fork or map marker in this path.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
The last BroadcastMarker or MapMarker in the path, or None if no forks exist
|
|
146
|
+
"""
|
|
147
|
+
for item in reversed(self.items):
|
|
148
|
+
if isinstance(item, BroadcastMarker | MapMarker):
|
|
149
|
+
return item
|
|
150
|
+
return None
|
|
151
|
+
|
|
152
|
+
@property
|
|
153
|
+
def next_path(self) -> Path:
|
|
154
|
+
"""Create a new path with the first item removed.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
A new Path with all items except the first one
|
|
158
|
+
"""
|
|
159
|
+
return Path(self.items[1:])
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@dataclass
|
|
163
|
+
class PathBuilder(Generic[StateT, DepsT, OutputT]):
|
|
164
|
+
"""A builder for constructing paths with method chaining.
|
|
165
|
+
|
|
166
|
+
PathBuilder provides a fluent interface for creating paths by chaining
|
|
167
|
+
operations like transforms, maps, and routing to destinations.
|
|
168
|
+
|
|
169
|
+
Type Parameters:
|
|
170
|
+
StateT: The type of the graph state
|
|
171
|
+
DepsT: The type of the dependencies
|
|
172
|
+
OutputT: The type of the current data in the path
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
working_items: Sequence[PathItem]
|
|
176
|
+
"""The accumulated sequence of path items being built."""
|
|
177
|
+
|
|
178
|
+
def to(
|
|
179
|
+
self,
|
|
180
|
+
destination: DestinationNode[StateT, DepsT, OutputT],
|
|
181
|
+
/,
|
|
182
|
+
*extra_destinations: DestinationNode[StateT, DepsT, OutputT],
|
|
183
|
+
fork_id: str | None = None,
|
|
184
|
+
) -> Path:
|
|
185
|
+
"""Route the path to one or more destination nodes.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
destination: The primary destination node
|
|
189
|
+
*extra_destinations: Additional destination nodes (creates a broadcast)
|
|
190
|
+
fork_id: Optional ID for the fork created when multiple destinations are specified
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
A complete Path ending at the specified destination(s)
|
|
194
|
+
"""
|
|
195
|
+
if extra_destinations:
|
|
196
|
+
next_item = BroadcastMarker(
|
|
197
|
+
paths=[Path(items=[DestinationMarker(d.id)]) for d in (destination,) + extra_destinations],
|
|
198
|
+
fork_id=ForkID(NodeID(fork_id or generate_placeholder_node_id('broadcast'))),
|
|
199
|
+
)
|
|
200
|
+
else:
|
|
201
|
+
next_item = DestinationMarker(destination.id)
|
|
202
|
+
return Path(items=[*self.working_items, next_item])
|
|
203
|
+
|
|
204
|
+
def broadcast(self, forks: Sequence[Path], /, *, fork_id: str | None = None) -> Path:
|
|
205
|
+
"""Create a fork that broadcasts data to multiple parallel paths.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
forks: The sequence of paths to run in parallel
|
|
209
|
+
fork_id: Optional ID for the fork, defaults to a generated value
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
A complete Path that forks to the specified parallel paths
|
|
213
|
+
"""
|
|
214
|
+
next_item = BroadcastMarker(
|
|
215
|
+
paths=forks, fork_id=ForkID(NodeID(fork_id or generate_placeholder_node_id('broadcast')))
|
|
216
|
+
)
|
|
217
|
+
return Path(items=[*self.working_items, next_item])
|
|
218
|
+
|
|
219
|
+
def transform(self, func: TransformFunction[StateT, DepsT, OutputT, T], /) -> PathBuilder[StateT, DepsT, T]:
|
|
220
|
+
"""Add a transformation step to the path.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
func: The step function that will transform the data
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
A new PathBuilder with the transformation added
|
|
227
|
+
"""
|
|
228
|
+
next_item = TransformMarker(func)
|
|
229
|
+
return PathBuilder[StateT, DepsT, T](working_items=[*self.working_items, next_item])
|
|
230
|
+
|
|
231
|
+
def map(
|
|
232
|
+
self: PathBuilder[StateT, DepsT, Iterable[T]] | PathBuilder[StateT, DepsT, AsyncIterable[T]],
|
|
233
|
+
*,
|
|
234
|
+
fork_id: str | None = None,
|
|
235
|
+
downstream_join_id: str | None = None,
|
|
236
|
+
) -> PathBuilder[StateT, DepsT, T]:
|
|
237
|
+
"""Spread iterable data across parallel execution paths.
|
|
238
|
+
|
|
239
|
+
This method can only be called when the current output type is iterable.
|
|
240
|
+
It creates parallel paths for each item in the iterable.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
fork_id: Optional ID for the fork, defaults to a generated value
|
|
244
|
+
downstream_join_id: Optional ID of a downstream join node which is involved when mapping empty iterables
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
A new PathBuilder that operates on individual items from the iterable
|
|
248
|
+
"""
|
|
249
|
+
next_item = MapMarker(
|
|
250
|
+
fork_id=ForkID(NodeID(fork_id or generate_placeholder_node_id('map'))),
|
|
251
|
+
downstream_join_id=JoinID(downstream_join_id) if downstream_join_id is not None else None,
|
|
252
|
+
)
|
|
253
|
+
return PathBuilder[StateT, DepsT, T](working_items=[*self.working_items, next_item])
|
|
254
|
+
|
|
255
|
+
def label(self, label: str, /) -> PathBuilder[StateT, DepsT, OutputT]:
|
|
256
|
+
"""Add a human-readable label to this point in the path.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
label: The label to add for documentation/debugging purposes
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
A new PathBuilder with the label added
|
|
263
|
+
"""
|
|
264
|
+
next_item = LabelMarker(label)
|
|
265
|
+
return PathBuilder[StateT, DepsT, OutputT](working_items=[*self.working_items, next_item])
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
@dataclass(init=False)
|
|
269
|
+
class EdgePath(Generic[StateT, DepsT]):
|
|
270
|
+
"""A complete edge connecting source nodes to destinations via a path.
|
|
271
|
+
|
|
272
|
+
EdgePath represents a complete connection in the graph, specifying the
|
|
273
|
+
source nodes, the path that data follows, and the destination nodes.
|
|
274
|
+
"""
|
|
275
|
+
|
|
276
|
+
_sources: Sequence[SourceNode[StateT, DepsT, Any]]
|
|
277
|
+
"""The source nodes that provide data to this edge."""
|
|
278
|
+
path: Path
|
|
279
|
+
"""The path that data follows through the graph."""
|
|
280
|
+
destinations: list[AnyDestinationNode]
|
|
281
|
+
"""The destination nodes that can be referenced by DestinationMarker in the path."""
|
|
282
|
+
|
|
283
|
+
def __init__(
|
|
284
|
+
self, sources: Sequence[SourceNode[StateT, DepsT, Any]], path: Path, destinations: list[AnyDestinationNode]
|
|
285
|
+
):
|
|
286
|
+
self._sources = sources
|
|
287
|
+
self.path = path
|
|
288
|
+
self.destinations = destinations
|
|
289
|
+
|
|
290
|
+
@property
|
|
291
|
+
def sources(self) -> Sequence[SourceNode[StateT, DepsT, Any]]:
|
|
292
|
+
return self._sources
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
class EdgePathBuilder(Generic[StateT, DepsT, OutputT]):
|
|
296
|
+
"""A builder for constructing complete edge paths with method chaining.
|
|
297
|
+
|
|
298
|
+
EdgePathBuilder combines source nodes with path building capabilities
|
|
299
|
+
to create complete edge definitions. It cannot use dataclass due to
|
|
300
|
+
type variance issues.
|
|
301
|
+
|
|
302
|
+
Type Parameters:
|
|
303
|
+
StateT: The type of the graph state
|
|
304
|
+
DepsT: The type of the dependencies
|
|
305
|
+
OutputT: The type of the current data in the path
|
|
306
|
+
"""
|
|
307
|
+
|
|
308
|
+
def __init__(
|
|
309
|
+
self, sources: Sequence[SourceNode[StateT, DepsT, Any]], path_builder: PathBuilder[StateT, DepsT, OutputT]
|
|
310
|
+
):
|
|
311
|
+
"""Initialize an edge path builder.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
sources: The source nodes for this edge path
|
|
315
|
+
path_builder: The path builder for defining the data flow
|
|
316
|
+
"""
|
|
317
|
+
self.sources = sources
|
|
318
|
+
self._path_builder = path_builder
|
|
319
|
+
|
|
320
|
+
def to(
|
|
321
|
+
self,
|
|
322
|
+
destination: DestinationNode[StateT, DepsT, OutputT] | type[BaseNode[StateT, DepsT, Any]],
|
|
323
|
+
/,
|
|
324
|
+
*extra_destinations: DestinationNode[StateT, DepsT, OutputT] | type[BaseNode[StateT, DepsT, Any]],
|
|
325
|
+
fork_id: str | None = None,
|
|
326
|
+
) -> EdgePath[StateT, DepsT]:
|
|
327
|
+
"""Complete the edge path by routing to destination nodes.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
destination: Either a destination node or a function that generates edge paths
|
|
331
|
+
*extra_destinations: Additional destination nodes (creates a broadcast)
|
|
332
|
+
fork_id: Optional ID for the fork created when multiple destinations are specified
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
A complete EdgePath connecting sources to destinations
|
|
336
|
+
"""
|
|
337
|
+
# `type[BaseNode[StateT, DepsT, Any]]` could actually be a `typing._GenericAlias` like `pydantic_ai._agent_graph.UserPromptNode[~DepsT, ~OutputT]`,
|
|
338
|
+
# so we get the origin to get to the actual class
|
|
339
|
+
destination = get_origin(destination) or destination
|
|
340
|
+
extra_destinations = tuple(get_origin(d) or d for d in extra_destinations)
|
|
341
|
+
destinations = [(NodeStep(d) if inspect.isclass(d) else d) for d in (destination, *extra_destinations)]
|
|
342
|
+
return EdgePath(
|
|
343
|
+
sources=self.sources,
|
|
344
|
+
path=self._path_builder.to(destinations[0], *destinations[1:], fork_id=fork_id),
|
|
345
|
+
destinations=destinations,
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
def broadcast(
|
|
349
|
+
self, get_forks: Callable[[Self], Sequence[EdgePath[StateT, DepsT]]], /, *, fork_id: str | None = None
|
|
350
|
+
) -> EdgePath[StateT, DepsT]:
|
|
351
|
+
"""Broadcast this EdgePathBuilder into multiple destinations.
|
|
352
|
+
|
|
353
|
+
Args:
|
|
354
|
+
get_forks: The callback that will return a sequence of EdgePaths to broadcast to.
|
|
355
|
+
fork_id: Optional node ID to use for the resulting broadcast fork.
|
|
356
|
+
|
|
357
|
+
Returns:
|
|
358
|
+
A completed EdgePath with the specified destinations.
|
|
359
|
+
"""
|
|
360
|
+
new_edge_paths = get_forks(self)
|
|
361
|
+
new_paths = [Path(x.path.items) for x in new_edge_paths]
|
|
362
|
+
if not new_paths:
|
|
363
|
+
raise GraphBuildingError(f'The call to {get_forks} returned no branches, but must return at least one.')
|
|
364
|
+
path = self._path_builder.broadcast(new_paths, fork_id=fork_id)
|
|
365
|
+
destinations = [d for ep in new_edge_paths for d in ep.destinations]
|
|
366
|
+
return EdgePath(
|
|
367
|
+
sources=self.sources,
|
|
368
|
+
path=path,
|
|
369
|
+
destinations=destinations,
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
def map(
|
|
373
|
+
self: EdgePathBuilder[StateT, DepsT, Iterable[T]] | EdgePathBuilder[StateT, DepsT, AsyncIterable[T]],
|
|
374
|
+
*,
|
|
375
|
+
fork_id: str | None = None,
|
|
376
|
+
downstream_join_id: JoinID | None = None,
|
|
377
|
+
) -> EdgePathBuilder[StateT, DepsT, T]:
|
|
378
|
+
"""Spread iterable data across parallel execution paths.
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
fork_id: Optional ID for the fork, defaults to a generated value
|
|
382
|
+
downstream_join_id: Optional ID of a downstream join node which is involved when mapping empty iterables
|
|
383
|
+
|
|
384
|
+
Returns:
|
|
385
|
+
A new EdgePathBuilder that operates on individual items from the iterable
|
|
386
|
+
"""
|
|
387
|
+
if len(self.sources) > 1:
|
|
388
|
+
# The current implementation mishandles this because you get one copy of each edge
|
|
389
|
+
# from the MapMarker to its destination for each source, resulting in unintentional multiple execution.
|
|
390
|
+
# I suspect this is fixable without a major refactor, though it's not clear to me what the ideal behavior
|
|
391
|
+
# would be. But for now, it's definitely easiest to just raise an error for this.
|
|
392
|
+
raise NotImplementedError(
|
|
393
|
+
'Map is not currently supported with multiple source nodes.'
|
|
394
|
+
' You can work around this by just creating a separate edge for each source.'
|
|
395
|
+
)
|
|
396
|
+
return EdgePathBuilder(
|
|
397
|
+
sources=self.sources,
|
|
398
|
+
path_builder=self._path_builder.map(fork_id=fork_id, downstream_join_id=downstream_join_id),
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
def transform(self, func: TransformFunction[StateT, DepsT, OutputT, T], /) -> EdgePathBuilder[StateT, DepsT, T]:
|
|
402
|
+
"""Add a transformation step to the edge path.
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
func: The step function that will transform the data
|
|
406
|
+
|
|
407
|
+
Returns:
|
|
408
|
+
A new EdgePathBuilder with the transformation added
|
|
409
|
+
"""
|
|
410
|
+
return EdgePathBuilder(sources=self.sources, path_builder=self._path_builder.transform(func))
|
|
411
|
+
|
|
412
|
+
def label(self, label: str) -> EdgePathBuilder[StateT, DepsT, OutputT]:
|
|
413
|
+
"""Add a human-readable label to this point in the edge path.
|
|
414
|
+
|
|
415
|
+
Args:
|
|
416
|
+
label: The label to add for documentation/debugging purposes
|
|
417
|
+
|
|
418
|
+
Returns:
|
|
419
|
+
A new EdgePathBuilder with the label added
|
|
420
|
+
"""
|
|
421
|
+
return EdgePathBuilder(sources=self.sources, path_builder=self._path_builder.label(label))
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
"""Step-based graph execution components.
|
|
2
|
+
|
|
3
|
+
This module provides the core abstractions for step-based graph execution,
|
|
4
|
+
including step contexts, step functions, and step nodes that bridge between
|
|
5
|
+
the v1 and v2 graph execution systems.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from collections.abc import AsyncIterator, Awaitable
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import Any, Generic, Protocol, cast, get_origin, overload
|
|
13
|
+
|
|
14
|
+
from typing_extensions import TypeVar
|
|
15
|
+
|
|
16
|
+
from pydantic_graph.beta.id_types import NodeID
|
|
17
|
+
from pydantic_graph.nodes import BaseNode, End, GraphRunContext
|
|
18
|
+
|
|
19
|
+
StateT = TypeVar('StateT', infer_variance=True)
|
|
20
|
+
DepsT = TypeVar('DepsT', infer_variance=True)
|
|
21
|
+
InputT = TypeVar('InputT', infer_variance=True)
|
|
22
|
+
OutputT = TypeVar('OutputT', infer_variance=True)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(init=False)
|
|
26
|
+
class StepContext(Generic[StateT, DepsT, InputT]):
|
|
27
|
+
"""Context information passed to step functions during graph execution.
|
|
28
|
+
|
|
29
|
+
The step context provides access to the current graph state, dependencies, and input data for a step.
|
|
30
|
+
|
|
31
|
+
Type Parameters:
|
|
32
|
+
StateT: The type of the graph state
|
|
33
|
+
DepsT: The type of the dependencies
|
|
34
|
+
InputT: The type of the input data
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
_state: StateT
|
|
38
|
+
"""The current graph state."""
|
|
39
|
+
_deps: DepsT
|
|
40
|
+
"""The graph run dependencies."""
|
|
41
|
+
_inputs: InputT
|
|
42
|
+
"""The input data for this step."""
|
|
43
|
+
|
|
44
|
+
def __init__(self, *, state: StateT, deps: DepsT, inputs: InputT):
|
|
45
|
+
self._state = state
|
|
46
|
+
self._deps = deps
|
|
47
|
+
self._inputs = inputs
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def state(self) -> StateT:
|
|
51
|
+
return self._state
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def deps(self) -> DepsT:
|
|
55
|
+
return self._deps
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def inputs(self) -> InputT:
|
|
59
|
+
"""The input data for this step.
|
|
60
|
+
|
|
61
|
+
This must be a property to ensure correct variance behavior
|
|
62
|
+
"""
|
|
63
|
+
return self._inputs
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class StepFunction(Protocol[StateT, DepsT, InputT, OutputT]):
|
|
67
|
+
"""Protocol for step functions that can be executed in the graph.
|
|
68
|
+
|
|
69
|
+
Step functions are async callables that receive a step context and return a result.
|
|
70
|
+
|
|
71
|
+
Type Parameters:
|
|
72
|
+
StateT: The type of the graph state
|
|
73
|
+
DepsT: The type of the dependencies
|
|
74
|
+
InputT: The type of the input data
|
|
75
|
+
OutputT: The type of the output data
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
def __call__(self, ctx: StepContext[StateT, DepsT, InputT]) -> Awaitable[OutputT]:
|
|
79
|
+
"""Execute the step function with the given context.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
ctx: The step context containing state, dependencies, and inputs
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
An awaitable that resolves to the step's output
|
|
86
|
+
"""
|
|
87
|
+
raise NotImplementedError
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class StreamFunction(Protocol[StateT, DepsT, InputT, OutputT]):
|
|
91
|
+
"""Protocol for stream functions that can be executed in the graph.
|
|
92
|
+
|
|
93
|
+
Stream functions are async callables that receive a step context and return an async iterator.
|
|
94
|
+
|
|
95
|
+
Type Parameters:
|
|
96
|
+
StateT: The type of the graph state
|
|
97
|
+
DepsT: The type of the dependencies
|
|
98
|
+
InputT: The type of the input data
|
|
99
|
+
OutputT: The type of the output data
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def __call__(self, ctx: StepContext[StateT, DepsT, InputT]) -> AsyncIterator[OutputT]:
|
|
103
|
+
"""Execute the stream function with the given context.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
ctx: The step context containing state, dependencies, and inputs
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
An async iterator yielding the streamed output
|
|
110
|
+
"""
|
|
111
|
+
raise NotImplementedError
|
|
112
|
+
yield
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
AnyStepFunction = StepFunction[Any, Any, Any, Any]
|
|
116
|
+
"""Type alias for a step function with any type parameters."""
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@dataclass(init=False)
|
|
120
|
+
class Step(Generic[StateT, DepsT, InputT, OutputT]):
|
|
121
|
+
"""A step in the graph execution that wraps a step function.
|
|
122
|
+
|
|
123
|
+
Steps represent individual units of execution in the graph, encapsulating
|
|
124
|
+
a step function along with metadata like ID and label.
|
|
125
|
+
|
|
126
|
+
Type Parameters:
|
|
127
|
+
StateT: The type of the graph state
|
|
128
|
+
DepsT: The type of the dependencies
|
|
129
|
+
InputT: The type of the input data
|
|
130
|
+
OutputT: The type of the output data
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
id: NodeID
|
|
134
|
+
"""Unique identifier for this step."""
|
|
135
|
+
_call: StepFunction[StateT, DepsT, InputT, OutputT]
|
|
136
|
+
"""The step function to execute."""
|
|
137
|
+
label: str | None
|
|
138
|
+
"""Optional human-readable label for this step."""
|
|
139
|
+
|
|
140
|
+
def __init__(self, *, id: NodeID, call: StepFunction[StateT, DepsT, InputT, OutputT], label: str | None = None):
|
|
141
|
+
self.id = id
|
|
142
|
+
self._call = call
|
|
143
|
+
self.label = label
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def call(self) -> StepFunction[StateT, DepsT, InputT, OutputT]:
|
|
147
|
+
"""The step function to execute. This needs to be a property for proper variance inference."""
|
|
148
|
+
return self._call
|
|
149
|
+
|
|
150
|
+
@overload
|
|
151
|
+
def as_node(self, inputs: None = None) -> StepNode[StateT, DepsT]: ...
|
|
152
|
+
|
|
153
|
+
@overload
|
|
154
|
+
def as_node(self, inputs: InputT) -> StepNode[StateT, DepsT]: ...
|
|
155
|
+
|
|
156
|
+
def as_node(self, inputs: InputT | None = None) -> StepNode[StateT, DepsT]:
|
|
157
|
+
"""Create a step node with bound inputs.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
inputs: The input data to bind to this step, or None
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
A [`StepNode`][pydantic_graph.beta.step.StepNode] with this step and the bound inputs
|
|
164
|
+
"""
|
|
165
|
+
return StepNode(self, inputs)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
@dataclass
|
|
169
|
+
class StepNode(BaseNode[StateT, DepsT, Any]):
|
|
170
|
+
"""A base node that represents a step with bound inputs.
|
|
171
|
+
|
|
172
|
+
StepNode bridges between the v1 and v2 graph execution systems by wrapping
|
|
173
|
+
a [`Step`][pydantic_graph.beta.step.Step] with bound inputs in a BaseNode interface.
|
|
174
|
+
It is not meant to be run directly but rather used to indicate transitions
|
|
175
|
+
to v2-style steps.
|
|
176
|
+
"""
|
|
177
|
+
|
|
178
|
+
step: Step[StateT, DepsT, Any, Any]
|
|
179
|
+
"""The step to execute."""
|
|
180
|
+
|
|
181
|
+
inputs: Any
|
|
182
|
+
"""The inputs bound to this step."""
|
|
183
|
+
|
|
184
|
+
async def run(self, ctx: GraphRunContext[StateT, DepsT]) -> BaseNode[StateT, DepsT, Any] | End[Any]:
|
|
185
|
+
"""Attempt to run the step node.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
ctx: The graph execution context
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
The result of step execution
|
|
192
|
+
|
|
193
|
+
Raises:
|
|
194
|
+
NotImplementedError: Always raised as StepNode is not meant to be run directly
|
|
195
|
+
"""
|
|
196
|
+
raise NotImplementedError(
|
|
197
|
+
'`StepNode` is not meant to be run directly, it is meant to be used in `BaseNode` subclasses to indicate a transition to v2-style steps.'
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
# Note: we should make this into a frozen dataclass if https://github.com/python/mypy/issues/17623 gets resolved
|
|
202
|
+
# Right now, it cannot be because that breaks variance inference in Python 3.13 due to __replace__
|
|
203
|
+
class NodeStep(Step[StateT, DepsT, Any, BaseNode[StateT, DepsT, Any] | End[Any]]):
|
|
204
|
+
"""A step that wraps a BaseNode type for execution.
|
|
205
|
+
|
|
206
|
+
NodeStep allows v1-style BaseNode classes to be used as steps in the
|
|
207
|
+
v2 graph execution system. It validates that the input is of the expected
|
|
208
|
+
node type and runs it with the appropriate graph context.
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
node_type: type[BaseNode[StateT, DepsT, Any]]
|
|
212
|
+
"""The BaseNode type this step executes."""
|
|
213
|
+
|
|
214
|
+
def __init__(
|
|
215
|
+
self,
|
|
216
|
+
node_type: type[BaseNode[StateT, DepsT, Any]],
|
|
217
|
+
*,
|
|
218
|
+
id: NodeID | None = None,
|
|
219
|
+
label: str | None = None,
|
|
220
|
+
):
|
|
221
|
+
"""Initialize a node step.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
node_type: The BaseNode class this step will execute
|
|
225
|
+
id: Optional unique identifier, defaults to the node's get_node_id()
|
|
226
|
+
label: Optional human-readable label for this step
|
|
227
|
+
"""
|
|
228
|
+
super().__init__(
|
|
229
|
+
id=id or NodeID(node_type.get_node_id()),
|
|
230
|
+
call=self._call_node,
|
|
231
|
+
label=label,
|
|
232
|
+
)
|
|
233
|
+
# `type[BaseNode[StateT, DepsT, Any]]` could actually be a `typing._GenericAlias` like `pydantic_ai._agent_graph.UserPromptNode[~DepsT, ~OutputT]`,
|
|
234
|
+
# so we get the origin to get to the actual class
|
|
235
|
+
self.node_type = get_origin(node_type) or node_type
|
|
236
|
+
|
|
237
|
+
async def _call_node(self, ctx: StepContext[StateT, DepsT, Any]) -> BaseNode[StateT, DepsT, Any] | End[Any]:
|
|
238
|
+
"""Execute the wrapped node with the step context.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
ctx: The step context containing the node instance to run
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
The result of running the node, either another BaseNode or End
|
|
245
|
+
|
|
246
|
+
Raises:
|
|
247
|
+
ValueError: If the input node is not of the expected type
|
|
248
|
+
"""
|
|
249
|
+
node = ctx.inputs
|
|
250
|
+
if not isinstance(node, self.node_type):
|
|
251
|
+
raise ValueError(f'Node {node} is not of type {self.node_type}') # pragma: no cover
|
|
252
|
+
node = cast(BaseNode[StateT, DepsT, Any], node)
|
|
253
|
+
return await node.run(GraphRunContext(state=ctx.state, deps=ctx.deps))
|