pydantic-graph 1.13.0__tar.gz → 1.31.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/.gitignore +2 -0
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/PKG-INFO +1 -1
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/graph.py +71 -29
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/graph_builder.py +5 -3
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/id_types.py +0 -3
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/mermaid.py +1 -1
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/LICENSE +0 -0
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/README.md +0 -0
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/__init__.py +0 -0
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/_utils.py +0 -0
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/__init__.py +0 -0
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/decision.py +0 -0
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/join.py +0 -0
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/node.py +0 -0
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/node_types.py +0 -0
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/parent_forks.py +0 -0
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/paths.py +0 -0
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/step.py +0 -0
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/util.py +0 -0
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/exceptions.py +0 -0
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/graph.py +0 -0
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/mermaid.py +0 -0
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/nodes.py +0 -0
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/persistence/__init__.py +0 -0
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/persistence/_utils.py +0 -0
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/persistence/file.py +0 -0
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/persistence/in_mem.py +0 -0
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/py.typed +0 -0
- {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pyproject.toml +0 -0
|
@@ -8,13 +8,12 @@ the graph-based workflow system.
|
|
|
8
8
|
from __future__ import annotations as _annotations
|
|
9
9
|
|
|
10
10
|
import sys
|
|
11
|
-
import
|
|
12
|
-
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Sequence
|
|
11
|
+
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Sequence
|
|
13
12
|
from contextlib import AbstractContextManager, AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
|
|
14
13
|
from dataclasses import dataclass, field
|
|
15
14
|
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast, get_args, get_origin, overload
|
|
16
15
|
|
|
17
|
-
from anyio import CancelScope, create_memory_object_stream, create_task_group
|
|
16
|
+
from anyio import BrokenResourceError, CancelScope, create_memory_object_stream, create_task_group
|
|
18
17
|
from anyio.abc import TaskGroup
|
|
19
18
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
20
19
|
from typing_extensions import TypeVar, assert_never
|
|
@@ -22,7 +21,7 @@ from typing_extensions import TypeVar, assert_never
|
|
|
22
21
|
from pydantic_graph import exceptions
|
|
23
22
|
from pydantic_graph._utils import AbstractSpan, get_traceparent, infer_obj_name, logfire_span
|
|
24
23
|
from pydantic_graph.beta.decision import Decision
|
|
25
|
-
from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem,
|
|
24
|
+
from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, JoinID, NodeID, NodeRunID, TaskID
|
|
26
25
|
from pydantic_graph.beta.join import Join, JoinNode, JoinState, ReducerContext
|
|
27
26
|
from pydantic_graph.beta.node import (
|
|
28
27
|
EndNode,
|
|
@@ -44,9 +43,9 @@ from pydantic_graph.beta.util import unpack_type_expression
|
|
|
44
43
|
from pydantic_graph.nodes import BaseNode, End
|
|
45
44
|
|
|
46
45
|
if sys.version_info < (3, 11):
|
|
47
|
-
from exceptiongroup import
|
|
46
|
+
from exceptiongroup import BaseExceptionGroup as BaseExceptionGroup # pragma: lax no cover
|
|
48
47
|
else:
|
|
49
|
-
|
|
48
|
+
BaseExceptionGroup = BaseExceptionGroup # pragma: lax no cover
|
|
50
49
|
|
|
51
50
|
if TYPE_CHECKING:
|
|
52
51
|
from pydantic_graph.beta.mermaid import StateDiagramDirection
|
|
@@ -306,14 +305,13 @@ class Graph(Generic[StateT, DepsT, InputT, OutputT]):
|
|
|
306
305
|
|
|
307
306
|
|
|
308
307
|
@dataclass
|
|
309
|
-
class
|
|
310
|
-
"""A
|
|
308
|
+
class GraphTaskRequest:
|
|
309
|
+
"""A request to run a task representing the execution of a node in the graph.
|
|
311
310
|
|
|
312
|
-
|
|
311
|
+
GraphTaskRequest encapsulates all the information needed to execute a specific
|
|
313
312
|
node, including its inputs and the fork context it's executing within.
|
|
314
313
|
"""
|
|
315
314
|
|
|
316
|
-
# With our current BaseNode thing, next_node_id and next_node_inputs are merged into `next_node` itself
|
|
317
315
|
node_id: NodeID
|
|
318
316
|
"""The ID of the node to execute."""
|
|
319
317
|
|
|
@@ -326,9 +324,26 @@ class GraphTask:
|
|
|
326
324
|
Used by the GraphRun to decide when to proceed through joins.
|
|
327
325
|
"""
|
|
328
326
|
|
|
329
|
-
|
|
327
|
+
|
|
328
|
+
@dataclass
|
|
329
|
+
class GraphTask(GraphTaskRequest):
|
|
330
|
+
"""A task representing the execution of a node in the graph.
|
|
331
|
+
|
|
332
|
+
GraphTask encapsulates all the information needed to execute a specific
|
|
333
|
+
node, including its inputs and the fork context it's executing within,
|
|
334
|
+
and has a unique ID to identify the task within the graph run.
|
|
335
|
+
"""
|
|
336
|
+
|
|
337
|
+
task_id: TaskID = field(repr=False)
|
|
330
338
|
"""Unique identifier for this task."""
|
|
331
339
|
|
|
340
|
+
@staticmethod
|
|
341
|
+
def from_request(request: GraphTaskRequest, get_task_id: Callable[[], TaskID]) -> GraphTask:
|
|
342
|
+
# Don't call the get_task_id callable, this is already a task
|
|
343
|
+
if isinstance(request, GraphTask):
|
|
344
|
+
return request
|
|
345
|
+
return GraphTask(request.node_id, request.inputs, request.fork_stack, get_task_id())
|
|
346
|
+
|
|
332
347
|
|
|
333
348
|
class GraphRun(Generic[StateT, DepsT, OutputT]):
|
|
334
349
|
"""A single execution instance of a graph.
|
|
@@ -378,12 +393,20 @@ class GraphRun(Generic[StateT, DepsT, OutputT]):
|
|
|
378
393
|
self._next: EndMarker[OutputT] | Sequence[GraphTask] | None = None
|
|
379
394
|
"""The next item to be processed."""
|
|
380
395
|
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
396
|
+
self._next_task_id = 0
|
|
397
|
+
self._next_node_run_id = 0
|
|
398
|
+
initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, self._get_next_node_run_id(), 0),)
|
|
399
|
+
self._first_task = GraphTask(
|
|
400
|
+
node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack, task_id=self._get_next_task_id()
|
|
401
|
+
)
|
|
384
402
|
self._iterator_task_group = create_task_group()
|
|
385
403
|
self._iterator_instance = _GraphIterator[StateT, DepsT, OutputT](
|
|
386
|
-
self.graph,
|
|
404
|
+
self.graph,
|
|
405
|
+
self.state,
|
|
406
|
+
self.deps,
|
|
407
|
+
self._iterator_task_group,
|
|
408
|
+
self._get_next_node_run_id,
|
|
409
|
+
self._get_next_task_id,
|
|
387
410
|
)
|
|
388
411
|
self._iterator = self._iterator_instance.iter_graph(self._first_task)
|
|
389
412
|
|
|
@@ -449,7 +472,7 @@ class GraphRun(Generic[StateT, DepsT, OutputT]):
|
|
|
449
472
|
return self._next
|
|
450
473
|
|
|
451
474
|
async def next(
|
|
452
|
-
self, value: EndMarker[OutputT] | Sequence[
|
|
475
|
+
self, value: EndMarker[OutputT] | Sequence[GraphTaskRequest] | None = None
|
|
453
476
|
) -> EndMarker[OutputT] | Sequence[GraphTask]:
|
|
454
477
|
"""Advance the graph execution by one step.
|
|
455
478
|
|
|
@@ -467,7 +490,10 @@ class GraphRun(Generic[StateT, DepsT, OutputT]):
|
|
|
467
490
|
# if `next` is called before the `first_node` has run.
|
|
468
491
|
await anext(self)
|
|
469
492
|
if value is not None:
|
|
470
|
-
|
|
493
|
+
if isinstance(value, EndMarker):
|
|
494
|
+
self._next = value
|
|
495
|
+
else:
|
|
496
|
+
self._next = [GraphTask.from_request(gtr, self._get_next_task_id) for gtr in value]
|
|
471
497
|
return await anext(self)
|
|
472
498
|
|
|
473
499
|
@property
|
|
@@ -490,6 +516,16 @@ class GraphRun(Generic[StateT, DepsT, OutputT]):
|
|
|
490
516
|
return self._next.value
|
|
491
517
|
return None
|
|
492
518
|
|
|
519
|
+
def _get_next_task_id(self) -> TaskID:
|
|
520
|
+
next_id = TaskID(f'task:{self._next_task_id}')
|
|
521
|
+
self._next_task_id += 1
|
|
522
|
+
return next_id
|
|
523
|
+
|
|
524
|
+
def _get_next_node_run_id(self) -> NodeRunID:
|
|
525
|
+
next_id = NodeRunID(f'task:{self._next_node_run_id}')
|
|
526
|
+
self._next_node_run_id += 1
|
|
527
|
+
return next_id
|
|
528
|
+
|
|
493
529
|
|
|
494
530
|
@dataclass
|
|
495
531
|
class _GraphTaskAsyncIterable:
|
|
@@ -510,6 +546,8 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
|
|
|
510
546
|
state: StateT
|
|
511
547
|
deps: DepsT
|
|
512
548
|
task_group: TaskGroup
|
|
549
|
+
get_next_node_run_id: Callable[[], NodeRunID]
|
|
550
|
+
get_next_task_id: Callable[[], TaskID]
|
|
513
551
|
|
|
514
552
|
cancel_scopes: dict[TaskID, CancelScope] = field(init=False)
|
|
515
553
|
active_tasks: dict[TaskID, GraphTask] = field(init=False)
|
|
@@ -522,8 +560,9 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
|
|
|
522
560
|
self.active_tasks = {}
|
|
523
561
|
self.active_reducers = {}
|
|
524
562
|
self.iter_stream_sender, self.iter_stream_receiver = create_memory_object_stream[_GraphTaskResult]()
|
|
563
|
+
self._next_node_run_id = 1
|
|
525
564
|
|
|
526
|
-
async def iter_graph( # noqa C901
|
|
565
|
+
async def iter_graph( # noqa: C901
|
|
527
566
|
self, first_task: GraphTask
|
|
528
567
|
) -> AsyncGenerator[EndMarker[OutputT] | Sequence[GraphTask], EndMarker[OutputT] | Sequence[GraphTask]]:
|
|
529
568
|
async with self.iter_stream_sender:
|
|
@@ -709,12 +748,15 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
|
|
|
709
748
|
with CancelScope() as scope:
|
|
710
749
|
self.cancel_scopes[t_.task_id] = scope
|
|
711
750
|
result = await self._run_task(t_)
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
751
|
+
try:
|
|
752
|
+
if isinstance(result, _GraphTaskAsyncIterable):
|
|
753
|
+
async for new_tasks in result.iterable:
|
|
754
|
+
await self.iter_stream_sender.send(_GraphTaskResult(t_, new_tasks, False))
|
|
755
|
+
await self.iter_stream_sender.send(_GraphTaskResult(t_, []))
|
|
756
|
+
else:
|
|
757
|
+
await self.iter_stream_sender.send(_GraphTaskResult(t_, result))
|
|
758
|
+
except BrokenResourceError:
|
|
759
|
+
pass # pragma: no cover # This can happen in difficult-to-reproduce circumstances when cancelling an asyncio task
|
|
718
760
|
|
|
719
761
|
async def _run_task(
|
|
720
762
|
self,
|
|
@@ -782,12 +824,12 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
|
|
|
782
824
|
fork_stack: ForkStack,
|
|
783
825
|
) -> Sequence[GraphTask] | JoinItem | EndMarker[OutputT]:
|
|
784
826
|
if isinstance(next_node, StepNode):
|
|
785
|
-
return [GraphTask(next_node.step.id, next_node.inputs, fork_stack)]
|
|
827
|
+
return [GraphTask(next_node.step.id, next_node.inputs, fork_stack, self.get_next_task_id())]
|
|
786
828
|
elif isinstance(next_node, JoinNode):
|
|
787
829
|
return JoinItem(next_node.join.id, next_node.inputs, fork_stack)
|
|
788
830
|
elif isinstance(next_node, BaseNode):
|
|
789
831
|
node_step = NodeStep(next_node.__class__)
|
|
790
|
-
return [GraphTask(node_step.id, next_node, fork_stack)]
|
|
832
|
+
return [GraphTask(node_step.id, next_node, fork_stack, self.get_next_task_id())]
|
|
791
833
|
elif isinstance(next_node, End):
|
|
792
834
|
return EndMarker(next_node.data)
|
|
793
835
|
else:
|
|
@@ -821,7 +863,7 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
|
|
|
821
863
|
'These markers should be removed from paths during graph building'
|
|
822
864
|
)
|
|
823
865
|
if isinstance(item, DestinationMarker):
|
|
824
|
-
return [GraphTask(item.destination_id, inputs, fork_stack)]
|
|
866
|
+
return [GraphTask(item.destination_id, inputs, fork_stack, self.get_next_task_id())]
|
|
825
867
|
elif isinstance(item, TransformMarker):
|
|
826
868
|
inputs = item.transform(StepContext(state=self.state, deps=self.deps, inputs=inputs))
|
|
827
869
|
return self._handle_path(path.next_path, inputs, fork_stack)
|
|
@@ -853,7 +895,7 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
|
|
|
853
895
|
) # this should have already been ensured during graph building
|
|
854
896
|
|
|
855
897
|
new_tasks: list[GraphTask] = []
|
|
856
|
-
node_run_id =
|
|
898
|
+
node_run_id = self.get_next_node_run_id()
|
|
857
899
|
if node.is_map:
|
|
858
900
|
# If the map specifies a downstream join id, eagerly create a join state for it
|
|
859
901
|
if (join_id := node.downstream_join_id) is not None:
|
|
@@ -931,7 +973,7 @@ def _unwrap_exception_groups():
|
|
|
931
973
|
else:
|
|
932
974
|
try:
|
|
933
975
|
yield
|
|
934
|
-
except
|
|
976
|
+
except BaseExceptionGroup as e:
|
|
935
977
|
exception = e.exceptions[0]
|
|
936
978
|
if exception.__cause__ is None:
|
|
937
979
|
# bizarrely, this prevents recursion errors when formatting the exception for logfire
|
|
@@ -284,6 +284,8 @@ class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]):
|
|
|
284
284
|
async def wrapper(ctx: StepContext[StateT, DepsT, InputT]):
|
|
285
285
|
return call(ctx)
|
|
286
286
|
|
|
287
|
+
node_id = node_id or get_callable_name(call)
|
|
288
|
+
|
|
287
289
|
return self.step(call=wrapper, node_id=node_id, label=label)
|
|
288
290
|
|
|
289
291
|
@overload
|
|
@@ -318,7 +320,7 @@ class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]):
|
|
|
318
320
|
preferred_parent_fork: Literal['farthest', 'closest'] = 'farthest',
|
|
319
321
|
) -> Join[StateT, DepsT, InputT, OutputT]:
|
|
320
322
|
if initial_factory is UNSET:
|
|
321
|
-
initial_factory = lambda: initial # pyright: ignore[reportAssignmentType] # noqa E731
|
|
323
|
+
initial_factory = lambda: initial # pyright: ignore[reportAssignmentType] # noqa: E731
|
|
322
324
|
|
|
323
325
|
return Join[StateT, DepsT, InputT, OutputT](
|
|
324
326
|
id=JoinID(NodeID(node_id or generate_placeholder_node_id(get_callable_name(reducer)))),
|
|
@@ -329,7 +331,7 @@ class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]):
|
|
|
329
331
|
)
|
|
330
332
|
|
|
331
333
|
# Edge building
|
|
332
|
-
def add(self, *edges: EdgePath[StateT, DepsT]) -> None: # noqa C901
|
|
334
|
+
def add(self, *edges: EdgePath[StateT, DepsT]) -> None: # noqa: C901
|
|
333
335
|
"""Add one or more edge paths to the graph.
|
|
334
336
|
|
|
335
337
|
This method processes edge paths and automatically creates any necessary
|
|
@@ -674,7 +676,7 @@ class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]):
|
|
|
674
676
|
)
|
|
675
677
|
|
|
676
678
|
|
|
677
|
-
def _validate_graph_structure( # noqa C901
|
|
679
|
+
def _validate_graph_structure( # noqa: C901
|
|
678
680
|
nodes: dict[NodeID, AnyNode],
|
|
679
681
|
edges_by_source: dict[NodeID, list[Path]],
|
|
680
682
|
) -> None:
|
|
@@ -24,9 +24,6 @@ JoinID = NodeID
|
|
|
24
24
|
ForkID = NodeID
|
|
25
25
|
"""Alias for NodeId when referring to fork nodes."""
|
|
26
26
|
|
|
27
|
-
GraphRunID = NewType('GraphRunID', str)
|
|
28
|
-
"""Unique identifier for a complete graph execution run."""
|
|
29
|
-
|
|
30
27
|
TaskID = NewType('TaskID', str)
|
|
31
28
|
"""Unique identifier for a task within the graph execution."""
|
|
32
29
|
|
|
@@ -49,7 +49,7 @@ class MermaidEdge:
|
|
|
49
49
|
label: str | None
|
|
50
50
|
|
|
51
51
|
|
|
52
|
-
def build_mermaid_graph( # noqa C901
|
|
52
|
+
def build_mermaid_graph( # noqa: C901
|
|
53
53
|
graph_nodes: dict[NodeID, AnyNode], graph_edges_by_source: dict[NodeID, list[Path]]
|
|
54
54
|
) -> MermaidGraph:
|
|
55
55
|
"""Build a mermaid graph."""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|