vellum-ai 0.10.8__py3-none-any.whl → 0.10.9__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/types/logical_operator.py +2 -0
- vellum/workflows/descriptors/utils.py +27 -0
- vellum/workflows/events/__init__.py +0 -2
- vellum/workflows/events/tests/test_event.py +2 -1
- vellum/workflows/events/types.py +35 -29
- vellum/workflows/events/workflow.py +14 -7
- vellum/workflows/nodes/bases/base.py +100 -38
- vellum/workflows/nodes/core/try_node/node.py +22 -4
- vellum/workflows/nodes/core/try_node/tests/test_node.py +15 -0
- vellum/workflows/runner/runner.py +109 -42
- vellum/workflows/state/base.py +55 -21
- vellum/workflows/state/context.py +26 -3
- vellum/workflows/types/core.py +1 -1
- vellum/workflows/workflows/base.py +51 -17
- vellum/workflows/workflows/event_filters.py +61 -0
- {vellum_ai-0.10.8.dist-info → vellum_ai-0.10.9.dist-info}/METADATA +1 -1
- {vellum_ai-0.10.8.dist-info → vellum_ai-0.10.9.dist-info}/RECORD +24 -22
- vellum_ee/workflows/display/nodes/vellum/__init__.py +6 -4
- vellum_ee/workflows/display/nodes/vellum/error_node.py +49 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +203 -0
- vellum/workflows/events/utils.py +0 -5
- {vellum_ai-0.10.8.dist-info → vellum_ai-0.10.9.dist-info}/LICENSE +0 -0
- {vellum_ai-0.10.8.dist-info → vellum_ai-0.10.9.dist-info}/WHEEL +0 -0
- {vellum_ai-0.10.8.dist-info → vellum_ai-0.10.9.dist-info}/entry_points.txt +0 -0
@@ -3,7 +3,7 @@ from copy import deepcopy
|
|
3
3
|
import logging
|
4
4
|
from queue import Empty, Queue
|
5
5
|
from threading import Event as ThreadingEvent, Thread
|
6
|
-
from uuid import UUID
|
6
|
+
from uuid import UUID
|
7
7
|
from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Optional, Sequence, Set, Type, Union
|
8
8
|
|
9
9
|
from vellum.workflows.constants import UNDEF
|
@@ -29,7 +29,6 @@ from vellum.workflows.events.node import (
|
|
29
29
|
NodeExecutionStreamingBody,
|
30
30
|
)
|
31
31
|
from vellum.workflows.events.types import BaseEvent, ParentContext, WorkflowParentContext
|
32
|
-
from vellum.workflows.events.utils import is_terminal_event
|
33
32
|
from vellum.workflows.events.workflow import (
|
34
33
|
WorkflowExecutionFulfilledBody,
|
35
34
|
WorkflowExecutionInitiatedBody,
|
@@ -74,7 +73,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
74
73
|
parent_context: Optional[ParentContext] = None,
|
75
74
|
):
|
76
75
|
if state and external_inputs:
|
77
|
-
raise ValueError(
|
76
|
+
raise ValueError(
|
77
|
+
"Can only run a Workflow providing one of state or external inputs, not both"
|
78
|
+
)
|
78
79
|
|
79
80
|
self.workflow = workflow
|
80
81
|
if entrypoint_nodes:
|
@@ -100,7 +101,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
100
101
|
if issubclass(ei.inputs_class.__parent_class__, BaseNode)
|
101
102
|
]
|
102
103
|
else:
|
103
|
-
normalized_inputs =
|
104
|
+
normalized_inputs = (
|
105
|
+
deepcopy(inputs) if inputs else self.workflow.get_default_inputs()
|
106
|
+
)
|
104
107
|
if state:
|
105
108
|
self._initial_state = deepcopy(state)
|
106
109
|
self._initial_state.meta.workflow_inputs = normalized_inputs
|
@@ -123,6 +126,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
123
126
|
"__snapshot_callback__",
|
124
127
|
lambda s: self._snapshot_state(s),
|
125
128
|
)
|
129
|
+
self.workflow.context._register_event_queue(self._workflow_event_queue)
|
126
130
|
|
127
131
|
def _snapshot_state(self, state: StateType) -> StateType:
|
128
132
|
self.workflow._store.append_state_snapshot(state)
|
@@ -148,8 +152,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
148
152
|
parent=WorkflowParentContext(
|
149
153
|
span_id=span_id,
|
150
154
|
workflow_definition=self.workflow.__class__,
|
151
|
-
parent=self._parent_context
|
152
|
-
|
155
|
+
parent=self._parent_context,
|
156
|
+
type="WORKFLOW",
|
157
|
+
),
|
153
158
|
),
|
154
159
|
)
|
155
160
|
)
|
@@ -185,7 +190,11 @@ class WorkflowRunner(Generic[StateType]):
|
|
185
190
|
instance=None,
|
186
191
|
outputs_class=node.Outputs,
|
187
192
|
)
|
188
|
-
node.state.meta.node_outputs[output_descriptor] =
|
193
|
+
node.state.meta.node_outputs[output_descriptor] = (
|
194
|
+
streaming_output_queues[output.name]
|
195
|
+
)
|
196
|
+
initiated_output: BaseOutput = BaseOutput(name=output.name)
|
197
|
+
initiated_ports = initiated_output > ports
|
189
198
|
self._work_item_event_queue.put(
|
190
199
|
WorkItemEvent(
|
191
200
|
node=node,
|
@@ -194,14 +203,14 @@ class WorkflowRunner(Generic[StateType]):
|
|
194
203
|
span_id=span_id,
|
195
204
|
body=NodeExecutionStreamingBody(
|
196
205
|
node_definition=node.__class__,
|
197
|
-
output=
|
198
|
-
invoked_ports=
|
206
|
+
output=initiated_output,
|
207
|
+
invoked_ports=initiated_ports,
|
199
208
|
),
|
200
209
|
parent=WorkflowParentContext(
|
201
210
|
span_id=span_id,
|
202
211
|
workflow_definition=self.workflow.__class__,
|
203
|
-
parent=self._parent_context
|
204
|
-
)
|
212
|
+
parent=self._parent_context,
|
213
|
+
),
|
205
214
|
),
|
206
215
|
)
|
207
216
|
)
|
@@ -230,7 +239,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
230
239
|
span_id=span_id,
|
231
240
|
workflow_definition=self.workflow.__class__,
|
232
241
|
parent=self._parent_context,
|
233
|
-
)
|
242
|
+
),
|
234
243
|
),
|
235
244
|
)
|
236
245
|
)
|
@@ -254,7 +263,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
254
263
|
span_id=span_id,
|
255
264
|
workflow_definition=self.workflow.__class__,
|
256
265
|
parent=self._parent_context,
|
257
|
-
)
|
266
|
+
),
|
258
267
|
),
|
259
268
|
)
|
260
269
|
)
|
@@ -268,7 +277,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
268
277
|
node.state.meta.node_outputs[descriptor] = output_value
|
269
278
|
|
270
279
|
invoked_ports = ports(outputs, node.state)
|
271
|
-
node.state.meta.node_execution_cache.fulfill_node_execution(
|
280
|
+
node.state.meta.node_execution_cache.fulfill_node_execution(
|
281
|
+
node.__class__, span_id
|
282
|
+
)
|
272
283
|
|
273
284
|
self._work_item_event_queue.put(
|
274
285
|
WorkItemEvent(
|
@@ -285,7 +296,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
285
296
|
span_id=span_id,
|
286
297
|
workflow_definition=self.workflow.__class__,
|
287
298
|
parent=self._parent_context,
|
288
|
-
)
|
299
|
+
),
|
289
300
|
),
|
290
301
|
)
|
291
302
|
)
|
@@ -304,12 +315,14 @@ class WorkflowRunner(Generic[StateType]):
|
|
304
315
|
span_id=span_id,
|
305
316
|
workflow_definition=self.workflow.__class__,
|
306
317
|
parent=self._parent_context,
|
307
|
-
)
|
318
|
+
),
|
308
319
|
),
|
309
320
|
)
|
310
321
|
)
|
311
322
|
except Exception as e:
|
312
|
-
logger.exception(
|
323
|
+
logger.exception(
|
324
|
+
f"An unexpected error occurred while running node {node.__class__.__name__}"
|
325
|
+
)
|
313
326
|
|
314
327
|
self._work_item_event_queue.put(
|
315
328
|
WorkItemEvent(
|
@@ -327,15 +340,17 @@ class WorkflowRunner(Generic[StateType]):
|
|
327
340
|
parent=WorkflowParentContext(
|
328
341
|
span_id=span_id,
|
329
342
|
workflow_definition=self.workflow.__class__,
|
330
|
-
parent=self._parent_context
|
331
|
-
)
|
343
|
+
parent=self._parent_context,
|
344
|
+
),
|
332
345
|
),
|
333
346
|
)
|
334
347
|
)
|
335
348
|
|
336
349
|
logger.debug(f"Finished running node: {node.__class__.__name__}")
|
337
350
|
|
338
|
-
def _handle_invoked_ports(
|
351
|
+
def _handle_invoked_ports(
|
352
|
+
self, state: StateType, ports: Optional[Iterable[Port]]
|
353
|
+
) -> None:
|
339
354
|
if not ports:
|
340
355
|
return
|
341
356
|
|
@@ -350,7 +365,10 @@ class WorkflowRunner(Generic[StateType]):
|
|
350
365
|
self._run_node_if_ready(next_state, edge.to_node, edge)
|
351
366
|
|
352
367
|
def _run_node_if_ready(
|
353
|
-
self,
|
368
|
+
self,
|
369
|
+
state: StateType,
|
370
|
+
node_class: Type[BaseNode],
|
371
|
+
invoked_by: Optional[Edge] = None,
|
354
372
|
) -> None:
|
355
373
|
with state.__lock__:
|
356
374
|
for descriptor in node_class.ExternalInputs:
|
@@ -362,18 +380,27 @@ class WorkflowRunner(Generic[StateType]):
|
|
362
380
|
return
|
363
381
|
|
364
382
|
all_deps = self._dependencies[node_class]
|
365
|
-
|
383
|
+
node_span_id = state.meta.node_execution_cache.queue_node_execution(
|
384
|
+
node_class, all_deps, invoked_by
|
385
|
+
)
|
386
|
+
if not node_class.Trigger.should_initiate(state, all_deps, node_span_id):
|
366
387
|
return
|
367
388
|
|
368
389
|
node = node_class(state=state, context=self.workflow.context)
|
369
|
-
|
370
|
-
|
390
|
+
state.meta.node_execution_cache.initiate_node_execution(
|
391
|
+
node_class, node_span_id
|
392
|
+
)
|
371
393
|
self._active_nodes_by_execution_id[node_span_id] = node
|
372
394
|
|
373
|
-
worker_thread = Thread(
|
395
|
+
worker_thread = Thread(
|
396
|
+
target=self._run_work_item,
|
397
|
+
kwargs={"node": node, "span_id": node_span_id},
|
398
|
+
)
|
374
399
|
worker_thread.start()
|
375
400
|
|
376
|
-
def _handle_work_item_event(
|
401
|
+
def _handle_work_item_event(
|
402
|
+
self, work_item_event: WorkItemEvent[StateType]
|
403
|
+
) -> Optional[VellumError]:
|
377
404
|
node = work_item_event.node
|
378
405
|
event = work_item_event.event
|
379
406
|
|
@@ -389,7 +416,10 @@ class WorkflowRunner(Generic[StateType]):
|
|
389
416
|
node_output_descriptor = workflow_output_descriptor.instance
|
390
417
|
if not isinstance(node_output_descriptor, OutputReference):
|
391
418
|
continue
|
392
|
-
if
|
419
|
+
if (
|
420
|
+
node_output_descriptor.outputs_class
|
421
|
+
!= event.node_definition.Outputs
|
422
|
+
):
|
393
423
|
continue
|
394
424
|
if node_output_descriptor.name != event.output.name:
|
395
425
|
continue
|
@@ -427,7 +457,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
427
457
|
parent=self._parent_context,
|
428
458
|
)
|
429
459
|
|
430
|
-
def _stream_workflow_event(
|
460
|
+
def _stream_workflow_event(
|
461
|
+
self, output: BaseOutput
|
462
|
+
) -> WorkflowExecutionStreamingEvent:
|
431
463
|
return WorkflowExecutionStreamingEvent(
|
432
464
|
trace_id=self._initial_state.meta.trace_id,
|
433
465
|
span_id=self._initial_state.meta.span_id,
|
@@ -435,10 +467,12 @@ class WorkflowRunner(Generic[StateType]):
|
|
435
467
|
workflow_definition=self.workflow.__class__,
|
436
468
|
output=output,
|
437
469
|
),
|
438
|
-
parent=self._parent_context
|
470
|
+
parent=self._parent_context,
|
439
471
|
)
|
440
472
|
|
441
|
-
def _fulfill_workflow_event(
|
473
|
+
def _fulfill_workflow_event(
|
474
|
+
self, outputs: OutputsType
|
475
|
+
) -> WorkflowExecutionFulfilledEvent:
|
442
476
|
return WorkflowExecutionFulfilledEvent(
|
443
477
|
trace_id=self._initial_state.meta.trace_id,
|
444
478
|
span_id=self._initial_state.meta.span_id,
|
@@ -449,7 +483,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
449
483
|
parent=self._parent_context,
|
450
484
|
)
|
451
485
|
|
452
|
-
def _reject_workflow_event(
|
486
|
+
def _reject_workflow_event(
|
487
|
+
self, error: VellumError
|
488
|
+
) -> WorkflowExecutionRejectedEvent:
|
453
489
|
return WorkflowExecutionRejectedEvent(
|
454
490
|
trace_id=self._initial_state.meta.trace_id,
|
455
491
|
span_id=self._initial_state.meta.span_id,
|
@@ -469,7 +505,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
469
505
|
),
|
470
506
|
)
|
471
507
|
|
472
|
-
def _pause_workflow_event(
|
508
|
+
def _pause_workflow_event(
|
509
|
+
self, external_inputs: Iterable[ExternalInputReference]
|
510
|
+
) -> WorkflowExecutionPausedEvent:
|
473
511
|
return WorkflowExecutionPausedEvent(
|
474
512
|
trace_id=self._initial_state.meta.trace_id,
|
475
513
|
span_id=self._initial_state.meta.span_id,
|
@@ -486,7 +524,10 @@ class WorkflowRunner(Generic[StateType]):
|
|
486
524
|
if not self._entrypoints:
|
487
525
|
self._workflow_event_queue.put(
|
488
526
|
self._reject_workflow_event(
|
489
|
-
VellumError(
|
527
|
+
VellumError(
|
528
|
+
message="No entrypoints defined",
|
529
|
+
code=VellumErrorCode.INVALID_WORKFLOW,
|
530
|
+
)
|
490
531
|
)
|
491
532
|
)
|
492
533
|
return
|
@@ -505,7 +546,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
505
546
|
logger.exception(err_message)
|
506
547
|
self._workflow_event_queue.put(
|
507
548
|
self._reject_workflow_event(
|
508
|
-
VellumError(
|
549
|
+
VellumError(
|
550
|
+
code=VellumErrorCode.INTERNAL_ERROR, message=err_message
|
551
|
+
),
|
509
552
|
)
|
510
553
|
)
|
511
554
|
return
|
@@ -561,7 +604,11 @@ class WorkflowRunner(Generic[StateType]):
|
|
561
604
|
if isinstance(value, BaseDescriptor):
|
562
605
|
setattr(fulfilled_outputs, descriptor.name, value.resolve(final_state))
|
563
606
|
elif isinstance(descriptor.instance, BaseDescriptor):
|
564
|
-
setattr(
|
607
|
+
setattr(
|
608
|
+
fulfilled_outputs,
|
609
|
+
descriptor.name,
|
610
|
+
descriptor.instance.resolve(final_state),
|
611
|
+
)
|
565
612
|
|
566
613
|
self._workflow_event_queue.put(self._fulfill_workflow_event(fulfilled_outputs))
|
567
614
|
|
@@ -586,24 +633,41 @@ class WorkflowRunner(Generic[StateType]):
|
|
586
633
|
self._cancel_signal.wait()
|
587
634
|
self._workflow_event_queue.put(
|
588
635
|
self._reject_workflow_event(
|
589
|
-
VellumError(
|
636
|
+
VellumError(
|
637
|
+
code=VellumErrorCode.WORKFLOW_CANCELLED,
|
638
|
+
message="Workflow run cancelled",
|
639
|
+
)
|
590
640
|
)
|
591
641
|
)
|
592
642
|
|
643
|
+
def _is_terminal_event(self, event: WorkflowEvent) -> bool:
|
644
|
+
if (
|
645
|
+
event.name == "workflow.execution.fulfilled"
|
646
|
+
or event.name == "workflow.execution.rejected"
|
647
|
+
or event.name == "workflow.execution.paused"
|
648
|
+
):
|
649
|
+
return event.workflow_definition == self.workflow.__class__
|
650
|
+
return False
|
651
|
+
|
593
652
|
def stream(self) -> WorkflowEventStream:
|
594
653
|
background_thread = Thread(
|
595
|
-
target=self._run_background_thread,
|
654
|
+
target=self._run_background_thread,
|
655
|
+
name=f"{self.workflow.__class__.__name__}.background_thread",
|
596
656
|
)
|
597
657
|
background_thread.start()
|
598
658
|
|
599
659
|
if self._cancel_signal:
|
600
660
|
cancel_thread = Thread(
|
601
|
-
target=self._run_cancel_thread,
|
661
|
+
target=self._run_cancel_thread,
|
662
|
+
name=f"{self.workflow.__class__.__name__}.cancel_thread",
|
602
663
|
)
|
603
664
|
cancel_thread.start()
|
604
665
|
|
605
666
|
event: WorkflowEvent
|
606
|
-
if
|
667
|
+
if (
|
668
|
+
self._initial_state.meta.is_terminated
|
669
|
+
or self._initial_state.meta.is_terminated is None
|
670
|
+
):
|
607
671
|
event = self._initiate_workflow_event()
|
608
672
|
else:
|
609
673
|
event = self._resume_workflow_event()
|
@@ -612,7 +676,10 @@ class WorkflowRunner(Generic[StateType]):
|
|
612
676
|
self._initial_state.meta.is_terminated = False
|
613
677
|
|
614
678
|
# The extra level of indirection prevents the runner from waiting on the caller to consume the event stream
|
615
|
-
stream_thread = Thread(
|
679
|
+
stream_thread = Thread(
|
680
|
+
target=self._stream,
|
681
|
+
name=f"{self.workflow.__class__.__name__}.stream_thread",
|
682
|
+
)
|
616
683
|
stream_thread.start()
|
617
684
|
|
618
685
|
while stream_thread.is_alive():
|
@@ -623,7 +690,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
623
690
|
|
624
691
|
yield self._emit_event(event)
|
625
692
|
|
626
|
-
if
|
693
|
+
if self._is_terminal_event(event):
|
627
694
|
break
|
628
695
|
|
629
696
|
try:
|
@@ -632,7 +699,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
632
699
|
except Empty:
|
633
700
|
pass
|
634
701
|
|
635
|
-
if not
|
702
|
+
if not self._is_terminal_event(event):
|
636
703
|
yield self._reject_workflow_event(
|
637
704
|
VellumError(
|
638
705
|
code=VellumErrorCode.INTERNAL_ERROR,
|
vellum/workflows/state/base.py
CHANGED
@@ -5,15 +5,15 @@ from datetime import datetime
|
|
5
5
|
from queue import Queue
|
6
6
|
from threading import Lock
|
7
7
|
from uuid import UUID, uuid4
|
8
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, Optional, Sequence, Set, Tuple, Type, cast
|
8
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, cast
|
9
9
|
from typing_extensions import dataclass_transform
|
10
10
|
|
11
11
|
from pydantic import GetCoreSchemaHandler, field_serializer
|
12
12
|
from pydantic_core import core_schema
|
13
13
|
|
14
14
|
from vellum.core.pydantic_utilities import UniversalBaseModel
|
15
|
-
|
16
15
|
from vellum.workflows.constants import UNDEF
|
16
|
+
from vellum.workflows.edges.edge import Edge
|
17
17
|
from vellum.workflows.inputs.base import BaseInputs
|
18
18
|
from vellum.workflows.references import ExternalInputReference, OutputReference, StateValueReference
|
19
19
|
from vellum.workflows.types.generics import StateType
|
@@ -71,58 +71,92 @@ def _make_snapshottable(value: Any, snapshot_callback: Callable[[], None]) -> An
|
|
71
71
|
|
72
72
|
|
73
73
|
class NodeExecutionCache:
|
74
|
-
|
74
|
+
_node_executions_fulfilled: Dict[Type["BaseNode"], Stack[UUID]]
|
75
75
|
_node_executions_initiated: Dict[Type["BaseNode"], Set[UUID]]
|
76
|
-
|
76
|
+
_node_executions_queued: Dict[Type["BaseNode"], List[UUID]]
|
77
|
+
_dependencies_invoked: Dict[UUID, Set[Type["BaseNode"]]]
|
77
78
|
|
78
79
|
def __init__(
|
79
80
|
self,
|
80
81
|
dependencies_invoked: Optional[Dict[str, Sequence[str]]] = None,
|
81
|
-
|
82
|
+
node_executions_fulfilled: Optional[Dict[str, Sequence[str]]] = None,
|
82
83
|
node_executions_initiated: Optional[Dict[str, Sequence[str]]] = None,
|
84
|
+
node_executions_queued: Optional[Dict[str, Sequence[str]]] = None,
|
83
85
|
) -> None:
|
84
86
|
self._dependencies_invoked = defaultdict(set)
|
85
|
-
self.
|
87
|
+
self._node_executions_fulfilled = defaultdict(Stack[UUID])
|
86
88
|
self._node_executions_initiated = defaultdict(set)
|
89
|
+
self._node_executions_queued = defaultdict(list)
|
87
90
|
|
88
|
-
for
|
89
|
-
self._dependencies_invoked[
|
91
|
+
for execution_id, dependencies in (dependencies_invoked or {}).items():
|
92
|
+
self._dependencies_invoked[UUID(execution_id)] = {get_class_by_qualname(dep) for dep in dependencies}
|
90
93
|
|
91
|
-
for node, execution_ids in (
|
94
|
+
for node, execution_ids in (node_executions_fulfilled or {}).items():
|
92
95
|
node_class = get_class_by_qualname(node)
|
93
|
-
self.
|
96
|
+
self._node_executions_fulfilled[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
|
94
97
|
|
95
98
|
for node, execution_ids in (node_executions_initiated or {}).items():
|
96
99
|
node_class = get_class_by_qualname(node)
|
97
100
|
self._node_executions_initiated[node_class].update({UUID(execution_id) for execution_id in execution_ids})
|
98
101
|
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
+
for node, execution_ids in (node_executions_queued or {}).items():
|
103
|
+
node_class = get_class_by_qualname(node)
|
104
|
+
self._node_executions_queued[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
|
102
105
|
|
103
|
-
def
|
104
|
-
|
106
|
+
def _invoke_dependency(
|
107
|
+
self,
|
108
|
+
execution_id: UUID,
|
109
|
+
node: Type["BaseNode"],
|
110
|
+
dependency: Type["BaseNode"],
|
111
|
+
dependencies: Set["Type[BaseNode]"],
|
112
|
+
) -> None:
|
113
|
+
self._dependencies_invoked[execution_id].add(dependency)
|
114
|
+
if all(dep in self._dependencies_invoked[execution_id] for dep in dependencies):
|
115
|
+
self._node_executions_queued[node].remove(execution_id)
|
116
|
+
|
117
|
+
def queue_node_execution(
|
118
|
+
self, node: Type["BaseNode"], dependencies: Set["Type[BaseNode]"], invoked_by: Optional[Edge] = None
|
119
|
+
) -> UUID:
|
120
|
+
execution_id = uuid4()
|
121
|
+
if not invoked_by:
|
122
|
+
return execution_id
|
123
|
+
|
124
|
+
source_node = invoked_by.from_port.node_class
|
125
|
+
for queued_node_execution_id in self._node_executions_queued[node]:
|
126
|
+
if source_node not in self._dependencies_invoked[queued_node_execution_id]:
|
127
|
+
self._invoke_dependency(queued_node_execution_id, node, source_node, dependencies)
|
128
|
+
return queued_node_execution_id
|
129
|
+
|
130
|
+
self._node_executions_queued[node].append(execution_id)
|
131
|
+
self._invoke_dependency(execution_id, node, source_node, dependencies)
|
132
|
+
return execution_id
|
133
|
+
|
134
|
+
def is_node_execution_initiated(self, node: Type["BaseNode"], execution_id: UUID) -> bool:
|
135
|
+
return execution_id in self._node_executions_initiated[node]
|
105
136
|
|
106
137
|
def initiate_node_execution(self, node: Type["BaseNode"], execution_id: UUID) -> None:
|
107
138
|
self._node_executions_initiated[node].add(execution_id)
|
108
139
|
|
109
140
|
def fulfill_node_execution(self, node: Type["BaseNode"], execution_id: UUID) -> None:
|
110
|
-
self.
|
111
|
-
self._node_execution_ids[node].push(execution_id)
|
141
|
+
self._node_executions_fulfilled[node].push(execution_id)
|
112
142
|
|
113
143
|
def get_execution_count(self, node: Type["BaseNode"]) -> int:
|
114
|
-
return self.
|
144
|
+
return self._node_executions_fulfilled[node].size()
|
115
145
|
|
116
146
|
def dump(self) -> Dict[str, Any]:
|
117
147
|
return {
|
118
148
|
"dependencies_invoked": {
|
119
|
-
|
149
|
+
str(execution_id): [str(dep) for dep in dependencies]
|
150
|
+
for execution_id, dependencies in self._dependencies_invoked.items()
|
120
151
|
},
|
121
152
|
"node_executions_initiated": {
|
122
153
|
str(node): list(execution_ids) for node, execution_ids in self._node_executions_initiated.items()
|
123
154
|
},
|
124
|
-
"
|
125
|
-
str(node): execution_ids.dump() for node, execution_ids in self.
|
155
|
+
"node_executions_fulfilled": {
|
156
|
+
str(node): execution_ids.dump() for node, execution_ids in self._node_executions_fulfilled.items()
|
157
|
+
},
|
158
|
+
"node_executions_queued": {
|
159
|
+
str(node): execution_ids for node, execution_ids in self._node_executions_queued.items()
|
126
160
|
},
|
127
161
|
}
|
128
162
|
|
@@ -1,14 +1,24 @@
|
|
1
1
|
from functools import cached_property
|
2
|
-
from
|
2
|
+
from queue import Queue
|
3
|
+
from typing import TYPE_CHECKING, Optional
|
3
4
|
|
4
5
|
from vellum import Vellum
|
5
|
-
|
6
|
+
from vellum.workflows.events.types import ParentContext
|
6
7
|
from vellum.workflows.vellum_client import create_vellum_client
|
7
8
|
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from vellum.workflows.events.workflow import WorkflowEvent
|
11
|
+
|
8
12
|
|
9
13
|
class WorkflowContext:
|
10
|
-
def __init__(
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
_vellum_client: Optional[Vellum] = None,
|
17
|
+
_parent_context: Optional[ParentContext] = None,
|
18
|
+
):
|
11
19
|
self._vellum_client = _vellum_client
|
20
|
+
self._parent_context = _parent_context
|
21
|
+
self._event_queue: Optional[Queue["WorkflowEvent"]] = None
|
12
22
|
|
13
23
|
@cached_property
|
14
24
|
def vellum_client(self) -> Vellum:
|
@@ -16,3 +26,16 @@ class WorkflowContext:
|
|
16
26
|
return self._vellum_client
|
17
27
|
|
18
28
|
return create_vellum_client()
|
29
|
+
|
30
|
+
@cached_property
|
31
|
+
def parent_context(self) -> Optional[ParentContext]:
|
32
|
+
if self._parent_context:
|
33
|
+
return self._parent_context
|
34
|
+
return None
|
35
|
+
|
36
|
+
def _emit_subworkflow_event(self, event: "WorkflowEvent") -> None:
|
37
|
+
if self._event_queue:
|
38
|
+
self._event_queue.put(event)
|
39
|
+
|
40
|
+
def _register_event_queue(self, event_queue: Queue["WorkflowEvent"]) -> None:
|
41
|
+
self._event_queue = event_queue
|