pydantic-graph 1.13.0__tar.gz → 1.31.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/.gitignore +2 -0
  2. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/PKG-INFO +1 -1
  3. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/graph.py +71 -29
  4. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/graph_builder.py +5 -3
  5. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/id_types.py +0 -3
  6. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/mermaid.py +1 -1
  7. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/LICENSE +0 -0
  8. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/README.md +0 -0
  9. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/__init__.py +0 -0
  10. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/_utils.py +0 -0
  11. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/__init__.py +0 -0
  12. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/decision.py +0 -0
  13. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/join.py +0 -0
  14. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/node.py +0 -0
  15. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/node_types.py +0 -0
  16. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/parent_forks.py +0 -0
  17. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/paths.py +0 -0
  18. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/step.py +0 -0
  19. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/beta/util.py +0 -0
  20. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/exceptions.py +0 -0
  21. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/graph.py +0 -0
  22. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/mermaid.py +0 -0
  23. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/nodes.py +0 -0
  24. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/persistence/__init__.py +0 -0
  25. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/persistence/_utils.py +0 -0
  26. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/persistence/file.py +0 -0
  27. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/persistence/in_mem.py +0 -0
  28. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pydantic_graph/py.typed +0 -0
  29. {pydantic_graph-1.13.0 → pydantic_graph-1.31.0}/pyproject.toml +0 -0
@@ -21,3 +21,5 @@ node_modules/
21
21
  /test_tmp/
22
22
  .mcp.json
23
23
  .claude/
24
+ /.cursor/
25
+ /.devcontainer/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-graph
3
- Version: 1.13.0
3
+ Version: 1.31.0
4
4
  Summary: Graph and state machine library
5
5
  Project-URL: Homepage, https://ai.pydantic.dev/graph/tree/main/pydantic_graph
6
6
  Project-URL: Source, https://github.com/pydantic/pydantic-ai
@@ -8,13 +8,12 @@ the graph-based workflow system.
8
8
  from __future__ import annotations as _annotations
9
9
 
10
10
  import sys
11
- import uuid
12
- from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Sequence
11
+ from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Sequence
13
12
  from contextlib import AbstractContextManager, AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
14
13
  from dataclasses import dataclass, field
15
14
  from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast, get_args, get_origin, overload
16
15
 
17
- from anyio import CancelScope, create_memory_object_stream, create_task_group
16
+ from anyio import BrokenResourceError, CancelScope, create_memory_object_stream, create_task_group
18
17
  from anyio.abc import TaskGroup
19
18
  from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
20
19
  from typing_extensions import TypeVar, assert_never
@@ -22,7 +21,7 @@ from typing_extensions import TypeVar, assert_never
22
21
  from pydantic_graph import exceptions
23
22
  from pydantic_graph._utils import AbstractSpan, get_traceparent, infer_obj_name, logfire_span
24
23
  from pydantic_graph.beta.decision import Decision
25
- from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, GraphRunID, JoinID, NodeID, NodeRunID, TaskID
24
+ from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, JoinID, NodeID, NodeRunID, TaskID
26
25
  from pydantic_graph.beta.join import Join, JoinNode, JoinState, ReducerContext
27
26
  from pydantic_graph.beta.node import (
28
27
  EndNode,
@@ -44,9 +43,9 @@ from pydantic_graph.beta.util import unpack_type_expression
44
43
  from pydantic_graph.nodes import BaseNode, End
45
44
 
46
45
  if sys.version_info < (3, 11):
47
- from exceptiongroup import ExceptionGroup as ExceptionGroup # pragma: lax no cover
46
+ from exceptiongroup import BaseExceptionGroup as BaseExceptionGroup # pragma: lax no cover
48
47
  else:
49
- ExceptionGroup = ExceptionGroup # pragma: lax no cover
48
+ BaseExceptionGroup = BaseExceptionGroup # pragma: lax no cover
50
49
 
51
50
  if TYPE_CHECKING:
52
51
  from pydantic_graph.beta.mermaid import StateDiagramDirection
@@ -306,14 +305,13 @@ class Graph(Generic[StateT, DepsT, InputT, OutputT]):
306
305
 
307
306
 
308
307
  @dataclass
309
- class GraphTask:
310
- """A single task representing the execution of a node in the graph.
308
+ class GraphTaskRequest:
309
+ """A request to run a task representing the execution of a node in the graph.
311
310
 
312
- GraphTask encapsulates all the information needed to execute a specific
311
+ GraphTaskRequest encapsulates all the information needed to execute a specific
313
312
  node, including its inputs and the fork context it's executing within.
314
313
  """
315
314
 
316
- # With our current BaseNode thing, next_node_id and next_node_inputs are merged into `next_node` itself
317
315
  node_id: NodeID
318
316
  """The ID of the node to execute."""
319
317
 
@@ -326,9 +324,26 @@ class GraphTask:
326
324
  Used by the GraphRun to decide when to proceed through joins.
327
325
  """
328
326
 
329
- task_id: TaskID = field(default_factory=lambda: TaskID(str(uuid.uuid4())), repr=False)
327
+
328
+ @dataclass
329
+ class GraphTask(GraphTaskRequest):
330
+ """A task representing the execution of a node in the graph.
331
+
332
+ GraphTask encapsulates all the information needed to execute a specific
333
+ node, including its inputs and the fork context it's executing within,
334
+ and has a unique ID to identify the task within the graph run.
335
+ """
336
+
337
+ task_id: TaskID = field(repr=False)
330
338
  """Unique identifier for this task."""
331
339
 
340
+ @staticmethod
341
+ def from_request(request: GraphTaskRequest, get_task_id: Callable[[], TaskID]) -> GraphTask:
342
+ # Don't call the get_task_id callable, this is already a task
343
+ if isinstance(request, GraphTask):
344
+ return request
345
+ return GraphTask(request.node_id, request.inputs, request.fork_stack, get_task_id())
346
+
332
347
 
333
348
  class GraphRun(Generic[StateT, DepsT, OutputT]):
334
349
  """A single execution instance of a graph.
@@ -378,12 +393,20 @@ class GraphRun(Generic[StateT, DepsT, OutputT]):
378
393
  self._next: EndMarker[OutputT] | Sequence[GraphTask] | None = None
379
394
  """The next item to be processed."""
380
395
 
381
- run_id = GraphRunID(str(uuid.uuid4()))
382
- initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, NodeRunID(run_id), 0),)
383
- self._first_task = GraphTask(node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack)
396
+ self._next_task_id = 0
397
+ self._next_node_run_id = 0
398
+ initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, self._get_next_node_run_id(), 0),)
399
+ self._first_task = GraphTask(
400
+ node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack, task_id=self._get_next_task_id()
401
+ )
384
402
  self._iterator_task_group = create_task_group()
385
403
  self._iterator_instance = _GraphIterator[StateT, DepsT, OutputT](
386
- self.graph, self.state, self.deps, self._iterator_task_group
404
+ self.graph,
405
+ self.state,
406
+ self.deps,
407
+ self._iterator_task_group,
408
+ self._get_next_node_run_id,
409
+ self._get_next_task_id,
387
410
  )
388
411
  self._iterator = self._iterator_instance.iter_graph(self._first_task)
389
412
 
@@ -449,7 +472,7 @@ class GraphRun(Generic[StateT, DepsT, OutputT]):
449
472
  return self._next
450
473
 
451
474
  async def next(
452
- self, value: EndMarker[OutputT] | Sequence[GraphTask] | None = None
475
+ self, value: EndMarker[OutputT] | Sequence[GraphTaskRequest] | None = None
453
476
  ) -> EndMarker[OutputT] | Sequence[GraphTask]:
454
477
  """Advance the graph execution by one step.
455
478
 
@@ -467,7 +490,10 @@ class GraphRun(Generic[StateT, DepsT, OutputT]):
467
490
  # if `next` is called before the `first_node` has run.
468
491
  await anext(self)
469
492
  if value is not None:
470
- self._next = value
493
+ if isinstance(value, EndMarker):
494
+ self._next = value
495
+ else:
496
+ self._next = [GraphTask.from_request(gtr, self._get_next_task_id) for gtr in value]
471
497
  return await anext(self)
472
498
 
473
499
  @property
@@ -490,6 +516,16 @@ class GraphRun(Generic[StateT, DepsT, OutputT]):
490
516
  return self._next.value
491
517
  return None
492
518
 
519
+ def _get_next_task_id(self) -> TaskID:
520
+ next_id = TaskID(f'task:{self._next_task_id}')
521
+ self._next_task_id += 1
522
+ return next_id
523
+
524
+ def _get_next_node_run_id(self) -> NodeRunID:
525
+ next_id = NodeRunID(f'task:{self._next_node_run_id}')
526
+ self._next_node_run_id += 1
527
+ return next_id
528
+
493
529
 
494
530
  @dataclass
495
531
  class _GraphTaskAsyncIterable:
@@ -510,6 +546,8 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
510
546
  state: StateT
511
547
  deps: DepsT
512
548
  task_group: TaskGroup
549
+ get_next_node_run_id: Callable[[], NodeRunID]
550
+ get_next_task_id: Callable[[], TaskID]
513
551
 
514
552
  cancel_scopes: dict[TaskID, CancelScope] = field(init=False)
515
553
  active_tasks: dict[TaskID, GraphTask] = field(init=False)
@@ -522,8 +560,9 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
522
560
  self.active_tasks = {}
523
561
  self.active_reducers = {}
524
562
  self.iter_stream_sender, self.iter_stream_receiver = create_memory_object_stream[_GraphTaskResult]()
563
+ self._next_node_run_id = 1
525
564
 
526
- async def iter_graph( # noqa C901
565
+ async def iter_graph( # noqa: C901
527
566
  self, first_task: GraphTask
528
567
  ) -> AsyncGenerator[EndMarker[OutputT] | Sequence[GraphTask], EndMarker[OutputT] | Sequence[GraphTask]]:
529
568
  async with self.iter_stream_sender:
@@ -709,12 +748,15 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
709
748
  with CancelScope() as scope:
710
749
  self.cancel_scopes[t_.task_id] = scope
711
750
  result = await self._run_task(t_)
712
- if isinstance(result, _GraphTaskAsyncIterable):
713
- async for new_tasks in result.iterable:
714
- await self.iter_stream_sender.send(_GraphTaskResult(t_, new_tasks, False))
715
- await self.iter_stream_sender.send(_GraphTaskResult(t_, []))
716
- else:
717
- await self.iter_stream_sender.send(_GraphTaskResult(t_, result))
751
+ try:
752
+ if isinstance(result, _GraphTaskAsyncIterable):
753
+ async for new_tasks in result.iterable:
754
+ await self.iter_stream_sender.send(_GraphTaskResult(t_, new_tasks, False))
755
+ await self.iter_stream_sender.send(_GraphTaskResult(t_, []))
756
+ else:
757
+ await self.iter_stream_sender.send(_GraphTaskResult(t_, result))
758
+ except BrokenResourceError:
759
+ pass # pragma: no cover # This can happen in difficult-to-reproduce circumstances when cancelling an asyncio task
718
760
 
719
761
  async def _run_task(
720
762
  self,
@@ -782,12 +824,12 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
782
824
  fork_stack: ForkStack,
783
825
  ) -> Sequence[GraphTask] | JoinItem | EndMarker[OutputT]:
784
826
  if isinstance(next_node, StepNode):
785
- return [GraphTask(next_node.step.id, next_node.inputs, fork_stack)]
827
+ return [GraphTask(next_node.step.id, next_node.inputs, fork_stack, self.get_next_task_id())]
786
828
  elif isinstance(next_node, JoinNode):
787
829
  return JoinItem(next_node.join.id, next_node.inputs, fork_stack)
788
830
  elif isinstance(next_node, BaseNode):
789
831
  node_step = NodeStep(next_node.__class__)
790
- return [GraphTask(node_step.id, next_node, fork_stack)]
832
+ return [GraphTask(node_step.id, next_node, fork_stack, self.get_next_task_id())]
791
833
  elif isinstance(next_node, End):
792
834
  return EndMarker(next_node.data)
793
835
  else:
@@ -821,7 +863,7 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
821
863
  'These markers should be removed from paths during graph building'
822
864
  )
823
865
  if isinstance(item, DestinationMarker):
824
- return [GraphTask(item.destination_id, inputs, fork_stack)]
866
+ return [GraphTask(item.destination_id, inputs, fork_stack, self.get_next_task_id())]
825
867
  elif isinstance(item, TransformMarker):
826
868
  inputs = item.transform(StepContext(state=self.state, deps=self.deps, inputs=inputs))
827
869
  return self._handle_path(path.next_path, inputs, fork_stack)
@@ -853,7 +895,7 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
853
895
  ) # this should have already been ensured during graph building
854
896
 
855
897
  new_tasks: list[GraphTask] = []
856
- node_run_id = NodeRunID(str(uuid.uuid4()))
898
+ node_run_id = self.get_next_node_run_id()
857
899
  if node.is_map:
858
900
  # If the map specifies a downstream join id, eagerly create a join state for it
859
901
  if (join_id := node.downstream_join_id) is not None:
@@ -931,7 +973,7 @@ def _unwrap_exception_groups():
931
973
  else:
932
974
  try:
933
975
  yield
934
- except ExceptionGroup as e:
976
+ except BaseExceptionGroup as e:
935
977
  exception = e.exceptions[0]
936
978
  if exception.__cause__ is None:
937
979
  # bizarrely, this prevents recursion errors when formatting the exception for logfire
@@ -284,6 +284,8 @@ class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]):
284
284
  async def wrapper(ctx: StepContext[StateT, DepsT, InputT]):
285
285
  return call(ctx)
286
286
 
287
+ node_id = node_id or get_callable_name(call)
288
+
287
289
  return self.step(call=wrapper, node_id=node_id, label=label)
288
290
 
289
291
  @overload
@@ -318,7 +320,7 @@ class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]):
318
320
  preferred_parent_fork: Literal['farthest', 'closest'] = 'farthest',
319
321
  ) -> Join[StateT, DepsT, InputT, OutputT]:
320
322
  if initial_factory is UNSET:
321
- initial_factory = lambda: initial # pyright: ignore[reportAssignmentType] # noqa E731
323
+ initial_factory = lambda: initial # pyright: ignore[reportAssignmentType] # noqa: E731
322
324
 
323
325
  return Join[StateT, DepsT, InputT, OutputT](
324
326
  id=JoinID(NodeID(node_id or generate_placeholder_node_id(get_callable_name(reducer)))),
@@ -329,7 +331,7 @@ class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]):
329
331
  )
330
332
 
331
333
  # Edge building
332
- def add(self, *edges: EdgePath[StateT, DepsT]) -> None: # noqa C901
334
+ def add(self, *edges: EdgePath[StateT, DepsT]) -> None: # noqa: C901
333
335
  """Add one or more edge paths to the graph.
334
336
 
335
337
  This method processes edge paths and automatically creates any necessary
@@ -674,7 +676,7 @@ class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]):
674
676
  )
675
677
 
676
678
 
677
- def _validate_graph_structure( # noqa C901
679
+ def _validate_graph_structure( # noqa: C901
678
680
  nodes: dict[NodeID, AnyNode],
679
681
  edges_by_source: dict[NodeID, list[Path]],
680
682
  ) -> None:
@@ -24,9 +24,6 @@ JoinID = NodeID
24
24
  ForkID = NodeID
25
25
  """Alias for NodeId when referring to fork nodes."""
26
26
 
27
- GraphRunID = NewType('GraphRunID', str)
28
- """Unique identifier for a complete graph execution run."""
29
-
30
27
  TaskID = NewType('TaskID', str)
31
28
  """Unique identifier for a task within the graph execution."""
32
29
 
@@ -49,7 +49,7 @@ class MermaidEdge:
49
49
  label: str | None
50
50
 
51
51
 
52
- def build_mermaid_graph( # noqa C901
52
+ def build_mermaid_graph( # noqa: C901
53
53
  graph_nodes: dict[NodeID, AnyNode], graph_edges_by_source: dict[NodeID, list[Path]]
54
54
  ) -> MermaidGraph:
55
55
  """Build a mermaid graph."""
File without changes