vellum-ai 0.10.8__py3-none-any.whl → 0.10.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,7 +3,7 @@ from copy import deepcopy
3
3
  import logging
4
4
  from queue import Empty, Queue
5
5
  from threading import Event as ThreadingEvent, Thread
6
- from uuid import UUID, uuid4
6
+ from uuid import UUID
7
7
  from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Optional, Sequence, Set, Type, Union
8
8
 
9
9
  from vellum.workflows.constants import UNDEF
@@ -29,7 +29,6 @@ from vellum.workflows.events.node import (
29
29
  NodeExecutionStreamingBody,
30
30
  )
31
31
  from vellum.workflows.events.types import BaseEvent, ParentContext, WorkflowParentContext
32
- from vellum.workflows.events.utils import is_terminal_event
33
32
  from vellum.workflows.events.workflow import (
34
33
  WorkflowExecutionFulfilledBody,
35
34
  WorkflowExecutionInitiatedBody,
@@ -74,7 +73,9 @@ class WorkflowRunner(Generic[StateType]):
74
73
  parent_context: Optional[ParentContext] = None,
75
74
  ):
76
75
  if state and external_inputs:
77
- raise ValueError("Can only run a Workflow providing one of state or external inputs, not both")
76
+ raise ValueError(
77
+ "Can only run a Workflow providing one of state or external inputs, not both"
78
+ )
78
79
 
79
80
  self.workflow = workflow
80
81
  if entrypoint_nodes:
@@ -100,7 +101,9 @@ class WorkflowRunner(Generic[StateType]):
100
101
  if issubclass(ei.inputs_class.__parent_class__, BaseNode)
101
102
  ]
102
103
  else:
103
- normalized_inputs = deepcopy(inputs) if inputs else self.workflow.get_default_inputs()
104
+ normalized_inputs = (
105
+ deepcopy(inputs) if inputs else self.workflow.get_default_inputs()
106
+ )
104
107
  if state:
105
108
  self._initial_state = deepcopy(state)
106
109
  self._initial_state.meta.workflow_inputs = normalized_inputs
@@ -123,6 +126,7 @@ class WorkflowRunner(Generic[StateType]):
123
126
  "__snapshot_callback__",
124
127
  lambda s: self._snapshot_state(s),
125
128
  )
129
+ self.workflow.context._register_event_queue(self._workflow_event_queue)
126
130
 
127
131
  def _snapshot_state(self, state: StateType) -> StateType:
128
132
  self.workflow._store.append_state_snapshot(state)
@@ -148,8 +152,9 @@ class WorkflowRunner(Generic[StateType]):
148
152
  parent=WorkflowParentContext(
149
153
  span_id=span_id,
150
154
  workflow_definition=self.workflow.__class__,
151
- parent=self._parent_context
152
- )
155
+ parent=self._parent_context,
156
+ type="WORKFLOW",
157
+ ),
153
158
  ),
154
159
  )
155
160
  )
@@ -185,7 +190,11 @@ class WorkflowRunner(Generic[StateType]):
185
190
  instance=None,
186
191
  outputs_class=node.Outputs,
187
192
  )
188
- node.state.meta.node_outputs[output_descriptor] = streaming_output_queues[output.name]
193
+ node.state.meta.node_outputs[output_descriptor] = (
194
+ streaming_output_queues[output.name]
195
+ )
196
+ initiated_output: BaseOutput = BaseOutput(name=output.name)
197
+ initiated_ports = initiated_output > ports
189
198
  self._work_item_event_queue.put(
190
199
  WorkItemEvent(
191
200
  node=node,
@@ -194,14 +203,14 @@ class WorkflowRunner(Generic[StateType]):
194
203
  span_id=span_id,
195
204
  body=NodeExecutionStreamingBody(
196
205
  node_definition=node.__class__,
197
- output=BaseOutput(name=output.name),
198
- invoked_ports=invoked_ports,
206
+ output=initiated_output,
207
+ invoked_ports=initiated_ports,
199
208
  ),
200
209
  parent=WorkflowParentContext(
201
210
  span_id=span_id,
202
211
  workflow_definition=self.workflow.__class__,
203
- parent=self._parent_context
204
- )
212
+ parent=self._parent_context,
213
+ ),
205
214
  ),
206
215
  )
207
216
  )
@@ -230,7 +239,7 @@ class WorkflowRunner(Generic[StateType]):
230
239
  span_id=span_id,
231
240
  workflow_definition=self.workflow.__class__,
232
241
  parent=self._parent_context,
233
- )
242
+ ),
234
243
  ),
235
244
  )
236
245
  )
@@ -254,7 +263,7 @@ class WorkflowRunner(Generic[StateType]):
254
263
  span_id=span_id,
255
264
  workflow_definition=self.workflow.__class__,
256
265
  parent=self._parent_context,
257
- )
266
+ ),
258
267
  ),
259
268
  )
260
269
  )
@@ -268,7 +277,9 @@ class WorkflowRunner(Generic[StateType]):
268
277
  node.state.meta.node_outputs[descriptor] = output_value
269
278
 
270
279
  invoked_ports = ports(outputs, node.state)
271
- node.state.meta.node_execution_cache.fulfill_node_execution(node.__class__, span_id)
280
+ node.state.meta.node_execution_cache.fulfill_node_execution(
281
+ node.__class__, span_id
282
+ )
272
283
 
273
284
  self._work_item_event_queue.put(
274
285
  WorkItemEvent(
@@ -285,7 +296,7 @@ class WorkflowRunner(Generic[StateType]):
285
296
  span_id=span_id,
286
297
  workflow_definition=self.workflow.__class__,
287
298
  parent=self._parent_context,
288
- )
299
+ ),
289
300
  ),
290
301
  )
291
302
  )
@@ -304,12 +315,14 @@ class WorkflowRunner(Generic[StateType]):
304
315
  span_id=span_id,
305
316
  workflow_definition=self.workflow.__class__,
306
317
  parent=self._parent_context,
307
- )
318
+ ),
308
319
  ),
309
320
  )
310
321
  )
311
322
  except Exception as e:
312
- logger.exception(f"An unexpected error occurred while running node {node.__class__.__name__}")
323
+ logger.exception(
324
+ f"An unexpected error occurred while running node {node.__class__.__name__}"
325
+ )
313
326
 
314
327
  self._work_item_event_queue.put(
315
328
  WorkItemEvent(
@@ -327,15 +340,17 @@ class WorkflowRunner(Generic[StateType]):
327
340
  parent=WorkflowParentContext(
328
341
  span_id=span_id,
329
342
  workflow_definition=self.workflow.__class__,
330
- parent=self._parent_context
331
- )
343
+ parent=self._parent_context,
344
+ ),
332
345
  ),
333
346
  )
334
347
  )
335
348
 
336
349
  logger.debug(f"Finished running node: {node.__class__.__name__}")
337
350
 
338
- def _handle_invoked_ports(self, state: StateType, ports: Optional[Iterable[Port]]) -> None:
351
+ def _handle_invoked_ports(
352
+ self, state: StateType, ports: Optional[Iterable[Port]]
353
+ ) -> None:
339
354
  if not ports:
340
355
  return
341
356
 
@@ -350,7 +365,10 @@ class WorkflowRunner(Generic[StateType]):
350
365
  self._run_node_if_ready(next_state, edge.to_node, edge)
351
366
 
352
367
  def _run_node_if_ready(
353
- self, state: StateType, node_class: Type[BaseNode], invoked_by: Optional[Edge] = None
368
+ self,
369
+ state: StateType,
370
+ node_class: Type[BaseNode],
371
+ invoked_by: Optional[Edge] = None,
354
372
  ) -> None:
355
373
  with state.__lock__:
356
374
  for descriptor in node_class.ExternalInputs:
@@ -362,18 +380,27 @@ class WorkflowRunner(Generic[StateType]):
362
380
  return
363
381
 
364
382
  all_deps = self._dependencies[node_class]
365
- if not node_class.Trigger.should_initiate(state, all_deps, invoked_by):
383
+ node_span_id = state.meta.node_execution_cache.queue_node_execution(
384
+ node_class, all_deps, invoked_by
385
+ )
386
+ if not node_class.Trigger.should_initiate(state, all_deps, node_span_id):
366
387
  return
367
388
 
368
389
  node = node_class(state=state, context=self.workflow.context)
369
- node_span_id = uuid4()
370
- state.meta.node_execution_cache.initiate_node_execution(node_class, node_span_id)
390
+ state.meta.node_execution_cache.initiate_node_execution(
391
+ node_class, node_span_id
392
+ )
371
393
  self._active_nodes_by_execution_id[node_span_id] = node
372
394
 
373
- worker_thread = Thread(target=self._run_work_item, kwargs={"node": node, "span_id": node_span_id})
395
+ worker_thread = Thread(
396
+ target=self._run_work_item,
397
+ kwargs={"node": node, "span_id": node_span_id},
398
+ )
374
399
  worker_thread.start()
375
400
 
376
- def _handle_work_item_event(self, work_item_event: WorkItemEvent[StateType]) -> Optional[VellumError]:
401
+ def _handle_work_item_event(
402
+ self, work_item_event: WorkItemEvent[StateType]
403
+ ) -> Optional[VellumError]:
377
404
  node = work_item_event.node
378
405
  event = work_item_event.event
379
406
 
@@ -389,7 +416,10 @@ class WorkflowRunner(Generic[StateType]):
389
416
  node_output_descriptor = workflow_output_descriptor.instance
390
417
  if not isinstance(node_output_descriptor, OutputReference):
391
418
  continue
392
- if node_output_descriptor.outputs_class != event.node_definition.Outputs:
419
+ if (
420
+ node_output_descriptor.outputs_class
421
+ != event.node_definition.Outputs
422
+ ):
393
423
  continue
394
424
  if node_output_descriptor.name != event.output.name:
395
425
  continue
@@ -427,7 +457,9 @@ class WorkflowRunner(Generic[StateType]):
427
457
  parent=self._parent_context,
428
458
  )
429
459
 
430
- def _stream_workflow_event(self, output: BaseOutput) -> WorkflowExecutionStreamingEvent:
460
+ def _stream_workflow_event(
461
+ self, output: BaseOutput
462
+ ) -> WorkflowExecutionStreamingEvent:
431
463
  return WorkflowExecutionStreamingEvent(
432
464
  trace_id=self._initial_state.meta.trace_id,
433
465
  span_id=self._initial_state.meta.span_id,
@@ -435,10 +467,12 @@ class WorkflowRunner(Generic[StateType]):
435
467
  workflow_definition=self.workflow.__class__,
436
468
  output=output,
437
469
  ),
438
- parent=self._parent_context
470
+ parent=self._parent_context,
439
471
  )
440
472
 
441
- def _fulfill_workflow_event(self, outputs: OutputsType) -> WorkflowExecutionFulfilledEvent:
473
+ def _fulfill_workflow_event(
474
+ self, outputs: OutputsType
475
+ ) -> WorkflowExecutionFulfilledEvent:
442
476
  return WorkflowExecutionFulfilledEvent(
443
477
  trace_id=self._initial_state.meta.trace_id,
444
478
  span_id=self._initial_state.meta.span_id,
@@ -449,7 +483,9 @@ class WorkflowRunner(Generic[StateType]):
449
483
  parent=self._parent_context,
450
484
  )
451
485
 
452
- def _reject_workflow_event(self, error: VellumError) -> WorkflowExecutionRejectedEvent:
486
+ def _reject_workflow_event(
487
+ self, error: VellumError
488
+ ) -> WorkflowExecutionRejectedEvent:
453
489
  return WorkflowExecutionRejectedEvent(
454
490
  trace_id=self._initial_state.meta.trace_id,
455
491
  span_id=self._initial_state.meta.span_id,
@@ -469,7 +505,9 @@ class WorkflowRunner(Generic[StateType]):
469
505
  ),
470
506
  )
471
507
 
472
- def _pause_workflow_event(self, external_inputs: Iterable[ExternalInputReference]) -> WorkflowExecutionPausedEvent:
508
+ def _pause_workflow_event(
509
+ self, external_inputs: Iterable[ExternalInputReference]
510
+ ) -> WorkflowExecutionPausedEvent:
473
511
  return WorkflowExecutionPausedEvent(
474
512
  trace_id=self._initial_state.meta.trace_id,
475
513
  span_id=self._initial_state.meta.span_id,
@@ -486,7 +524,10 @@ class WorkflowRunner(Generic[StateType]):
486
524
  if not self._entrypoints:
487
525
  self._workflow_event_queue.put(
488
526
  self._reject_workflow_event(
489
- VellumError(message="No entrypoints defined", code=VellumErrorCode.INVALID_WORKFLOW)
527
+ VellumError(
528
+ message="No entrypoints defined",
529
+ code=VellumErrorCode.INVALID_WORKFLOW,
530
+ )
490
531
  )
491
532
  )
492
533
  return
@@ -505,7 +546,9 @@ class WorkflowRunner(Generic[StateType]):
505
546
  logger.exception(err_message)
506
547
  self._workflow_event_queue.put(
507
548
  self._reject_workflow_event(
508
- VellumError(code=VellumErrorCode.INTERNAL_ERROR, message=err_message),
549
+ VellumError(
550
+ code=VellumErrorCode.INTERNAL_ERROR, message=err_message
551
+ ),
509
552
  )
510
553
  )
511
554
  return
@@ -561,7 +604,11 @@ class WorkflowRunner(Generic[StateType]):
561
604
  if isinstance(value, BaseDescriptor):
562
605
  setattr(fulfilled_outputs, descriptor.name, value.resolve(final_state))
563
606
  elif isinstance(descriptor.instance, BaseDescriptor):
564
- setattr(fulfilled_outputs, descriptor.name, descriptor.instance.resolve(final_state))
607
+ setattr(
608
+ fulfilled_outputs,
609
+ descriptor.name,
610
+ descriptor.instance.resolve(final_state),
611
+ )
565
612
 
566
613
  self._workflow_event_queue.put(self._fulfill_workflow_event(fulfilled_outputs))
567
614
 
@@ -586,24 +633,41 @@ class WorkflowRunner(Generic[StateType]):
586
633
  self._cancel_signal.wait()
587
634
  self._workflow_event_queue.put(
588
635
  self._reject_workflow_event(
589
- VellumError(code=VellumErrorCode.WORKFLOW_CANCELLED, message="Workflow run cancelled")
636
+ VellumError(
637
+ code=VellumErrorCode.WORKFLOW_CANCELLED,
638
+ message="Workflow run cancelled",
639
+ )
590
640
  )
591
641
  )
592
642
 
643
+ def _is_terminal_event(self, event: WorkflowEvent) -> bool:
644
+ if (
645
+ event.name == "workflow.execution.fulfilled"
646
+ or event.name == "workflow.execution.rejected"
647
+ or event.name == "workflow.execution.paused"
648
+ ):
649
+ return event.workflow_definition == self.workflow.__class__
650
+ return False
651
+
593
652
  def stream(self) -> WorkflowEventStream:
594
653
  background_thread = Thread(
595
- target=self._run_background_thread, name=f"{self.workflow.__class__.__name__}.background_thread"
654
+ target=self._run_background_thread,
655
+ name=f"{self.workflow.__class__.__name__}.background_thread",
596
656
  )
597
657
  background_thread.start()
598
658
 
599
659
  if self._cancel_signal:
600
660
  cancel_thread = Thread(
601
- target=self._run_cancel_thread, name=f"{self.workflow.__class__.__name__}.cancel_thread"
661
+ target=self._run_cancel_thread,
662
+ name=f"{self.workflow.__class__.__name__}.cancel_thread",
602
663
  )
603
664
  cancel_thread.start()
604
665
 
605
666
  event: WorkflowEvent
606
- if self._initial_state.meta.is_terminated or self._initial_state.meta.is_terminated is None:
667
+ if (
668
+ self._initial_state.meta.is_terminated
669
+ or self._initial_state.meta.is_terminated is None
670
+ ):
607
671
  event = self._initiate_workflow_event()
608
672
  else:
609
673
  event = self._resume_workflow_event()
@@ -612,7 +676,10 @@ class WorkflowRunner(Generic[StateType]):
612
676
  self._initial_state.meta.is_terminated = False
613
677
 
614
678
  # The extra level of indirection prevents the runner from waiting on the caller to consume the event stream
615
- stream_thread = Thread(target=self._stream, name=f"{self.workflow.__class__.__name__}.stream_thread")
679
+ stream_thread = Thread(
680
+ target=self._stream,
681
+ name=f"{self.workflow.__class__.__name__}.stream_thread",
682
+ )
616
683
  stream_thread.start()
617
684
 
618
685
  while stream_thread.is_alive():
@@ -623,7 +690,7 @@ class WorkflowRunner(Generic[StateType]):
623
690
 
624
691
  yield self._emit_event(event)
625
692
 
626
- if is_terminal_event(event):
693
+ if self._is_terminal_event(event):
627
694
  break
628
695
 
629
696
  try:
@@ -632,7 +699,7 @@ class WorkflowRunner(Generic[StateType]):
632
699
  except Empty:
633
700
  pass
634
701
 
635
- if not is_terminal_event(event):
702
+ if not self._is_terminal_event(event):
636
703
  yield self._reject_workflow_event(
637
704
  VellumError(
638
705
  code=VellumErrorCode.INTERNAL_ERROR,
@@ -5,15 +5,15 @@ from datetime import datetime
5
5
  from queue import Queue
6
6
  from threading import Lock
7
7
  from uuid import UUID, uuid4
8
- from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, Optional, Sequence, Set, Tuple, Type, cast
8
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, cast
9
9
  from typing_extensions import dataclass_transform
10
10
 
11
11
  from pydantic import GetCoreSchemaHandler, field_serializer
12
12
  from pydantic_core import core_schema
13
13
 
14
14
  from vellum.core.pydantic_utilities import UniversalBaseModel
15
-
16
15
  from vellum.workflows.constants import UNDEF
16
+ from vellum.workflows.edges.edge import Edge
17
17
  from vellum.workflows.inputs.base import BaseInputs
18
18
  from vellum.workflows.references import ExternalInputReference, OutputReference, StateValueReference
19
19
  from vellum.workflows.types.generics import StateType
@@ -71,58 +71,92 @@ def _make_snapshottable(value: Any, snapshot_callback: Callable[[], None]) -> An
71
71
 
72
72
 
73
73
  class NodeExecutionCache:
74
- _node_execution_ids: Dict[Type["BaseNode"], Stack[UUID]]
74
+ _node_executions_fulfilled: Dict[Type["BaseNode"], Stack[UUID]]
75
75
  _node_executions_initiated: Dict[Type["BaseNode"], Set[UUID]]
76
- _dependencies_invoked: Dict[str, Set[str]]
76
+ _node_executions_queued: Dict[Type["BaseNode"], List[UUID]]
77
+ _dependencies_invoked: Dict[UUID, Set[Type["BaseNode"]]]
77
78
 
78
79
  def __init__(
79
80
  self,
80
81
  dependencies_invoked: Optional[Dict[str, Sequence[str]]] = None,
81
- node_execution_ids: Optional[Dict[str, Sequence[str]]] = None,
82
+ node_executions_fulfilled: Optional[Dict[str, Sequence[str]]] = None,
82
83
  node_executions_initiated: Optional[Dict[str, Sequence[str]]] = None,
84
+ node_executions_queued: Optional[Dict[str, Sequence[str]]] = None,
83
85
  ) -> None:
84
86
  self._dependencies_invoked = defaultdict(set)
85
- self._node_execution_ids = defaultdict(Stack[UUID])
87
+ self._node_executions_fulfilled = defaultdict(Stack[UUID])
86
88
  self._node_executions_initiated = defaultdict(set)
89
+ self._node_executions_queued = defaultdict(list)
87
90
 
88
- for node, dependencies in (dependencies_invoked or {}).items():
89
- self._dependencies_invoked[node].update(dependencies)
91
+ for execution_id, dependencies in (dependencies_invoked or {}).items():
92
+ self._dependencies_invoked[UUID(execution_id)] = {get_class_by_qualname(dep) for dep in dependencies}
90
93
 
91
- for node, execution_ids in (node_execution_ids or {}).items():
94
+ for node, execution_ids in (node_executions_fulfilled or {}).items():
92
95
  node_class = get_class_by_qualname(node)
93
- self._node_execution_ids[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
96
+ self._node_executions_fulfilled[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
94
97
 
95
98
  for node, execution_ids in (node_executions_initiated or {}).items():
96
99
  node_class = get_class_by_qualname(node)
97
100
  self._node_executions_initiated[node_class].update({UUID(execution_id) for execution_id in execution_ids})
98
101
 
99
- @property
100
- def dependencies_invoked(self) -> Dict[str, Set[str]]:
101
- return self._dependencies_invoked
102
+ for node, execution_ids in (node_executions_queued or {}).items():
103
+ node_class = get_class_by_qualname(node)
104
+ self._node_executions_queued[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
102
105
 
103
- def is_node_initiated(self, node: Type["BaseNode"]) -> bool:
104
- return node in self._node_executions_initiated and len(self._node_executions_initiated[node]) > 0
106
+ def _invoke_dependency(
107
+ self,
108
+ execution_id: UUID,
109
+ node: Type["BaseNode"],
110
+ dependency: Type["BaseNode"],
111
+ dependencies: Set["Type[BaseNode]"],
112
+ ) -> None:
113
+ self._dependencies_invoked[execution_id].add(dependency)
114
+ if all(dep in self._dependencies_invoked[execution_id] for dep in dependencies):
115
+ self._node_executions_queued[node].remove(execution_id)
116
+
117
+ def queue_node_execution(
118
+ self, node: Type["BaseNode"], dependencies: Set["Type[BaseNode]"], invoked_by: Optional[Edge] = None
119
+ ) -> UUID:
120
+ execution_id = uuid4()
121
+ if not invoked_by:
122
+ return execution_id
123
+
124
+ source_node = invoked_by.from_port.node_class
125
+ for queued_node_execution_id in self._node_executions_queued[node]:
126
+ if source_node not in self._dependencies_invoked[queued_node_execution_id]:
127
+ self._invoke_dependency(queued_node_execution_id, node, source_node, dependencies)
128
+ return queued_node_execution_id
129
+
130
+ self._node_executions_queued[node].append(execution_id)
131
+ self._invoke_dependency(execution_id, node, source_node, dependencies)
132
+ return execution_id
133
+
134
+ def is_node_execution_initiated(self, node: Type["BaseNode"], execution_id: UUID) -> bool:
135
+ return execution_id in self._node_executions_initiated[node]
105
136
 
106
137
  def initiate_node_execution(self, node: Type["BaseNode"], execution_id: UUID) -> None:
107
138
  self._node_executions_initiated[node].add(execution_id)
108
139
 
109
140
  def fulfill_node_execution(self, node: Type["BaseNode"], execution_id: UUID) -> None:
110
- self._node_executions_initiated[node].remove(execution_id)
111
- self._node_execution_ids[node].push(execution_id)
141
+ self._node_executions_fulfilled[node].push(execution_id)
112
142
 
113
143
  def get_execution_count(self, node: Type["BaseNode"]) -> int:
114
- return self._node_execution_ids[node].size()
144
+ return self._node_executions_fulfilled[node].size()
115
145
 
116
146
  def dump(self) -> Dict[str, Any]:
117
147
  return {
118
148
  "dependencies_invoked": {
119
- node: list(dependencies) for node, dependencies in self._dependencies_invoked.items()
149
+ str(execution_id): [str(dep) for dep in dependencies]
150
+ for execution_id, dependencies in self._dependencies_invoked.items()
120
151
  },
121
152
  "node_executions_initiated": {
122
153
  str(node): list(execution_ids) for node, execution_ids in self._node_executions_initiated.items()
123
154
  },
124
- "node_execution_ids": {
125
- str(node): execution_ids.dump() for node, execution_ids in self._node_execution_ids.items()
155
+ "node_executions_fulfilled": {
156
+ str(node): execution_ids.dump() for node, execution_ids in self._node_executions_fulfilled.items()
157
+ },
158
+ "node_executions_queued": {
159
+ str(node): execution_ids for node, execution_ids in self._node_executions_queued.items()
126
160
  },
127
161
  }
128
162
 
@@ -1,14 +1,24 @@
1
1
  from functools import cached_property
2
- from typing import Optional
2
+ from queue import Queue
3
+ from typing import TYPE_CHECKING, Optional
3
4
 
4
5
  from vellum import Vellum
5
-
6
+ from vellum.workflows.events.types import ParentContext
6
7
  from vellum.workflows.vellum_client import create_vellum_client
7
8
 
9
+ if TYPE_CHECKING:
10
+ from vellum.workflows.events.workflow import WorkflowEvent
11
+
8
12
 
9
13
  class WorkflowContext:
10
- def __init__(self, _vellum_client: Optional[Vellum] = None):
14
+ def __init__(
15
+ self,
16
+ _vellum_client: Optional[Vellum] = None,
17
+ _parent_context: Optional[ParentContext] = None,
18
+ ):
11
19
  self._vellum_client = _vellum_client
20
+ self._parent_context = _parent_context
21
+ self._event_queue: Optional[Queue["WorkflowEvent"]] = None
12
22
 
13
23
  @cached_property
14
24
  def vellum_client(self) -> Vellum:
@@ -16,3 +26,16 @@ class WorkflowContext:
16
26
  return self._vellum_client
17
27
 
18
28
  return create_vellum_client()
29
+
30
+ @cached_property
31
+ def parent_context(self) -> Optional[ParentContext]:
32
+ if self._parent_context:
33
+ return self._parent_context
34
+ return None
35
+
36
+ def _emit_subworkflow_event(self, event: "WorkflowEvent") -> None:
37
+ if self._event_queue:
38
+ self._event_queue.put(event)
39
+
40
+ def _register_event_queue(self, event_queue: Queue["WorkflowEvent"]) -> None:
41
+ self._event_queue = event_queue
@@ -83,7 +83,7 @@ EntityInputsInterface = Dict[
83
83
  class MergeBehavior(Enum):
84
84
  AWAIT_ALL = "AWAIT_ALL"
85
85
  AWAIT_ANY = "AWAIT_ANY"
86
-
86
+ AWAIT_ATTRIBUTES = "AWAIT_ATTRIBUTES"
87
87
 
88
88
  class ConditionType(Enum):
89
89
  IF = "IF"