vellum-ai 0.14.39__py3-none-any.whl → 0.14.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. vellum/client/core/client_wrapper.py +1 -1
  2. vellum/client/reference.md +138 -1
  3. vellum/client/resources/ad_hoc/client.py +311 -1
  4. vellum/client/resources/deployments/client.py +2 -2
  5. vellum/workflows/nodes/bases/tests/test_base_node.py +24 -0
  6. vellum/workflows/nodes/core/try_node/node.py +1 -2
  7. vellum/workflows/nodes/experimental/tool_calling_node/__init__.py +3 -0
  8. vellum/workflows/nodes/experimental/tool_calling_node/node.py +125 -0
  9. vellum/workflows/nodes/experimental/tool_calling_node/utils.py +128 -0
  10. vellum/workflows/nodes/utils.py +4 -2
  11. vellum/workflows/outputs/base.py +3 -2
  12. vellum/workflows/references/output.py +20 -0
  13. vellum/workflows/state/base.py +36 -14
  14. vellum/workflows/state/tests/test_state.py +5 -2
  15. vellum/workflows/types/stack.py +11 -0
  16. vellum/workflows/workflows/base.py +5 -0
  17. vellum/workflows/workflows/tests/test_base_workflow.py +96 -9
  18. {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.41.dist-info}/METADATA +1 -1
  19. {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.41.dist-info}/RECORD +84 -80
  20. vellum_cli/push.py +0 -2
  21. vellum_ee/workflows/display/base.py +14 -1
  22. vellum_ee/workflows/display/nodes/base_node_display.py +91 -19
  23. vellum_ee/workflows/display/nodes/get_node_display_class.py +9 -15
  24. vellum_ee/workflows/display/nodes/tests/test_base_node_display.py +54 -0
  25. vellum_ee/workflows/display/nodes/vellum/api_node.py +2 -2
  26. vellum_ee/workflows/display/nodes/vellum/base_adornment_node.py +4 -4
  27. vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +2 -2
  28. vellum_ee/workflows/display/nodes/vellum/conditional_node.py +2 -2
  29. vellum_ee/workflows/display/nodes/vellum/error_node.py +2 -2
  30. vellum_ee/workflows/display/nodes/vellum/final_output_node.py +2 -2
  31. vellum_ee/workflows/display/nodes/vellum/guardrail_node.py +2 -2
  32. vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +2 -2
  33. vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +2 -2
  34. vellum_ee/workflows/display/nodes/vellum/merge_node.py +2 -2
  35. vellum_ee/workflows/display/nodes/vellum/note_node.py +2 -2
  36. vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +2 -4
  37. vellum_ee/workflows/display/nodes/vellum/retry_node.py +1 -2
  38. vellum_ee/workflows/display/nodes/vellum/search_node.py +2 -2
  39. vellum_ee/workflows/display/nodes/vellum/subworkflow_deployment_node.py +2 -2
  40. vellum_ee/workflows/display/nodes/vellum/templating_node.py +2 -2
  41. vellum_ee/workflows/display/nodes/vellum/tests/test_code_execution_node.py +1 -2
  42. vellum_ee/workflows/display/nodes/vellum/tests/test_error_node.py +1 -2
  43. vellum_ee/workflows/display/nodes/vellum/tests/test_note_node.py +1 -2
  44. vellum_ee/workflows/display/nodes/vellum/tests/test_prompt_node.py +55 -3
  45. vellum_ee/workflows/display/nodes/vellum/tests/test_retry_node.py +1 -2
  46. vellum_ee/workflows/display/nodes/vellum/tests/test_templating_node.py +1 -2
  47. vellum_ee/workflows/display/nodes/vellum/tests/test_try_node.py +1 -2
  48. vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +4 -4
  49. vellum_ee/workflows/display/nodes/vellum/try_node.py +1 -2
  50. vellum_ee/workflows/display/nodes/vellum/utils.py +7 -1
  51. vellum_ee/workflows/display/tests/{test_vellum_workflow_display.py → test_base_workflow_display.py} +10 -22
  52. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/conftest.py +4 -6
  53. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +7 -16
  54. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +2 -6
  55. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +1 -2
  56. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +3 -10
  57. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +4 -5
  58. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_default_state_serialization.py +1 -4
  59. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +1 -4
  60. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_generic_node_serialization.py +2 -5
  61. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +7 -5
  62. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_prompt_node_serialization.py +1 -4
  63. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +1 -4
  64. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +1 -2
  65. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +1 -4
  66. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +1 -4
  67. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_search_node_serialization.py +7 -5
  68. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +1 -4
  69. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +1 -4
  70. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +1 -4
  71. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +2 -5
  72. vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +2 -7
  73. vellum_ee/workflows/display/types.py +5 -4
  74. vellum_ee/workflows/display/utils/exceptions.py +7 -0
  75. vellum_ee/workflows/display/utils/registry.py +37 -0
  76. vellum_ee/workflows/display/utils/vellum.py +2 -1
  77. vellum_ee/workflows/display/workflows/base_workflow_display.py +277 -47
  78. vellum_ee/workflows/display/workflows/get_vellum_workflow_display_class.py +34 -21
  79. vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +58 -20
  80. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +4 -257
  81. vellum_ee/workflows/tests/local_workflow/display/workflow.py +2 -2
  82. vellum_ee/workflows/display/nodes/base_node_vellum_display.py +0 -40
  83. {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.41.dist-info}/LICENSE +0 -0
  84. {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.41.dist-info}/WHEEL +0 -0
  85. {vellum_ai-0.14.39.dist-info → vellum_ai-0.14.41.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,125 @@
1
+ from collections.abc import Callable
2
+ from typing import Any, ClassVar, List, Optional
3
+
4
+ from vellum import ChatMessage, PromptBlock
5
+ from vellum.workflows.context import execution_context, get_parent_context
6
+ from vellum.workflows.errors.types import WorkflowErrorCode
7
+ from vellum.workflows.exceptions import NodeException
8
+ from vellum.workflows.graph.graph import Graph
9
+ from vellum.workflows.inputs.base import BaseInputs
10
+ from vellum.workflows.nodes.bases import BaseNode
11
+ from vellum.workflows.nodes.experimental.tool_calling_node.utils import (
12
+ ToolRouterNode,
13
+ create_function_node,
14
+ create_tool_router_node,
15
+ )
16
+ from vellum.workflows.outputs.base import BaseOutputs
17
+ from vellum.workflows.state.base import BaseState
18
+ from vellum.workflows.state.context import WorkflowContext
19
+ from vellum.workflows.types.core import EntityInputsInterface
20
+ from vellum.workflows.workflows.base import BaseWorkflow
21
+
22
+
23
+ class ToolCallingNode(BaseNode):
24
+ """
25
+ A Node that dynamically invokes the provided functions to the underlying Prompt
26
+
27
+ Attributes:
28
+ ml_model: str - The model to use for tool calling (e.g., "gpt-4o-mini")
29
+ blocks: List[PromptBlock] - The prompt blocks to use (same format as InlinePromptNode)
30
+ functions: List[FunctionDefinition] - The functions that can be called
31
+ function_callables: List[Callable] - The callables that can be called
32
+ prompt_inputs: Optional[EntityInputsInterface] - Mapping of input variable names to values
33
+ """
34
+
35
+ ml_model: ClassVar[str] = "gpt-4o-mini"
36
+ blocks: ClassVar[List[PromptBlock]] = []
37
+ functions: ClassVar[List[Callable[..., Any]]] = []
38
+ prompt_inputs: ClassVar[Optional[EntityInputsInterface]] = None
39
+ # TODO: https://linear.app/vellum/issue/APO-342/support-tool-call-max-retries
40
+ max_tool_calls: ClassVar[int] = 1
41
+
42
+ class Outputs(BaseOutputs):
43
+ """
44
+ The outputs of the ToolCallingNode.
45
+
46
+ text: The final text response after tool calling
47
+ chat_history: The complete chat history including tool calls
48
+ """
49
+
50
+ text: str = ""
51
+ chat_history: List[ChatMessage] = []
52
+
53
+ def run(self) -> Outputs:
54
+ """
55
+ Run the tool calling workflow.
56
+
57
+ This dynamically builds a graph with router and function nodes,
58
+ then executes the workflow.
59
+ """
60
+
61
+ self._build_graph()
62
+
63
+ with execution_context(parent_context=get_parent_context()):
64
+
65
+ class ToolCallingState(BaseState):
66
+ chat_history: List[ChatMessage] = []
67
+
68
+ class ToolCallingWorkflow(BaseWorkflow[BaseInputs, ToolCallingState]):
69
+ graph = self._graph
70
+
71
+ class Outputs(BaseWorkflow.Outputs):
72
+ text: str = ToolRouterNode.Outputs.text
73
+ chat_history: List[ChatMessage] = ToolCallingState.chat_history
74
+
75
+ subworkflow = ToolCallingWorkflow(
76
+ parent_state=self.state,
77
+ context=WorkflowContext.create_from(self._context),
78
+ )
79
+
80
+ terminal_event = subworkflow.run()
81
+
82
+ if terminal_event.name == "workflow.execution.paused":
83
+ raise NodeException(
84
+ code=WorkflowErrorCode.INVALID_OUTPUTS,
85
+ message="Subworkflow unexpectedly paused",
86
+ )
87
+ elif terminal_event.name == "workflow.execution.fulfilled":
88
+ node_outputs = self.Outputs()
89
+
90
+ for output_descriptor, output_value in terminal_event.outputs:
91
+ setattr(node_outputs, output_descriptor.name, output_value)
92
+
93
+ return node_outputs
94
+ elif terminal_event.name == "workflow.execution.rejected":
95
+ raise Exception(f"Workflow execution rejected: {terminal_event.error}")
96
+
97
+ raise Exception(f"Unexpected workflow event: {terminal_event.name}")
98
+
99
+ def _build_graph(self) -> None:
100
+ self.tool_router_node = create_tool_router_node(
101
+ ml_model=self.ml_model,
102
+ blocks=self.blocks,
103
+ functions=self.functions,
104
+ prompt_inputs=self.prompt_inputs,
105
+ )
106
+
107
+ self._function_nodes = {
108
+ function.__name__: create_function_node(
109
+ function=function,
110
+ )
111
+ for function in self.functions
112
+ }
113
+
114
+ graph_set = set()
115
+
116
+ # Add connections from ports of router to function nodes and back to router
117
+ for function_name, FunctionNodeClass in self._function_nodes.items():
118
+ router_port = getattr(self.tool_router_node.Ports, function_name)
119
+ edge_graph = router_port >> FunctionNodeClass >> self.tool_router_node
120
+ graph_set.add(edge_graph)
121
+
122
+ default_port = getattr(self.tool_router_node.Ports, "default")
123
+ graph_set.add(default_port)
124
+
125
+ self._graph = Graph.from_set(graph_set)
@@ -0,0 +1,128 @@
1
+ from collections.abc import Callable
2
+ import json
3
+ from typing import Any, Iterator, List, Optional, Type, cast
4
+
5
+ from vellum import ChatMessage, FunctionDefinition, PromptBlock
6
+ from vellum.client.types.function_call_chat_message_content import FunctionCallChatMessageContent
7
+ from vellum.client.types.function_call_chat_message_content_value import FunctionCallChatMessageContentValue
8
+ from vellum.client.types.variable_prompt_block import VariablePromptBlock
9
+ from vellum.workflows.nodes.bases import BaseNode
10
+ from vellum.workflows.nodes.displayable.inline_prompt_node.node import InlinePromptNode
11
+ from vellum.workflows.outputs.base import BaseOutput
12
+ from vellum.workflows.ports.port import Port
13
+ from vellum.workflows.references.lazy import LazyReference
14
+ from vellum.workflows.types.core import EntityInputsInterface
15
+
16
+
17
+ class FunctionNode(BaseNode):
18
+ """Node that executes a specific function."""
19
+
20
+ function: FunctionDefinition
21
+
22
+
23
+ class ToolRouterNode(InlinePromptNode):
24
+ def run(self) -> Iterator[BaseOutput]:
25
+ self.prompt_inputs = {**self.prompt_inputs, "chat_history": self.state.chat_history} # type: ignore
26
+ generator = super().run()
27
+ for output in generator:
28
+ if output.name == "results" and output.value:
29
+ values = cast(List[Any], output.value)
30
+ if values and len(values) > 0:
31
+ if values[0].type == "STRING":
32
+ self.state.chat_history.append(ChatMessage(role="ASSISTANT", text=values[0].value))
33
+ elif values[0].type == "FUNCTION_CALL":
34
+ function_call = values[0].value
35
+ if function_call is not None:
36
+ self.state.chat_history.append(
37
+ ChatMessage(
38
+ role="ASSISTANT",
39
+ content=FunctionCallChatMessageContent(
40
+ value=FunctionCallChatMessageContentValue(
41
+ name=function_call.name,
42
+ arguments=function_call.arguments,
43
+ id=function_call.id,
44
+ ),
45
+ ),
46
+ )
47
+ )
48
+ yield output
49
+
50
+
51
+ def create_tool_router_node(
52
+ ml_model: str,
53
+ blocks: List[PromptBlock],
54
+ functions: List[Callable[..., Any]],
55
+ prompt_inputs: Optional[EntityInputsInterface],
56
+ ) -> Type[ToolRouterNode]:
57
+ Ports = type("Ports", (), {})
58
+ for function in functions:
59
+ function_name = function.__name__
60
+ port_condition = LazyReference(
61
+ lambda: (
62
+ ToolRouterNode.Outputs.results[0]["type"].equals("FUNCTION_CALL")
63
+ & ToolRouterNode.Outputs.results[0]["value"]["name"].equals(function_name)
64
+ )
65
+ )
66
+ port = Port.on_if(port_condition)
67
+ setattr(Ports, function_name, port)
68
+
69
+ setattr(Ports, "default", Port.on_else())
70
+
71
+ # Add a chat history block to blocks
72
+ blocks.append(
73
+ VariablePromptBlock(
74
+ block_type="VARIABLE",
75
+ input_variable="chat_history",
76
+ state=None,
77
+ cache_config=None,
78
+ )
79
+ )
80
+
81
+ node = type(
82
+ "ToolRouterNode",
83
+ (ToolRouterNode,),
84
+ {
85
+ "ml_model": ml_model,
86
+ "blocks": blocks,
87
+ "functions": functions,
88
+ "prompt_inputs": prompt_inputs,
89
+ "Ports": Ports,
90
+ "__module__": __name__,
91
+ },
92
+ )
93
+
94
+ return node
95
+
96
+
97
+ def create_function_node(function: Callable[..., Any]) -> Type[FunctionNode]:
98
+ """
99
+ Create a FunctionNode class for a given function.
100
+
101
+ This ensures the callable is properly registered and can be called with the expected arguments.
102
+ """
103
+
104
+ # Create a class-level wrapper that calls the original function
105
+ def execute_function(self) -> BaseNode.Outputs:
106
+ outputs = self.state.meta.node_outputs.get(ToolRouterNode.Outputs.text)
107
+ # first parse into json
108
+ outputs = json.loads(outputs)
109
+ arguments = outputs["arguments"]
110
+
111
+ # Call the original function directly with the arguments
112
+ result = function(**arguments)
113
+
114
+ self.state.chat_history.append(ChatMessage(role="FUNCTION", text=result))
115
+
116
+ return self.Outputs()
117
+
118
+ node = type(
119
+ f"FunctionNode_{function.__name__}",
120
+ (FunctionNode,),
121
+ {
122
+ "function": function,
123
+ "run": execute_function,
124
+ "__module__": __name__,
125
+ },
126
+ )
127
+
128
+ return node
@@ -57,10 +57,12 @@ def create_adornment(
57
57
  class Subworkflow(BaseWorkflow):
58
58
  graph = inner_cls
59
59
 
60
- # mypy is wrong here, this works and is defined
61
- class Outputs(inner_cls.Outputs): # type: ignore[name-defined]
60
+ class Outputs(BaseWorkflow.Outputs):
62
61
  pass
63
62
 
63
+ for output_reference in inner_cls.Outputs:
64
+ setattr(Subworkflow.Outputs, output_reference.name, output_reference)
65
+
64
66
  dynamic_module = f"{inner_cls.__module__}.{inner_cls.__name__}.{ADORNMENT_MODULE_NAME}"
65
67
  # This dynamic module allows calls to `type_hints` to work
66
68
  sys.modules[dynamic_module] = ModuleType(dynamic_module)
@@ -147,8 +147,9 @@ class _BaseOutputsMeta(type):
147
147
  instance = vars(cls).get(name, undefined)
148
148
  if instance is undefined:
149
149
  for base in cls.__mro__[1:]:
150
- if hasattr(base, name):
151
- instance = getattr(base, name)
150
+ inherited_output_reference = getattr(base, name, undefined)
151
+ if isinstance(inherited_output_reference, OutputReference):
152
+ instance = inherited_output_reference.instance
152
153
  break
153
154
 
154
155
  types = infer_types(cls, name)
@@ -1,4 +1,6 @@
1
+ from functools import cached_property
1
2
  from queue import Queue
3
+ from uuid import UUID, uuid4
2
4
  from typing import TYPE_CHECKING, Any, Generator, Generic, Optional, Tuple, Type, TypeVar, cast
3
5
 
4
6
  from pydantic import GetCoreSchemaHandler
@@ -31,6 +33,24 @@ class OutputReference(BaseDescriptor[_OutputType], Generic[_OutputType]):
31
33
  def outputs_class(self) -> Type["BaseOutputs"]:
32
34
  return self._outputs_class
33
35
 
36
+ @cached_property
37
+ def id(self) -> UUID:
38
+ self._outputs_class = self._outputs_class
39
+
40
+ node_class = getattr(self._outputs_class, "_node_class", None)
41
+ if not node_class:
42
+ return uuid4()
43
+
44
+ output_ids = getattr(node_class, "__output_ids__", {})
45
+ if not isinstance(output_ids, dict):
46
+ return uuid4()
47
+
48
+ output_id = output_ids.get(self.name)
49
+ if not isinstance(output_id, UUID):
50
+ return uuid4()
51
+
52
+ return output_id
53
+
34
54
  def resolve(self, state: "BaseState") -> _OutputType:
35
55
  node_output = state.meta.node_outputs.get(self, undefined)
36
56
  if isinstance(node_output, Queue):
@@ -7,13 +7,14 @@ import logging
7
7
  from queue import Queue
8
8
  from threading import Lock
9
9
  from uuid import UUID, uuid4
10
- from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, cast
10
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union, cast
11
11
  from typing_extensions import dataclass_transform
12
12
 
13
13
  from pydantic import GetCoreSchemaHandler, ValidationInfo, field_serializer, field_validator
14
14
  from pydantic_core import core_schema
15
15
 
16
16
  from vellum.core.pydantic_utilities import UniversalBaseModel
17
+ from vellum.utils.uuid import is_valid_uuid
17
18
  from vellum.workflows.constants import undefined
18
19
  from vellum.workflows.edges.edge import Edge
19
20
  from vellum.workflows.inputs.base import BaseInputs
@@ -109,18 +110,30 @@ class NodeExecutionCache:
109
110
  self._node_executions_queued = defaultdict(list)
110
111
 
111
112
  @classmethod
112
- def deserialize(cls, raw_data: dict, nodes: Dict[str, Type["BaseNode"]]):
113
+ def deserialize(cls, raw_data: dict, nodes: Dict[Union[str, UUID], Type["BaseNode"]]):
113
114
  cache = cls()
114
115
 
116
+ def get_node_class(node_id: Any) -> Optional[Type["BaseNode"]]:
117
+ if not isinstance(node_id, str):
118
+ return None
119
+
120
+ if is_valid_uuid(node_id):
121
+ return nodes.get(UUID(node_id))
122
+
123
+ return nodes.get(node_id)
124
+
115
125
  dependencies_invoked = raw_data.get("dependencies_invoked")
116
126
  if isinstance(dependencies_invoked, dict):
117
127
  for execution_id, dependencies in dependencies_invoked.items():
118
- cache._dependencies_invoked[UUID(execution_id)] = {nodes[dep] for dep in dependencies if dep in nodes}
128
+ dependency_classes = {get_node_class(dep) for dep in dependencies}
129
+ cache._dependencies_invoked[UUID(execution_id)] = {
130
+ dep_class for dep_class in dependency_classes if dep_class is not None
131
+ }
119
132
 
120
133
  node_executions_fulfilled = raw_data.get("node_executions_fulfilled")
121
134
  if isinstance(node_executions_fulfilled, dict):
122
135
  for node, execution_ids in node_executions_fulfilled.items():
123
- node_class = nodes.get(node)
136
+ node_class = get_node_class(node)
124
137
  if not node_class:
125
138
  continue
126
139
 
@@ -131,7 +144,7 @@ class NodeExecutionCache:
131
144
  node_executions_initiated = raw_data.get("node_executions_initiated")
132
145
  if isinstance(node_executions_initiated, dict):
133
146
  for node, execution_ids in node_executions_initiated.items():
134
- node_class = nodes.get(node)
147
+ node_class = get_node_class(node)
135
148
  if not node_class:
136
149
  continue
137
150
 
@@ -142,7 +155,7 @@ class NodeExecutionCache:
142
155
  node_executions_queued = raw_data.get("node_executions_queued")
143
156
  if isinstance(node_executions_queued, dict):
144
157
  for node, execution_ids in node_executions_queued.items():
145
- node_class = nodes.get(node)
158
+ node_class = get_node_class(node)
146
159
  if not node_class:
147
160
  continue
148
161
 
@@ -193,17 +206,18 @@ class NodeExecutionCache:
193
206
  def dump(self) -> Dict[str, Any]:
194
207
  return {
195
208
  "dependencies_invoked": {
196
- str(execution_id): [str(dep) for dep in dependencies]
209
+ str(execution_id): [str(dep.__id__) for dep in dependencies]
197
210
  for execution_id, dependencies in self._dependencies_invoked.items()
198
211
  },
199
212
  "node_executions_initiated": {
200
- str(node): list(execution_ids) for node, execution_ids in self._node_executions_initiated.items()
213
+ str(node.__id__): list(execution_ids) for node, execution_ids in self._node_executions_initiated.items()
201
214
  },
202
215
  "node_executions_fulfilled": {
203
- str(node): execution_ids.dump() for node, execution_ids in self._node_executions_fulfilled.items()
216
+ str(node.__id__): execution_ids.dump()
217
+ for node, execution_ids in self._node_executions_fulfilled.items()
204
218
  },
205
219
  "node_executions_queued": {
206
- str(node): execution_ids for node, execution_ids in self._node_executions_queued.items()
220
+ str(node.__id__): execution_ids for node, execution_ids in self._node_executions_queued.items()
207
221
  },
208
222
  }
209
223
 
@@ -279,7 +293,7 @@ class StateMeta(UniversalBaseModel):
279
293
 
280
294
  @field_serializer("node_outputs")
281
295
  def serialize_node_outputs(self, node_outputs: Dict[OutputReference, Any], _info: Any) -> Dict[str, Any]:
282
- return {str(descriptor): value for descriptor, value in node_outputs.items()}
296
+ return {str(descriptor.id): value for descriptor, value in node_outputs.items()}
283
297
 
284
298
  @field_validator("node_outputs", mode="before")
285
299
  @classmethod
@@ -290,15 +304,22 @@ class StateMeta(UniversalBaseModel):
290
304
  return node_outputs
291
305
 
292
306
  raw_workflow_nodes = workflow_definition.get_nodes()
293
- workflow_node_outputs = {}
307
+ workflow_node_outputs: Dict[Union[str, UUID], OutputReference] = {}
294
308
  for node in raw_workflow_nodes:
295
309
  for output in node.Outputs:
296
310
  workflow_node_outputs[str(output)] = output
311
+ output_id = node.__output_ids__.get(output.name)
312
+ if output_id:
313
+ workflow_node_outputs[output_id] = output
297
314
 
298
315
  node_output_keys = list(node_outputs.keys())
299
316
  deserialized_node_outputs = {}
300
317
  for node_output_key in node_output_keys:
301
- output_reference = workflow_node_outputs.get(node_output_key)
318
+ if is_valid_uuid(node_output_key):
319
+ output_reference = workflow_node_outputs.get(UUID(node_output_key))
320
+ else:
321
+ output_reference = workflow_node_outputs.get(node_output_key)
322
+
302
323
  if not output_reference:
303
324
  continue
304
325
 
@@ -316,10 +337,11 @@ class StateMeta(UniversalBaseModel):
316
337
  if not workflow_definition:
317
338
  return node_execution_cache
318
339
 
319
- nodes_cache: Dict[str, Type["BaseNode"]] = {}
340
+ nodes_cache: Dict[Union[str, UUID], Type["BaseNode"]] = {}
320
341
  raw_workflow_nodes = workflow_definition.get_nodes()
321
342
  for node in raw_workflow_nodes:
322
343
  nodes_cache[str(node)] = node
344
+ nodes_cache[node.__id__] = node
323
345
 
324
346
  return NodeExecutionCache.deserialize(node_execution_cache, nodes_cache)
325
347
 
@@ -47,6 +47,9 @@ class MockNode(BaseNode):
47
47
  baz: str
48
48
 
49
49
 
50
+ MOCK_NODE_OUTPUT_ID = "e4dc3136-0c27-4bda-b3ab-ea355d5219d6"
51
+
52
+
50
53
  def test_state_snapshot__node_attribute_edit():
51
54
  # GIVEN an initial state instance
52
55
  state = MockState(foo="bar")
@@ -144,7 +147,7 @@ def test_state_json_serialization__with_node_output_updates():
144
147
  json_state = json.loads(json.dumps(state, cls=DefaultStateEncoder))
145
148
 
146
149
  # THEN the state is serialized correctly
147
- assert json_state["meta"]["node_outputs"] == {"MockNode.Outputs.baz": "hello"}
150
+ assert json_state["meta"]["node_outputs"] == {MOCK_NODE_OUTPUT_ID: "hello"}
148
151
 
149
152
 
150
153
  def test_state_deepcopy__with_external_input_updates():
@@ -185,7 +188,7 @@ def test_state_json_serialization__with_queue():
185
188
  json_state = json.loads(json.dumps(state, cls=DefaultStateEncoder))
186
189
 
187
190
  # THEN the state is serialized correctly with the queue turned into a list
188
- assert json_state["meta"]["node_outputs"] == {"MockNode.Outputs.baz": ["test1", "test2"]}
191
+ assert json_state["meta"]["node_outputs"] == {MOCK_NODE_OUTPUT_ID: ["test1", "test2"]}
189
192
 
190
193
 
191
194
  def test_state_snapshot__deepcopy_fails__logs_error(mock_deepcopy, mock_logger):
@@ -37,3 +37,14 @@ class Stack(Generic[_T]):
37
37
 
38
38
  def dump(self) -> List[_T]:
39
39
  return [item for item in self._items][::-1]
40
+
41
+ @classmethod
42
+ def from_list(cls, items: List[_T]) -> "Stack[_T]":
43
+ stack = cls()
44
+ stack.extend(items)
45
+ return stack
46
+
47
+ def __eq__(self, other: object) -> bool:
48
+ if not isinstance(other, Stack):
49
+ return False
50
+ return self._items == other._items
@@ -80,6 +80,11 @@ class _BaseWorkflowMeta(type):
80
80
  def __new__(mcs, name: str, bases: Tuple[Type, ...], dct: Dict[str, Any]) -> Any:
81
81
  if "graph" not in dct:
82
82
  dct["graph"] = set()
83
+ for base in bases:
84
+ base_graph = getattr(base, "graph", None)
85
+ if base_graph:
86
+ dct["graph"] = base_graph
87
+ break
83
88
 
84
89
  if "Outputs" in dct:
85
90
  outputs_class = dct["Outputs"]
@@ -1,5 +1,5 @@
1
1
  import pytest
2
- from uuid import uuid4
2
+ from uuid import UUID, uuid4
3
3
 
4
4
  from vellum.workflows.edges.edge import Edge
5
5
  from vellum.workflows.inputs.base import BaseInputs
@@ -7,6 +7,7 @@ from vellum.workflows.nodes.bases.base import BaseNode
7
7
  from vellum.workflows.nodes.core.inline_subworkflow_node.node import InlineSubworkflowNode
8
8
  from vellum.workflows.outputs.base import BaseOutputs
9
9
  from vellum.workflows.state.base import BaseState
10
+ from vellum.workflows.types.stack import Stack
10
11
  from vellum.workflows.workflows.base import BaseWorkflow
11
12
 
12
13
 
@@ -326,7 +327,6 @@ def test_workflow__unsupported_graph_item():
326
327
 
327
328
 
328
329
  def test_base_workflow__deserialize_state():
329
-
330
330
  # GIVEN a state definition
331
331
  class State(BaseState):
332
332
  bar: str
@@ -340,6 +340,14 @@ def test_base_workflow__deserialize_state():
340
340
  class Outputs(BaseNode.Outputs):
341
341
  foo: str
342
342
 
343
+ # AND a node id for node A
344
+ node_a_id = uuid4()
345
+ NodeA.__id__ = node_a_id
346
+
347
+ # AND an output id for NodeA.Outputs.foo
348
+ node_a_output_id = uuid4()
349
+ NodeA.__output_ids__["foo"] = node_a_output_id
350
+
343
351
  # AND a workflow that uses all three
344
352
  class TestWorkflow(BaseWorkflow[Inputs, State]):
345
353
  graph = NodeA
@@ -356,20 +364,20 @@ def test_base_workflow__deserialize_state():
356
364
  "updated_ts": "2025-04-14T19:22:18.504902",
357
365
  "external_inputs": {},
358
366
  "node_outputs": {
359
- "test_base_workflow__deserialize_state.<locals>.NodeA.Outputs.foo": "My node A output foo"
367
+ str(node_a_output_id): "My node A output foo",
360
368
  },
361
369
  "node_execution_cache": {
362
370
  "dependencies_invoked": {
363
- last_span_id: ["test_base_workflow__deserialize_state.<locals>.NodeA"],
371
+ last_span_id: [str(node_a_id)],
364
372
  },
365
373
  "node_executions_initiated": {
366
- "test_base_workflow__deserialize_state.<locals>.NodeA": [last_span_id],
374
+ str(node_a_id): [last_span_id],
367
375
  },
368
376
  "node_executions_fulfilled": {
369
- "test_base_workflow__deserialize_state.<locals>.NodeA": [last_span_id],
377
+ str(node_a_id): [last_span_id],
370
378
  },
371
379
  "node_executions_queued": {
372
- "test_base_workflow__deserialize_state.<locals>.NodeA": [],
380
+ str(node_a_id): [],
373
381
  },
374
382
  },
375
383
  "parent": None,
@@ -384,8 +392,87 @@ def test_base_workflow__deserialize_state():
384
392
  assert isinstance(state.meta.workflow_inputs, Inputs)
385
393
  assert state.meta.workflow_inputs.baz == "My input baz"
386
394
 
387
- # AND the node execution cache should deserialize
388
- assert state.meta.node_execution_cache
395
+ # AND the node execution cache should deserialize correctly
396
+ assert state.meta.node_execution_cache._node_executions_initiated == {NodeA: {UUID(last_span_id)}}
397
+ assert state.meta.node_execution_cache._node_executions_fulfilled == {NodeA: Stack.from_list([UUID(last_span_id)])}
398
+ assert state.meta.node_execution_cache._node_executions_queued == {NodeA: []}
399
+ assert state.meta.node_execution_cache._dependencies_invoked == {UUID(last_span_id): {NodeA}}
400
+
401
+
402
+ def test_base_workflow__deserialize_legacy_node_execution_cache():
403
+
404
+ # GIVEN a state definition
405
+ class State(BaseState):
406
+ bar: str
407
+
408
+ # AND a node
409
+ class NodeA(BaseNode):
410
+ class Outputs(BaseNode.Outputs):
411
+ foo: str
412
+
413
+ # AND a workflow that uses both
414
+ class TestWorkflow(BaseWorkflow[BaseInputs, State]):
415
+ graph = NodeA
416
+
417
+ # WHEN we deserialize the state that had a legacy node execution cache format
418
+ last_span_id = str(uuid4())
419
+ node_a_full_name = "vellum.workflows.workflows.tests.test_base_workflow.test_base_workflow__deserialize_legacy_node_execution_cache.<locals>.NodeA" # noqa: E501
420
+ state = TestWorkflow.deserialize_state(
421
+ {
422
+ "meta": {
423
+ "node_execution_cache": {
424
+ "dependencies_invoked": {
425
+ last_span_id: [node_a_full_name],
426
+ },
427
+ "node_executions_initiated": {
428
+ node_a_full_name: [last_span_id],
429
+ },
430
+ "node_executions_fulfilled": {
431
+ node_a_full_name: [last_span_id],
432
+ },
433
+ "node_executions_queued": {
434
+ node_a_full_name: [],
435
+ },
436
+ },
437
+ },
438
+ },
439
+ )
440
+
441
+ # THEN the node execution cache should deserialize correctly
442
+ assert state.meta.node_execution_cache._node_executions_initiated == {NodeA: {UUID(last_span_id)}}
443
+ assert state.meta.node_execution_cache._node_executions_fulfilled == {NodeA: Stack.from_list([UUID(last_span_id)])}
444
+ assert state.meta.node_execution_cache._node_executions_queued == {NodeA: []}
445
+ assert state.meta.node_execution_cache._dependencies_invoked == {UUID(last_span_id): {NodeA}}
446
+
447
+
448
+ def test_base_workflow__deserialize_legacy_node_outputs():
449
+
450
+ # GIVEN a state definition
451
+ class State(BaseState):
452
+ bar: str
453
+
454
+ # AND a node
455
+ class NodeA(BaseNode):
456
+ class Outputs(BaseNode.Outputs):
457
+ foo: str
458
+
459
+ # AND a workflow that uses both
460
+ class TestWorkflow(BaseWorkflow[BaseInputs, State]):
461
+ graph = NodeA
462
+
463
+ # WHEN we deserialize the state that had a legacy node execution cache format
464
+ state = TestWorkflow.deserialize_state(
465
+ {
466
+ "meta": {
467
+ "node_outputs": {
468
+ "test_base_workflow__deserialize_legacy_node_outputs.<locals>.NodeA.Outputs.foo": "My node A output foo", # noqa: E501
469
+ },
470
+ },
471
+ },
472
+ )
473
+
474
+ # THEN the node execution cache should deserialize correctly
475
+ assert state.meta.node_outputs == {NodeA.Outputs.foo: "My node A output foo"}
389
476
 
390
477
 
391
478
  def test_base_workflow__deserialize_state_with_optional_inputs():
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vellum-ai
3
- Version: 0.14.39
3
+ Version: 0.14.41
4
4
  Summary:
5
5
  License: MIT
6
6
  Requires-Python: >=3.9,<4.0