vellum-ai 0.10.8__py3-none-any.whl → 0.10.9__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -3,7 +3,7 @@ from copy import deepcopy
3
3
  import logging
4
4
  from queue import Empty, Queue
5
5
  from threading import Event as ThreadingEvent, Thread
6
- from uuid import UUID, uuid4
6
+ from uuid import UUID
7
7
  from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Optional, Sequence, Set, Type, Union
8
8
 
9
9
  from vellum.workflows.constants import UNDEF
@@ -29,7 +29,6 @@ from vellum.workflows.events.node import (
29
29
  NodeExecutionStreamingBody,
30
30
  )
31
31
  from vellum.workflows.events.types import BaseEvent, ParentContext, WorkflowParentContext
32
- from vellum.workflows.events.utils import is_terminal_event
33
32
  from vellum.workflows.events.workflow import (
34
33
  WorkflowExecutionFulfilledBody,
35
34
  WorkflowExecutionInitiatedBody,
@@ -74,7 +73,9 @@ class WorkflowRunner(Generic[StateType]):
74
73
  parent_context: Optional[ParentContext] = None,
75
74
  ):
76
75
  if state and external_inputs:
77
- raise ValueError("Can only run a Workflow providing one of state or external inputs, not both")
76
+ raise ValueError(
77
+ "Can only run a Workflow providing one of state or external inputs, not both"
78
+ )
78
79
 
79
80
  self.workflow = workflow
80
81
  if entrypoint_nodes:
@@ -100,7 +101,9 @@ class WorkflowRunner(Generic[StateType]):
100
101
  if issubclass(ei.inputs_class.__parent_class__, BaseNode)
101
102
  ]
102
103
  else:
103
- normalized_inputs = deepcopy(inputs) if inputs else self.workflow.get_default_inputs()
104
+ normalized_inputs = (
105
+ deepcopy(inputs) if inputs else self.workflow.get_default_inputs()
106
+ )
104
107
  if state:
105
108
  self._initial_state = deepcopy(state)
106
109
  self._initial_state.meta.workflow_inputs = normalized_inputs
@@ -123,6 +126,7 @@ class WorkflowRunner(Generic[StateType]):
123
126
  "__snapshot_callback__",
124
127
  lambda s: self._snapshot_state(s),
125
128
  )
129
+ self.workflow.context._register_event_queue(self._workflow_event_queue)
126
130
 
127
131
  def _snapshot_state(self, state: StateType) -> StateType:
128
132
  self.workflow._store.append_state_snapshot(state)
@@ -148,8 +152,9 @@ class WorkflowRunner(Generic[StateType]):
148
152
  parent=WorkflowParentContext(
149
153
  span_id=span_id,
150
154
  workflow_definition=self.workflow.__class__,
151
- parent=self._parent_context
152
- )
155
+ parent=self._parent_context,
156
+ type="WORKFLOW",
157
+ ),
153
158
  ),
154
159
  )
155
160
  )
@@ -185,7 +190,11 @@ class WorkflowRunner(Generic[StateType]):
185
190
  instance=None,
186
191
  outputs_class=node.Outputs,
187
192
  )
188
- node.state.meta.node_outputs[output_descriptor] = streaming_output_queues[output.name]
193
+ node.state.meta.node_outputs[output_descriptor] = (
194
+ streaming_output_queues[output.name]
195
+ )
196
+ initiated_output: BaseOutput = BaseOutput(name=output.name)
197
+ initiated_ports = initiated_output > ports
189
198
  self._work_item_event_queue.put(
190
199
  WorkItemEvent(
191
200
  node=node,
@@ -194,14 +203,14 @@ class WorkflowRunner(Generic[StateType]):
194
203
  span_id=span_id,
195
204
  body=NodeExecutionStreamingBody(
196
205
  node_definition=node.__class__,
197
- output=BaseOutput(name=output.name),
198
- invoked_ports=invoked_ports,
206
+ output=initiated_output,
207
+ invoked_ports=initiated_ports,
199
208
  ),
200
209
  parent=WorkflowParentContext(
201
210
  span_id=span_id,
202
211
  workflow_definition=self.workflow.__class__,
203
- parent=self._parent_context
204
- )
212
+ parent=self._parent_context,
213
+ ),
205
214
  ),
206
215
  )
207
216
  )
@@ -230,7 +239,7 @@ class WorkflowRunner(Generic[StateType]):
230
239
  span_id=span_id,
231
240
  workflow_definition=self.workflow.__class__,
232
241
  parent=self._parent_context,
233
- )
242
+ ),
234
243
  ),
235
244
  )
236
245
  )
@@ -254,7 +263,7 @@ class WorkflowRunner(Generic[StateType]):
254
263
  span_id=span_id,
255
264
  workflow_definition=self.workflow.__class__,
256
265
  parent=self._parent_context,
257
- )
266
+ ),
258
267
  ),
259
268
  )
260
269
  )
@@ -268,7 +277,9 @@ class WorkflowRunner(Generic[StateType]):
268
277
  node.state.meta.node_outputs[descriptor] = output_value
269
278
 
270
279
  invoked_ports = ports(outputs, node.state)
271
- node.state.meta.node_execution_cache.fulfill_node_execution(node.__class__, span_id)
280
+ node.state.meta.node_execution_cache.fulfill_node_execution(
281
+ node.__class__, span_id
282
+ )
272
283
 
273
284
  self._work_item_event_queue.put(
274
285
  WorkItemEvent(
@@ -285,7 +296,7 @@ class WorkflowRunner(Generic[StateType]):
285
296
  span_id=span_id,
286
297
  workflow_definition=self.workflow.__class__,
287
298
  parent=self._parent_context,
288
- )
299
+ ),
289
300
  ),
290
301
  )
291
302
  )
@@ -304,12 +315,14 @@ class WorkflowRunner(Generic[StateType]):
304
315
  span_id=span_id,
305
316
  workflow_definition=self.workflow.__class__,
306
317
  parent=self._parent_context,
307
- )
318
+ ),
308
319
  ),
309
320
  )
310
321
  )
311
322
  except Exception as e:
312
- logger.exception(f"An unexpected error occurred while running node {node.__class__.__name__}")
323
+ logger.exception(
324
+ f"An unexpected error occurred while running node {node.__class__.__name__}"
325
+ )
313
326
 
314
327
  self._work_item_event_queue.put(
315
328
  WorkItemEvent(
@@ -327,15 +340,17 @@ class WorkflowRunner(Generic[StateType]):
327
340
  parent=WorkflowParentContext(
328
341
  span_id=span_id,
329
342
  workflow_definition=self.workflow.__class__,
330
- parent=self._parent_context
331
- )
343
+ parent=self._parent_context,
344
+ ),
332
345
  ),
333
346
  )
334
347
  )
335
348
 
336
349
  logger.debug(f"Finished running node: {node.__class__.__name__}")
337
350
 
338
- def _handle_invoked_ports(self, state: StateType, ports: Optional[Iterable[Port]]) -> None:
351
+ def _handle_invoked_ports(
352
+ self, state: StateType, ports: Optional[Iterable[Port]]
353
+ ) -> None:
339
354
  if not ports:
340
355
  return
341
356
 
@@ -350,7 +365,10 @@ class WorkflowRunner(Generic[StateType]):
350
365
  self._run_node_if_ready(next_state, edge.to_node, edge)
351
366
 
352
367
  def _run_node_if_ready(
353
- self, state: StateType, node_class: Type[BaseNode], invoked_by: Optional[Edge] = None
368
+ self,
369
+ state: StateType,
370
+ node_class: Type[BaseNode],
371
+ invoked_by: Optional[Edge] = None,
354
372
  ) -> None:
355
373
  with state.__lock__:
356
374
  for descriptor in node_class.ExternalInputs:
@@ -362,18 +380,27 @@ class WorkflowRunner(Generic[StateType]):
362
380
  return
363
381
 
364
382
  all_deps = self._dependencies[node_class]
365
- if not node_class.Trigger.should_initiate(state, all_deps, invoked_by):
383
+ node_span_id = state.meta.node_execution_cache.queue_node_execution(
384
+ node_class, all_deps, invoked_by
385
+ )
386
+ if not node_class.Trigger.should_initiate(state, all_deps, node_span_id):
366
387
  return
367
388
 
368
389
  node = node_class(state=state, context=self.workflow.context)
369
- node_span_id = uuid4()
370
- state.meta.node_execution_cache.initiate_node_execution(node_class, node_span_id)
390
+ state.meta.node_execution_cache.initiate_node_execution(
391
+ node_class, node_span_id
392
+ )
371
393
  self._active_nodes_by_execution_id[node_span_id] = node
372
394
 
373
- worker_thread = Thread(target=self._run_work_item, kwargs={"node": node, "span_id": node_span_id})
395
+ worker_thread = Thread(
396
+ target=self._run_work_item,
397
+ kwargs={"node": node, "span_id": node_span_id},
398
+ )
374
399
  worker_thread.start()
375
400
 
376
- def _handle_work_item_event(self, work_item_event: WorkItemEvent[StateType]) -> Optional[VellumError]:
401
+ def _handle_work_item_event(
402
+ self, work_item_event: WorkItemEvent[StateType]
403
+ ) -> Optional[VellumError]:
377
404
  node = work_item_event.node
378
405
  event = work_item_event.event
379
406
 
@@ -389,7 +416,10 @@ class WorkflowRunner(Generic[StateType]):
389
416
  node_output_descriptor = workflow_output_descriptor.instance
390
417
  if not isinstance(node_output_descriptor, OutputReference):
391
418
  continue
392
- if node_output_descriptor.outputs_class != event.node_definition.Outputs:
419
+ if (
420
+ node_output_descriptor.outputs_class
421
+ != event.node_definition.Outputs
422
+ ):
393
423
  continue
394
424
  if node_output_descriptor.name != event.output.name:
395
425
  continue
@@ -427,7 +457,9 @@ class WorkflowRunner(Generic[StateType]):
427
457
  parent=self._parent_context,
428
458
  )
429
459
 
430
- def _stream_workflow_event(self, output: BaseOutput) -> WorkflowExecutionStreamingEvent:
460
+ def _stream_workflow_event(
461
+ self, output: BaseOutput
462
+ ) -> WorkflowExecutionStreamingEvent:
431
463
  return WorkflowExecutionStreamingEvent(
432
464
  trace_id=self._initial_state.meta.trace_id,
433
465
  span_id=self._initial_state.meta.span_id,
@@ -435,10 +467,12 @@ class WorkflowRunner(Generic[StateType]):
435
467
  workflow_definition=self.workflow.__class__,
436
468
  output=output,
437
469
  ),
438
- parent=self._parent_context
470
+ parent=self._parent_context,
439
471
  )
440
472
 
441
- def _fulfill_workflow_event(self, outputs: OutputsType) -> WorkflowExecutionFulfilledEvent:
473
+ def _fulfill_workflow_event(
474
+ self, outputs: OutputsType
475
+ ) -> WorkflowExecutionFulfilledEvent:
442
476
  return WorkflowExecutionFulfilledEvent(
443
477
  trace_id=self._initial_state.meta.trace_id,
444
478
  span_id=self._initial_state.meta.span_id,
@@ -449,7 +483,9 @@ class WorkflowRunner(Generic[StateType]):
449
483
  parent=self._parent_context,
450
484
  )
451
485
 
452
- def _reject_workflow_event(self, error: VellumError) -> WorkflowExecutionRejectedEvent:
486
+ def _reject_workflow_event(
487
+ self, error: VellumError
488
+ ) -> WorkflowExecutionRejectedEvent:
453
489
  return WorkflowExecutionRejectedEvent(
454
490
  trace_id=self._initial_state.meta.trace_id,
455
491
  span_id=self._initial_state.meta.span_id,
@@ -469,7 +505,9 @@ class WorkflowRunner(Generic[StateType]):
469
505
  ),
470
506
  )
471
507
 
472
- def _pause_workflow_event(self, external_inputs: Iterable[ExternalInputReference]) -> WorkflowExecutionPausedEvent:
508
+ def _pause_workflow_event(
509
+ self, external_inputs: Iterable[ExternalInputReference]
510
+ ) -> WorkflowExecutionPausedEvent:
473
511
  return WorkflowExecutionPausedEvent(
474
512
  trace_id=self._initial_state.meta.trace_id,
475
513
  span_id=self._initial_state.meta.span_id,
@@ -486,7 +524,10 @@ class WorkflowRunner(Generic[StateType]):
486
524
  if not self._entrypoints:
487
525
  self._workflow_event_queue.put(
488
526
  self._reject_workflow_event(
489
- VellumError(message="No entrypoints defined", code=VellumErrorCode.INVALID_WORKFLOW)
527
+ VellumError(
528
+ message="No entrypoints defined",
529
+ code=VellumErrorCode.INVALID_WORKFLOW,
530
+ )
490
531
  )
491
532
  )
492
533
  return
@@ -505,7 +546,9 @@ class WorkflowRunner(Generic[StateType]):
505
546
  logger.exception(err_message)
506
547
  self._workflow_event_queue.put(
507
548
  self._reject_workflow_event(
508
- VellumError(code=VellumErrorCode.INTERNAL_ERROR, message=err_message),
549
+ VellumError(
550
+ code=VellumErrorCode.INTERNAL_ERROR, message=err_message
551
+ ),
509
552
  )
510
553
  )
511
554
  return
@@ -561,7 +604,11 @@ class WorkflowRunner(Generic[StateType]):
561
604
  if isinstance(value, BaseDescriptor):
562
605
  setattr(fulfilled_outputs, descriptor.name, value.resolve(final_state))
563
606
  elif isinstance(descriptor.instance, BaseDescriptor):
564
- setattr(fulfilled_outputs, descriptor.name, descriptor.instance.resolve(final_state))
607
+ setattr(
608
+ fulfilled_outputs,
609
+ descriptor.name,
610
+ descriptor.instance.resolve(final_state),
611
+ )
565
612
 
566
613
  self._workflow_event_queue.put(self._fulfill_workflow_event(fulfilled_outputs))
567
614
 
@@ -586,24 +633,41 @@ class WorkflowRunner(Generic[StateType]):
586
633
  self._cancel_signal.wait()
587
634
  self._workflow_event_queue.put(
588
635
  self._reject_workflow_event(
589
- VellumError(code=VellumErrorCode.WORKFLOW_CANCELLED, message="Workflow run cancelled")
636
+ VellumError(
637
+ code=VellumErrorCode.WORKFLOW_CANCELLED,
638
+ message="Workflow run cancelled",
639
+ )
590
640
  )
591
641
  )
592
642
 
643
+ def _is_terminal_event(self, event: WorkflowEvent) -> bool:
644
+ if (
645
+ event.name == "workflow.execution.fulfilled"
646
+ or event.name == "workflow.execution.rejected"
647
+ or event.name == "workflow.execution.paused"
648
+ ):
649
+ return event.workflow_definition == self.workflow.__class__
650
+ return False
651
+
593
652
  def stream(self) -> WorkflowEventStream:
594
653
  background_thread = Thread(
595
- target=self._run_background_thread, name=f"{self.workflow.__class__.__name__}.background_thread"
654
+ target=self._run_background_thread,
655
+ name=f"{self.workflow.__class__.__name__}.background_thread",
596
656
  )
597
657
  background_thread.start()
598
658
 
599
659
  if self._cancel_signal:
600
660
  cancel_thread = Thread(
601
- target=self._run_cancel_thread, name=f"{self.workflow.__class__.__name__}.cancel_thread"
661
+ target=self._run_cancel_thread,
662
+ name=f"{self.workflow.__class__.__name__}.cancel_thread",
602
663
  )
603
664
  cancel_thread.start()
604
665
 
605
666
  event: WorkflowEvent
606
- if self._initial_state.meta.is_terminated or self._initial_state.meta.is_terminated is None:
667
+ if (
668
+ self._initial_state.meta.is_terminated
669
+ or self._initial_state.meta.is_terminated is None
670
+ ):
607
671
  event = self._initiate_workflow_event()
608
672
  else:
609
673
  event = self._resume_workflow_event()
@@ -612,7 +676,10 @@ class WorkflowRunner(Generic[StateType]):
612
676
  self._initial_state.meta.is_terminated = False
613
677
 
614
678
  # The extra level of indirection prevents the runner from waiting on the caller to consume the event stream
615
- stream_thread = Thread(target=self._stream, name=f"{self.workflow.__class__.__name__}.stream_thread")
679
+ stream_thread = Thread(
680
+ target=self._stream,
681
+ name=f"{self.workflow.__class__.__name__}.stream_thread",
682
+ )
616
683
  stream_thread.start()
617
684
 
618
685
  while stream_thread.is_alive():
@@ -623,7 +690,7 @@ class WorkflowRunner(Generic[StateType]):
623
690
 
624
691
  yield self._emit_event(event)
625
692
 
626
- if is_terminal_event(event):
693
+ if self._is_terminal_event(event):
627
694
  break
628
695
 
629
696
  try:
@@ -632,7 +699,7 @@ class WorkflowRunner(Generic[StateType]):
632
699
  except Empty:
633
700
  pass
634
701
 
635
- if not is_terminal_event(event):
702
+ if not self._is_terminal_event(event):
636
703
  yield self._reject_workflow_event(
637
704
  VellumError(
638
705
  code=VellumErrorCode.INTERNAL_ERROR,
@@ -5,15 +5,15 @@ from datetime import datetime
5
5
  from queue import Queue
6
6
  from threading import Lock
7
7
  from uuid import UUID, uuid4
8
- from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, Optional, Sequence, Set, Tuple, Type, cast
8
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, cast
9
9
  from typing_extensions import dataclass_transform
10
10
 
11
11
  from pydantic import GetCoreSchemaHandler, field_serializer
12
12
  from pydantic_core import core_schema
13
13
 
14
14
  from vellum.core.pydantic_utilities import UniversalBaseModel
15
-
16
15
  from vellum.workflows.constants import UNDEF
16
+ from vellum.workflows.edges.edge import Edge
17
17
  from vellum.workflows.inputs.base import BaseInputs
18
18
  from vellum.workflows.references import ExternalInputReference, OutputReference, StateValueReference
19
19
  from vellum.workflows.types.generics import StateType
@@ -71,58 +71,92 @@ def _make_snapshottable(value: Any, snapshot_callback: Callable[[], None]) -> An
71
71
 
72
72
 
73
73
  class NodeExecutionCache:
74
- _node_execution_ids: Dict[Type["BaseNode"], Stack[UUID]]
74
+ _node_executions_fulfilled: Dict[Type["BaseNode"], Stack[UUID]]
75
75
  _node_executions_initiated: Dict[Type["BaseNode"], Set[UUID]]
76
- _dependencies_invoked: Dict[str, Set[str]]
76
+ _node_executions_queued: Dict[Type["BaseNode"], List[UUID]]
77
+ _dependencies_invoked: Dict[UUID, Set[Type["BaseNode"]]]
77
78
 
78
79
  def __init__(
79
80
  self,
80
81
  dependencies_invoked: Optional[Dict[str, Sequence[str]]] = None,
81
- node_execution_ids: Optional[Dict[str, Sequence[str]]] = None,
82
+ node_executions_fulfilled: Optional[Dict[str, Sequence[str]]] = None,
82
83
  node_executions_initiated: Optional[Dict[str, Sequence[str]]] = None,
84
+ node_executions_queued: Optional[Dict[str, Sequence[str]]] = None,
83
85
  ) -> None:
84
86
  self._dependencies_invoked = defaultdict(set)
85
- self._node_execution_ids = defaultdict(Stack[UUID])
87
+ self._node_executions_fulfilled = defaultdict(Stack[UUID])
86
88
  self._node_executions_initiated = defaultdict(set)
89
+ self._node_executions_queued = defaultdict(list)
87
90
 
88
- for node, dependencies in (dependencies_invoked or {}).items():
89
- self._dependencies_invoked[node].update(dependencies)
91
+ for execution_id, dependencies in (dependencies_invoked or {}).items():
92
+ self._dependencies_invoked[UUID(execution_id)] = {get_class_by_qualname(dep) for dep in dependencies}
90
93
 
91
- for node, execution_ids in (node_execution_ids or {}).items():
94
+ for node, execution_ids in (node_executions_fulfilled or {}).items():
92
95
  node_class = get_class_by_qualname(node)
93
- self._node_execution_ids[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
96
+ self._node_executions_fulfilled[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
94
97
 
95
98
  for node, execution_ids in (node_executions_initiated or {}).items():
96
99
  node_class = get_class_by_qualname(node)
97
100
  self._node_executions_initiated[node_class].update({UUID(execution_id) for execution_id in execution_ids})
98
101
 
99
- @property
100
- def dependencies_invoked(self) -> Dict[str, Set[str]]:
101
- return self._dependencies_invoked
102
+ for node, execution_ids in (node_executions_queued or {}).items():
103
+ node_class = get_class_by_qualname(node)
104
+ self._node_executions_queued[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
102
105
 
103
- def is_node_initiated(self, node: Type["BaseNode"]) -> bool:
104
- return node in self._node_executions_initiated and len(self._node_executions_initiated[node]) > 0
106
+ def _invoke_dependency(
107
+ self,
108
+ execution_id: UUID,
109
+ node: Type["BaseNode"],
110
+ dependency: Type["BaseNode"],
111
+ dependencies: Set["Type[BaseNode]"],
112
+ ) -> None:
113
+ self._dependencies_invoked[execution_id].add(dependency)
114
+ if all(dep in self._dependencies_invoked[execution_id] for dep in dependencies):
115
+ self._node_executions_queued[node].remove(execution_id)
116
+
117
+ def queue_node_execution(
118
+ self, node: Type["BaseNode"], dependencies: Set["Type[BaseNode]"], invoked_by: Optional[Edge] = None
119
+ ) -> UUID:
120
+ execution_id = uuid4()
121
+ if not invoked_by:
122
+ return execution_id
123
+
124
+ source_node = invoked_by.from_port.node_class
125
+ for queued_node_execution_id in self._node_executions_queued[node]:
126
+ if source_node not in self._dependencies_invoked[queued_node_execution_id]:
127
+ self._invoke_dependency(queued_node_execution_id, node, source_node, dependencies)
128
+ return queued_node_execution_id
129
+
130
+ self._node_executions_queued[node].append(execution_id)
131
+ self._invoke_dependency(execution_id, node, source_node, dependencies)
132
+ return execution_id
133
+
134
+ def is_node_execution_initiated(self, node: Type["BaseNode"], execution_id: UUID) -> bool:
135
+ return execution_id in self._node_executions_initiated[node]
105
136
 
106
137
  def initiate_node_execution(self, node: Type["BaseNode"], execution_id: UUID) -> None:
107
138
  self._node_executions_initiated[node].add(execution_id)
108
139
 
109
140
  def fulfill_node_execution(self, node: Type["BaseNode"], execution_id: UUID) -> None:
110
- self._node_executions_initiated[node].remove(execution_id)
111
- self._node_execution_ids[node].push(execution_id)
141
+ self._node_executions_fulfilled[node].push(execution_id)
112
142
 
113
143
  def get_execution_count(self, node: Type["BaseNode"]) -> int:
114
- return self._node_execution_ids[node].size()
144
+ return self._node_executions_fulfilled[node].size()
115
145
 
116
146
  def dump(self) -> Dict[str, Any]:
117
147
  return {
118
148
  "dependencies_invoked": {
119
- node: list(dependencies) for node, dependencies in self._dependencies_invoked.items()
149
+ str(execution_id): [str(dep) for dep in dependencies]
150
+ for execution_id, dependencies in self._dependencies_invoked.items()
120
151
  },
121
152
  "node_executions_initiated": {
122
153
  str(node): list(execution_ids) for node, execution_ids in self._node_executions_initiated.items()
123
154
  },
124
- "node_execution_ids": {
125
- str(node): execution_ids.dump() for node, execution_ids in self._node_execution_ids.items()
155
+ "node_executions_fulfilled": {
156
+ str(node): execution_ids.dump() for node, execution_ids in self._node_executions_fulfilled.items()
157
+ },
158
+ "node_executions_queued": {
159
+ str(node): execution_ids for node, execution_ids in self._node_executions_queued.items()
126
160
  },
127
161
  }
128
162
 
@@ -1,14 +1,24 @@
1
1
  from functools import cached_property
2
- from typing import Optional
2
+ from queue import Queue
3
+ from typing import TYPE_CHECKING, Optional
3
4
 
4
5
  from vellum import Vellum
5
-
6
+ from vellum.workflows.events.types import ParentContext
6
7
  from vellum.workflows.vellum_client import create_vellum_client
7
8
 
9
+ if TYPE_CHECKING:
10
+ from vellum.workflows.events.workflow import WorkflowEvent
11
+
8
12
 
9
13
  class WorkflowContext:
10
- def __init__(self, _vellum_client: Optional[Vellum] = None):
14
+ def __init__(
15
+ self,
16
+ _vellum_client: Optional[Vellum] = None,
17
+ _parent_context: Optional[ParentContext] = None,
18
+ ):
11
19
  self._vellum_client = _vellum_client
20
+ self._parent_context = _parent_context
21
+ self._event_queue: Optional[Queue["WorkflowEvent"]] = None
12
22
 
13
23
  @cached_property
14
24
  def vellum_client(self) -> Vellum:
@@ -16,3 +26,16 @@ class WorkflowContext:
16
26
  return self._vellum_client
17
27
 
18
28
  return create_vellum_client()
29
+
30
+ @cached_property
31
+ def parent_context(self) -> Optional[ParentContext]:
32
+ if self._parent_context:
33
+ return self._parent_context
34
+ return None
35
+
36
+ def _emit_subworkflow_event(self, event: "WorkflowEvent") -> None:
37
+ if self._event_queue:
38
+ self._event_queue.put(event)
39
+
40
+ def _register_event_queue(self, event_queue: Queue["WorkflowEvent"]) -> None:
41
+ self._event_queue = event_queue
@@ -83,7 +83,7 @@ EntityInputsInterface = Dict[
83
83
  class MergeBehavior(Enum):
84
84
  AWAIT_ALL = "AWAIT_ALL"
85
85
  AWAIT_ANY = "AWAIT_ANY"
86
-
86
+ AWAIT_ATTRIBUTES = "AWAIT_ATTRIBUTES"
87
87
 
88
88
  class ConditionType(Enum):
89
89
  IF = "IF"