vellum-ai 0.10.8__py3-none-any.whl → 0.10.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/types/logical_operator.py +2 -0
- vellum/workflows/descriptors/utils.py +27 -0
- vellum/workflows/events/__init__.py +0 -2
- vellum/workflows/events/tests/test_event.py +2 -1
- vellum/workflows/events/types.py +35 -29
- vellum/workflows/events/workflow.py +14 -7
- vellum/workflows/nodes/bases/base.py +100 -38
- vellum/workflows/nodes/core/try_node/node.py +22 -4
- vellum/workflows/nodes/core/try_node/tests/test_node.py +15 -0
- vellum/workflows/runner/runner.py +109 -42
- vellum/workflows/state/base.py +55 -21
- vellum/workflows/state/context.py +26 -3
- vellum/workflows/types/core.py +1 -1
- vellum/workflows/workflows/base.py +51 -17
- vellum/workflows/workflows/event_filters.py +61 -0
- {vellum_ai-0.10.8.dist-info → vellum_ai-0.10.9.dist-info}/METADATA +1 -1
- {vellum_ai-0.10.8.dist-info → vellum_ai-0.10.9.dist-info}/RECORD +24 -22
- vellum_ee/workflows/display/nodes/vellum/__init__.py +6 -4
- vellum_ee/workflows/display/nodes/vellum/error_node.py +49 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +203 -0
- vellum/workflows/events/utils.py +0 -5
- {vellum_ai-0.10.8.dist-info → vellum_ai-0.10.9.dist-info}/LICENSE +0 -0
- {vellum_ai-0.10.8.dist-info → vellum_ai-0.10.9.dist-info}/WHEEL +0 -0
- {vellum_ai-0.10.8.dist-info → vellum_ai-0.10.9.dist-info}/entry_points.txt +0 -0
@@ -3,7 +3,7 @@ from copy import deepcopy
|
|
3
3
|
import logging
|
4
4
|
from queue import Empty, Queue
|
5
5
|
from threading import Event as ThreadingEvent, Thread
|
6
|
-
from uuid import UUID
|
6
|
+
from uuid import UUID
|
7
7
|
from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Optional, Sequence, Set, Type, Union
|
8
8
|
|
9
9
|
from vellum.workflows.constants import UNDEF
|
@@ -29,7 +29,6 @@ from vellum.workflows.events.node import (
|
|
29
29
|
NodeExecutionStreamingBody,
|
30
30
|
)
|
31
31
|
from vellum.workflows.events.types import BaseEvent, ParentContext, WorkflowParentContext
|
32
|
-
from vellum.workflows.events.utils import is_terminal_event
|
33
32
|
from vellum.workflows.events.workflow import (
|
34
33
|
WorkflowExecutionFulfilledBody,
|
35
34
|
WorkflowExecutionInitiatedBody,
|
@@ -74,7 +73,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
74
73
|
parent_context: Optional[ParentContext] = None,
|
75
74
|
):
|
76
75
|
if state and external_inputs:
|
77
|
-
raise ValueError(
|
76
|
+
raise ValueError(
|
77
|
+
"Can only run a Workflow providing one of state or external inputs, not both"
|
78
|
+
)
|
78
79
|
|
79
80
|
self.workflow = workflow
|
80
81
|
if entrypoint_nodes:
|
@@ -100,7 +101,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
100
101
|
if issubclass(ei.inputs_class.__parent_class__, BaseNode)
|
101
102
|
]
|
102
103
|
else:
|
103
|
-
normalized_inputs =
|
104
|
+
normalized_inputs = (
|
105
|
+
deepcopy(inputs) if inputs else self.workflow.get_default_inputs()
|
106
|
+
)
|
104
107
|
if state:
|
105
108
|
self._initial_state = deepcopy(state)
|
106
109
|
self._initial_state.meta.workflow_inputs = normalized_inputs
|
@@ -123,6 +126,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
123
126
|
"__snapshot_callback__",
|
124
127
|
lambda s: self._snapshot_state(s),
|
125
128
|
)
|
129
|
+
self.workflow.context._register_event_queue(self._workflow_event_queue)
|
126
130
|
|
127
131
|
def _snapshot_state(self, state: StateType) -> StateType:
|
128
132
|
self.workflow._store.append_state_snapshot(state)
|
@@ -148,8 +152,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
148
152
|
parent=WorkflowParentContext(
|
149
153
|
span_id=span_id,
|
150
154
|
workflow_definition=self.workflow.__class__,
|
151
|
-
parent=self._parent_context
|
152
|
-
|
155
|
+
parent=self._parent_context,
|
156
|
+
type="WORKFLOW",
|
157
|
+
),
|
153
158
|
),
|
154
159
|
)
|
155
160
|
)
|
@@ -185,7 +190,11 @@ class WorkflowRunner(Generic[StateType]):
|
|
185
190
|
instance=None,
|
186
191
|
outputs_class=node.Outputs,
|
187
192
|
)
|
188
|
-
node.state.meta.node_outputs[output_descriptor] =
|
193
|
+
node.state.meta.node_outputs[output_descriptor] = (
|
194
|
+
streaming_output_queues[output.name]
|
195
|
+
)
|
196
|
+
initiated_output: BaseOutput = BaseOutput(name=output.name)
|
197
|
+
initiated_ports = initiated_output > ports
|
189
198
|
self._work_item_event_queue.put(
|
190
199
|
WorkItemEvent(
|
191
200
|
node=node,
|
@@ -194,14 +203,14 @@ class WorkflowRunner(Generic[StateType]):
|
|
194
203
|
span_id=span_id,
|
195
204
|
body=NodeExecutionStreamingBody(
|
196
205
|
node_definition=node.__class__,
|
197
|
-
output=
|
198
|
-
invoked_ports=
|
206
|
+
output=initiated_output,
|
207
|
+
invoked_ports=initiated_ports,
|
199
208
|
),
|
200
209
|
parent=WorkflowParentContext(
|
201
210
|
span_id=span_id,
|
202
211
|
workflow_definition=self.workflow.__class__,
|
203
|
-
parent=self._parent_context
|
204
|
-
)
|
212
|
+
parent=self._parent_context,
|
213
|
+
),
|
205
214
|
),
|
206
215
|
)
|
207
216
|
)
|
@@ -230,7 +239,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
230
239
|
span_id=span_id,
|
231
240
|
workflow_definition=self.workflow.__class__,
|
232
241
|
parent=self._parent_context,
|
233
|
-
)
|
242
|
+
),
|
234
243
|
),
|
235
244
|
)
|
236
245
|
)
|
@@ -254,7 +263,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
254
263
|
span_id=span_id,
|
255
264
|
workflow_definition=self.workflow.__class__,
|
256
265
|
parent=self._parent_context,
|
257
|
-
)
|
266
|
+
),
|
258
267
|
),
|
259
268
|
)
|
260
269
|
)
|
@@ -268,7 +277,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
268
277
|
node.state.meta.node_outputs[descriptor] = output_value
|
269
278
|
|
270
279
|
invoked_ports = ports(outputs, node.state)
|
271
|
-
node.state.meta.node_execution_cache.fulfill_node_execution(
|
280
|
+
node.state.meta.node_execution_cache.fulfill_node_execution(
|
281
|
+
node.__class__, span_id
|
282
|
+
)
|
272
283
|
|
273
284
|
self._work_item_event_queue.put(
|
274
285
|
WorkItemEvent(
|
@@ -285,7 +296,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
285
296
|
span_id=span_id,
|
286
297
|
workflow_definition=self.workflow.__class__,
|
287
298
|
parent=self._parent_context,
|
288
|
-
)
|
299
|
+
),
|
289
300
|
),
|
290
301
|
)
|
291
302
|
)
|
@@ -304,12 +315,14 @@ class WorkflowRunner(Generic[StateType]):
|
|
304
315
|
span_id=span_id,
|
305
316
|
workflow_definition=self.workflow.__class__,
|
306
317
|
parent=self._parent_context,
|
307
|
-
)
|
318
|
+
),
|
308
319
|
),
|
309
320
|
)
|
310
321
|
)
|
311
322
|
except Exception as e:
|
312
|
-
logger.exception(
|
323
|
+
logger.exception(
|
324
|
+
f"An unexpected error occurred while running node {node.__class__.__name__}"
|
325
|
+
)
|
313
326
|
|
314
327
|
self._work_item_event_queue.put(
|
315
328
|
WorkItemEvent(
|
@@ -327,15 +340,17 @@ class WorkflowRunner(Generic[StateType]):
|
|
327
340
|
parent=WorkflowParentContext(
|
328
341
|
span_id=span_id,
|
329
342
|
workflow_definition=self.workflow.__class__,
|
330
|
-
parent=self._parent_context
|
331
|
-
)
|
343
|
+
parent=self._parent_context,
|
344
|
+
),
|
332
345
|
),
|
333
346
|
)
|
334
347
|
)
|
335
348
|
|
336
349
|
logger.debug(f"Finished running node: {node.__class__.__name__}")
|
337
350
|
|
338
|
-
def _handle_invoked_ports(
|
351
|
+
def _handle_invoked_ports(
|
352
|
+
self, state: StateType, ports: Optional[Iterable[Port]]
|
353
|
+
) -> None:
|
339
354
|
if not ports:
|
340
355
|
return
|
341
356
|
|
@@ -350,7 +365,10 @@ class WorkflowRunner(Generic[StateType]):
|
|
350
365
|
self._run_node_if_ready(next_state, edge.to_node, edge)
|
351
366
|
|
352
367
|
def _run_node_if_ready(
|
353
|
-
self,
|
368
|
+
self,
|
369
|
+
state: StateType,
|
370
|
+
node_class: Type[BaseNode],
|
371
|
+
invoked_by: Optional[Edge] = None,
|
354
372
|
) -> None:
|
355
373
|
with state.__lock__:
|
356
374
|
for descriptor in node_class.ExternalInputs:
|
@@ -362,18 +380,27 @@ class WorkflowRunner(Generic[StateType]):
|
|
362
380
|
return
|
363
381
|
|
364
382
|
all_deps = self._dependencies[node_class]
|
365
|
-
|
383
|
+
node_span_id = state.meta.node_execution_cache.queue_node_execution(
|
384
|
+
node_class, all_deps, invoked_by
|
385
|
+
)
|
386
|
+
if not node_class.Trigger.should_initiate(state, all_deps, node_span_id):
|
366
387
|
return
|
367
388
|
|
368
389
|
node = node_class(state=state, context=self.workflow.context)
|
369
|
-
|
370
|
-
|
390
|
+
state.meta.node_execution_cache.initiate_node_execution(
|
391
|
+
node_class, node_span_id
|
392
|
+
)
|
371
393
|
self._active_nodes_by_execution_id[node_span_id] = node
|
372
394
|
|
373
|
-
worker_thread = Thread(
|
395
|
+
worker_thread = Thread(
|
396
|
+
target=self._run_work_item,
|
397
|
+
kwargs={"node": node, "span_id": node_span_id},
|
398
|
+
)
|
374
399
|
worker_thread.start()
|
375
400
|
|
376
|
-
def _handle_work_item_event(
|
401
|
+
def _handle_work_item_event(
|
402
|
+
self, work_item_event: WorkItemEvent[StateType]
|
403
|
+
) -> Optional[VellumError]:
|
377
404
|
node = work_item_event.node
|
378
405
|
event = work_item_event.event
|
379
406
|
|
@@ -389,7 +416,10 @@ class WorkflowRunner(Generic[StateType]):
|
|
389
416
|
node_output_descriptor = workflow_output_descriptor.instance
|
390
417
|
if not isinstance(node_output_descriptor, OutputReference):
|
391
418
|
continue
|
392
|
-
if
|
419
|
+
if (
|
420
|
+
node_output_descriptor.outputs_class
|
421
|
+
!= event.node_definition.Outputs
|
422
|
+
):
|
393
423
|
continue
|
394
424
|
if node_output_descriptor.name != event.output.name:
|
395
425
|
continue
|
@@ -427,7 +457,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
427
457
|
parent=self._parent_context,
|
428
458
|
)
|
429
459
|
|
430
|
-
def _stream_workflow_event(
|
460
|
+
def _stream_workflow_event(
|
461
|
+
self, output: BaseOutput
|
462
|
+
) -> WorkflowExecutionStreamingEvent:
|
431
463
|
return WorkflowExecutionStreamingEvent(
|
432
464
|
trace_id=self._initial_state.meta.trace_id,
|
433
465
|
span_id=self._initial_state.meta.span_id,
|
@@ -435,10 +467,12 @@ class WorkflowRunner(Generic[StateType]):
|
|
435
467
|
workflow_definition=self.workflow.__class__,
|
436
468
|
output=output,
|
437
469
|
),
|
438
|
-
parent=self._parent_context
|
470
|
+
parent=self._parent_context,
|
439
471
|
)
|
440
472
|
|
441
|
-
def _fulfill_workflow_event(
|
473
|
+
def _fulfill_workflow_event(
|
474
|
+
self, outputs: OutputsType
|
475
|
+
) -> WorkflowExecutionFulfilledEvent:
|
442
476
|
return WorkflowExecutionFulfilledEvent(
|
443
477
|
trace_id=self._initial_state.meta.trace_id,
|
444
478
|
span_id=self._initial_state.meta.span_id,
|
@@ -449,7 +483,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
449
483
|
parent=self._parent_context,
|
450
484
|
)
|
451
485
|
|
452
|
-
def _reject_workflow_event(
|
486
|
+
def _reject_workflow_event(
|
487
|
+
self, error: VellumError
|
488
|
+
) -> WorkflowExecutionRejectedEvent:
|
453
489
|
return WorkflowExecutionRejectedEvent(
|
454
490
|
trace_id=self._initial_state.meta.trace_id,
|
455
491
|
span_id=self._initial_state.meta.span_id,
|
@@ -469,7 +505,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
469
505
|
),
|
470
506
|
)
|
471
507
|
|
472
|
-
def _pause_workflow_event(
|
508
|
+
def _pause_workflow_event(
|
509
|
+
self, external_inputs: Iterable[ExternalInputReference]
|
510
|
+
) -> WorkflowExecutionPausedEvent:
|
473
511
|
return WorkflowExecutionPausedEvent(
|
474
512
|
trace_id=self._initial_state.meta.trace_id,
|
475
513
|
span_id=self._initial_state.meta.span_id,
|
@@ -486,7 +524,10 @@ class WorkflowRunner(Generic[StateType]):
|
|
486
524
|
if not self._entrypoints:
|
487
525
|
self._workflow_event_queue.put(
|
488
526
|
self._reject_workflow_event(
|
489
|
-
VellumError(
|
527
|
+
VellumError(
|
528
|
+
message="No entrypoints defined",
|
529
|
+
code=VellumErrorCode.INVALID_WORKFLOW,
|
530
|
+
)
|
490
531
|
)
|
491
532
|
)
|
492
533
|
return
|
@@ -505,7 +546,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
505
546
|
logger.exception(err_message)
|
506
547
|
self._workflow_event_queue.put(
|
507
548
|
self._reject_workflow_event(
|
508
|
-
VellumError(
|
549
|
+
VellumError(
|
550
|
+
code=VellumErrorCode.INTERNAL_ERROR, message=err_message
|
551
|
+
),
|
509
552
|
)
|
510
553
|
)
|
511
554
|
return
|
@@ -561,7 +604,11 @@ class WorkflowRunner(Generic[StateType]):
|
|
561
604
|
if isinstance(value, BaseDescriptor):
|
562
605
|
setattr(fulfilled_outputs, descriptor.name, value.resolve(final_state))
|
563
606
|
elif isinstance(descriptor.instance, BaseDescriptor):
|
564
|
-
setattr(
|
607
|
+
setattr(
|
608
|
+
fulfilled_outputs,
|
609
|
+
descriptor.name,
|
610
|
+
descriptor.instance.resolve(final_state),
|
611
|
+
)
|
565
612
|
|
566
613
|
self._workflow_event_queue.put(self._fulfill_workflow_event(fulfilled_outputs))
|
567
614
|
|
@@ -586,24 +633,41 @@ class WorkflowRunner(Generic[StateType]):
|
|
586
633
|
self._cancel_signal.wait()
|
587
634
|
self._workflow_event_queue.put(
|
588
635
|
self._reject_workflow_event(
|
589
|
-
VellumError(
|
636
|
+
VellumError(
|
637
|
+
code=VellumErrorCode.WORKFLOW_CANCELLED,
|
638
|
+
message="Workflow run cancelled",
|
639
|
+
)
|
590
640
|
)
|
591
641
|
)
|
592
642
|
|
643
|
+
def _is_terminal_event(self, event: WorkflowEvent) -> bool:
|
644
|
+
if (
|
645
|
+
event.name == "workflow.execution.fulfilled"
|
646
|
+
or event.name == "workflow.execution.rejected"
|
647
|
+
or event.name == "workflow.execution.paused"
|
648
|
+
):
|
649
|
+
return event.workflow_definition == self.workflow.__class__
|
650
|
+
return False
|
651
|
+
|
593
652
|
def stream(self) -> WorkflowEventStream:
|
594
653
|
background_thread = Thread(
|
595
|
-
target=self._run_background_thread,
|
654
|
+
target=self._run_background_thread,
|
655
|
+
name=f"{self.workflow.__class__.__name__}.background_thread",
|
596
656
|
)
|
597
657
|
background_thread.start()
|
598
658
|
|
599
659
|
if self._cancel_signal:
|
600
660
|
cancel_thread = Thread(
|
601
|
-
target=self._run_cancel_thread,
|
661
|
+
target=self._run_cancel_thread,
|
662
|
+
name=f"{self.workflow.__class__.__name__}.cancel_thread",
|
602
663
|
)
|
603
664
|
cancel_thread.start()
|
604
665
|
|
605
666
|
event: WorkflowEvent
|
606
|
-
if
|
667
|
+
if (
|
668
|
+
self._initial_state.meta.is_terminated
|
669
|
+
or self._initial_state.meta.is_terminated is None
|
670
|
+
):
|
607
671
|
event = self._initiate_workflow_event()
|
608
672
|
else:
|
609
673
|
event = self._resume_workflow_event()
|
@@ -612,7 +676,10 @@ class WorkflowRunner(Generic[StateType]):
|
|
612
676
|
self._initial_state.meta.is_terminated = False
|
613
677
|
|
614
678
|
# The extra level of indirection prevents the runner from waiting on the caller to consume the event stream
|
615
|
-
stream_thread = Thread(
|
679
|
+
stream_thread = Thread(
|
680
|
+
target=self._stream,
|
681
|
+
name=f"{self.workflow.__class__.__name__}.stream_thread",
|
682
|
+
)
|
616
683
|
stream_thread.start()
|
617
684
|
|
618
685
|
while stream_thread.is_alive():
|
@@ -623,7 +690,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
623
690
|
|
624
691
|
yield self._emit_event(event)
|
625
692
|
|
626
|
-
if
|
693
|
+
if self._is_terminal_event(event):
|
627
694
|
break
|
628
695
|
|
629
696
|
try:
|
@@ -632,7 +699,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
632
699
|
except Empty:
|
633
700
|
pass
|
634
701
|
|
635
|
-
if not
|
702
|
+
if not self._is_terminal_event(event):
|
636
703
|
yield self._reject_workflow_event(
|
637
704
|
VellumError(
|
638
705
|
code=VellumErrorCode.INTERNAL_ERROR,
|
vellum/workflows/state/base.py
CHANGED
@@ -5,15 +5,15 @@ from datetime import datetime
|
|
5
5
|
from queue import Queue
|
6
6
|
from threading import Lock
|
7
7
|
from uuid import UUID, uuid4
|
8
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, Optional, Sequence, Set, Tuple, Type, cast
|
8
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, cast
|
9
9
|
from typing_extensions import dataclass_transform
|
10
10
|
|
11
11
|
from pydantic import GetCoreSchemaHandler, field_serializer
|
12
12
|
from pydantic_core import core_schema
|
13
13
|
|
14
14
|
from vellum.core.pydantic_utilities import UniversalBaseModel
|
15
|
-
|
16
15
|
from vellum.workflows.constants import UNDEF
|
16
|
+
from vellum.workflows.edges.edge import Edge
|
17
17
|
from vellum.workflows.inputs.base import BaseInputs
|
18
18
|
from vellum.workflows.references import ExternalInputReference, OutputReference, StateValueReference
|
19
19
|
from vellum.workflows.types.generics import StateType
|
@@ -71,58 +71,92 @@ def _make_snapshottable(value: Any, snapshot_callback: Callable[[], None]) -> An
|
|
71
71
|
|
72
72
|
|
73
73
|
class NodeExecutionCache:
|
74
|
-
|
74
|
+
_node_executions_fulfilled: Dict[Type["BaseNode"], Stack[UUID]]
|
75
75
|
_node_executions_initiated: Dict[Type["BaseNode"], Set[UUID]]
|
76
|
-
|
76
|
+
_node_executions_queued: Dict[Type["BaseNode"], List[UUID]]
|
77
|
+
_dependencies_invoked: Dict[UUID, Set[Type["BaseNode"]]]
|
77
78
|
|
78
79
|
def __init__(
|
79
80
|
self,
|
80
81
|
dependencies_invoked: Optional[Dict[str, Sequence[str]]] = None,
|
81
|
-
|
82
|
+
node_executions_fulfilled: Optional[Dict[str, Sequence[str]]] = None,
|
82
83
|
node_executions_initiated: Optional[Dict[str, Sequence[str]]] = None,
|
84
|
+
node_executions_queued: Optional[Dict[str, Sequence[str]]] = None,
|
83
85
|
) -> None:
|
84
86
|
self._dependencies_invoked = defaultdict(set)
|
85
|
-
self.
|
87
|
+
self._node_executions_fulfilled = defaultdict(Stack[UUID])
|
86
88
|
self._node_executions_initiated = defaultdict(set)
|
89
|
+
self._node_executions_queued = defaultdict(list)
|
87
90
|
|
88
|
-
for
|
89
|
-
self._dependencies_invoked[
|
91
|
+
for execution_id, dependencies in (dependencies_invoked or {}).items():
|
92
|
+
self._dependencies_invoked[UUID(execution_id)] = {get_class_by_qualname(dep) for dep in dependencies}
|
90
93
|
|
91
|
-
for node, execution_ids in (
|
94
|
+
for node, execution_ids in (node_executions_fulfilled or {}).items():
|
92
95
|
node_class = get_class_by_qualname(node)
|
93
|
-
self.
|
96
|
+
self._node_executions_fulfilled[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
|
94
97
|
|
95
98
|
for node, execution_ids in (node_executions_initiated or {}).items():
|
96
99
|
node_class = get_class_by_qualname(node)
|
97
100
|
self._node_executions_initiated[node_class].update({UUID(execution_id) for execution_id in execution_ids})
|
98
101
|
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
+
for node, execution_ids in (node_executions_queued or {}).items():
|
103
|
+
node_class = get_class_by_qualname(node)
|
104
|
+
self._node_executions_queued[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
|
102
105
|
|
103
|
-
def
|
104
|
-
|
106
|
+
def _invoke_dependency(
|
107
|
+
self,
|
108
|
+
execution_id: UUID,
|
109
|
+
node: Type["BaseNode"],
|
110
|
+
dependency: Type["BaseNode"],
|
111
|
+
dependencies: Set["Type[BaseNode]"],
|
112
|
+
) -> None:
|
113
|
+
self._dependencies_invoked[execution_id].add(dependency)
|
114
|
+
if all(dep in self._dependencies_invoked[execution_id] for dep in dependencies):
|
115
|
+
self._node_executions_queued[node].remove(execution_id)
|
116
|
+
|
117
|
+
def queue_node_execution(
|
118
|
+
self, node: Type["BaseNode"], dependencies: Set["Type[BaseNode]"], invoked_by: Optional[Edge] = None
|
119
|
+
) -> UUID:
|
120
|
+
execution_id = uuid4()
|
121
|
+
if not invoked_by:
|
122
|
+
return execution_id
|
123
|
+
|
124
|
+
source_node = invoked_by.from_port.node_class
|
125
|
+
for queued_node_execution_id in self._node_executions_queued[node]:
|
126
|
+
if source_node not in self._dependencies_invoked[queued_node_execution_id]:
|
127
|
+
self._invoke_dependency(queued_node_execution_id, node, source_node, dependencies)
|
128
|
+
return queued_node_execution_id
|
129
|
+
|
130
|
+
self._node_executions_queued[node].append(execution_id)
|
131
|
+
self._invoke_dependency(execution_id, node, source_node, dependencies)
|
132
|
+
return execution_id
|
133
|
+
|
134
|
+
def is_node_execution_initiated(self, node: Type["BaseNode"], execution_id: UUID) -> bool:
|
135
|
+
return execution_id in self._node_executions_initiated[node]
|
105
136
|
|
106
137
|
def initiate_node_execution(self, node: Type["BaseNode"], execution_id: UUID) -> None:
|
107
138
|
self._node_executions_initiated[node].add(execution_id)
|
108
139
|
|
109
140
|
def fulfill_node_execution(self, node: Type["BaseNode"], execution_id: UUID) -> None:
|
110
|
-
self.
|
111
|
-
self._node_execution_ids[node].push(execution_id)
|
141
|
+
self._node_executions_fulfilled[node].push(execution_id)
|
112
142
|
|
113
143
|
def get_execution_count(self, node: Type["BaseNode"]) -> int:
|
114
|
-
return self.
|
144
|
+
return self._node_executions_fulfilled[node].size()
|
115
145
|
|
116
146
|
def dump(self) -> Dict[str, Any]:
|
117
147
|
return {
|
118
148
|
"dependencies_invoked": {
|
119
|
-
|
149
|
+
str(execution_id): [str(dep) for dep in dependencies]
|
150
|
+
for execution_id, dependencies in self._dependencies_invoked.items()
|
120
151
|
},
|
121
152
|
"node_executions_initiated": {
|
122
153
|
str(node): list(execution_ids) for node, execution_ids in self._node_executions_initiated.items()
|
123
154
|
},
|
124
|
-
"
|
125
|
-
str(node): execution_ids.dump() for node, execution_ids in self.
|
155
|
+
"node_executions_fulfilled": {
|
156
|
+
str(node): execution_ids.dump() for node, execution_ids in self._node_executions_fulfilled.items()
|
157
|
+
},
|
158
|
+
"node_executions_queued": {
|
159
|
+
str(node): execution_ids for node, execution_ids in self._node_executions_queued.items()
|
126
160
|
},
|
127
161
|
}
|
128
162
|
|
@@ -1,14 +1,24 @@
|
|
1
1
|
from functools import cached_property
|
2
|
-
from
|
2
|
+
from queue import Queue
|
3
|
+
from typing import TYPE_CHECKING, Optional
|
3
4
|
|
4
5
|
from vellum import Vellum
|
5
|
-
|
6
|
+
from vellum.workflows.events.types import ParentContext
|
6
7
|
from vellum.workflows.vellum_client import create_vellum_client
|
7
8
|
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from vellum.workflows.events.workflow import WorkflowEvent
|
11
|
+
|
8
12
|
|
9
13
|
class WorkflowContext:
|
10
|
-
def __init__(
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
_vellum_client: Optional[Vellum] = None,
|
17
|
+
_parent_context: Optional[ParentContext] = None,
|
18
|
+
):
|
11
19
|
self._vellum_client = _vellum_client
|
20
|
+
self._parent_context = _parent_context
|
21
|
+
self._event_queue: Optional[Queue["WorkflowEvent"]] = None
|
12
22
|
|
13
23
|
@cached_property
|
14
24
|
def vellum_client(self) -> Vellum:
|
@@ -16,3 +26,16 @@ class WorkflowContext:
|
|
16
26
|
return self._vellum_client
|
17
27
|
|
18
28
|
return create_vellum_client()
|
29
|
+
|
30
|
+
@cached_property
|
31
|
+
def parent_context(self) -> Optional[ParentContext]:
|
32
|
+
if self._parent_context:
|
33
|
+
return self._parent_context
|
34
|
+
return None
|
35
|
+
|
36
|
+
def _emit_subworkflow_event(self, event: "WorkflowEvent") -> None:
|
37
|
+
if self._event_queue:
|
38
|
+
self._event_queue.put(event)
|
39
|
+
|
40
|
+
def _register_event_queue(self, event_queue: Queue["WorkflowEvent"]) -> None:
|
41
|
+
self._event_queue = event_queue
|