pydantic-graph 1.2.1__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,421 @@
1
+ """Path and edge definition for graph navigation.
2
+
3
+ This module provides the building blocks for defining paths through a graph,
4
+ including transformations, maps, broadcasts, and routing to destinations.
5
+ Paths enable complex data flow patterns in graph execution.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import inspect
11
+ from collections.abc import AsyncIterable, Callable, Iterable, Sequence
12
+ from dataclasses import dataclass
13
+ from typing import TYPE_CHECKING, Any, Generic, get_origin
14
+
15
+ from typing_extensions import Protocol, Self, TypeAliasType, TypeVar
16
+
17
+ from pydantic_graph import BaseNode
18
+ from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID, generate_placeholder_node_id
19
+ from pydantic_graph.beta.step import NodeStep, StepContext
20
+ from pydantic_graph.exceptions import GraphBuildingError
21
+
22
+ StateT = TypeVar('StateT', infer_variance=True)
23
+ DepsT = TypeVar('DepsT', infer_variance=True)
24
+ OutputT = TypeVar('OutputT', infer_variance=True)
25
+ InputT = TypeVar('InputT', infer_variance=True)
26
+ T = TypeVar('T')
27
+
28
+ if TYPE_CHECKING:
29
+ from pydantic_graph.beta.node_types import AnyDestinationNode, DestinationNode, SourceNode
30
+
31
+
32
+ class TransformFunction(Protocol[StateT, DepsT, InputT, OutputT]):
33
+ """Protocol for step functions that can be executed in the graph.
34
+
35
+ Transform functions are sync callables that receive a step context and return
36
+ a result. This protocol enables serialization and deserialization of step
37
+ calls similar to how evaluators work.
38
+
39
+ This is very similar to a StepFunction, but must be sync instead of async.
40
+
41
+ Type Parameters:
42
+ StateT: The type of the graph state
43
+ DepsT: The type of the dependencies
44
+ InputT: The type of the input data
45
+ OutputT: The type of the output data
46
+ """
47
+
48
+ def __call__(self, ctx: StepContext[StateT, DepsT, InputT]) -> OutputT:
49
+ """Execute the step function with the given context.
50
+
51
+ Args:
52
+ ctx: The step context containing state, dependencies, and inputs
53
+
54
+ Returns:
55
+ An awaitable that resolves to the step's output
56
+ """
57
+ raise NotImplementedError
58
+
59
+
60
+ @dataclass
61
+ class TransformMarker:
62
+ """A marker indicating a data transformation step in a path.
63
+
64
+ Transform markers wrap step functions that modify data as it flows
65
+ through the graph path.
66
+ """
67
+
68
+ transform: TransformFunction[Any, Any, Any, Any]
69
+ """The step function that performs the transformation."""
70
+
71
+
72
+ @dataclass
73
+ class MapMarker:
74
+ """A marker indicating that iterable data should be map across parallel paths.
75
+
76
+ Spread markers take iterable input and create parallel execution paths
77
+ for each item in the iterable.
78
+ """
79
+
80
+ fork_id: ForkID
81
+ """Unique identifier for the fork created by this map operation."""
82
+ downstream_join_id: JoinID | None
83
+ """Optional identifier of a downstream join node that should be jumped to if mapping an empty iterable."""
84
+
85
+
86
+ @dataclass
87
+ class BroadcastMarker:
88
+ """A marker indicating that data should be broadcast to multiple parallel paths.
89
+
90
+ Broadcast markers create multiple parallel execution paths, sending the
91
+ same input data to each path.
92
+ """
93
+
94
+ paths: Sequence[Path]
95
+ """The parallel paths that will receive the broadcast data."""
96
+
97
+ fork_id: ForkID
98
+ """Unique identifier for the fork created by this broadcast operation."""
99
+
100
+
101
+ @dataclass
102
+ class LabelMarker:
103
+ """A marker providing a human-readable label for a path segment.
104
+
105
+ Label markers are used for debugging, visualization, and documentation
106
+ purposes to provide meaningful names for path segments.
107
+ """
108
+
109
+ label: str
110
+ """The human-readable label for this path segment."""
111
+
112
+
113
+ @dataclass
114
+ class DestinationMarker:
115
+ """A marker indicating the target destination node for a path.
116
+
117
+ Destination markers specify where data should be routed at the end
118
+ of a path execution.
119
+ """
120
+
121
+ destination_id: NodeID
122
+ """The unique identifier of the destination node."""
123
+
124
+
125
+ PathItem = TypeAliasType('PathItem', TransformMarker | MapMarker | BroadcastMarker | LabelMarker | DestinationMarker)
126
+ """Type alias for any item that can appear in a path sequence."""
127
+
128
+
129
+ @dataclass
130
+ class Path:
131
+ """A sequence of path items defining data flow through the graph.
132
+
133
+ Paths represent the route that data takes through the graph, including
134
+ transformations, forks, and routing decisions.
135
+ """
136
+
137
+ items: list[PathItem]
138
+ """The sequence of path items that define this path."""
139
+
140
+ @property
141
+ def last_fork(self) -> BroadcastMarker | MapMarker | None:
142
+ """Get the most recent fork or map marker in this path.
143
+
144
+ Returns:
145
+ The last BroadcastMarker or MapMarker in the path, or None if no forks exist
146
+ """
147
+ for item in reversed(self.items):
148
+ if isinstance(item, BroadcastMarker | MapMarker):
149
+ return item
150
+ return None
151
+
152
+ @property
153
+ def next_path(self) -> Path:
154
+ """Create a new path with the first item removed.
155
+
156
+ Returns:
157
+ A new Path with all items except the first one
158
+ """
159
+ return Path(self.items[1:])
160
+
161
+
162
+ @dataclass
163
+ class PathBuilder(Generic[StateT, DepsT, OutputT]):
164
+ """A builder for constructing paths with method chaining.
165
+
166
+ PathBuilder provides a fluent interface for creating paths by chaining
167
+ operations like transforms, maps, and routing to destinations.
168
+
169
+ Type Parameters:
170
+ StateT: The type of the graph state
171
+ DepsT: The type of the dependencies
172
+ OutputT: The type of the current data in the path
173
+ """
174
+
175
+ working_items: Sequence[PathItem]
176
+ """The accumulated sequence of path items being built."""
177
+
178
+ def to(
179
+ self,
180
+ destination: DestinationNode[StateT, DepsT, OutputT],
181
+ /,
182
+ *extra_destinations: DestinationNode[StateT, DepsT, OutputT],
183
+ fork_id: str | None = None,
184
+ ) -> Path:
185
+ """Route the path to one or more destination nodes.
186
+
187
+ Args:
188
+ destination: The primary destination node
189
+ *extra_destinations: Additional destination nodes (creates a broadcast)
190
+ fork_id: Optional ID for the fork created when multiple destinations are specified
191
+
192
+ Returns:
193
+ A complete Path ending at the specified destination(s)
194
+ """
195
+ if extra_destinations:
196
+ next_item = BroadcastMarker(
197
+ paths=[Path(items=[DestinationMarker(d.id)]) for d in (destination,) + extra_destinations],
198
+ fork_id=ForkID(NodeID(fork_id or generate_placeholder_node_id('broadcast'))),
199
+ )
200
+ else:
201
+ next_item = DestinationMarker(destination.id)
202
+ return Path(items=[*self.working_items, next_item])
203
+
204
+ def broadcast(self, forks: Sequence[Path], /, *, fork_id: str | None = None) -> Path:
205
+ """Create a fork that broadcasts data to multiple parallel paths.
206
+
207
+ Args:
208
+ forks: The sequence of paths to run in parallel
209
+ fork_id: Optional ID for the fork, defaults to a generated value
210
+
211
+ Returns:
212
+ A complete Path that forks to the specified parallel paths
213
+ """
214
+ next_item = BroadcastMarker(
215
+ paths=forks, fork_id=ForkID(NodeID(fork_id or generate_placeholder_node_id('broadcast')))
216
+ )
217
+ return Path(items=[*self.working_items, next_item])
218
+
219
+ def transform(self, func: TransformFunction[StateT, DepsT, OutputT, T], /) -> PathBuilder[StateT, DepsT, T]:
220
+ """Add a transformation step to the path.
221
+
222
+ Args:
223
+ func: The step function that will transform the data
224
+
225
+ Returns:
226
+ A new PathBuilder with the transformation added
227
+ """
228
+ next_item = TransformMarker(func)
229
+ return PathBuilder[StateT, DepsT, T](working_items=[*self.working_items, next_item])
230
+
231
+ def map(
232
+ self: PathBuilder[StateT, DepsT, Iterable[T]] | PathBuilder[StateT, DepsT, AsyncIterable[T]],
233
+ *,
234
+ fork_id: str | None = None,
235
+ downstream_join_id: str | None = None,
236
+ ) -> PathBuilder[StateT, DepsT, T]:
237
+ """Spread iterable data across parallel execution paths.
238
+
239
+ This method can only be called when the current output type is iterable.
240
+ It creates parallel paths for each item in the iterable.
241
+
242
+ Args:
243
+ fork_id: Optional ID for the fork, defaults to a generated value
244
+ downstream_join_id: Optional ID of a downstream join node which is involved when mapping empty iterables
245
+
246
+ Returns:
247
+ A new PathBuilder that operates on individual items from the iterable
248
+ """
249
+ next_item = MapMarker(
250
+ fork_id=ForkID(NodeID(fork_id or generate_placeholder_node_id('map'))),
251
+ downstream_join_id=JoinID(downstream_join_id) if downstream_join_id is not None else None,
252
+ )
253
+ return PathBuilder[StateT, DepsT, T](working_items=[*self.working_items, next_item])
254
+
255
+ def label(self, label: str, /) -> PathBuilder[StateT, DepsT, OutputT]:
256
+ """Add a human-readable label to this point in the path.
257
+
258
+ Args:
259
+ label: The label to add for documentation/debugging purposes
260
+
261
+ Returns:
262
+ A new PathBuilder with the label added
263
+ """
264
+ next_item = LabelMarker(label)
265
+ return PathBuilder[StateT, DepsT, OutputT](working_items=[*self.working_items, next_item])
266
+
267
+
268
+ @dataclass(init=False)
269
+ class EdgePath(Generic[StateT, DepsT]):
270
+ """A complete edge connecting source nodes to destinations via a path.
271
+
272
+ EdgePath represents a complete connection in the graph, specifying the
273
+ source nodes, the path that data follows, and the destination nodes.
274
+ """
275
+
276
+ _sources: Sequence[SourceNode[StateT, DepsT, Any]]
277
+ """The source nodes that provide data to this edge."""
278
+ path: Path
279
+ """The path that data follows through the graph."""
280
+ destinations: list[AnyDestinationNode]
281
+ """The destination nodes that can be referenced by DestinationMarker in the path."""
282
+
283
+ def __init__(
284
+ self, sources: Sequence[SourceNode[StateT, DepsT, Any]], path: Path, destinations: list[AnyDestinationNode]
285
+ ):
286
+ self._sources = sources
287
+ self.path = path
288
+ self.destinations = destinations
289
+
290
+ @property
291
+ def sources(self) -> Sequence[SourceNode[StateT, DepsT, Any]]:
292
+ return self._sources
293
+
294
+
295
+ class EdgePathBuilder(Generic[StateT, DepsT, OutputT]):
296
+ """A builder for constructing complete edge paths with method chaining.
297
+
298
+ EdgePathBuilder combines source nodes with path building capabilities
299
+ to create complete edge definitions. It cannot use dataclass due to
300
+ type variance issues.
301
+
302
+ Type Parameters:
303
+ StateT: The type of the graph state
304
+ DepsT: The type of the dependencies
305
+ OutputT: The type of the current data in the path
306
+ """
307
+
308
+ def __init__(
309
+ self, sources: Sequence[SourceNode[StateT, DepsT, Any]], path_builder: PathBuilder[StateT, DepsT, OutputT]
310
+ ):
311
+ """Initialize an edge path builder.
312
+
313
+ Args:
314
+ sources: The source nodes for this edge path
315
+ path_builder: The path builder for defining the data flow
316
+ """
317
+ self.sources = sources
318
+ self._path_builder = path_builder
319
+
320
+ def to(
321
+ self,
322
+ destination: DestinationNode[StateT, DepsT, OutputT] | type[BaseNode[StateT, DepsT, Any]],
323
+ /,
324
+ *extra_destinations: DestinationNode[StateT, DepsT, OutputT] | type[BaseNode[StateT, DepsT, Any]],
325
+ fork_id: str | None = None,
326
+ ) -> EdgePath[StateT, DepsT]:
327
+ """Complete the edge path by routing to destination nodes.
328
+
329
+ Args:
330
+ destination: Either a destination node or a function that generates edge paths
331
+ *extra_destinations: Additional destination nodes (creates a broadcast)
332
+ fork_id: Optional ID for the fork created when multiple destinations are specified
333
+
334
+ Returns:
335
+ A complete EdgePath connecting sources to destinations
336
+ """
337
+ # `type[BaseNode[StateT, DepsT, Any]]` could actually be a `typing._GenericAlias` like `pydantic_ai._agent_graph.UserPromptNode[~DepsT, ~OutputT]`,
338
+ # so we get the origin to get to the actual class
339
+ destination = get_origin(destination) or destination
340
+ extra_destinations = tuple(get_origin(d) or d for d in extra_destinations)
341
+ destinations = [(NodeStep(d) if inspect.isclass(d) else d) for d in (destination, *extra_destinations)]
342
+ return EdgePath(
343
+ sources=self.sources,
344
+ path=self._path_builder.to(destinations[0], *destinations[1:], fork_id=fork_id),
345
+ destinations=destinations,
346
+ )
347
+
348
+ def broadcast(
349
+ self, get_forks: Callable[[Self], Sequence[EdgePath[StateT, DepsT]]], /, *, fork_id: str | None = None
350
+ ) -> EdgePath[StateT, DepsT]:
351
+ """Broadcast this EdgePathBuilder into multiple destinations.
352
+
353
+ Args:
354
+ get_forks: The callback that will return a sequence of EdgePaths to broadcast to.
355
+ fork_id: Optional node ID to use for the resulting broadcast fork.
356
+
357
+ Returns:
358
+ A completed EdgePath with the specified destinations.
359
+ """
360
+ new_edge_paths = get_forks(self)
361
+ new_paths = [Path(x.path.items) for x in new_edge_paths]
362
+ if not new_paths:
363
+ raise GraphBuildingError(f'The call to {get_forks} returned no branches, but must return at least one.')
364
+ path = self._path_builder.broadcast(new_paths, fork_id=fork_id)
365
+ destinations = [d for ep in new_edge_paths for d in ep.destinations]
366
+ return EdgePath(
367
+ sources=self.sources,
368
+ path=path,
369
+ destinations=destinations,
370
+ )
371
+
372
+ def map(
373
+ self: EdgePathBuilder[StateT, DepsT, Iterable[T]] | EdgePathBuilder[StateT, DepsT, AsyncIterable[T]],
374
+ *,
375
+ fork_id: str | None = None,
376
+ downstream_join_id: JoinID | None = None,
377
+ ) -> EdgePathBuilder[StateT, DepsT, T]:
378
+ """Spread iterable data across parallel execution paths.
379
+
380
+ Args:
381
+ fork_id: Optional ID for the fork, defaults to a generated value
382
+ downstream_join_id: Optional ID of a downstream join node which is involved when mapping empty iterables
383
+
384
+ Returns:
385
+ A new EdgePathBuilder that operates on individual items from the iterable
386
+ """
387
+ if len(self.sources) > 1:
388
+ # The current implementation mishandles this because you get one copy of each edge
389
+ # from the MapMarker to its destination for each source, resulting in unintentional multiple execution.
390
+ # I suspect this is fixable without a major refactor, though it's not clear to me what the ideal behavior
391
+ # would be. But for now, it's definitely easiest to just raise an error for this.
392
+ raise NotImplementedError(
393
+ 'Map is not currently supported with multiple source nodes.'
394
+ ' You can work around this by just creating a separate edge for each source.'
395
+ )
396
+ return EdgePathBuilder(
397
+ sources=self.sources,
398
+ path_builder=self._path_builder.map(fork_id=fork_id, downstream_join_id=downstream_join_id),
399
+ )
400
+
401
+ def transform(self, func: TransformFunction[StateT, DepsT, OutputT, T], /) -> EdgePathBuilder[StateT, DepsT, T]:
402
+ """Add a transformation step to the edge path.
403
+
404
+ Args:
405
+ func: The step function that will transform the data
406
+
407
+ Returns:
408
+ A new EdgePathBuilder with the transformation added
409
+ """
410
+ return EdgePathBuilder(sources=self.sources, path_builder=self._path_builder.transform(func))
411
+
412
+ def label(self, label: str) -> EdgePathBuilder[StateT, DepsT, OutputT]:
413
+ """Add a human-readable label to this point in the edge path.
414
+
415
+ Args:
416
+ label: The label to add for documentation/debugging purposes
417
+
418
+ Returns:
419
+ A new EdgePathBuilder with the label added
420
+ """
421
+ return EdgePathBuilder(sources=self.sources, path_builder=self._path_builder.label(label))
@@ -0,0 +1,253 @@
1
+ """Step-based graph execution components.
2
+
3
+ This module provides the core abstractions for step-based graph execution,
4
+ including step contexts, step functions, and step nodes that bridge between
5
+ the v1 and v2 graph execution systems.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from collections.abc import AsyncIterator, Awaitable
11
+ from dataclasses import dataclass
12
+ from typing import Any, Generic, Protocol, cast, get_origin, overload
13
+
14
+ from typing_extensions import TypeVar
15
+
16
+ from pydantic_graph.beta.id_types import NodeID
17
+ from pydantic_graph.nodes import BaseNode, End, GraphRunContext
18
+
19
+ StateT = TypeVar('StateT', infer_variance=True)
20
+ DepsT = TypeVar('DepsT', infer_variance=True)
21
+ InputT = TypeVar('InputT', infer_variance=True)
22
+ OutputT = TypeVar('OutputT', infer_variance=True)
23
+
24
+
25
+ @dataclass(init=False)
26
+ class StepContext(Generic[StateT, DepsT, InputT]):
27
+ """Context information passed to step functions during graph execution.
28
+
29
+ The step context provides access to the current graph state, dependencies, and input data for a step.
30
+
31
+ Type Parameters:
32
+ StateT: The type of the graph state
33
+ DepsT: The type of the dependencies
34
+ InputT: The type of the input data
35
+ """
36
+
37
+ _state: StateT
38
+ """The current graph state."""
39
+ _deps: DepsT
40
+ """The graph run dependencies."""
41
+ _inputs: InputT
42
+ """The input data for this step."""
43
+
44
+ def __init__(self, *, state: StateT, deps: DepsT, inputs: InputT):
45
+ self._state = state
46
+ self._deps = deps
47
+ self._inputs = inputs
48
+
49
+ @property
50
+ def state(self) -> StateT:
51
+ return self._state
52
+
53
+ @property
54
+ def deps(self) -> DepsT:
55
+ return self._deps
56
+
57
+ @property
58
+ def inputs(self) -> InputT:
59
+ """The input data for this step.
60
+
61
+ This must be a property to ensure correct variance behavior
62
+ """
63
+ return self._inputs
64
+
65
+
66
+ class StepFunction(Protocol[StateT, DepsT, InputT, OutputT]):
67
+ """Protocol for step functions that can be executed in the graph.
68
+
69
+ Step functions are async callables that receive a step context and return a result.
70
+
71
+ Type Parameters:
72
+ StateT: The type of the graph state
73
+ DepsT: The type of the dependencies
74
+ InputT: The type of the input data
75
+ OutputT: The type of the output data
76
+ """
77
+
78
+ def __call__(self, ctx: StepContext[StateT, DepsT, InputT]) -> Awaitable[OutputT]:
79
+ """Execute the step function with the given context.
80
+
81
+ Args:
82
+ ctx: The step context containing state, dependencies, and inputs
83
+
84
+ Returns:
85
+ An awaitable that resolves to the step's output
86
+ """
87
+ raise NotImplementedError
88
+
89
+
90
+ class StreamFunction(Protocol[StateT, DepsT, InputT, OutputT]):
91
+ """Protocol for stream functions that can be executed in the graph.
92
+
93
+ Stream functions are async callables that receive a step context and return an async iterator.
94
+
95
+ Type Parameters:
96
+ StateT: The type of the graph state
97
+ DepsT: The type of the dependencies
98
+ InputT: The type of the input data
99
+ OutputT: The type of the output data
100
+ """
101
+
102
+ def __call__(self, ctx: StepContext[StateT, DepsT, InputT]) -> AsyncIterator[OutputT]:
103
+ """Execute the stream function with the given context.
104
+
105
+ Args:
106
+ ctx: The step context containing state, dependencies, and inputs
107
+
108
+ Returns:
109
+ An async iterator yielding the streamed output
110
+ """
111
+ raise NotImplementedError
112
+ yield
113
+
114
+
115
+ AnyStepFunction = StepFunction[Any, Any, Any, Any]
116
+ """Type alias for a step function with any type parameters."""
117
+
118
+
119
+ @dataclass(init=False)
120
+ class Step(Generic[StateT, DepsT, InputT, OutputT]):
121
+ """A step in the graph execution that wraps a step function.
122
+
123
+ Steps represent individual units of execution in the graph, encapsulating
124
+ a step function along with metadata like ID and label.
125
+
126
+ Type Parameters:
127
+ StateT: The type of the graph state
128
+ DepsT: The type of the dependencies
129
+ InputT: The type of the input data
130
+ OutputT: The type of the output data
131
+ """
132
+
133
+ id: NodeID
134
+ """Unique identifier for this step."""
135
+ _call: StepFunction[StateT, DepsT, InputT, OutputT]
136
+ """The step function to execute."""
137
+ label: str | None
138
+ """Optional human-readable label for this step."""
139
+
140
+ def __init__(self, *, id: NodeID, call: StepFunction[StateT, DepsT, InputT, OutputT], label: str | None = None):
141
+ self.id = id
142
+ self._call = call
143
+ self.label = label
144
+
145
+ @property
146
+ def call(self) -> StepFunction[StateT, DepsT, InputT, OutputT]:
147
+ """The step function to execute. This needs to be a property for proper variance inference."""
148
+ return self._call
149
+
150
+ @overload
151
+ def as_node(self, inputs: None = None) -> StepNode[StateT, DepsT]: ...
152
+
153
+ @overload
154
+ def as_node(self, inputs: InputT) -> StepNode[StateT, DepsT]: ...
155
+
156
+ def as_node(self, inputs: InputT | None = None) -> StepNode[StateT, DepsT]:
157
+ """Create a step node with bound inputs.
158
+
159
+ Args:
160
+ inputs: The input data to bind to this step, or None
161
+
162
+ Returns:
163
+ A [`StepNode`][pydantic_graph.beta.step.StepNode] with this step and the bound inputs
164
+ """
165
+ return StepNode(self, inputs)
166
+
167
+
168
+ @dataclass
169
+ class StepNode(BaseNode[StateT, DepsT, Any]):
170
+ """A base node that represents a step with bound inputs.
171
+
172
+ StepNode bridges between the v1 and v2 graph execution systems by wrapping
173
+ a [`Step`][pydantic_graph.beta.step.Step] with bound inputs in a BaseNode interface.
174
+ It is not meant to be run directly but rather used to indicate transitions
175
+ to v2-style steps.
176
+ """
177
+
178
+ step: Step[StateT, DepsT, Any, Any]
179
+ """The step to execute."""
180
+
181
+ inputs: Any
182
+ """The inputs bound to this step."""
183
+
184
+ async def run(self, ctx: GraphRunContext[StateT, DepsT]) -> BaseNode[StateT, DepsT, Any] | End[Any]:
185
+ """Attempt to run the step node.
186
+
187
+ Args:
188
+ ctx: The graph execution context
189
+
190
+ Returns:
191
+ The result of step execution
192
+
193
+ Raises:
194
+ NotImplementedError: Always raised as StepNode is not meant to be run directly
195
+ """
196
+ raise NotImplementedError(
197
+ '`StepNode` is not meant to be run directly, it is meant to be used in `BaseNode` subclasses to indicate a transition to v2-style steps.'
198
+ )
199
+
200
+
201
+ # Note: we should make this into a frozen dataclass if https://github.com/python/mypy/issues/17623 gets resolved
202
+ # Right now, it cannot be because that breaks variance inference in Python 3.13 due to __replace__
203
+ class NodeStep(Step[StateT, DepsT, Any, BaseNode[StateT, DepsT, Any] | End[Any]]):
204
+ """A step that wraps a BaseNode type for execution.
205
+
206
+ NodeStep allows v1-style BaseNode classes to be used as steps in the
207
+ v2 graph execution system. It validates that the input is of the expected
208
+ node type and runs it with the appropriate graph context.
209
+ """
210
+
211
+ node_type: type[BaseNode[StateT, DepsT, Any]]
212
+ """The BaseNode type this step executes."""
213
+
214
+ def __init__(
215
+ self,
216
+ node_type: type[BaseNode[StateT, DepsT, Any]],
217
+ *,
218
+ id: NodeID | None = None,
219
+ label: str | None = None,
220
+ ):
221
+ """Initialize a node step.
222
+
223
+ Args:
224
+ node_type: The BaseNode class this step will execute
225
+ id: Optional unique identifier, defaults to the node's get_node_id()
226
+ label: Optional human-readable label for this step
227
+ """
228
+ super().__init__(
229
+ id=id or NodeID(node_type.get_node_id()),
230
+ call=self._call_node,
231
+ label=label,
232
+ )
233
+ # `type[BaseNode[StateT, DepsT, Any]]` could actually be a `typing._GenericAlias` like `pydantic_ai._agent_graph.UserPromptNode[~DepsT, ~OutputT]`,
234
+ # so we get the origin to get to the actual class
235
+ self.node_type = get_origin(node_type) or node_type
236
+
237
+ async def _call_node(self, ctx: StepContext[StateT, DepsT, Any]) -> BaseNode[StateT, DepsT, Any] | End[Any]:
238
+ """Execute the wrapped node with the step context.
239
+
240
+ Args:
241
+ ctx: The step context containing the node instance to run
242
+
243
+ Returns:
244
+ The result of running the node, either another BaseNode or End
245
+
246
+ Raises:
247
+ ValueError: If the input node is not of the expected type
248
+ """
249
+ node = ctx.inputs
250
+ if not isinstance(node, self.node_type):
251
+ raise ValueError(f'Node {node} is not of type {self.node_type}') # pragma: no cover
252
+ node = cast(BaseNode[StateT, DepsT, Any], node)
253
+ return await node.run(GraphRunContext(state=ctx.state, deps=ctx.deps))