pydantic-graph 1.18.0__py3-none-any.whl → 1.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_graph/beta/graph.py +57 -18
- pydantic_graph/beta/id_types.py +0 -3
- {pydantic_graph-1.18.0.dist-info → pydantic_graph-1.20.0.dist-info}/METADATA +1 -1
- {pydantic_graph-1.18.0.dist-info → pydantic_graph-1.20.0.dist-info}/RECORD +6 -6
- {pydantic_graph-1.18.0.dist-info → pydantic_graph-1.20.0.dist-info}/WHEEL +0 -0
- {pydantic_graph-1.18.0.dist-info → pydantic_graph-1.20.0.dist-info}/licenses/LICENSE +0 -0
pydantic_graph/beta/graph.py
CHANGED
|
@@ -8,8 +8,7 @@ the graph-based workflow system.
|
|
|
8
8
|
from __future__ import annotations as _annotations
|
|
9
9
|
|
|
10
10
|
import sys
|
|
11
|
-
import
|
|
12
|
-
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Sequence
|
|
11
|
+
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Sequence
|
|
13
12
|
from contextlib import AbstractContextManager, AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
|
|
14
13
|
from dataclasses import dataclass, field
|
|
15
14
|
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast, get_args, get_origin, overload
|
|
@@ -22,7 +21,7 @@ from typing_extensions import TypeVar, assert_never
|
|
|
22
21
|
from pydantic_graph import exceptions
|
|
23
22
|
from pydantic_graph._utils import AbstractSpan, get_traceparent, infer_obj_name, logfire_span
|
|
24
23
|
from pydantic_graph.beta.decision import Decision
|
|
25
|
-
from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem,
|
|
24
|
+
from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, JoinID, NodeID, NodeRunID, TaskID
|
|
26
25
|
from pydantic_graph.beta.join import Join, JoinNode, JoinState, ReducerContext
|
|
27
26
|
from pydantic_graph.beta.node import (
|
|
28
27
|
EndNode,
|
|
@@ -306,14 +305,13 @@ class Graph(Generic[StateT, DepsT, InputT, OutputT]):
|
|
|
306
305
|
|
|
307
306
|
|
|
308
307
|
@dataclass
|
|
309
|
-
class
|
|
310
|
-
"""A
|
|
308
|
+
class GraphTaskRequest:
|
|
309
|
+
"""A request to run a task representing the execution of a node in the graph.
|
|
311
310
|
|
|
312
|
-
|
|
311
|
+
GraphTaskRequest encapsulates all the information needed to execute a specific
|
|
313
312
|
node, including its inputs and the fork context it's executing within.
|
|
314
313
|
"""
|
|
315
314
|
|
|
316
|
-
# With our current BaseNode thing, next_node_id and next_node_inputs are merged into `next_node` itself
|
|
317
315
|
node_id: NodeID
|
|
318
316
|
"""The ID of the node to execute."""
|
|
319
317
|
|
|
@@ -326,9 +324,26 @@ class GraphTask:
|
|
|
326
324
|
Used by the GraphRun to decide when to proceed through joins.
|
|
327
325
|
"""
|
|
328
326
|
|
|
329
|
-
|
|
327
|
+
|
|
328
|
+
@dataclass
|
|
329
|
+
class GraphTask(GraphTaskRequest):
|
|
330
|
+
"""A task representing the execution of a node in the graph.
|
|
331
|
+
|
|
332
|
+
GraphTask encapsulates all the information needed to execute a specific
|
|
333
|
+
node, including its inputs and the fork context it's executing within,
|
|
334
|
+
and has a unique ID to identify the task within the graph run.
|
|
335
|
+
"""
|
|
336
|
+
|
|
337
|
+
task_id: TaskID = field(repr=False)
|
|
330
338
|
"""Unique identifier for this task."""
|
|
331
339
|
|
|
340
|
+
@staticmethod
|
|
341
|
+
def from_request(request: GraphTaskRequest, get_task_id: Callable[[], TaskID]) -> GraphTask:
|
|
342
|
+
# Don't call the get_task_id callable, this is already a task
|
|
343
|
+
if isinstance(request, GraphTask):
|
|
344
|
+
return request
|
|
345
|
+
return GraphTask(request.node_id, request.inputs, request.fork_stack, get_task_id())
|
|
346
|
+
|
|
332
347
|
|
|
333
348
|
class GraphRun(Generic[StateT, DepsT, OutputT]):
|
|
334
349
|
"""A single execution instance of a graph.
|
|
@@ -378,12 +393,20 @@ class GraphRun(Generic[StateT, DepsT, OutputT]):
|
|
|
378
393
|
self._next: EndMarker[OutputT] | Sequence[GraphTask] | None = None
|
|
379
394
|
"""The next item to be processed."""
|
|
380
395
|
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
396
|
+
self._next_task_id = 0
|
|
397
|
+
self._next_node_run_id = 0
|
|
398
|
+
initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, self._get_next_node_run_id(), 0),)
|
|
399
|
+
self._first_task = GraphTask(
|
|
400
|
+
node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack, task_id=self._get_next_task_id()
|
|
401
|
+
)
|
|
384
402
|
self._iterator_task_group = create_task_group()
|
|
385
403
|
self._iterator_instance = _GraphIterator[StateT, DepsT, OutputT](
|
|
386
|
-
self.graph,
|
|
404
|
+
self.graph,
|
|
405
|
+
self.state,
|
|
406
|
+
self.deps,
|
|
407
|
+
self._iterator_task_group,
|
|
408
|
+
self._get_next_node_run_id,
|
|
409
|
+
self._get_next_task_id,
|
|
387
410
|
)
|
|
388
411
|
self._iterator = self._iterator_instance.iter_graph(self._first_task)
|
|
389
412
|
|
|
@@ -449,7 +472,7 @@ class GraphRun(Generic[StateT, DepsT, OutputT]):
|
|
|
449
472
|
return self._next
|
|
450
473
|
|
|
451
474
|
async def next(
|
|
452
|
-
self, value: EndMarker[OutputT] | Sequence[
|
|
475
|
+
self, value: EndMarker[OutputT] | Sequence[GraphTaskRequest] | None = None
|
|
453
476
|
) -> EndMarker[OutputT] | Sequence[GraphTask]:
|
|
454
477
|
"""Advance the graph execution by one step.
|
|
455
478
|
|
|
@@ -467,7 +490,10 @@ class GraphRun(Generic[StateT, DepsT, OutputT]):
|
|
|
467
490
|
# if `next` is called before the `first_node` has run.
|
|
468
491
|
await anext(self)
|
|
469
492
|
if value is not None:
|
|
470
|
-
|
|
493
|
+
if isinstance(value, EndMarker):
|
|
494
|
+
self._next = value
|
|
495
|
+
else:
|
|
496
|
+
self._next = [GraphTask.from_request(gtr, self._get_next_task_id) for gtr in value]
|
|
471
497
|
return await anext(self)
|
|
472
498
|
|
|
473
499
|
@property
|
|
@@ -490,6 +516,16 @@ class GraphRun(Generic[StateT, DepsT, OutputT]):
|
|
|
490
516
|
return self._next.value
|
|
491
517
|
return None
|
|
492
518
|
|
|
519
|
+
def _get_next_task_id(self) -> TaskID:
|
|
520
|
+
next_id = TaskID(f'task:{self._next_task_id}')
|
|
521
|
+
self._next_task_id += 1
|
|
522
|
+
return next_id
|
|
523
|
+
|
|
524
|
+
def _get_next_node_run_id(self) -> NodeRunID:
|
|
525
|
+
next_id = NodeRunID(f'task:{self._next_node_run_id}')
|
|
526
|
+
self._next_node_run_id += 1
|
|
527
|
+
return next_id
|
|
528
|
+
|
|
493
529
|
|
|
494
530
|
@dataclass
|
|
495
531
|
class _GraphTaskAsyncIterable:
|
|
@@ -510,6 +546,8 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
|
|
|
510
546
|
state: StateT
|
|
511
547
|
deps: DepsT
|
|
512
548
|
task_group: TaskGroup
|
|
549
|
+
get_next_node_run_id: Callable[[], NodeRunID]
|
|
550
|
+
get_next_task_id: Callable[[], TaskID]
|
|
513
551
|
|
|
514
552
|
cancel_scopes: dict[TaskID, CancelScope] = field(init=False)
|
|
515
553
|
active_tasks: dict[TaskID, GraphTask] = field(init=False)
|
|
@@ -522,6 +560,7 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
|
|
|
522
560
|
self.active_tasks = {}
|
|
523
561
|
self.active_reducers = {}
|
|
524
562
|
self.iter_stream_sender, self.iter_stream_receiver = create_memory_object_stream[_GraphTaskResult]()
|
|
563
|
+
self._next_node_run_id = 1
|
|
525
564
|
|
|
526
565
|
async def iter_graph( # noqa C901
|
|
527
566
|
self, first_task: GraphTask
|
|
@@ -782,12 +821,12 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
|
|
|
782
821
|
fork_stack: ForkStack,
|
|
783
822
|
) -> Sequence[GraphTask] | JoinItem | EndMarker[OutputT]:
|
|
784
823
|
if isinstance(next_node, StepNode):
|
|
785
|
-
return [GraphTask(next_node.step.id, next_node.inputs, fork_stack)]
|
|
824
|
+
return [GraphTask(next_node.step.id, next_node.inputs, fork_stack, self.get_next_task_id())]
|
|
786
825
|
elif isinstance(next_node, JoinNode):
|
|
787
826
|
return JoinItem(next_node.join.id, next_node.inputs, fork_stack)
|
|
788
827
|
elif isinstance(next_node, BaseNode):
|
|
789
828
|
node_step = NodeStep(next_node.__class__)
|
|
790
|
-
return [GraphTask(node_step.id, next_node, fork_stack)]
|
|
829
|
+
return [GraphTask(node_step.id, next_node, fork_stack, self.get_next_task_id())]
|
|
791
830
|
elif isinstance(next_node, End):
|
|
792
831
|
return EndMarker(next_node.data)
|
|
793
832
|
else:
|
|
@@ -821,7 +860,7 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
|
|
|
821
860
|
'These markers should be removed from paths during graph building'
|
|
822
861
|
)
|
|
823
862
|
if isinstance(item, DestinationMarker):
|
|
824
|
-
return [GraphTask(item.destination_id, inputs, fork_stack)]
|
|
863
|
+
return [GraphTask(item.destination_id, inputs, fork_stack, self.get_next_task_id())]
|
|
825
864
|
elif isinstance(item, TransformMarker):
|
|
826
865
|
inputs = item.transform(StepContext(state=self.state, deps=self.deps, inputs=inputs))
|
|
827
866
|
return self._handle_path(path.next_path, inputs, fork_stack)
|
|
@@ -853,7 +892,7 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
|
|
|
853
892
|
) # this should have already been ensured during graph building
|
|
854
893
|
|
|
855
894
|
new_tasks: list[GraphTask] = []
|
|
856
|
-
node_run_id =
|
|
895
|
+
node_run_id = self.get_next_node_run_id()
|
|
857
896
|
if node.is_map:
|
|
858
897
|
# If the map specifies a downstream join id, eagerly create a join state for it
|
|
859
898
|
if (join_id := node.downstream_join_id) is not None:
|
pydantic_graph/beta/id_types.py
CHANGED
|
@@ -24,9 +24,6 @@ JoinID = NodeID
|
|
|
24
24
|
ForkID = NodeID
|
|
25
25
|
"""Alias for NodeId when referring to fork nodes."""
|
|
26
26
|
|
|
27
|
-
GraphRunID = NewType('GraphRunID', str)
|
|
28
|
-
"""Unique identifier for a complete graph execution run."""
|
|
29
|
-
|
|
30
27
|
TaskID = NewType('TaskID', str)
|
|
31
28
|
"""Unique identifier for a task within the graph execution."""
|
|
32
29
|
|
|
@@ -7,9 +7,9 @@ pydantic_graph/nodes.py,sha256=CkY3lrC6jqZtzwhSRjFzmM69TdFFFrr58XSDU4THKHA,7450
|
|
|
7
7
|
pydantic_graph/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
8
|
pydantic_graph/beta/__init__.py,sha256=VVmbEFaCSXYHwXqS4pANg4B3cn_c86tT62tW_EXcuyw,751
|
|
9
9
|
pydantic_graph/beta/decision.py,sha256=x-Ta549b-j5hyBPUWFdwRQDRaJqnBHF1pfBP9L8I3vI,11239
|
|
10
|
-
pydantic_graph/beta/graph.py,sha256=
|
|
10
|
+
pydantic_graph/beta/graph.py,sha256=CBCqWTJ1u7o0k2Rud193IJq2NffydEOJ5UMl0A_tYJs,42790
|
|
11
11
|
pydantic_graph/beta/graph_builder.py,sha256=dCw4LePreagujGNtTdCVZfRVWkCs35MpoPEtAncLo5U,43326
|
|
12
|
-
pydantic_graph/beta/id_types.py,sha256=
|
|
12
|
+
pydantic_graph/beta/id_types.py,sha256=FZ3rYSubF6g_Ocv0faL3yJsy1lNN9AGZl9f_izvORUg,2814
|
|
13
13
|
pydantic_graph/beta/join.py,sha256=rzCumDX_YgaU_a5bisfbjbbOuI3IwSZsCZs9TC0T9E4,8002
|
|
14
14
|
pydantic_graph/beta/mermaid.py,sha256=Bj8a3CODPcojwT7BnrYqLBKTp0AbA1T3XsTmK2St3v4,7127
|
|
15
15
|
pydantic_graph/beta/node.py,sha256=cTEGKiT3Lutg-PWxBbZDihpnBTVoPMSyCbfB50fjKeY,3071
|
|
@@ -22,7 +22,7 @@ pydantic_graph/persistence/__init__.py,sha256=NLBGvUWhem23EdMHHxtX0XgTS2vyixmuWt
|
|
|
22
22
|
pydantic_graph/persistence/_utils.py,sha256=6ySxCc1lFz7bbLUwDLkoZWNqi8VNLBVU4xxJbKI23fQ,2264
|
|
23
23
|
pydantic_graph/persistence/file.py,sha256=XZy295cGc86HfUl_KuB-e7cECZW3bubiEdyJMVQ1OD0,6906
|
|
24
24
|
pydantic_graph/persistence/in_mem.py,sha256=MmahaVpdzmDB30Dm3ZfSCZBqgmx6vH4HXdBaWwVF0K0,6799
|
|
25
|
-
pydantic_graph-1.
|
|
26
|
-
pydantic_graph-1.
|
|
27
|
-
pydantic_graph-1.
|
|
28
|
-
pydantic_graph-1.
|
|
25
|
+
pydantic_graph-1.20.0.dist-info/METADATA,sha256=Iob0hAz2T1R6EGWDsSesCp03KNFZh6E9RZgS24pweSA,3895
|
|
26
|
+
pydantic_graph-1.20.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
27
|
+
pydantic_graph-1.20.0.dist-info/licenses/LICENSE,sha256=vA6Jc482lEyBBuGUfD1pYx-cM7jxvLYOxPidZ30t_PQ,1100
|
|
28
|
+
pydantic_graph-1.20.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|