vellum-ai 0.10.7__py3-none-any.whl → 0.10.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. vellum/client/core/client_wrapper.py +1 -1
  2. vellum/client/types/logical_operator.py +2 -0
  3. vellum/workflows/descriptors/utils.py +27 -0
  4. vellum/workflows/events/__init__.py +0 -2
  5. vellum/workflows/events/tests/test_event.py +2 -1
  6. vellum/workflows/events/types.py +36 -30
  7. vellum/workflows/events/workflow.py +14 -7
  8. vellum/workflows/nodes/bases/base.py +100 -38
  9. vellum/workflows/nodes/core/inline_subworkflow_node/node.py +1 -0
  10. vellum/workflows/nodes/core/templating_node/node.py +5 -0
  11. vellum/workflows/nodes/core/try_node/node.py +22 -4
  12. vellum/workflows/nodes/core/try_node/tests/test_node.py +15 -0
  13. vellum/workflows/nodes/displayable/api_node/node.py +1 -1
  14. vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +1 -2
  15. vellum/workflows/nodes/displayable/code_execution_node/node.py +1 -2
  16. vellum/workflows/nodes/displayable/code_execution_node/utils.py +13 -2
  17. vellum/workflows/nodes/displayable/inline_prompt_node/node.py +10 -3
  18. vellum/workflows/nodes/displayable/prompt_deployment_node/node.py +6 -1
  19. vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +1 -2
  20. vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +1 -2
  21. vellum/workflows/runner/runner.py +141 -32
  22. vellum/workflows/state/base.py +55 -21
  23. vellum/workflows/state/context.py +26 -3
  24. vellum/workflows/types/__init__.py +5 -0
  25. vellum/workflows/types/core.py +1 -1
  26. vellum/workflows/workflows/base.py +51 -17
  27. vellum/workflows/workflows/event_filters.py +61 -0
  28. {vellum_ai-0.10.7.dist-info → vellum_ai-0.10.9.dist-info}/METADATA +1 -1
  29. {vellum_ai-0.10.7.dist-info → vellum_ai-0.10.9.dist-info}/RECORD +40 -38
  30. vellum_cli/__init__.py +23 -4
  31. vellum_cli/pull.py +28 -13
  32. vellum_cli/tests/test_pull.py +45 -2
  33. vellum_ee/workflows/display/nodes/base_node_display.py +1 -1
  34. vellum_ee/workflows/display/nodes/vellum/__init__.py +6 -4
  35. vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +17 -2
  36. vellum_ee/workflows/display/nodes/vellum/error_node.py +49 -0
  37. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +203 -0
  38. vellum/workflows/events/utils.py +0 -5
  39. {vellum_ai-0.10.7.dist-info → vellum_ai-0.10.9.dist-info}/LICENSE +0 -0
  40. {vellum_ai-0.10.7.dist-info → vellum_ai-0.10.9.dist-info}/WHEEL +0 -0
  41. {vellum_ai-0.10.7.dist-info → vellum_ai-0.10.9.dist-info}/entry_points.txt +0 -0
@@ -67,9 +67,8 @@ class CodeExecutionNode(BaseNode[StateType], Generic[StateType, _OutputType], me
67
67
 
68
68
  filepath: str - The path to the script to execute.
69
69
  code_inputs: EntityInputsInterface - The inputs for the custom script.
70
- output_type: VellumVariableType = "STRING" - The type of the output from the custom script.
71
70
  runtime: CodeExecutionRuntime = "PYTHON_3_12" - The runtime to use for the custom script.
72
- packages: Optional[Sequence[CodeExecutionPackageRequest]] = None - The packages to use for the custom script.
71
+ packages: Optional[Sequence[CodeExecutionPackage]] = None - The packages to use for the custom script.
73
72
  request_options: Optional[RequestOptions] = None - The request options to use for the custom script.
74
73
  """
75
74
 
@@ -2,9 +2,20 @@ import os
2
2
  from typing import Union
3
3
 
4
4
 
5
+ def get_project_root() -> str:
6
+ current_dir = os.getcwd()
7
+ while current_dir != '/':
8
+ if ".git" in os.listdir(current_dir):
9
+ return current_dir
10
+ current_dir = os.path.dirname(current_dir)
11
+ raise FileNotFoundError("Project root not found.")
12
+
5
13
  def read_file_from_path(filepath: str) -> Union[str, None]:
6
- if not os.path.exists(filepath):
14
+ project_root = get_project_root()
15
+ relative_filepath = os.path.join(project_root, filepath)
16
+
17
+ if not os.path.exists(relative_filepath):
7
18
  return None
8
19
 
9
- with open(filepath) as file:
20
+ with open(relative_filepath, 'r') as file:
10
21
  return file.read()
@@ -9,16 +9,23 @@ from vellum.workflows.types.generics import StateType
9
9
 
10
10
  class InlinePromptNode(BaseInlinePromptNode[StateType]):
11
11
  """
12
- Used to execute an Inline Prompt and surface a string output for convenience.
12
+ Used to execute a Prompt defined inline.
13
13
 
14
14
  prompt_inputs: EntityInputsInterface - The inputs for the Prompt
15
15
  ml_model: str - Either the ML Model's UUID or its name.
16
- blocks: List[PromptBlockRequest] - The blocks that make up the Prompt
16
+ blocks: List[PromptBlock] - The blocks that make up the Prompt
17
+ functions: Optional[List[FunctionDefinition]] - The functions to include in the Prompt
17
18
  parameters: PromptParameters - The parameters for the Prompt
18
- expand_meta: Optional[AdHocExpandMetaRequest] - Set of expandable execution fields to include in the response
19
+ expand_meta: Optional[AdHocExpandMeta] - Expandable execution fields to include in the response
20
+ request_options: Optional[RequestOptions] - The request options to use for the Prompt Execution
19
21
  """
20
22
 
21
23
  class Outputs(BaseInlinePromptNode.Outputs):
24
+ """
25
+ The outputs of the InlinePromptNode.
26
+
27
+ text: str - The result of the Prompt Execution
28
+ """
22
29
  text: str
23
30
 
24
31
  def run(self) -> Iterator[BaseOutput]:
@@ -14,7 +14,7 @@ class PromptDeploymentNode(BasePromptDeploymentNode[StateType]):
14
14
  prompt_inputs: EntityInputsInterface - The inputs for the Prompt
15
15
  deployment: Union[UUID, str] - Either the Prompt Deployment's UUID or its name.
16
16
  release_tag: str - The release tag to use for the Prompt Execution
17
- external_id: Optional[str] - The external ID to use for the Prompt Execution
17
+ external_id: Optional[str] - Optionally include a unique identifier for tracking purposes. Must be unique within a given Prompt Deployment.
18
18
  expand_meta: Optional[PromptDeploymentExpandMetaRequest] - Expandable execution fields to include in the response
19
19
  raw_overrides: Optional[RawPromptExecutionOverridesRequest] - The raw overrides to use for the Prompt Execution
20
20
  expand_raw: Optional[Sequence[str]] - Expandable raw fields to include in the response
@@ -23,6 +23,11 @@ class PromptDeploymentNode(BasePromptDeploymentNode[StateType]):
23
23
  """
24
24
 
25
25
  class Outputs(BasePromptDeploymentNode.Outputs):
26
+ """
27
+ The outputs of the PromptDeploymentNode.
28
+
29
+ text: str - The result of the Prompt Execution
30
+ """
26
31
  text: str
27
32
 
28
33
  def run(self) -> Iterator[BaseOutput]:
@@ -12,7 +12,6 @@ from vellum import (
12
12
  WorkflowRequestStringInputRequest,
13
13
  )
14
14
  from vellum.core import RequestOptions
15
-
16
15
  from vellum.workflows.constants import LATEST_RELEASE_TAG, OMIT
17
16
  from vellum.workflows.errors import VellumErrorCode
18
17
  from vellum.workflows.exceptions import NodeException
@@ -28,7 +27,7 @@ class SubworkflowDeploymentNode(BaseSubworkflowNode[StateType], Generic[StateTyp
28
27
  subworkflow_inputs: EntityInputsInterface - The inputs for the Subworkflow
29
28
  deployment: Union[UUID, str] - Either the Workflow Deployment's UUID or its name.
30
29
  release_tag: str = LATEST_RELEASE_TAG - The release tag to use for the Workflow Execution
31
- external_id: Optional[str] = OMIT - The external ID to use for the Workflow Execution
30
+ external_id: Optional[str] = OMIT - Optionally include a unique identifier for tracking purposes. Must be unique within a given Workflow Deployment.
32
31
  expand_meta: Optional[WorkflowExpandMetaRequest] = OMIT - Expandable execution fields to include in the respownse
33
32
  metadata: Optional[Dict[str, Optional[Any]]] = OMIT - The metadata to use for the Workflow Execution
34
33
  request_options: Optional[RequestOptions] = None - The request options to use for the Workflow Execution
@@ -8,7 +8,6 @@ from vellum import (
8
8
  PromptOutput,
9
9
  StringVellumValue,
10
10
  )
11
-
12
11
  from vellum.workflows.constants import OMIT
13
12
  from vellum.workflows.inputs import BaseInputs
14
13
  from vellum.workflows.nodes import PromptDeploymentNode
@@ -65,7 +64,7 @@ def test_text_prompt_deployment_node__basic(vellum_client):
65
64
  assert text_output.name == "text"
66
65
  assert text_output.value == "Hello, world!"
67
66
 
68
- # AND we should have made the expected call to Vellum search
67
+ # AND we should have made the expected call to stream the prompt execution
69
68
  vellum_client.execute_prompt_stream.assert_called_once_with(
70
69
  expand_meta=OMIT,
71
70
  expand_raw=OMIT,
@@ -3,7 +3,7 @@ from copy import deepcopy
3
3
  import logging
4
4
  from queue import Empty, Queue
5
5
  from threading import Event as ThreadingEvent, Thread
6
- from uuid import UUID, uuid4
6
+ from uuid import UUID
7
7
  from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Optional, Sequence, Set, Type, Union
8
8
 
9
9
  from vellum.workflows.constants import UNDEF
@@ -28,8 +28,7 @@ from vellum.workflows.events.node import (
28
28
  NodeExecutionRejectedBody,
29
29
  NodeExecutionStreamingBody,
30
30
  )
31
- from vellum.workflows.events.types import BaseEvent
32
- from vellum.workflows.events.utils import is_terminal_event
31
+ from vellum.workflows.events.types import BaseEvent, ParentContext, WorkflowParentContext
33
32
  from vellum.workflows.events.workflow import (
34
33
  WorkflowExecutionFulfilledBody,
35
34
  WorkflowExecutionInitiatedBody,
@@ -71,9 +70,12 @@ class WorkflowRunner(Generic[StateType]):
71
70
  entrypoint_nodes: Optional[RunFromNodeArg] = None,
72
71
  external_inputs: Optional[ExternalInputsArg] = None,
73
72
  cancel_signal: Optional[ThreadingEvent] = None,
73
+ parent_context: Optional[ParentContext] = None,
74
74
  ):
75
75
  if state and external_inputs:
76
- raise ValueError("Can only run a Workflow providing one of state or external inputs, not both")
76
+ raise ValueError(
77
+ "Can only run a Workflow providing one of state or external inputs, not both"
78
+ )
77
79
 
78
80
  self.workflow = workflow
79
81
  if entrypoint_nodes:
@@ -99,7 +101,9 @@ class WorkflowRunner(Generic[StateType]):
99
101
  if issubclass(ei.inputs_class.__parent_class__, BaseNode)
100
102
  ]
101
103
  else:
102
- normalized_inputs = deepcopy(inputs) if inputs else self.workflow.get_default_inputs()
104
+ normalized_inputs = (
105
+ deepcopy(inputs) if inputs else self.workflow.get_default_inputs()
106
+ )
103
107
  if state:
104
108
  self._initial_state = deepcopy(state)
105
109
  self._initial_state.meta.workflow_inputs = normalized_inputs
@@ -115,12 +119,14 @@ class WorkflowRunner(Generic[StateType]):
115
119
 
116
120
  self._active_nodes_by_execution_id: Dict[UUID, BaseNode[StateType]] = {}
117
121
  self._cancel_signal = cancel_signal
122
+ self._parent_context = parent_context
118
123
 
119
124
  setattr(
120
125
  self._initial_state,
121
126
  "__snapshot_callback__",
122
127
  lambda s: self._snapshot_state(s),
123
128
  )
129
+ self.workflow.context._register_event_queue(self._workflow_event_queue)
124
130
 
125
131
  def _snapshot_state(self, state: StateType) -> StateType:
126
132
  self.workflow._store.append_state_snapshot(state)
@@ -143,6 +149,12 @@ class WorkflowRunner(Generic[StateType]):
143
149
  node_definition=node.__class__,
144
150
  inputs=node._inputs,
145
151
  ),
152
+ parent=WorkflowParentContext(
153
+ span_id=span_id,
154
+ workflow_definition=self.workflow.__class__,
155
+ parent=self._parent_context,
156
+ type="WORKFLOW",
157
+ ),
146
158
  ),
147
159
  )
148
160
  )
@@ -178,7 +190,11 @@ class WorkflowRunner(Generic[StateType]):
178
190
  instance=None,
179
191
  outputs_class=node.Outputs,
180
192
  )
181
- node.state.meta.node_outputs[output_descriptor] = streaming_output_queues[output.name]
193
+ node.state.meta.node_outputs[output_descriptor] = (
194
+ streaming_output_queues[output.name]
195
+ )
196
+ initiated_output: BaseOutput = BaseOutput(name=output.name)
197
+ initiated_ports = initiated_output > ports
182
198
  self._work_item_event_queue.put(
183
199
  WorkItemEvent(
184
200
  node=node,
@@ -187,8 +203,13 @@ class WorkflowRunner(Generic[StateType]):
187
203
  span_id=span_id,
188
204
  body=NodeExecutionStreamingBody(
189
205
  node_definition=node.__class__,
190
- output=BaseOutput(name=output.name),
191
- invoked_ports=invoked_ports,
206
+ output=initiated_output,
207
+ invoked_ports=initiated_ports,
208
+ ),
209
+ parent=WorkflowParentContext(
210
+ span_id=span_id,
211
+ workflow_definition=self.workflow.__class__,
212
+ parent=self._parent_context,
192
213
  ),
193
214
  ),
194
215
  )
@@ -214,6 +235,11 @@ class WorkflowRunner(Generic[StateType]):
214
235
  output=output,
215
236
  invoked_ports=invoked_ports,
216
237
  ),
238
+ parent=WorkflowParentContext(
239
+ span_id=span_id,
240
+ workflow_definition=self.workflow.__class__,
241
+ parent=self._parent_context,
242
+ ),
217
243
  ),
218
244
  )
219
245
  )
@@ -233,6 +259,11 @@ class WorkflowRunner(Generic[StateType]):
233
259
  output=output,
234
260
  invoked_ports=invoked_ports,
235
261
  ),
262
+ parent=WorkflowParentContext(
263
+ span_id=span_id,
264
+ workflow_definition=self.workflow.__class__,
265
+ parent=self._parent_context,
266
+ ),
236
267
  ),
237
268
  )
238
269
  )
@@ -246,7 +277,9 @@ class WorkflowRunner(Generic[StateType]):
246
277
  node.state.meta.node_outputs[descriptor] = output_value
247
278
 
248
279
  invoked_ports = ports(outputs, node.state)
249
- node.state.meta.node_execution_cache.fulfill_node_execution(node.__class__, span_id)
280
+ node.state.meta.node_execution_cache.fulfill_node_execution(
281
+ node.__class__, span_id
282
+ )
250
283
 
251
284
  self._work_item_event_queue.put(
252
285
  WorkItemEvent(
@@ -259,6 +292,11 @@ class WorkflowRunner(Generic[StateType]):
259
292
  outputs=outputs,
260
293
  invoked_ports=invoked_ports,
261
294
  ),
295
+ parent=WorkflowParentContext(
296
+ span_id=span_id,
297
+ workflow_definition=self.workflow.__class__,
298
+ parent=self._parent_context,
299
+ ),
262
300
  ),
263
301
  )
264
302
  )
@@ -273,11 +311,18 @@ class WorkflowRunner(Generic[StateType]):
273
311
  node_definition=node.__class__,
274
312
  error=e.error,
275
313
  ),
314
+ parent=WorkflowParentContext(
315
+ span_id=span_id,
316
+ workflow_definition=self.workflow.__class__,
317
+ parent=self._parent_context,
318
+ ),
276
319
  ),
277
320
  )
278
321
  )
279
322
  except Exception as e:
280
- logger.exception(f"An unexpected error occurred while running node {node.__class__.__name__}")
323
+ logger.exception(
324
+ f"An unexpected error occurred while running node {node.__class__.__name__}"
325
+ )
281
326
 
282
327
  self._work_item_event_queue.put(
283
328
  WorkItemEvent(
@@ -292,13 +337,20 @@ class WorkflowRunner(Generic[StateType]):
292
337
  code=VellumErrorCode.INTERNAL_ERROR,
293
338
  ),
294
339
  ),
340
+ parent=WorkflowParentContext(
341
+ span_id=span_id,
342
+ workflow_definition=self.workflow.__class__,
343
+ parent=self._parent_context,
344
+ ),
295
345
  ),
296
346
  )
297
347
  )
298
348
 
299
349
  logger.debug(f"Finished running node: {node.__class__.__name__}")
300
350
 
301
- def _handle_invoked_ports(self, state: StateType, ports: Optional[Iterable[Port]]) -> None:
351
+ def _handle_invoked_ports(
352
+ self, state: StateType, ports: Optional[Iterable[Port]]
353
+ ) -> None:
302
354
  if not ports:
303
355
  return
304
356
 
@@ -313,7 +365,10 @@ class WorkflowRunner(Generic[StateType]):
313
365
  self._run_node_if_ready(next_state, edge.to_node, edge)
314
366
 
315
367
  def _run_node_if_ready(
316
- self, state: StateType, node_class: Type[BaseNode], invoked_by: Optional[Edge] = None
368
+ self,
369
+ state: StateType,
370
+ node_class: Type[BaseNode],
371
+ invoked_by: Optional[Edge] = None,
317
372
  ) -> None:
318
373
  with state.__lock__:
319
374
  for descriptor in node_class.ExternalInputs:
@@ -325,18 +380,27 @@ class WorkflowRunner(Generic[StateType]):
325
380
  return
326
381
 
327
382
  all_deps = self._dependencies[node_class]
328
- if not node_class.Trigger.should_initiate(state, all_deps, invoked_by):
383
+ node_span_id = state.meta.node_execution_cache.queue_node_execution(
384
+ node_class, all_deps, invoked_by
385
+ )
386
+ if not node_class.Trigger.should_initiate(state, all_deps, node_span_id):
329
387
  return
330
388
 
331
389
  node = node_class(state=state, context=self.workflow.context)
332
- node_span_id = uuid4()
333
- state.meta.node_execution_cache.initiate_node_execution(node_class, node_span_id)
390
+ state.meta.node_execution_cache.initiate_node_execution(
391
+ node_class, node_span_id
392
+ )
334
393
  self._active_nodes_by_execution_id[node_span_id] = node
335
394
 
336
- worker_thread = Thread(target=self._run_work_item, kwargs={"node": node, "span_id": node_span_id})
395
+ worker_thread = Thread(
396
+ target=self._run_work_item,
397
+ kwargs={"node": node, "span_id": node_span_id},
398
+ )
337
399
  worker_thread.start()
338
400
 
339
- def _handle_work_item_event(self, work_item_event: WorkItemEvent[StateType]) -> Optional[VellumError]:
401
+ def _handle_work_item_event(
402
+ self, work_item_event: WorkItemEvent[StateType]
403
+ ) -> Optional[VellumError]:
340
404
  node = work_item_event.node
341
405
  event = work_item_event.event
342
406
 
@@ -352,7 +416,10 @@ class WorkflowRunner(Generic[StateType]):
352
416
  node_output_descriptor = workflow_output_descriptor.instance
353
417
  if not isinstance(node_output_descriptor, OutputReference):
354
418
  continue
355
- if node_output_descriptor.outputs_class != event.node_definition.Outputs:
419
+ if (
420
+ node_output_descriptor.outputs_class
421
+ != event.node_definition.Outputs
422
+ ):
356
423
  continue
357
424
  if node_output_descriptor.name != event.output.name:
358
425
  continue
@@ -387,9 +454,12 @@ class WorkflowRunner(Generic[StateType]):
387
454
  workflow_definition=self.workflow.__class__,
388
455
  inputs=self._initial_state.meta.workflow_inputs,
389
456
  ),
457
+ parent=self._parent_context,
390
458
  )
391
459
 
392
- def _stream_workflow_event(self, output: BaseOutput) -> WorkflowExecutionStreamingEvent:
460
+ def _stream_workflow_event(
461
+ self, output: BaseOutput
462
+ ) -> WorkflowExecutionStreamingEvent:
393
463
  return WorkflowExecutionStreamingEvent(
394
464
  trace_id=self._initial_state.meta.trace_id,
395
465
  span_id=self._initial_state.meta.span_id,
@@ -397,9 +467,12 @@ class WorkflowRunner(Generic[StateType]):
397
467
  workflow_definition=self.workflow.__class__,
398
468
  output=output,
399
469
  ),
470
+ parent=self._parent_context,
400
471
  )
401
472
 
402
- def _fulfill_workflow_event(self, outputs: OutputsType) -> WorkflowExecutionFulfilledEvent:
473
+ def _fulfill_workflow_event(
474
+ self, outputs: OutputsType
475
+ ) -> WorkflowExecutionFulfilledEvent:
403
476
  return WorkflowExecutionFulfilledEvent(
404
477
  trace_id=self._initial_state.meta.trace_id,
405
478
  span_id=self._initial_state.meta.span_id,
@@ -407,9 +480,12 @@ class WorkflowRunner(Generic[StateType]):
407
480
  workflow_definition=self.workflow.__class__,
408
481
  outputs=outputs,
409
482
  ),
483
+ parent=self._parent_context,
410
484
  )
411
485
 
412
- def _reject_workflow_event(self, error: VellumError) -> WorkflowExecutionRejectedEvent:
486
+ def _reject_workflow_event(
487
+ self, error: VellumError
488
+ ) -> WorkflowExecutionRejectedEvent:
413
489
  return WorkflowExecutionRejectedEvent(
414
490
  trace_id=self._initial_state.meta.trace_id,
415
491
  span_id=self._initial_state.meta.span_id,
@@ -417,6 +493,7 @@ class WorkflowRunner(Generic[StateType]):
417
493
  workflow_definition=self.workflow.__class__,
418
494
  error=error,
419
495
  ),
496
+ parent=self._parent_context,
420
497
  )
421
498
 
422
499
  def _resume_workflow_event(self) -> WorkflowExecutionResumedEvent:
@@ -428,7 +505,9 @@ class WorkflowRunner(Generic[StateType]):
428
505
  ),
429
506
  )
430
507
 
431
- def _pause_workflow_event(self, external_inputs: Iterable[ExternalInputReference]) -> WorkflowExecutionPausedEvent:
508
+ def _pause_workflow_event(
509
+ self, external_inputs: Iterable[ExternalInputReference]
510
+ ) -> WorkflowExecutionPausedEvent:
432
511
  return WorkflowExecutionPausedEvent(
433
512
  trace_id=self._initial_state.meta.trace_id,
434
513
  span_id=self._initial_state.meta.span_id,
@@ -436,6 +515,7 @@ class WorkflowRunner(Generic[StateType]):
436
515
  workflow_definition=self.workflow.__class__,
437
516
  external_inputs=external_inputs,
438
517
  ),
518
+ parent=self._parent_context,
439
519
  )
440
520
 
441
521
  def _stream(self) -> None:
@@ -444,7 +524,10 @@ class WorkflowRunner(Generic[StateType]):
444
524
  if not self._entrypoints:
445
525
  self._workflow_event_queue.put(
446
526
  self._reject_workflow_event(
447
- VellumError(message="No entrypoints defined", code=VellumErrorCode.INVALID_WORKFLOW)
527
+ VellumError(
528
+ message="No entrypoints defined",
529
+ code=VellumErrorCode.INVALID_WORKFLOW,
530
+ )
448
531
  )
449
532
  )
450
533
  return
@@ -463,7 +546,9 @@ class WorkflowRunner(Generic[StateType]):
463
546
  logger.exception(err_message)
464
547
  self._workflow_event_queue.put(
465
548
  self._reject_workflow_event(
466
- VellumError(code=VellumErrorCode.INTERNAL_ERROR, message=err_message),
549
+ VellumError(
550
+ code=VellumErrorCode.INTERNAL_ERROR, message=err_message
551
+ ),
467
552
  )
468
553
  )
469
554
  return
@@ -519,7 +604,11 @@ class WorkflowRunner(Generic[StateType]):
519
604
  if isinstance(value, BaseDescriptor):
520
605
  setattr(fulfilled_outputs, descriptor.name, value.resolve(final_state))
521
606
  elif isinstance(descriptor.instance, BaseDescriptor):
522
- setattr(fulfilled_outputs, descriptor.name, descriptor.instance.resolve(final_state))
607
+ setattr(
608
+ fulfilled_outputs,
609
+ descriptor.name,
610
+ descriptor.instance.resolve(final_state),
611
+ )
523
612
 
524
613
  self._workflow_event_queue.put(self._fulfill_workflow_event(fulfilled_outputs))
525
614
 
@@ -544,24 +633,41 @@ class WorkflowRunner(Generic[StateType]):
544
633
  self._cancel_signal.wait()
545
634
  self._workflow_event_queue.put(
546
635
  self._reject_workflow_event(
547
- VellumError(code=VellumErrorCode.WORKFLOW_CANCELLED, message="Workflow run cancelled")
636
+ VellumError(
637
+ code=VellumErrorCode.WORKFLOW_CANCELLED,
638
+ message="Workflow run cancelled",
639
+ )
548
640
  )
549
641
  )
550
642
 
643
+ def _is_terminal_event(self, event: WorkflowEvent) -> bool:
644
+ if (
645
+ event.name == "workflow.execution.fulfilled"
646
+ or event.name == "workflow.execution.rejected"
647
+ or event.name == "workflow.execution.paused"
648
+ ):
649
+ return event.workflow_definition == self.workflow.__class__
650
+ return False
651
+
551
652
  def stream(self) -> WorkflowEventStream:
552
653
  background_thread = Thread(
553
- target=self._run_background_thread, name=f"{self.workflow.__class__.__name__}.background_thread"
654
+ target=self._run_background_thread,
655
+ name=f"{self.workflow.__class__.__name__}.background_thread",
554
656
  )
555
657
  background_thread.start()
556
658
 
557
659
  if self._cancel_signal:
558
660
  cancel_thread = Thread(
559
- target=self._run_cancel_thread, name=f"{self.workflow.__class__.__name__}.cancel_thread"
661
+ target=self._run_cancel_thread,
662
+ name=f"{self.workflow.__class__.__name__}.cancel_thread",
560
663
  )
561
664
  cancel_thread.start()
562
665
 
563
666
  event: WorkflowEvent
564
- if self._initial_state.meta.is_terminated or self._initial_state.meta.is_terminated is None:
667
+ if (
668
+ self._initial_state.meta.is_terminated
669
+ or self._initial_state.meta.is_terminated is None
670
+ ):
565
671
  event = self._initiate_workflow_event()
566
672
  else:
567
673
  event = self._resume_workflow_event()
@@ -570,7 +676,10 @@ class WorkflowRunner(Generic[StateType]):
570
676
  self._initial_state.meta.is_terminated = False
571
677
 
572
678
  # The extra level of indirection prevents the runner from waiting on the caller to consume the event stream
573
- stream_thread = Thread(target=self._stream, name=f"{self.workflow.__class__.__name__}.stream_thread")
679
+ stream_thread = Thread(
680
+ target=self._stream,
681
+ name=f"{self.workflow.__class__.__name__}.stream_thread",
682
+ )
574
683
  stream_thread.start()
575
684
 
576
685
  while stream_thread.is_alive():
@@ -581,7 +690,7 @@ class WorkflowRunner(Generic[StateType]):
581
690
 
582
691
  yield self._emit_event(event)
583
692
 
584
- if is_terminal_event(event):
693
+ if self._is_terminal_event(event):
585
694
  break
586
695
 
587
696
  try:
@@ -590,7 +699,7 @@ class WorkflowRunner(Generic[StateType]):
590
699
  except Empty:
591
700
  pass
592
701
 
593
- if not is_terminal_event(event):
702
+ if not self._is_terminal_event(event):
594
703
  yield self._reject_workflow_event(
595
704
  VellumError(
596
705
  code=VellumErrorCode.INTERNAL_ERROR,
@@ -5,15 +5,15 @@ from datetime import datetime
5
5
  from queue import Queue
6
6
  from threading import Lock
7
7
  from uuid import UUID, uuid4
8
- from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, Optional, Sequence, Set, Tuple, Type, cast
8
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, cast
9
9
  from typing_extensions import dataclass_transform
10
10
 
11
11
  from pydantic import GetCoreSchemaHandler, field_serializer
12
12
  from pydantic_core import core_schema
13
13
 
14
14
  from vellum.core.pydantic_utilities import UniversalBaseModel
15
-
16
15
  from vellum.workflows.constants import UNDEF
16
+ from vellum.workflows.edges.edge import Edge
17
17
  from vellum.workflows.inputs.base import BaseInputs
18
18
  from vellum.workflows.references import ExternalInputReference, OutputReference, StateValueReference
19
19
  from vellum.workflows.types.generics import StateType
@@ -71,58 +71,92 @@ def _make_snapshottable(value: Any, snapshot_callback: Callable[[], None]) -> An
71
71
 
72
72
 
73
73
  class NodeExecutionCache:
74
- _node_execution_ids: Dict[Type["BaseNode"], Stack[UUID]]
74
+ _node_executions_fulfilled: Dict[Type["BaseNode"], Stack[UUID]]
75
75
  _node_executions_initiated: Dict[Type["BaseNode"], Set[UUID]]
76
- _dependencies_invoked: Dict[str, Set[str]]
76
+ _node_executions_queued: Dict[Type["BaseNode"], List[UUID]]
77
+ _dependencies_invoked: Dict[UUID, Set[Type["BaseNode"]]]
77
78
 
78
79
  def __init__(
79
80
  self,
80
81
  dependencies_invoked: Optional[Dict[str, Sequence[str]]] = None,
81
- node_execution_ids: Optional[Dict[str, Sequence[str]]] = None,
82
+ node_executions_fulfilled: Optional[Dict[str, Sequence[str]]] = None,
82
83
  node_executions_initiated: Optional[Dict[str, Sequence[str]]] = None,
84
+ node_executions_queued: Optional[Dict[str, Sequence[str]]] = None,
83
85
  ) -> None:
84
86
  self._dependencies_invoked = defaultdict(set)
85
- self._node_execution_ids = defaultdict(Stack[UUID])
87
+ self._node_executions_fulfilled = defaultdict(Stack[UUID])
86
88
  self._node_executions_initiated = defaultdict(set)
89
+ self._node_executions_queued = defaultdict(list)
87
90
 
88
- for node, dependencies in (dependencies_invoked or {}).items():
89
- self._dependencies_invoked[node].update(dependencies)
91
+ for execution_id, dependencies in (dependencies_invoked or {}).items():
92
+ self._dependencies_invoked[UUID(execution_id)] = {get_class_by_qualname(dep) for dep in dependencies}
90
93
 
91
- for node, execution_ids in (node_execution_ids or {}).items():
94
+ for node, execution_ids in (node_executions_fulfilled or {}).items():
92
95
  node_class = get_class_by_qualname(node)
93
- self._node_execution_ids[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
96
+ self._node_executions_fulfilled[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
94
97
 
95
98
  for node, execution_ids in (node_executions_initiated or {}).items():
96
99
  node_class = get_class_by_qualname(node)
97
100
  self._node_executions_initiated[node_class].update({UUID(execution_id) for execution_id in execution_ids})
98
101
 
99
- @property
100
- def dependencies_invoked(self) -> Dict[str, Set[str]]:
101
- return self._dependencies_invoked
102
+ for node, execution_ids in (node_executions_queued or {}).items():
103
+ node_class = get_class_by_qualname(node)
104
+ self._node_executions_queued[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
102
105
 
103
- def is_node_initiated(self, node: Type["BaseNode"]) -> bool:
104
- return node in self._node_executions_initiated and len(self._node_executions_initiated[node]) > 0
106
+ def _invoke_dependency(
107
+ self,
108
+ execution_id: UUID,
109
+ node: Type["BaseNode"],
110
+ dependency: Type["BaseNode"],
111
+ dependencies: Set["Type[BaseNode]"],
112
+ ) -> None:
113
+ self._dependencies_invoked[execution_id].add(dependency)
114
+ if all(dep in self._dependencies_invoked[execution_id] for dep in dependencies):
115
+ self._node_executions_queued[node].remove(execution_id)
116
+
117
+ def queue_node_execution(
118
+ self, node: Type["BaseNode"], dependencies: Set["Type[BaseNode]"], invoked_by: Optional[Edge] = None
119
+ ) -> UUID:
120
+ execution_id = uuid4()
121
+ if not invoked_by:
122
+ return execution_id
123
+
124
+ source_node = invoked_by.from_port.node_class
125
+ for queued_node_execution_id in self._node_executions_queued[node]:
126
+ if source_node not in self._dependencies_invoked[queued_node_execution_id]:
127
+ self._invoke_dependency(queued_node_execution_id, node, source_node, dependencies)
128
+ return queued_node_execution_id
129
+
130
+ self._node_executions_queued[node].append(execution_id)
131
+ self._invoke_dependency(execution_id, node, source_node, dependencies)
132
+ return execution_id
133
+
134
+ def is_node_execution_initiated(self, node: Type["BaseNode"], execution_id: UUID) -> bool:
135
+ return execution_id in self._node_executions_initiated[node]
105
136
 
106
137
  def initiate_node_execution(self, node: Type["BaseNode"], execution_id: UUID) -> None:
107
138
  self._node_executions_initiated[node].add(execution_id)
108
139
 
109
140
  def fulfill_node_execution(self, node: Type["BaseNode"], execution_id: UUID) -> None:
110
- self._node_executions_initiated[node].remove(execution_id)
111
- self._node_execution_ids[node].push(execution_id)
141
+ self._node_executions_fulfilled[node].push(execution_id)
112
142
 
113
143
  def get_execution_count(self, node: Type["BaseNode"]) -> int:
114
- return self._node_execution_ids[node].size()
144
+ return self._node_executions_fulfilled[node].size()
115
145
 
116
146
  def dump(self) -> Dict[str, Any]:
117
147
  return {
118
148
  "dependencies_invoked": {
119
- node: list(dependencies) for node, dependencies in self._dependencies_invoked.items()
149
+ str(execution_id): [str(dep) for dep in dependencies]
150
+ for execution_id, dependencies in self._dependencies_invoked.items()
120
151
  },
121
152
  "node_executions_initiated": {
122
153
  str(node): list(execution_ids) for node, execution_ids in self._node_executions_initiated.items()
123
154
  },
124
- "node_execution_ids": {
125
- str(node): execution_ids.dump() for node, execution_ids in self._node_execution_ids.items()
155
+ "node_executions_fulfilled": {
156
+ str(node): execution_ids.dump() for node, execution_ids in self._node_executions_fulfilled.items()
157
+ },
158
+ "node_executions_queued": {
159
+ str(node): execution_ids for node, execution_ids in self._node_executions_queued.items()
126
160
  },
127
161
  }
128
162