vellum-ai 0.10.7__py3-none-any.whl → 0.10.9__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (41) hide show
  1. vellum/client/core/client_wrapper.py +1 -1
  2. vellum/client/types/logical_operator.py +2 -0
  3. vellum/workflows/descriptors/utils.py +27 -0
  4. vellum/workflows/events/__init__.py +0 -2
  5. vellum/workflows/events/tests/test_event.py +2 -1
  6. vellum/workflows/events/types.py +36 -30
  7. vellum/workflows/events/workflow.py +14 -7
  8. vellum/workflows/nodes/bases/base.py +100 -38
  9. vellum/workflows/nodes/core/inline_subworkflow_node/node.py +1 -0
  10. vellum/workflows/nodes/core/templating_node/node.py +5 -0
  11. vellum/workflows/nodes/core/try_node/node.py +22 -4
  12. vellum/workflows/nodes/core/try_node/tests/test_node.py +15 -0
  13. vellum/workflows/nodes/displayable/api_node/node.py +1 -1
  14. vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +1 -2
  15. vellum/workflows/nodes/displayable/code_execution_node/node.py +1 -2
  16. vellum/workflows/nodes/displayable/code_execution_node/utils.py +13 -2
  17. vellum/workflows/nodes/displayable/inline_prompt_node/node.py +10 -3
  18. vellum/workflows/nodes/displayable/prompt_deployment_node/node.py +6 -1
  19. vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +1 -2
  20. vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +1 -2
  21. vellum/workflows/runner/runner.py +141 -32
  22. vellum/workflows/state/base.py +55 -21
  23. vellum/workflows/state/context.py +26 -3
  24. vellum/workflows/types/__init__.py +5 -0
  25. vellum/workflows/types/core.py +1 -1
  26. vellum/workflows/workflows/base.py +51 -17
  27. vellum/workflows/workflows/event_filters.py +61 -0
  28. {vellum_ai-0.10.7.dist-info → vellum_ai-0.10.9.dist-info}/METADATA +1 -1
  29. {vellum_ai-0.10.7.dist-info → vellum_ai-0.10.9.dist-info}/RECORD +40 -38
  30. vellum_cli/__init__.py +23 -4
  31. vellum_cli/pull.py +28 -13
  32. vellum_cli/tests/test_pull.py +45 -2
  33. vellum_ee/workflows/display/nodes/base_node_display.py +1 -1
  34. vellum_ee/workflows/display/nodes/vellum/__init__.py +6 -4
  35. vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +17 -2
  36. vellum_ee/workflows/display/nodes/vellum/error_node.py +49 -0
  37. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +203 -0
  38. vellum/workflows/events/utils.py +0 -5
  39. {vellum_ai-0.10.7.dist-info → vellum_ai-0.10.9.dist-info}/LICENSE +0 -0
  40. {vellum_ai-0.10.7.dist-info → vellum_ai-0.10.9.dist-info}/WHEEL +0 -0
  41. {vellum_ai-0.10.7.dist-info → vellum_ai-0.10.9.dist-info}/entry_points.txt +0 -0
@@ -67,9 +67,8 @@ class CodeExecutionNode(BaseNode[StateType], Generic[StateType, _OutputType], me
67
67
 
68
68
  filepath: str - The path to the script to execute.
69
69
  code_inputs: EntityInputsInterface - The inputs for the custom script.
70
- output_type: VellumVariableType = "STRING" - The type of the output from the custom script.
71
70
  runtime: CodeExecutionRuntime = "PYTHON_3_12" - The runtime to use for the custom script.
72
- packages: Optional[Sequence[CodeExecutionPackageRequest]] = None - The packages to use for the custom script.
71
+ packages: Optional[Sequence[CodeExecutionPackage]] = None - The packages to use for the custom script.
73
72
  request_options: Optional[RequestOptions] = None - The request options to use for the custom script.
74
73
  """
75
74
 
@@ -2,9 +2,20 @@ import os
2
2
  from typing import Union
3
3
 
4
4
 
5
+ def get_project_root() -> str:
6
+ current_dir = os.getcwd()
7
+ while current_dir != '/':
8
+ if ".git" in os.listdir(current_dir):
9
+ return current_dir
10
+ current_dir = os.path.dirname(current_dir)
11
+ raise FileNotFoundError("Project root not found.")
12
+
5
13
  def read_file_from_path(filepath: str) -> Union[str, None]:
6
- if not os.path.exists(filepath):
14
+ project_root = get_project_root()
15
+ relative_filepath = os.path.join(project_root, filepath)
16
+
17
+ if not os.path.exists(relative_filepath):
7
18
  return None
8
19
 
9
- with open(filepath) as file:
20
+ with open(relative_filepath, 'r') as file:
10
21
  return file.read()
@@ -9,16 +9,23 @@ from vellum.workflows.types.generics import StateType
9
9
 
10
10
  class InlinePromptNode(BaseInlinePromptNode[StateType]):
11
11
  """
12
- Used to execute an Inline Prompt and surface a string output for convenience.
12
+ Used to execute a Prompt defined inline.
13
13
 
14
14
  prompt_inputs: EntityInputsInterface - The inputs for the Prompt
15
15
  ml_model: str - Either the ML Model's UUID or its name.
16
- blocks: List[PromptBlockRequest] - The blocks that make up the Prompt
16
+ blocks: List[PromptBlock] - The blocks that make up the Prompt
17
+ functions: Optional[List[FunctionDefinition]] - The functions to include in the Prompt
17
18
  parameters: PromptParameters - The parameters for the Prompt
18
- expand_meta: Optional[AdHocExpandMetaRequest] - Set of expandable execution fields to include in the response
19
+ expand_meta: Optional[AdHocExpandMeta] - Expandable execution fields to include in the response
20
+ request_options: Optional[RequestOptions] - The request options to use for the Prompt Execution
19
21
  """
20
22
 
21
23
  class Outputs(BaseInlinePromptNode.Outputs):
24
+ """
25
+ The outputs of the InlinePromptNode.
26
+
27
+ text: str - The result of the Prompt Execution
28
+ """
22
29
  text: str
23
30
 
24
31
  def run(self) -> Iterator[BaseOutput]:
@@ -14,7 +14,7 @@ class PromptDeploymentNode(BasePromptDeploymentNode[StateType]):
14
14
  prompt_inputs: EntityInputsInterface - The inputs for the Prompt
15
15
  deployment: Union[UUID, str] - Either the Prompt Deployment's UUID or its name.
16
16
  release_tag: str - The release tag to use for the Prompt Execution
17
- external_id: Optional[str] - The external ID to use for the Prompt Execution
17
+ external_id: Optional[str] - Optionally include a unique identifier for tracking purposes. Must be unique within a given Prompt Deployment.
18
18
  expand_meta: Optional[PromptDeploymentExpandMetaRequest] - Expandable execution fields to include in the response
19
19
  raw_overrides: Optional[RawPromptExecutionOverridesRequest] - The raw overrides to use for the Prompt Execution
20
20
  expand_raw: Optional[Sequence[str]] - Expandable raw fields to include in the response
@@ -23,6 +23,11 @@ class PromptDeploymentNode(BasePromptDeploymentNode[StateType]):
23
23
  """
24
24
 
25
25
  class Outputs(BasePromptDeploymentNode.Outputs):
26
+ """
27
+ The outputs of the PromptDeploymentNode.
28
+
29
+ text: str - The result of the Prompt Execution
30
+ """
26
31
  text: str
27
32
 
28
33
  def run(self) -> Iterator[BaseOutput]:
@@ -12,7 +12,6 @@ from vellum import (
12
12
  WorkflowRequestStringInputRequest,
13
13
  )
14
14
  from vellum.core import RequestOptions
15
-
16
15
  from vellum.workflows.constants import LATEST_RELEASE_TAG, OMIT
17
16
  from vellum.workflows.errors import VellumErrorCode
18
17
  from vellum.workflows.exceptions import NodeException
@@ -28,7 +27,7 @@ class SubworkflowDeploymentNode(BaseSubworkflowNode[StateType], Generic[StateTyp
28
27
  subworkflow_inputs: EntityInputsInterface - The inputs for the Subworkflow
29
28
  deployment: Union[UUID, str] - Either the Workflow Deployment's UUID or its name.
30
29
  release_tag: str = LATEST_RELEASE_TAG - The release tag to use for the Workflow Execution
31
- external_id: Optional[str] = OMIT - The external ID to use for the Workflow Execution
30
+ external_id: Optional[str] = OMIT - Optionally include a unique identifier for tracking purposes. Must be unique within a given Workflow Deployment.
32
31
  expand_meta: Optional[WorkflowExpandMetaRequest] = OMIT - Expandable execution fields to include in the respownse
33
32
  metadata: Optional[Dict[str, Optional[Any]]] = OMIT - The metadata to use for the Workflow Execution
34
33
  request_options: Optional[RequestOptions] = None - The request options to use for the Workflow Execution
@@ -8,7 +8,6 @@ from vellum import (
8
8
  PromptOutput,
9
9
  StringVellumValue,
10
10
  )
11
-
12
11
  from vellum.workflows.constants import OMIT
13
12
  from vellum.workflows.inputs import BaseInputs
14
13
  from vellum.workflows.nodes import PromptDeploymentNode
@@ -65,7 +64,7 @@ def test_text_prompt_deployment_node__basic(vellum_client):
65
64
  assert text_output.name == "text"
66
65
  assert text_output.value == "Hello, world!"
67
66
 
68
- # AND we should have made the expected call to Vellum search
67
+ # AND we should have made the expected call to stream the prompt execution
69
68
  vellum_client.execute_prompt_stream.assert_called_once_with(
70
69
  expand_meta=OMIT,
71
70
  expand_raw=OMIT,
@@ -3,7 +3,7 @@ from copy import deepcopy
3
3
  import logging
4
4
  from queue import Empty, Queue
5
5
  from threading import Event as ThreadingEvent, Thread
6
- from uuid import UUID, uuid4
6
+ from uuid import UUID
7
7
  from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Optional, Sequence, Set, Type, Union
8
8
 
9
9
  from vellum.workflows.constants import UNDEF
@@ -28,8 +28,7 @@ from vellum.workflows.events.node import (
28
28
  NodeExecutionRejectedBody,
29
29
  NodeExecutionStreamingBody,
30
30
  )
31
- from vellum.workflows.events.types import BaseEvent
32
- from vellum.workflows.events.utils import is_terminal_event
31
+ from vellum.workflows.events.types import BaseEvent, ParentContext, WorkflowParentContext
33
32
  from vellum.workflows.events.workflow import (
34
33
  WorkflowExecutionFulfilledBody,
35
34
  WorkflowExecutionInitiatedBody,
@@ -71,9 +70,12 @@ class WorkflowRunner(Generic[StateType]):
71
70
  entrypoint_nodes: Optional[RunFromNodeArg] = None,
72
71
  external_inputs: Optional[ExternalInputsArg] = None,
73
72
  cancel_signal: Optional[ThreadingEvent] = None,
73
+ parent_context: Optional[ParentContext] = None,
74
74
  ):
75
75
  if state and external_inputs:
76
- raise ValueError("Can only run a Workflow providing one of state or external inputs, not both")
76
+ raise ValueError(
77
+ "Can only run a Workflow providing one of state or external inputs, not both"
78
+ )
77
79
 
78
80
  self.workflow = workflow
79
81
  if entrypoint_nodes:
@@ -99,7 +101,9 @@ class WorkflowRunner(Generic[StateType]):
99
101
  if issubclass(ei.inputs_class.__parent_class__, BaseNode)
100
102
  ]
101
103
  else:
102
- normalized_inputs = deepcopy(inputs) if inputs else self.workflow.get_default_inputs()
104
+ normalized_inputs = (
105
+ deepcopy(inputs) if inputs else self.workflow.get_default_inputs()
106
+ )
103
107
  if state:
104
108
  self._initial_state = deepcopy(state)
105
109
  self._initial_state.meta.workflow_inputs = normalized_inputs
@@ -115,12 +119,14 @@ class WorkflowRunner(Generic[StateType]):
115
119
 
116
120
  self._active_nodes_by_execution_id: Dict[UUID, BaseNode[StateType]] = {}
117
121
  self._cancel_signal = cancel_signal
122
+ self._parent_context = parent_context
118
123
 
119
124
  setattr(
120
125
  self._initial_state,
121
126
  "__snapshot_callback__",
122
127
  lambda s: self._snapshot_state(s),
123
128
  )
129
+ self.workflow.context._register_event_queue(self._workflow_event_queue)
124
130
 
125
131
  def _snapshot_state(self, state: StateType) -> StateType:
126
132
  self.workflow._store.append_state_snapshot(state)
@@ -143,6 +149,12 @@ class WorkflowRunner(Generic[StateType]):
143
149
  node_definition=node.__class__,
144
150
  inputs=node._inputs,
145
151
  ),
152
+ parent=WorkflowParentContext(
153
+ span_id=span_id,
154
+ workflow_definition=self.workflow.__class__,
155
+ parent=self._parent_context,
156
+ type="WORKFLOW",
157
+ ),
146
158
  ),
147
159
  )
148
160
  )
@@ -178,7 +190,11 @@ class WorkflowRunner(Generic[StateType]):
178
190
  instance=None,
179
191
  outputs_class=node.Outputs,
180
192
  )
181
- node.state.meta.node_outputs[output_descriptor] = streaming_output_queues[output.name]
193
+ node.state.meta.node_outputs[output_descriptor] = (
194
+ streaming_output_queues[output.name]
195
+ )
196
+ initiated_output: BaseOutput = BaseOutput(name=output.name)
197
+ initiated_ports = initiated_output > ports
182
198
  self._work_item_event_queue.put(
183
199
  WorkItemEvent(
184
200
  node=node,
@@ -187,8 +203,13 @@ class WorkflowRunner(Generic[StateType]):
187
203
  span_id=span_id,
188
204
  body=NodeExecutionStreamingBody(
189
205
  node_definition=node.__class__,
190
- output=BaseOutput(name=output.name),
191
- invoked_ports=invoked_ports,
206
+ output=initiated_output,
207
+ invoked_ports=initiated_ports,
208
+ ),
209
+ parent=WorkflowParentContext(
210
+ span_id=span_id,
211
+ workflow_definition=self.workflow.__class__,
212
+ parent=self._parent_context,
192
213
  ),
193
214
  ),
194
215
  )
@@ -214,6 +235,11 @@ class WorkflowRunner(Generic[StateType]):
214
235
  output=output,
215
236
  invoked_ports=invoked_ports,
216
237
  ),
238
+ parent=WorkflowParentContext(
239
+ span_id=span_id,
240
+ workflow_definition=self.workflow.__class__,
241
+ parent=self._parent_context,
242
+ ),
217
243
  ),
218
244
  )
219
245
  )
@@ -233,6 +259,11 @@ class WorkflowRunner(Generic[StateType]):
233
259
  output=output,
234
260
  invoked_ports=invoked_ports,
235
261
  ),
262
+ parent=WorkflowParentContext(
263
+ span_id=span_id,
264
+ workflow_definition=self.workflow.__class__,
265
+ parent=self._parent_context,
266
+ ),
236
267
  ),
237
268
  )
238
269
  )
@@ -246,7 +277,9 @@ class WorkflowRunner(Generic[StateType]):
246
277
  node.state.meta.node_outputs[descriptor] = output_value
247
278
 
248
279
  invoked_ports = ports(outputs, node.state)
249
- node.state.meta.node_execution_cache.fulfill_node_execution(node.__class__, span_id)
280
+ node.state.meta.node_execution_cache.fulfill_node_execution(
281
+ node.__class__, span_id
282
+ )
250
283
 
251
284
  self._work_item_event_queue.put(
252
285
  WorkItemEvent(
@@ -259,6 +292,11 @@ class WorkflowRunner(Generic[StateType]):
259
292
  outputs=outputs,
260
293
  invoked_ports=invoked_ports,
261
294
  ),
295
+ parent=WorkflowParentContext(
296
+ span_id=span_id,
297
+ workflow_definition=self.workflow.__class__,
298
+ parent=self._parent_context,
299
+ ),
262
300
  ),
263
301
  )
264
302
  )
@@ -273,11 +311,18 @@ class WorkflowRunner(Generic[StateType]):
273
311
  node_definition=node.__class__,
274
312
  error=e.error,
275
313
  ),
314
+ parent=WorkflowParentContext(
315
+ span_id=span_id,
316
+ workflow_definition=self.workflow.__class__,
317
+ parent=self._parent_context,
318
+ ),
276
319
  ),
277
320
  )
278
321
  )
279
322
  except Exception as e:
280
- logger.exception(f"An unexpected error occurred while running node {node.__class__.__name__}")
323
+ logger.exception(
324
+ f"An unexpected error occurred while running node {node.__class__.__name__}"
325
+ )
281
326
 
282
327
  self._work_item_event_queue.put(
283
328
  WorkItemEvent(
@@ -292,13 +337,20 @@ class WorkflowRunner(Generic[StateType]):
292
337
  code=VellumErrorCode.INTERNAL_ERROR,
293
338
  ),
294
339
  ),
340
+ parent=WorkflowParentContext(
341
+ span_id=span_id,
342
+ workflow_definition=self.workflow.__class__,
343
+ parent=self._parent_context,
344
+ ),
295
345
  ),
296
346
  )
297
347
  )
298
348
 
299
349
  logger.debug(f"Finished running node: {node.__class__.__name__}")
300
350
 
301
- def _handle_invoked_ports(self, state: StateType, ports: Optional[Iterable[Port]]) -> None:
351
+ def _handle_invoked_ports(
352
+ self, state: StateType, ports: Optional[Iterable[Port]]
353
+ ) -> None:
302
354
  if not ports:
303
355
  return
304
356
 
@@ -313,7 +365,10 @@ class WorkflowRunner(Generic[StateType]):
313
365
  self._run_node_if_ready(next_state, edge.to_node, edge)
314
366
 
315
367
  def _run_node_if_ready(
316
- self, state: StateType, node_class: Type[BaseNode], invoked_by: Optional[Edge] = None
368
+ self,
369
+ state: StateType,
370
+ node_class: Type[BaseNode],
371
+ invoked_by: Optional[Edge] = None,
317
372
  ) -> None:
318
373
  with state.__lock__:
319
374
  for descriptor in node_class.ExternalInputs:
@@ -325,18 +380,27 @@ class WorkflowRunner(Generic[StateType]):
325
380
  return
326
381
 
327
382
  all_deps = self._dependencies[node_class]
328
- if not node_class.Trigger.should_initiate(state, all_deps, invoked_by):
383
+ node_span_id = state.meta.node_execution_cache.queue_node_execution(
384
+ node_class, all_deps, invoked_by
385
+ )
386
+ if not node_class.Trigger.should_initiate(state, all_deps, node_span_id):
329
387
  return
330
388
 
331
389
  node = node_class(state=state, context=self.workflow.context)
332
- node_span_id = uuid4()
333
- state.meta.node_execution_cache.initiate_node_execution(node_class, node_span_id)
390
+ state.meta.node_execution_cache.initiate_node_execution(
391
+ node_class, node_span_id
392
+ )
334
393
  self._active_nodes_by_execution_id[node_span_id] = node
335
394
 
336
- worker_thread = Thread(target=self._run_work_item, kwargs={"node": node, "span_id": node_span_id})
395
+ worker_thread = Thread(
396
+ target=self._run_work_item,
397
+ kwargs={"node": node, "span_id": node_span_id},
398
+ )
337
399
  worker_thread.start()
338
400
 
339
- def _handle_work_item_event(self, work_item_event: WorkItemEvent[StateType]) -> Optional[VellumError]:
401
+ def _handle_work_item_event(
402
+ self, work_item_event: WorkItemEvent[StateType]
403
+ ) -> Optional[VellumError]:
340
404
  node = work_item_event.node
341
405
  event = work_item_event.event
342
406
 
@@ -352,7 +416,10 @@ class WorkflowRunner(Generic[StateType]):
352
416
  node_output_descriptor = workflow_output_descriptor.instance
353
417
  if not isinstance(node_output_descriptor, OutputReference):
354
418
  continue
355
- if node_output_descriptor.outputs_class != event.node_definition.Outputs:
419
+ if (
420
+ node_output_descriptor.outputs_class
421
+ != event.node_definition.Outputs
422
+ ):
356
423
  continue
357
424
  if node_output_descriptor.name != event.output.name:
358
425
  continue
@@ -387,9 +454,12 @@ class WorkflowRunner(Generic[StateType]):
387
454
  workflow_definition=self.workflow.__class__,
388
455
  inputs=self._initial_state.meta.workflow_inputs,
389
456
  ),
457
+ parent=self._parent_context,
390
458
  )
391
459
 
392
- def _stream_workflow_event(self, output: BaseOutput) -> WorkflowExecutionStreamingEvent:
460
+ def _stream_workflow_event(
461
+ self, output: BaseOutput
462
+ ) -> WorkflowExecutionStreamingEvent:
393
463
  return WorkflowExecutionStreamingEvent(
394
464
  trace_id=self._initial_state.meta.trace_id,
395
465
  span_id=self._initial_state.meta.span_id,
@@ -397,9 +467,12 @@ class WorkflowRunner(Generic[StateType]):
397
467
  workflow_definition=self.workflow.__class__,
398
468
  output=output,
399
469
  ),
470
+ parent=self._parent_context,
400
471
  )
401
472
 
402
- def _fulfill_workflow_event(self, outputs: OutputsType) -> WorkflowExecutionFulfilledEvent:
473
+ def _fulfill_workflow_event(
474
+ self, outputs: OutputsType
475
+ ) -> WorkflowExecutionFulfilledEvent:
403
476
  return WorkflowExecutionFulfilledEvent(
404
477
  trace_id=self._initial_state.meta.trace_id,
405
478
  span_id=self._initial_state.meta.span_id,
@@ -407,9 +480,12 @@ class WorkflowRunner(Generic[StateType]):
407
480
  workflow_definition=self.workflow.__class__,
408
481
  outputs=outputs,
409
482
  ),
483
+ parent=self._parent_context,
410
484
  )
411
485
 
412
- def _reject_workflow_event(self, error: VellumError) -> WorkflowExecutionRejectedEvent:
486
+ def _reject_workflow_event(
487
+ self, error: VellumError
488
+ ) -> WorkflowExecutionRejectedEvent:
413
489
  return WorkflowExecutionRejectedEvent(
414
490
  trace_id=self._initial_state.meta.trace_id,
415
491
  span_id=self._initial_state.meta.span_id,
@@ -417,6 +493,7 @@ class WorkflowRunner(Generic[StateType]):
417
493
  workflow_definition=self.workflow.__class__,
418
494
  error=error,
419
495
  ),
496
+ parent=self._parent_context,
420
497
  )
421
498
 
422
499
  def _resume_workflow_event(self) -> WorkflowExecutionResumedEvent:
@@ -428,7 +505,9 @@ class WorkflowRunner(Generic[StateType]):
428
505
  ),
429
506
  )
430
507
 
431
- def _pause_workflow_event(self, external_inputs: Iterable[ExternalInputReference]) -> WorkflowExecutionPausedEvent:
508
+ def _pause_workflow_event(
509
+ self, external_inputs: Iterable[ExternalInputReference]
510
+ ) -> WorkflowExecutionPausedEvent:
432
511
  return WorkflowExecutionPausedEvent(
433
512
  trace_id=self._initial_state.meta.trace_id,
434
513
  span_id=self._initial_state.meta.span_id,
@@ -436,6 +515,7 @@ class WorkflowRunner(Generic[StateType]):
436
515
  workflow_definition=self.workflow.__class__,
437
516
  external_inputs=external_inputs,
438
517
  ),
518
+ parent=self._parent_context,
439
519
  )
440
520
 
441
521
  def _stream(self) -> None:
@@ -444,7 +524,10 @@ class WorkflowRunner(Generic[StateType]):
444
524
  if not self._entrypoints:
445
525
  self._workflow_event_queue.put(
446
526
  self._reject_workflow_event(
447
- VellumError(message="No entrypoints defined", code=VellumErrorCode.INVALID_WORKFLOW)
527
+ VellumError(
528
+ message="No entrypoints defined",
529
+ code=VellumErrorCode.INVALID_WORKFLOW,
530
+ )
448
531
  )
449
532
  )
450
533
  return
@@ -463,7 +546,9 @@ class WorkflowRunner(Generic[StateType]):
463
546
  logger.exception(err_message)
464
547
  self._workflow_event_queue.put(
465
548
  self._reject_workflow_event(
466
- VellumError(code=VellumErrorCode.INTERNAL_ERROR, message=err_message),
549
+ VellumError(
550
+ code=VellumErrorCode.INTERNAL_ERROR, message=err_message
551
+ ),
467
552
  )
468
553
  )
469
554
  return
@@ -519,7 +604,11 @@ class WorkflowRunner(Generic[StateType]):
519
604
  if isinstance(value, BaseDescriptor):
520
605
  setattr(fulfilled_outputs, descriptor.name, value.resolve(final_state))
521
606
  elif isinstance(descriptor.instance, BaseDescriptor):
522
- setattr(fulfilled_outputs, descriptor.name, descriptor.instance.resolve(final_state))
607
+ setattr(
608
+ fulfilled_outputs,
609
+ descriptor.name,
610
+ descriptor.instance.resolve(final_state),
611
+ )
523
612
 
524
613
  self._workflow_event_queue.put(self._fulfill_workflow_event(fulfilled_outputs))
525
614
 
@@ -544,24 +633,41 @@ class WorkflowRunner(Generic[StateType]):
544
633
  self._cancel_signal.wait()
545
634
  self._workflow_event_queue.put(
546
635
  self._reject_workflow_event(
547
- VellumError(code=VellumErrorCode.WORKFLOW_CANCELLED, message="Workflow run cancelled")
636
+ VellumError(
637
+ code=VellumErrorCode.WORKFLOW_CANCELLED,
638
+ message="Workflow run cancelled",
639
+ )
548
640
  )
549
641
  )
550
642
 
643
+ def _is_terminal_event(self, event: WorkflowEvent) -> bool:
644
+ if (
645
+ event.name == "workflow.execution.fulfilled"
646
+ or event.name == "workflow.execution.rejected"
647
+ or event.name == "workflow.execution.paused"
648
+ ):
649
+ return event.workflow_definition == self.workflow.__class__
650
+ return False
651
+
551
652
  def stream(self) -> WorkflowEventStream:
552
653
  background_thread = Thread(
553
- target=self._run_background_thread, name=f"{self.workflow.__class__.__name__}.background_thread"
654
+ target=self._run_background_thread,
655
+ name=f"{self.workflow.__class__.__name__}.background_thread",
554
656
  )
555
657
  background_thread.start()
556
658
 
557
659
  if self._cancel_signal:
558
660
  cancel_thread = Thread(
559
- target=self._run_cancel_thread, name=f"{self.workflow.__class__.__name__}.cancel_thread"
661
+ target=self._run_cancel_thread,
662
+ name=f"{self.workflow.__class__.__name__}.cancel_thread",
560
663
  )
561
664
  cancel_thread.start()
562
665
 
563
666
  event: WorkflowEvent
564
- if self._initial_state.meta.is_terminated or self._initial_state.meta.is_terminated is None:
667
+ if (
668
+ self._initial_state.meta.is_terminated
669
+ or self._initial_state.meta.is_terminated is None
670
+ ):
565
671
  event = self._initiate_workflow_event()
566
672
  else:
567
673
  event = self._resume_workflow_event()
@@ -570,7 +676,10 @@ class WorkflowRunner(Generic[StateType]):
570
676
  self._initial_state.meta.is_terminated = False
571
677
 
572
678
  # The extra level of indirection prevents the runner from waiting on the caller to consume the event stream
573
- stream_thread = Thread(target=self._stream, name=f"{self.workflow.__class__.__name__}.stream_thread")
679
+ stream_thread = Thread(
680
+ target=self._stream,
681
+ name=f"{self.workflow.__class__.__name__}.stream_thread",
682
+ )
574
683
  stream_thread.start()
575
684
 
576
685
  while stream_thread.is_alive():
@@ -581,7 +690,7 @@ class WorkflowRunner(Generic[StateType]):
581
690
 
582
691
  yield self._emit_event(event)
583
692
 
584
- if is_terminal_event(event):
693
+ if self._is_terminal_event(event):
585
694
  break
586
695
 
587
696
  try:
@@ -590,7 +699,7 @@ class WorkflowRunner(Generic[StateType]):
590
699
  except Empty:
591
700
  pass
592
701
 
593
- if not is_terminal_event(event):
702
+ if not self._is_terminal_event(event):
594
703
  yield self._reject_workflow_event(
595
704
  VellumError(
596
705
  code=VellumErrorCode.INTERNAL_ERROR,
@@ -5,15 +5,15 @@ from datetime import datetime
5
5
  from queue import Queue
6
6
  from threading import Lock
7
7
  from uuid import UUID, uuid4
8
- from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, Optional, Sequence, Set, Tuple, Type, cast
8
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, cast
9
9
  from typing_extensions import dataclass_transform
10
10
 
11
11
  from pydantic import GetCoreSchemaHandler, field_serializer
12
12
  from pydantic_core import core_schema
13
13
 
14
14
  from vellum.core.pydantic_utilities import UniversalBaseModel
15
-
16
15
  from vellum.workflows.constants import UNDEF
16
+ from vellum.workflows.edges.edge import Edge
17
17
  from vellum.workflows.inputs.base import BaseInputs
18
18
  from vellum.workflows.references import ExternalInputReference, OutputReference, StateValueReference
19
19
  from vellum.workflows.types.generics import StateType
@@ -71,58 +71,92 @@ def _make_snapshottable(value: Any, snapshot_callback: Callable[[], None]) -> An
71
71
 
72
72
 
73
73
  class NodeExecutionCache:
74
- _node_execution_ids: Dict[Type["BaseNode"], Stack[UUID]]
74
+ _node_executions_fulfilled: Dict[Type["BaseNode"], Stack[UUID]]
75
75
  _node_executions_initiated: Dict[Type["BaseNode"], Set[UUID]]
76
- _dependencies_invoked: Dict[str, Set[str]]
76
+ _node_executions_queued: Dict[Type["BaseNode"], List[UUID]]
77
+ _dependencies_invoked: Dict[UUID, Set[Type["BaseNode"]]]
77
78
 
78
79
  def __init__(
79
80
  self,
80
81
  dependencies_invoked: Optional[Dict[str, Sequence[str]]] = None,
81
- node_execution_ids: Optional[Dict[str, Sequence[str]]] = None,
82
+ node_executions_fulfilled: Optional[Dict[str, Sequence[str]]] = None,
82
83
  node_executions_initiated: Optional[Dict[str, Sequence[str]]] = None,
84
+ node_executions_queued: Optional[Dict[str, Sequence[str]]] = None,
83
85
  ) -> None:
84
86
  self._dependencies_invoked = defaultdict(set)
85
- self._node_execution_ids = defaultdict(Stack[UUID])
87
+ self._node_executions_fulfilled = defaultdict(Stack[UUID])
86
88
  self._node_executions_initiated = defaultdict(set)
89
+ self._node_executions_queued = defaultdict(list)
87
90
 
88
- for node, dependencies in (dependencies_invoked or {}).items():
89
- self._dependencies_invoked[node].update(dependencies)
91
+ for execution_id, dependencies in (dependencies_invoked or {}).items():
92
+ self._dependencies_invoked[UUID(execution_id)] = {get_class_by_qualname(dep) for dep in dependencies}
90
93
 
91
- for node, execution_ids in (node_execution_ids or {}).items():
94
+ for node, execution_ids in (node_executions_fulfilled or {}).items():
92
95
  node_class = get_class_by_qualname(node)
93
- self._node_execution_ids[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
96
+ self._node_executions_fulfilled[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
94
97
 
95
98
  for node, execution_ids in (node_executions_initiated or {}).items():
96
99
  node_class = get_class_by_qualname(node)
97
100
  self._node_executions_initiated[node_class].update({UUID(execution_id) for execution_id in execution_ids})
98
101
 
99
- @property
100
- def dependencies_invoked(self) -> Dict[str, Set[str]]:
101
- return self._dependencies_invoked
102
+ for node, execution_ids in (node_executions_queued or {}).items():
103
+ node_class = get_class_by_qualname(node)
104
+ self._node_executions_queued[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
102
105
 
103
- def is_node_initiated(self, node: Type["BaseNode"]) -> bool:
104
- return node in self._node_executions_initiated and len(self._node_executions_initiated[node]) > 0
106
+ def _invoke_dependency(
107
+ self,
108
+ execution_id: UUID,
109
+ node: Type["BaseNode"],
110
+ dependency: Type["BaseNode"],
111
+ dependencies: Set["Type[BaseNode]"],
112
+ ) -> None:
113
+ self._dependencies_invoked[execution_id].add(dependency)
114
+ if all(dep in self._dependencies_invoked[execution_id] for dep in dependencies):
115
+ self._node_executions_queued[node].remove(execution_id)
116
+
117
+ def queue_node_execution(
118
+ self, node: Type["BaseNode"], dependencies: Set["Type[BaseNode]"], invoked_by: Optional[Edge] = None
119
+ ) -> UUID:
120
+ execution_id = uuid4()
121
+ if not invoked_by:
122
+ return execution_id
123
+
124
+ source_node = invoked_by.from_port.node_class
125
+ for queued_node_execution_id in self._node_executions_queued[node]:
126
+ if source_node not in self._dependencies_invoked[queued_node_execution_id]:
127
+ self._invoke_dependency(queued_node_execution_id, node, source_node, dependencies)
128
+ return queued_node_execution_id
129
+
130
+ self._node_executions_queued[node].append(execution_id)
131
+ self._invoke_dependency(execution_id, node, source_node, dependencies)
132
+ return execution_id
133
+
134
+ def is_node_execution_initiated(self, node: Type["BaseNode"], execution_id: UUID) -> bool:
135
+ return execution_id in self._node_executions_initiated[node]
105
136
 
106
137
  def initiate_node_execution(self, node: Type["BaseNode"], execution_id: UUID) -> None:
107
138
  self._node_executions_initiated[node].add(execution_id)
108
139
 
109
140
  def fulfill_node_execution(self, node: Type["BaseNode"], execution_id: UUID) -> None:
110
- self._node_executions_initiated[node].remove(execution_id)
111
- self._node_execution_ids[node].push(execution_id)
141
+ self._node_executions_fulfilled[node].push(execution_id)
112
142
 
113
143
  def get_execution_count(self, node: Type["BaseNode"]) -> int:
114
- return self._node_execution_ids[node].size()
144
+ return self._node_executions_fulfilled[node].size()
115
145
 
116
146
  def dump(self) -> Dict[str, Any]:
117
147
  return {
118
148
  "dependencies_invoked": {
119
- node: list(dependencies) for node, dependencies in self._dependencies_invoked.items()
149
+ str(execution_id): [str(dep) for dep in dependencies]
150
+ for execution_id, dependencies in self._dependencies_invoked.items()
120
151
  },
121
152
  "node_executions_initiated": {
122
153
  str(node): list(execution_ids) for node, execution_ids in self._node_executions_initiated.items()
123
154
  },
124
- "node_execution_ids": {
125
- str(node): execution_ids.dump() for node, execution_ids in self._node_execution_ids.items()
155
+ "node_executions_fulfilled": {
156
+ str(node): execution_ids.dump() for node, execution_ids in self._node_executions_fulfilled.items()
157
+ },
158
+ "node_executions_queued": {
159
+ str(node): execution_ids for node, execution_ids in self._node_executions_queued.items()
126
160
  },
127
161
  }
128
162